├── .github ├── dependabot.yml └── workflows │ ├── build.yml │ ├── codspeed.yml │ └── release.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE.txt ├── Makefile ├── README.md ├── docs ├── api.md ├── examples.ipynb ├── index.md ├── reference.md └── requirements.in ├── graphique ├── __init__.py ├── core.py ├── inputs.py ├── interface.py ├── middleware.py ├── models.py ├── py.typed ├── scalars.py ├── service.py └── shell.py ├── mkdocs.yml ├── package.json ├── pyproject.toml ├── requirements.in └── tests ├── __init__.py ├── conftest.py ├── federated.py ├── fixtures ├── alltypes.parquet ├── nofields.parquet ├── partitioned │ ├── north=0 │ │ ├── west=0 │ │ │ └── 18ed2d55859f4e5aabd025832d04a421-0.parquet │ │ └── west=1 │ │ │ └── 18ed2d55859f4e5aabd025832d04a421-0.parquet │ └── north=1 │ │ ├── west=0 │ │ └── 18ed2d55859f4e5aabd025832d04a421-0.parquet │ │ └── west=1 │ │ └── 18ed2d55859f4e5aabd025832d04a421-0.parquet ├── zip_db.parquet └── zipcodes.parquet ├── requirements.in ├── test_bench.py ├── test_core.py ├── test_dataset.py ├── test_models.py └── test_service.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | 9 | - package-ecosystem: "pip" 10 | directory: "/" 11 | schedule: 12 | interval: "weekly" 13 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | branches: [main] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ['3.10', '3.11', '3.12', '3.13'] 16 | arrow-version: [''] 17 | include: 18 | - python-version: 3.x 19 | arrow-version: '--pre --only-binary :all: ' 20 | steps: 21 | - uses: actions/checkout@v4 22 | - uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - run: pip install --extra-index-url https://pypi.fury.io/arrow-nightlies/ ${{ matrix.arrow-version }}pyarrow 26 | - run: pip install -r tests/requirements.in 27 | - run: make check 28 | - run: coverage xml 29 | - uses: codecov/codecov-action@v5 30 | with: 31 | token: ${{ secrets.CODECOV_TOKEN }} 32 | 33 | lint: 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@v4 37 | - uses: actions/setup-python@v5 38 | with: 39 | python-version: 3.x 40 | - run: pip install ruff mypy 41 | - run: make lint 42 | 43 | docs: 44 | runs-on: ubuntu-latest 45 | steps: 46 | - uses: actions/checkout@v4 47 | - uses: actions/setup-python@v5 48 | with: 49 | python-version: 3.x 50 | - run: pip install -r docs/requirements.in 51 | - run: npm install 52 | - run: make html 53 | -------------------------------------------------------------------------------- /.github/workflows/codspeed.yml: -------------------------------------------------------------------------------- 1 | name: codspeed-benchmarks 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [main] 7 | pull_request: 8 | branches: [main] 9 | 10 | jobs: 11 | benchmarks: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | - uses: actions/setup-python@v5 16 | with: 17 | python-version: 3.x 18 | - run: pip install -r tests/requirements.in 19 | - uses: CodSpeedHQ/action@v3 20 | with: 21 | token: ${{ secrets.CODSPEED_TOKEN }} 22 | run: make bench 23 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | permissions: write-all 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: 3.x 17 | - run: pip install build -r docs/requirements.in 18 | - run: python -m build 19 | - run: npm install 20 | - run: make html 21 | - run: PYTHONPATH=$PWD python -m mkdocs gh-deploy --force 22 | - uses: pypa/gh-action-pypi-publish@release/v1 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .coverage 3 | site/ 4 | docs/schema.* 5 | node_modules/ 6 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). 5 | 6 | ## Unreleased 7 | ### Changed 8 | * Pyarrow >=19 required 9 | * Python >=3.10 required 10 | 11 | ## [1.8](https://pypi.org/project/graphique/1.8/) - 2024-11-01 12 | ### Changed 13 | * Pyarrow >=18 required 14 | * `isodate` dependency for durations 15 | * Acero engine used for scanning 16 | * Grouping defaults to parallelized but unordered 17 | * Partitioning supports arbitrary functions 18 | * `group` optimized for dictionary arrays 19 | * `rank` optimized for out-of-core 20 | 21 | ## [1.7](https://pypi.org/project/graphique/1.7/) - 2024-07-19 22 | ### Changed 23 | * Pyarrow >=17 required 24 | * Partitioning supports original indices 25 | * Acero engine declaration 26 | * `Duration` format improvements 27 | 28 | ### Fixed 29 | * Strawberry >=0.236 compatible 30 | 31 | ## [1.6](https://pypi.org/project/graphique/1.6/) - 2024-04-30 32 | ### Changed 33 | * Pyarrow >=16 required 34 | * `group` optimized for datasets 35 | * `Duration` scalar 36 | 37 | ### Removed 38 | * `Interval` type 39 | 40 | ## [1.5](https://pypi.org/project/graphique/1.5/) - 2024-01-24 41 | ### Changed 42 | * Pyarrow >=15 required 43 | 44 | ### Fixed 45 | * Strawberry >=0.212 compatible 46 | * Starlette >=0.36 compatible 47 | 48 | ## [1.4](https://pypi.org/project/graphique/1.4/) - 2023-11-05 49 | ### Changed 50 | * Pyarrow >=14 required 51 | * Python >=3.9 required 52 | * `group` optimized for memory 53 | 54 | ### Removed 55 | * `fragments` replaced by `group` 56 | * `min` and `max` replaced by `rank` 57 | * `partition` replaced by `runs` 58 | * `list` aggregation must be explicit 59 | * `group` list functions are in `apply` 60 | 61 | ## [1.3](https://pypi.org/project/graphique/1.3/) - 2023-08-25 62 | ### Changed 63 | * Pyarrow >=13 required 64 | * List filtering and sorting moved to functions and optimized 65 | * Dataset filtering, grouping, and sorting on fragments optimized 66 | * `group` can aggregate entire table 67 | 68 | ### Added 69 | * `flatten` field for list columns 70 | * `rank` field for min and max filtering 71 | * Schema extensions for metrics and deprecations 72 | * `optional` field for partial query results 73 | * `dropNull`, `fillNull`, and `size` fields 74 | * Command-line utilities 75 | * Allow datasets with invalid field names 76 | 77 | ### Deprecated 78 | * `fragments` field deprecated and functionality moved to `group` field 79 | * Implicit list aggregation on `group` deprecated 80 | * `partition` field deprecated and renamed to `runs` 81 | 82 | ## [1.2](https://pypi.org/project/graphique/1.2/) - 2023-05-07 83 | ### Changed 84 | * Pyarrow >=12 required 85 | * Grouping fragments optimized 86 | * Group by empty columns 87 | * Batch sorting and grouping into lists 88 | 89 | ## [1.1](https://pypi.org/project/graphique/1.1/) - 2023-01-29 90 | * Pyarrow >=11 required 91 | * Python >=3.8 required 92 | * Scannable functions added 93 | * List aggregations deprecated 94 | * Group by fragments 95 | * Month day nano interval array 96 | * `min` and `max` fields memory optimized 97 | 98 | ## [1.0](https://pypi.org/project/graphique/1.0/) - 2022-10-28 99 | * Pyarrow >=10 required 100 | * Dataset schema introspection 101 | * Dataset scanning with selection and projection 102 | * Binary search on sorted columns 103 | * List aggregation, filtering, and sorting optimizations 104 | * Compute functions generalized 105 | * Multiple datasets and federation 106 | * Provisional dataset `join` and `take` 107 | 108 | ## [0.9](https://pypi.org/project/graphique/0.9/) - 2022-08-04 109 | * Pyarrow >=9 required 110 | * Multi-directional sorting 111 | * Removed unnecessary interfaces 112 | * Filtering has stricter typing 113 | 114 | ## [0.8](https://pypi.org/project/graphique/0.8/) - 2022-05-08 115 | * Pyarrow >=8 required 116 | * Grouping and aggregation integrated 117 | * `AbstractTable` interface renamed to `Dataset` 118 | * `Binary` scalar renamed to `Base64` 119 | 120 | ## [0.7](https://pypi.org/project/graphique/0.7/) - 2022-02-04 121 | * Pyarrow >=7 required 122 | * `FILTERS` use query syntax and trigger reading the dataset 123 | * `FEDERATED` field configuration 124 | * List columns support sorting and filtering 125 | * Group by and aggregate optimizations 126 | * Dataset scanning 127 | 128 | ## [0.6](https://pypi.org/project/graphique/0./) - 2021-10-28 129 | * Pyarrow >=6 required 130 | * Group by optimized and replaced `unique` field 131 | * Dictionary related optimizations 132 | * Null consistency with arrow `count` functions 133 | 134 | ## [0.5](https://pypi.org/project/graphique/0.5/) - 2021-08-06 135 | * Pyarrow >=5 required 136 | * Stricter validation of inputs 137 | * Columns can be cast to another arrow data type 138 | * Grouping uses large list arrays with 64-bit counts 139 | * Datasets are read on-demand or optionally at startup 140 | 141 | ## [0.4](https://pypi.org/project/graphique/0.4/) - 2021-05-16 142 | * Pyarrow >=4 required 143 | * `sort` updated to use new native routines 144 | * `partition` tables by adjacent values and differences 145 | * `filter` supports unknown column types using tagged union pattern 146 | * `Groups` replaced with `Table.tables` and `Table.aggregate` fields 147 | * Tagged unions used for `filter`, `apply`, and `partition` functions 148 | 149 | ## [0.3](https://pypi.org/project/graphique/0.3/) - 2021-01-31 150 | * Pyarrow >=3 required 151 | * `any` and `all` fields 152 | * String column `split` field 153 | 154 | ## [0.2](https://pypi.org/project/graphique/0.2/) - 2020-11-26 155 | * Pyarrow >= 2 required 156 | * `ListColumn` and `StructColumn` types 157 | * `Groups` type with `aggregate` field 158 | * `group` and `unique` optimized 159 | * Statistical fields: `mode`, `stddev`, `variance` 160 | * `is_in`, `min`, and `max` optimized 161 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2022 Aric Coady 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | check: 2 | python -m pytest -s --cov 3 | 4 | bench: 5 | python -m pytest --codspeed 6 | 7 | lint: 8 | ruff check . 9 | ruff format --check . 10 | mypy -p graphique 11 | 12 | html: docs/schema.md 13 | PYTHONPATH=$(PWD) python -m mkdocs build 14 | 15 | docs/schema.md: docs/schema.graphql 16 | ./node_modules/.bin/graphql-markdown \ 17 | --title "Example Schema" \ 18 | --no-toc \ 19 | --prologue "Generated from a test fixture of zipcodes." \ 20 | $? > $@ 21 | 22 | docs/schema.graphql: graphique/*.py 23 | PARQUET_PATH=tests/fixtures/zipcodes.parquet strawberry export-schema graphique.service:app.schema > $@ 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![image](https://img.shields.io/pypi/v/graphique.svg)](https://pypi.org/project/graphique/) 2 | ![image](https://img.shields.io/pypi/pyversions/graphique.svg) 3 | [![image](https://pepy.tech/badge/graphique)](https://pepy.tech/project/graphique) 4 | ![image](https://img.shields.io/pypi/status/graphique.svg) 5 | [![build](https://github.com/coady/graphique/actions/workflows/build.yml/badge.svg)](https://github.com/coady/graphique/actions/workflows/build.yml) 6 | [![image](https://codecov.io/gh/coady/graphique/branch/main/graph/badge.svg)](https://codecov.io/gh/coady/graphique/) 7 | [![CodeQL](https://github.com/coady/graphique/actions/workflows/github-code-scanning/codeql/badge.svg)](https://github.com/coady/graphique/actions/workflows/github-code-scanning/codeql) 8 | [![CodSpeed Badge](https://img.shields.io/endpoint?url=https://codspeed.io/badge.json)](https://codspeed.io/coady/graphique) 9 | [![image](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 10 | [![image](https://mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/) 11 | 12 | [GraphQL](https://graphql.org) service for [arrow](https://arrow.apache.org) tables and [parquet](https://parquet.apache.org) data sets. The schema for a query API is derived automatically. 13 | 14 | ## Usage 15 | ```console 16 | % env PARQUET_PATH=... uvicorn graphique.service:app 17 | ``` 18 | 19 | Open http://localhost:8000/ to try out the API in [GraphiQL](https://github.com/graphql/graphiql/tree/main/packages/graphiql#readme). There is a test fixture at `./tests/fixtures/zipcodes.parquet`. 20 | 21 | ```console 22 | % env PARQUET_PATH=... strawberry export-schema graphique.service:app.schema 23 | ``` 24 | outputs the graphql schema for a parquet data set. 25 | 26 | ### Configuration 27 | Graphique uses [Starlette's config](https://www.starlette.io/config/): in environment variables or a `.env` file. Config variables are used as input to a [parquet dataset](https://arrow.apache.org/docs/python/dataset.html). 28 | 29 | * PARQUET_PATH: path to the parquet directory or file 30 | * FEDERATED = '': field name to extend type `Query` with a federated `Table` 31 | * DEBUG = False: run service in debug mode, which includes metrics 32 | * COLUMNS = None: list of names, or mapping of aliases, of columns to select 33 | * FILTERS = None: json `filter` query for which rows to read at startup 34 | 35 | For more options create a custom [ASGI](https://asgi.readthedocs.io/en/latest/index.html) app. Call graphique's `GraphQL` on an arrow [Dataset](https://arrow.apache.org/docs/python/api/dataset.html), [Scanner](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html), or [Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html). The GraphQL `Table` type will be the root Query type. 36 | 37 | Supply a mapping of names to datasets for multiple roots, and to enable federation. 38 | 39 | ```python 40 | import pyarrow.dataset as ds 41 | from graphique import GraphQL 42 | 43 | source = ds.dataset(...) 44 | app = GraphQL(source) # Table is root query type 45 | app = GraphQL.federated({: source, ...}, keys={: [], ...}) # Tables on federated fields 46 | ``` 47 | 48 | Start like any ASGI app. 49 | 50 | ```console 51 | uvicorn :app 52 | ``` 53 | 54 | Configuration options exist to provide a convenient no-code solution, but are subject to change in the future. Using a custom app is recommended for production usage. 55 | 56 | ### API 57 | #### types 58 | * `Dataset`: interface for an arrow dataset, scanner, or table. 59 | * `Table`: implements the `Dataset` interface. Adds typed `row`, `columns`, and `filter` fields from introspecting the schema. 60 | * `Column`: interface for an arrow column (a.k.a. ChunkedArray). Each arrow data type has a corresponding column implementation: Boolean, Int, Long, Float, Decimal, Date, Datetime, Time, Duration, Base64, String, List, Struct. All columns have a `values` field for their list of scalars. Additional fields vary by type. 61 | * `Row`: scalar fields. Arrow tables are column-oriented, and graphique encourages that usage for performance. A single `row` field is provided for convenience, but a field for a list of rows is not. Requesting parallel columns is far more efficient. 62 | 63 | #### selection 64 | * `slice`: contiguous selection of rows 65 | * `filter`: select rows with simple predicates 66 | * `scan`: select rows and project columns with expressions 67 | 68 | #### projection 69 | * `columns`: provides a field for every `Column` in the schema 70 | * `column`: access a column of any type by name 71 | * `row`: provides a field for each scalar of a single row 72 | * `apply`: transform columns by applying a function 73 | * `join`: join tables by key columns 74 | 75 | #### aggregation 76 | * `group`: group by given columns, and aggregate the others 77 | * `runs`: partition on adjacent values in given columns, transforming the others into list columns 78 | * `tables`: return a list of tables by splitting on the scalars in list columns 79 | * `flatten`: flatten list columns with repeated scalars 80 | 81 | #### ordering 82 | * `sort`: sort table by given columns 83 | * `rank`: select rows with smallest or largest values 84 | 85 | ### Performance 86 | Graphique relies on native [PyArrow](https://arrow.apache.org/docs/python/index.html) routines wherever possible. Otherwise it falls back to using [NumPy](https://numpy.org/doc/stable/) or custom optimizations. 87 | 88 | By default, datasets are read on-demand, with only the necessary rows and columns scanned. Although graphique is a running service, [parquet is performant](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html) at reading a subset of data. Optionally specify `FILTERS` in the json `filter` format to read a subset of rows at startup, trading-off memory for latency. An empty filter (`{}`) will read the whole table. 89 | 90 | Specifying `COLUMNS` will limit memory usage when reading at startup (`FILTERS`). There is little speed difference as unused columns are inherently ignored. Optional aliasing can also be used for camel casing. 91 | 92 | If index columns are detected in the schema metadata, then an initial `filter` will also attempt a binary search on tables. 93 | 94 | ## Installation 95 | ```console 96 | % pip install graphique[server] 97 | ``` 98 | 99 | ## Dependencies 100 | * pyarrow 101 | * strawberry-graphql[asgi,cli] 102 | * numpy 103 | * isodate 104 | * uvicorn (or other [ASGI server](https://asgi.readthedocs.io/en/latest/implementations.html)) 105 | 106 | ## Tests 107 | 100% branch coverage. 108 | 109 | ```console 110 | % pytest [--cov] 111 | ``` 112 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | ## Types 2 | A typed schema is automatically generated from the arrow table and its columns. However, advanced usage of tables often creates new columns - or changes the type of existing ones - and therefore falls outside the schema. Fields which create columns also allow aliasing, otherwise the column is replaced. 3 | 4 | ### Output 5 | A column within the schema can be accessed by `Table.columns`. 6 | ``` 7 | { 8 | columns { 9 | { ... } 10 | } 11 | } 12 | ``` 13 | 14 | Any column can be accessed by name using `Dataset.column` and [inline fragments](https://graphql.org/learn/queries/#inline-fragments). 15 | ``` 16 | { 17 | column(name: "...") { 18 | ... on Column { ... } 19 | } 20 | } 21 | ``` 22 | 23 | ### Input 24 | Input types don't have the equivalent of inline fragments, but GraphQL is converging on the [OneOf input pattern](https://github.com/graphql/graphql-spec/pull/825). Effectively the type of the field becomes the name of the field. 25 | 26 | `Dataset.scan` has flexible selection and projection. 27 | ``` 28 | { 29 | scan(filter: { ... }, columns: [{ ... }, ...]) { ... } 30 | } 31 | ``` 32 | 33 | `Table.filter` provides a friendlier interface for simple queries on columns within the schema. 34 | ``` 35 | { 36 | filter(: { ... }, ...) { ... } 37 | } 38 | ``` 39 | 40 | Note list inputs allow passing a single value, [coercing the input](https://spec.graphql.org/October2021/#sec-List.Input-Coercion) to a list of 1. 41 | 42 | ## Batches 43 | Datasets and scanners are processed in batches when possible, instead of loading the table into memory. 44 | 45 | * `group`, `scan`, and `filter` - native parallel batch processing 46 | * `sort` with `length` 47 | * `apply` with `list` functions 48 | * `rank` 49 | * `flatten` 50 | 51 | ## Partitions 52 | Partitioned datasets use fragment keys when possible. 53 | 54 | * `group` on fragment keys with counts 55 | * `rank` and `sort` with length on fragment keys 56 | 57 | ## Column selection 58 | Each field resolver transforms a table or array as needed. When working with an embedded library like [pandas](https://pandas.pydata.org), it's common to select a working set of columns for efficiency. Whereas GraphQL has the advantage of knowing the entire query up front, so there is no `select` field because it's done automatically at every level of resolvers. 59 | 60 | ## List Arrays 61 | Arrow ListArrays are supported as ListColumns. `group: {aggregate: {list: ...}}` and `runs` leverage that feature to transform columns into ListColumns, which can be accessed via inline fragments and further aggregated. Though `group` hash aggregate functions are more efficient than creating lists. 62 | 63 | * `tables` returns a list of tables based on the list scalars. 64 | * `flatten` flattens the list columns and repeats the scalar columns as needed. 65 | * `apply(list: {filter:, ..., sort: ..., rank: ...})` applies vector functions to the list scalars. 66 | 67 | The list in use must all have the same value lengths, which is naturally the case when the result of grouping. Iterating scalars (in Python) is not ideal, but it can be faster than re-aggregating, depending on the average list size. 68 | 69 | ## Dictionary Arrays 70 | Arrow has dictionary-encoded arrays as a space optimization, but doesn't natively support some builtin functions on them. Support for dictionaries is extended, and often faster by only having to apply functions to the unique values. 71 | 72 | ## Nulls 73 | GraphQL continues the long tradition of confusing ["optional" with "nullable"](https://github.com/graphql/graphql-spec/issues/872). Graphique strives to be explicit regarding what may be omitted versus what may be null. 74 | 75 | ### Output 76 | Arrow has first-class support for nulls, so array scalars are nullable. Non-null scalars are used where relevant. 77 | 78 | Columns and rows are nullable to allow partial query results. `Dataset.optional` enables [client controlled nullability](https://github.com/graphql/graphql-spec/issues/867). 79 | 80 | ### Input 81 | Default values and non-null types are used wherever possible. When an input is optional and has no natural default, there are two cases to distinguish: 82 | 83 | * if null is expected and semantically different, the input's description explains null behavior 84 | * otherwise the input has an `@optional` directive, and explicit null behavior is undefined 85 | -------------------------------------------------------------------------------- /docs/examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pyarrow.dataset as ds\n", 10 | "from graphique import GraphQL\n", 11 | "\n", 12 | "\n", 13 | "def execute(query):\n", 14 | " result = app.schema.execute_sync(query, root_value=app.root_value, context_value={})\n", 15 | " for error in result.errors or []:\n", 16 | " raise ValueError(error)\n", 17 | " return result.data\n", 18 | "\n", 19 | "\n", 20 | "format = ds.ParquetFileFormat(read_options={'dictionary_columns': ['state']})\n", 21 | "dataset = ds.dataset('../tests/fixtures/zipcodes.parquet', format=format)\n", 22 | "app = GraphQL(dataset)" 23 | ] 24 | }, 25 | { 26 | "attachments": {}, 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### Introspect the dataset." 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "execute(\"\"\"{\n", 40 | " length\n", 41 | " schema {\n", 42 | " names\n", 43 | " types\n", 44 | " partitioning\n", 45 | " index\n", 46 | " }\n", 47 | "}\"\"\")" 48 | ] 49 | }, 50 | { 51 | "attachments": {}, 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "### Loading options\n", 56 | "* Scanner with camel-cased fields (not relevant in this dataset)\n", 57 | "* Table already read" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "import pyarrow.compute as pc\n", 67 | "from strawberry.utils.str_converters import to_camel_case\n", 68 | "\n", 69 | "columns = {to_camel_case(name): pc.field(name) for name in dataset.schema.names}\n", 70 | "GraphQL(dataset.scanner(columns=columns))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "GraphQL(dataset.to_table())" 80 | ] 81 | }, 82 | { 83 | "attachments": {}, 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "### Find California counties with the most cities.\n", 88 | "* `filter` state by \"CA\"\n", 89 | "* `group` by county\n", 90 | " * aggregate distinct count of cities\n", 91 | "* `sort` by city counts descending\n", 92 | "* access `columns`\n", 93 | " * `county` is still known in the schema\n", 94 | " * cities is a new `column` accessed through an inline fragment" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "execute(\"\"\"{\n", 104 | " filter(state: {eq: \"CA\"}) {\n", 105 | " group(by: \"county\", aggregate: {countDistinct: {name: \"city\", alias: \"cities\"}}) {\n", 106 | " sort(by: \"-cities\", length: 5) {\n", 107 | " columns {\n", 108 | " county {\n", 109 | " values\n", 110 | " }\n", 111 | " }\n", 112 | " cities: column(name: \"cities\") {\n", 113 | " ... on LongColumn {\n", 114 | " values\n", 115 | " }\n", 116 | " }\n", 117 | " }\n", 118 | " }\n", 119 | " }\n", 120 | "}\"\"\")" 121 | ] 122 | }, 123 | { 124 | "attachments": {}, 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Find states with cities which match the name of their county.\n", 129 | "* `scan` instead of `filter`, because comparing two columns is not a \"simple\" query\n", 130 | "* `Column.unique` instead of `group`, because no other aggregates are needed" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "execute(\"\"\"{\n", 140 | " scan(filter: {eq: [{name: \"county\"}, {name: \"city\"}]}) {\n", 141 | " columns {\n", 142 | " state {\n", 143 | " unique {\n", 144 | " length\n", 145 | " values\n", 146 | " }\n", 147 | " }\n", 148 | " }\n", 149 | " }\n", 150 | "}\"\"\")" 151 | ] 152 | }, 153 | { 154 | "attachments": {}, 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "### States which have no cities which match the name of their county.\n", 159 | "The opposite of the previous example. Filtering rows would drop needed data; the \"zeros\" have to be counted.\n", 160 | "* `scan` with projected column matching names instead of filtering\n", 161 | "* `group` by state\n", 162 | " * aggregate whether there are `any` matches\n", 163 | "* `scan` for no matches\n", 164 | "* access column\n" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "execute(\"\"\"{\n", 174 | " scan(columns: {alias: \"match\", eq: [{name: \"county\"}, {name: \"city\"}]}) {\n", 175 | " group(by: \"state\", aggregate: {any: {name: \"match\"}}) {\n", 176 | " scan(filter: {inv: {name: \"match\"}}) {\n", 177 | " columns {\n", 178 | " state {\n", 179 | " values\n", 180 | " }\n", 181 | " }\n", 182 | " }\n", 183 | " }\n", 184 | " }\n", 185 | "}\"\"\")" 186 | ] 187 | } 188 | ], 189 | "metadata": { 190 | "kernelspec": { 191 | "display_name": "Python 3.10.6 ('.venv': venv)", 192 | "language": "python", 193 | "name": "python3" 194 | }, 195 | "language_info": { 196 | "codemirror_mode": { 197 | "name": "ipython", 198 | "version": 3 199 | }, 200 | "file_extension": ".py", 201 | "mimetype": "text/x-python", 202 | "name": "python", 203 | "nbconvert_exporter": "python", 204 | "pygments_lexer": "ipython3", 205 | "version": "3.11.4" 206 | }, 207 | "orig_nbformat": 4, 208 | "vscode": { 209 | "interpreter": { 210 | "hash": "fe1d9005a8a33982f05f67810ca98c5c9c7de363fa0f442feea70330697eb4e5" 211 | } 212 | } 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 2 216 | } 217 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /docs/reference.md: -------------------------------------------------------------------------------- 1 | ::: graphique.core.ListChunk 2 | 3 | ::: graphique.core.Column 4 | 5 | ::: graphique.core.Table 6 | 7 | ::: graphique.core.Nodes 8 | 9 | ::: graphique.interface.Dataset 10 | 11 | ::: graphique.middleware.GraphQL 12 | -------------------------------------------------------------------------------- /docs/requirements.in: -------------------------------------------------------------------------------- 1 | -r ../requirements.in 2 | mkdocstrings[python] 3 | mkdocs-jupyter 4 | -------------------------------------------------------------------------------- /graphique/__init__.py: -------------------------------------------------------------------------------- 1 | from .middleware import GraphQL # noqa: F401 2 | -------------------------------------------------------------------------------- /graphique/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core utilities that add pandas-esque features to arrow arrays and tables. 3 | 4 | Arrow forbids subclassing, so the classes are for logical grouping. 5 | Their methods are called as functions. 6 | """ 7 | 8 | import bisect 9 | import contextlib 10 | import functools 11 | import inspect 12 | import itertools 13 | import operator 14 | import json 15 | from collections.abc import Callable, Iterable, Iterator, Mapping 16 | from dataclasses import dataclass 17 | from typing import TypeAlias, get_type_hints 18 | import numpy as np 19 | import pyarrow as pa 20 | import pyarrow.acero as ac 21 | import pyarrow.compute as pc 22 | import pyarrow.dataset as ds 23 | from typing_extensions import Self 24 | 25 | Array: TypeAlias = pa.Array | pa.ChunkedArray 26 | Batch: TypeAlias = pa.RecordBatch | pa.Table 27 | bit_any = functools.partial(functools.reduce, operator.or_) 28 | bit_all = functools.partial(functools.reduce, operator.and_) 29 | 30 | 31 | class Agg: 32 | """Aggregation options.""" 33 | 34 | option_map = { 35 | 'all': pc.ScalarAggregateOptions, 36 | 'any': pc.ScalarAggregateOptions, 37 | 'approximate_median': pc.ScalarAggregateOptions, 38 | 'count': pc.CountOptions, 39 | 'count_distinct': pc.CountOptions, 40 | 'distinct': pc.CountOptions, 41 | 'first': pc.ScalarAggregateOptions, 42 | 'first_last': pc.ScalarAggregateOptions, 43 | 'last': pc.ScalarAggregateOptions, 44 | 'list': type(None), 45 | 'max': pc.ScalarAggregateOptions, 46 | 'mean': pc.ScalarAggregateOptions, 47 | 'min': pc.ScalarAggregateOptions, 48 | 'min_max': pc.ScalarAggregateOptions, 49 | 'one': type(None), 50 | 'product': pc.ScalarAggregateOptions, 51 | 'stddev': pc.VarianceOptions, 52 | 'sum': pc.ScalarAggregateOptions, 53 | 'tdigest': pc.TDigestOptions, 54 | 'variance': pc.VarianceOptions, 55 | } 56 | ordered = {'first', 'last'} 57 | 58 | def __init__(self, name: str, alias: str = '', **options): 59 | self.name = name 60 | self.alias = alias or name 61 | self.options = options 62 | 63 | def func_options(self, func: str) -> pc.FunctionOptions: 64 | return self.option_map[func.removeprefix('hash_')](**self.options) 65 | 66 | 67 | @dataclass(frozen=True, slots=True) 68 | class Compare: 69 | """Comparable wrapper for bisection search.""" 70 | 71 | value: object 72 | 73 | def __lt__(self, other): 74 | return self.value < other.as_py() 75 | 76 | def __gt__(self, other): 77 | return self.value > other.as_py() 78 | 79 | 80 | def sort_key(name: str) -> tuple: 81 | """Parse sort order.""" 82 | return name.lstrip('-'), ('descending' if name.startswith('-') else 'ascending') 83 | 84 | 85 | def register(func: Callable, kind: str = 'scalar') -> pc.Function: 86 | """Register user defined function by kind.""" 87 | doc = inspect.getdoc(func) 88 | doc = {'summary': doc.splitlines()[0], 'description': doc} # type: ignore 89 | annotations = dict(get_type_hints(func)) 90 | result = annotations.pop('return') 91 | with contextlib.suppress(pa.ArrowKeyError): # apache/arrow#{31611,31612} 92 | getattr(pc, f'register_{kind}_function')(func, func.__name__, doc, annotations, result) 93 | return pc.get_function(func.__name__) 94 | 95 | 96 | @register 97 | def digitize( 98 | ctx, 99 | array: pa.float64(), # type: ignore 100 | bins: pa.list_(pa.float64()), # type: ignore 101 | right: pa.bool_(), # type: ignore 102 | ) -> pa.int64(): # type: ignore 103 | """Return the indices of the bins to which each value in input array belongs.""" 104 | return pa.array(np.digitize(array, bins.values, right.as_py())) 105 | 106 | 107 | class ListChunk(pa.lib.BaseListArray): 108 | def from_counts(counts: pa.IntegerArray, values: pa.Array) -> pa.LargeListArray: 109 | """Return list array by converting counts into offsets.""" 110 | mask = None 111 | if counts.null_count: 112 | mask, counts = counts.is_null(), counts.fill_null(0) 113 | offsets = pa.concat_arrays([pa.array([0], counts.type), pc.cumulative_sum_checked(counts)]) 114 | cls = pa.LargeListArray if offsets.type == 'int64' else pa.ListArray 115 | return cls.from_arrays(offsets, values, mask=mask) 116 | 117 | def from_scalars(values: Iterable) -> pa.LargeListArray: 118 | """Return list array from array scalars.""" 119 | return ListChunk.from_counts(pa.array(map(len, values)), pa.concat_arrays(values)) 120 | 121 | def element(self, index: int) -> pa.Array: 122 | """element at index of each list scalar; defaults to null""" 123 | with contextlib.suppress(ValueError): 124 | return pc.list_element(self, index) 125 | size = -index if index < 0 else index + 1 126 | if isinstance(self, pa.ChunkedArray): 127 | self = self.combine_chunks() 128 | mask = np.asarray(Column.fill_null(pc.list_value_length(self), 0)) < size 129 | offsets = np.asarray(self.offsets[1:] if index < 0 else self.offsets[:-1]) 130 | return pc.list_flatten(self).take(pa.array(offsets + index, mask=mask)) 131 | 132 | def first(self) -> pa.Array: 133 | """first value of each list scalar""" 134 | return ListChunk.element(self, 0) 135 | 136 | def last(self) -> pa.Array: 137 | """last value of each list scalar""" 138 | return ListChunk.element(self, -1) 139 | 140 | def scalars(self) -> Iterable: 141 | empty = pa.array([], self.type.value_type) 142 | return (scalar.values or empty for scalar in self) 143 | 144 | def map_list(self, func: Callable, **kwargs) -> pa.lib.BaseListArray: 145 | """Return list array by mapping function across scalars, with null handling.""" 146 | values = [func(value, **kwargs) for value in ListChunk.scalars(self)] 147 | return ListChunk.from_scalars(values) 148 | 149 | def inner_flatten(self) -> pa.lib.BaseListArray: 150 | """Return flattened inner lists from a nested list array.""" 151 | offsets = self.values.offsets.take(self.offsets) 152 | return type(self).from_arrays(offsets, self.values.values) 153 | 154 | def aggregate(self, **funcs: pc.FunctionOptions | None) -> pa.RecordBatch: 155 | """Return aggregated scalars by grouping each hash function on the parent indices. 156 | 157 | If there are empty or null scalars, then the result must be padded with null defaults and 158 | reordered. If the function is a `count`, then the default is 0. 159 | """ 160 | columns = {'key': pc.list_parent_indices(self), '': pc.list_flatten(self)} 161 | items = [('', name, funcs[name]) for name in funcs] 162 | table = pa.table(columns).group_by(['key']).aggregate(items) 163 | indices, table = table['key'], table.remove_column(table.schema.get_field_index('key')) 164 | (batch,) = table.to_batches() 165 | if len(batch) == len(self): # no empty or null scalars 166 | return batch 167 | mask = pc.equal(pc.list_value_length(self), 0) 168 | empties = pc.indices_nonzero(Column.fill_null(mask, True)) 169 | indices = pa.chunked_array(indices.chunks + [empties.cast(indices.type)]) 170 | columns = {} 171 | for field in batch.schema: 172 | scalar = pa.scalar(0 if 'count' in field.name else None, field.type) 173 | columns[field.name] = pa.repeat(scalar, len(empties)) 174 | table = pa.concat_tables([table, pa.table(columns)]).combine_chunks() 175 | return table.to_batches()[0].take(pc.sort_indices(indices)) 176 | 177 | def min_max(self, **options) -> pa.Array: 178 | if pa.types.is_dictionary(self.type.value_type): 179 | (self,) = ListChunk.aggregate(self, distinct=None) 180 | self = type(self).from_arrays(self.offsets, self.values.dictionary_decode()) 181 | return ListChunk.aggregate(self, min_max=pc.ScalarAggregateOptions(**options))[0] 182 | 183 | def min(self, **options) -> pa.Array: 184 | """min value of each list scalar""" 185 | return ListChunk.min_max(self, **options).field('min') 186 | 187 | def max(self, **options) -> pa.Array: 188 | """max value of each list scalar""" 189 | return ListChunk.min_max(self, **options).field('max') 190 | 191 | def mode(self, **options) -> pa.Array: 192 | """modes of each list scalar""" 193 | return ListChunk.map_list(self, pc.mode, **options) 194 | 195 | def quantile(self, **options) -> pa.Array: 196 | """quantiles of each list scalar""" 197 | return ListChunk.map_list(self, pc.quantile, **options) 198 | 199 | def index(self, **options) -> pa.Array: 200 | """index for first occurrence of each list scalar""" 201 | return pa.array(pc.index(value, **options) for value in ListChunk.scalars(self)) 202 | 203 | @register 204 | def list_all(ctx, self: pa.list_(pa.bool_())) -> pa.bool_(): # type: ignore 205 | """Test whether all elements in a boolean array evaluate to true.""" 206 | return ListChunk.aggregate(self, all=None)[0] 207 | 208 | @register 209 | def list_any(ctx, self: pa.list_(pa.bool_())) -> pa.bool_(): # type: ignore 210 | """Test whether any element in a boolean array evaluates to true.""" 211 | return ListChunk.aggregate(self, any=None)[0] 212 | 213 | 214 | class Column(pa.ChunkedArray): 215 | """Chunked array interface as a namespace of functions.""" 216 | 217 | def is_list_type(self): 218 | funcs = pa.types.is_list, pa.types.is_large_list, pa.types.is_fixed_size_list 219 | return any(func(self.type) for func in funcs) 220 | 221 | def call_indices(self, func: Callable) -> Array: 222 | if not pa.types.is_dictionary(self.type): 223 | return func(self) 224 | array = self.combine_chunks() 225 | return pa.DictionaryArray.from_arrays(func(array.indices), array.dictionary) 226 | 227 | def fill_null_backward(self) -> Array: 228 | """`fill_null_backward` with dictionary support.""" 229 | return Column.call_indices(self, pc.fill_null_backward) 230 | 231 | def fill_null_forward(self) -> Array: 232 | """`fill_null_forward` with dictionary support.""" 233 | return Column.call_indices(self, pc.fill_null_forward) 234 | 235 | def fill_null(self, value) -> pa.ChunkedArray: 236 | """Optimized `fill_null` to check `null_count`.""" 237 | return self.fill_null(value) if self.null_count else self 238 | 239 | def sort_values(self) -> Array: 240 | if not pa.types.is_dictionary(self.type): 241 | return self 242 | array = self if isinstance(self, pa.Array) else self.combine_chunks() 243 | return pc.rank(array.dictionary, 'ascending').take(array.indices) 244 | 245 | def pairwise_diff(self, period: int = 1) -> Array: 246 | """`pairwise_diff` with chunked array support.""" 247 | return pc.pairwise_diff(self.combine_chunks(), period) 248 | 249 | def diff(self, func: Callable = pc.subtract, period: int = 1) -> Array: 250 | """Compute first order difference of an array. 251 | 252 | Unlike `pairwise_diff`, does not return leading nulls. 253 | """ 254 | return func(self[period:], self[:-period]) 255 | 256 | def run_offsets(self, predicate: Callable = pc.not_equal, *args) -> pa.IntegerArray: 257 | """Run-end encode array with leading zero, suitable for list offsets. 258 | 259 | Args: 260 | predicate: binary function applied to adjacent values 261 | *args: apply binary function to scalar, using `subtract` as the difference function 262 | """ 263 | ends = [pa.array([True])] 264 | mask = predicate(Column.diff(self), *args) if args else Column.diff(self, predicate) 265 | return pc.indices_nonzero(pa.chunked_array(ends + mask.chunks + ends)) 266 | 267 | def index(self, value, start=0, end=None) -> int: 268 | """Return the first index of a value.""" 269 | with contextlib.suppress(NotImplementedError): 270 | return self.index(value, start, end).as_py() # type: ignore 271 | offset = start 272 | for chunk in self[start:end].iterchunks(): 273 | index = chunk.dictionary.index(value).as_py() 274 | if index >= 0: 275 | index = chunk.indices.index(index).as_py() 276 | if index >= 0: 277 | return offset + index 278 | offset += len(chunk) 279 | return -1 280 | 281 | def range(self, lower=None, upper=None, include_lower=True, include_upper=False) -> slice: 282 | """Return slice within range from a sorted array, by default a half-open interval.""" 283 | method = bisect.bisect_left if include_lower else bisect.bisect_right 284 | start = 0 if lower is None else method(self, Compare(lower)) 285 | method = bisect.bisect_right if include_upper else bisect.bisect_left 286 | stop = None if upper is None else method(self, Compare(upper), start) 287 | return slice(start, stop) 288 | 289 | def find(self, *values) -> Iterator[slice]: 290 | """Generate slices of matching rows from a sorted array.""" 291 | stop = 0 292 | for value in map(Compare, sorted(values)): 293 | start = bisect.bisect_left(self, value, stop) 294 | stop = bisect.bisect_right(self, value, start) 295 | yield slice(start, stop) 296 | 297 | 298 | class Table(pa.Table): 299 | """Table interface as a namespace of functions.""" 300 | 301 | def map_batch(self, func: Callable, *args, **kwargs) -> pa.Table: 302 | return pa.Table.from_batches(func(batch, *args, **kwargs) for batch in self.to_batches()) 303 | 304 | def columns(self) -> dict: 305 | """Return columns as a dictionary.""" 306 | return dict(zip(self.schema.names, self)) 307 | 308 | def union(*tables: Batch) -> Batch: 309 | """Return table with union of columns.""" 310 | columns: dict = {} 311 | for table in tables: 312 | columns |= Table.columns(table) 313 | return type(tables[0]).from_pydict(columns) 314 | 315 | def range(self, name: str, lower=None, upper=None, **includes) -> pa.Table: 316 | """Return rows within range, by default a half-open interval. 317 | 318 | Assumes the table is sorted by the column name, i.e., indexed. 319 | """ 320 | return self[Column.range(self[name], lower, upper, **includes)] 321 | 322 | def is_in(self, name: str, *values) -> pa.Table: 323 | """Return rows which matches one of the values. 324 | 325 | Assumes the table is sorted by the column name, i.e., indexed. 326 | """ 327 | slices = list(Column.find(self[name], *values)) or [slice(0)] 328 | return pa.concat_tables(self[slc] for slc in slices) 329 | 330 | def not_equal(self, name: str, value) -> pa.Table: 331 | """Return rows which don't match the value. 332 | 333 | Assumes the table is sorted by the column name, i.e., indexed. 334 | """ 335 | (slc,) = Column.find(self[name], value) 336 | return pa.concat_tables([self[: slc.start], self[slc.stop :]]) 337 | 338 | def from_offsets(self, offsets: pa.IntegerArray, mask=None) -> pa.RecordBatch: 339 | """Return record batch with columns converted into list columns.""" 340 | cls = pa.LargeListArray if offsets.type == 'int64' else pa.ListArray 341 | if isinstance(self, pa.Table): 342 | (self,) = self.combine_chunks().to_batches() or [pa.record_batch([], self.schema)] 343 | arrays = [cls.from_arrays(offsets, array, mask=mask) for array in self] 344 | return pa.RecordBatch.from_arrays(arrays, self.schema.names) 345 | 346 | def from_counts(self, counts: pa.IntegerArray) -> pa.RecordBatch: 347 | """Return record batch with columns converted into list columns.""" 348 | mask = None 349 | if counts.null_count: 350 | mask, counts = counts.is_null(), counts.fill_null(0) 351 | offsets = pa.concat_arrays([pa.array([0], counts.type), pc.cumulative_sum_checked(counts)]) 352 | return Table.from_offsets(self, offsets, mask=mask) 353 | 354 | def runs(self, *names: str, **predicates: tuple) -> tuple: 355 | """Return table grouped by pairwise differences, and corresponding counts. 356 | 357 | Args: 358 | *names: columns to partition by `not_equal` which will return scalars 359 | **predicates: pairwise predicates with optional args which will return list arrays; 360 | if the predicate has args, it will be called on the differences 361 | """ 362 | offsets = pa.chunked_array( 363 | Column.run_offsets(self[name], *predicates.get(name, ())) 364 | for name in names + tuple(predicates) 365 | ) 366 | offsets = offsets.unique().sort() 367 | scalars = self.select(names).take(offsets[:-1]) 368 | lists = self.select(set(self.schema.names) - set(names)) 369 | table = Table.union(scalars, Table.from_offsets(lists, offsets)) 370 | return table, Column.diff(offsets) 371 | 372 | def list_fields(self) -> set: 373 | return {field.name for field in self.schema if Column.is_list_type(field)} 374 | 375 | def list_value_length(self) -> pa.Array: 376 | lists = Table.list_fields(self) 377 | if not lists: 378 | raise ValueError(f"no list columns available: {self.schema.names}") 379 | counts, *others = (pc.list_value_length(self[name]) for name in lists) 380 | if any(counts != other for other in others): 381 | raise ValueError(f"list columns have different value lengths: {lists}") 382 | return counts if isinstance(counts, pa.Array) else counts.chunk(0) 383 | 384 | def map_list(self, func: Callable, *args, **kwargs) -> Batch: 385 | """Return table with function mapped across list scalars.""" 386 | batches: Iterable = Table.split(self.select(Table.list_fields(self))) 387 | batches = [None if batch is None else func(batch, *args, **kwargs) for batch in batches] 388 | counts = pa.array(None if batch is None else len(batch) for batch in batches) 389 | table = pa.Table.from_batches(batch for batch in batches if batch is not None) 390 | return Table.union(self, Table.from_counts(table, counts)) 391 | 392 | def sort_indices( 393 | self, *names: str, length: int | None = None, null_placement: str = 'at_end' 394 | ) -> pa.Array: 395 | """Return indices which would sort the table by columns, optimized for fixed length.""" 396 | func = functools.partial(pc.sort_indices, null_placement=null_placement) 397 | if length is not None and length < len(self): 398 | func = functools.partial(pc.select_k_unstable, k=length) 399 | keys = dict(map(sort_key, names)) 400 | table = pa.table({name: Column.sort_values(self[name]) for name in keys}) 401 | return func(table, sort_keys=keys.items()) if table else pa.array([], 'int64') 402 | 403 | def sort( 404 | self, 405 | *names: str, 406 | length: int | None = None, 407 | indices: str = '', 408 | null_placement: str = 'at_end', 409 | ) -> Batch: 410 | """Return table sorted by columns, optimized for fixed length. 411 | 412 | Args: 413 | *names: columns to sort by 414 | length: maximum number of rows to return 415 | indices: include original indices in the table 416 | """ 417 | if length == 1 and not indices: 418 | return Table.min_max(self, *names)[:1] 419 | indices_ = Table.sort_indices(self, *names, length=length, null_placement=null_placement) 420 | table = self.take(indices_) 421 | if indices: 422 | table = table.append_column(indices, indices_) 423 | func = lambda name: not name.startswith('-') and not self[name].null_count # noqa: E731 424 | metadata = {'index_columns': list(itertools.takewhile(func, names))} 425 | return table.replace_schema_metadata({'pandas': json.dumps(metadata)}) 426 | 427 | def filter_list(self, expr: ds.Expression) -> Batch: 428 | """Return table with list columns filtered within scalars.""" 429 | fields = Table.list_fields(self) 430 | tables = [ 431 | None if batch is None else pa.Table.from_batches([batch]).filter(expr).select(fields) 432 | for batch in Table.split(self) 433 | ] 434 | counts = pa.array(None if table is None else len(table) for table in tables) 435 | table = pa.concat_tables(table for table in tables if table is not None) 436 | return Table.union(self, Table.from_counts(table, counts)) 437 | 438 | def min_max(self, *names: str) -> Self: 439 | """Return table filtered by minimum or maximum values.""" 440 | for key, order in map(sort_key, names): 441 | field, asc = pc.field(key), (order == 'ascending') 442 | ((value,),) = Nodes.group(self, _=(key, ('min' if asc else 'max'), None)).to_table() 443 | self = self.filter(field <= value if asc else field >= value) 444 | return self 445 | 446 | def rank(self, k: int, *names: str) -> Self: 447 | """Return table filtered by values within dense rank, similar to `select_k_unstable`.""" 448 | if k == 1: 449 | return Table.min_max(self, *names) 450 | keys = dict(map(sort_key, names)) 451 | table = Nodes.group(self, *keys).to_table() 452 | table = table.take(pc.select_k_unstable(table, k, keys.items())) 453 | exprs = [] 454 | for key, order in keys.items(): 455 | field, asc = pc.field(key), (order == 'ascending') 456 | exprs.append(field <= pc.max(table[key]) if asc else field >= pc.min(table[key])) 457 | return self.filter(bit_all(exprs)) 458 | 459 | def fragments(self, *names, counts: str = '') -> pa.Table: 460 | """Return selected fragment keys in a table.""" 461 | try: 462 | expr = self._scan_options.get('filter') 463 | if expr is not None: # raise ValueError if filter references other fields 464 | ds.dataset([], schema=self.partitioning.schema).scanner(filter=expr) 465 | except (AttributeError, ValueError): 466 | return pa.table({}) 467 | fragments = self._get_fragments(expr) 468 | parts = [ds.get_partition_keys(frag.partition_expression) for frag in fragments] 469 | names, table = set(names), pa.Table.from_pylist(parts) # type: ignore 470 | keys = [name for name in table.schema.names if name in names] 471 | table = table.group_by(keys, use_threads=False).aggregate([]) 472 | if not counts: 473 | return table 474 | if not table.schema: 475 | return table.append_column(counts, pa.array([self.count_rows()])) 476 | exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()] 477 | column = [self.filter(expr).count_rows() for expr in exprs] 478 | return table.append_column(counts, pa.array(column)) 479 | 480 | def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple: 481 | """Return expression and unmatched fields for partitioned dataset which filters by rank. 482 | 483 | Args: 484 | k: max dense rank or length 485 | *names: columns to rank by 486 | dense: use dense rank; false indicates sorting 487 | """ 488 | keys = dict(map(sort_key, names)) 489 | table = Table.fragments(self, *keys, counts='' if dense else '_') 490 | keys = {name: keys[name] for name in table.schema.names if name in keys} 491 | if not keys: 492 | return None, names 493 | if dense: 494 | table = table.take(pc.select_k_unstable(table, k, keys.items())) 495 | else: 496 | table = table.sort_by(keys.items()) 497 | totals = itertools.accumulate(table['_'].to_pylist()) 498 | counts = (count for count, total in enumerate(totals, 1) if total >= k) 499 | table = table[: next(counts, None)].remove_column(len(table) - 1) 500 | exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()] 501 | remaining = names[len(keys) :] 502 | if remaining or not dense: # fields with a single value are no longer needed 503 | selectors = [len(table[key].unique()) > 1 for key in keys] 504 | remaining = tuple(itertools.compress(names, selectors)) + remaining 505 | return bit_any(exprs[: len(table)]), remaining 506 | 507 | def flatten(self, indices: str = '') -> Iterator[pa.RecordBatch]: 508 | """Generate batches with list arrays flattened, optionally with parent indices.""" 509 | offset = 0 510 | for batch in self.to_batches(): 511 | _ = Table.list_value_length(batch) 512 | indices_ = pc.list_parent_indices(batch[Table.list_fields(batch).pop()]) 513 | arrays = [ 514 | pc.list_flatten(array) if Column.is_list_type(array) else array.take(indices_) 515 | for array in batch 516 | ] 517 | columns = dict(zip(batch.schema.names, arrays)) 518 | if indices: 519 | columns[indices] = pc.add(indices_, offset) 520 | offset += len(batch) 521 | yield pa.RecordBatch.from_pydict(columns) 522 | 523 | def split(self) -> Iterator[pa.RecordBatch | None]: 524 | """Generate tables from splitting list scalars.""" 525 | lists = Table.list_fields(self) 526 | scalars = set(self.schema.names) - lists 527 | for index, count in enumerate(Table.list_value_length(self).to_pylist()): 528 | if count is None: 529 | yield None 530 | else: 531 | row = {name: pa.repeat(self[name][index], count) for name in scalars} 532 | row |= {name: self[name][index].values for name in lists} 533 | yield pa.RecordBatch.from_pydict(row) 534 | 535 | def size(self) -> str: 536 | """Return buffer size in readable units.""" 537 | size, prefix = self.nbytes, '' 538 | for prefix in itertools.takewhile(lambda _: size >= 1e3, 'kMGT'): 539 | size /= 1e3 540 | return f'{size:n} {prefix}B' 541 | 542 | 543 | class Nodes(ac.Declaration): 544 | """[Acero](https://arrow.apache.org/docs/python/api/acero.html) engine declaration. 545 | 546 | Provides a `Scanner` interface with no "oneshot" limitation. 547 | """ 548 | 549 | option_map = { 550 | 'table_source': ac.TableSourceNodeOptions, 551 | 'scan': ac.ScanNodeOptions, 552 | 'filter': ac.FilterNodeOptions, 553 | 'project': ac.ProjectNodeOptions, 554 | 'aggregate': ac.AggregateNodeOptions, 555 | 'order_by': ac.OrderByNodeOptions, 556 | 'hashjoin': ac.HashJoinNodeOptions, 557 | } 558 | to_batches = ac.Declaration.to_reader # source compatibility 559 | 560 | def __init__(self, name, *args, inputs=None, **options): 561 | super().__init__(name, self.option_map[name](*args, **options), inputs) 562 | 563 | def scan(self, columns: Iterable[str]) -> Self: 564 | """Return projected source node, supporting datasets and tables.""" 565 | if isinstance(self, ds.Dataset): 566 | expr = self._scan_options.get('filter') 567 | self = Nodes('scan', self, columns=columns) 568 | if expr is not None: 569 | self = self.apply('filter', expr) 570 | elif isinstance(self, pa.Table): 571 | self = Nodes('table_source', self) 572 | elif isinstance(self, pa.RecordBatch): 573 | self = Nodes('table_source', pa.table(self)) 574 | if isinstance(columns, Mapping): 575 | return self.apply('project', columns.values(), columns) 576 | return self.apply('project', map(pc.field, columns)) 577 | 578 | @property 579 | def schema(self) -> pa.Schema: 580 | """projected schema""" 581 | with self.to_reader() as reader: 582 | return reader.schema 583 | 584 | def scanner(self, **options) -> ds.Scanner: 585 | return ds.Scanner.from_batches(self.to_reader(**options)) 586 | 587 | def count_rows(self) -> int: 588 | """Count matching rows.""" 589 | return self.scanner().count_rows() 590 | 591 | def head(self, num_rows: int, **options) -> pa.Table: 592 | """Load the first N rows.""" 593 | return self.scanner(**options).head(num_rows) 594 | 595 | def take(self, indices: Iterable[int], **options) -> pa.Table: 596 | """Select rows by index.""" 597 | return self.scanner(**options).take(indices) 598 | 599 | def apply(self, name: str, *args, **options) -> Self: 600 | """Add a node by name.""" 601 | return type(self)(name, *args, inputs=[self], **options) 602 | 603 | filter = functools.partialmethod(apply, 'filter') 604 | 605 | def group(self, *names, **aggs: tuple) -> Self: 606 | """Add `aggregate` node with dictionary support. 607 | 608 | Also supports datasets because aggregation determines the projection. 609 | """ 610 | aggregates, targets = [], set(names) 611 | for name, (target, _, _) in aggs.items(): 612 | aggregates.append(aggs[name] + (name,)) 613 | targets.update([target] if isinstance(target, str) else target) 614 | columns = {name: pc.field(name) for name in targets} 615 | for name in columns: 616 | field = self.schema.field(name) 617 | if pa.types.is_dictionary(field.type): 618 | columns[name] = columns[name].cast(field.type.value_type) 619 | return Nodes.scan(self, columns).apply('aggregate', aggregates, names) 620 | -------------------------------------------------------------------------------- /graphique/inputs.py: -------------------------------------------------------------------------------- 1 | """ 2 | GraphQL input types. 3 | """ 4 | 5 | from __future__ import annotations 6 | import functools 7 | import inspect 8 | import operator 9 | from collections.abc import Callable, Iterable 10 | from datetime import date, datetime, time, timedelta 11 | from decimal import Decimal 12 | from typing import Generic, TypeVar, no_type_check 13 | import pyarrow as pa 14 | import pyarrow.compute as pc 15 | import pyarrow.dataset as ds 16 | import strawberry 17 | from strawberry import UNSET 18 | from strawberry.annotation import StrawberryAnnotation 19 | from strawberry.types.arguments import StrawberryArgument 20 | from strawberry.schema_directive import Location 21 | from strawberry.types.field import StrawberryField 22 | from strawberry.scalars import JSON 23 | from typing_extensions import Self 24 | from .core import Agg 25 | from .scalars import Long 26 | 27 | T = TypeVar('T') 28 | 29 | 30 | class links: 31 | compute = 'https://arrow.apache.org/docs/python/api/compute.html' 32 | type = '[arrow type](https://arrow.apache.org/docs/python/api/datatypes.html)' 33 | 34 | 35 | class Input: 36 | """Common utilities for input types.""" 37 | 38 | nullables: set = set() 39 | 40 | def keys(self): 41 | for name, value in self.__dict__.items(): 42 | if value is None and name not in self.nullables: 43 | raise TypeError(f"`{self.__class__.__name__}.{name}` is optional, not nullable") 44 | if value is not UNSET: 45 | yield name 46 | 47 | def __getitem__(self, name): 48 | value = getattr(self, name) 49 | return dict(value) if hasattr(value, 'keys') else value 50 | 51 | 52 | def use_doc(decorator: Callable, **kwargs): 53 | return lambda func: decorator(description=inspect.getdoc(func), **kwargs)(func) 54 | 55 | 56 | @use_doc( 57 | strawberry.schema_directive, 58 | locations=[Location.ARGUMENT_DEFINITION, Location.INPUT_FIELD_DEFINITION], 59 | ) 60 | class optional: 61 | """This input is optional, not nullable. 62 | If the client insists on sending an explicit null value, the behavior is undefined. 63 | """ 64 | 65 | 66 | @use_doc( 67 | strawberry.schema_directive, 68 | locations=[Location.ARGUMENT_DEFINITION, Location.FIELD_DEFINITION], 69 | ) 70 | class provisional: 71 | """Provisional feature; subject to change in the future.""" 72 | 73 | 74 | def default_field( 75 | default=UNSET, func: Callable | None = None, nullable: bool = False, **kwargs 76 | ) -> StrawberryField: 77 | """Use dataclass `default_factory` for `UNSET` or mutables.""" 78 | if func is not None: 79 | kwargs['description'] = inspect.getdoc(func).splitlines()[0] # type: ignore 80 | if not nullable and default is UNSET: 81 | kwargs.setdefault('directives', []).append(optional()) 82 | return strawberry.field(default_factory=type(default), **kwargs) 83 | 84 | 85 | @strawberry.input(description="predicates for scalars") 86 | class Filter(Generic[T], Input): 87 | eq: list[T | None] | None = default_field( 88 | description="== or `isin`; `null` is equivalent to arrow `is_null`.", nullable=True 89 | ) 90 | ne: T | None = default_field( 91 | description="!=; `null` is equivalent to arrow `is_valid`.", nullable=True 92 | ) 93 | lt: T | None = default_field(description="<") 94 | le: T | None = default_field(description="<=") 95 | gt: T | None = default_field(description=r"\>") 96 | ge: T | None = default_field(description=r"\>=") 97 | 98 | nullables = {'eq', 'ne'} 99 | 100 | @classmethod 101 | def resolve_args(cls, types: dict) -> Iterable[StrawberryArgument]: 102 | """Generate dynamically resolved arguments for filter field.""" 103 | for name in types: 104 | annotation = StrawberryAnnotation(cls[types[name]]) # type: ignore 105 | if types[name] not in (list, dict): 106 | yield StrawberryArgument(name, name, annotation, default={}) 107 | 108 | 109 | @strawberry.input(description=f"name and optional alias for [compute functions]({links.compute})") 110 | class Field(Agg): 111 | name: str = strawberry.field(description="column name") 112 | alias: str = strawberry.field(default='', description="output column name") 113 | 114 | __init__ = Agg.__init__ 115 | 116 | def __init_subclass__(cls): 117 | cls.__init__ = cls.__init__ 118 | 119 | 120 | @strawberry.input 121 | class Cumulative(Field): 122 | start: float = 0.0 123 | skip_nulls: bool = False 124 | checked: bool = False 125 | 126 | 127 | @strawberry.input 128 | class Pairwise(Field): 129 | period: int = 1 130 | checked: bool = False 131 | 132 | 133 | @strawberry.input 134 | class Index(Field): 135 | value: JSON 136 | start: Long = 0 137 | end: Long | None = None 138 | 139 | 140 | @strawberry.input 141 | class Mode(Field): 142 | n: int = 1 143 | skip_nulls: bool = True 144 | min_count: int = 1 145 | 146 | 147 | @strawberry.input 148 | class Quantile(Field): 149 | q: list[float] = (0.5,) # type: ignore 150 | interpolation: str = 'linear' 151 | skip_nulls: bool = True 152 | min_count: int = 1 153 | 154 | 155 | @strawberry.input 156 | class Rank(Field): 157 | sort_keys: str = 'ascending' 158 | null_placement: str = 'at_end' 159 | tiebreaker: str = 'first' 160 | 161 | 162 | @strawberry.input 163 | class Sort: 164 | by: list[str] 165 | length: Long | None = None 166 | 167 | 168 | @strawberry.input 169 | class Ranked: 170 | by: list[str] 171 | max: int = 1 172 | 173 | 174 | @strawberry.input(description=f"[functions]({links.compute}#structural-transforms) for lists") 175 | class ListFunction(Input): 176 | deprecation = "List scalar functions will be moved to `scan(...: {list: ...})`" 177 | 178 | filter: Expression = default_field({}, description="filter within list scalars") 179 | sort: Sort | None = default_field(description="sort within list scalars") 180 | rank: Ranked | None = default_field(description="select by dense rank within list scalars") 181 | index: Index | None = default_field(func=pc.index, deprecation_reason=deprecation) 182 | mode: Mode | None = default_field(func=pc.mode, deprecation_reason=deprecation) 183 | quantile: Quantile | None = default_field(func=pc.quantile, deprecation_reason=deprecation) 184 | 185 | def keys(self): 186 | return set(super().keys()) - {'filter', 'sort', 'rank'} 187 | 188 | 189 | @strawberry.input(description=f"options for count [aggregation]({links.compute}#aggregations)") 190 | class CountAggregate(Field): 191 | mode: str = 'only_valid' 192 | 193 | 194 | @strawberry.input(description=f"options for scalar [aggregation]({links.compute}#aggregations)") 195 | class ScalarAggregate(Field): 196 | skip_nulls: bool = True 197 | 198 | 199 | @strawberry.input(description=f"options for variance [aggregation]({links.compute}#aggregations)") 200 | class VarianceAggregate(ScalarAggregate): 201 | ddof: int = 0 202 | 203 | 204 | @strawberry.input(description=f"options for tdigest [aggregation]({links.compute}#aggregations)") 205 | class TDigestAggregate(ScalarAggregate): 206 | q: list[float] = (0.5,) # type: ignore 207 | delta: int = 100 208 | buffer_size: int = 500 209 | 210 | 211 | @strawberry.input 212 | class ScalarAggregates(Input): 213 | all: list[ScalarAggregate] = default_field([], func=pc.all) 214 | any: list[ScalarAggregate] = default_field([], func=pc.any) 215 | approximate_median: list[ScalarAggregate] = default_field([], func=pc.approximate_median) 216 | count: list[CountAggregate] = default_field([], func=pc.count) 217 | count_distinct: list[CountAggregate] = default_field([], func=pc.count_distinct) 218 | first: list[ScalarAggregate] = default_field([], func=pc.first) 219 | first_last: list[ScalarAggregate] = default_field([], func=pc.first_last) 220 | last: list[ScalarAggregate] = default_field([], func=pc.last) 221 | max: list[ScalarAggregate] = default_field([], func=pc.max) 222 | mean: list[ScalarAggregate] = default_field([], func=pc.mean) 223 | min: list[ScalarAggregate] = default_field([], func=pc.min) 224 | min_max: list[ScalarAggregate] = default_field([], func=pc.min_max) 225 | product: list[ScalarAggregate] = default_field([], func=pc.product) 226 | stddev: list[VarianceAggregate] = default_field([], func=pc.stddev) 227 | sum: list[ScalarAggregate] = default_field([], func=pc.sum) 228 | tdigest: list[TDigestAggregate] = default_field([], func=pc.tdigest) 229 | variance: list[VarianceAggregate] = default_field([], func=pc.variance) 230 | 231 | def keys(self): 232 | return (key.rstrip('_') for key in super().keys() if self[key]) 233 | 234 | def __getitem__(self, name): 235 | return super().__getitem__('list_' if name == 'list' else name) 236 | 237 | 238 | @strawberry.input 239 | class HashAggregates(ScalarAggregates): 240 | distinct: list[CountAggregate] = default_field( 241 | [], description="distinct values within each scalar" 242 | ) 243 | list_: list[Field] = default_field([], name='list', description="all values within each scalar") 244 | one: list[Field] = default_field([], description="arbitrary value within each scalar") 245 | 246 | 247 | @use_doc(strawberry.input) 248 | class Diff(Input): 249 | """Discrete difference predicates, applied in forwards direction (array[i + 1] ? array[i]). 250 | 251 | By default compares by not equal. Specifying `null` with a predicate compares pairwise. 252 | A float computes the discrete difference first; durations may be in float seconds. 253 | """ 254 | 255 | name: str 256 | less: float | None = default_field(name='lt', description="<", nullable=True) 257 | less_equal: float | None = default_field(name='le', description="<=", nullable=True) 258 | greater: float | None = default_field(name='gt', description=r"\>", nullable=True) 259 | greater_equal: float | None = default_field(name='ge', description=r"\>=", nullable=True) 260 | 261 | nullables = {'less', 'less_equal', 'greater', 'greater_equal'} 262 | 263 | 264 | @use_doc(strawberry.input) 265 | class Expression: 266 | """[Dataset expression](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html) 267 | used for [scanning](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Scanner.html). 268 | 269 | Expects one of: a field `name`, a scalar, or an operator with expressions. Single values can be passed for an 270 | [input `List`](https://spec.graphql.org/October2021/#sec-List.Input-Coercion). 271 | * `eq` with a list scalar is equivalent to `isin` 272 | * `eq` with a `null` scalar is equivalent `is_null` 273 | * `ne` with a `null` scalar is equivalent to `is_valid` 274 | """ 275 | 276 | name: list[str] = default_field([], description="field name(s)") 277 | cast: str = strawberry.field(default='', description=f"cast as {links.type}") 278 | safe: bool = strawberry.field(default=True, description="check for conversion errors on cast") 279 | value: JSON | None = default_field( 280 | description="JSON scalar; also see typed scalars", nullable=True 281 | ) 282 | kleene: bool = strawberry.field(default=False, description="use kleene logic for booleans") 283 | checked: bool = strawberry.field(default=False, description="check for overflow errors") 284 | 285 | base64: list[bytes] = default_field([]) 286 | date_: list[date] = default_field([], name='date') 287 | datetime_: list[datetime] = default_field([], name='datetime') 288 | decimal: list[Decimal] = default_field([]) 289 | duration: list[timedelta] = default_field([]) 290 | time_: list[time] = default_field([], name='time') 291 | 292 | eq: list[Expression] = default_field([], description="==") 293 | ne: list[Expression] = default_field([], description="!=") 294 | lt: list[Expression] = default_field([], description="<") 295 | le: list[Expression] = default_field([], description="<=") 296 | gt: list[Expression] = default_field([], description=r"\>") 297 | ge: list[Expression] = default_field([], description=r"\>=") 298 | inv: Expression | None = default_field(description="~") 299 | 300 | abs: Expression | None = default_field(func=pc.abs) 301 | add: list[Expression] = default_field([], func=pc.add) 302 | divide: list[Expression] = default_field([], func=pc.divide) 303 | multiply: list[Expression] = default_field([], func=pc.multiply) 304 | negate: Expression | None = default_field(func=pc.negate) 305 | power: list[Expression] = default_field([], func=pc.power) 306 | sign: Expression | None = default_field(func=pc.sign) 307 | subtract: list[Expression] = default_field([], func=pc.subtract) 308 | 309 | bit_wise: BitWise | None = default_field(description="bit-wise functions") 310 | rounding: Rounding | None = default_field(description="rounding functions") 311 | log: Log | None = default_field(description="logarithmic functions") 312 | trig: Trig | None = default_field(description="trigonometry functions") 313 | element_wise: ElementWise | None = default_field(description="element-wise aggregate functions") 314 | 315 | and_: list[Expression] = default_field([], name='and', description="&") 316 | and_not: list[Expression] = default_field([], func=pc.and_not) 317 | or_: list[Expression] = default_field([], name='or', description="|") 318 | xor: list[Expression] = default_field([], func=pc.xor) 319 | 320 | utf8: Utf8 | None = default_field(description="utf8 string functions") 321 | string_is_ascii: Expression | None = default_field(func=pc.string_is_ascii) 322 | substring: MatchSubstring | None = default_field(description="match substring functions") 323 | 324 | binary: Binary | None = default_field(description="binary functions") 325 | set_lookup: SetLookup | None = default_field(description="set lookup functions") 326 | 327 | is_finite: Expression | None = default_field(func=pc.is_finite) 328 | is_inf: Expression | None = default_field(func=pc.is_inf) 329 | is_nan: Expression | None = default_field(func=pc.is_nan) 330 | true_unless_null: Expression | None = default_field(func=pc.true_unless_null) 331 | 332 | case_when: list[Expression] = default_field([], func=pc.case_when) 333 | choose: list[Expression] = default_field([], func=pc.choose) 334 | coalesce: list[Expression] = default_field([], func=pc.coalesce) 335 | if_else: list[Expression] = default_field([], func=pc.if_else) 336 | 337 | temporal: Temporal | None = default_field(description="temporal functions") 338 | 339 | replace_with_mask: list[Expression] = default_field([], func=pc.replace_with_mask) 340 | 341 | list: Lists | None = default_field(description="list array functions") 342 | 343 | unaries = ('inv', 'abs', 'negate', 'sign', 'string_is_ascii', 'is_finite', 'is_inf', 'is_nan') 344 | associatives = ('add', 'multiply', 'and_', 'or_', 'xor') 345 | variadics = ('eq', 'ne', 'lt', 'le', 'gt', 'ge', 'divide', 'power', 'subtract', 'and_not') 346 | variadics += ('case_when', 'choose', 'coalesce', 'if_else', 'replace_with_mask') # type: ignore 347 | scalars = ('base64', 'date_', 'datetime_', 'decimal', 'duration', 'time_') 348 | groups = ('bit_wise', 'rounding', 'log', 'trig', 'element_wise', 'utf8', 'substring', 'binary') 349 | groups += ('set_lookup', 'temporal', 'list') # type: ignore 350 | 351 | def to_arrow(self) -> ds.Expression | None: 352 | """Transform GraphQL expression into a dataset expression.""" 353 | fields = [] 354 | if self.name: 355 | fields.append(pc.field(*self.name)) 356 | for name in self.scalars: 357 | scalars = list(map(self.getscalar, getattr(self, name))) 358 | if scalars: 359 | fields.append(scalars[0] if len(scalars) == 1 else scalars) 360 | if self.value is not UNSET: 361 | fields.append(self.getscalar(self.value)) 362 | for op in self.associatives: 363 | exprs = [expr.to_arrow() for expr in getattr(self, op)] 364 | if exprs: 365 | fields.append(functools.reduce(self.getfunc(op), exprs)) 366 | for op in self.variadics: 367 | exprs = [expr.to_arrow() for expr in getattr(self, op)] 368 | if exprs: 369 | if op == 'eq' and isinstance(exprs[-1], list): 370 | field = ds.Expression.isin(*exprs) 371 | elif exprs[-1] is None and op in ('eq', 'ne'): 372 | field, _ = exprs 373 | field = field.is_null() if op == 'eq' else field.is_valid() 374 | else: 375 | field = self.getfunc(op)(*exprs) 376 | fields.append(field) 377 | for group in operator.attrgetter(*self.groups)(self): 378 | if group is not UNSET: 379 | fields += group.to_fields() 380 | for op in self.unaries: 381 | expr = getattr(self, op) 382 | if expr is not UNSET: 383 | fields.append(self.getfunc(op)(expr.to_arrow())) 384 | if not fields: 385 | return None 386 | if len(fields) > 1: 387 | raise ValueError(f"conflicting inputs: {', '.join(map(str, fields))}") 388 | (field,) = fields 389 | cast = self.cast and isinstance(field, ds.Expression) 390 | return field.cast(self.cast, self.safe) if cast else field 391 | 392 | def getscalar(self, value): 393 | return pa.scalar(value, self.cast) if self.cast else value 394 | 395 | def getfunc(self, name): 396 | if self.kleene: 397 | name = name.rstrip('_') + '_kleene' 398 | if self.checked: 399 | name += '_checked' 400 | if name.endswith('_'): # `and_` and `or_` functions differ from operators 401 | return getattr(operator, name) 402 | return getattr(pc if hasattr(pc, name) else operator, name) 403 | 404 | @classmethod 405 | @no_type_check 406 | def from_query(cls, **queries: Filter) -> Self: 407 | """Transform query syntax into an Expression input.""" 408 | exprs = [] 409 | for name, query in queries.items(): 410 | field = cls(name=[name]) 411 | exprs += (cls(**{op: [field, cls(value=value)]}) for op, value in dict(query).items()) 412 | return cls(and_=exprs) 413 | 414 | 415 | @strawberry.input(description="an `Expression` with an optional alias") 416 | class Projection(Expression): 417 | alias: str = strawberry.field(default='', description="name of projected column") 418 | 419 | 420 | class Fields: 421 | """Fields grouped by naming conventions or common options.""" 422 | 423 | prefix: str = '' 424 | 425 | def to_fields(self) -> Iterable[ds.Expression]: 426 | funcs, arguments, options = [], [], {} 427 | for field in self.__strawberry_definition__.fields: # type: ignore 428 | value = getattr(self, field.name) 429 | if isinstance(value, Expression): 430 | value = [value] 431 | if not isinstance(value, (list, type(UNSET))): 432 | options[field.name] = value 433 | elif value: 434 | funcs.append(self.getfunc(field.name)) 435 | arguments.append([expr.to_arrow() for expr in value]) 436 | for func, args in zip(funcs, arguments): 437 | keys = set(options) & set(inspect.signature(func).parameters) 438 | yield func(*args, **{key: options[key] for key in keys}) 439 | 440 | def getfunc(self, name): 441 | return getattr(pc, self.prefix + name) 442 | 443 | 444 | @strawberry.input(description="Bit-wise functions.") 445 | class BitWise(Fields): 446 | and_: list[Expression] = default_field([], name='and', func=pc.bit_wise_and) 447 | not_: list[Expression] = default_field([], name='not', func=pc.bit_wise_not) 448 | or_: list[Expression] = default_field([], name='or', func=pc.bit_wise_or) 449 | xor: list[Expression] = default_field([], func=pc.bit_wise_xor) 450 | shift_left: list[Expression] = default_field([], func=pc.shift_left) 451 | shift_right: list[Expression] = default_field([], func=pc.shift_right) 452 | 453 | def getfunc(self, name): 454 | return getattr(pc, name if name.startswith('shift') else 'bit_wise_' + name.rstrip('_')) 455 | 456 | 457 | @strawberry.input(description="Rounding functions.") 458 | class Rounding(Fields): 459 | ceil: Expression | None = default_field(func=pc.ceil) 460 | floor: Expression | None = default_field(func=pc.floor) 461 | trunc: Expression | None = default_field(func=pc.trunc) 462 | 463 | round: Expression | None = default_field(func=pc.round) 464 | ndigits: int = 0 465 | round_mode: str = 'half_to_even' 466 | multiple: float = 1.0 467 | 468 | def getfunc(self, name): 469 | if name == 'round' and self.multiple != 1.0: 470 | name = 'round_to_multiple' 471 | return getattr(pc, name) 472 | 473 | 474 | @strawberry.input(description="Logarithmic functions.") 475 | class Log(Fields): 476 | ln: Expression | None = default_field(func=pc.ln) 477 | log1p: Expression | None = default_field(func=pc.log1p) 478 | logb: list[Expression] = default_field([], func=pc.logb) 479 | 480 | 481 | @strawberry.input(description="Trigonometry functions.") 482 | class Trig(Fields): 483 | checked: bool = strawberry.field(default=False, description="check for overflow errors") 484 | 485 | acos: Expression | None = default_field(func=pc.acos) 486 | asin: Expression | None = default_field(func=pc.asin) 487 | atan: Expression | None = default_field(func=pc.atan) 488 | atan2: list[Expression] = default_field([], func=pc.atan2) 489 | cos: Expression | None = default_field(func=pc.cos) 490 | sin: Expression | None = default_field(func=pc.sin) 491 | tan: Expression | None = default_field(func=pc.tan) 492 | 493 | def getfunc(self, name): 494 | return getattr(pc, name + ('_checked' * self.checked)) 495 | 496 | 497 | @strawberry.input(description="Element-wise aggregate functions.") 498 | class ElementWise(Fields): 499 | min_element_wise: list[Expression] = default_field([], name='min', func=pc.min_element_wise) 500 | max_element_wise: list[Expression] = default_field([], name='max', func=pc.max_element_wise) 501 | skip_nulls: bool = True 502 | 503 | 504 | @strawberry.input(description="Utf8 string functions.") 505 | class Utf8(Fields): 506 | is_alnum: Expression | None = default_field(func=pc.utf8_is_alnum) 507 | is_alpha: Expression | None = default_field(func=pc.utf8_is_alpha) 508 | is_decimal: Expression | None = default_field(func=pc.utf8_is_decimal) 509 | is_digit: Expression | None = default_field(func=pc.utf8_is_digit) 510 | is_lower: Expression | None = default_field(func=pc.utf8_is_lower) 511 | is_numeric: Expression | None = default_field(func=pc.utf8_is_numeric) 512 | is_printable: Expression | None = default_field(func=pc.utf8_is_printable) 513 | is_space: Expression | None = default_field(func=pc.utf8_is_space) 514 | is_title: Expression | None = default_field(func=pc.utf8_is_title) 515 | is_upper: Expression | None = default_field(func=pc.utf8_is_upper) 516 | 517 | capitalize: Expression | None = default_field(func=pc.utf8_capitalize) 518 | length: Expression | None = default_field(func=pc.utf8_length) 519 | lower: Expression | None = default_field(func=pc.utf8_lower) 520 | reverse: Expression | None = default_field(func=pc.utf8_reverse) 521 | swapcase: Expression | None = default_field(func=pc.utf8_swapcase) 522 | title: Expression | None = default_field(func=pc.utf8_title) 523 | upper: Expression | None = default_field(func=pc.utf8_upper) 524 | 525 | ltrim: Expression | None = default_field(func=pc.utf8_ltrim) 526 | rtrim: Expression | None = default_field(func=pc.utf8_rtrim) 527 | trim: Expression | None = default_field(func=pc.utf8_trim) 528 | characters: str = default_field('', description="trim options; by default trims whitespace") 529 | 530 | replace_slice: Expression | None = default_field(func=pc.utf8_replace_slice) 531 | slice_codeunits: Expression | None = default_field(func=pc.utf8_slice_codeunits) 532 | start: int = 0 533 | stop: int | None = UNSET 534 | step: int = 1 535 | replacement: str = '' 536 | 537 | center: Expression | None = default_field(func=pc.utf8_center) 538 | lpad: Expression | None = default_field(func=pc.utf8_lpad) 539 | rpad: Expression | None = default_field(func=pc.utf8_rpad) 540 | width: int = 0 541 | padding: str = '' 542 | 543 | def getfunc(self, name): 544 | if name.endswith('trim') and not self.characters: 545 | name += '_whitespace' 546 | return getattr(pc, 'utf8_' + name) 547 | 548 | 549 | @strawberry.input(description="Binary functions.") 550 | class Binary(Fields): 551 | length: Expression | None = default_field(func=pc.binary_length) 552 | repeat: list[Expression] = default_field([], func=pc.binary_repeat) 553 | reverse: Expression | None = default_field(func=pc.binary_reverse) 554 | 555 | join: list[Expression] = default_field([], func=pc.binary_join) 556 | join_element_wise: list[Expression] = default_field([], func=pc.binary_join_element_wise) 557 | null_handling: str = 'emit_null' 558 | null_replacement: str = '' 559 | 560 | replace_slice: Expression | None = default_field(func=pc.binary_replace_slice) 561 | start: int = 0 562 | stop: int = 0 563 | replacement: bytes = b'' 564 | 565 | prefix = 'binary_' 566 | 567 | 568 | @strawberry.input(description="Match substring functions.") 569 | class MatchSubstring(Fields): 570 | count_substring: Expression | None = default_field(name='count', func=pc.count_substring) 571 | ends_with: Expression | None = default_field(func=pc.ends_with) 572 | find_substring: Expression | None = default_field(name='find', func=pc.find_substring) 573 | match_substring: Expression | None = default_field(name='match', func=pc.match_substring) 574 | starts_with: Expression | None = default_field(func=pc.starts_with) 575 | replace_substring: Expression | None = default_field(name='replace', func=pc.replace_substring) 576 | split_pattern: Expression | None = default_field(name='split', func=pc.split_pattern) 577 | extract: Expression | None = default_field(func=pc.extract_regex) 578 | pattern: str = '' 579 | ignore_case: bool = False 580 | regex: bool = False 581 | replacement: str = '' 582 | max_replacements: int | None = None 583 | max_splits: int | None = None 584 | reverse: bool = False 585 | 586 | def getfunc(self, name): 587 | if name == 'split_pattern' and not self.pattern: 588 | name = 'utf8_split_whitespace' 589 | return getattr(pc, name + ('_regex' * self.regex)) 590 | 591 | 592 | @strawberry.input(description="Temporal functions.") 593 | class Temporal(Fields): 594 | day: Expression | None = default_field(func=pc.day) 595 | day_of_year: Expression | None = default_field(func=pc.day_of_year) 596 | hour: Expression | None = default_field(func=pc.hour) 597 | iso_week: Expression | None = default_field(func=pc.iso_week) 598 | iso_year: Expression | None = default_field(func=pc.iso_year) 599 | iso_calendar: Expression | None = default_field(func=pc.iso_calendar) 600 | is_leap_year: Expression | None = default_field(func=pc.is_leap_year) 601 | 602 | microsecond: Expression | None = default_field(func=pc.microsecond) 603 | millisecond: Expression | None = default_field(func=pc.millisecond) 604 | minute: Expression | None = default_field(func=pc.minute) 605 | month: Expression | None = default_field(func=pc.month) 606 | nanosecond: Expression | None = default_field(func=pc.nanosecond) 607 | quarter: Expression | None = default_field(func=pc.quarter) 608 | second: Expression | None = default_field(func=pc.second) 609 | subsecond: Expression | None = default_field(func=pc.subsecond) 610 | us_week: Expression | None = default_field(func=pc.us_week) 611 | us_year: Expression | None = default_field(func=pc.us_year) 612 | year: Expression | None = default_field(func=pc.year) 613 | year_month_day: Expression | None = default_field(func=pc.year_month_day) 614 | 615 | day_time_interval_between: list[Expression] = default_field( 616 | [], func=pc.day_time_interval_between 617 | ) 618 | days_between: list[Expression] = default_field([], func=pc.days_between) 619 | hours_between: list[Expression] = default_field([], func=pc.hours_between) 620 | microseconds_between: list[Expression] = default_field([], func=pc.microseconds_between) 621 | milliseconds_between: list[Expression] = default_field([], func=pc.milliseconds_between) 622 | minutes_between: list[Expression] = default_field([], func=pc.minutes_between) 623 | month_day_nano_interval_between: list[Expression] = default_field( 624 | [], func=pc.month_day_nano_interval_between 625 | ) 626 | month_interval_between: list[Expression] = default_field([], func=pc.month_interval_between) 627 | nanoseconds_between: list[Expression] = default_field([], func=pc.nanoseconds_between) 628 | quarters_between: list[Expression] = default_field([], func=pc.quarters_between) 629 | seconds_between: list[Expression] = default_field([], func=pc.seconds_between) 630 | weeks_between: list[Expression] = default_field([], func=pc.weeks_between) 631 | years_between: list[Expression] = default_field([], func=pc.years_between) 632 | 633 | ceil_temporal: Expression | None = default_field(name='ceil', func=pc.ceil_temporal) 634 | floor_temporal: Expression | None = default_field(name='floor', func=pc.floor_temporal) 635 | round_temporal: Expression | None = default_field(name='round', func=pc.round_temporal) 636 | multiple: int = 1 637 | unit: str = 'day' 638 | week_starts_monday: bool = True 639 | ceil_is_strictly_greater: bool = False 640 | calendar_based_origin: bool = False 641 | 642 | week: Expression | None = default_field(func=pc.week) 643 | count_from_zero: bool | None = UNSET 644 | first_week_is_fully_in_year: bool = False 645 | 646 | day_of_week: Expression | None = default_field(func=pc.day_of_week) 647 | week_start: int = 1 648 | 649 | strftime: Expression | None = default_field(func=pc.strftime) 650 | strptime: Expression | None = default_field(func=pc.strptime) 651 | format: str = '%Y-%m-%dT%H:%M:%S' 652 | locale: str = 'C' 653 | error_is_null: bool = False 654 | 655 | assume_timezone: Expression | None = default_field(func=pc.assume_timezone) 656 | timezone: str = '' 657 | ambiguous: str = 'raise' 658 | nonexistent: str = 'raise' 659 | 660 | 661 | @strawberry.input(description="Set lookup functions.") 662 | class SetLookup(Fields): 663 | index_in: list[Expression] = default_field([], func=pc.index_in) 664 | digitize: list[Expression] = default_field( 665 | [], 666 | description="numpy [digitize](https://numpy.org/doc/stable/reference/generated/numpy.digitize.html)", 667 | ) 668 | skip_nulls: bool = False 669 | right: bool = False 670 | 671 | def to_fields(self) -> Iterable[ds.Expression]: 672 | if self.index_in: 673 | values, value_set = [expr.to_arrow() for expr in self.index_in] 674 | yield pc.index_in(values, pa.array(value_set), skip_nulls=self.skip_nulls) 675 | if self.digitize: 676 | values, value_set = [expr.to_arrow() for expr in self.digitize] 677 | args = values.cast('float64'), list(map(float, value_set)), self.right # type: ignore 678 | yield ds.Expression._call('digitize', list(args)) 679 | 680 | 681 | @strawberry.input(description="List array functions.") 682 | class Lists(Fields): 683 | element: list[Expression] = default_field([], func=pc.list_element) 684 | value_length: Expression | None = default_field(func=pc.list_value_length) 685 | # user defined functions 686 | all: Expression | None = default_field(func=pc.all) 687 | any: Expression | None = default_field(func=pc.any) 688 | 689 | slice: Expression | None = default_field(func=pc.list_slice) 690 | start: int = 0 691 | stop: int | None = None 692 | step: int = 1 693 | return_fixed_size_list: bool | None = None 694 | 695 | prefix = 'list_' 696 | 697 | def getfunc(self, name): 698 | if name in ('element', 'value_length', 'slice'): # built-ins 699 | return super().getfunc(name) 700 | return lambda *args: ds.Expression._call(self.prefix + name, list(args)) 701 | -------------------------------------------------------------------------------- /graphique/interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Primary Dataset interface. 3 | 4 | Doesn't require knowledge of the schema. 5 | """ 6 | 7 | # mypy: disable-error-code=valid-type 8 | import collections 9 | import inspect 10 | import itertools 11 | from collections.abc import Callable, Iterable, Iterator, Mapping, Sized 12 | from datetime import timedelta 13 | from typing import Annotated, TypeAlias, no_type_check 14 | import pyarrow as pa 15 | import pyarrow.compute as pc 16 | import pyarrow.dataset as ds 17 | import strawberry.asgi 18 | from strawberry import Info 19 | from strawberry.extensions.utils import get_path_from_info 20 | from typing_extensions import Self 21 | from .core import Agg, Batch, Column as C, ListChunk, Nodes, Table as T 22 | from .inputs import CountAggregate, Cumulative, Diff, Expression, Field, Filter 23 | from .inputs import HashAggregates, ListFunction, Pairwise, Projection, Rank 24 | from .inputs import ScalarAggregate, TDigestAggregate, VarianceAggregate, links, provisional 25 | from .models import Column, doc_field 26 | from .scalars import Long 27 | 28 | Source: TypeAlias = ds.Dataset | Nodes | ds.Scanner | pa.Table 29 | 30 | 31 | def references(field) -> Iterator: 32 | """Generate every possible column reference from strawberry `SelectedField`.""" 33 | if isinstance(field, str): 34 | yield field.lstrip('-') 35 | elif isinstance(field, Iterable): 36 | for value in field: 37 | yield from references(value) 38 | if isinstance(field, Mapping): 39 | for value in field.values(): 40 | yield from references(value) 41 | else: 42 | for name in ('name', 'arguments', 'selections'): 43 | yield from references(getattr(field, name, [])) 44 | 45 | 46 | def doc_argument(annotation, func: Callable, **kwargs): 47 | """Use function doc for argument description.""" 48 | kwargs['description'] = inspect.getdoc(func).splitlines()[0] # type: ignore 49 | return Annotated[annotation, strawberry.argument(**kwargs)] 50 | 51 | 52 | @strawberry.type(description="dataset schema") 53 | class Schema: 54 | names: list[str] = strawberry.field(description="field names") 55 | types: list[str] = strawberry.field( 56 | description="[arrow types](https://arrow.apache.org/docs/python/api/datatypes.html), corresponding to `names`" 57 | ) 58 | partitioning: list[str] = strawberry.field(description="partition keys") 59 | index: list[str] = strawberry.field(description="sorted index columns") 60 | 61 | 62 | @strawberry.interface(description="an arrow dataset, scanner, or table") 63 | class Dataset: 64 | def __init__(self, source: Source): 65 | self.source = source 66 | 67 | def references(self, info: Info, level: int = 0) -> set: 68 | """Return set of every possible future column reference.""" 69 | fields = info.selected_fields 70 | for _ in range(level): 71 | fields = itertools.chain(*[field.selections for field in fields]) 72 | return set(itertools.chain(*map(references, fields))) & set(self.schema().names) 73 | 74 | def select(self, info: Info) -> Source: 75 | """Return source with only the columns necessary to proceed.""" 76 | names = list(self.references(info)) 77 | if len(names) >= len(self.schema().names): 78 | return self.source 79 | if isinstance(self.source, ds.Scanner): 80 | schema = self.source.projected_schema 81 | return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names) 82 | if isinstance(self.source, pa.Table): 83 | return self.source.select(names) 84 | return Nodes.scan(self.source, names) 85 | 86 | def to_table(self, info: Info, length: int | None = None) -> pa.Table: 87 | """Return table with only the rows and columns necessary to proceed.""" 88 | source = self.select(info) 89 | if isinstance(source, pa.Table): 90 | return source 91 | if length is None: 92 | return self.add_metric(info, source.to_table(), mode='read') 93 | return self.add_metric(info, source.head(length), mode='head') 94 | 95 | @classmethod 96 | @no_type_check 97 | def resolve_reference(cls, info: Info, **keys) -> Self: 98 | """Return table from federated keys.""" 99 | self = getattr(info.root_value, cls.field) 100 | queries = {name: Filter(eq=[keys[name]]) for name in keys} 101 | return self.filter(info, **queries) 102 | 103 | def columns(self, info: Info) -> dict: 104 | """fields for each column""" 105 | table = self.to_table(info) 106 | return {name: Column.cast(table[name]) for name in table.schema.names} 107 | 108 | def row(self, info: Info, index: int = 0) -> dict: 109 | """Return scalar values at index.""" 110 | table = self.to_table(info, index + 1 if index >= 0 else None) 111 | row = {} 112 | for name in table.schema.names: 113 | scalar = table[name][index] 114 | columnar = isinstance(scalar, pa.ListScalar) 115 | row[name] = Column.fromscalar(scalar) if columnar else scalar.as_py() 116 | return row 117 | 118 | def filter(self, info: Info, **queries: Filter) -> Self: 119 | """Return table with rows which match all queries. 120 | 121 | See `scan(filter: ...)` for more advanced queries. Additional feature: sorted tables 122 | support binary search 123 | """ 124 | source = self.source 125 | prev = info.path.prev 126 | search = isinstance(source, pa.Table) and (prev is None or prev.typename == 'Query') 127 | for name in self.schema().index if search else []: 128 | assert not source[name].null_count, f"search requires non-null column: {name}" 129 | query = dict(queries.pop(name)) 130 | if 'eq' in query: 131 | source = T.is_in(source, name, *query['eq']) 132 | if 'ne' in query: 133 | source = T.not_equal(source, name, query['ne']) 134 | lower, upper = query.get('gt'), query.get('lt') 135 | includes = {'include_lower': False, 'include_upper': False} 136 | if 'ge' in query and (lower is None or query['ge'] > lower): 137 | lower, includes['include_lower'] = query['ge'], True 138 | if 'le' in query and (upper is None or query['le'] > upper): 139 | upper, includes['include_upper'] = query['le'], True 140 | if {lower, upper} != {None}: 141 | source = T.range(source, name, lower, upper, **includes) 142 | if len(query.pop('eq', [])) != 1 or query: 143 | break 144 | return type(self)(source).scan(info, filter=Expression.from_query(**queries)) 145 | 146 | @doc_field 147 | def type(self) -> str: 148 | """[arrow type](https://arrow.apache.org/docs/python/api/dataset.html#classes)""" 149 | return type(self.source).__name__ 150 | 151 | @doc_field 152 | def schema(self) -> Schema: 153 | """dataset schema""" 154 | source = self.source 155 | schema = source.projected_schema if isinstance(source, ds.Scanner) else source.schema 156 | partitioning = getattr(source, 'partitioning', None) 157 | index = (schema.pandas_metadata or {}).get('index_columns', []) 158 | return Schema( 159 | names=schema.names, 160 | types=schema.types, 161 | partitioning=partitioning.schema.names if partitioning else [], 162 | index=[name for name in index if isinstance(name, str)], 163 | ) # type: ignore 164 | 165 | @doc_field 166 | def optional(self) -> Self | None: 167 | """Nullable field to stop error propagation, enabling partial query results. 168 | 169 | Will be replaced by client controlled nullability. 170 | """ 171 | return self 172 | 173 | @staticmethod 174 | def add_metric(info: Info, table: pa.Table, **data): 175 | """Add memory usage and other metrics to context with path info.""" 176 | path = tuple(get_path_from_info(info)) 177 | info.context.setdefault('metrics', {})[path] = dict(data, memory=T.size(table)) 178 | return table 179 | 180 | @doc_field 181 | def length(self) -> Long: 182 | """number of rows""" 183 | return len(self.source) if isinstance(self.source, Sized) else self.source.count_rows() 184 | 185 | @doc_field 186 | def any(self, info: Info, length: Long = 1) -> bool: 187 | """Return whether there are at least `length` rows. 188 | 189 | May be significantly faster than `length` for out-of-core data. 190 | """ 191 | table = self.to_table(info, length) 192 | return len(table) >= length 193 | 194 | @doc_field 195 | def size(self) -> Long | None: 196 | """buffer size in bytes; null if table is not loaded""" 197 | return getattr(self.source, 'nbytes', None) 198 | 199 | @doc_field( 200 | name="column name(s); multiple names access nested struct fields", 201 | cast=f"cast array to {links.type}", 202 | safe="check for conversion errors on cast", 203 | ) 204 | def column( 205 | self, info: Info, name: list[str], cast: str = '', safe: bool = True 206 | ) -> Column | None: 207 | """Return column of any type by name. 208 | 209 | This is typically only needed for aliased or casted columns. 210 | If the column is in the schema, `columns` can be used instead. 211 | """ 212 | if isinstance(self.source, pa.Table) and len(name) == 1: 213 | column = self.source.column(*name) 214 | return Column.cast(column.cast(cast, safe) if cast else column) 215 | column = Projection(alias='_', name=name, cast=cast, safe=safe) # type: ignore 216 | source = self.scan(info, Expression(), [column]).source 217 | return Column.cast(*(source if isinstance(source, pa.Table) else source.to_table())) 218 | 219 | @doc_field( 220 | offset="number of rows to skip; negative value skips from the end", 221 | length="maximum number of rows to return", 222 | reverse="reverse order after slicing; forces a copy", 223 | ) 224 | def slice( 225 | self, info: Info, offset: Long = 0, length: Long | None = None, reverse: bool = False 226 | ) -> Self: 227 | """Return zero-copy slice of table. 228 | 229 | Can also be sued to force loading a dataset. 230 | """ 231 | table = self.to_table(info, length and (offset + length if offset >= 0 else None)) 232 | table = table[offset:][:length] # `slice` bug: ARROW-15412 233 | return type(self)(table[::-1] if reverse else table) 234 | 235 | @doc_field( 236 | by="column names; empty will aggregate into a single row table", 237 | counts="optionally include counts in an aliased column", 238 | ordered="optionally disable parallelization to maintain ordering", 239 | aggregate="aggregation functions applied to other columns", 240 | ) 241 | def group( 242 | self, 243 | info: Info, 244 | by: list[str] = [], 245 | counts: str = '', 246 | ordered: bool = False, 247 | aggregate: HashAggregates = {}, # type: ignore 248 | ) -> Self: 249 | """Return table grouped by columns. 250 | 251 | See `column` for accessing any column which has changed type. See `tables` to split on any 252 | aggregated list columns. 253 | """ 254 | if not any(aggregate.keys()): 255 | fragments = T.fragments(self.source, *by, counts=counts) 256 | if set(fragments.schema.names) >= set(by): 257 | return type(self)(fragments) 258 | prefix = 'hash_' if by else '' 259 | aggs: dict = {counts: ([], prefix + 'count_all', None)} if counts else {} 260 | for func, values in dict(aggregate).items(): 261 | ordered = ordered or func in Agg.ordered 262 | for agg in values: 263 | aggs[agg.alias] = (agg.name, prefix + func, agg.func_options(func)) 264 | source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source 265 | source = Nodes.group(source, *by, **aggs) 266 | if ordered: 267 | source = self.add_metric(info, source.to_table(use_threads=False), mode='group') 268 | return type(self)(source) 269 | 270 | @doc_field( 271 | by="column names", 272 | split="optional predicates to split on; scalars are compared to pairwise difference", 273 | counts="optionally include counts in an aliased column", 274 | ) 275 | @no_type_check 276 | def runs( 277 | self, info: Info, by: list[str] = [], split: list[Diff] = [], counts: str = '' 278 | ) -> Self: 279 | """Return table grouped by pairwise differences. 280 | 281 | Differs from `group` by relying on adjacency, and is typically faster. Other columns are 282 | transformed into list columns. See `column` and `tables` to further access lists. 283 | """ 284 | table = self.to_table(info) 285 | predicates = {} 286 | for diff in map(dict, split): 287 | name = diff.pop('name') 288 | ((func, value),) = diff.items() 289 | if pa.types.is_timestamp(table.field(name).type): 290 | value = timedelta(seconds=value) 291 | predicates[name] = (getattr(pc, func), value)[: 1 if value is None else 2] 292 | table, counts_ = T.runs(table, *by, **predicates) 293 | return type(self)(table.append_column(counts, counts_) if counts else table) 294 | 295 | @doc_field( 296 | by="column names; prefix with `-` for descending order", 297 | length="maximum number of rows to return; may be significantly faster but is unstable", 298 | null_placement="where nulls in input should be sorted; incompatible with `length`", 299 | ) 300 | def sort( 301 | self, 302 | info: Info, 303 | by: list[str], 304 | length: Long | None = None, 305 | null_placement: str = 'at_end', 306 | ) -> Self: 307 | """Return table slice sorted by specified columns. 308 | 309 | Optimized for length == 1; matches min or max values. 310 | """ 311 | kwargs = dict(length=length, null_placement=null_placement) 312 | if isinstance(self.source, pa.Table) or length is None: 313 | table = self.to_table(info) 314 | else: 315 | expr, by = T.rank_keys(self.source, length, *by, dense=False) 316 | if expr is not None: 317 | self = type(self)(self.source.filter(expr)) 318 | source = self.select(info) 319 | if not by: 320 | return type(self)(self.add_metric(info, source.head(length), mode='head')) 321 | table = T.map_batch(source, T.sort, *by, **kwargs) 322 | self.add_metric(info, table, mode='batch') 323 | return type(self)(T.sort(table, *by, **kwargs)) # type: ignore 324 | 325 | @doc_field( 326 | by="column names; prefix with `-` for descending order", 327 | max="maximum dense rank to select; optimized for == 1 (min or max)", 328 | ) 329 | def rank(self, info: Info, by: list[str], max: int = 1) -> Self: 330 | """Return table selected by maximum dense rank.""" 331 | source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source 332 | expr, by = T.rank_keys(source, max, *by) 333 | if expr is not None: 334 | source = source.filter(expr) 335 | return type(self)(T.rank(source, max, *by) if by else source) 336 | 337 | @staticmethod 338 | def apply_list(table: Batch, list_: ListFunction) -> Batch: 339 | expr = list_.filter.to_arrow() if list_.filter else None 340 | if expr is not None: 341 | table = T.filter_list(table, expr) 342 | if list_.rank: 343 | table = T.map_list(table, T.rank, list_.rank.max, *list_.rank.by) 344 | if list_.sort: 345 | table = T.map_list(table, T.sort, *list_.sort.by, length=list_.sort.length) 346 | columns = {} 347 | for func, field in dict(list_).items(): 348 | columns[field.alias] = getattr(ListChunk, func)(table[field.name], **field.options) 349 | return T.union(table, pa.RecordBatch.from_pydict(columns)) 350 | 351 | @doc_field 352 | @no_type_check 353 | def apply( 354 | self, 355 | info: Info, 356 | cumulative_max: doc_argument(list[Cumulative], func=pc.cumulative_max) = [], 357 | cumulative_mean: doc_argument(list[Cumulative], func=pc.cumulative_mean) = [], 358 | cumulative_min: doc_argument(list[Cumulative], func=pc.cumulative_min) = [], 359 | cumulative_prod: doc_argument(list[Cumulative], func=pc.cumulative_prod) = [], 360 | cumulative_sum: doc_argument(list[Cumulative], func=pc.cumulative_sum) = [], 361 | fill_null_backward: doc_argument(list[Field], func=pc.fill_null_backward) = [], 362 | fill_null_forward: doc_argument(list[Field], func=pc.fill_null_forward) = [], 363 | pairwise_diff: doc_argument(list[Pairwise], func=pc.pairwise_diff) = [], 364 | rank: doc_argument(list[Rank], func=pc.rank) = [], 365 | list_: Annotated[ 366 | ListFunction, 367 | strawberry.argument(name='list', description="functions for list arrays."), 368 | ] = {}, 369 | ) -> Self: 370 | """Return view of table with vector functions applied across columns. 371 | 372 | Applied functions load arrays into memory as needed. See `scan` for scalar functions, 373 | which do not require loading. 374 | """ 375 | table = T.map_batch(self.select(info), self.apply_list, list_) 376 | self.add_metric(info, table, mode='batch') 377 | columns = {} 378 | funcs = pc.cumulative_max, pc.cumulative_mean, pc.cumulative_min, pc.cumulative_prod 379 | funcs += pc.cumulative_sum, C.fill_null_backward, C.fill_null_forward, C.pairwise_diff 380 | funcs += (pc.rank,) 381 | for func in funcs: 382 | for field in locals()[func.__name__]: 383 | callable = func 384 | if field.options.pop('checked', False): 385 | callable = getattr(pc, func.__name__ + '_checked') 386 | columns[field.alias] = callable(table[field.name], **field.options) 387 | return type(self)(T.union(table, pa.table(columns))) 388 | 389 | @doc_field 390 | def flatten(self, info: Info, indices: str = '') -> Self: 391 | """Return table with list arrays flattened. 392 | 393 | At least one list column must be referenced, and all list columns must have the same lengths. 394 | """ 395 | table = pa.Table.from_batches(T.flatten(self.select(info), indices)) 396 | return type(self)(self.add_metric(info, table, mode='batch')) 397 | 398 | @doc_field 399 | def tables(self, info: Info) -> list[Self | None]: # type: ignore 400 | """Return a list of tables by splitting list columns. 401 | 402 | At least one list column must be referenced, and all list columns must have the same lengths. 403 | """ 404 | for batch in self.select(info).to_batches(): 405 | for row in T.split(batch): 406 | yield None if row is None else type(self)(pa.Table.from_batches([row])) 407 | 408 | @doc_field 409 | def aggregate( 410 | self, 411 | info: Info, 412 | approximate_median: doc_argument(list[ScalarAggregate], func=pc.approximate_median) = [], 413 | count: doc_argument(list[CountAggregate], func=pc.count) = [], 414 | count_distinct: doc_argument(list[CountAggregate], func=pc.count_distinct) = [], 415 | distinct: Annotated[ 416 | list[CountAggregate], 417 | strawberry.argument(description="distinct values within each scalar"), 418 | ] = [], 419 | first: doc_argument(list[Field], func=ListChunk.first) = [], 420 | last: doc_argument(list[Field], func=ListChunk.last) = [], 421 | max: doc_argument(list[ScalarAggregate], func=pc.max) = [], 422 | mean: doc_argument(list[ScalarAggregate], func=pc.mean) = [], 423 | min: doc_argument(list[ScalarAggregate], func=pc.min) = [], 424 | product: doc_argument(list[ScalarAggregate], func=pc.product) = [], 425 | stddev: doc_argument(list[VarianceAggregate], func=pc.stddev) = [], 426 | sum: doc_argument(list[ScalarAggregate], func=pc.sum) = [], 427 | tdigest: doc_argument(list[TDigestAggregate], func=pc.tdigest) = [], 428 | variance: doc_argument(list[VarianceAggregate], func=pc.variance) = [], 429 | ) -> Self: 430 | """Return table with scalar aggregate functions applied to list columns.""" 431 | table = self.to_table(info) 432 | columns = T.columns(table) 433 | agg_fields: dict = collections.defaultdict(dict) 434 | keys: tuple = 'approximate_median', 'count', 'count_distinct', 'distinct', 'first', 'last' 435 | keys += 'max', 'mean', 'min', 'product', 'stddev', 'sum', 'tdigest', 'variance' 436 | for key in keys: 437 | func = getattr(ListChunk, key, None) 438 | for agg in locals()[key]: 439 | if func is None or key == 'sum': # `sum` is a method on `Array`` 440 | agg_fields[agg.name][key] = agg 441 | else: 442 | columns[agg.alias] = func(table[agg.name], **agg.options) 443 | for name, aggs in agg_fields.items(): 444 | funcs = {key: agg.func_options(key) for key, agg in aggs.items()} 445 | batch = ListChunk.aggregate(table[name], **funcs) 446 | columns.update(zip([agg.alias for agg in aggs.values()], batch)) 447 | return type(self)(pa.table(columns)) 448 | 449 | aggregate.deprecation_reason = ListFunction.deprecation 450 | 451 | @doc_field(filter="selected rows", columns="projected columns") 452 | def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self: # type: ignore 453 | """Select rows and project columns without memory usage.""" 454 | expr = filter.to_arrow() 455 | projection = {name: pc.field(name) for name in self.references(info, level=1)} 456 | projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns} 457 | if '' in projection: 458 | raise ValueError(f"projected columns need a name or alias: {projection['']}") 459 | if isinstance(self.source, ds.Scanner): 460 | options = dict(schema=self.source.projected_schema, filter=expr, columns=projection) 461 | scanner = ds.Scanner.from_batches(self.source.to_batches(), **options) 462 | return type(self)(self.add_metric(info, scanner.to_table(), mode='batch')) 463 | source = self.source if expr is None else self.source.filter(expr) 464 | return type(self)(Nodes.scan(source, projection) if columns else source) 465 | 466 | @doc_field( 467 | right="name of right table; must be on root Query type", 468 | keys="column names used as keys on the left side", 469 | right_keys="column names used as keys on the right side; defaults to left side.", 470 | join_type="the kind of join: 'left semi', 'right semi', 'left anti', 'right anti', 'inner', 'left outer', 'right outer', 'full outer'", 471 | left_suffix="add suffix to left column names; for preventing collisions", 472 | right_suffix="add suffix to right column names; for preventing collisions.", 473 | coalesce_keys="omit duplicate keys", 474 | ) 475 | def join( 476 | self, 477 | info: Info, 478 | right: str, 479 | keys: list[str], 480 | right_keys: list[str] | None = None, 481 | join_type: str = 'left outer', 482 | left_suffix: str = '', 483 | right_suffix: str = '', 484 | coalesce_keys: bool = True, 485 | ) -> Self: 486 | """Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type.""" 487 | left, right = ( 488 | root.source if isinstance(root.source, ds.Dataset) else root.to_table(info) 489 | for root in (self, getattr(info.root_value, right)) 490 | ) 491 | table = left.join( 492 | right, 493 | keys=keys, 494 | right_keys=right_keys, 495 | join_type=join_type, 496 | left_suffix=left_suffix, 497 | right_suffix=right_suffix, 498 | coalesce_keys=coalesce_keys, 499 | ) 500 | return type(self)(table) 501 | 502 | join.directives = [provisional()] 503 | 504 | @doc_field 505 | def take(self, info: Info, indices: list[Long]) -> Self: 506 | """Select rows from indices.""" 507 | table = self.select(info).take(indices) 508 | return type(self)(self.add_metric(info, table, mode='take')) 509 | 510 | @doc_field 511 | def drop_null(self, info: Info) -> Self: 512 | """Remove missing values from referenced columns in the table.""" 513 | if isinstance(self.source, pa.Table): 514 | return type(self)(pc.drop_null(self.to_table(info))) 515 | table = T.map_batch(self.select(info), pc.drop_null) 516 | return type(self)(self.add_metric(info, table, mode='batch')) 517 | -------------------------------------------------------------------------------- /graphique/middleware.py: -------------------------------------------------------------------------------- 1 | """ 2 | ASGI GraphQL utilities. 3 | """ 4 | 5 | import warnings 6 | from collections.abc import Iterable, Mapping 7 | from datetime import timedelta 8 | from keyword import iskeyword 9 | import pyarrow.dataset as ds 10 | import strawberry.asgi 11 | from strawberry import Info, UNSET 12 | from strawberry.extensions import tracing 13 | from strawberry.utils.str_converters import to_camel_case 14 | from .inputs import Filter 15 | from .interface import Dataset, Source 16 | from .models import Column, doc_field 17 | from .scalars import Long, py_type, scalar_map 18 | 19 | 20 | class MetricsExtension(tracing.ApolloTracingExtension): 21 | """Human-readable metrics from apollo tracing.""" 22 | 23 | def get_results(self) -> dict: 24 | tracing = super().get_results()['tracing'] 25 | metrics = self.execution_context.context.get('metrics', {}) 26 | resolvers = [] 27 | for resolver in tracing['execution']['resolvers']: # pragma: no cover 28 | path = tuple(resolver['path']) 29 | resolvers.append({'path': path, 'duration': self.duration(resolver)}) 30 | resolvers[-1].update(metrics.get(path, {})) 31 | metrics = {'duration': self.duration(tracing), 'execution': {'resolvers': resolvers}} 32 | return {'metrics': metrics} 33 | 34 | @staticmethod 35 | def duration(data: dict) -> str | None: 36 | return data['duration'] and str(timedelta(microseconds=data['duration'] / 1e3)) 37 | 38 | 39 | class GraphQL(strawberry.asgi.GraphQL): 40 | """ASGI GraphQL app with root value(s). 41 | 42 | Args: 43 | root: root dataset to attach as the Query type 44 | debug: enable timing extension 45 | **kwargs: additional `asgi.GraphQL` options 46 | """ 47 | 48 | options = dict(types=Column.registry.values(), scalar_overrides=scalar_map) 49 | 50 | def __init__(self, root: Source, debug: bool = False, **kwargs): 51 | options: dict = dict(self.options, extensions=(MetricsExtension,) * bool(debug)) 52 | if type(root).__name__ == 'Query': 53 | self.root_value = root 54 | options['enable_federation_2'] = True 55 | schema = strawberry.federation.Schema(type(self.root_value), **options) 56 | else: 57 | self.root_value = implemented(root) 58 | schema = strawberry.Schema(type(self.root_value), **options) 59 | super().__init__(schema, debug=debug, **kwargs) 60 | 61 | async def get_root_value(self, request): 62 | return self.root_value 63 | 64 | @classmethod 65 | def federated(cls, roots: Mapping[str, Source], keys: Mapping[str, Iterable] = {}, **kwargs): 66 | """Construct GraphQL app with multiple federated datasets. 67 | 68 | Args: 69 | roots: mapping of field names to root datasets 70 | keys: mapping of optional federation keys for each root 71 | **kwargs: additional `asgi.GraphQL` options 72 | """ 73 | root_values = {name: implemented(roots[name], name, keys.get(name, ())) for name in roots} 74 | annotations = {name: type(root_values[name]) for name in root_values} 75 | Query = type('Query', (), {'__annotations__': annotations}) 76 | return cls(strawberry.type(Query)(**root_values), **kwargs) 77 | 78 | 79 | def implemented(root: Source, name: str = '', keys: Iterable = ()): 80 | """Return type which extends the Dataset interface with knowledge of the schema.""" 81 | schema = root.projected_schema if isinstance(root, ds.Scanner) else root.schema 82 | types = {field.name: py_type(field.type) for field in schema} 83 | types = {name: types[name] for name in types if name.isidentifier() and not iskeyword(name)} 84 | if invalid := set(schema.names) - set(types): 85 | warnings.warn(f'invalid field names: {invalid}') 86 | prefix = to_camel_case(name.title()) 87 | 88 | namespace = {name: strawberry.field(default=UNSET, name=name) for name in types} 89 | annotations = {name: Column.registry[types[name]] | None for name in types} 90 | cls = type(prefix + 'Columns', (), dict(namespace, __annotations__=annotations)) 91 | Columns = strawberry.type(cls, description="fields for each column") 92 | 93 | namespace = {name: strawberry.field(default=UNSET, name=name) for name in types} 94 | annotations = {name: (Column if cls is list else cls) | None for name, cls in types.items()} 95 | cls = type(prefix + 'Row', (), dict(namespace, __annotations__=annotations)) 96 | Row = strawberry.type(cls, description="scalar fields") 97 | 98 | class Table(Dataset): 99 | __init__ = Dataset.__init__ 100 | field = name 101 | 102 | def columns(self, info: Info) -> Columns: # type: ignore 103 | """fields for each column""" 104 | return Columns(**super().columns(info)) 105 | 106 | def row(self, info: Info, index: Long = 0) -> Row | None: # type: ignore 107 | """Return scalar values at index.""" 108 | row = super().row(info, index) 109 | for name, value in row.items(): 110 | if isinstance(value, Column) and types[name] is not list: 111 | raise TypeError(f"Field `{name}` cannot represent `Column` value") 112 | return Row(**row) 113 | 114 | if types: 115 | for field in ('filter', 'columns', 'row'): 116 | setattr(Table, field, doc_field(getattr(Table, field))) 117 | Table.filter.type = Table 118 | Table.filter.base_resolver.arguments = list(Filter.resolve_args(types)) 119 | options = dict(name=prefix + 'Table', description="a dataset with a derived schema") 120 | if name: 121 | return strawberry.federation.type(Table, keys=keys, **options)(root) 122 | return strawberry.type(Table, **options)(root) 123 | -------------------------------------------------------------------------------- /graphique/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | GraphQL output types and resolvers. 3 | """ 4 | 5 | from __future__ import annotations 6 | import functools 7 | import inspect 8 | from collections.abc import Callable 9 | from datetime import date, datetime, time, timedelta 10 | from decimal import Decimal 11 | from typing import Annotated, Generic, TypeVar, TYPE_CHECKING, get_args 12 | import pyarrow as pa 13 | import pyarrow.compute as pc 14 | import strawberry 15 | from strawberry import Info 16 | from strawberry.types.field import StrawberryField 17 | from .core import Column as C 18 | from .inputs import links 19 | from .scalars import Long, py_type, scalar_map 20 | 21 | if TYPE_CHECKING: # pragma: no cover 22 | from .interface import Dataset 23 | T = TypeVar('T') 24 | 25 | 26 | def selections(*fields) -> set: 27 | """Return field name selections from strawberry `SelectedField`.""" 28 | return {selection.name for field in fields for selection in field.selections} 29 | 30 | 31 | def doc_field(func: Callable | None = None, **kwargs: str) -> StrawberryField: 32 | """Return strawberry field with argument and docstring descriptions.""" 33 | if func is None: 34 | return functools.partial(doc_field, **kwargs) # type: ignore 35 | for name in kwargs: 36 | argument = strawberry.argument(description=kwargs[name]) 37 | func.__annotations__[name] = Annotated[func.__annotations__[name], argument] 38 | return strawberry.field(func, description=inspect.getdoc(func)) 39 | 40 | 41 | def compute_field(func: Callable): 42 | """Wrap compute function with its description.""" 43 | doc = inspect.getdoc(getattr(pc, func.__name__)) 44 | return strawberry.field(func, description=doc.splitlines()[0]) # type: ignore 45 | 46 | 47 | @strawberry.interface(description="an arrow array") 48 | class Column: 49 | registry = {} # type: ignore 50 | 51 | def __init__(self, array): 52 | self.array = array 53 | 54 | def __init_subclass__(cls): 55 | cls.__init__ = cls.__init__ 56 | 57 | @classmethod 58 | def register(cls, *scalars): 59 | if cls is Column: 60 | return lambda cls: cls.register(*scalars) or cls 61 | # strawberry#1921: scalar python names are prepended to column name 62 | generic = issubclass(cls, Generic) 63 | for scalar in scalars: 64 | cls.registry[scalar] = cls[scalar_map.get(scalar, scalar)] if generic else cls 65 | 66 | @strawberry.field(description=links.type) 67 | def type(self) -> str: 68 | return str(self.array.type) 69 | 70 | @doc_field 71 | def length(self) -> Long: 72 | """array length""" 73 | return len(self.array) 74 | 75 | @doc_field 76 | def size(self) -> Long: 77 | """buffer size in bytes""" 78 | return self.array.nbytes 79 | 80 | @classmethod 81 | def cast(cls, array: pa.ChunkedArray) -> Column: 82 | """Return typed column based on array type.""" 83 | return cls.registry[py_type(array.type)](array) 84 | 85 | @classmethod 86 | def fromscalar(cls, scalar: pa.ListScalar) -> Column | None: 87 | return None if scalar.values is None else cls.cast(pa.chunked_array([scalar.values])) 88 | 89 | @compute_field 90 | def count(self, mode: str = 'only_valid') -> Long: 91 | return pc.count(self.array, mode=mode).as_py() 92 | 93 | @classmethod 94 | def resolve_type(cls, obj, info, *_) -> str: 95 | config = Info(info, None).schema.config 96 | args = get_args(getattr(obj, '__orig_class__', None)) 97 | return config.name_converter.from_generic(obj.__strawberry_definition__, args) 98 | 99 | 100 | @strawberry.type(description="unique values and counts") 101 | class Set(Generic[T]): 102 | length = doc_field(Column.length) 103 | counts: list[Long] = strawberry.field(description="list of counts") 104 | 105 | def __init__(self, array, counts=pa.array([])): 106 | self.array, self.counts = array, counts.to_pylist() 107 | 108 | @doc_field 109 | def values(self) -> list[T | None]: 110 | """list of values""" 111 | return self.array.to_pylist() 112 | 113 | 114 | @Column.register(timedelta, pa.MonthDayNano) 115 | @strawberry.type(name='Column', description="column of elapsed times") 116 | class NominalColumn(Generic[T], Column): 117 | values = doc_field(Set.values) 118 | 119 | @compute_field 120 | def count_distinct(self, mode: str = 'only_valid') -> Long: 121 | return pc.count_distinct(self.array, mode=mode).as_py() 122 | 123 | @strawberry.field(description=Set.__strawberry_definition__.description) # type: ignore 124 | def unique(self, info: Info) -> Set[T]: 125 | if 'counts' in selections(*info.selected_fields): 126 | return Set(*self.array.value_counts().flatten()) 127 | return Set(self.array.unique()) 128 | 129 | @doc_field 130 | def value(self, index: Long = 0) -> T | None: 131 | """scalar value at index""" 132 | return self.array[index].as_py() 133 | 134 | @compute_field 135 | def drop_null(self) -> list[T]: 136 | return self.array.drop_null().to_pylist() 137 | 138 | 139 | @Column.register(date, datetime, time, bytes) 140 | @strawberry.type(name='Column', description="column of ordinal values") 141 | class OrdinalColumn(NominalColumn[T]): 142 | @compute_field 143 | def first(self, skip_nulls: bool = True, min_count: int = 0) -> T | None: 144 | return pc.first(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 145 | 146 | @compute_field 147 | def last(self, skip_nulls: bool = True, min_count: int = 0) -> T | None: 148 | return pc.last(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 149 | 150 | @compute_field 151 | def min(self, skip_nulls: bool = True, min_count: int = 0) -> T | None: 152 | return pc.min(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 153 | 154 | @compute_field 155 | def max(self, skip_nulls: bool = True, min_count: int = 0) -> T | None: 156 | return pc.max(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 157 | 158 | @compute_field 159 | def index(self, value: T, start: Long = 0, end: Long | None = None) -> Long: 160 | return C.index(self.array, value, start, end) 161 | 162 | @compute_field 163 | def fill_null(self, value: T) -> list[T]: 164 | return self.array.fill_null(value).to_pylist() 165 | 166 | 167 | @Column.register(str) 168 | @strawberry.type(name='ingColumn', description="column of strings") 169 | class StringColumn(OrdinalColumn[T]): ... 170 | 171 | 172 | @strawberry.type 173 | class IntervalColumn(OrdinalColumn[T]): 174 | @compute_field 175 | def mode(self, n: int = 1, skip_nulls: bool = True, min_count: int = 0) -> Set[T]: 176 | return Set(*pc.mode(self.array, n, skip_nulls=skip_nulls, min_count=min_count).flatten()) 177 | 178 | @compute_field 179 | def sum(self, skip_nulls: bool = True, min_count: int = 0) -> T | None: 180 | return pc.sum(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 181 | 182 | @compute_field 183 | def product(self, skip_nulls: bool = True, min_count: int = 0) -> T | None: 184 | return pc.product(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 185 | 186 | @compute_field 187 | def mean(self, skip_nulls: bool = True, min_count: int = 0) -> float | None: 188 | return pc.mean(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 189 | 190 | @compute_field 191 | def indices_nonzero(self) -> list[Long]: 192 | return pc.indices_nonzero(self.array).to_pylist() 193 | 194 | 195 | @Column.register(float, Decimal) 196 | @strawberry.type(name='Column', description="column of floats or decimals") 197 | class RatioColumn(IntervalColumn[T]): 198 | @compute_field 199 | def stddev(self, ddof: int = 0, skip_nulls: bool = True, min_count: int = 0) -> float | None: 200 | return pc.stddev(self.array, ddof=ddof, skip_nulls=skip_nulls, min_count=min_count).as_py() 201 | 202 | @compute_field 203 | def variance(self, ddof: int = 0, skip_nulls: bool = True, min_count: int = 0) -> float | None: 204 | options = {'skip_nulls': skip_nulls, 'min_count': min_count} 205 | return pc.variance(self.array, ddof=ddof, **options).as_py() 206 | 207 | @compute_field 208 | def quantile( 209 | self, 210 | q: list[float] = [0.5], 211 | interpolation: str = 'linear', 212 | skip_nulls: bool = True, 213 | min_count: int = 0, 214 | ) -> list[float | None]: 215 | options = {'skip_nulls': skip_nulls, 'min_count': min_count} 216 | return pc.quantile(self.array, q=q, interpolation=interpolation, **options).to_pylist() 217 | 218 | @compute_field 219 | def tdigest( 220 | self, 221 | q: list[float] = [0.5], 222 | delta: int = 100, 223 | buffer_size: int = 500, 224 | skip_nulls: bool = True, 225 | min_count: int = 0, 226 | ) -> list[float | None]: 227 | options = {'buffer_size': buffer_size, 'skip_nulls': skip_nulls, 'min_count': min_count} 228 | return pc.tdigest(self.array, q=q, delta=delta, **options).to_pylist() 229 | 230 | 231 | @Column.register(bool) 232 | @strawberry.type(name='eanColumn', description="column of booleans") 233 | class BooleanColumn(IntervalColumn[T]): 234 | @compute_field 235 | def any(self, skip_nulls: bool = True, min_count: int = 1) -> bool | None: 236 | return pc.any(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 237 | 238 | @compute_field 239 | def all(self, skip_nulls: bool = True, min_count: int = 1) -> bool | None: 240 | return pc.all(self.array, skip_nulls=skip_nulls, min_count=min_count).as_py() 241 | 242 | 243 | @Column.register(int, Long) 244 | @strawberry.type(name='Column', description="column of integers") 245 | class IntColumn(RatioColumn[T]): 246 | @doc_field 247 | def take_from( 248 | self, info: Info, field: str 249 | ) -> Annotated['Dataset', strawberry.lazy('.interface')] | None: 250 | """Select indices from a table on the root Query type.""" 251 | root = getattr(info.root_value, field) 252 | return type(root)(root.select(info).take(self.array.combine_chunks())) 253 | 254 | 255 | @Column.register(list) 256 | @strawberry.type(description="column of lists") 257 | class ListColumn(Column): 258 | @doc_field 259 | def value(self, index: Long = 0) -> Column | None: 260 | """scalar column at index""" 261 | return self.fromscalar(self.array[index]) 262 | 263 | @doc_field 264 | def values(self) -> list[Column | None]: 265 | """list of columns""" 266 | return list(map(self.fromscalar, self.array)) 267 | 268 | @compute_field 269 | def drop_null(self) -> list[Column]: 270 | return map(self.fromscalar, self.array.drop_null()) # type: ignore 271 | 272 | @doc_field 273 | def flatten(self) -> Column: 274 | """concatenation of all sub-lists""" 275 | return self.cast(pc.list_flatten(self.array)) 276 | 277 | 278 | @Column.register(dict) 279 | @strawberry.type(description="column of structs") 280 | class StructColumn(Column): 281 | @doc_field 282 | def value(self, index: Long = 0) -> dict | None: 283 | """scalar json object at index""" 284 | return self.array[index].as_py() 285 | 286 | @doc_field 287 | def names(self) -> list[str]: 288 | """field names""" 289 | return [field.name for field in self.array.type] 290 | 291 | @doc_field(name="field name(s); multiple names access nested fields") 292 | def column(self, name: list[str]) -> Column | None: 293 | """Return struct field as a column.""" 294 | return self.cast(pc.struct_field(self.array, name)) 295 | -------------------------------------------------------------------------------- /graphique/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/graphique/py.typed -------------------------------------------------------------------------------- /graphique/scalars.py: -------------------------------------------------------------------------------- 1 | """ 2 | GraphQL scalars. 3 | """ 4 | 5 | import functools 6 | from datetime import date, datetime, time, timedelta 7 | from decimal import Decimal 8 | import isodate 9 | import pyarrow as pa 10 | import strawberry 11 | 12 | 13 | def parse_long(value) -> int: 14 | if isinstance(value, int): 15 | return value 16 | raise TypeError(f"Long cannot represent value: {value}") 17 | 18 | 19 | def parse_duration(value): 20 | duration = isodate.parse_duration(value) 21 | if isinstance(duration, timedelta) and set(value.partition('T')[0]).isdisjoint('YM'): 22 | return duration 23 | months = getattr(duration, 'years', 0) * 12 + getattr(duration, 'months', 0) 24 | nanoseconds = duration.seconds * 1_000_000_000 + duration.microseconds * 1_000 25 | return pa.MonthDayNano([months, duration.days, nanoseconds]) 26 | 27 | 28 | duration_isoformat = functools.singledispatch(isodate.duration_isoformat) 29 | 30 | 31 | @duration_isoformat.register 32 | def _(mdn: pa.MonthDayNano) -> str: 33 | value = isodate.duration_isoformat( 34 | isodate.Duration(months=mdn.months, days=mdn.days, microseconds=mdn.nanoseconds // 1_000) 35 | ) 36 | return value if mdn.months else value.replace('P', 'P0M') 37 | 38 | 39 | Long = strawberry.scalar(int, name='Long', description="64-bit int", parse_value=parse_long) 40 | Duration = strawberry.scalar( 41 | timedelta | pa.MonthDayNano, 42 | name='Duration', 43 | description="Duration (isoformat)", 44 | specified_by_url="https://en.wikipedia.org/wiki/ISO_8601#Durations", 45 | serialize=duration_isoformat, 46 | parse_value=parse_duration, 47 | ) 48 | scalar_map = { 49 | bytes: strawberry.scalars.Base64, 50 | dict: strawberry.scalars.JSON, 51 | timedelta: Duration, 52 | pa.MonthDayNano: Duration, 53 | } 54 | 55 | type_map = { 56 | pa.lib.Type_BOOL: bool, 57 | pa.lib.Type_UINT8: int, 58 | pa.lib.Type_INT8: int, 59 | pa.lib.Type_UINT16: int, 60 | pa.lib.Type_INT16: int, 61 | pa.lib.Type_UINT32: Long, 62 | pa.lib.Type_INT32: int, 63 | pa.lib.Type_UINT64: Long, 64 | pa.lib.Type_INT64: Long, 65 | pa.lib.Type_HALF_FLOAT: float, 66 | pa.lib.Type_FLOAT: float, 67 | pa.lib.Type_DOUBLE: float, 68 | pa.lib.Type_DECIMAL32: Decimal, 69 | pa.lib.Type_DECIMAL64: Decimal, 70 | pa.lib.Type_DECIMAL128: Decimal, 71 | pa.lib.Type_DECIMAL256: Decimal, 72 | pa.lib.Type_DATE32: date, 73 | pa.lib.Type_DATE64: date, 74 | pa.lib.Type_TIMESTAMP: datetime, 75 | pa.lib.Type_TIME32: time, 76 | pa.lib.Type_TIME64: time, 77 | pa.lib.Type_DURATION: timedelta, 78 | pa.lib.Type_INTERVAL_MONTH_DAY_NANO: pa.MonthDayNano, 79 | pa.lib.Type_BINARY: bytes, 80 | pa.lib.Type_FIXED_SIZE_BINARY: bytes, 81 | pa.lib.Type_LARGE_BINARY: bytes, 82 | pa.lib.Type_STRING: str, 83 | pa.lib.Type_LARGE_STRING: str, 84 | pa.lib.Type_LIST: list, 85 | pa.lib.Type_FIXED_SIZE_LIST: list, 86 | pa.lib.Type_LARGE_LIST: list, 87 | pa.lib.Type_STRUCT: dict, 88 | } 89 | 90 | 91 | def py_type(dt: pa.DataType) -> type: 92 | return type_map[(dt.value_type if pa.types.is_dictionary(dt) else dt).id] 93 | -------------------------------------------------------------------------------- /graphique/service.py: -------------------------------------------------------------------------------- 1 | """ 2 | Default GraphQL service. 3 | 4 | Copy and customize as needed. Demonstrates: 5 | * federation versus root type 6 | * datasets, scanners, and tables 7 | * filtering and projection 8 | """ 9 | 10 | import json 11 | from pathlib import Path 12 | import pyarrow as pa 13 | import pyarrow.dataset as ds 14 | from starlette.config import Config 15 | from graphique.inputs import Expression 16 | from graphique import GraphQL 17 | 18 | config = Config('.env' if Path('.env').is_file() else None) 19 | PARQUET_PATH = Path(config('PARQUET_PATH')).resolve() 20 | FEDERATED = config('FEDERATED', default='') 21 | DEBUG = config('DEBUG', cast=bool, default=False) 22 | COLUMNS = config('COLUMNS', cast=json.loads, default=None) 23 | FILTERS = config('FILTERS', cast=json.loads, default=None) 24 | 25 | root = ds.dataset(PARQUET_PATH, partitioning='hive' if PARQUET_PATH.is_dir() else None) 26 | 27 | if isinstance(COLUMNS, dict): 28 | COLUMNS = {alias: ds.field(name) for alias, name in COLUMNS.items()} 29 | elif COLUMNS: 30 | root = root.replace_schema(pa.schema(map(root.schema.field, COLUMNS), root.schema.metadata)) 31 | COLUMNS = None 32 | if FILTERS is not None: 33 | root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow()) 34 | elif COLUMNS: 35 | root = root.scanner(columns=COLUMNS) 36 | 37 | if FEDERATED: 38 | app = GraphQL.federated({FEDERATED: root}, debug=DEBUG) 39 | else: 40 | app = GraphQL(root, debug=DEBUG) 41 | -------------------------------------------------------------------------------- /graphique/shell.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partition datasets out-of-core, in parquet hive format. 3 | 4 | It follows a 2-pass strategy. First, batches are scanned and partitioned into fragments, with 5 | multiple parts per fragment. 6 | 7 | Second, the partitioned dataset is rewritten to merge parts. Often the built-in `write_dataset` is 8 | sufficient once partitioned, but there is a `fragments` option to optimize for memory or show 9 | progress on the second pass. 10 | """ 11 | 12 | import operator 13 | import shutil 14 | from pathlib import Path 15 | from typing import Annotated, Callable 16 | import numpy as np 17 | import pyarrow.compute as pc 18 | import pyarrow.dataset as ds 19 | import typer # type: ignore 20 | from tqdm import tqdm # type: ignore 21 | 22 | 23 | def sort_key(name: str) -> tuple: 24 | """Parse sort order.""" 25 | return name.lstrip('-'), ('descending' if name.startswith('-') else 'ascending') 26 | 27 | 28 | def write_batches( 29 | scanner: ds.Scanner, base_dir: str, *partitioning: str, indices: str = '', **options 30 | ): 31 | """Partition dataset by batches. 32 | 33 | Optionally include original indices. 34 | """ 35 | options.update(format='parquet', partitioning=partitioning) 36 | options.update(partitioning_flavor='hive', existing_data_behavior='overwrite_or_ignore') 37 | with tqdm(total=scanner.count_rows(), desc="Batches") as pbar: 38 | for index, batch in enumerate(scanner.to_batches()): 39 | if indices: 40 | batch = batch.append_column(indices, pc.add(np.arange(len(batch)), pbar.n)) 41 | options['basename_template'] = f'part-{index}-{{i}}.parquet' 42 | ds.write_dataset(batch, base_dir, **options) 43 | pbar.update(len(batch)) 44 | 45 | 46 | def write_fragments(dataset: ds.Dataset, base_dir: str, func: Callable | None = None, **options): 47 | """Rewrite partition files by fragment to consolidate, optionally transforming.""" 48 | options['format'] = 'parquet' 49 | exprs = {Path(frag.path).parent: frag.partition_expression for frag in dataset.get_fragments()} 50 | offset = len(dataset.partitioning.schema) 51 | for path in tqdm(exprs, desc="Fragments"): 52 | part_dir = Path(base_dir, *path.parts[-offset:]) 53 | part = dataset.filter(exprs[path]) 54 | ds.write_dataset(func(part) if func else part, part_dir, **options) 55 | 56 | 57 | def partition( 58 | src: Annotated[str, typer.Argument(help="source path")], 59 | dest: Annotated[str, typer.Argument(help="destination path")], 60 | partitioning: Annotated[list[str], typer.Argument(help="partition keys")], 61 | fragments: Annotated[bool, typer.Option(help="iterate over fragments")] = False, 62 | sort: Annotated[list[str], typer.Option(help="sort keys; will load fragments")] = [], 63 | ): 64 | """Partition dataset by keys.""" 65 | temp = Path(dest) / 'temp' 66 | write_batches(ds.dataset(src, partitioning='hive'), str(temp), *partitioning) 67 | dataset = ds.dataset(temp, partitioning='hive') 68 | options = dict(partitioning_flavor='hive', existing_data_behavior='overwrite_or_ignore') 69 | if sorting := list(map(sort_key, sort)): 70 | write_fragments(dataset, dest, operator.methodcaller('sort_by', sorting)) 71 | elif fragments: 72 | write_fragments(dataset, dest) 73 | else: 74 | with tqdm(desc="Partitions"): 75 | ds.write_dataset(dataset, dest, partitioning=partitioning, **options) 76 | shutil.rmtree(temp) 77 | 78 | 79 | if __name__ == '__main__': 80 | partition.__doc__ = __doc__ 81 | typer.run(partition) 82 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: graphique 2 | site_url: https://coady.github.io/graphique/ 3 | site_description: GraphQL service for arrow tables and parquet data sets. 4 | theme: material 5 | 6 | repo_name: coady/graphique 7 | repo_url: https://github.com/coady/graphique 8 | edit_uri: "" 9 | 10 | nav: 11 | - Introduction: index.md 12 | - GraphQL API: api.md 13 | - Example Schema: schema.md 14 | - Example Queries: examples.ipynb 15 | - Core Reference: reference.md 16 | 17 | plugins: 18 | - search 19 | - mkdocstrings: 20 | handlers: 21 | python: 22 | options: 23 | show_root_heading: true 24 | - mkdocs-jupyter: 25 | execute: true 26 | allow_errors: false 27 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "devDependencies": { 3 | "graphql-markdown": ">=7" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "graphique" 7 | version = "1.8" 8 | dynamic = ["dependencies"] 9 | description = "GraphQL service for arrow tables and parquet data sets." 10 | readme = "README.md" 11 | requires-python = ">=3.10" 12 | license = {file = "LICENSE.txt"} 13 | authors = [{name = "Aric Coady", email = "aric.coady@gmail.com"}] 14 | keywords = ["graphql", "arrow", "parquet"] 15 | classifiers = [ 16 | "Development Status :: 5 - Production/Stable", 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: Science/Research", 19 | "License :: OSI Approved :: Apache Software License", 20 | "Operating System :: OS Independent", 21 | "Programming Language :: Python :: 3", 22 | "Programming Language :: Python :: 3.10", 23 | "Programming Language :: Python :: 3.11", 24 | "Programming Language :: Python :: 3.12", 25 | "Programming Language :: Python :: 3.13", 26 | "Topic :: Database :: Database Engines/Servers", 27 | "Topic :: Internet :: WWW/HTTP :: HTTP Servers", 28 | "Topic :: Internet :: WWW/HTTP :: Indexing/Search", 29 | "Topic :: Software Development :: Libraries :: Python Modules", 30 | "Typing :: Typed", 31 | ] 32 | 33 | [project.urls] 34 | Homepage = "https://github.com/coady/graphique" 35 | Documentation = "https://coady.github.io/graphique" 36 | Changelog = "https://github.com/coady/graphique/blob/main/CHANGELOG.md" 37 | Issues = "https://github.com/coady/graphique/issues" 38 | 39 | [project.optional-dependencies] 40 | server = ["uvicorn[standard]"] 41 | cli = ["tqdm"] 42 | 43 | [tool.setuptools.dynamic] 44 | dependencies = {file = "requirements.in"} 45 | 46 | [tool.ruff] 47 | line-length = 100 48 | 49 | [tool.ruff.format] 50 | quote-style = "preserve" 51 | 52 | [[tool.mypy.overrides]] 53 | module = ["numpy.*", "pyarrow.*", "strawberry.*", "starlette.*", "isodate.*"] 54 | ignore_missing_imports = true 55 | 56 | [tool.coverage.run] 57 | source = ["graphique"] 58 | branch = true 59 | omit = ["graphique/shell.py"] 60 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | pyarrow>=19 2 | strawberry-graphql[asgi,cli]>=0.236 3 | numpy 4 | isodate>=0.7 5 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from importlib import metadata 5 | from pathlib import Path 6 | import pyarrow.dataset as ds 7 | import pytest 8 | 9 | fixtures = Path(__file__).parent / 'fixtures' 10 | 11 | 12 | def pytest_report_header(config): 13 | return [f'{name}: {metadata.version(name)}' for name in ('pyarrow', 'strawberry-graphql')] 14 | 15 | 16 | class TestClient: 17 | def __init__(self, app): 18 | self.app = app 19 | 20 | def _execute(self, query): 21 | root_value = self.app.root_value 22 | return self.app.schema.execute_sync(query, root_value=root_value, context_value={}) 23 | 24 | def execute(self, query): 25 | result = self._execute(query) 26 | for error in result.errors or []: 27 | raise ValueError(error) 28 | return result.data 29 | 30 | 31 | def load(path, **vars): 32 | os.environ.update(vars, PARQUET_PATH=str(fixtures / path)) 33 | sys.modules.pop('graphique.service', None) 34 | sys.modules.pop('graphique.settings', None) 35 | from graphique.service import app 36 | 37 | for var in vars: 38 | del os.environ[var] 39 | return app 40 | 41 | 42 | @pytest.fixture(scope='module') 43 | def table(): 44 | return ds.dataset(fixtures / 'zipcodes.parquet').to_table() 45 | 46 | 47 | @pytest.fixture(scope='module') 48 | def client(): 49 | filters = json.dumps({'zipcode': {'gt': 0}}) 50 | app = load('zipcodes.parquet', FILTERS=filters) 51 | return TestClient(app) 52 | 53 | 54 | @pytest.fixture(params=[None, ['zipcode', 'state', 'county']], scope='module') 55 | def dsclient(request): 56 | app = load('zipcodes.parquet', COLUMNS=json.dumps(request.param)) 57 | return TestClient(app) 58 | 59 | 60 | @pytest.fixture(scope='module') 61 | def partclient(request): 62 | app = load('partitioned') 63 | return TestClient(app) 64 | 65 | 66 | @pytest.fixture(scope='module') 67 | def fedclient(): 68 | from .federated import app 69 | 70 | return TestClient(app) 71 | 72 | 73 | @pytest.fixture(scope='module') 74 | def aliasclient(): 75 | columns = {'snakeId': 'snake_id', 'camelId': 'camelId'} 76 | app = load('alltypes.parquet', COLUMNS=json.dumps(columns)) 77 | return TestClient(app) 78 | 79 | 80 | @pytest.fixture(scope='module') 81 | def executor(): 82 | app = load('alltypes.parquet', FILTERS='{}') 83 | return TestClient(app).execute 84 | -------------------------------------------------------------------------------- /tests/federated.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pyarrow.dataset as ds 3 | from graphique import GraphQL, core 4 | 5 | fixtures = Path(__file__).parent / 'fixtures' 6 | dataset = ds.dataset(fixtures / 'zipcodes.parquet') 7 | roots = { 8 | 'zipcodes': core.Nodes.scan(dataset, dataset.schema.names), 9 | 'states': core.Table.sort(dataset.to_table(), 'state', 'county', indices='indices'), 10 | 'zip_db': ds.dataset(fixtures / 'zip_db.parquet'), 11 | } 12 | app = GraphQL.federated(roots, keys={'zipcodes': ['zipcode']}) 13 | -------------------------------------------------------------------------------- /tests/fixtures/alltypes.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/alltypes.parquet -------------------------------------------------------------------------------- /tests/fixtures/nofields.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/nofields.parquet -------------------------------------------------------------------------------- /tests/fixtures/partitioned/north=0/west=0/18ed2d55859f4e5aabd025832d04a421-0.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/partitioned/north=0/west=0/18ed2d55859f4e5aabd025832d04a421-0.parquet -------------------------------------------------------------------------------- /tests/fixtures/partitioned/north=0/west=1/18ed2d55859f4e5aabd025832d04a421-0.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/partitioned/north=0/west=1/18ed2d55859f4e5aabd025832d04a421-0.parquet -------------------------------------------------------------------------------- /tests/fixtures/partitioned/north=1/west=0/18ed2d55859f4e5aabd025832d04a421-0.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/partitioned/north=1/west=0/18ed2d55859f4e5aabd025832d04a421-0.parquet -------------------------------------------------------------------------------- /tests/fixtures/partitioned/north=1/west=1/18ed2d55859f4e5aabd025832d04a421-0.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/partitioned/north=1/west=1/18ed2d55859f4e5aabd025832d04a421-0.parquet -------------------------------------------------------------------------------- /tests/fixtures/zip_db.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/zip_db.parquet -------------------------------------------------------------------------------- /tests/fixtures/zipcodes.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coady/graphique/17e5922e8420acef18858d204250d78a518c5e4c/tests/fixtures/zipcodes.parquet -------------------------------------------------------------------------------- /tests/requirements.in: -------------------------------------------------------------------------------- 1 | -r ../requirements.in 2 | pytest-cov 3 | pytest-codspeed 4 | -------------------------------------------------------------------------------- /tests/test_bench.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from graphique.core import Nodes, Table as T 3 | 4 | 5 | @pytest.mark.benchmark 6 | def test_group(table): 7 | Nodes('table_source', table).group('state', 'county', 'city') 8 | T.runs(table, 'state', 'county', 'city') 9 | 10 | 11 | @pytest.mark.benchmark 12 | def test_rank(table): 13 | T.rank(table, 1, 'state', 'county', 'city') 14 | T.rank(table, 10, 'state', 'county', 'city') 15 | 16 | 17 | @pytest.mark.benchmark 18 | def test_sort(table): 19 | T.sort(table, 'state', 'county', 'city', length=1) 20 | T.sort(table, 'state', 'county', 'city', length=10) 21 | T.sort(table, 'state', 'county', 'city') 22 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | import pyarrow as pa 2 | import pyarrow.compute as pc 3 | import pyarrow.dataset as ds 4 | import pytest 5 | from graphique.core import ListChunk, Nodes, Column as C, Table as T 6 | from graphique.scalars import parse_duration, duration_isoformat 7 | 8 | 9 | def test_duration(): 10 | assert duration_isoformat(parse_duration('P1Y1M1DT1H1M1.1S')) == 'P13M1DT1H1M1.1S' 11 | assert duration_isoformat(parse_duration('P1M1DT1H1M1.1S')) == 'P1M1DT1H1M1.1S' 12 | assert duration_isoformat(parse_duration('P1DT1H1M1.1S')) == 'P1DT1H1M1.1S' 13 | assert duration_isoformat(parse_duration('PT1H1M1.1S')) == 'PT1H1M1.1S' 14 | assert duration_isoformat(parse_duration('PT1M1.1S')) == 'PT1M1.1S' 15 | assert duration_isoformat(parse_duration('PT1.1S')) == 'PT1.1S' 16 | assert duration_isoformat(parse_duration('PT1S')) == 'PT1S' 17 | assert duration_isoformat(parse_duration('P0D')) == 'P0D' 18 | assert duration_isoformat(parse_duration('PT0S')) == 'P0D' 19 | assert duration_isoformat(parse_duration('P0MT')) == 'P0M0D' 20 | assert duration_isoformat(parse_duration('P0YT')) == 'P0M0D' 21 | with pytest.raises(ValueError): 22 | duration_isoformat(parse_duration('T1H')) 23 | with pytest.raises(ValueError): 24 | duration_isoformat(parse_duration('P1H')) 25 | 26 | 27 | def test_dictionary(table): 28 | array = table['state'].dictionary_encode() 29 | table = pa.table({'state': array}) 30 | assert T.sort(table, 'state')['state'][0].as_py() == 'AK' 31 | array = pa.chunked_array([['a', 'b'], ['a', 'b', None]]).dictionary_encode() 32 | assert C.fill_null_backward(array) == array.combine_chunks() 33 | assert C.fill_null_forward(array)[-1].as_py() == 'b' 34 | assert C.fill_null(array[3:], "c").to_pylist() == list('bc') 35 | assert C.fill_null(array[:3], "c").to_pylist() == list('aba') 36 | assert C.sort_values(array.combine_chunks()).to_pylist() == [1, 2, 1, 2, None] 37 | 38 | 39 | def test_chunks(): 40 | array = pa.chunked_array([pa.array(list(chunk)).dictionary_encode() for chunk in ('aba', 'ca')]) 41 | assert C.index(array, 'a') == 0 42 | assert C.index(array, 'c') == 3 43 | assert C.index(array, 'a', start=3) == 4 44 | assert C.index(array, 'b', start=2) == -1 45 | 46 | 47 | def test_lists(): 48 | array = pa.array([[2, 1], [0, 0], [None], [], None]) 49 | assert ListChunk.first(array).to_pylist() == [2, 0, None, None, None] 50 | assert ListChunk.element(array, -2).to_pylist() == [2, 0, None, None, None] 51 | assert ListChunk.last(array).to_pylist() == [1, 0, None, None, None] 52 | assert ListChunk.last(pa.chunked_array([array])).to_pylist() == [1, 0, None, None, None] 53 | assert ListChunk.element(array, 1).to_pylist() == [1, 0, None, None, None] 54 | assert ListChunk.min(array).to_pylist() == [1, 0, None, None, None] 55 | assert ListChunk.max(array).to_pylist() == [2, 0, None, None, None] 56 | assert ListChunk.mode(array)[0].as_py() == [{'mode': 1, 'count': 1}] 57 | assert ListChunk.quantile(array).to_pylist() == [[1.5], [0.0], [None], [None], [None]] 58 | quantile = ListChunk.quantile(array, q=[0.75]) 59 | assert quantile.to_pylist() == [[1.75], [0.0], [None], [None], [None]] 60 | array = pa.array([[True, True], [False, False], [None], [], None]) 61 | array = pa.ListArray.from_arrays([0, 2, 3], pa.array(["a", "b", None]).dictionary_encode()) 62 | assert ListChunk.min(array).to_pylist() == ["a", None] 63 | assert ListChunk.max(array).to_pylist() == ["b", None] 64 | assert C.is_list_type(pa.FixedSizeListArray.from_arrays([], 1)) 65 | array = pa.array([[list('ab'), ['c']], [list('de')]]) 66 | assert ListChunk.inner_flatten(array).to_pylist() == [list('abc'), list('de')] 67 | batch = T.from_offsets(pa.record_batch([list('abcde')], ['col']), pa.array([0, 3, 5])) 68 | assert batch['col'].to_pylist() == [list('abc'), list('de')] 69 | assert not T.from_offsets(pa.table({}), pa.array([0])) 70 | array = ListChunk.from_counts(pa.array([3, None, 2]), list('abcde')) 71 | assert array.to_pylist() == [list('abc'), None, list('de')] 72 | with pytest.raises(ValueError): 73 | T.list_value_length(pa.table({'x': array, 'y': array[::-1]})) 74 | 75 | 76 | def test_membership(): 77 | array = pa.chunked_array([[1, 1]]) 78 | assert C.index(array, 1) == C.index(array, 1, end=1) == 0 79 | assert C.index(array, 1, start=1) == 1 80 | assert C.index(array, 1, start=2) == -1 81 | 82 | 83 | def test_nodes(table): 84 | dataset = ds.dataset(table).filter(pc.field('state') == 'CA') 85 | (column,) = Nodes.scan(dataset, columns={'_': pc.field('state')}).to_table() 86 | assert column.unique().to_pylist() == ['CA'] 87 | table = Nodes.group(dataset, 'county', 'city', counts=([], 'hash_count_all', None)).to_table() 88 | assert len(table) == 1241 89 | assert pc.sum(table['counts']).as_py() == 2647 90 | scanner = Nodes.scan(dataset, columns=['state']) 91 | assert scanner.schema.names == ['state'] 92 | assert scanner.group('state').to_table() == pa.table({'state': ['CA']}) 93 | assert scanner.count_rows() == 2647 94 | assert scanner.head(3) == pa.table({'state': ['CA'] * 3}) 95 | assert scanner.take([0, 2]) == pa.table({'state': ['CA'] * 2}) 96 | 97 | 98 | def test_runs(table): 99 | groups, counts = T.runs(table, 'state') 100 | assert len(groups) == len(counts) == 66 101 | assert pc.sum(counts).as_py() == 41700 102 | assert groups['state'][0].as_py() == 'NY' 103 | assert groups['county'][0].values.to_pylist() == ['Suffolk', 'Suffolk'] 104 | groups, counts = T.runs(table, 'state', 'county') 105 | assert len(groups) == len(counts) == 22751 106 | groups, counts = T.runs(table, zipcode=(pc.greater, 100)) 107 | assert len(groups) == len(counts) == 59 108 | tbl = T.sort(table, 'state', 'longitude') 109 | groups, counts = T.runs(tbl, 'state', longitude=(pc.greater, 1.0)) 110 | assert len(groups) == len(counts) == 62 111 | assert groups['state'].value_counts()[0].as_py() == {'values': 'AK', 'counts': 7} 112 | assert groups['longitude'][:2].to_pylist() == [[-174.213333], [-171.701685]] 113 | groups, counts = T.runs(tbl, 'state', longitude=(pc.less,)) 114 | assert len(groups) == len(counts) == 52 115 | 116 | 117 | def test_sort(table): 118 | data = T.sort(table, 'state').to_pydict() 119 | assert (data['state'][0], data['county'][0]) == ('AK', 'Anchorage') 120 | data = T.sort(table, 'state', 'county', length=1).to_pydict() 121 | assert (data['state'], data['county']) == (['AK'], ['Aleutians East']) 122 | tbl = T.sort(table, 'state', '-county', length=2, null_placement='at_start') 123 | assert tbl.schema.pandas_metadata == {'index_columns': ['state']} 124 | assert tbl['state'].to_pylist() == ['AK'] * 2 125 | assert tbl['county'].to_pylist() == ['Yukon Koyukuk'] * 2 126 | counts = T.rank(table, 1, 'state')['state'].value_counts().to_pylist() 127 | assert counts == [{'values': 'AK', 'counts': 273}] 128 | counts = T.rank(table, 1, 'state', '-county')['county'].value_counts().to_pylist() 129 | assert counts == [{'values': 'Yukon Koyukuk', 'counts': 30}] 130 | counts = T.rank(table, 2, 'state')['state'].value_counts().to_pylist() 131 | assert counts == [{'values': 'AL', 'counts': 838}, {'values': 'AK', 'counts': 273}] 132 | counts = T.rank(table, 2, 'state', '-county')['county'].value_counts().to_pylist() 133 | assert counts == [{'counts': 30, 'values': 'Yukon Koyukuk'}, {'counts': 1, 'values': 'Yakutat'}] 134 | table = pa.table({'x': [list('ab'), [], None, ['c']]}) 135 | (column,) = T.map_list(table, T.sort, '-x', length=2) 136 | assert column.to_pylist() == [list('ba'), [], None, ['c']] 137 | 138 | 139 | def test_numeric(): 140 | array = pa.array([0.0, 10.0, 20.0]) 141 | scalar = pa.scalar([10.0]) 142 | assert pc.call_function('digitize', [array, scalar, False]).to_pylist() == [0, 1, 1] 143 | assert pc.call_function('digitize', [array, scalar, True]).to_pylist() == [0, 0, 1] 144 | 145 | 146 | def test_list(): 147 | array = pa.array([[False], [True], [False, True]]) 148 | assert pc.call_function('list_all', [array]).to_pylist() == [False, True, False] 149 | assert pc.call_function('list_any', [array]).to_pylist() == [False, True, True] 150 | 151 | 152 | def test_not_implemented(): 153 | dictionary = pa.array(['']).dictionary_encode() 154 | with pytest.raises((NotImplementedError, TypeError)): 155 | pc.sort_indices(pa.table({'': dictionary}), [('', 'ascending')]) 156 | with pytest.raises(NotImplementedError): 157 | dictionary.index('') 158 | with pytest.raises(NotImplementedError): 159 | pc.first_last(dictionary) 160 | with pytest.raises(NotImplementedError): 161 | pc.min_max(dictionary) 162 | with pytest.raises(NotImplementedError): 163 | pc.count_distinct(dictionary) 164 | with pytest.raises(NotImplementedError): 165 | pc.utf8_lower(dictionary) 166 | with pytest.raises(NotImplementedError): 167 | pa.StructArray.from_arrays([], []).dictionary_encode() 168 | for index in (-1, 1): 169 | with pytest.raises(ValueError): 170 | pc.list_element(pa.array([[0]]), index) 171 | with pytest.raises(NotImplementedError): 172 | pc.any([0]) 173 | array = pa.array(list('aba')) 174 | with pytest.raises(NotImplementedError): 175 | pa.table({'': array.dictionary_encode()}).group_by('').aggregate([('', 'min')]) 176 | with pytest.raises(NotImplementedError): 177 | pa.table({'': array}).group_by('').aggregate([('', 'any')]) 178 | agg = 'value', 'max', pc.ScalarAggregateOptions(min_count=4) 179 | table = pa.table({'value': list('abc'), 'key': [0, 1, 0]}) 180 | table = table.group_by('key').aggregate([agg]) 181 | assert table['value_max'].to_pylist() == list('cb') # min_count has no effect 182 | for name in ('one', 'list', 'distinct'): 183 | assert not hasattr(pc, name) 184 | with pytest.raises(NotImplementedError): 185 | pc.fill_null_forward(dictionary) 186 | with pytest.raises(NotImplementedError): 187 | pa.table({'': list('aba')}).group_by([]).aggregate([('', 'first'), ('', 'last')]) 188 | with pytest.raises(ValueError): 189 | pc.pairwise_diff(pa.chunked_array([[0]])) 190 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import pytest 3 | from graphique import middleware 4 | from .conftest import load 5 | 6 | 7 | def test_extensions(): 8 | ext = middleware.MetricsExtension(type('', (), {'context': {}})) 9 | for name in ('operation', 'parse', 'validate'): 10 | assert list(getattr(ext, 'on_' + name)()) == [None] 11 | assert set(ext.get_results()['metrics']) == {'duration', 'execution'} 12 | 13 | 14 | def test_filter(dsclient): 15 | data = dsclient.execute('{ column(name: "state") { length } }') 16 | assert data == {'column': {'length': 41700}} 17 | data = dsclient.execute('{ length row { state } }') 18 | assert data == {'length': 41700, 'row': {'state': 'NY'}} 19 | data = dsclient.execute('{ filter(state: {eq: ["CA", "NY"]}) { length } }') 20 | assert data == {'filter': {'length': 4852}} 21 | data = dsclient.execute('{ filter(state: {ne: "CA"}) { length } }') 22 | assert data == {'filter': {'length': 39053}} 23 | data = dsclient.execute('{ filter { length } }') 24 | assert data == {'filter': {'length': 41700}} 25 | data = dsclient.execute('{ filter(state: {ne: null}) { length } }') 26 | assert data == {'filter': {'length': 41700}} 27 | data = dsclient.execute('{ dropNull { length } }') 28 | assert data == {'dropNull': {'length': 41700}} 29 | 30 | 31 | def test_search(dsclient): 32 | data = dsclient.execute('{ filter(zipcode: {lt: 10000}) { length } }') 33 | assert data == {'filter': {'length': 3224}} 34 | data = dsclient.execute('{ filter(zipcode: {}) { length } }') 35 | assert data == {'filter': {'length': 41700}} 36 | data = dsclient.execute('{ filter(zipcode: {}) { row { zipcode } } }') 37 | assert data == {'filter': {'row': {'zipcode': 501}}} 38 | data = dsclient.execute("""{ filter(zipcode: {gt: 90000}) { filter(state: {eq: "CA"}) { 39 | length } } }""") 40 | assert data == {'filter': {'filter': {'length': 2647}}} 41 | data = dsclient.execute("""{ filter(zipcode: {gt: 90000}) { filter(state: {eq: "CA"}) { 42 | length row { zipcode } } } }""") 43 | assert data == {'filter': {'filter': {'length': 2647, 'row': {'zipcode': 90001}}}} 44 | data = dsclient.execute("""{ filter(zipcode: {lt: 90000}) { filter(state: {eq: "CA"}) { 45 | group(by: "county") { length } } } }""") 46 | assert data == {'filter': {'filter': {'group': {'length': 0}}}} 47 | 48 | 49 | def test_slice(dsclient): 50 | data = dsclient.execute('{ slice(length: 3) { length } }') 51 | assert data == {'slice': {'length': 3}} 52 | data = dsclient.execute('{ slice(offset: -3) { length } }') 53 | assert data == {'slice': {'length': 3}} 54 | data = dsclient.execute('{ slice { length } }') 55 | assert data == {'slice': {'length': 41700}} 56 | data = dsclient.execute('{ take(indices: [0]) { row { zipcode } } }') 57 | assert data == {'take': {'row': {'zipcode': 501}}} 58 | data = dsclient.execute('{ any many: any(length: 50000)}') 59 | assert data == {'any': True, 'many': False} 60 | data = dsclient.execute('{ size }') 61 | assert data == {'size': None} 62 | data = dsclient.execute('{ slice { size } }') 63 | assert data == {'slice': {'size': 0}} 64 | 65 | 66 | def test_group(dsclient): 67 | data = dsclient.execute( 68 | """{ group(by: ["state"], aggregate: {min: {name: "county"}}) { row { state county } } }""" 69 | ) 70 | assert data == {'group': {'row': {'state': 'NY', 'county': 'Albany'}}} 71 | data = dsclient.execute("""{ group(by: ["state"], counts: "c") { slice(length: 1) { 72 | column(name: "c") { ... on LongColumn { values } } } } }""") 73 | assert data == {'group': {'slice': {'column': {'values': [2205]}}}} 74 | data = dsclient.execute( 75 | '{ group(by: ["state"], aggregate: {first: {name: "county"}}) { row { county } } }' 76 | ) 77 | assert data == {'group': {'row': {'county': 'Suffolk'}}} 78 | data = dsclient.execute( 79 | '{ group(by: ["state"], aggregate: {one: {name: "county"}}) { row { county } } }' 80 | ) 81 | assert data['group']['row']['county'] 82 | data = dsclient.execute( 83 | """{ group(by: ["state"], aggregate: {mean: {name: "zipcode"}}) { slice(length: 1) { 84 | column(name: "zipcode") { ... on FloatColumn { values } } } } }""" 85 | ) 86 | assert data == {'group': {'slice': {'column': {'values': [pytest.approx(12614.62721)]}}}} 87 | data = dsclient.execute( 88 | """{ group(by: ["state"], aggregate: {list: {name: "zipcode"}}) { aggregate(mean: {name: "zipcode"}) { 89 | slice(length: 1) { column(name: "zipcode") { ... on FloatColumn { values } } } } } }""" 90 | ) 91 | assert data == { 92 | 'group': {'aggregate': {'slice': {'column': {'values': [pytest.approx(12614.62721)]}}}} 93 | } 94 | data = dsclient.execute("""{ group(aggregate: {min: {alias: "st", name: "state"}}) { 95 | column(name: "st") { ... on StringColumn { values } } } }""") 96 | assert data == {'group': {'column': {'values': ['AK']}}} 97 | 98 | 99 | def test_list(partclient): 100 | data = partclient.execute( 101 | """{ group(by: "state", aggregate: {distinct: {alias: "counties", name: "county"}}) { 102 | tables { row { state } column(name: "counties") { length } } } } """ 103 | ) 104 | (table,) = [table for table in data['group']['tables'] if table['row']['state'] == 'PR'] 105 | assert table == {'row': {'state': 'PR'}, 'column': {'length': 78}} 106 | data = partclient.execute("""{ group(by: "north", aggregate: {distinct: {name: "west"}}) { 107 | tables { row { north } columns { west { length } } } } }""") 108 | tables = data['group']['tables'] 109 | assert {table['row']['north'] for table in tables} == {0, 1} 110 | assert [table['columns'] for table in tables] == [{'west': {'length': 2}}] * 2 111 | 112 | 113 | def test_fragments(partclient): 114 | data = partclient.execute('{ group(by: ["north", "west"]) { columns { north { values } } } }') 115 | data = partclient.execute( 116 | '{ group(by: ["north", "west"], counts: "c") { column(name: "c") { ... on LongColumn { values } } } }' 117 | ) 118 | assert data == {'group': {'column': {'values': [9301, 11549, 11549, 9301]}}} 119 | data = partclient.execute('{ rank(by: "north") { row { north } } }') 120 | assert data == {'rank': {'row': {'north': 0}}} 121 | data = partclient.execute('{ rank(by: ["-north", "-zipcode"]) { row { zipcode } } }') 122 | assert data == {'rank': {'row': {'zipcode': 99950}}} 123 | data = partclient.execute('{ sort(by: "north", length: 1) { row { north } } }') 124 | assert data == {'sort': {'row': {'north': 0}}} 125 | data = partclient.execute( 126 | '{ group(by: ["north"], aggregate: {max: {name: "zipcode"}}) { row { north zipcode } } }' 127 | ) 128 | assert data['group']['row']['zipcode'] >= 96898 129 | data = partclient.execute( 130 | '{ group(by: [], aggregate: {min: {name: "state"}}) { length row { state } } }' 131 | ) 132 | assert data == {'group': {'length': 1, 'row': {'state': 'AK'}}} 133 | data = partclient.execute( 134 | """{ group(by: ["north", "west"], aggregate: {distinct: {name: "city"}, mean: {name: "zipcode"}}) { 135 | length column(name: "city") { type } } }""" 136 | ) 137 | assert data == {'group': {'length': 4, 'column': {'type': 'list'}}} 138 | data = partclient.execute("""{ group(by: "north", aggregate: {countDistinct: {name: "west"}}) { 139 | column(name: "west") { ... on LongColumn { values } } } }""") 140 | assert data == {'group': {'column': {'values': [2, 2]}}} 141 | data = partclient.execute( 142 | '{ group(by: "north", counts: "c") { column(name: "c") { ... on LongColumn { values } } } }' 143 | ) 144 | assert data == {'group': {'column': {'values': [20850, 20850]}}} 145 | 146 | 147 | def test_schema(dsclient): 148 | schema = dsclient.execute('{ schema { names types partitioning } }')['schema'] 149 | assert set(schema['names']) >= {'zipcode', 'state', 'county'} 150 | assert set(schema['types']) >= {'int32', 'string'} 151 | assert len(schema['partitioning']) in (0, 6) 152 | data = dsclient.execute('{ scan(filter: {}) { type } }') 153 | assert data == {'scan': {'type': 'FileSystemDataset'}} 154 | data = dsclient.execute('{ scan(columns: {name: "zipcode"}) { type } }') 155 | assert data == {'scan': {'type': 'Nodes'}} 156 | result = dsclient._execute('{ length optional { tables { length } } }') 157 | assert result.data == {'length': 41700, 'optional': None} 158 | assert len(result.errors) == 1 159 | 160 | 161 | def test_scan(dsclient): 162 | data = dsclient.execute( 163 | '{ scan(columns: {name: "zipcode", alias: "zip"}) { column(name: "zip") { type } } }' 164 | ) 165 | assert data == {'scan': {'column': {'type': 'int32'}}} 166 | data = dsclient.execute( 167 | '{ scan(filter: {eq: [{name: "county"}, {name: "state"}]}) { length } }' 168 | ) 169 | assert data == {'scan': {'length': 0}} 170 | data = dsclient.execute('{ scan(filter: {eq: [{name: "zipcode"}, {value: null}]}) { length } }') 171 | assert data == {'scan': {'length': 0}} 172 | data = dsclient.execute( 173 | '{ scan(filter: {inv: {ne: [{name: "zipcode"}, {value: null}]}}) { length } }' 174 | ) 175 | assert data == {'scan': {'length': 0}} 176 | data = dsclient.execute( 177 | '{ scan(filter: {eq: [{name: "state"} {value: "CA", cast: "string"}]}) { length } }' 178 | ) 179 | assert data == {'scan': {'length': 2647}} 180 | data = dsclient.execute( 181 | '{ scan(filter: {eq: [{name: "state"} {value: ["CA", "OR"]}]}) { length } }' 182 | ) 183 | assert data == {'scan': {'length': 3131}} 184 | with pytest.raises(ValueError, match="conflicting inputs"): 185 | dsclient.execute('{ scan(filter: {name: "state", value: "CA"}) { length } }') 186 | with pytest.raises(ValueError, match="name or alias"): 187 | dsclient.execute('{ scan(columns: {}) { length } }') 188 | data = dsclient.execute("""{ scan(filter: {eq: [{name: "state"}, {value: "CA"}]}) 189 | { scan(filter: {eq: [{name: "county"}, {value: "Santa Clara"}]}) 190 | { length row { county } } } }""") 191 | assert data == {'scan': {'scan': {'length': 108, 'row': {'county': 'Santa Clara'}}}} 192 | data = dsclient.execute("""{ scan(filter: {or: [{eq: [{name: "state"}, {value: "CA"}]}, 193 | {eq: [{name: "county"}, {value: "Santa Clara"}]}]}) { length } }""") 194 | assert data == {'scan': {'length': 2647}} 195 | 196 | 197 | def test_rank(partclient): 198 | data = partclient.execute('{ rank(by: ["state"]) { length row { state } } }') 199 | assert data == {'rank': {'length': 273, 'row': {'state': 'AK'}}} 200 | data = partclient.execute('{ rank(by: ["-state", "-county"]) { length row { state county } } }') 201 | assert data == {'rank': {'length': 4, 'row': {'state': 'WY', 'county': 'Weston'}}} 202 | data = partclient.execute('{ sort(by: "state", length: 3) { columns { state { values } } } }') 203 | assert data == {'sort': {'columns': {'state': {'values': ['AK'] * 3}}}} 204 | data = partclient.execute('{ rank(by: "north") { length } }') 205 | assert data == {'rank': {'length': 20850}} 206 | data = partclient.execute('{ rank(by: "north", max: 2) { length } }') 207 | assert data == {'rank': {'length': 41700}} 208 | data = partclient.execute('{ rank(by: ["north", "west"]) { length } }') 209 | assert data == {'rank': {'length': 9301}} 210 | data = partclient.execute('{ rank(by: ["north", "west"], max: 2) { length } }') 211 | assert data == {'rank': {'length': 20850}} 212 | data = partclient.execute('{ rank(by: ["north", "west"], max: 3) { length } }') 213 | assert data == {'rank': {'length': 32399}} 214 | data = partclient.execute( 215 | '{ rank(by: ["north", "state"], max: 2) { columns { state { unique { values } } } } }' 216 | ) 217 | assert data == {'rank': {'columns': {'state': {'unique': {'values': ['AL', 'AR']}}}}} 218 | data = partclient.execute('{ sort(by: "north", length: 3) { length } }') 219 | assert data == {'sort': {'length': 3}} 220 | data = partclient.execute('{ sort(by: "north", length: 50000) { length } }') 221 | assert data == {'sort': {'length': 41700}} 222 | 223 | 224 | def test_root(): 225 | app = load('zipcodes.parquet', FEDERATED='test') 226 | assert asyncio.run(app.get_root_value(None)) is app.root_value 227 | assert app.root_value.test 228 | with pytest.warns(UserWarning): 229 | assert load('nofields.parquet', FEDERATED='test') 230 | 231 | 232 | def test_federation(fedclient): 233 | data = fedclient.execute( 234 | '{ _service { sdl } zipcodes { __typename length } zipDb { __typename length } }' 235 | ) 236 | assert data['_service']['sdl'] 237 | assert data['zipcodes'] == {'__typename': 'ZipcodesTable', 'length': 41700} 238 | assert data['zipDb'] == {'__typename': 'ZipDbTable', 'length': 42724} 239 | 240 | data = fedclient.execute("""{ zipcodes { scan(columns: {name: "zipcode", cast: "int64"}) { 241 | join(right: "zip_db", keys: "zipcode", rightKeys: "zip") { length schema { names } } } } }""") 242 | table = data['zipcodes']['scan']['join'] 243 | assert table['length'] == 41700 244 | assert set(table['schema']['names']) > {'zipcode', 'timezone', 'latitude'} 245 | data = fedclient.execute( 246 | """{ zipcodes { scan(columns: {alias: "zip", name: "zipcode", cast: "int64"}) { 247 | join(right: "zip_db", keys: "zip", joinType: "right outer") { length schema { names } } } } }""" 248 | ) 249 | table = data['zipcodes']['scan']['join'] 250 | assert table['length'] == 42724 251 | assert set(table['schema']['names']) > {'zip', 'timezone', 'latitude'} 252 | 253 | data = fedclient.execute( 254 | """{ _entities(representations: {__typename: "ZipcodesTable", zipcode: 90001}) { 255 | ... on ZipcodesTable { length type row { state } } } }""" 256 | ) 257 | assert data == {'_entities': [{'length': 1, 'type': 'Nodes', 'row': {'state': 'CA'}}]} 258 | data = fedclient.execute("""{ states { filter(state: {eq: "CA"}) { columns { indices { 259 | takeFrom(field: "zipcodes") { __typename column(name: "state") { length } } } } } } }""") 260 | table = data['states']['filter']['columns']['indices']['takeFrom'] 261 | assert table == {'__typename': 'ZipcodesTable', 'column': {'length': 2647}} 262 | 263 | 264 | def test_sorted(fedclient): 265 | data = fedclient.execute( 266 | '{ states { filter(state: {eq: "CA"}, county: {eq: "Santa Clara"}) { length } } }' 267 | ) 268 | assert data == {'states': {'filter': {'length': 108}}} 269 | data = fedclient.execute( 270 | '{ states { filter(state: {eq: ["CA", "OR"]}, county: {eq: "Santa Clara"}) { length } } }' 271 | ) 272 | assert data == {'states': {'filter': {'length': 108}}} 273 | data = fedclient.execute( 274 | '{ states { filter(state: {le: "CA"}, county: {eq: "Santa Clara"}) { length } } }' 275 | ) 276 | assert data == {'states': {'filter': {'length': 108}}} 277 | data = fedclient.execute('{ states { filter { filter(state: {eq: "CA"}) { length } } } }') 278 | assert data == {'states': {'filter': {'filter': {'length': 2647}}}} 279 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_camel(aliasclient): 5 | data = aliasclient.execute('{ schema { index names } }') 6 | assert data == {'schema': {'index': [], 'names': ['snakeId', 'camelId']}} 7 | data = aliasclient.execute('{ row { snakeId } columns { snakeId { type } } }') 8 | assert data == {'row': {'snakeId': 1}, 'columns': {'snakeId': {'type': 'int64'}}} 9 | data = aliasclient.execute('{ filter(snakeId: {eq: 1}) { length } }') 10 | assert data == {'filter': {'length': 1}} 11 | data = aliasclient.execute('{ filter(camelId: {eq: 1}) { length } }') 12 | assert data == {'filter': {'length': 1}} 13 | data = aliasclient.execute('{ group(by: "camelId") { length } }') 14 | assert data == {'group': {'length': 2}} 15 | data = aliasclient.execute('{ rank(by: "camelId") { length } }') 16 | assert data == {'rank': {'length': 1}} 17 | 18 | 19 | def test_snake(executor): 20 | data = executor('{ schema { names } }') 21 | assert 'snake_id' in data['schema']['names'] 22 | data = executor('{ row { snake_id } columns { snake_id { type } } }') 23 | assert data == {'row': {'snake_id': 1}, 'columns': {'snake_id': {'type': 'int64'}}} 24 | data = executor('{ filter(snake_id: {eq: 1}) { length } }') 25 | assert data == {'filter': {'length': 1}} 26 | data = executor('{ filter(camelId: {eq: 1}) { length } }') 27 | assert data == {'filter': {'length': 1}} 28 | 29 | 30 | def test_columns(executor): 31 | def execute(query): 32 | return executor(f'{{ columns {query} }}')['columns'] 33 | 34 | for name in ('uint8', 'int8', 'uint16', 'int16', 'int32'): 35 | assert execute(f'{{ {name} {{ values }} }}') == {name: {'values': [0, None]}} 36 | assert execute(f'{{ {name} {{ index(value: 0) }} }}') == {name: {'index': 0}} 37 | data = execute(f'{{ {name} {{ dropNull }} }}') 38 | assert data == {name: {'dropNull': [0]}} 39 | assert execute(f'{{ {name} {{ type }} }}') == {name: {'type': name}} 40 | assert execute(f'{{ {name} {{ min max }} }}') 41 | for name in ('uint32', 'uint64', 'int64'): 42 | assert execute(f'{{ {name} {{ values }} }}') == {name: {'values': [0, None]}} 43 | assert execute(f'{{ {name} {{ index(value: 0) }} }}') == {name: {'index': 0}} 44 | data = execute(f'{{ {name} {{ dropNull }} }}') 45 | assert data == {name: {'dropNull': [0]}} 46 | assert execute(f'{{ {name} {{ min max }} }}') 47 | 48 | for name in ('float', 'double'): 49 | assert execute(f'{{ {name} {{ values }} }}') == {name: {'values': [0.0, None]}} 50 | assert execute(f'{{ {name} {{ index(value: 0.0) }} }}') == {name: {'index': 0}} 51 | data = execute(f'{{ {name} {{ dropNull }} }}') 52 | assert data == {name: {'dropNull': [0.0]}} 53 | assert execute(f'{{ {name} {{ min max }} }}') 54 | 55 | for name in ('date32', 'date64'): 56 | assert execute(f'{{ {name} {{ values }} }}') == {name: {'values': ['1970-01-01', None]}} 57 | assert execute(f'{{ {name} {{ index(value: "1970-01-01") }} }}') == {name: {'index': 0}} 58 | data = execute(f'{{ {name} {{ dropNull }} }}') 59 | assert data == {name: {'dropNull': ['1970-01-01']}} 60 | assert execute(f'{{ {name} {{ min max }} }}') 61 | assert execute(f'{{ {name} {{ first last }} }}') 62 | 63 | data = execute('{ timestamp { values } }') 64 | assert data == {'timestamp': {'values': ['1970-01-01T00:00:00', None]}} 65 | data = execute(f'{{ {name} {{ dropNull }} }}') 66 | assert data == {name: {'dropNull': ['1970-01-01']}} 67 | assert execute('{ timestamp { index(value: "1970-01-01") } }') == {'timestamp': {'index': 0}} 68 | assert execute('{ timestamp { min max } }') 69 | 70 | for name in ('time32', 'time64'): 71 | assert execute(f'{{ {name} {{ values }} }}') == {name: {'values': ['00:00:00', None]}} 72 | assert execute(f'{{ {name} {{ index(value: "00:00:00") }} }}') == {name: {'index': 0}} 73 | data = execute(f'{{ {name} {{ dropNull }} }}') 74 | assert data == {name: {'dropNull': ['00:00:00']}} 75 | assert execute(f'{{ {name} {{ min max }} }}') 76 | 77 | for name in ('binary', 'string'): 78 | assert execute(f'{{ {name} {{ values }} }}') == {name: {'values': ['', None]}} 79 | assert execute(f'{{ {name} {{ index(value: "") }} }}') == {name: {'index': 0}} 80 | data = execute(f'{{ {name} {{ dropNull }} }}') 81 | assert data == {name: {'dropNull': ['']}} 82 | data = execute(f'{{ {name} {{ fillNull(value: "") }} }}') 83 | assert data == {name: {'fillNull': ['', '']}} 84 | 85 | assert execute('{ string { type } }') == { 86 | 'string': {'type': 'dictionary'} 87 | } 88 | 89 | 90 | def test_boolean(executor): 91 | def execute(query): 92 | return executor(f'{{ columns {query} }}')['columns'] 93 | 94 | assert execute('{ bool { values } }') == {'bool': {'values': [False, None]}} 95 | assert execute('{ bool { index(value: false) } }') == {'bool': {'index': 0}} 96 | assert execute('{ bool { index(value: false, start: 1, end: 2) } }') == {'bool': {'index': -1}} 97 | assert execute('{ bool { type } }') == {'bool': {'type': 'bool'}} 98 | assert execute('{ bool { unique { length } } }') == {'bool': {'unique': {'length': 2}}} 99 | assert execute('{ bool { any all } }') == {'bool': {'any': False, 'all': False}} 100 | assert execute('{ bool { indicesNonzero } }') == {'bool': {'indicesNonzero': []}} 101 | 102 | data = executor('{ scan(filter: {xor: [{name: "bool"}, {inv: {name: "bool"}}]}) { length } }') 103 | assert data == {'scan': {'length': 1}} 104 | data = executor( 105 | """{ scan(columns: {alias: "bool", andNot: [{inv: {name: "bool"}}, {name: "bool"}], kleene: true}) 106 | { columns { bool { values } } } }""" 107 | ) 108 | assert data == {'scan': {'columns': {'bool': {'values': [True, None]}}}} 109 | 110 | 111 | def test_decimal(executor): 112 | def execute(query): 113 | return executor(f'{{ columns {query} }}')['columns'] 114 | 115 | assert execute('{ decimal { values } }') == {'decimal': {'values': ['0', None]}} 116 | assert execute('{ decimal { min max } }') 117 | assert execute('{ decimal { indicesNonzero } }') == {'decimal': {'indicesNonzero': []}} 118 | assert execute('{ decimal { index(value: 0) } }') 119 | data = executor( 120 | '{ sort(by: "decimal", nullPlacement: "at_start") { columns { decimal { values } } } }' 121 | ) 122 | assert data == {'sort': {'columns': {'decimal': {'values': [None, '0']}}}} 123 | data = executor('{ rank(by: "decimal") { columns { decimal { values } } } }') 124 | assert data == {'rank': {'columns': {'decimal': {'values': ['0']}}}} 125 | data = executor('{ rank(by: "-decimal") { columns { decimal { values } } } }') 126 | assert data == {'rank': {'columns': {'decimal': {'values': ['0']}}}} 127 | 128 | 129 | def test_numeric(executor): 130 | for name in ('int32', 'int64', 'float'): 131 | data = executor(f'{{ columns {{ {name} {{ mean stddev variance }} }} }}') 132 | assert data == {'columns': {name: {'mean': 0.0, 'stddev': 0.0, 'variance': 0.0}}} 133 | data = executor(f'{{ columns {{ {name} {{ mode {{ values }} }} }} }}') 134 | assert data == {'columns': {name: {'mode': {'values': [0]}}}} 135 | data = executor(f'{{ columns {{ {name} {{ mode(n: 2) {{ counts }} }} }} }}') 136 | assert data == {'columns': {name: {'mode': {'counts': [1]}}}} 137 | data = executor(f'{{ columns {{ {name} {{ quantile }} }} }}') 138 | assert data == {'columns': {name: {'quantile': [0.0]}}} 139 | data = executor(f'{{ columns {{ {name} {{ tdigest }} }} }}') 140 | assert data == {'columns': {name: {'tdigest': [0.0]}}} 141 | data = executor(f'{{ columns {{ {name} {{ product }} }} }}') 142 | assert data == {'columns': {name: {'product': 0.0}}} 143 | data = executor(f'{{ columns {{ {name} {{ indicesNonzero }} }} }}') 144 | assert data == {'columns': {name: {'indicesNonzero': []}}} 145 | 146 | data = executor("""{ scan(columns: {alias: "int32", elementWise: {min: {name: "int32"}}}) { 147 | columns { int32 { values } } } }""") 148 | assert data == {'scan': {'columns': {'int32': {'values': [0, None]}}}} 149 | data = executor('{ column(name: "float", cast: "int32") { type } }') 150 | assert data == {'column': {'type': 'int32'}} 151 | data = executor("""{ scan(columns: {alias: "int32", negate: {checked: true, name: "int32"}}) { 152 | columns { int32 { values } } } }""") 153 | assert data == {'scan': {'columns': {'int32': {'values': [0, None]}}}} 154 | data = executor( 155 | """{ scan(columns: {alias: "float", coalesce: [{name: "float"}, {name: "int32"}]}) { 156 | columns { float { values } } } }""" 157 | ) 158 | assert data == {'scan': {'columns': {'float': {'values': [0.0, None]}}}} 159 | data = executor("""{ scan(columns: {bitWise: {not: {name: "int32"}}, alias: "int32"}) { 160 | columns { int32 { values } } } }""") 161 | assert data == {'scan': {'columns': {'int32': {'values': [-1, None]}}}} 162 | data = executor( 163 | """{ scan(columns: {bitWise: {or: [{name: "int32"}, {name: "int64"}]}, alias: "int64"}) { 164 | columns { int64 { values } } } }""" 165 | ) 166 | assert data == {'scan': {'columns': {'int64': {'values': [0, None]}}}} 167 | 168 | 169 | def test_datetime(executor): 170 | for name in ('timestamp', 'date32'): 171 | data = executor( 172 | f"""{{ scan(columns: {{alias: "year", temporal: {{year: {{name: "{name}"}}}}}}) 173 | {{ column(name: "year") {{ ... on LongColumn {{ values }} }} }} }}""" 174 | ) 175 | assert data == {'scan': {'column': {'values': [1970, None]}}} 176 | data = executor( 177 | f"""{{ scan(columns: {{alias: "quarter", temporal: {{quarter: {{name: "{name}"}}}}}}) 178 | {{ column(name: "quarter") {{ ... on LongColumn {{ values }} }} }} }}""" 179 | ) 180 | assert data == {'scan': {'column': {'values': [1, None]}}} 181 | data = executor(f"""{{ scan(columns: {{alias: "{name}", 182 | temporal: {{yearsBetween: [{{name: "{name}"}}, {{name: "{name}"}}]}}}}) 183 | {{ column(name: "{name}") {{ ... on LongColumn {{ values }} }} }} }}""") 184 | assert data == {'scan': {'column': {'values': [0, None]}}} 185 | data = executor( 186 | """{ scan(columns: {alias: "timestamp", temporal: {strftime: {name: "timestamp"}}}) { 187 | column(name: "timestamp") { type } } }""" 188 | ) 189 | assert data == {'scan': {'column': {'type': 'string'}}} 190 | for name in ('timestamp', 'time32'): 191 | data = executor( 192 | f"""{{ scan(columns: {{alias: "hour", temporal: {{hour: {{name: "{name}"}}}}}}) 193 | {{ column(name: "hour") {{ ... on LongColumn {{ values }} }} }} }}""" 194 | ) 195 | assert data == {'scan': {'column': {'values': [0, None]}}} 196 | data = executor( 197 | f"""{{ scan(columns: {{alias: "subsecond", temporal: {{subsecond: {{name: "{name}"}}}}}}) 198 | {{ column(name: "subsecond") {{ ... on FloatColumn {{ values }} }} }} }}""" 199 | ) 200 | assert data == {'scan': {'column': {'values': [0.0, None]}}} 201 | data = executor(f"""{{ scan(columns: {{alias: "hours", 202 | temporal: {{hoursBetween: [{{name: "{name}"}}, {{name: "{name}"}}]}}}}) 203 | {{ column(name: "hours") {{ ... on LongColumn {{ values }} }} }} }}""") 204 | assert data == {'scan': {'column': {'values': [0, None]}}} 205 | with pytest.raises(ValueError): 206 | executor('{ columns { time64 { between(unit: "hours") { values } } } }') 207 | data = executor( 208 | """{ scan(columns: {alias: "timestamp", temporal: {assumeTimezone: {name: "timestamp"}, timezone: "UTC"}}) { 209 | columns { timestamp { values } } } }""" 210 | ) 211 | dates = data['scan']['columns']['timestamp']['values'] 212 | assert dates == ['1970-01-01T00:00:00+00:00', None] 213 | data = executor( 214 | """{ scan(columns: {alias: "time32", temporal: {round: {name: "time32"}, unit: "hour"}}) { 215 | columns { time32 { values } } } }""" 216 | ) 217 | assert data == {'scan': {'columns': {'time32': {'values': ['00:00:00', None]}}}} 218 | 219 | 220 | def test_duration(executor): 221 | data = executor( 222 | """{ scan(columns: {alias: "diff", checked: true, subtract: [{name: "timestamp"}, {name: "timestamp"}]}) 223 | { column(name: "diff") { ... on DurationColumn { unique { values } } } } }""" 224 | ) 225 | assert data == {'scan': {'column': {'unique': {'values': ['P0D', None]}}}} 226 | data = executor('{ runs(split: [{name: "timestamp", gt: 0.0}]) { length } }') 227 | assert data == {'runs': {'length': 1}} 228 | data = executor( 229 | """{ scan(columns: {alias: "diff", temporal: 230 | {monthDayNanoIntervalBetween: [{name: "timestamp"}, {name: "timestamp"}]}}) 231 | { column(name: "diff") { ... on DurationColumn { values } } } }""" 232 | ) 233 | assert data == {'scan': {'column': {'values': ['P0M0D', None]}}} 234 | 235 | 236 | def test_list(executor): 237 | data = executor('{ columns { list { length type } } }') 238 | assert data == {'columns': {'list': {'length': 2, 'type': 'list'}}} 239 | data = executor('{ columns { list { values { length } } } }') 240 | assert data == {'columns': {'list': {'values': [{'length': 3}, None]}}} 241 | data = executor('{ columns { list { dropNull { length } } } }') 242 | assert data == {'columns': {'list': {'dropNull': [{'length': 3}]}}} 243 | data = executor('{ row { list { ... on IntColumn { values } } } }') 244 | assert data == {'row': {'list': {'values': [0, 1, 2]}}} 245 | data = executor('{ row(index: -1) { list { ... on IntColumn { values } } } }') 246 | assert data == {'row': {'list': None}} 247 | data = executor("""{ aggregate(approximateMedian: {name: "list"}) { 248 | column(name: "list") { ... on FloatColumn { values } } } }""") 249 | assert data == {'aggregate': {'column': {'values': [1.0, None]}}} 250 | data = executor("""{ aggregate(tdigest: {name: "list", q: [0.25, 0.75]}) { 251 | columns { list { flatten { ... on FloatColumn { values } } } } } }""") 252 | column = data['aggregate']['columns']['list'] 253 | assert column == {'flatten': {'values': [0.0, 2.0]}} 254 | 255 | data = executor("""{ apply(list: {quantile: {name: "list", q: 0.5}}) { 256 | columns { list { flatten { ... on FloatColumn { values } } } } } }""") 257 | assert data == {'apply': {'columns': {'list': {'flatten': {'values': [1.0, None]}}}}} 258 | data = executor("""{ apply(list: {index: {name: "list", value: 1}}) { 259 | column(name: "list") { ... on LongColumn { values } } } }""") 260 | assert data == {'apply': {'column': {'values': [1, -1]}}} 261 | data = executor( 262 | """{ scan(columns: {list: {element: [{name: "list"}, {value: 1}]}, alias: "value"}) { 263 | column(name: "value") { ... on IntColumn { values } } } }""" 264 | ) 265 | assert data == {'scan': {'column': {'values': [1, None]}}} 266 | data = executor("""{ scan(columns: {list: {slice: {name: "list"}, stop: 1}, alias: "value"}) { 267 | column(name: "value") { ... on ListColumn { flatten { ... on IntColumn { values } } } } } }""") 268 | assert data == {'scan': {'column': {'flatten': {'values': [0]}}}} 269 | data = executor("""{ aggregate(distinct: {name: "list", mode: "only_null"}) 270 | { columns { list { flatten { length } } } } }""") 271 | assert data['aggregate']['columns']['list'] == {'flatten': {'length': 0}} 272 | data = executor("""{ apply(list: {filter: {ne: [{name: "list"}, {value: 1}]}}) { 273 | columns { list { values { ... on IntColumn { values } } } } } }""") 274 | column = data['apply']['columns']['list'] 275 | assert column == {'values': [{'values': [0, 2]}, None]} 276 | data = executor('{ apply(list: {mode: {name: "list"}}) { column(name: "list") { type } } }') 277 | assert data['apply']['column']['type'] == 'large_list>' 278 | data = executor( 279 | """{ aggregate(stddev: {name: "list"}, variance: {name: "list", alias: "var", ddof: 1}) { 280 | column(name: "list") { ... on FloatColumn { values } } 281 | var: column(name: "var") { ... on FloatColumn { values } } } }""" 282 | ) 283 | assert data['aggregate']['column']['values'] == [pytest.approx((2 / 3) ** 0.5), None] 284 | assert data['aggregate']['var']['values'] == [1, None] 285 | data = executor( 286 | """{ runs(by: "int32") { scan(columns: {binary: {join: [{name: "binary"}, {base64: ""}]}, alias: "binary"}) { 287 | column(name: "binary") { ... on Base64Column { values } } } } }""" 288 | ) 289 | assert data == {'runs': {'scan': {'column': {'values': [None]}}}} 290 | data = executor('{ columns { list { value { type } } } }') 291 | assert data == {'columns': {'list': {'value': {'type': 'int32'}}}} 292 | data = executor('{ tables { column(name: "list") { type } } }') 293 | assert data == {'tables': [{'column': {'type': 'int32'}}, None]} 294 | data = executor( 295 | '{ apply(list: {rank: {by: "list"}}) { columns { list { values { length } } } } }' 296 | ) 297 | assert data == {'apply': {'columns': {'list': {'values': [{'length': 1}, None]}}}} 298 | data = executor( 299 | '{ apply(list: {rank: {by: "list", max: 2}}) { columns { list { flatten { ... on IntColumn { values } } } } } }' 300 | ) 301 | assert data == {'apply': {'columns': {'list': {'flatten': {'values': [0, 1]}}}}} 302 | 303 | 304 | def test_struct(executor): 305 | data = executor('{ columns { struct { names column(name: "x") { length } } } }') 306 | assert data == {'columns': {'struct': {'names': ['x', 'y'], 'column': {'length': 2}}}} 307 | data = executor("""{ scan(columns: {alias: "leaf", name: ["struct", "x"]}) { 308 | column(name: "leaf") { ... on IntColumn { values } } } }""") 309 | assert data == {'scan': {'column': {'values': [0, None]}}} 310 | data = executor('{ column(name: ["struct", "x"]) { type } }') 311 | assert data == {'column': {'type': 'int32'}} 312 | data = executor('{ row { struct } columns { struct { value } } }') 313 | assert data['row']['struct'] == data['columns']['struct']['value'] == {'x': 0, 'y': None} 314 | with pytest.raises(ValueError, match="must be BOOL"): 315 | executor( 316 | '{ scan(filter: {caseWhen: [{name: "struct"}, {name: "int32"}, {name: "float"}]}) {type } }' 317 | ) 318 | 319 | 320 | def test_dictionary(executor): 321 | data = executor('{ column(name: "string") { length } }') 322 | assert data == {'column': {'length': 2}} 323 | data = executor("""{ group(by: ["string"], aggregate: {list: {name: "camelId"}}) { tables { 324 | columns { string { values } } column(name: "camelId") { length } } } }""") 325 | assert data['group']['tables'] == [ 326 | {'columns': {'string': {'values': ['']}}, 'column': {'length': 1}}, 327 | {'columns': {'string': {'values': [None]}}, 'column': {'length': 1}}, 328 | ] 329 | data = executor("""{ group(by: ["camelId"], aggregate: {countDistinct: {name: "string"}}) { 330 | column(name: "string") { ... on LongColumn { values } } } }""") 331 | assert data == {'group': {'column': {'values': [1, 0]}}} 332 | data = executor( 333 | """{ scan(columns: {alias: "string", coalesce: [{name: "string"}, {value: ""}]}) { 334 | columns { string { values } } } }""" 335 | ) 336 | assert data == {'scan': {'columns': {'string': {'values': ['', '']}}}} 337 | 338 | 339 | def test_selections(executor): 340 | data = executor('{ slice { length } slice { sort(by: "snake_id") { length } } }') 341 | assert data == {'slice': {'length': 2, 'sort': {'length': 2}}} 342 | data = executor('{ dropNull { length } }') 343 | assert data == {'dropNull': {'length': 2}} 344 | data = executor('{ dropNull { columns { float { values } } } }') 345 | assert data == {'dropNull': {'columns': {'float': {'values': [0.0]}}}} 346 | 347 | 348 | def test_conditions(executor): 349 | data = executor( 350 | """{ scan(columns: {alias: "bool", ifElse: [{name: "bool"}, {name: "int32"}, {name: "float"}]}) { 351 | column(name: "bool") { type } } }""" 352 | ) 353 | assert data == {'scan': {'column': {'type': 'float'}}} 354 | with pytest.raises(ValueError, match="no kernel"): 355 | executor("""{ scan(columns: {alias: "bool", 356 | ifElse: [{name: "struct"}, {name: "int32"}, {name: "float"}]}) { slice { type } } }""") 357 | 358 | 359 | def test_long(executor): 360 | with pytest.raises(ValueError, match="Long cannot represent value"): 361 | executor('{ filter(int64: {eq: 0.0}) { length } }') 362 | 363 | 364 | def test_base64(executor): 365 | data = executor("""{ scan(columns: {alias: "binary", binary: {length: {name: "binary"}}}) { 366 | column(name: "binary") { ...on IntColumn { values } } } }""") 367 | assert data == {'scan': {'column': {'values': [0, None]}}} 368 | data = executor( 369 | """{ scan(columns: {alias: "binary", binary: {repeat: [{name: "binary"}, {value: 2}]}}) { 370 | columns { binary { values } } } }""" 371 | ) 372 | assert data == {'scan': {'columns': {'binary': {'values': ['', None]}}}} 373 | data = executor( 374 | '{ apply(fillNullForward: {name: "binary"}) { columns { binary { values } } } }' 375 | ) 376 | assert data == {'apply': {'columns': {'binary': {'values': ['', '']}}}} 377 | data = executor( 378 | """{ scan(columns: {alias: "binary", coalesce: [{name: "binary"}, {base64: "Xw=="}]}) { 379 | columns { binary { values } } } }""" 380 | ) 381 | assert data == {'scan': {'columns': {'binary': {'values': ['', 'Xw==']}}}} 382 | data = executor("""{ scan(columns: {alias: "binary", binary: {joinElementWise: [ 383 | {name: "binary"}, {name: "binary"}, {base64: "Xw=="}], nullHandling: "replace"}}) { 384 | columns { binary { values } } } }""") 385 | assert data == {'scan': {'columns': {'binary': {'values': ['Xw==', 'Xw==']}}}} 386 | data = executor("""{ scan(columns: {alias: "binary", binary: {replaceSlice: {name: "binary"} 387 | start: 0, stop: 1, replacement: "Xw=="}}) { columns { binary { values } } } }""") 388 | assert data == {'scan': {'columns': {'binary': {'values': ['Xw==', None]}}}} 389 | data = executor('{ scan(filter: {eq: [{name: "binary"}, {base64: "Xw=="}]}) { length } }') 390 | assert data == {'scan': {'length': 0}} 391 | -------------------------------------------------------------------------------- /tests/test_service.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_slice(client): 5 | data = client.execute('{ length slice(length: 3) { columns { zipcode { values } } } }') 6 | assert data == {'length': 41700, 'slice': {'columns': {'zipcode': {'values': [501, 544, 601]}}}} 7 | data = client.execute('{ slice(offset: 1) { columns { zipcode { values } } } }') 8 | zipcodes = data['slice']['columns']['zipcode']['values'] 9 | assert zipcodes[0] == 544 10 | assert len(zipcodes) == 41699 11 | data = client.execute('{ slice(offset: -1, reverse: true) { columns { zipcode { values } } } }') 12 | assert data['slice']['columns']['zipcode']['values'] == [99950] 13 | data = client.execute('{ columns { zipcode { count } } }') 14 | assert data['columns']['zipcode']['count'] == 41700 15 | data = client.execute('{ columns { zipcode { count(mode: "only_null") } } }') 16 | assert data['columns']['zipcode']['count'] == 0 17 | 18 | 19 | def test_ints(client): 20 | data = client.execute('{ columns { zipcode { values sum mean } } }') 21 | zipcodes = data['columns']['zipcode'] 22 | assert len(zipcodes['values']) == 41700 23 | assert zipcodes['sum'] == 2066562337 24 | assert zipcodes['mean'] == pytest.approx(49557.849808) 25 | data = client.execute('{ columns { zipcode { values min max unique { values counts } } } }') 26 | zipcodes = data['columns']['zipcode'] 27 | assert len(zipcodes['values']) == 41700 28 | assert zipcodes['min'] == 501 29 | assert zipcodes['max'] == 99950 30 | assert len(zipcodes['unique']['values']) == 41700 31 | assert set(zipcodes['unique']['counts']) == {1} 32 | data = client.execute('{ columns { zipcode { size } } }') 33 | assert data['columns']['zipcode']['size'] > 0 34 | 35 | 36 | def test_floats(client): 37 | data = client.execute('{ columns { latitude { values sum mean } } }') 38 | latitudes = data['columns']['latitude'] 39 | assert len(latitudes['values']) == 41700 40 | assert latitudes['sum'] == pytest.approx(1606220.07592) 41 | assert latitudes['mean'] == pytest.approx(38.518467) 42 | data = client.execute('{ columns { longitude { min max } } }') 43 | longitudes = data['columns']['longitude'] 44 | assert longitudes['min'] == pytest.approx(-174.21333) 45 | assert longitudes['max'] == pytest.approx(-65.301389) 46 | data = client.execute('{ columns { latitude { quantile(q: [0.5]) } } }') 47 | (quantile,) = data['columns']['latitude']['quantile'] 48 | assert quantile == pytest.approx(39.12054) 49 | data = client.execute( 50 | """{ scan(columns: {alias: "latitude", elementWise: {min: [{name: "latitude"}, {name: "longitude"}]}}) { 51 | columns { latitude { min } } } }""" 52 | ) 53 | assert data == {'scan': {'columns': {'latitude': {'min': pytest.approx(-174.213333)}}}} 54 | data = client.execute( 55 | """{scan(columns: {alias: "l", setLookup: {digitize: [{name: "latitude"}, {value: [40]}]}}) { 56 | column(name: "l") { ... on LongColumn { unique { values counts } } } } }""" 57 | ) 58 | assert data == {"scan": {"column": {"unique": {"values": [1, 0], "counts": [17955, 23745]}}}} 59 | data = client.execute( 60 | '{ scan(columns: {alias: "latitude", log: {logb: [{name: "latitude"}, {value: 3}]}}) { row { latitude } } }' 61 | ) 62 | assert data == {'scan': {'row': {'latitude': pytest.approx(3.376188)}}} 63 | data = client.execute( 64 | '{ scan(columns: {alias: "latitude", rounding: {round: {name: "latitude"}}}) {row { latitude } } }' 65 | ) 66 | assert data == {'scan': {'row': {'latitude': 41.0}}} 67 | data = client.execute( 68 | '{ scan(columns: {alias: "latitude", rounding: {round: {name: "latitude"}, multiple: 2.0}}) {row { latitude } } }' 69 | ) 70 | assert data == {'scan': {'row': {'latitude': 40.0}}} 71 | data = client.execute( 72 | '{ scan(columns: {alias: "latitude", trig: {sin: {name: "latitude"}}}) {row { latitude } } }' 73 | ) 74 | assert data == {'scan': {'row': {'latitude': pytest.approx(0.02273553)}}} 75 | data = client.execute('{ scan(filter: {isFinite: {name: "longitude"}}) { length } }') 76 | assert data == {'scan': {'length': 41700}} 77 | data = client.execute('{ column(name: "latitude", cast: "int32", safe: false) { type } }') 78 | assert data == {'column': {'type': 'int32'}} 79 | 80 | 81 | def test_strings(client): 82 | data = client.execute("""{ columns { 83 | state { values unique { values counts } countDistinct } 84 | county { unique { length values } } 85 | city { min max } 86 | } }""") 87 | states = data['columns']['state'] 88 | assert len(states['values']) == 41700 89 | assert len(states['unique']['values']) == states['countDistinct'] == 52 90 | assert sum(states['unique']['counts']) == 41700 91 | counties = data['columns']['county'] 92 | assert len(counties['unique']['values']) == counties['unique']['length'] == 1920 93 | assert data['columns']['city'] == {'min': 'Aaronsburg', 'max': 'Zwolle'} 94 | data = client.execute("""{ filter(state: {eq: "CA"}) { 95 | scan(filter: {gt: [{utf8: {length: {name: "city"}}}, {value: 23}]}) { length } } }""") 96 | assert data == {'filter': {'scan': {'length': 1}}} 97 | data = client.execute( 98 | '{ scan(columns: {utf8: {swapcase: {name: "city"}}, alias: "city"}) { row { city } } }' 99 | ) 100 | assert data == {'scan': {'row': {'city': 'hOLTSVILLE'}}} 101 | data = client.execute( 102 | '{ scan(columns: {utf8: {capitalize: {name: "state"}}, alias: "state"}) { row { state } } }' 103 | ) 104 | assert data == {'scan': {'row': {'state': 'Ny'}}} 105 | data = client.execute('{ scan(filter: {utf8: {isLower: {name: "city"}}}) { length } }') 106 | assert data == {'scan': {'length': 0}} 107 | data = client.execute('{ scan(filter: {utf8: {isTitle: {name: "city"}}}) { length } }') 108 | assert data == {'scan': {'length': 41700}} 109 | data = client.execute( 110 | """{ scan(columns: {alias: "city", substring: {match: {name: "city"}, pattern: "Mountain"}}) 111 | { scan(filter: {name: "city"}) { length } } }""" 112 | ) 113 | assert data == {'scan': {'scan': {'length': 88}}} 114 | data = client.execute( 115 | """{ scan(filter: {substring: {match: {name: "city"}, pattern: "mountain", ignoreCase: true}}) 116 | { length } }""" 117 | ) 118 | assert data == {'scan': {'length': 88}} 119 | data = client.execute( 120 | """{ scan(filter: {substring: {match: {name: "city"}, pattern: "^Mountain", regex: true}}) 121 | { length } }""" 122 | ) 123 | assert data == {'scan': {'length': 42}} 124 | data = client.execute( 125 | """{ scan(columns: {alias: "idx", setLookup: {indexIn: [{name: "state"}, {value: ["CA", "OR"]}]}}) 126 | { column(name: "idx") { ... on IntColumn { unique { values } } } } }""" 127 | ) 128 | assert data == {'scan': {'column': {'unique': {'values': [None, 0, 1]}}}} 129 | 130 | 131 | def test_string_methods(client): 132 | data = client.execute( 133 | """{ scan(columns: {alias: "split", substring: {split: {name: "city"}, pattern: "-", maxSplits: 1}}) { 134 | column(name: "split") { type } } }""" 135 | ) 136 | assert data == {'scan': {'column': {'type': 'list'}}} 137 | data = client.execute("""{ scan(columns: {alias: "split", substring: {split: {name: "city"}}}) { 138 | column(name: "split") { type } } }""") 139 | assert data == {'scan': {'column': {'type': 'list'}}} 140 | data = client.execute( 141 | """{ scan(columns: {alias: "state", utf8: {trim: {name: "state"}, characters: "C"}}) { 142 | columns { state { values } } } }""" 143 | ) 144 | states = data['scan']['columns']['state']['values'] 145 | assert 'CA' not in states and 'A' in states 146 | data = client.execute( 147 | '{ scan(columns: {alias: "state", utf8: {ltrim: {name: "state"}}}) { length } }' 148 | ) 149 | assert data == {'scan': {'length': 41700}} 150 | data = client.execute( 151 | """{ scan(columns: {alias: "state", utf8: {center: {name: "state"}, width: 4, padding: "_"}}) 152 | { row { state } } }""" 153 | ) 154 | assert data == {'scan': {'row': {'state': '_NY_'}}} 155 | data = client.execute( 156 | """{ scan(columns: {alias: "state", utf8: {replaceSlice: {name: "state"}, start: 0, stop: 2, replacement: ""}}) 157 | { columns { state { unique { values } } } } }""" 158 | ) 159 | assert data == {'scan': {'columns': {'state': {'unique': {'values': ['']}}}}} 160 | data = client.execute( 161 | '{ scan(columns: {alias: "state", utf8: {sliceCodeunits: {name: "state"}, start: 0, stop: 1}}) { row { state } } }' 162 | ) 163 | assert data == {'scan': {'row': {'state': 'N'}}} 164 | data = client.execute( 165 | """{ scan(columns: {alias: "state", substring: {replace: {name: "state"}, pattern: "C", replacement: "A"}}) 166 | { columns { state { values } } } }""" 167 | ) 168 | assert 'AA' in data['scan']['columns']['state']['values'] 169 | 170 | 171 | def test_search(client): 172 | data = client.execute('{ schema { index } filter { length } }') 173 | assert data == {'schema': {'index': ['zipcode']}, 'filter': {'length': 41700}} 174 | data = client.execute('{ filter(zipcode: {eq: 501}) { columns { zipcode { values } } } }') 175 | assert data == {'filter': {'columns': {'zipcode': {'values': [501]}}}} 176 | data = client.execute('{ filter(zipcode: {ne: 501}) { length } }') 177 | assert data['filter']['length'] == 41699 178 | 179 | data = client.execute('{ filter(zipcode: {ge: 99929}) { columns { zipcode { values } } } }') 180 | assert data == {'filter': {'columns': {'zipcode': {'values': [99929, 99950]}}}} 181 | data = client.execute('{ filter(zipcode: {lt: 601}) { columns { zipcode { values } } } }') 182 | assert data == {'filter': {'columns': {'zipcode': {'values': [501, 544]}}}} 183 | data = client.execute( 184 | '{ filter(zipcode: {gt: 501, le: 601}) { columns { zipcode { values } } } }' 185 | ) 186 | assert data == {'filter': {'columns': {'zipcode': {'values': [544, 601]}}}} 187 | 188 | data = client.execute('{ filter(zipcode: {eq: []}) { length } }') 189 | assert data == {'filter': {'length': 0}} 190 | data = client.execute('{ filter(zipcode: {eq: [0]}) { length } }') 191 | assert data == {'filter': {'length': 0}} 192 | data = client.execute( 193 | '{ filter(zipcode: {eq: [501, 601]}) { columns { zipcode { values } } } }' 194 | ) 195 | assert data == {'filter': {'columns': {'zipcode': {'values': [501, 601]}}}} 196 | data = client.execute('{ slice(reverse: true) { filter(zipcode: {ge: 90000}) { length } } }') 197 | assert data == {'slice': {'filter': {'length': 4275}}} 198 | 199 | 200 | def test_filter(client): 201 | data = client.execute('{ filter { length } }') 202 | assert data['filter']['length'] == 41700 203 | data = client.execute('{ filter(city: {eq: "Mountain View"}) { length } }') 204 | assert data['filter']['length'] == 11 205 | data = client.execute('{ filter(state: {ne: "CA"}) { length } }') 206 | assert data['filter']['length'] == 39053 207 | data = client.execute('{ filter(city: {eq: "Mountain View"}, state: {le: "CA"}) { length } }') 208 | assert data['filter']['length'] == 7 209 | data = client.execute('{ filter(state: {eq: null}) { columns { state { values } } } }') 210 | assert data['filter']['columns']['state']['values'] == [] 211 | data = client.execute( 212 | '{ scan(filter: {le: [{abs: {name: "longitude"}}, {value: 66}]}) { length } }' 213 | ) 214 | assert data['scan']['length'] == 30 215 | with pytest.raises(ValueError, match="optional, not nullable"): 216 | client.execute('{ filter(city: {le: null}) { length } }') 217 | 218 | 219 | def test_scan(client): 220 | data = client.execute('{ scan(filter: {eq: [{name: "county"}, {name: "city"}]}) { length } }') 221 | assert data['scan']['length'] == 2805 222 | data = client.execute("""{ scan(filter: {or: [{eq: [{name: "state"}, {name: "county"}]}, 223 | {eq: [{name: "county"}, {name: "city"}]}]}) { length } }""") 224 | assert data['scan']['length'] == 2805 225 | data = client.execute( 226 | """{ scan(columns: {alias: "zipcode", add: [{name: "zipcode"}, {name: "zipcode"}]}) 227 | { columns { zipcode { min } } } }""" 228 | ) 229 | assert data['scan']['columns']['zipcode']['min'] == 1002 230 | data = client.execute( 231 | """{ scan(columns: {alias: "zipcode", subtract: [{name: "zipcode"}, {name: "zipcode"}]}) 232 | { columns { zipcode { unique { values } } } } }""" 233 | ) 234 | assert data['scan']['columns']['zipcode']['unique']['values'] == [0] 235 | data = client.execute( 236 | """{ scan(columns: {alias: "product", multiply: [{name: "latitude"}, {name: "longitude"}]}) 237 | { scan(filter: {gt: [{name: "product"}, {value: 0}]}) { length } } }""" 238 | ) 239 | assert data['scan']['scan']['length'] == 0 240 | data = client.execute( 241 | '{ scan(columns: {name: "zipcode", cast: "float"}) { column(name: "zipcode") { type } } }' 242 | ) 243 | assert data['scan']['column']['type'] == 'float' 244 | data = client.execute( 245 | '{ scan(filter: {inv: {eq: [{name: "state"}, {value: "CA"}]}}) { length } }' 246 | ) 247 | assert data == {'scan': {'length': 39053}} 248 | data = client.execute( 249 | '{ scan(columns: {name: "latitude", cast: "int32", safe: false}) { column(name: "latitude") { type } } }' 250 | ) 251 | assert data == {'scan': {'column': {'type': 'int32'}}} 252 | data = client.execute( 253 | """{ scan(columns: {alias: "longitude", elementWise: {max: [{name: "longitude"}, {name: "latitude"}]}}) 254 | { columns { longitude { min } } } }""" 255 | ) 256 | assert data['scan']['columns']['longitude']['min'] == pytest.approx(17.963333) 257 | data = client.execute( 258 | """{ scan(columns: {alias: "latitude", elementWise: {min: [{name: "longitude"}, {name: "latitude"}]}}) 259 | { columns { latitude { max } } } }""" 260 | ) 261 | assert data['scan']['columns']['latitude']['max'] == pytest.approx(-65.301389) 262 | data = client.execute( 263 | """{ scan(columns: {alias: "state", elementWise: {min: [{name: "state"}, {name: "county"}], skipNulls: false}}) 264 | { columns { state { values } } } }""" 265 | ) 266 | assert data['scan']['columns']['state']['values'][0] == 'NY' 267 | 268 | 269 | def test_apply(client): 270 | data = client.execute( 271 | """{ scan(columns: {alias: "city", substring: {find: {name: "city"}, pattern: "mountain"}}) 272 | { column(name: "city") { ... on IntColumn { unique { values } } } } }""" 273 | ) 274 | assert data['scan']['column']['unique']['values'] == [-1] 275 | data = client.execute( 276 | """{ scan(columns: {alias: "city", substring: {count: {name: "city"}, pattern: "mountain", ignoreCase: true}}) 277 | { column(name: "city") { ... on IntColumn { unique { values } } } } }""" 278 | ) 279 | assert data['scan']['column']['unique']['values'] == [0, 1] 280 | data = client.execute("""{ scan(columns: {alias: "state", binary: {joinElementWise: [ 281 | {name: "state"}, {name: "county"}, {value: " "}]}}) { columns { state { values } } } }""") 282 | assert data['scan']['columns']['state']['values'][0] == 'NY Suffolk' 283 | data = client.execute("""{ apply(cumulativeSum: {name: "zipcode", skipNulls: false}) 284 | { columns { zipcode { value(index: -1) } } } }""") 285 | assert data == {'apply': {'columns': {'zipcode': {'value': 2066562337}}}} 286 | data = client.execute("""{ apply(cumulativeSum: {name: "zipcode", checked: true}) 287 | { columns { zipcode { value(index: -1) } } } }""") 288 | assert data == {'apply': {'columns': {'zipcode': {'value': 2066562337}}}} 289 | data = client.execute( 290 | '{ apply(pairwiseDiff: {name: "zipcode"}) { columns { zipcode { value } } } }' 291 | ) 292 | assert data == {'apply': {'columns': {'zipcode': {'value': None}}}} 293 | data = client.execute('{ apply(rank: {name: "zipcode"}) { row { zipcode } } }') 294 | assert data == {'apply': {'row': {'zipcode': 1}}} 295 | data = client.execute( 296 | """{ apply(rank: {name: "zipcode", sortKeys: "descending", nullPlacement: "at_start", tiebreaker: "dense"}) 297 | { row { zipcode } } }""" 298 | ) 299 | assert data == {'apply': {'row': {'zipcode': 41700}}} 300 | 301 | 302 | def test_sort(client): 303 | with pytest.raises(ValueError, match="is required"): 304 | client.execute('{ sort { columns { state { values } } } }') 305 | data = client.execute('{ sort(by: ["state"]) { columns { state { values } } } }') 306 | assert data['sort']['columns']['state']['values'][0] == 'AK' 307 | data = client.execute('{ sort(by: "-state") { columns { state { values } } } }') 308 | assert data['sort']['columns']['state']['values'][0] == 'WY' 309 | data = client.execute('{ sort(by: ["state"], length: 1) { columns { state { values } } } }') 310 | assert data['sort']['columns']['state']['values'] == ['AK'] 311 | data = client.execute('{ sort(by: "-state", length: 1) { columns { state { values } } } }') 312 | assert data['sort']['columns']['state']['values'] == ['WY'] 313 | data = client.execute('{ sort(by: ["state", "county"]) { columns { county { values } } } }') 314 | assert data['sort']['columns']['county']['values'][0] == 'Aleutians East' 315 | data = client.execute( 316 | """{ sort(by: ["-state", "-county"], length: 1) { columns { county { values } } } }""" 317 | ) 318 | assert data['sort']['columns']['county']['values'] == ['Weston'] 319 | data = client.execute('{ sort(by: ["state"], length: 2) { columns { state { values } } } }') 320 | assert data['sort']['columns']['state']['values'] == ['AK', 'AK'] 321 | data = client.execute( 322 | """{ group(by: ["state"], aggregate: {list: {name: "county"}}) { apply(list: {sort: {by: ["county"]}}) 323 | { aggregate(first: [{name: "county"}]) { row { state county } } } } }""" 324 | ) 325 | assert data['group']['apply']['aggregate']['row'] == {'state': 'NY', 'county': 'Albany'} 326 | data = client.execute( 327 | """{ group(by: ["state"], aggregate: {list: {name: "county"}}) { apply(list: {sort: {by: ["-county"], length: 1}}) 328 | { aggregate(first: [{name: "county"}]) { row { state county } } } } }""" 329 | ) 330 | assert data['group']['apply']['aggregate']['row'] == {'state': 'NY', 'county': 'Yates'} 331 | data = client.execute( 332 | """{ group(by: ["state"], aggregate: {list: {name: "county"}}) { apply(list: {sort: {by: "county", length: 2}}) 333 | { row { state } column(name: "county") { ... on ListColumn { value { length } } } } } }""" 334 | ) 335 | assert data['group']['apply'] == {'row': {'state': 'NY'}, 'column': {'value': {'length': 2}}} 336 | 337 | 338 | def test_group(client): 339 | with pytest.raises(ValueError, match="list"): 340 | client.execute('{ group(by: ["state"]) { tables { length } } }') 341 | with pytest.raises(ValueError, match="cannot represent"): 342 | client.execute('{ group(by: "state", aggregate: {list: {name: "city"}}) { row { city } } }') 343 | data = client.execute( 344 | """{ group(by: ["state"], ordered: true, aggregate: {list: {name: "county"}}) { length tables { length 345 | columns { state { values } county { min max } } } 346 | scan(columns: {list: {valueLength: {name: "county"}}, alias: "c"}) { 347 | column(name: "c") { ... on IntColumn { values } } } } }""" 348 | ) 349 | assert len(data['group']['tables']) == data['group']['length'] == 52 350 | table = data['group']['tables'][0] 351 | assert table['length'] == data['group']['scan']['column']['values'][0] == 2205 352 | assert set(table['columns']['state']['values']) == {'NY'} 353 | assert table['columns']['county'] == {'min': 'Albany', 'max': 'Yates'} 354 | data = client.execute( 355 | """{ group(by: ["state", "county"], counts: "counts", aggregate: {list: {name: "city"}}) { 356 | scan(filter: {gt: [{name: "counts"}, {value: 200}]}) { 357 | aggregate(min: [{name: "city", alias: "min"}], max: [{name: "city", alias: "max"}]) { 358 | min: column(name: "min") { ... on StringColumn { values } } 359 | max: column(name: "max") { ... on StringColumn { values } } } } } }""" 360 | ) 361 | agg = data['group']['scan']['aggregate'] 362 | assert agg['min']['values'] == ['Naval Anacost Annex', 'Alsip', 'Alief', 'Acton'] 363 | assert agg['max']['values'] == ['Washington Navy Yard', 'Worth', 'Webster', 'Woodland Hills'] 364 | data = client.execute( 365 | """{ group(by: ["state", "county"], counts: "c", aggregate: {list: [{name: "zipcode"}, {name: "latitude"}, {name: "longitude"}]}) { 366 | sort(by: ["-c"], length: 4) { aggregate(sum: [{name: "latitude"}], mean: [{name: "longitude"}]) { 367 | columns { latitude { values } longitude { values } } 368 | column(name: "zipcode") { type } } } } }""" 369 | ) 370 | agg = data['group']['sort']['aggregate'] 371 | assert agg['column']['type'] == 'list' 372 | assert all(latitude > 1000 for latitude in agg['columns']['latitude']['values']) 373 | assert all(77 > longitude > -119 for longitude in agg['columns']['longitude']['values']) 374 | data = client.execute("""{ scan(columns: {name: "zipcode", cast: "bool"}) 375 | { group(by: ["state"], aggregate: {list: {name: "zipcode"}}) { slice(length: 3) { 376 | scan(columns: [{alias: "a", list: {any: {name: "zipcode"}}}, {alias: "b", list: {all: {name: "zipcode"}}}]) { 377 | a: column(name: "a") { ... on BooleanColumn { values } } 378 | b: column(name: "b") { ... on BooleanColumn { values } } 379 | column(name: "zipcode") { type } } } } } }""") 380 | assert data['scan']['group']['slice']['scan'] == { 381 | 'a': {'values': [True, True, True]}, 382 | 'b': {'values': [True, True, True]}, 383 | 'column': {'type': 'list'}, 384 | } 385 | data = client.execute("""{ sc: group(by: ["state", "county"]) { length } 386 | cs: group(by: ["county", "state"]) { length } }""") 387 | assert data['sc']['length'] == data['cs']['length'] == 3216 388 | 389 | 390 | def test_flatten(client): 391 | data = client.execute( 392 | '{ group(by: "state", aggregate: {list: {name: "city"}}) { flatten { columns { city { type } } } } }' 393 | ) 394 | assert data == {'group': {'flatten': {'columns': {'city': {'type': 'string'}}}}} 395 | data = client.execute( 396 | """{ group(by: "state", aggregate: {list: {name: "city"}}) { flatten(indices: "idx") { columns { city { type } } 397 | column(name: "idx") { ... on LongColumn { unique { values counts } } } } } }""" 398 | ) 399 | idx = data['group']['flatten']['column']['unique'] 400 | assert idx['values'] == list(range(52)) 401 | assert sum(idx['counts']) == 41700 402 | assert idx['counts'][0] == 2205 403 | 404 | 405 | def test_aggregate(client): 406 | data = client.execute( 407 | """{ group(by: ["state"] counts: "c", aggregate: {first: [{name: "county"}] 408 | countDistinct: [{name: "city", alias: "cd"}]}) { slice(length: 3) { 409 | c: column(name: "c") { ... on LongColumn { values } } 410 | cd: column(name: "cd") { ... on LongColumn { values } } 411 | columns { state { values } county { values } } } } }""" 412 | ) 413 | assert data['group']['slice'] == { 414 | 'c': {'values': [2205, 176, 703]}, 415 | 'cd': {'values': [1612, 99, 511]}, 416 | 'columns': { 417 | 'state': {'values': ['NY', 'PR', 'MA']}, 418 | 'county': {'values': ['Suffolk', 'Adjuntas', 'Hampden']}, 419 | }, 420 | } 421 | data = client.execute( 422 | """{ group(by: ["state", "county"], aggregate: {list: {name: "city"}, min: {name: "city", alias: "first"}}) { 423 | aggregate(max: {name: "city", alias: "last"}) { slice(length: 3) { 424 | first: column(name: "first") { ... on StringColumn { values } } 425 | last: column(name: "last") { ... on StringColumn { values } } } } } }""" 426 | ) 427 | assert data['group']['aggregate']['slice'] == { 428 | 'first': {'values': ['Amagansett', 'Adjuntas', 'Aguada']}, 429 | 'last': {'values': ['Yaphank', 'Adjuntas', 'Aguada']}, 430 | } 431 | 432 | 433 | def test_runs(client): 434 | data = client.execute("""{ runs(by: ["state"]) { aggregate { length columns { state { values } } 435 | column(name: "county") { type } } } }""") 436 | agg = data['runs']['aggregate'] 437 | assert agg['length'] == 66 438 | assert agg['columns']['state']['values'][:3] == ['NY', 'PR', 'MA'] 439 | assert agg['column']['type'] == 'list' 440 | data = client.execute("""{ sort(by: ["state", "longitude"]) { 441 | runs(by: ["state"], split: [{name: "longitude", gt: 1.0}]) { 442 | length columns { state { values } } 443 | column(name: "longitude") { type } } } }""") 444 | groups = data['sort']['runs'] 445 | assert groups['length'] == 62 446 | assert groups['columns']['state']['values'][:7] == ['AK'] * 7 447 | assert groups['column']['type'] == 'list' 448 | data = client.execute("""{ runs(split: [{name: "state", lt: null}]) { 449 | length column(name: "state") { 450 | ... on ListColumn {values { ... on StringColumn { values } } } } } }""") 451 | assert data['runs']['length'] == 34 452 | assert data['runs']['column']['values'][0]['values'] == (['NY'] * 2) + (['PR'] * 176) 453 | data = client.execute("""{ runs(by: ["state"]) { sort(by: ["state"]) { 454 | columns { state { values } } } } }""") 455 | assert data['runs']['sort']['columns']['state']['values'][-2:] == ['WY', 'WY'] 456 | data = client.execute("""{ runs(by: ["state"]) { slice(offset: 2, length: 2) { 457 | aggregate(count: {name: "zipcode", alias: "c"}) { 458 | column(name: "c") { ... on LongColumn { values } } columns { state { values } } } } } }""") 459 | agg = data['runs']['slice']['aggregate'] 460 | assert agg['column']['values'] == [701, 91] 461 | assert agg['columns']['state']['values'] == ['MA', 'RI'] 462 | data = client.execute("""{ runs(by: ["state"]) { 463 | apply(list: {filter: {gt: [{name: "zipcode"}, {value: 90000}]}}) { 464 | column(name: "zipcode") { type } } } }""") 465 | assert data['runs']['apply']['column']['type'] == 'large_list' 466 | data = client.execute("""{ runs(by: ["state"], counts: "c") { filter(state: {eq: "NY"}) { 467 | column(name: "c") { ... on LongColumn { values } } columns { state { values } } } } }""") 468 | agg = data['runs']['filter'] 469 | assert agg['column']['values'] == [2, 1, 2202] 470 | assert agg['columns']['state']['values'] == ['NY'] * 3 471 | data = client.execute('{ runs(split: {name: "zipcode", lt: 0}) { length } }') 472 | assert data == {'runs': {'length': 1}} 473 | data = client.execute('{ runs(split: {name: "zipcode", lt: null}) { length } }') 474 | assert data == {'runs': {'length': 1}} 475 | 476 | 477 | def test_rows(client): 478 | with pytest.raises(ValueError, match="out of bounds"): 479 | client.execute('{ row(index: 100000) { zipcode } }') 480 | data = client.execute('{ row { state } }') 481 | assert data == {'row': {'state': 'NY'}} 482 | data = client.execute('{ row(index: -1) { state } }') 483 | assert data == {'row': {'state': 'AK'}} 484 | --------------------------------------------------------------------------------