├── .github └── workflows │ ├── build_and_validate.yml │ └── release.yml ├── .gitignore ├── AUTH.md ├── LICENSE ├── README.md ├── RELEASE.md ├── bigquery_frame ├── __init__.py ├── auth.py ├── bigquery_builder.py ├── cli │ ├── __init__.py │ └── diff.py ├── column.py ├── conf.py ├── data_diff │ ├── __init__.py │ ├── compare_dataframes_impl.py │ ├── diff_format_options.py │ ├── diff_per_col.py │ ├── diff_result.py │ ├── diff_result_analyzer.py │ ├── diff_result_summary.py │ ├── diff_stats.py │ ├── export.py │ ├── package.py │ └── schema_diff.py ├── data_type_utils.py ├── dataframe.py ├── dataframe_writer.py ├── exceptions.py ├── field_utils.py ├── fp │ ├── README.md │ ├── __init__.py │ ├── higher_order.py │ ├── package.py │ └── printable_function.py ├── functions.py ├── graph.py ├── graph_impl │ ├── __init__.py │ └── connected_components.py ├── grouped_data.py ├── has_bigquery_client.py ├── nested.py ├── nested_impl │ ├── __init__.py │ ├── fields.py │ ├── package.py │ ├── print_schema.py │ ├── schema_string.py │ ├── select_impl.py │ ├── unnest_all_fields.py │ ├── unnest_field.py │ └── with_fields.py ├── printing.py ├── py.typed ├── special_characters.py ├── temp_names.py ├── transformations.py ├── transformations_impl │ ├── __init__.py │ ├── analyze.py │ ├── analyze_aggs.py │ ├── flatten.py │ ├── harmonize_dataframes.py │ ├── normalize_arrays.py │ ├── pivot_unpivot.py │ ├── sort_all_arrays.py │ ├── sort_columns.py │ ├── transform_all_fields.py │ └── union_dataframes.py ├── units.py └── utils.py ├── dev └── bin │ ├── run_linters.sh │ ├── run_security_checks.sh │ └── run_unit_tests.sh ├── examples ├── data_diff │ └── country_code_iso.py ├── demo.py └── pivot.py ├── pyproject.toml ├── ruff.toml ├── sonar-project.properties ├── test_working_dir └── .gitkeep └── tests ├── __init__.py ├── cli ├── __init__.py └── test_diff.py ├── conftest.py ├── data_diff ├── __init__.py └── test_compare_dataframes_impl.py ├── graph_impl ├── __init__.py └── test_connected_components.py ├── nested_impl ├── __init__.py ├── test_package.py ├── test_unnest_all_fields.py └── test_with_fields.py ├── test_bigquery_builder.py ├── test_column.py ├── test_dataframe.py ├── test_dataframe_writer.py ├── test_functions.py ├── test_grouped_data.py ├── test_has_bigquery_client.py ├── test_utils.py ├── transformations_impl ├── __init__.py ├── test_analyze.py ├── test_pivot_unpivot.py ├── test_transform_all_fields.py └── test_union_dataframes.py └── utils.py /.github/workflows/build_and_validate.yml: -------------------------------------------------------------------------------- 1 | name: Build and Validate 2 | on: 3 | push: 4 | 5 | concurrency: 6 | group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}' 7 | cancel-in-progress: true 8 | 9 | env: 10 | PROJECT_NAME: bigquery_frame 11 | POETRY_VERSION: "1.7.1" 12 | 13 | jobs: 14 | Build-and-Validate: 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.9", "3.10", "3.11", "3.12"] 19 | os: [ubuntu-latest, macos-latest, windows-latest] 20 | runs-on: ${{ matrix.os }} 21 | steps: 22 | - name: Checkout code 23 | uses: actions/checkout@v3 24 | 25 | - name: Set up Python 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Install poetry 31 | uses: abatilo/actions-poetry@v2.0.0 32 | with: 33 | poetry-version: ${{ env.POETRY_VERSION }} 34 | 35 | - name: Poetry lock 36 | run: poetry lock 37 | 38 | - name: Set up poetry cache 39 | uses: actions/setup-python@v4 40 | with: 41 | poetry-version: ${{ env.POETRY_VERSION }} 42 | cache: 'poetry' 43 | 44 | - name: Install project 45 | run: poetry install 46 | 47 | - name: Linter ruff (check format) 48 | run: poetry run ruff format --check . 49 | 50 | - name: Linter ruff (replaces black, isort, flake8 and safety) 51 | run: poetry run ruff check . 52 | continue-on-error: true 53 | 54 | - name: Linter mypy 55 | run: poetry run mypy ${{ env.PROJECT_NAME }} 56 | # We run mypy but ignore the results as there are too many things to fix for now. 57 | continue-on-error: true 58 | 59 | - name: Security safety 60 | run: poetry run safety check 61 | 62 | - name: Run Unit Tests 63 | env: 64 | GCP_CREDENTIALS: ${{ secrets.GCP_CREDENTIALS }} 65 | run: poetry run pytest --cov --cov-report=xml -n 6 66 | 67 | - name: SonarCloud Scan 68 | uses: SonarSource/sonarcloud-github-action@master 69 | env: 70 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Needed to get PR information, if any 71 | SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} 72 | if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' 73 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*.*.*' 7 | 8 | env: 9 | POETRY_VERSION: "1.7.1" 10 | 11 | 12 | jobs: 13 | release: 14 | name: Release 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@v3 19 | 20 | - name: Set up Python 3.10 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: "3.10" 24 | 25 | - name: Install poetry 26 | uses: abatilo/actions-poetry@v2.0.0 27 | with: 28 | poetry-version: ${{ env.POETRY_VERSION }} 29 | 30 | - name: Update PATH 31 | run: echo "$HOME/.local/bin" >> $GITHUB_PATH 32 | 33 | - name: Build project for distribution 34 | run: poetry build 35 | 36 | - name: Check Version 37 | id: check-version 38 | run: | 39 | [[ "$(poetry version --short)" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]] \ 40 | || echo ::set-output name=prerelease::true 41 | 42 | - name: Create Release 43 | uses: ncipollo/release-action@v1 44 | with: 45 | artifacts: "dist/*" 46 | token: ${{ secrets.GITHUB_TOKEN }} 47 | draft: false 48 | prerelease: steps.check-version.outputs.prerelease == 'true' 49 | 50 | - name: Publish to PyPI 51 | env: 52 | POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} 53 | run: poetry publish 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python stuff 2 | __pycache__/ 3 | venv/ 4 | 5 | # PyCharm stuff 6 | .idea/ 7 | 8 | # GCP credentials 9 | gcp-credentials.json 10 | 11 | # Unit Testing 12 | .coverage 13 | coverage.xml 14 | htmlcov/ 15 | 16 | # Python Poetry 17 | poetry.lock 18 | dist/ 19 | 20 | # project files 21 | /test_working_dir 22 | /diff_report.html 23 | -------------------------------------------------------------------------------- /AUTH.md: -------------------------------------------------------------------------------- 1 | # Authentication 2 | 3 | This is a quick walkthrough of the possible way to configure this project and make it run on your BigQuery project. 4 | 5 | ## Disclaimer 6 | 7 | Please remember this project is just a POC. 8 | 9 | Be assured that the original author has no ill intent, but be also warned that 10 | the original author declines all responsibility for any user suffering any loss, 11 | GCP billing cost increase, or prejudice caused by: 12 | - Any feature, bug or error in the code 13 | - Any malicious code introduced into this project by a third party using any mean 14 | 15 | (malicious fork, dependency injection, dependency confusion, etc.) 16 | 17 | This documentation aims at informing the user as well as possible about the supported ways to connect this project 18 | to the user's GCP account, and the risks associated with each method. To learn more about GCP authentication, 19 | please refer to the official GCP documentation: 20 | https://cloud.google.com/docs/authentication/best-practices-applications 21 | 22 | 23 | ## Method 1 : Use application default credentials 24 | 25 | - **Documentation**: https://cloud.google.com/bigquery/docs/authentication/getting-started 26 | - **Pros**: easy as pie 27 | - **Cons**: not very safe 28 | - **Advice**: 29 | - Only do this with code that you trust 30 | - Use this with a dummy GCP account that only have acces to a sandbox project. 31 | - Don't use this method in production 32 | 33 | ### Step 1. Install gcloud 34 | 35 | Follow the instructions here: https://cloud.google.com/sdk/docs/install 36 | 37 | ### Step 2. Generate application-default login 38 | 39 | _(This step is necessary if you run this locally, but can be skipped if you run 40 | this project directly from inside GCP where the application default are pre-configured)_ 41 | 42 | Run this command in a terminal: 43 | ```shell 44 | gcloud auth application-default login 45 | ``` 46 | 47 | You can revoke the credentials at any time by running: 48 | ```shell 49 | gcloud auth application-default revoke 50 | ``` 51 | 52 | ### Step 3. Pass the name of your GCP project to bigquery-frame 53 | 54 | #### Option A. Update bigquery-frame's configuration directly in your client code 55 | 56 | ```python 57 | import bigquery_frame 58 | 59 | bigquery_frame.conf.GCP_PROJECT = "Name of your BigQuery project" 60 | ``` 61 | 62 | #### _OR_ 63 | 64 | #### Option B. Set it as a variable in your environment 65 | 66 | ```shell 67 | export GCP_PROJECT="Name of your BigQuery project" 68 | ``` 69 | 70 | ## Method 2 : Use a service account 71 | 72 | - **Documentation**: https://cloud.google.com/bigquery/docs/authentication/service-account-file 73 | - **Pros**: More secure 74 | - **Cons**: A little more work involved 75 | - **Advice**: 76 | - We recommend this method 77 | - Use a dedicated service account for this project 78 | - Only give it the minimal access necessary for your test 79 | 80 | ### Step 1. Create a service account 81 | 82 | Go to your project's Service Account page: https://console.cloud.google.com/iam-admin/serviceaccounts 83 | 84 | _(Please make sure you select the correct project.)_ 85 | 86 | #### Create a new service account 87 | For example, you can call it `bigquery-frame-poc`. 88 | 89 | #### Grant the following roles to the service account 90 | 91 | You can grant it the following rights 92 | - `BigQuery Job User` on the project you want. 93 | - `BigQuery Data Viewer` on the project (or just on the specific datasets) that you want. 94 | 95 | _(If you want to grant access to a specific dataset, this can be done after the 96 | service account is created, directly in the BigQuery console, by clicking 97 | the "Share Dataset" button on a Dataset's panel)_ 98 | 99 | ### Step 2. Create and download a new json key for this service account 100 | 101 | Once the service account is created, click on it, go to the "KEYS" tab, 102 | and click on the "ADD KEY" button. You will automatically download 103 | a json Oauth2 file for this service account. Store it somewhere on your 104 | computer. 105 | 106 | _(If you have forked this repo and stored the credentials inside, 107 | be careful not to commit it accidentally, use `.gitignore`)_ 108 | 109 | ### Step 3. Pass the json key to bigquery-frame 110 | 111 | There are two possible variants here: 112 | - Method 2.A: pass the path to the json file to bigquery-frame 113 | - Method 2.B: pass directly the content of the json file to bigquery-frame 114 | 115 | The first method is generally simpler to set up a local development environment, while the second method 116 | is generally easier for setting up automated CI pipelines. 117 | 118 | #### Method 2.A: pass the path to the json file to bigquery-frame 119 | 120 | ##### Option 2.A.1: Update bigquery-frame's configuration directly in your client code 121 | 122 | ```python 123 | import bigquery_frame 124 | 125 | bigquery_frame.conf.GCP_CREDENTIALS_PATH = "Path to your service account credentials json file" 126 | ``` 127 | 128 | ##### _OR_ 129 | 130 | ##### Option 2.A.2: Set it as a variable in your environment 131 | 132 | ```shell 133 | export GCP_CREDENTIALS_PATH="Path to your service account credentials json file" 134 | ``` 135 | 136 | #### Method 2.B: pass the content of the json file to bigquery-frame 137 | 138 | When using this method, be careful to not accidentally get escaped newline characters `"\n"` 139 | in your json content. 140 | 141 | ##### Option 2.B.1: Update bigquery-frame's configuration directly in your client code 142 | 143 | ```python 144 | import bigquery_frame 145 | 146 | bigquery_frame.conf.GCP_CREDENTIALS = "Content of your credentials json file" 147 | ``` 148 | 149 | ##### _OR_ 150 | 151 | ##### Option 2.B.2: Set it as a variable in your environment 152 | 153 | ```shell 154 | export GCP_CREDENTIALS="Content of your credentials json file" 155 | ``` 156 | 157 | 158 | ## Method 3 : Do it your way 159 | 160 | The constructor of the `BigQueryBuilder` class takes a `google.cloud.bigquery.Client` 161 | as argument, allowing the users to instantiate the client in any other way they 162 | might prefer. 163 | 164 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | The release is automatically handled by the [`.github/workflows/release.yml`](.github/workflows/release.yml) GitHub 2 | action pipeline. 3 | 4 | Make sure the version matches the upstream version and increase the last digit in the version number. 5 | 6 | ### Bump version 7 | 8 | We use the tool [bump-my-version](https://github.com/callowayproject/bump-my-version) to handle version changes. 9 | 10 | ``` 11 | # 0.1.0 -> 0.1.1 12 | poetry run bump-my-version bump patch 13 | 14 | # 0.1.1 -> 0.2.0 15 | poetry run bump-my-version bump minor 16 | ``` 17 | 18 | ### Release 19 | 20 | - [ ] add release notes to README 21 | - [ ] bumpversion 22 | - [ ] `git push` 23 | - [ ] check build 24 | - [ ] `git push --tags` 25 | - [ ] Check docs with `poetry run mkdocs serve` 26 | - [ ] `poetry run mkdocs gh-deploy` 27 | -------------------------------------------------------------------------------- /bigquery_frame/__init__.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.bigquery_builder import BigQueryBuilder 2 | from bigquery_frame.column import Column 3 | from bigquery_frame.dataframe import DataFrame 4 | from bigquery_frame.utils import _ref 5 | 6 | _ref(BigQueryBuilder) 7 | _ref(Column) 8 | _ref(DataFrame) 9 | 10 | __version__ = "0.5.0" 11 | -------------------------------------------------------------------------------- /bigquery_frame/auth.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from json import JSONDecodeError 4 | from typing import Optional 5 | 6 | import google 7 | from google.cloud.bigquery import Client 8 | from google.oauth2 import service_account 9 | 10 | from bigquery_frame import conf 11 | 12 | 13 | def _get_bq_client_from_credential_files() -> Optional[Client]: 14 | gcp_credentials_path = os.getenv("GCP_CREDENTIALS_PATH") or conf.GCP_CREDENTIALS_PATH 15 | if gcp_credentials_path.endswith(".json"): 16 | credentials = service_account.Credentials.from_service_account_file( 17 | filename=gcp_credentials_path, 18 | scopes=["https://www.googleapis.com/auth/cloud-platform"], 19 | ) 20 | client = google.cloud.bigquery.Client(credentials=credentials, project=credentials.project_id) 21 | return client 22 | else: 23 | return None 24 | 25 | 26 | def _get_bq_client_from_credentials() -> Optional[Client]: 27 | gcp_credentials = os.getenv("GCP_CREDENTIALS") or conf.GCP_CREDENTIALS 28 | try: 29 | json_credentials = json.loads(gcp_credentials) 30 | except JSONDecodeError: 31 | return None 32 | credentials = service_account.Credentials.from_service_account_info( 33 | info=json_credentials, 34 | scopes=["https://www.googleapis.com/auth/cloud-platform"], 35 | ) 36 | client = google.cloud.bigquery.Client(credentials=credentials, project=credentials.project_id) 37 | return client 38 | 39 | 40 | def _get_bq_client_from_default() -> Client: 41 | gcp_project = os.getenv("GCP_PROJECT") or conf.GCP_PROJECT 42 | client = google.cloud.bigquery.Client(gcp_project) 43 | return client 44 | 45 | 46 | def get_bq_client(): 47 | client = _get_bq_client_from_credentials() 48 | if client is None: 49 | client = _get_bq_client_from_credential_files() 50 | if client is None: 51 | client = _get_bq_client_from_default() 52 | return client 53 | -------------------------------------------------------------------------------- /bigquery_frame/bigquery_builder.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional 2 | 3 | from google.cloud.bigquery import Client, SchemaField 4 | from google.cloud.bigquery.table import RowIterator 5 | 6 | import bigquery_frame 7 | from bigquery_frame.auth import get_bq_client 8 | from bigquery_frame.has_bigquery_client import HasBigQueryClient 9 | from bigquery_frame.temp_names import _get_temp_table_name 10 | from bigquery_frame.utils import indent, quote, strip_margin 11 | 12 | if TYPE_CHECKING: 13 | from bigquery_frame import DataFrame 14 | 15 | 16 | class BigQueryBuilder(HasBigQueryClient): 17 | def __init__(self, client: Optional[Client] = None, use_session: bool = True, debug: bool = False): 18 | if client is None: 19 | client = get_bq_client() 20 | super().__init__(client, use_session) 21 | self._views: dict[str, "DataFrame"] = {} 22 | self._temp_tables: set[str] = set() 23 | self.debug = debug 24 | 25 | def table(self, full_table_name: str) -> "DataFrame": 26 | """Returns the specified table as a :class:`DataFrame`.""" 27 | from bigquery_frame import DataFrame 28 | 29 | query = f"""SELECT * FROM {quote(full_table_name)}""" 30 | return DataFrame(query, alias=None, bigquery=self) 31 | 32 | def sql(self, sql_query) -> "DataFrame": 33 | from bigquery_frame import DataFrame 34 | 35 | """Returns a :class:`DataFrame` representing the result of the given query.""" 36 | return DataFrame(sql_query, None, self) 37 | 38 | def _generate_header(self) -> str: 39 | return f"/* This query was generated using bigquery-frame v{bigquery_frame.__version__} */\n" 40 | 41 | def _get_query_schema(self, query: str) -> list[SchemaField]: 42 | query = self._generate_header() + query 43 | return super()._get_query_schema(query) 44 | 45 | def _execute_query(self, query: str, use_query_cache=True) -> RowIterator: 46 | query = self._generate_header() + query 47 | return super()._execute_query(query, use_query_cache=use_query_cache) 48 | 49 | def _registerDataFrameAsTempView(self, df: "DataFrame", alias: str) -> None: 50 | self._views[alias] = df 51 | 52 | def _registerDataFrameAsTempTable(self, df: "DataFrame", alias: Optional[str] = None) -> "DataFrame": 53 | if alias is None: 54 | alias = _get_temp_table_name() 55 | query = f"CREATE OR REPLACE TEMP TABLE {quote(alias)} AS \n" + df.compile() 56 | self._execute_query(query) 57 | return self.table(alias) 58 | 59 | def _compile_views(self) -> dict[str, str]: 60 | return { 61 | alias: strip_margin( 62 | f"""{quote(alias)} AS ( 63 | |{indent(df._compile_with_deps(), 2)} 64 | |)""", 65 | ) 66 | for alias, df in self._views.items() 67 | } 68 | 69 | def _check_alias(self, new_alias, deps: list[tuple[str, "DataFrame"]]) -> None: 70 | """Checks that the alias follows BigQuery constraints, such as: 71 | 72 | - BigQuery does not allow having two CTEs with the same name in a query. 73 | 74 | :param new_alias: 75 | :param deps: 76 | :return: None 77 | :raises: an Exception if something that does not comply with BigQuery's rules is found. 78 | """ 79 | collisions = [alias for alias, df in list(self._views.items()) + deps if alias == new_alias] 80 | if len(collisions) > 0: 81 | raise ValueError(f"Duplicate alias {new_alias}") 82 | -------------------------------------------------------------------------------- /bigquery_frame/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/bigquery_frame/cli/__init__.py -------------------------------------------------------------------------------- /bigquery_frame/cli/diff.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from argparse import ArgumentParser 3 | 4 | from bigquery_frame import BigQueryBuilder 5 | from bigquery_frame.data_diff import compare_dataframes 6 | from bigquery_frame.data_diff.diff_format_options import DiffFormatOptions, DEFAULT_NB_DIFFED_ROWS 7 | from bigquery_frame.data_diff.export import DEFAULT_HTML_REPORT_OUTPUT_FILE_PATH 8 | 9 | 10 | def main(argv: list[str] = None): 11 | if argv is None: 12 | argv = sys.argv[1:] 13 | if len(argv) == 0: 14 | argv = ["--help"] 15 | parser = ArgumentParser(description="Compare two BigQuery Tables and generate a HTML report", prog="bq-diff") 16 | parser.add_argument( 17 | "--tables", 18 | nargs=2, 19 | metavar=("LEFT_TABLE", "RIGHT_TABLE"), 20 | type=str, 21 | help="Fully qualified names of the two tables to compare", 22 | ) 23 | parser.add_argument( 24 | "--join-cols", 25 | nargs="*", 26 | default=None, 27 | type=str, 28 | help="Name of the fields used to join the DataFrames together. " 29 | "Each row should be uniquely identifiable using these fields. " 30 | "Fields inside repeated structs are also supported.", 31 | ) 32 | parser.add_argument( 33 | "--output", 34 | default=None, 35 | type=str, 36 | help="Path of the HTML report to generate.", 37 | ) 38 | parser.add_argument( 39 | "--nb-top-values", 40 | default=DEFAULT_NB_DIFFED_ROWS, 41 | type=int, 42 | help="Number of most frequent change/values to display in the diff for each column " 43 | f"(Default: {DEFAULT_NB_DIFFED_ROWS}).", 44 | ) 45 | args = parser.parse_args(argv) 46 | left_table, right_table = args.tables 47 | bq = BigQueryBuilder() 48 | left_df = bq.table(left_table) 49 | right_df = bq.table(right_table) 50 | diff_result = compare_dataframes(left_df, right_df, args.join_cols) 51 | if args.output is not None: 52 | output_path = args.output 53 | else: 54 | output_path = DEFAULT_HTML_REPORT_OUTPUT_FILE_PATH 55 | diff_format_options = DiffFormatOptions(nb_diffed_rows=args.nb_top_values) 56 | 57 | diff_result.export_to_html(output_file_path=output_path, diff_format_options=diff_format_options) 58 | -------------------------------------------------------------------------------- /bigquery_frame/conf.py: -------------------------------------------------------------------------------- 1 | # This file may be edited by the user. 2 | # Please read AUTH.md first before changing anything. 3 | 4 | STRUCT_SEPARATOR = "." 5 | REPETITION_MARKER = "!" 6 | STRUCT_SEPARATOR_REPLACEMENT = "__STRUCT__" 7 | REPETITION_MARKER_REPLACEMENT = "__ARRAY__" 8 | 9 | # Method 1. Set this variable here or set it as an environment variable 10 | GCP_PROJECT = "Name of your BigQuery project" 11 | 12 | 13 | # Method 2.A Set this variable here or set it as an environment variable 14 | GCP_CREDENTIALS_PATH = "Path to your service account credentials json file" 15 | 16 | 17 | # Method 2.B Set this variable here or set it as an environment variable 18 | GCP_CREDENTIALS = "Content of your credentials json file" 19 | -------------------------------------------------------------------------------- /bigquery_frame/data_diff/__init__.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.data_diff.compare_dataframes_impl import compare_dataframes 2 | from bigquery_frame.utils import _ref 3 | 4 | _ref(compare_dataframes) 5 | -------------------------------------------------------------------------------- /bigquery_frame/data_diff/diff_format_options.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | DEFAULT_NB_DIFFED_ROWS = 10 4 | DEFAULT_MAX_STRING_LENGTH = 30 5 | DEFAULT_LEFT_DF_ALIAS = "left" 6 | DEFAULT_RIGHT_DF_ALIAS = "right" 7 | 8 | 9 | @dataclass 10 | class DiffFormatOptions: 11 | nb_diffed_rows: int = DEFAULT_NB_DIFFED_ROWS 12 | max_string_length: int = DEFAULT_MAX_STRING_LENGTH 13 | left_df_alias: str = DEFAULT_LEFT_DF_ALIAS 14 | right_df_alias: str = DEFAULT_RIGHT_DF_ALIAS 15 | -------------------------------------------------------------------------------- /bigquery_frame/data_diff/diff_result_summary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame.data_diff.schema_diff import SchemaDiffResult 5 | 6 | 7 | @dataclass 8 | class DiffStatsForColumn: 9 | column_name: str 10 | """Name of the column""" 11 | total: int 12 | """Total number of rows after joining the two DataFrames""" 13 | no_change: int 14 | """Number of rows where this column is identical in both DataFrames""" 15 | changed: int 16 | """Number of rows that are present in both DataFrames but where this column has different values""" 17 | only_in_left: int 18 | """Number of rows that are only present in the left DataFrame""" 19 | only_in_right: int 20 | """Number of rows that are only present in the right DataFrame""" 21 | 22 | 23 | @dataclass 24 | class DiffResultSummary: 25 | left_df_alias: str 26 | right_df_alias: str 27 | diff_per_col_df: DataFrame 28 | 29 | schema_diff_result: SchemaDiffResult 30 | join_cols: list[str] 31 | same_schema: bool 32 | same_data: bool 33 | total_nb_rows: int 34 | -------------------------------------------------------------------------------- /bigquery_frame/data_diff/diff_stats.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class DiffStats: 6 | total: int 7 | """Total number of rows after joining the two DataFrames""" 8 | no_change: int 9 | """Number of rows that are identical in both DataFrames""" 10 | changed: int 11 | """Number of rows that are present in both DataFrames but that have different values""" 12 | in_left: int 13 | """Number of rows in the left DataFrame""" 14 | in_right: int 15 | """Number of rows in the right DataFrame""" 16 | only_in_left: int 17 | """Number of rows that are only present in the left DataFrame""" 18 | only_in_right: int 19 | """Number of rows that are only present in the right DataFrame""" 20 | 21 | @property 22 | def same_data(self) -> bool: 23 | return self.no_change == self.total 24 | 25 | @property 26 | def percent_changed(self) -> float: 27 | return round(self.changed * 100.0 / self.total, 2) 28 | 29 | @property 30 | def percent_no_change(self) -> float: 31 | return round(self.no_change * 100.0 / self.total, 2) 32 | 33 | @property 34 | def percent_only_in_left(self) -> float: 35 | return round(self.only_in_left * 100.0 / self.total, 2) 36 | 37 | @property 38 | def percent_only_in_right(self) -> float: 39 | return round(self.only_in_right * 100.0 / self.total, 2) 40 | 41 | 42 | def print_diff_stats_shard(diff_stats: DiffStats, left_df_alias: str, right_df_alias: str) -> None: 43 | nb_row_diff = diff_stats.in_right - diff_stats.in_left 44 | if nb_row_diff != 0: 45 | if nb_row_diff > 0: 46 | more_less = "more" 47 | else: 48 | more_less = "less" 49 | print("\nRow count changed: ") 50 | print(f"{left_df_alias}: {diff_stats.in_left} rows") 51 | print(f"{right_df_alias}: {diff_stats.in_right} rows ({abs(nb_row_diff)} {more_less})") 52 | print("") 53 | else: 54 | print(f"Row count ok: {diff_stats.in_right} rows") 55 | print("") 56 | print(f"{diff_stats.no_change} ({diff_stats.percent_no_change}%) rows are identical") 57 | print(f"{diff_stats.changed} ({diff_stats.percent_changed}%) rows have changed") 58 | if diff_stats.only_in_left > 0: 59 | print(f"{diff_stats.only_in_left} ({diff_stats.percent_only_in_left}%) rows are only in '{left_df_alias}'") 60 | if diff_stats.only_in_right > 0: 61 | print(f"{diff_stats.only_in_right} ({diff_stats.percent_only_in_right}%) rows are only in '{right_df_alias}") 62 | print("") 63 | -------------------------------------------------------------------------------- /bigquery_frame/data_diff/export.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import bigquery_frame 6 | from bigquery_frame.data_diff.diff_result_summary import DiffResultSummary 7 | 8 | DEFAULT_HTML_REPORT_OUTPUT_FILE_PATH = "diff_report.html" 9 | DEFAULT_HTML_REPORT_ENCODING = "utf-8" 10 | 11 | 12 | def export_html_diff_report( 13 | diff_result_summary: DiffResultSummary, 14 | title: Optional[str] = None, 15 | output_file_path: str = DEFAULT_HTML_REPORT_OUTPUT_FILE_PATH, 16 | encoding: str = DEFAULT_HTML_REPORT_ENCODING, 17 | ) -> None: 18 | from data_diff_viewer import DiffSummary, generate_report_string 19 | 20 | with tempfile.TemporaryDirectory() as temp_dir: 21 | temp_dir_path = Path(temp_dir) 22 | diff_per_col_parquet_path = temp_dir_path / "diff_per_col.parquet" 23 | diff_result_summary.diff_per_col_df.toPandas().to_parquet(diff_per_col_parquet_path) 24 | if title is None: 25 | report_title = f"{diff_result_summary.left_df_alias} vs {diff_result_summary.right_df_alias}" 26 | else: 27 | report_title = title 28 | column_names_diff = {k: v.value for k, v in diff_result_summary.schema_diff_result.column_names_diff.items()} 29 | diff_summary = DiffSummary( 30 | generated_with=f"{bigquery_frame.__name__}:{bigquery_frame.__version__}", 31 | left_df_alias=diff_result_summary.left_df_alias, 32 | right_df_alias=diff_result_summary.right_df_alias, 33 | join_cols=diff_result_summary.join_cols, 34 | same_schema=diff_result_summary.same_schema, 35 | schema_diff_str=diff_result_summary.schema_diff_result.diff_str, 36 | column_names_diff=column_names_diff, 37 | same_data=diff_result_summary.same_data, 38 | total_nb_rows=diff_result_summary.total_nb_rows, 39 | ) 40 | report = generate_report_string(report_title, diff_summary, temp_dir_path, diff_per_col_parquet_path) 41 | output_path = Path(output_file_path) 42 | with output_path.open("w", encoding=encoding) as output: 43 | output.write(report) 44 | print(f"Report exported as {output_file_path}") 45 | -------------------------------------------------------------------------------- /bigquery_frame/data_diff/schema_diff.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | 5 | from google.cloud.bigquery import SchemaField 6 | 7 | from bigquery_frame import DataFrame 8 | from bigquery_frame.conf import REPETITION_MARKER 9 | from bigquery_frame.data_type_utils import flatten_schema 10 | from bigquery_frame.dataframe import is_nullable, is_repeated, is_struct 11 | from bigquery_frame.field_utils import has_same_granularity_as_any, is_parent_field_of_any 12 | 13 | 14 | class DiffPrefix(str, Enum): 15 | ADDED = "+" 16 | REMOVED = "-" 17 | UNCHANGED = " " 18 | 19 | def __repr__(self) -> str: 20 | return f"'{self.value}'" 21 | 22 | 23 | @dataclass 24 | class SchemaDiffResult: 25 | same_schema: bool 26 | diff_str: str 27 | nb_cols: int 28 | column_names_diff: dict[str, DiffPrefix] 29 | """The diff per column names. 30 | Used to determine which columns appeared or disappeared and the order in which the columns shall be displayed""" 31 | 32 | def display(self) -> None: 33 | if not self.same_schema: 34 | print(f"Schema has changed:\n{self.diff_str}") 35 | print("WARNING: columns that do not match both sides will be ignored") 36 | else: 37 | print(f"Schema: ok ({self.nb_cols})") 38 | 39 | @property 40 | def column_names(self) -> list[str]: 41 | return list(self.column_names_diff.keys()) 42 | 43 | 44 | def _schema_to_string( 45 | schema: list[SchemaField], 46 | include_nullable: bool = False, 47 | include_description: bool = False, 48 | ) -> list[str]: 49 | """Return a list of strings representing the schema 50 | 51 | Args: 52 | schema: A DataFrame's schema 53 | include_nullable: Indicate for each field if it is nullable 54 | include_description: Add field description 55 | 56 | Returns: 57 | 58 | Examples: 59 | >>> from bigquery_frame import BigQueryBuilder 60 | >>> bq = BigQueryBuilder() 61 | >>> df = bq.sql('''SELECT 1 as id, "a" as c1, 1 as c2''') 62 | >>> print('\\n'.join(_schema_to_string(df.schema))) 63 | id INTEGER 64 | c1 STRING 65 | c2 INTEGER 66 | >>> print('\\n'.join(_schema_to_string(df.schema, include_nullable=True))) 67 | id INTEGER (nullable) 68 | c1 STRING (nullable) 69 | c2 INTEGER (nullable) 70 | >>> schema = [ 71 | ... SchemaField(name='id', field_type='INTEGER', mode='NULLABLE', description='An id'), 72 | ... SchemaField(name='c1', field_type='STRING', mode='REQUIRED', description='A string column'), 73 | ... SchemaField(name='c2', field_type='INTEGER', mode='NULLABLE', description='An int column') 74 | ... ] 75 | >>> print('\\n'.join(_schema_to_string(schema, include_nullable=True, include_description=True))) 76 | id INTEGER (nullable) An id 77 | c1 STRING (required) A string column 78 | c2 INTEGER (nullable) An int column 79 | >>> df = bq.sql('''SELECT 1 as id, STRUCT(2 as a, [STRUCT(3 as c, 4 as d, [5] as e)] as b) as s''') 80 | >>> print('\\n'.join(_schema_to_string(df.schema))) 81 | id INTEGER 82 | s STRUCT>>> 83 | """ 84 | 85 | def field_to_string(field: SchemaField, sep=":") -> str: 86 | if is_struct(field): 87 | tpe = f"""STRUCT<{",".join(field_to_string(f) for f in field.fields)}>""" 88 | if is_repeated(field): 89 | tpe = f"""ARRAY<{tpe}>""" 90 | elif is_repeated(field): 91 | tpe = f"ARRAY<{field.field_type}>" 92 | else: 93 | tpe = f"{field.field_type}" 94 | return f"""{field.name}{sep}{tpe}""" 95 | 96 | def meta_str(field) -> str: 97 | s = "" 98 | if include_nullable: 99 | if is_nullable(field): 100 | s += " (nullable)" 101 | else: 102 | s += " (required)" 103 | if include_description: 104 | s += f" {field.description}" 105 | return s 106 | 107 | return [field_to_string(field, sep=" ") + meta_str(field) for field in schema] 108 | 109 | 110 | def diff_dataframe_schemas(left_df: DataFrame, right_df: DataFrame, join_cols: list[str]) -> SchemaDiffResult: 111 | """Compares two DataFrames schemas and print out the differences. 112 | Ignore the nullable and comment attributes. 113 | 114 | Args: 115 | left_df: A DataFrame 116 | right_df: Another DataFrame 117 | join_cols: The list of column names that will be used for joining the two DataFrames together 118 | 119 | Returns: 120 | A SchemaDiffResult object 121 | 122 | Examples: 123 | >>> from bigquery_frame import BigQueryBuilder 124 | >>> bq = BigQueryBuilder() 125 | >>> left_df = bq.sql('''SELECT 1 as id, "" as c1, "" as c2, [STRUCT(2 as a, "" as b)] as c4''') 126 | >>> right_df = bq.sql('''SELECT 1 as id, 2 as c1, "" as c3, [STRUCT(3 as a, "" as d)] as c4''') 127 | >>> schema_diff_result = diff_dataframe_schemas(left_df, right_df, ["id"]) 128 | >>> schema_diff_result.display() 129 | Schema has changed: 130 | @@ -1,4 +1,4 @@ 131 | 132 | id INTEGER 133 | -c1 STRING 134 | -c2 STRING 135 | -c4 ARRAY> 136 | +c1 INTEGER 137 | +c3 STRING 138 | +c4 ARRAY> 139 | WARNING: columns that do not match both sides will be ignored 140 | >>> schema_diff_result.same_schema 141 | False 142 | >>> schema_diff_result.column_names_diff 143 | {'id': ' ', 'c1': ' ', 'c2': '-', 'c3': '+', 'c4': ' '} 144 | 145 | >>> schema_diff_result = diff_dataframe_schemas(left_df, right_df, ["id", "c4!.a"]) 146 | >>> schema_diff_result.display() 147 | Schema has changed: 148 | @@ -1,5 +1,5 @@ 149 | 150 | id INTEGER 151 | -c1 STRING 152 | -c2 STRING 153 | +c1 INTEGER 154 | +c3 STRING 155 | c4!.a INTEGER 156 | -c4!.b STRING 157 | +c4!.d STRING 158 | WARNING: columns that do not match both sides will be ignored 159 | >>> schema_diff_result.same_schema 160 | False 161 | >>> schema_diff_result.column_names_diff 162 | {'id': ' ', 'c1': ' ', 'c2': '-', 'c3': '+', 'c4!.a': ' ', 'c4!.b': '-', 'c4!.d': '+'} 163 | """ 164 | 165 | def explode_schema_according_to_join_cols(schema: list[SchemaField]) -> list[SchemaField]: 166 | exploded_schema = flatten_schema(schema, explode=True, keep_non_leaf_fields=True) 167 | return [ 168 | field 169 | for field in exploded_schema 170 | if has_same_granularity_as_any(field.name, join_cols) 171 | and not is_parent_field_of_any(field.name, join_cols) 172 | and not (is_struct(field) and not is_repeated(field)) 173 | and not field.name.endswith(REPETITION_MARKER) 174 | ] 175 | 176 | left_schema_flat_exploded = explode_schema_according_to_join_cols(left_df.schema) 177 | right_schema_flat_exploded = explode_schema_according_to_join_cols(right_df.schema) 178 | 179 | left_schema: list[str] = _schema_to_string(left_schema_flat_exploded) 180 | right_schema: list[str] = _schema_to_string(right_schema_flat_exploded) 181 | left_columns_flat: list[str] = [field.name for field in left_schema_flat_exploded] 182 | right_columns_flat: list[str] = [field.name for field in right_schema_flat_exploded] 183 | 184 | diff_str = list(difflib.unified_diff(left_schema, right_schema, n=10000))[2:] 185 | column_names_diff = _diff_dataframe_column_names(left_columns_flat, right_columns_flat) 186 | same_schema = len(diff_str) == 0 187 | if same_schema: 188 | diff_str = left_schema 189 | return SchemaDiffResult( 190 | same_schema=same_schema, 191 | diff_str="\n".join(diff_str), 192 | nb_cols=len(left_df.columns), 193 | column_names_diff=column_names_diff, 194 | ) 195 | 196 | 197 | def _remove_potential_duplicates_from_diff(diff: list[str]) -> list[str]: 198 | """In some cases (e.g. swapping the order of two columns), the difflib.unified_diff produces results 199 | where a column is added and then removed. This method replaces such duplicates with a single occurrence 200 | of the column marked as unchanged. We keep the column ordering of the left side. 201 | 202 | Examples: 203 | >>> _remove_potential_duplicates_from_diff([' id', ' col1', '+col4', '+col3', ' col2', '-col3', '-col4']) 204 | [' id', ' col1', ' col2', ' col3', ' col4'] 205 | 206 | """ 207 | plus = {row[1:] for row in diff if row[0] == DiffPrefix.ADDED} 208 | minus = {row[1:] for row in diff if row[0] == DiffPrefix.REMOVED} 209 | both = plus.intersection(minus) 210 | return [ 211 | DiffPrefix.UNCHANGED + row[1:] if row[1:] in both else row 212 | for row in diff 213 | if (row[1:] not in both) or (row[0] == DiffPrefix.REMOVED) 214 | ] 215 | 216 | 217 | def _diff_dataframe_column_names(left_col_names: list[str], right_col_names: list[str]) -> dict[str, DiffPrefix]: 218 | """Compares the column names of two DataFrames. 219 | 220 | Returns a list of column names that preserves the ordering of the left DataFrame when possible. 221 | The columns names are prefixed by a character according to the following convention: 222 | 223 | - ' ' if the column exists in both DataFrame 224 | - '-' if it only exists in the left DataFrame 225 | - '+' if it only exists in the right DataFrame 226 | 227 | Args: 228 | left_col_names: A list 229 | right_col_names: Another DataFrame 230 | 231 | Returns: 232 | A list of column names prefixed with a character: ' ', '+' or '-' 233 | 234 | Examples: 235 | >>> left_cols = ["id", "col1", "col2", "col3"] 236 | >>> right_cols = ["id", "col1", "col4", "col3"] 237 | >>> _diff_dataframe_column_names(left_cols, right_cols) 238 | {'id': ' ', 'col1': ' ', 'col2': '-', 'col4': '+', 'col3': ' '} 239 | >>> _diff_dataframe_column_names(left_cols, left_cols) 240 | {'id': ' ', 'col1': ' ', 'col2': ' ', 'col3': ' '} 241 | 242 | >>> left_cols = ["id", "col1", "col2", "col3", "col4"] 243 | >>> right_cols = ["id", "col1", "col4", "col3", "col2"] 244 | >>> _diff_dataframe_column_names(left_cols, right_cols) 245 | {'id': ' ', 'col1': ' ', 'col2': ' ', 'col3': ' ', 'col4': ' '} 246 | 247 | """ 248 | diff = list(difflib.unified_diff(left_col_names, right_col_names, n=10000))[2:] 249 | same_schema = len(diff) == 0 250 | if same_schema: 251 | list_result = [DiffPrefix.UNCHANGED + s for s in left_col_names] 252 | else: 253 | list_result = _remove_potential_duplicates_from_diff(diff[1:]) 254 | return {s[1:]: DiffPrefix(s[0]) for s in list_result} 255 | -------------------------------------------------------------------------------- /bigquery_frame/exceptions.py: -------------------------------------------------------------------------------- 1 | class IllegalArgumentException(Exception): 2 | """Passed an illegal or inappropriate argument.""" 3 | 4 | 5 | class AnalysisException(Exception): 6 | """Exception raised when an anomaly is detected during the preparation of a transformation.""" 7 | 8 | 9 | class UnsupportedOperationException(Exception): 10 | """When the user does an operation that is not supported.""" 11 | 12 | 13 | class UnexpectedException(Exception): 14 | """Exception raised when something that is not supposed to happen happens""" 15 | 16 | issue_submit_url = "https://github.com/FurcyPin/bigquery-frame/issues/new" 17 | 18 | def __init__(self, error: str) -> None: 19 | msg = ( 20 | f"An unexpected error occurred: {error}" 21 | f"\nPlease report a bug with the complete stacktrace at {self.issue_submit_url}" 22 | ) 23 | Exception.__init__(self, msg) 24 | 25 | 26 | class DataframeComparatorException(Exception): 27 | """Exception happening during data diff.""" 28 | 29 | 30 | class CombinatorialExplosionError(DataframeComparatorException): 31 | """Exception happening before a join when we detect that the join key is incorrect, 32 | which would lead to a combinatorial explosion. 33 | """ 34 | -------------------------------------------------------------------------------- /bigquery_frame/field_utils.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.conf import REPETITION_MARKER 2 | 3 | 4 | def is_sub_field_or_equal(sub_field: str, field: str) -> bool: 5 | """Return True if `sub_field` is a sub-field of `field` 6 | 7 | >>> is_sub_field_or_equal("a", "a") 8 | True 9 | >>> is_sub_field_or_equal("a", "b") 10 | False 11 | 12 | >>> is_sub_field_or_equal("a.b", "a") 13 | True 14 | >>> is_sub_field_or_equal("a.b", "b") 15 | False 16 | 17 | >>> is_sub_field_or_equal("a", "a.b") 18 | False 19 | 20 | """ 21 | return sub_field == field or sub_field.startswith(field + ".") 22 | 23 | 24 | def is_sub_field_or_equal_to_any(sub_field: str, fields: list[str]) -> bool: 25 | """Return True if `sub_field` is a sub-field of any field in `fields` 26 | 27 | >>> is_sub_field_or_equal_to_any("a", ["a", "b"]) 28 | True 29 | >>> is_sub_field_or_equal_to_any("a", ["b", "c"]) 30 | False 31 | 32 | >>> is_sub_field_or_equal_to_any("a.b", ["a", "b"]) 33 | True 34 | >>> is_sub_field_or_equal_to_any("a.b", ["b", "c"]) 35 | False 36 | 37 | >>> is_sub_field_or_equal_to_any("a", ["a.b"]) 38 | False 39 | 40 | >>> is_sub_field_or_equal_to_any("a", []) 41 | False 42 | 43 | """ 44 | return any(is_sub_field_or_equal(sub_field, field) for field in fields) 45 | 46 | 47 | def get_granularity(field: str) -> str: 48 | """Return True if `field` has the given granularity 49 | 50 | >>> get_granularity("a") 51 | '' 52 | >>> get_granularity("s.z") 53 | '' 54 | >>> get_granularity("a!") 55 | 'a!' 56 | >>> get_granularity("a!.b") 57 | 'a!' 58 | >>> get_granularity("a!.b!.c") 59 | 'a!.b!' 60 | 61 | """ 62 | granularity = substring_before_last_occurrence(field, "!") 63 | if granularity == "": 64 | return granularity 65 | else: 66 | return granularity + "!" 67 | 68 | 69 | def has_same_granularity(field: str, other_field: str) -> bool: 70 | """Return True if `field` is at the same granularity level as `other_field` 71 | 72 | >>> has_same_granularity("a", "a") 73 | True 74 | >>> has_same_granularity("a", "b") 75 | True 76 | >>> has_same_granularity("a", "a.b") 77 | True 78 | >>> has_same_granularity("a.b", "a.a") 79 | True 80 | >>> has_same_granularity("a.b", "b.a") 81 | True 82 | >>> has_same_granularity("a.b", "a") 83 | True 84 | >>> has_same_granularity("a", "a.b") 85 | True 86 | 87 | >>> has_same_granularity("a!.a", "a!.b") 88 | True 89 | >>> has_same_granularity("a!.a", "a!.b.c") 90 | True 91 | >>> has_same_granularity("a!.b", "b.a") 92 | False 93 | >>> has_same_granularity("a!.b", "a") 94 | False 95 | >>> has_same_granularity("a", "a!.b") 96 | False 97 | 98 | """ 99 | return get_granularity(field) == get_granularity(other_field) 100 | 101 | 102 | def has_same_granularity_as_any(field: str, other_fields: list[str]) -> bool: 103 | """Return True if `field` is equal to or has the same parent as any field in `other_fields` 104 | 105 | >>> has_same_granularity_as_any("a", ["a", "b"]) 106 | True 107 | 108 | >>> has_same_granularity_as_any("a.b", ["a.a", "b"]) 109 | True 110 | >>> has_same_granularity_as_any("a.b", ["b", "c"]) 111 | True 112 | 113 | >>> has_same_granularity_as_any("a!.b", ["b", "c"]) 114 | False 115 | 116 | >>> has_same_granularity_as_any("a!.b!.c", ["a!.b!.d"]) 117 | True 118 | 119 | >>> has_same_granularity_as_any("a!.b.c", ["a"]) 120 | False 121 | >>> has_same_granularity_as_any("a.b!.c", ["a"]) 122 | False 123 | 124 | >>> has_same_granularity_as_any("a", []) 125 | False 126 | 127 | """ 128 | return any(has_same_granularity(field, other_field) for other_field in other_fields) 129 | 130 | 131 | def is_sub_field(sub_field: str, field: str) -> bool: 132 | """Return True if `sub_field` is a sub-field of `field`, or is equal to `field` 133 | 134 | >>> is_sub_field("a", "a") 135 | True 136 | >>> is_sub_field("a", "b") 137 | False 138 | 139 | >>> is_sub_field("a.b", "a") 140 | True 141 | >>> is_sub_field("a!.b", "a!") 142 | True 143 | >>> is_sub_field("a.b", "b") 144 | False 145 | >>> is_sub_field("a.b!", "a") 146 | False 147 | >>> is_sub_field("a!.b", "a") 148 | False 149 | 150 | >>> is_sub_field("a", "a.b") 151 | False 152 | 153 | >>> is_sub_field("a.b.c", "a.b") 154 | True 155 | >>> is_sub_field("a.b.c", "a") 156 | True 157 | 158 | """ 159 | return sub_field == field or ( 160 | sub_field.startswith(field + ".") and REPETITION_MARKER not in sub_field[len(field) :] 161 | ) 162 | 163 | 164 | def is_sub_field_of_any(direct_sub_field: str, fields: list[str]) -> bool: 165 | """Return True if `direct_sub_field` is a sub-field of any field in `fields` 166 | 167 | >>> is_sub_field_of_any("a", ["a", "b"]) 168 | True 169 | 170 | >>> is_sub_field_of_any("a.b", ["a", "b"]) 171 | True 172 | >>> is_sub_field_of_any("a!.b", ["a!", "b"]) 173 | True 174 | >>> is_sub_field_of_any("a!.b!", ["a!", "b!"]) 175 | False 176 | >>> is_sub_field_of_any("a.b", ["b", "c"]) 177 | False 178 | 179 | >>> is_sub_field_of_any("a.b.c", ["a.b"]) 180 | True 181 | 182 | >>> is_sub_field_of_any("a.b.c", ["a"]) 183 | True 184 | 185 | >>> is_sub_field_of_any("a", []) 186 | False 187 | 188 | """ 189 | return any(is_sub_field(direct_sub_field, field) for field in fields) 190 | 191 | 192 | def is_parent_field(field: str, other_field: str) -> bool: 193 | """Return True if `other_field` is a sub-field of `field` 194 | 195 | >>> is_parent_field("a", "a") 196 | False 197 | >>> is_parent_field("a", "b") 198 | False 199 | 200 | >>> is_parent_field("a", "a.b") 201 | True 202 | >>> is_parent_field("b", "a.b") 203 | False 204 | 205 | >>> is_parent_field("a.b", "a") 206 | False 207 | 208 | """ 209 | return other_field.startswith((field + ".", field + "!.")) 210 | 211 | 212 | def is_parent_field_of_any(field: str, other_fields: list[str]) -> bool: 213 | """Return True if any field in `other_fields` is a sub-field of `field` 214 | 215 | >>> is_parent_field_of_any("a", ["a", "b"]) 216 | False 217 | >>> is_parent_field_of_any("a", ["b", "c"]) 218 | False 219 | 220 | >>> is_parent_field_of_any("a", ["a.b", "b"]) 221 | True 222 | >>> is_parent_field_of_any("a", ["b.a", "c.a"]) 223 | False 224 | 225 | >>> is_parent_field_of_any("a.b", ["a"]) 226 | False 227 | 228 | >>> is_parent_field_of_any("a", []) 229 | False 230 | 231 | """ 232 | return any(is_parent_field(field, other_field) for other_field in other_fields) 233 | 234 | 235 | def substring_before_last_occurrence(s: str, sep: str) -> str: 236 | """Returns the substring before the last occurrence of `sep` in `s` 237 | 238 | >>> substring_before_last_occurrence("abc", ".") 239 | '' 240 | >>> substring_before_last_occurrence("abc.d", ".") 241 | 'abc' 242 | >>> substring_before_last_occurrence("abc.d.e", ".") 243 | 'abc.d' 244 | 245 | """ 246 | index = s.rfind(sep) 247 | if index == -1: 248 | return "" 249 | else: 250 | return s[:index] 251 | 252 | 253 | def substring_after_last_occurrence(s: str, sep: str) -> str: 254 | """Returns the substring after the last occurrence of `sep` in `s` 255 | 256 | >>> substring_after_last_occurrence("abc", ".") 257 | 'abc' 258 | >>> substring_after_last_occurrence("abc.d", ".") 259 | 'd' 260 | >>> substring_after_last_occurrence("abc.d.e", ".") 261 | 'e' 262 | 263 | """ 264 | index = s.rfind(sep) 265 | if index == -1: 266 | return s 267 | else: 268 | return s[index + 1 :] 269 | -------------------------------------------------------------------------------- /bigquery_frame/fp/README.md: -------------------------------------------------------------------------------- 1 | This packages contains utility code used for Functional Programming, such as: 2 | - `bigquery.fp.PrintableFunction`: Wrapper class for lambda function that makes debugging much 3 | easier when composing them together. 4 | - `bigquery.fp.compose`: method that allows to compose two PrintableFunctions together. 5 | - `bigquery.fp.higher_order`: package that provides and can generate several reusable PrintableFunctions. 6 | -------------------------------------------------------------------------------- /bigquery_frame/fp/__init__.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.fp.package import compose 2 | from bigquery_frame.fp.printable_function import PrintableFunction 3 | from bigquery_frame.utils import _ref 4 | 5 | _ref(PrintableFunction) 6 | _ref(compose) 7 | -------------------------------------------------------------------------------- /bigquery_frame/fp/higher_order.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | from bigquery_frame import Column, DataFrame, fp 4 | from bigquery_frame import functions as f 5 | from bigquery_frame.fp import PrintableFunction 6 | from bigquery_frame.utils import quote 7 | from bigquery_frame.utils import str_to_col as _str_to_col 8 | 9 | 10 | def alias(name: str) -> PrintableFunction: 11 | """Return a PrintableFunction version of the `bigquery_frame.Column.alias` method""" 12 | return PrintableFunction(lambda s: s.alias(name), lambda s: str(s) + f".alias({name!r})") 13 | 14 | 15 | identity = PrintableFunction(lambda s: s, lambda s: str(s)) 16 | struct = PrintableFunction(lambda x: f.struct(x), lambda x: f"f.struct({x})") 17 | str_to_col = PrintableFunction(lambda x: _str_to_col(x), lambda s: str(s)) 18 | 19 | 20 | def struct_get(key: str) -> PrintableFunction: 21 | """Return a PrintableFunction that gets a struct's subfield, unless the struct is None, 22 | in which case it returns a Column expression for the field itself. 23 | 24 | Get a column's subfield, unless the column is None, in which case it returns 25 | a Column expression for the field itself. 26 | 27 | Examples: 28 | >>> struct_get("c") 29 | lambda x: x['c'] 30 | >>> struct_get("c").alias(None) 31 | "f.col('c')" 32 | """ 33 | 34 | def _safe_struct_get(s: Optional[Column], field: str) -> Column: 35 | if s is None: 36 | return f.col(field) 37 | else: 38 | if ("." in field or "!" in field) and isinstance(s, DataFrame): 39 | return s[quote(field)] 40 | else: 41 | return s[field] 42 | 43 | def _safe_struct_get_alias(s: Optional[str], field: str) -> str: 44 | if s is None: 45 | return f"f.col({field!r})" 46 | else: 47 | return f"{s}[{field!r}]" 48 | 49 | return PrintableFunction(lambda s: _safe_struct_get(s, key), lambda s: _safe_struct_get_alias(s, key)) 50 | 51 | 52 | def recursive_struct_get(keys: list[str]) -> PrintableFunction: 53 | """Return a PrintableFunction that recursively applies get to a nested structure. 54 | 55 | Examples: 56 | >>> recursive_struct_get([]) 57 | lambda x: x 58 | >>> recursive_struct_get(["a", "b", "c"]) 59 | lambda x: x['a']['b']['c'] 60 | >>> recursive_struct_get(["a", "b", "c"]).alias(None) 61 | "f.col('a')['b']['c']" 62 | """ 63 | if len(keys) == 0: 64 | return identity 65 | else: 66 | return fp.compose(recursive_struct_get(keys[1:]), struct_get(keys[0])) 67 | 68 | 69 | def transform(transformation: PrintableFunction) -> PrintableFunction: 70 | """Return a PrintableFunction version of the `bigquery_frame.functions.transform` method, 71 | which applies the given transformation to any array column. 72 | """ 73 | return PrintableFunction( 74 | lambda x: f.transform(x, transformation.func), 75 | lambda x: f"f.transform({x}, lambda x: {transformation.alias('x')})", 76 | ) 77 | 78 | 79 | def _partial_box_right(func: Callable, args: Any) -> Callable: 80 | """Given a function and an array of arguments, return a new function that takes an argument, add it to the 81 | array, and pass it to the original function. 82 | """ 83 | if isinstance(args, str): 84 | args = [args] 85 | return lambda a: func([*args, a]) 86 | 87 | 88 | def boxed_transform(transformation: PrintableFunction, parents: list[str]) -> PrintableFunction: 89 | """Return a PrintableFunction version of the `bigquery_frame.functions.transform` method, 90 | which applies the given transformation to any array column. 91 | """ 92 | return PrintableFunction( 93 | lambda x: f.transform(recursive_struct_get(parents)(x[-1]), _partial_box_right(transformation.func, x)), 94 | lambda x: f"f.transform({recursive_struct_get(parents).alias(x[-1])}, " 95 | f"lambda x: {_partial_box_right(transformation.alias, x)('x')})", 96 | ) 97 | -------------------------------------------------------------------------------- /bigquery_frame/fp/package.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, cast 2 | 3 | from bigquery_frame.fp.printable_function import PrintableFunction 4 | 5 | 6 | def __compose(f1: PrintableFunction, f2: PrintableFunction) -> PrintableFunction: 7 | """Composes together two PrintableFunctions. 8 | 9 | For instance, if `h = compose(g, f)`, then for every `x`, `h(x) = g(f(x))`. 10 | 11 | Args: 12 | f1: A PrintableFunction 13 | f2: A PrintableFunction 14 | 15 | Returns: 16 | The composition of f1 with f2. 17 | 18 | Examples: 19 | >>> f = PrintableFunction(lambda x: x+1, lambda s: f'{s} + 1') 20 | >>> g = PrintableFunction(lambda x: x.cast("Double"), lambda s: f'({s}).cast("Double")') 21 | >>> __compose(g, f) 22 | lambda x: (x + 1).cast("Double") 23 | >>> __compose(f, g) 24 | lambda x: (x).cast("Double") + 1 25 | 26 | >>> h = PrintableFunction(lambda x: x*2, "h") 27 | >>> __compose(h, f) 28 | lambda x: h(x + 1) 29 | >>> __compose(f, h) 30 | h + 1 31 | >>> __compose(h, h) 32 | h(h) 33 | """ 34 | 35 | def f1f2(s: Any) -> Any: 36 | return f1.func(f2.func(s)) 37 | 38 | if callable(f1.alias) and callable(f2.alias): 39 | c1 = cast(Callable, f1.alias) 40 | c2 = cast(Callable, f2.alias) 41 | return PrintableFunction(f1f2, lambda s: c1(c2(s))) 42 | elif callable(f1.alias) and not callable(f2.alias): 43 | c1 = cast(Callable, f1.alias) 44 | a2 = str(f2.alias) 45 | return PrintableFunction(f1f2, c1(a2)) 46 | elif not callable(f1.alias) and callable(f2.alias): 47 | a1 = str(f1.alias) 48 | c2 = cast(Callable, f2.alias) 49 | return PrintableFunction(f1f2, lambda s: f"{a1}({c2(s)})") 50 | else: 51 | a1 = str(f1.alias) 52 | a2 = str(f2.alias) 53 | return PrintableFunction(f1f2, f"{a1}({a2})") 54 | 55 | 56 | def compose(f1: PrintableFunction, f2: PrintableFunction, *f3: PrintableFunction) -> PrintableFunction: 57 | """Composes together two or more PrintableFunctions. 58 | For instance, if `h = compose(g, f)`, then for every `x`, `h(x) = g(f(x))`. 59 | 60 | Args: 61 | f1: A PrintableFunction 62 | f2: A PrintableFunction 63 | 64 | Returns: 65 | The composition of f1 with f2. 66 | 67 | Examples: 68 | >>> f = PrintableFunction(lambda x: x+1, lambda s: f'{s} + 1') 69 | >>> g = PrintableFunction(lambda x: x.cast("Double"), lambda s: f'({s}).cast("Double")') 70 | >>> h = PrintableFunction(lambda x: x*2, lambda x: f"{x}*2") 71 | >>> compose(f, g, h) 72 | lambda x: (x*2).cast("Double") + 1 73 | 74 | """ 75 | res = __compose(f1, f2) 76 | for f in f3: 77 | res = __compose(res, f) 78 | return res 79 | -------------------------------------------------------------------------------- /bigquery_frame/fp/printable_function.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Any, Callable, cast 3 | 4 | 5 | class PrintableFunction: 6 | """Wrapper for anonymous functions with a short description making them much human-friendly when printed. 7 | 8 | Very useful when debugging, useless otherwise. 9 | 10 | Args: 11 | func: A function that takes a Column and return a Column. 12 | alias: A string or a function that takes a string and return a string. 13 | 14 | Examples: 15 | >>> print(PrintableFunction(lambda s: s["c"], "my_function")) 16 | my_function 17 | 18 | >>> print(PrintableFunction(lambda s: s["c"], lambda s: f'{s}["c"]')) 19 | lambda x: x["c"] 20 | 21 | >>> func = PrintableFunction(lambda s: s.cast("Double"), lambda s: f'{s}.cast("Double")') 22 | >>> print(func.alias("s")) 23 | s.cast("Double") 24 | 25 | Composition: 26 | 27 | >>> f1 = PrintableFunction(lambda s: s.cast("Double"), lambda s: f'{s}.cast("Double")') 28 | >>> f2 = PrintableFunction(lambda s: s * s, lambda s: f'{f"({s} * {s})"}') 29 | >>> f2_then_f1 = PrintableFunction(lambda s: f1(f2(s)), lambda s: f1.alias(f2.alias(s))) 30 | >>> print(f2_then_f1) 31 | lambda x: (x * x).cast("Double") 32 | """ 33 | 34 | def __init__(self, func: Callable[[Any], Any], alias: Callable[[Any], str]) -> None: 35 | self.func: Callable[[Any], Any] = func 36 | self.alias: Callable[[Any], str] = alias 37 | 38 | def __repr__(self) -> str: 39 | if callable(self.alias): 40 | return f"lambda x: {cast(Callable, self.alias)('x')}" 41 | else: 42 | return cast(str, self.alias) 43 | 44 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 45 | return self.func(*args, **kwargs) 46 | 47 | def boxed(self) -> "PrintableFunction": 48 | """Return a boxed version of this method.""" 49 | return PrintableFunction(box(self.func), box(self.alias)) 50 | 51 | 52 | def arity(function: Callable) -> int: 53 | """Return the arity of the given function""" 54 | sig = inspect.signature(function) 55 | arity = len(sig.parameters) 56 | return arity 57 | 58 | 59 | def box(func: Callable) -> Callable: 60 | """Transform a constant or function that takes any number `n` of arguments into a function that takes 61 | a single argument of type array and passes the `n` right-most arguments to that function. 62 | 63 | Examples: 64 | >>> func = lambda a, b, c: f"{a}.{b}.{c}" 65 | >>> boxed_func = box(func) 66 | >>> boxed_func(["1", "2", "3"]) 67 | '1.2.3' 68 | >>> boxed_func(["1", "2", "3", "4", "5"]) 69 | '3.4.5' 70 | 71 | >>> func_no_arg = lambda: "a" 72 | >>> boxed_func_no_arg = box(func_no_arg) 73 | >>> boxed_func_no_arg(["1", "2"]) 74 | 'a' 75 | 76 | >>> constant = "a" 77 | >>> boxed_constant = box(constant) 78 | >>> boxed_constant(["1", "2", "3", "4", "5"]) 79 | 'a' 80 | 81 | """ 82 | if not callable(func): 83 | return lambda x: func 84 | n = arity(func) 85 | if n > 0: 86 | return lambda x: func(*x[-n:]) 87 | else: 88 | return lambda x: func() 89 | -------------------------------------------------------------------------------- /bigquery_frame/graph.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.graph_impl.connected_components import connected_components 2 | 3 | connected_components = connected_components 4 | -------------------------------------------------------------------------------- /bigquery_frame/graph_impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/bigquery_frame/graph_impl/__init__.py -------------------------------------------------------------------------------- /bigquery_frame/graph_impl/connected_components.py: -------------------------------------------------------------------------------- 1 | from google.cloud.bigquery import SchemaField 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame import functions as f 5 | from bigquery_frame.utils import assert_true 6 | 7 | _generic_star_sql = """ 8 | WITH edges AS ( 9 | SELECT l_node, r_node FROM {input} 10 | UNION ALL 11 | SELECT r_node, l_node FROM {input} 12 | UNION ALL 13 | SELECT l_node, l_node FROM {input} 14 | UNION ALL 15 | SELECT r_node, r_node FROM {input} 16 | ), 17 | neighborhood AS ( 18 | SELECT 19 | l_node, 20 | array_agg(distinct r_node) AS neighbors_and_self 21 | FROM edges 22 | WHERE l_node IS NOT NULL and r_node IS NOT NULL 23 | GROUP BY l_node 24 | ) 25 | SELECT 26 | l_node, 27 | (SELECT MIN(neighbor) FROM unnest(neighbors_and_self) AS neighbor) AS min_neighbors_and_self, 28 | ARRAY(SELECT * FROM unnest(neighbors_and_self) AS neighbor WHERE neighbor > l_node) AS larger_neighbors, 29 | ARRAY(SELECT * FROM unnest(neighbors_and_self) AS neighbor WHERE neighbor <= l_node) AS smaller_neighbors_and_self 30 | FROM neighborhood 31 | """ 32 | 33 | _large_star_sql = """ 34 | SELECT 35 | min_neighbors_and_self AS l_node, 36 | r_node 37 | FROM large_star 38 | JOIN unnest(larger_neighbors) AS r_node 39 | /* This adds back self-loop (N <=> N) relations */ 40 | UNION ALL 41 | SELECT 42 | l_node AS l_node, 43 | l_node AS r_node 44 | FROM large_star 45 | """ 46 | 47 | _small_star_sql = """ 48 | SELECT 49 | min_neighbors_and_self AS l_node, 50 | r_node 51 | FROM small_star 52 | JOIN UNNEST(smaller_neighbors_and_self) AS r_node 53 | /* This makes sure that self-loops (N <=> N) relations are only added once */ 54 | WHERE min_neighbors_and_self <> r_node OR l_node = min_neighbors_and_self 55 | """ 56 | 57 | 58 | def _large_star(df: DataFrame): 59 | df.createOrReplaceTempView("large_star_input") 60 | df.bigquery.sql(_generic_star_sql.format(input="large_star_input")).createOrReplaceTempView("large_star") 61 | return df.bigquery.sql(_large_star_sql) 62 | 63 | 64 | def _small_star(df: DataFrame): 65 | df.createOrReplaceTempView("small_star_input") 66 | df.bigquery.sql(_generic_star_sql.format(input="small_star_input")).createOrReplaceTempView("small_star") 67 | return df.bigquery.sql(_small_star_sql) 68 | 69 | 70 | def _star_loop(df: DataFrame): 71 | working_df = df 72 | changes = 1 73 | while changes > 0: 74 | new_df = _small_star(_large_star(working_df)).persist() 75 | working_df.createOrReplaceTempView("old_df") 76 | new_df.createOrReplaceTempView("new_df") 77 | changes = df.bigquery.sql( 78 | """ 79 | SELECT 1 80 | FROM new_df 81 | WHERE NOT EXISTS ( 82 | SELECT 1 FROM old_df WHERE new_df.l_node = old_df.l_node AND new_df.r_node = old_df.r_node 83 | )""", 84 | ).count() 85 | working_df = new_df 86 | return working_df.select("l_node", "r_node") 87 | 88 | 89 | def connected_components( 90 | df: DataFrame, 91 | node_name: str = "node_id", 92 | connected_component_col_name: str = "connected_component_id", 93 | ): 94 | """Compute the connected components of a non-directed graph. 95 | 96 | Given a DataFrame with two columns of the same type STRING or INTEGER representing the edges of a graph, 97 | this computes a new DataFrame containing two columns of the same type named using `node_name` and 98 | `connected_component_col_name`. 99 | 100 | This is an implementation of the Alternating Algorithm (large-star, small-star) described in the 2014 paper 101 | "Connected Components in MapReduce and Beyond" 102 | written by {rkiveris, silviol, mirrokni, rvibhor, sergeiv} @google.com 103 | 104 | PERFORMANCE AND COST CONSIDERATIONS 105 | ----------------------------------- 106 | This algorithm has been proved to converge in O(log(n)²) and is conjectured to converge in O(log(n)), where n 107 | is the number of nodes in the graph. It was the most performant known distributed connected component algorithm 108 | last time I checked (in 2017). 109 | 110 | This implementation persists temporary results at each iteration loop: for the BigQuery pricing, you should 111 | be expecting it to cost the equivalent of 15 to 30 scans on your input table. Since the input table has only 112 | two columns, this should be reasonable, and we recommend using INTEGER columns rather than STRING when possible. 113 | 114 | If your graph contains nodes with a very high number of neighbors, the algorithm may crash. It is recommended 115 | to apply a pre-filtering on your nodes and remove nodes with a pathologically high cardinality. 116 | You should also monitor actively the number of nodes filtered this way and their cardinality, as this could help 117 | you detect a data quality deterioration in your input graph. 118 | If the input graph contains duplicate edges, they will be automatically removed by the algorithm. 119 | 120 | If you want to have isolated nodes (nodes that have no neighbors) in the resulting graph, there is two possible 121 | ways to achieve this: 122 | A. Add self-loops edges to all your nodes in your input graph (it also works if you add edges between all the graph 123 | nodes and a fictitious node with id NULL) 124 | B. Only add edges between distinct nodes to your input, and perform a join between your input graph and the 125 | algorithm's output to find all the nodes that have disappeared. These will be the isolated nodes. 126 | Method B. requires a little more work but it should also be cheaper. 127 | 128 | Example: 129 | >>> df = __get_test_df() 130 | >>> df.show() 131 | +--------+--------+ 132 | | l_node | r_node | 133 | +--------+--------+ 134 | | 1 | 8 | 135 | | 8 | 9 | 136 | | 5 | 8 | 137 | | 7 | 8 | 138 | | 3 | 7 | 139 | | 2 | 3 | 140 | | 4 | 6 | 141 | +--------+--------+ 142 | >>> connected_components(df, connected_component_col_name="cc_id").sort("node_id", "cc_id").show() 143 | +---------+-------+ 144 | | node_id | cc_id | 145 | +---------+-------+ 146 | | 1 | 1 | 147 | | 2 | 1 | 148 | | 3 | 1 | 149 | | 4 | 4 | 150 | | 5 | 1 | 151 | | 6 | 4 | 152 | | 7 | 1 | 153 | | 8 | 1 | 154 | | 9 | 1 | 155 | +---------+-------+ 156 | 157 | :param df: 158 | :param node_name: Name of the column representing the node in the output DataFrame (default: "node_id") 159 | :param connected_component_col_name: Name of the column representing the connected component to which each node 160 | belongs in the output DataFrame (default: "cc_id") 161 | :return: 162 | """ 163 | assert_true(len(df.columns) == 2, "Input DataFrame must have two columns") 164 | l_field: SchemaField 165 | r_field: SchemaField 166 | [l_field, r_field] = df.schema 167 | assert_true( 168 | l_field.field_type == r_field.field_type, 169 | "The two columns of the input DataFrame must have the same type", 170 | ) 171 | assert_true( 172 | l_field.field_type in ["STRING", "INTEGER"], 173 | "The two columns of the input DataFrame must be of type STRING or INTEGER", 174 | ) 175 | [l_col, r_col] = df.columns 176 | df = df.select(f.col(l_col).alias("l_node"), f.col(r_col).alias("r_node")) 177 | res = _star_loop(df) 178 | return res.select(f.col("r_node").alias(node_name), f.col("l_node").alias(connected_component_col_name)) 179 | 180 | 181 | def __get_test_df() -> DataFrame: 182 | from bigquery_frame import BigQueryBuilder 183 | 184 | bq = BigQueryBuilder() 185 | df = bq.sql( 186 | """ 187 | SELECT * 188 | FROM UNNEST([ 189 | STRUCT(1 as l_node, 8 as r_node), 190 | STRUCT(8 as l_node, 9 as r_node), 191 | STRUCT(5 as l_node, 8 as r_node), 192 | STRUCT(7 as l_node, 8 as r_node), 193 | STRUCT(3 as l_node, 7 as r_node), 194 | STRUCT(2 as l_node, 3 as r_node), 195 | STRUCT(4 as l_node, 6 as r_node) 196 | ]) 197 | """, 198 | ) 199 | return df 200 | -------------------------------------------------------------------------------- /bigquery_frame/has_bigquery_client.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys 3 | import traceback 4 | from dataclasses import dataclass 5 | from typing import Callable, Optional, TypeVar, cast 6 | 7 | from google.api_core.exceptions import BadRequest, InternalServerError 8 | from google.cloud.bigquery import ConnectionProperty, QueryJob, QueryJobConfig, SchemaField 9 | from google.cloud.bigquery.client import Client 10 | from google.cloud.bigquery.table import RowIterator 11 | 12 | from bigquery_frame.units import bytes_to_human_readable 13 | from bigquery_frame.utils import number_lines, strip_margin 14 | 15 | DEFAULT_MAX_TRY_COUNT = 3 16 | 17 | ReturnType = TypeVar("ReturnType") 18 | 19 | 20 | @dataclass() 21 | class BigQueryStats: 22 | estimated_bytes_processed: int = 0 23 | """Estimation of the number of bytes computed before the query is run to dimension the query's provisioning.""" 24 | 25 | total_bytes_processed: int = 0 26 | """Actual number of bytes processed by the query.""" 27 | 28 | total_bytes_billed: int = 0 29 | """Actual number of bytes billed for the query. 30 | This may exceed the number of bytes processed for small queries because BigQuery charges a minimum of 10 MiB 31 | per input table to account for the query overhead. However, queries with a "LIMIT 0" are completely free. 32 | For more details, please see the official 33 | `BigQuery Documentation `_. 34 | """ 35 | 36 | def add_job_stats(self, job: QueryJob): 37 | if job.estimated_bytes_processed is not None: 38 | self.estimated_bytes_processed += job.estimated_bytes_processed 39 | if job.total_bytes_processed is not None: 40 | self.total_bytes_processed += job.total_bytes_processed 41 | if job.total_bytes_billed is not None: 42 | self.total_bytes_billed += job.total_bytes_billed 43 | 44 | def human_readable_estimated_bytes_processed(self): 45 | return f"Estimated bytes processed : {bytes_to_human_readable(self.estimated_bytes_processed)}" 46 | 47 | def human_readable_total_bytes_billed(self): 48 | return f"Total bytes billed : {bytes_to_human_readable(self.total_bytes_billed)}" 49 | 50 | def human_readable_total_bytes_processed(self): 51 | return f"Total bytes processed : {bytes_to_human_readable(self.total_bytes_processed)}" 52 | 53 | def human_readable(self): 54 | return strip_margin( 55 | f""" 56 | |{self.human_readable_estimated_bytes_processed()} 57 | |{self.human_readable_total_bytes_processed()} 58 | |{self.human_readable_total_bytes_billed()} 59 | |""", 60 | ) 61 | 62 | 63 | class HasBigQueryClient: 64 | """Wrapper class for the BigQuery client 65 | 66 | This isolates all the logic of direct interaction with the BigQuery client, 67 | which makes the code's security easier to audit (although nothing can be really private in Python). 68 | """ 69 | 70 | def __init__(self, client: Client, use_session: bool = True, max_try_count: int = DEFAULT_MAX_TRY_COUNT): 71 | """Wrapper class for the BigQuery client 72 | 73 | :param client: A :class:`google.cloud.bigquery.client.Client` 74 | :param use_session: If set to true, all queries will be executed in the same session. 75 | This is necessary for reusing temporary tables across multiple queries 76 | """ 77 | self.max_try_count = max_try_count 78 | self.__use_session = use_session 79 | self.__client = client 80 | self.__session_id: Optional[str] = None 81 | self.__stats: BigQueryStats = BigQueryStats() 82 | 83 | def _get_session_id_after_query(self, job): 84 | if self.__use_session and self.__session_id is None and job.session_info is not None: 85 | self.__session_id = job.session_info.session_id 86 | 87 | def _set_session_id_before_query(self, job_config): 88 | if self.__use_session: 89 | if self.__session_id is None: 90 | job_config.create_session = True 91 | else: 92 | job_config.connection_properties = [ConnectionProperty("session_id", self.__session_id)] 93 | 94 | def _execute_job( 95 | self, 96 | query: str, 97 | action: Callable[[QueryJob], ReturnType], 98 | dry_run: bool, 99 | use_query_cache: bool, 100 | try_count: int = 1, 101 | ) -> ReturnType: 102 | job_config = QueryJobConfig(use_query_cache=use_query_cache, dry_run=dry_run) 103 | 104 | try: 105 | self._set_session_id_before_query(job_config) 106 | job = self.__client.query(query=query, job_config=job_config) 107 | self._get_session_id_after_query(job) 108 | res = action(job) 109 | except BadRequest as e: 110 | if len(query) < 1024 * 1000: 111 | e.message += "\nQuery:\n" + number_lines(query, 1) 112 | else: 113 | e.message += "\n(The query is too large to be displayed)\n" 114 | raise e 115 | except InternalServerError as e: 116 | try_count += 1 117 | if try_count <= self.max_try_count: 118 | traceback.print_exc(file=sys.stderr) 119 | else: 120 | raise e 121 | else: 122 | self.__stats.add_job_stats(job) 123 | return res 124 | print(f"Retrying query (Try n°{try_count}/{self.max_try_count})", file=sys.stderr) 125 | return self._execute_job(query, action, dry_run=dry_run, use_query_cache=use_query_cache, try_count=try_count) 126 | 127 | def _get_query_schema(self, query: str) -> list[SchemaField]: 128 | def action(job: QueryJob) -> list[SchemaField]: 129 | return cast(list[SchemaField], job.schema) 130 | 131 | return self._execute_job(query, action, dry_run=True, use_query_cache=False) 132 | 133 | def _execute_query(self, query: str, use_query_cache=True) -> RowIterator: 134 | def action(job: QueryJob) -> RowIterator: 135 | return job.result() 136 | 137 | return self._execute_job(query, action, dry_run=False, use_query_cache=use_query_cache) 138 | 139 | def close(self): 140 | self.__client.close() 141 | 142 | @property 143 | def stats(self): 144 | return copy.copy(self.__stats) 145 | -------------------------------------------------------------------------------- /bigquery_frame/nested.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.nested_impl.fields import fields 2 | from bigquery_frame.nested_impl.print_schema import print_schema 3 | from bigquery_frame.nested_impl.schema_string import schema_string 4 | from bigquery_frame.nested_impl.select_impl import select 5 | from bigquery_frame.nested_impl.unnest_all_fields import unnest_all_fields 6 | from bigquery_frame.nested_impl.unnest_field import unnest_field 7 | from bigquery_frame.nested_impl.with_fields import with_fields 8 | from bigquery_frame.utils import _ref 9 | 10 | _ref(fields) 11 | _ref(print_schema) 12 | _ref(schema_string) 13 | _ref(select) 14 | _ref(unnest_all_fields) 15 | _ref(unnest_field) 16 | _ref(with_fields) 17 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/bigquery_frame/nested_impl/__init__.py -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/fields.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import DataFrame 2 | from bigquery_frame.data_type_utils import flatten_schema 3 | 4 | 5 | def fields(df: DataFrame, keep_non_leaf_fields: bool = False) -> list[str]: 6 | """Return the name of all the fields (including nested sub-fields) in the given DataFrame. 7 | 8 | - Structs are flattened with a `.` after their name. 9 | - Arrays are flattened with a `!` character after their name. 10 | 11 | Args: 12 | df: A DataFrame 13 | keep_non_leaf_fields: If set, the fields of type array or struct are also included in the result 14 | 15 | Returns: 16 | The list of all flattened field names in this DataFrame 17 | 18 | Examples: 19 | >>> from bigquery_frame import BigQueryBuilder 20 | >>> bq = BigQueryBuilder() 21 | >>> df = bq.sql('''SELECT 22 | ... 1 as id, 23 | ... [STRUCT(2 as a, [STRUCT(3 as c, 4 as d)] as b, [5, 6] as e)] as s1, 24 | ... STRUCT(7 as f) as s2, 25 | ... ''') 26 | >>> df.printSchema() 27 | root 28 | |-- id: INTEGER (NULLABLE) 29 | |-- s1: RECORD (REPEATED) 30 | | |-- a: INTEGER (NULLABLE) 31 | | |-- b: RECORD (REPEATED) 32 | | | |-- c: INTEGER (NULLABLE) 33 | | | |-- d: INTEGER (NULLABLE) 34 | | |-- e: INTEGER (REPEATED) 35 | |-- s2: RECORD (NULLABLE) 36 | | |-- f: INTEGER (NULLABLE) 37 | 38 | >>> for field in fields(df): print(field) 39 | id 40 | s1!.a 41 | s1!.b!.c 42 | s1!.b!.d 43 | s1!.e! 44 | s2.f 45 | >>> for field in fields(df, keep_non_leaf_fields = True): print(field) 46 | id 47 | s1 48 | s1! 49 | s1!.a 50 | s1!.b 51 | s1!.b! 52 | s1!.b!.c 53 | s1!.b!.d 54 | s1!.e 55 | s1!.e! 56 | s2 57 | s2.f 58 | """ 59 | return [field.name for field in flatten_schema(df.schema, explode=True, keep_non_leaf_fields=keep_non_leaf_fields)] 60 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/print_schema.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import DataFrame 2 | from bigquery_frame.nested_impl.schema_string import schema_string 3 | 4 | 5 | def print_schema(df: DataFrame) -> None: 6 | """Print the DataFrame's flattened schema to the standard output. 7 | 8 | - Structs are flattened with a `.` after their name. 9 | - Arrays are flattened with a `!` character after their name. 10 | 11 | Args: 12 | df: A DataFrame 13 | 14 | Examples: 15 | >>> from bigquery_frame import BigQueryBuilder 16 | >>> from bigquery_frame import nested 17 | >>> bq = BigQueryBuilder() 18 | >>> df = bq.sql('''SELECT 19 | ... 1 as id, 20 | ... [STRUCT(2 as a, [STRUCT(3 as c, 4 as d)] as b, [5, 6] as e)] as s1, 21 | ... STRUCT(7 as f) as s2 22 | ... ''') 23 | >>> df.printSchema() 24 | root 25 | |-- id: INTEGER (NULLABLE) 26 | |-- s1: RECORD (REPEATED) 27 | | |-- a: INTEGER (NULLABLE) 28 | | |-- b: RECORD (REPEATED) 29 | | | |-- c: INTEGER (NULLABLE) 30 | | | |-- d: INTEGER (NULLABLE) 31 | | |-- e: INTEGER (REPEATED) 32 | |-- s2: RECORD (NULLABLE) 33 | | |-- f: INTEGER (NULLABLE) 34 | 35 | >>> nested.print_schema(df) 36 | root 37 | |-- id: INTEGER (nullable = true) 38 | |-- s1!.a: INTEGER (nullable = true) 39 | |-- s1!.b!.c: INTEGER (nullable = true) 40 | |-- s1!.b!.d: INTEGER (nullable = true) 41 | |-- s1!.e!: INTEGER (nullable = false) 42 | |-- s2.f: INTEGER (nullable = true) 43 | 44 | """ 45 | print(schema_string(df)) 46 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/schema_string.py: -------------------------------------------------------------------------------- 1 | from google.cloud.bigquery import SchemaField 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame.data_type_utils import flatten_schema 5 | from bigquery_frame.dataframe import is_nullable 6 | 7 | 8 | def _flat_schema_to_tree_string(schema: list[SchemaField]) -> str: 9 | """Generates a string representing a flat schema in tree format""" 10 | 11 | def str_gen_schema_field(schema_field: SchemaField, prefix: str) -> list[str]: 12 | res = [ 13 | f"{prefix}{schema_field.name}: {schema_field.field_type} " 14 | f"(nullable = {str(is_nullable(schema_field)).lower()})", 15 | ] 16 | return res 17 | 18 | def str_gen_schema(schema: list[SchemaField], prefix: str) -> list[str]: 19 | return [string for schema_field in schema for string in str_gen_schema_field(schema_field, prefix)] 20 | 21 | res = ["root", *str_gen_schema(schema, " |-- ")] 22 | 23 | return "\n".join(res) + "\n" 24 | 25 | 26 | def schema_string(df: DataFrame) -> str: 27 | """Write the DataFrame's flattened schema to a string. 28 | 29 | - Structs are flattened with a `.` after their name. 30 | - Arrays are flattened with a `!` character after their name. 31 | 32 | Args: 33 | df: A DataFrame 34 | 35 | Returns: 36 | a string representing the flattened schema 37 | 38 | Examples: 39 | >>> from bigquery_frame import BigQueryBuilder 40 | >>> from bigquery_frame import nested 41 | >>> bq = BigQueryBuilder() 42 | >>> df = bq.sql('''SELECT 43 | ... 1 as id, 44 | ... [STRUCT(2 as a, [STRUCT(3 as c, 4 as d)] as b, [5, 6] as e)] as s1, 45 | ... STRUCT(7 as f) as s2 46 | ... ''') 47 | >>> df.printSchema() 48 | root 49 | |-- id: INTEGER (NULLABLE) 50 | |-- s1: RECORD (REPEATED) 51 | | |-- a: INTEGER (NULLABLE) 52 | | |-- b: RECORD (REPEATED) 53 | | | |-- c: INTEGER (NULLABLE) 54 | | | |-- d: INTEGER (NULLABLE) 55 | | |-- e: INTEGER (REPEATED) 56 | |-- s2: RECORD (NULLABLE) 57 | | |-- f: INTEGER (NULLABLE) 58 | 59 | >>> print(nested.schema_string(df)) 60 | root 61 | |-- id: INTEGER (nullable = true) 62 | |-- s1!.a: INTEGER (nullable = true) 63 | |-- s1!.b!.c: INTEGER (nullable = true) 64 | |-- s1!.b!.d: INTEGER (nullable = true) 65 | |-- s1!.e!: INTEGER (nullable = false) 66 | |-- s2.f: INTEGER (nullable = true) 67 | 68 | """ 69 | flat_schema = flatten_schema(df.schema, explode=True) 70 | return _flat_schema_to_tree_string(flat_schema) 71 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/select_impl.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame.nested_impl.package import ColumnTransformation, resolve_nested_fields 5 | 6 | 7 | # Workaround: This file is temporarily called "select_impl.py" instead of "select.py" to work around a bug in PyCharm. 8 | # https://youtrack.jetbrains.com/issue/PY-58068 9 | def select(df: DataFrame, fields: Mapping[str, ColumnTransformation]) -> DataFrame: 10 | """Project a set of expressions and returns a new [DataFrame][bigquery_frame.DataFrame]. 11 | 12 | This method is similar to the [DataFrame.select][bigquery_frame.DataFrame.select] method, with the extra 13 | capability of working on nested and repeated fields (structs and arrays). 14 | 15 | The syntax for field names works as follows: 16 | 17 | - "." is the separator for struct elements 18 | - "!" must be appended at the end of fields that are repeated (arrays) 19 | 20 | The following types of transformation are allowed: 21 | 22 | - String and column expressions can be used on any non-repeated field, even nested ones. 23 | - When working on repeated fields, transformations must be expressed as higher order functions 24 | (e.g. lambda expressions). String and column expressions can be used on repeated fields as well, 25 | but their value will be repeated multiple times. 26 | - When working on multiple levels of nested arrays, higher order functions may take multiple arguments, 27 | corresponding to each level of repetition (See Example 5.). 28 | - `None` can also be used to represent the identity transformation, this is useful to select a field without 29 | changing and without having to repeat its name. 30 | 31 | Args: 32 | df: A DataFrame 33 | fields: A Dict(field_name, transformation_to_apply) 34 | 35 | Returns: 36 | A new DataFrame where only the specified field have been selected and the corresponding 37 | transformations were applied to each of them. 38 | 39 | Examples: 40 | *Example 1: non-repeated fields* 41 | 42 | >>> from bigquery_frame import BigQueryBuilder 43 | >>> from bigquery_frame import functions as f 44 | >>> from bigquery_frame import nested 45 | >>> bq = BigQueryBuilder() 46 | >>> df = bq.sql('''SELECT 1 as id, STRUCT(2 as a, 3 as b) as s''') 47 | >>> df.printSchema() 48 | root 49 | |-- id: INTEGER (NULLABLE) 50 | |-- s: RECORD (NULLABLE) 51 | | |-- a: INTEGER (NULLABLE) 52 | | |-- b: INTEGER (NULLABLE) 53 | 54 | >>> df.show(simplify_structs=True) 55 | +----+--------+ 56 | | id | s | 57 | +----+--------+ 58 | | 1 | {2, 3} | 59 | +----+--------+ 60 | 61 | Transformations on non-repeated fields may be expressed as a string representing a column name, 62 | a Column expression or None. 63 | (In this example the column "id" will be dropped because it was not selected) 64 | >>> new_df = df.transform(nested.select, { 65 | ... "s.a": "s.a", # Column name (string) 66 | ... "s.b": None, # None: use to keep a column without having to repeat its name 67 | ... "s.c": f.col("s.a") + f.col("s.b") # Column expression 68 | ... }) 69 | >>> new_df.printSchema() 70 | root 71 | |-- s: RECORD (NULLABLE) 72 | | |-- a: INTEGER (NULLABLE) 73 | | |-- b: INTEGER (NULLABLE) 74 | | |-- c: INTEGER (NULLABLE) 75 | 76 | 77 | >>> new_df.show(simplify_structs=True) 78 | +-----------+ 79 | | s | 80 | +-----------+ 81 | | {2, 3, 5} | 82 | +-----------+ 83 | 84 | *Example 2: repeated fields* 85 | 86 | >>> df = bq.sql('SELECT 1 as id, [STRUCT(1 as a, 2 as b), STRUCT(3 as a, 4 as b)] as s') 87 | >>> nested.print_schema(df) 88 | root 89 | |-- id: INTEGER (nullable = true) 90 | |-- s!.a: INTEGER (nullable = true) 91 | |-- s!.b: INTEGER (nullable = true) 92 | 93 | >>> df.show(simplify_structs=True) 94 | +----+------------------+ 95 | | id | s | 96 | +----+------------------+ 97 | | 1 | [{1, 2}, {3, 4}] | 98 | +----+------------------+ 99 | 100 | Transformations on repeated fields must be expressed as higher-order 101 | functions (lambda expressions or named functions). 102 | The value passed to this function will correspond to the last repeated element. 103 | >>> df.transform(nested.select, { 104 | ... "s!.a": lambda s: s["a"], 105 | ... "s!.b": None, 106 | ... "s!.c": lambda s: s["a"] + s["b"] 107 | ... }).show(simplify_structs=True) 108 | +------------------------+ 109 | | s | 110 | +------------------------+ 111 | | [{1, 2, 3}, {3, 4, 7}] | 112 | +------------------------+ 113 | 114 | String and column expressions can be used on repeated fields as well, 115 | but their value will be repeated multiple times. 116 | >>> df.transform(nested.select, { 117 | ... "id": None, 118 | ... "s!.a": "id", 119 | ... "s!.b": f.lit(2) 120 | ... }).show(simplify_structs=True) 121 | +----+------------------+ 122 | | id | s | 123 | +----+------------------+ 124 | | 1 | [{1, 2}, {1, 2}] | 125 | +----+------------------+ 126 | 127 | *Example 3: field repeated twice* 128 | >>> df = bq.sql(''' 129 | ... SELECT 130 | ... 1 as id, 131 | ... [STRUCT([1, 2, 3] as e)] as s1, 132 | ... [STRUCT([4, 5, 6] as e)] as s2 133 | ... ''') 134 | >>> nested.print_schema(df) 135 | root 136 | |-- id: INTEGER (nullable = true) 137 | |-- s1!.e!: INTEGER (nullable = false) 138 | |-- s2!.e!: INTEGER (nullable = false) 139 | 140 | >>> df.show(simplify_structs=True) 141 | +----+---------------+---------------+ 142 | | id | s1 | s2 | 143 | +----+---------------+---------------+ 144 | | 1 | [{[1, 2, 3]}] | [{[4, 5, 6]}] | 145 | +----+---------------+---------------+ 146 | 147 | Here, the lambda expression will be applied to the last repeated element `e`. 148 | >>> new_df = df.transform(nested.select, { 149 | ... "s1!.e!": None, 150 | ... "s2!.e!": lambda e : e.cast("FLOAT64") 151 | ... }) 152 | >>> nested.print_schema(new_df) 153 | root 154 | |-- s1!.e!: INTEGER (nullable = false) 155 | |-- s2!.e!: FLOAT (nullable = false) 156 | 157 | >>> new_df.show(simplify_structs=True) 158 | +---------------+---------------------+ 159 | | s1 | s2 | 160 | +---------------+---------------------+ 161 | | [{[1, 2, 3]}] | [{[4.0, 5.0, 6.0]}] | 162 | +---------------+---------------------+ 163 | 164 | *Example 4: Accessing multiple repetition levels* 165 | >>> df = bq.sql(''' 166 | ... SELECT 167 | ... 1 as id, 168 | ... [ 169 | ... STRUCT(2 as average, [1, 2, 3] as values), 170 | ... STRUCT(3 as average, [1, 2, 3, 4, 5] as values) 171 | ... ] as s1 172 | ... ''') 173 | >>> nested.print_schema(df) 174 | root 175 | |-- id: INTEGER (nullable = true) 176 | |-- s1!.average: INTEGER (nullable = true) 177 | |-- s1!.values!: INTEGER (nullable = false) 178 | 179 | >>> df.show(simplify_structs=True) 180 | +----+----------------------------------------+ 181 | | id | s1 | 182 | +----+----------------------------------------+ 183 | | 1 | [{2, [1, 2, 3]}, {3, [1, 2, 3, 4, 5]}] | 184 | +----+----------------------------------------+ 185 | 186 | Here, the transformation applied to "s1!.values!" takes two arguments. 187 | >>> new_df = df.transform(nested.select, { 188 | ... "id": None, 189 | ... "s1!.average": None, 190 | ... "s1!.values!": lambda s1, value : value - s1["average"] 191 | ... }) 192 | >>> new_df.show(simplify_structs=True) 193 | +----+-------------------------------------------+ 194 | | id | s1 | 195 | +----+-------------------------------------------+ 196 | | 1 | [{2, [-1, 0, 1]}, {3, [-2, -1, 0, 1, 2]}] | 197 | +----+-------------------------------------------+ 198 | 199 | Extra arguments can be added to the left for each repetition level, up to the root level. 200 | >>> new_df = df.transform(nested.select, { 201 | ... "id": None, 202 | ... "s1!.average": None, 203 | ... "s1!.values!": lambda root, s1, value : value - s1["average"] + root["id"] 204 | ... }) 205 | >>> new_df.show(simplify_structs=True) 206 | +----+-----------------------------------------+ 207 | | id | s1 | 208 | +----+-----------------------------------------+ 209 | | 1 | [{2, [0, 1, 2]}, {3, [-1, 0, 1, 2, 3]}] | 210 | +----+-----------------------------------------+ 211 | 212 | """ 213 | return df.select(*resolve_nested_fields(fields, starting_level=df)) 214 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/unnest_all_fields.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from bigquery_frame import DataFrame, nested 4 | from bigquery_frame.field_utils import is_sub_field_or_equal_to_any 5 | from bigquery_frame.nested_impl.package import unnest_fields 6 | 7 | 8 | def unnest_all_fields(df: DataFrame, keep_columns: Optional[list[str]] = None) -> dict[str, DataFrame]: 9 | """Given a DataFrame, return a dict of {granularity: DataFrame} where all arrays have been recursively 10 | unnested (a.k.a. exploded). 11 | This produce one DataFrame for each possible granularity. 12 | 13 | For instance, given a DataFrame with the following flattened schema: 14 | id 15 | s1.a 16 | s2!.b 17 | s2!.c 18 | s2!.s3!.d 19 | s4!.e 20 | s4!.f 21 | 22 | This will produce a dict with four granularity - DataFrames entries: 23 | - '': DataFrame[id, s1.a] ('' corresponds to the root granularity) 24 | - 's2': DataFrame[s2!.b, s2!.c] 25 | - 's2!.s3': DataFrame[s2!.s3!.d] 26 | - 's4': DataFrame[s4!.e, s4!.f] 27 | 28 | !!! warning "Limitation: BigQuery does not support dots and exclamation marks in column names" 29 | For this reason, dots are replaced with the string "__ARRAY__" and exclamation mark are replaced 30 | with "__STRUCT__". When manipulating column names, you can use the utils methods from the 31 | module `bigquery_frame.special_characters` to reverse the replacement. 32 | 33 | Args: 34 | df: A Spark DataFrame 35 | keep_columns: Names of columns that should be kept while unnesting 36 | 37 | Returns: 38 | A list of DataFrames 39 | 40 | Examples: 41 | >>> from bigquery_frame import BigQueryBuilder 42 | >>> from bigquery_frame import nested 43 | >>> bq = BigQueryBuilder() 44 | >>> df = bq.sql(''' 45 | ... SELECT 46 | ... 1 as id, 47 | ... STRUCT(2 as a) as s1, 48 | ... [STRUCT(3 as b, 4 as c, [STRUCT(5 as d), STRUCT(6 as d)] as s3)] as s2, 49 | ... [STRUCT(7 as e, 8 as f), STRUCT(9 as e, 10 as f)] as s4 50 | ... ''') 51 | >>> df.show(simplify_structs=True) 52 | +----+-----+----------------------+-------------------+ 53 | | id | s1 | s2 | s4 | 54 | +----+-----+----------------------+-------------------+ 55 | | 1 | {2} | [{3, 4, [{5}, {6}]}] | [{7, 8}, {9, 10}] | 56 | +----+-----+----------------------+-------------------+ 57 | 58 | >>> nested.fields(df) 59 | ['id', 's1.a', 's2!.b', 's2!.c', 's2!.s3!.d', 's4!.e', 's4!.f'] 60 | >>> result_df_list = nested.unnest_all_fields(df, keep_columns=["id"]) 61 | >>> for cols, result_df in result_df_list.items(): 62 | ... print(cols) 63 | ... result_df.show() 64 | 65 | +----+---------------+ 66 | | id | s1__STRUCT__a | 67 | +----+---------------+ 68 | | 1 | 2 | 69 | +----+---------------+ 70 | s2! 71 | +----+------------------------+------------------------+ 72 | | id | s2__ARRAY____STRUCT__b | s2__ARRAY____STRUCT__c | 73 | +----+------------------------+------------------------+ 74 | | 1 | 3 | 4 | 75 | +----+------------------------+------------------------+ 76 | s2!.s3! 77 | +----+---------------------------------------------+ 78 | | id | s2__ARRAY____STRUCT__s3__ARRAY____STRUCT__d | 79 | +----+---------------------------------------------+ 80 | | 1 | 5 | 81 | | 1 | 6 | 82 | +----+---------------------------------------------+ 83 | s4! 84 | +----+------------------------+------------------------+ 85 | | id | s4__ARRAY____STRUCT__e | s4__ARRAY____STRUCT__f | 86 | +----+------------------------+------------------------+ 87 | | 1 | 7 | 8 | 88 | | 1 | 9 | 10 | 89 | +----+------------------------+------------------------+ 90 | """ 91 | if keep_columns is None: 92 | keep_columns = [] 93 | fields_to_unnest = [field for field in nested.fields(df) if not is_sub_field_or_equal_to_any(field, keep_columns)] 94 | return unnest_fields(df, fields_to_unnest, keep_fields=keep_columns) 95 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/unnest_field.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame.nested_impl.package import unnest_fields 5 | 6 | 7 | def unnest_field(df: DataFrame, field_name: str, keep_columns: Optional[list[str]] = None) -> DataFrame: 8 | """Given a DataFrame, return a new DataFrame where the specified column has been recursively 9 | unnested (a.k.a. exploded). 10 | 11 | !!! warning "Limitation: BigQuery does not support dots and exclamation marks in column names" 12 | For this reason, dots are replaced with the string "__ARRAY__" and exclamation mark are replaced 13 | with "__STRUCT__". When manipulating column names, you can use the utils methods from the 14 | module `bigquery_frame.special_characters` to reverse the replacement. 15 | 16 | Args: 17 | df: A Spark DataFrame 18 | field_name: The name of a nested column to unnest 19 | keep_columns: List of column names to keep while unnesting 20 | 21 | Returns: 22 | A new DataFrame 23 | 24 | Examples: 25 | >>> from bigquery_frame import BigQueryBuilder 26 | >>> from bigquery_frame import nested 27 | >>> bq = BigQueryBuilder() 28 | >>> df = bq.sql(''' 29 | ... SELECT 30 | ... 1 as id, 31 | ... [STRUCT([1, 2] as a), STRUCT([3, 4] as a)] as arr 32 | ... ''') 33 | >>> df.show(simplify_structs=True) 34 | +----+----------------------+ 35 | | id | arr | 36 | +----+----------------------+ 37 | | 1 | [{[1, 2]}, {[3, 4]}] | 38 | +----+----------------------+ 39 | 40 | >>> nested.fields(df) 41 | ['id', 'arr!.a!'] 42 | >>> nested.unnest_field(df, 'arr!').show(simplify_structs=True) 43 | +--------------+ 44 | | arr__ARRAY__ | 45 | +--------------+ 46 | | {[1, 2]} | 47 | | {[3, 4]} | 48 | +--------------+ 49 | 50 | >>> nested.unnest_field(df, 'arr!.a!').show(simplify_structs=True) 51 | +----------------------------------+ 52 | | arr__ARRAY____STRUCT__a__ARRAY__ | 53 | +----------------------------------+ 54 | | 1 | 55 | | 2 | 56 | | 3 | 57 | | 4 | 58 | +----------------------------------+ 59 | 60 | >>> nested.unnest_field(df, 'arr!.a!', keep_columns=["id"]).show(simplify_structs=True) 61 | +----+----------------------------------+ 62 | | id | arr__ARRAY____STRUCT__a__ARRAY__ | 63 | +----+----------------------------------+ 64 | | 1 | 1 | 65 | | 1 | 2 | 66 | | 1 | 3 | 67 | | 1 | 4 | 68 | +----+----------------------------------+ 69 | 70 | >>> df = bq.sql(''' 71 | ... SELECT 72 | ... 1 as id, 73 | ... [ 74 | ... STRUCT([STRUCT("a1" as a, "b1" as b), STRUCT("a2" as a, "b1" as b)] as s2), 75 | ... STRUCT([STRUCT("a3" as a, "b3" as b)] as s2) 76 | ... ] as s1 77 | ... ''') 78 | >>> df.show(simplify_structs=True) 79 | +----+----------------------------------------+ 80 | | id | s1 | 81 | +----+----------------------------------------+ 82 | | 1 | [{[{a1, b1}, {a2, b1}]}, {[{a3, b3}]}] | 83 | +----+----------------------------------------+ 84 | 85 | >>> nested.fields(df) 86 | ['id', 's1!.s2!.a', 's1!.s2!.b'] 87 | >>> nested.unnest_field(df, 's1!.s2!').show(simplify_structs=True) 88 | +----------------------------------+ 89 | | s1__ARRAY____STRUCT__s2__ARRAY__ | 90 | +----------------------------------+ 91 | | {a1, b1} | 92 | | {a2, b1} | 93 | | {a3, b3} | 94 | +----------------------------------+ 95 | """ 96 | if keep_columns is None: 97 | keep_columns = [] 98 | return next(iter(unnest_fields(df, field_name, keep_fields=keep_columns).values())) 99 | -------------------------------------------------------------------------------- /bigquery_frame/nested_impl/with_fields.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | 3 | from bigquery_frame import DataFrame, nested 4 | from bigquery_frame.nested_impl.package import AnyKindOfTransformation, resolve_nested_fields 5 | 6 | 7 | def with_fields(df: DataFrame, fields: Mapping[str, AnyKindOfTransformation]) -> DataFrame: 8 | """Return a new [DataFrame][pyspark.sql.DataFrame] by adding or replacing (when they already exist) columns. 9 | 10 | This method is similar to the [DataFrame.withColumn][bigquery_frame.DataFrame.withColumn] method, with the extra 11 | capability of working on nested and repeated fields (structs and arrays). 12 | 13 | The syntax for field names works as follows: 14 | 15 | - "." is the separator for struct elements 16 | - "!" must be appended at the end of fields that are repeated (arrays) 17 | 18 | The following types of transformation are allowed: 19 | 20 | - String and column expressions can be used on any non-repeated field, even nested ones. 21 | - When working on repeated fields, transformations must be expressed as higher order functions 22 | (e.g. lambda expressions). String and column expressions can be used on repeated fields as well, 23 | but their value will be repeated multiple times. 24 | - When working on multiple levels of nested arrays, higher order functions may take multiple arguments, 25 | corresponding to each level of repetition (See Example 5.). 26 | - `None` can also be used to represent the identity transformation, this is useful to select a field without 27 | changing and without having to repeat its name. 28 | 29 | Args: 30 | df: A DataFrame 31 | fields: A Dict(field_name, transformation_to_apply) 32 | 33 | Returns: 34 | A new DataFrame with the same fields as the input DataFrame, where the specified transformations have been 35 | applied to the corresponding fields. If a field name did not exist in the input DataFrame, 36 | it will be added to the output DataFrame. If it did exist, the original value will be replaced with the new one. 37 | 38 | Examples: 39 | *Example 1: non-repeated fields* 40 | >>> from bigquery_frame import BigQueryBuilder 41 | >>> from bigquery_frame import functions as f 42 | >>> from bigquery_frame import nested 43 | >>> bq = BigQueryBuilder() 44 | >>> df = bq.sql('''SELECT 1 as id, STRUCT(2 as a, 3 as b) as s''') 45 | >>> nested.print_schema(df) 46 | root 47 | |-- id: INTEGER (nullable = true) 48 | |-- s.a: INTEGER (nullable = true) 49 | |-- s.b: INTEGER (nullable = true) 50 | 51 | >>> df.show(simplify_structs=True) 52 | +----+--------+ 53 | | id | s | 54 | +----+--------+ 55 | | 1 | {2, 3} | 56 | +----+--------+ 57 | 58 | Transformations on non-repeated fields may be expressed as a string representing a column name 59 | or a Column expression. 60 | >>> new_df = nested.with_fields(df, { 61 | ... "s.id": "id", # column name (string) 62 | ... "s.c": f.col("s.a") + f.col("s.b") # Column expression 63 | ... }) 64 | >>> new_df.printSchema() 65 | root 66 | |-- id: INTEGER (NULLABLE) 67 | |-- s: RECORD (NULLABLE) 68 | | |-- a: INTEGER (NULLABLE) 69 | | |-- b: INTEGER (NULLABLE) 70 | | |-- id: INTEGER (NULLABLE) 71 | | |-- c: INTEGER (NULLABLE) 72 | 73 | >>> new_df.show(simplify_structs=True) 74 | +----+--------------+ 75 | | id | s | 76 | +----+--------------+ 77 | | 1 | {2, 3, 1, 5} | 78 | +----+--------------+ 79 | 80 | *Example 2: repeated fields* 81 | >>> df = bq.sql(''' 82 | ... SELECT 83 | ... 1 as id, 84 | ... [STRUCT(1 as a, STRUCT(2 as c) as b), STRUCT(3 as a, STRUCT(4 as c) as b)] as s 85 | ... ''') 86 | >>> nested.print_schema(df) 87 | root 88 | |-- id: INTEGER (nullable = true) 89 | |-- s!.a: INTEGER (nullable = true) 90 | |-- s!.b.c: INTEGER (nullable = true) 91 | 92 | >>> df.show(simplify_structs=True) 93 | +----+----------------------+ 94 | | id | s | 95 | +----+----------------------+ 96 | | 1 | [{1, {2}}, {3, {4}}] | 97 | +----+----------------------+ 98 | 99 | Transformations on repeated fields must be expressed as 100 | higher-order functions (lambda expressions or named functions). 101 | The value passed to this function will correspond to the last repeated element. 102 | >>> new_df = df.transform(nested.with_fields, { 103 | ... "s!.b.d": lambda s: s["a"] + s["b"]["c"]} 104 | ... ) 105 | >>> nested.print_schema(new_df) 106 | root 107 | |-- id: INTEGER (nullable = true) 108 | |-- s!.a: INTEGER (nullable = true) 109 | |-- s!.b.c: INTEGER (nullable = true) 110 | |-- s!.b.d: INTEGER (nullable = true) 111 | 112 | >>> new_df.show(simplify_structs=True) 113 | +----+----------------------------+ 114 | | id | s | 115 | +----+----------------------------+ 116 | | 1 | [{1, {2, 3}}, {3, {4, 7}}] | 117 | +----+----------------------------+ 118 | 119 | String and column expressions can be used on repeated fields as well, 120 | but their value will be repeated multiple times. 121 | >>> df.transform(nested.with_fields, { 122 | ... "id": None, 123 | ... "s!.a": "id", 124 | ... "s!.b.c": f.lit(2) 125 | ... }).show(simplify_structs=True) 126 | +----+----------------------+ 127 | | id | s | 128 | +----+----------------------+ 129 | | 1 | [{1, {2}}, {1, {2}}] | 130 | +----+----------------------+ 131 | 132 | *Example 3: field repeated twice* 133 | >>> df = bq.sql('SELECT 1 as id, [STRUCT([1, 2, 3] as e)] as s') 134 | >>> nested.print_schema(df) 135 | root 136 | |-- id: INTEGER (nullable = true) 137 | |-- s!.e!: INTEGER (nullable = false) 138 | 139 | >>> df.show(simplify_structs=True) 140 | +----+---------------+ 141 | | id | s | 142 | +----+---------------+ 143 | | 1 | [{[1, 2, 3]}] | 144 | +----+---------------+ 145 | 146 | Here, the lambda expression will be applied to the last repeated element `e`. 147 | >>> df.transform(nested.with_fields, {"s!.e!": lambda e : e.cast("FLOAT64")}).show(simplify_structs=True) 148 | +----+---------------------+ 149 | | id | s | 150 | +----+---------------------+ 151 | | 1 | [{[1.0, 2.0, 3.0]}] | 152 | +----+---------------------+ 153 | 154 | *Example 4: Accessing multiple repetition levels* 155 | >>> df = bq.sql(''' 156 | ... SELECT 157 | ... 1 as id, 158 | ... [ 159 | ... STRUCT(2 as average, [1, 2, 3] as values), 160 | ... STRUCT(3 as average, [1, 2, 3, 4, 5] as values) 161 | ... ] as s1 162 | ... ''') 163 | >>> nested.print_schema(df) 164 | root 165 | |-- id: INTEGER (nullable = true) 166 | |-- s1!.average: INTEGER (nullable = true) 167 | |-- s1!.values!: INTEGER (nullable = false) 168 | 169 | >>> df.show(simplify_structs=True) 170 | +----+----------------------------------------+ 171 | | id | s1 | 172 | +----+----------------------------------------+ 173 | | 1 | [{2, [1, 2, 3]}, {3, [1, 2, 3, 4, 5]}] | 174 | +----+----------------------------------------+ 175 | 176 | Here, the transformation applied to "s1!.values!" takes two arguments. 177 | >>> new_df = df.transform(nested.with_fields, { 178 | ... "s1!.values!": lambda s1, value : value - s1["average"] 179 | ... }) 180 | >>> new_df.show(simplify_structs=True) 181 | +----+-------------------------------------------+ 182 | | id | s1 | 183 | +----+-------------------------------------------+ 184 | | 1 | [{2, [-1, 0, 1]}, {3, [-2, -1, 0, 1, 2]}] | 185 | +----+-------------------------------------------+ 186 | 187 | Extra arguments can be added to the left for each repetition level, up to the root level. 188 | >>> new_df = df.transform(nested.with_fields, { 189 | ... "s1!.values!": lambda root, s1, value : value - s1["average"] + root["id"] 190 | ... }) 191 | >>> new_df.show(simplify_structs=True) 192 | +----+-----------------------------------------+ 193 | | id | s1 | 194 | +----+-----------------------------------------+ 195 | | 1 | [{2, [0, 1, 2]}, {3, [-1, 0, 1, 2, 3]}] | 196 | +----+-----------------------------------------+ 197 | """ 198 | default_columns = {field: None for field in nested.fields(df)} 199 | fields = {**default_columns, **fields} 200 | return df.select(*resolve_nested_fields(fields, starting_level=df)) 201 | -------------------------------------------------------------------------------- /bigquery_frame/printing.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from google.cloud.bigquery.table import RowIterator 4 | from tabulate import tabulate 5 | 6 | 7 | def _struct_to_string_without_field_names(s: Any) -> str: 8 | """Transform an object into a string, but do not display the field names of dicts. 9 | 10 | Args: 11 | s: The object to transform into a string 12 | 13 | Returns: 14 | A string 15 | 16 | Examples: 17 | >>> _struct_to_string_without_field_names({"a": 1, "b": 2}) 18 | '{1, 2}' 19 | >>> _struct_to_string_without_field_names({"a": [{"s": {"b": 1, "c": 2}}]}) 20 | '{[{{1, 2}}]}' 21 | """ 22 | if isinstance(s, list): 23 | return "[" + ", ".join(_struct_to_string_without_field_names(item) for item in s) + "]" 24 | elif isinstance(s, dict): 25 | return "{" + ", ".join(_struct_to_string_without_field_names(item) for item in s.values()) + "}" 26 | else: 27 | return str(s) 28 | 29 | 30 | def tabulate_results(it: RowIterator, format_args: dict = None, limit=None, simplify_structs=False) -> str: 31 | if format_args is None: 32 | format_args = { 33 | "tablefmt": "pretty", 34 | "missingval": "null", 35 | "stralign": "right", 36 | } 37 | headers = {field.name: field.name for field in it.schema} 38 | rows = list(it) 39 | nb_rows = len(rows) 40 | rows = rows[0:limit] 41 | if simplify_structs: 42 | rows = [[_struct_to_string_without_field_names(field) for field in row] for row in rows] 43 | res = tabulate(rows, headers=headers, **format_args) 44 | if nb_rows > limit: 45 | plural = "s" if limit > 1 else "" 46 | res += f"\nonly showing top {limit} row{plural}" 47 | return res 48 | -------------------------------------------------------------------------------- /bigquery_frame/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/bigquery_frame/py.typed -------------------------------------------------------------------------------- /bigquery_frame/special_characters.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import Column 2 | from bigquery_frame import functions as f 3 | from bigquery_frame.conf import ( 4 | REPETITION_MARKER, 5 | REPETITION_MARKER_REPLACEMENT, 6 | STRUCT_SEPARATOR, 7 | STRUCT_SEPARATOR_REPLACEMENT, 8 | ) 9 | 10 | _replacement_mapping = { 11 | STRUCT_SEPARATOR: STRUCT_SEPARATOR_REPLACEMENT, 12 | REPETITION_MARKER: REPETITION_MARKER_REPLACEMENT, 13 | } 14 | _replacements = str.maketrans(_replacement_mapping) 15 | 16 | 17 | def _replace_special_characters_except_last_granularity(col_name: str) -> str: 18 | """Replace special characters except for the ones at the last granularity 19 | 20 | >>> _replace_special_characters_except_last_granularity("a.b.c") 21 | 'a.b.c' 22 | >>> _replace_special_characters_except_last_granularity("a!.b!.c.d") 23 | 'a__ARRAY____STRUCT__b__ARRAY__.c.d' 24 | 25 | """ 26 | index = col_name.rfind(REPETITION_MARKER) 27 | if index == -1: 28 | return col_name 29 | else: 30 | return _replace_special_characters(col_name[: index + 1]) + col_name[index + 1 :] 31 | 32 | 33 | def _replace_special_characters(col_name: str) -> str: 34 | """Replace special characters 35 | 36 | >>> _replace_special_characters("a.b!.c") 37 | 'a__STRUCT__b__ARRAY____STRUCT__c' 38 | """ 39 | return col_name.translate(_replacements) 40 | 41 | 42 | def _restore_special_characters(col_name: str) -> str: 43 | """Restore special characters 44 | 45 | >>> _restore_special_characters("a__STRUCT__b__ARRAY____STRUCT__c") 46 | 'a.b!.c' 47 | """ 48 | result = col_name 49 | for value, replacement in _replacement_mapping.items(): 50 | result = result.replace(replacement, value) 51 | return result 52 | 53 | 54 | def _restore_special_characters_from_col(col: Column) -> Column: 55 | return f.regexp_replace( 56 | f.regexp_replace(col, STRUCT_SEPARATOR_REPLACEMENT, STRUCT_SEPARATOR), 57 | REPETITION_MARKER_REPLACEMENT, 58 | REPETITION_MARKER, 59 | ) 60 | -------------------------------------------------------------------------------- /bigquery_frame/temp_names.py: -------------------------------------------------------------------------------- 1 | DEFAULT_ALIAS_NAME = "_default_alias_{num}" 2 | DEFAULT_TEMP_COLUMN_NAME = "_default_temp_column_{num}" 3 | DEFAULT_TEMP_TABLE_NAME = "_default_temp_table_{num}" 4 | 5 | 6 | _alias_count = 0 7 | _temp_column_count = 0 8 | _temp_table_count = 0 9 | 10 | 11 | def _get_alias() -> str: 12 | global _alias_count 13 | _alias_count += 1 14 | return "{" + DEFAULT_ALIAS_NAME.format(num=_alias_count) + "}" 15 | 16 | 17 | def _get_temp_column_name() -> str: 18 | global _temp_column_count 19 | _temp_column_count += 1 20 | return DEFAULT_TEMP_COLUMN_NAME.format(num=_temp_column_count) 21 | 22 | 23 | def _get_temp_table_name() -> str: 24 | global _temp_table_count 25 | _temp_table_count += 1 26 | return DEFAULT_TEMP_TABLE_NAME.format(num=_temp_table_count) 27 | -------------------------------------------------------------------------------- /bigquery_frame/transformations.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.transformations_impl.analyze import analyze 2 | from bigquery_frame.transformations_impl.flatten import flatten 3 | from bigquery_frame.transformations_impl.harmonize_dataframes import harmonize_dataframes 4 | from bigquery_frame.transformations_impl.normalize_arrays import normalize_arrays 5 | from bigquery_frame.transformations_impl.pivot_unpivot import pivot, unpivot 6 | from bigquery_frame.transformations_impl.sort_all_arrays import sort_all_arrays 7 | from bigquery_frame.transformations_impl.sort_columns import sort_columns 8 | from bigquery_frame.transformations_impl.transform_all_fields import transform_all_fields 9 | from bigquery_frame.transformations_impl.union_dataframes import union_dataframes 10 | 11 | analyze = analyze 12 | flatten = flatten 13 | harmonize_dataframes = harmonize_dataframes 14 | normalize_arrays = normalize_arrays 15 | pivot = pivot 16 | sort_all_arrays = sort_all_arrays 17 | sort_columns = sort_columns 18 | transform_all_fields = transform_all_fields 19 | union_dataframes = union_dataframes 20 | unpivot = unpivot 21 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/bigquery_frame/transformations_impl/__init__.py -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/analyze_aggs.py: -------------------------------------------------------------------------------- 1 | from google.cloud.bigquery import SchemaField 2 | 3 | from bigquery_frame import Column 4 | from bigquery_frame import functions as f 5 | 6 | 7 | def _to_string(col: Column, field_type: str): 8 | if field_type == "BYTES": 9 | return f.to_base64(col) 10 | else: 11 | return col.cast("STRING") 12 | 13 | 14 | def column_number(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 15 | return f.lit(col_num).alias("column_number") 16 | 17 | 18 | def column_name(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 19 | return f.lit(schema_field.name).alias("column_name") 20 | 21 | 22 | def column_type(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 23 | return f.lit(schema_field.field_type).alias("column_type") 24 | 25 | 26 | def count(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 27 | return f.count(f.lit(1)).alias("count") 28 | 29 | 30 | def count_distinct(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 31 | return f.count_distinct(col).alias("count_distinct") 32 | 33 | 34 | def count_null(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 35 | return (f.count(f.lit(1)) - f.count(col)).alias("count_null") 36 | 37 | 38 | def min(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 39 | return _to_string(f.min(col), schema_field.field_type).alias("min") 40 | 41 | 42 | def max(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 43 | return _to_string(f.max(col), schema_field.field_type).alias("max") 44 | 45 | 46 | def approx_top_100(col: str, schema_field: SchemaField, col_num: int) -> Column: # NOSONAR 47 | column = f.coalesce(_to_string(f.col(col), schema_field.field_type), f.lit("NULL")) 48 | return f.expr(f"APPROX_TOP_COUNT({column.expr}, 100)").alias("approx_top_100") 49 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/flatten.py: -------------------------------------------------------------------------------- 1 | from google.cloud.bigquery import SchemaField 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame import functions as f 5 | from bigquery_frame.dataframe import is_repeated, is_struct 6 | 7 | 8 | def flatten(df: DataFrame, struct_separator: str = "_") -> DataFrame: 9 | """Flattens all the struct columns of a DataFrame 10 | Nested fields names will be joined together using the specified separator 11 | 12 | Examples: 13 | >>> from bigquery_frame import BigQueryBuilder 14 | >>> bq = BigQueryBuilder() 15 | >>> df = bq.sql('''SELECT 1 as id, STRUCT(1 as a, STRUCT(1 as c, 1 as d) as b) as s''') 16 | >>> df.printSchema() 17 | root 18 | |-- id: INTEGER (NULLABLE) 19 | |-- s: RECORD (NULLABLE) 20 | | |-- a: INTEGER (NULLABLE) 21 | | |-- b: RECORD (NULLABLE) 22 | | | |-- c: INTEGER (NULLABLE) 23 | | | |-- d: INTEGER (NULLABLE) 24 | 25 | >>> flatten(df).printSchema() 26 | root 27 | |-- id: INTEGER (NULLABLE) 28 | |-- s_a: INTEGER (NULLABLE) 29 | |-- s_b_c: INTEGER (NULLABLE) 30 | |-- s_b_d: INTEGER (NULLABLE) 31 | 32 | >>> flatten(df, "__").printSchema() 33 | root 34 | |-- id: INTEGER (NULLABLE) 35 | |-- s__a: INTEGER (NULLABLE) 36 | |-- s__b__c: INTEGER (NULLABLE) 37 | |-- s__b__d: INTEGER (NULLABLE) 38 | 39 | 40 | :param df: a DataFrame 41 | :param struct_separator: It might be useful to change the separator when some DataFrame's column names already 42 | contain dots 43 | :return: a flattened DataFrame 44 | """ 45 | # The idea is to recursively write a "SELECT s.b.c as s_b_c" for each nested column. 46 | cols = [] 47 | 48 | def expand_struct(struct: list[SchemaField], col_stack: list[str]): 49 | for field in struct: 50 | if is_struct(field) and not is_repeated(field): 51 | expand_struct(field.fields, col_stack + [field.name]) 52 | else: 53 | col_expr = ".".join(col_stack + [field.name]) 54 | col_alias = struct_separator.join(col_stack + [field.name]) 55 | column = f.col(col_expr).alias(col_alias) 56 | cols.append(column) 57 | 58 | expand_struct(df.schema, col_stack=[]) 59 | return df.select(cols) 60 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/harmonize_dataframes.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from bigquery_frame import Column, DataFrame 4 | from bigquery_frame import functions as f 5 | from bigquery_frame.conf import REPETITION_MARKER, STRUCT_SEPARATOR 6 | from bigquery_frame.data_type_utils import flatten_schema, get_common_columns 7 | 8 | 9 | def harmonize_dataframes( 10 | left_df: DataFrame, 11 | right_df: DataFrame, 12 | common_columns: Optional[dict[str, Optional[str]]] = None, 13 | keep_missing_columns: bool = False, 14 | ) -> tuple[DataFrame, DataFrame]: 15 | """Given two DataFrames, returns two new corresponding DataFrames with the same schemas by applying the following 16 | changes: 17 | 18 | - Only common columns are kept 19 | - Columns are re-ordered to have the same ordering in both DataFrames 20 | - When matching columns have different types, their type is widened to their most narrow common type. 21 | This transformation is applied recursively on nested columns, including those inside 22 | repeated records (a.k.a. ARRAY>). 23 | 24 | Args: 25 | left_df: A DataFrame 26 | right_df: A DataFrame 27 | common_columns: A dict of (column name, type). 28 | Column names must appear in both DataFrames, and each column will be cast into the corresponding type. 29 | keep_missing_columns: If set to true, the root columns of each DataFrames that do not exist in the other 30 | one are kept. 31 | 32 | Returns: 33 | Two new DataFrames with the same schema 34 | 35 | Examples: 36 | >>> from bigquery_frame import BigQueryBuilder 37 | >>> bq = BigQueryBuilder() 38 | >>> df1 = bq.sql('SELECT 1 as id, STRUCT(1 as a, [STRUCT(2 as c, 3 as d)] as b, [4, 5] as f) as s') 39 | >>> df2 = bq.sql( 40 | ... 'SELECT 1 as id, STRUCT(2 as a, [STRUCT(3.0 as c, "4" as d, 5 as e)] as b, [5.0, 6.0] as f) as s' 41 | ... ) 42 | >>> df1.union(df2).show() # doctest: +ELLIPSIS 43 | Traceback (most recent call last): 44 | ... 45 | google.api_core.exceptions.BadRequest: 400 Column 2 in UNION ALL has incompatible types: ... 46 | >>> df1, df2 = harmonize_dataframes(df1, df2) 47 | >>> df1.union(df2).show() 48 | +----+--------------------------------------------------------+ 49 | | id | s | 50 | +----+--------------------------------------------------------+ 51 | | 1 | {'a': 1, 'b': [{'c': 2.0, 'd': '3'}], 'f': [4.0, 5.0]} | 52 | | 1 | {'a': 2, 'b': [{'c': 3.0, 'd': '4'}], 'f': [5.0, 6.0]} | 53 | +----+--------------------------------------------------------+ 54 | """ 55 | left_schema_flat = flatten_schema(left_df.schema, explode=True) 56 | right_schema_flat = flatten_schema(right_df.schema, explode=True) 57 | if common_columns is None: 58 | common_columns = get_common_columns(left_schema_flat, right_schema_flat) 59 | 60 | left_only_columns = {} 61 | right_only_columns = {} 62 | if keep_missing_columns: 63 | left_cols = [field.name for field in left_schema_flat] 64 | right_cols = [field.name for field in right_schema_flat] 65 | left_cols_set = set(left_cols) 66 | right_cols_set = set(right_cols) 67 | left_only_columns = {col: None for col in left_cols if col not in right_cols_set} 68 | right_only_columns = {col: None for col in right_cols if col not in left_cols_set} 69 | 70 | def build_col(col_name: str, col_type: Optional[str]) -> Column: 71 | if col_name[-1] == REPETITION_MARKER: 72 | if col_type is not None: 73 | return lambda col: col.cast(col_type) 74 | else: 75 | return lambda col: col 76 | else: 77 | col = f.col(col_name.split(REPETITION_MARKER + STRUCT_SEPARATOR)[-1]) 78 | if col_type is not None: 79 | return col.cast(col_type) 80 | else: 81 | return col 82 | 83 | left_columns = {**common_columns, **left_only_columns} 84 | right_columns = {**common_columns, **right_only_columns} 85 | left_columns_dict = {col_name: build_col(col_name, col_type) for (col_name, col_type) in left_columns.items()} 86 | right_columns_dict = {col_name: build_col(col_name, col_type) for (col_name, col_type) in right_columns.items()} 87 | 88 | return left_df.select_nested_columns(left_columns_dict), right_df.select_nested_columns(right_columns_dict) 89 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/normalize_arrays.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from bigquery_frame import DataFrame 4 | from bigquery_frame.transformations_impl.sort_all_arrays import sort_all_arrays 5 | 6 | 7 | def normalize_arrays(df: DataFrame) -> DataFrame: 8 | """Given a DataFrame, sort all columns of type arrays (even nested ones) in a canonical order, 9 | making them comparable 10 | 11 | !!! warning 12 | This method is deprecated since version 0.5.0 and will be removed in version 0.6.0. 13 | Please use transformations.sort_all_arrays instead. 14 | 15 | >>> from bigquery_frame import BigQueryBuilder 16 | >>> bq = BigQueryBuilder() 17 | >>> df = bq.sql('SELECT [3, 2, 1] as a') 18 | 19 | >>> df.show() 20 | +-----------+ 21 | | a | 22 | +-----------+ 23 | | [3, 2, 1] | 24 | +-----------+ 25 | >>> normalize_arrays(df).show() 26 | +-----------+ 27 | | a | 28 | +-----------+ 29 | | [1, 2, 3] | 30 | +-----------+ 31 | 32 | >>> df = bq.sql('SELECT [STRUCT(2 as a, 1 as b), STRUCT(1 as a, 2 as b), STRUCT(1 as a, 1 as b)] as s') 33 | >>> df.show() 34 | +--------------------------------------------------------+ 35 | | s | 36 | +--------------------------------------------------------+ 37 | | [{'a': 2, 'b': 1}, {'a': 1, 'b': 2}, {'a': 1, 'b': 1}] | 38 | +--------------------------------------------------------+ 39 | >>> normalize_arrays(df).show() 40 | +--------------------------------------------------------+ 41 | | s | 42 | +--------------------------------------------------------+ 43 | | [{'a': 1, 'b': 1}, {'a': 1, 'b': 2}, {'a': 2, 'b': 1}] | 44 | +--------------------------------------------------------+ 45 | 46 | >>> df = bq.sql('''SELECT [ 47 | ... STRUCT([STRUCT(2 as a, 2 as b), STRUCT(2 as a, 1 as b)] as l2), 48 | ... STRUCT([STRUCT(1 as a, 2 as b), STRUCT(1 as a, 1 as b)] as l2) 49 | ... ] as l1 50 | ... ''') 51 | >>> df.show() 52 | +----------------------------------------------------------------------------------------------+ 53 | | l1 | 54 | +----------------------------------------------------------------------------------------------+ 55 | | [{'l2': [{'a': 2, 'b': 2}, {'a': 2, 'b': 1}]}, {'l2': [{'a': 1, 'b': 2}, {'a': 1, 'b': 1}]}] | 56 | +----------------------------------------------------------------------------------------------+ 57 | >>> normalize_arrays(df).show() 58 | +----------------------------------------------------------------------------------------------+ 59 | | l1 | 60 | +----------------------------------------------------------------------------------------------+ 61 | | [{'l2': [{'a': 1, 'b': 1}, {'a': 1, 'b': 2}]}, {'l2': [{'a': 2, 'b': 1}, {'a': 2, 'b': 2}]}] | 62 | +----------------------------------------------------------------------------------------------+ 63 | 64 | >>> df = bq.sql('''SELECT [ 65 | ... STRUCT(STRUCT(2 as a, 2 as b) as s), 66 | ... STRUCT(STRUCT(1 as a, 2 as b) as s) 67 | ... ] as l1 68 | ... ''') 69 | >>> df.show() 70 | +----------------------------------------------------+ 71 | | l1 | 72 | +----------------------------------------------------+ 73 | | [{'s': {'a': 2, 'b': 2}}, {'s': {'a': 1, 'b': 2}}] | 74 | +----------------------------------------------------+ 75 | >>> normalize_arrays(df).show() 76 | +----------------------------------------------------+ 77 | | l1 | 78 | +----------------------------------------------------+ 79 | | [{'s': {'a': 1, 'b': 2}}, {'s': {'a': 2, 'b': 2}}] | 80 | +----------------------------------------------------+ 81 | 82 | :return: 83 | """ 84 | warning_message = ( 85 | "The method bigquery_frame.transformations.normalize_arrays is deprecated since version 0.5.0 " 86 | "and will be removed in version 0.6.0. " 87 | "Please use transformations.sort_all_arrays instead." 88 | ) 89 | warnings.warn(warning_message, category=DeprecationWarning) 90 | return sort_all_arrays(df) 91 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/sort_all_arrays.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from google.cloud.bigquery import SchemaField 4 | 5 | from bigquery_frame import Column, DataFrame 6 | from bigquery_frame import functions as f 7 | from bigquery_frame.dataframe import is_repeated, is_struct 8 | from bigquery_frame.transformations_impl.transform_all_fields import transform_all_fields 9 | 10 | 11 | def sort_all_arrays(df: DataFrame) -> DataFrame: 12 | """Given a DataFrame, sort all fields of type `ARRAY` in a canonical order, making them comparable. 13 | This also applies to nested fields, even those inside other arrays. 14 | 15 | Args: 16 | df: A DataFrame 17 | 18 | Returns: 19 | A new DataFrame where all arrays have been sorted. 20 | 21 | Examples: 22 | *Example 1:* with a simple `ARRAY` 23 | 24 | >>> from bigquery_frame import BigQueryBuilder 25 | >>> bq = BigQueryBuilder() 26 | >>> df = bq.sql('SELECT 1 as id, [3, 2, 1] as a') 27 | >>> df.show() 28 | +----+-----------+ 29 | | id | a | 30 | +----+-----------+ 31 | | 1 | [3, 2, 1] | 32 | +----+-----------+ 33 | 34 | >>> sort_all_arrays(df).show() 35 | +----+-----------+ 36 | | id | a | 37 | +----+-----------+ 38 | | 1 | [1, 2, 3] | 39 | +----+-----------+ 40 | 41 | *Example 2:* with an `ARRAY>` 42 | 43 | >>> df = bq.sql('SELECT [STRUCT(2 as a, 1 as b), STRUCT(1 as a, 2 as b), STRUCT(1 as a, 1 as b)] as s') 44 | >>> df.show(simplify_structs=True) 45 | +--------------------------+ 46 | | s | 47 | +--------------------------+ 48 | | [{2, 1}, {1, 2}, {1, 1}] | 49 | +--------------------------+ 50 | 51 | >>> df.transform(sort_all_arrays).show(simplify_structs=True) 52 | +--------------------------+ 53 | | s | 54 | +--------------------------+ 55 | | [{1, 1}, {1, 2}, {2, 1}] | 56 | +--------------------------+ 57 | 58 | *Example 3:* with an `ARRAY>>` 59 | 60 | >>> df = bq.sql('''SELECT [ 61 | ... STRUCT(STRUCT(2 as a, 2 as b) as s), 62 | ... STRUCT(STRUCT(1 as a, 2 as b) as s) 63 | ... ] as l1 64 | ... ''') 65 | >>> df.show(simplify_structs=True) 66 | +----------------------+ 67 | | l1 | 68 | +----------------------+ 69 | | [{{2, 2}}, {{1, 2}}] | 70 | +----------------------+ 71 | 72 | >>> df.transform(sort_all_arrays).show(simplify_structs=True) 73 | +----------------------+ 74 | | l1 | 75 | +----------------------+ 76 | | [{{1, 2}}, {{2, 2}}] | 77 | +----------------------+ 78 | 79 | *Example 4:* with an `ARRAY>>` 80 | 81 | As this example shows, the innermost arrays are sorted before the outermost arrays. 82 | 83 | >>> df = bq.sql('''SELECT [ 84 | ... STRUCT([STRUCT([4, 1] as b), STRUCT([3, 2] as b)] as a), 85 | ... STRUCT([STRUCT([2, 2] as b), STRUCT([2, 1] as b)] as a) 86 | ... ] as l1 87 | ... ''') 88 | >>> df.show(simplify_structs=True) 89 | +--------------------------------------------------+ 90 | | l1 | 91 | +--------------------------------------------------+ 92 | | [{[{[4, 1]}, {[3, 2]}]}, {[{[2, 2]}, {[2, 1]}]}] | 93 | +--------------------------------------------------+ 94 | 95 | >>> df.transform(sort_all_arrays).show(simplify_structs=True) 96 | +--------------------------------------------------+ 97 | | l1 | 98 | +--------------------------------------------------+ 99 | | [{[{[1, 2]}, {[2, 2]}]}, {[{[1, 4]}, {[2, 3]}]}] | 100 | +--------------------------------------------------+ 101 | """ 102 | 103 | def sort_array(col: Column, field: SchemaField) -> Optional[Column]: 104 | def json_if_not_sortable(col: Column, _field: SchemaField) -> Column: 105 | if is_struct(_field) or is_repeated(_field): 106 | return f.expr(f"TO_JSON_STRING({col.expr})") 107 | else: 108 | return col 109 | 110 | if is_repeated(field): 111 | if is_struct(field): 112 | return f.sort_array( 113 | col, 114 | lambda c: [json_if_not_sortable(c[_field.name], _field) for _field in field.fields], 115 | ) 116 | else: 117 | return f.sort_array(col) 118 | else: 119 | return None 120 | 121 | return transform_all_fields(df, sort_array) 122 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/sort_columns.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import DataFrame 2 | 3 | 4 | def sort_columns(df: DataFrame) -> DataFrame: 5 | """Returns a new DataFrame where the order of columns has been sorted 6 | 7 | Examples: 8 | >>> from bigquery_frame import BigQueryBuilder 9 | >>> bq = BigQueryBuilder() 10 | >>> df = bq.sql('''SELECT 1 as b, 1 as a, 1 as c''') 11 | >>> df.printSchema() 12 | root 13 | |-- b: INTEGER (NULLABLE) 14 | |-- a: INTEGER (NULLABLE) 15 | |-- c: INTEGER (NULLABLE) 16 | 17 | >>> sort_columns(df).printSchema() 18 | root 19 | |-- a: INTEGER (NULLABLE) 20 | |-- b: INTEGER (NULLABLE) 21 | |-- c: INTEGER (NULLABLE) 22 | 23 | """ 24 | return df.select(*sorted(df.columns)) 25 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/transform_all_fields.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | from google.cloud.bigquery import SchemaField 4 | 5 | from bigquery_frame import Column, DataFrame 6 | from bigquery_frame.nested_impl.package import build_transformation_from_schema 7 | 8 | 9 | def transform_all_fields( 10 | df: DataFrame, 11 | transformation: Callable[[Column, SchemaField], Optional[Column]], 12 | ) -> DataFrame: 13 | """Apply a transformation to all nested fields of a DataFrame. 14 | 15 | !!! info 16 | This method is compatible with any schema. It recursively applies on structs and arrays 17 | and is compatible with field names containing special characters. 18 | 19 | !!! warning "BigQuery specificity" 20 | When applying a transformation on all columns of a given type, make sure to check 21 | that `schema_field.mode != "REPEATED"` otherwise the transformation will be applied on arrays 22 | containing this type too. 23 | *Explanation:* in BigQuery, columns of type `ARRAY` are represented with a 24 | [SchemaField](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.schema.SchemaField) 25 | with attributes `mode="REPEATED"` and `field_type="T"`. 26 | 27 | Args: 28 | df: A DataFrame 29 | transformation: Transformation to apply to all fields of the DataFrame. The transformation must take as input 30 | a Column expression and the DataType of the corresponding expression. 31 | 32 | Returns: 33 | A new DataFrame 34 | 35 | Examples: 36 | >>> from bigquery_frame import BigQueryBuilder 37 | >>> from bigquery_frame import nested 38 | >>> bq = BigQueryBuilder() 39 | 40 | >>> df = bq.sql('''SELECT 41 | ... "John" as name, 42 | ... [1, 2, 3] as s1, 43 | ... [STRUCT(1 as a), STRUCT(2 as a)] as s2, 44 | ... [STRUCT([1, 2] as a), STRUCT([3, 4] as a)] as s3, 45 | ... [ 46 | ... STRUCT([STRUCT(STRUCT(1 as c) as b), STRUCT(STRUCT(2 as c) as b)] as a), 47 | ... STRUCT([STRUCT(STRUCT(3 as c) as b), STRUCT(STRUCT(4 as c) as b)] as a) 48 | ... ] as s4 49 | ... ''') 50 | >>> nested.print_schema(df) 51 | root 52 | |-- name: STRING (nullable = true) 53 | |-- s1!: INTEGER (nullable = false) 54 | |-- s2!.a: INTEGER (nullable = true) 55 | |-- s3!.a!: INTEGER (nullable = false) 56 | |-- s4!.a!.b.c: INTEGER (nullable = true) 57 | 58 | >>> df.show(simplify_structs=True) 59 | +------+-----------+------------+----------------------+--------------------------------------+ 60 | | name | s1 | s2 | s3 | s4 | 61 | +------+-----------+------------+----------------------+--------------------------------------+ 62 | | John | [1, 2, 3] | [{1}, {2}] | [{[1, 2]}, {[3, 4]}] | [{[{{1}}, {{2}}]}, {[{{3}}, {{4}}]}] | 63 | +------+-----------+------------+----------------------+--------------------------------------+ 64 | >>> from bigquery_frame.dataframe import is_repeated 65 | >>> def cast_int_as_double(col: Column, schema_field: SchemaField): 66 | ... if schema_field.field_type == "INTEGER" and schema_field.mode != "REPEATED": 67 | ... return col.cast("FLOAT64") 68 | >>> new_df = df.transform(transform_all_fields, cast_int_as_double) 69 | >>> nested.print_schema(new_df) 70 | root 71 | |-- name: STRING (nullable = true) 72 | |-- s1!: FLOAT (nullable = false) 73 | |-- s2!.a: FLOAT (nullable = true) 74 | |-- s3!.a!: FLOAT (nullable = false) 75 | |-- s4!.a!.b.c: FLOAT (nullable = true) 76 | 77 | >>> new_df.show(simplify_structs=True) 78 | +------+-----------------+----------------+------------------------------+----------------------------------------------+ 79 | | name | s1 | s2 | s3 | s4 | 80 | +------+-----------------+----------------+------------------------------+----------------------------------------------+ 81 | | John | [1.0, 2.0, 3.0] | [{1.0}, {2.0}] | [{[1.0, 2.0]}, {[3.0, 4.0]}] | [{[{{1.0}}, {{2.0}}]}, {[{{3.0}}, {{4.0}}]}] | 82 | +------+-----------------+----------------+------------------------------+----------------------------------------------+ 83 | """ # noqa: E501 84 | root_transformation = build_transformation_from_schema( 85 | df.schema, 86 | column_transformation=transformation, 87 | ) 88 | return df.select(*root_transformation(df)) 89 | -------------------------------------------------------------------------------- /bigquery_frame/transformations_impl/union_dataframes.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import DataFrame 2 | from bigquery_frame.utils import quote 3 | 4 | 5 | def union_dataframes(dfs: list[DataFrame]) -> DataFrame: 6 | """Returns the union between multiple DataFrames""" 7 | if len(dfs) == 0: 8 | raise ValueError("input list is empty") 9 | query = "\nUNION ALL\n".join([f" SELECT * FROM {quote(df._alias)}" for df in dfs]) 10 | return DataFrame(query, alias=None, bigquery=dfs[0].bigquery, deps=dfs) 11 | -------------------------------------------------------------------------------- /bigquery_frame/units.py: -------------------------------------------------------------------------------- 1 | KIBI = 1024 2 | MULTIPLES = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] 3 | 4 | 5 | def bytes_to_human_readable(byte_amount: int) -> str: 6 | """Transform a integer representing an amount of bytes into a human-readable string. 7 | 8 | >>> bytes_to_human_readable(128) 9 | '128.00 B' 10 | >>> bytes_to_human_readable(2048) 11 | '2.00 KiB' 12 | >>> bytes_to_human_readable(1000000000) 13 | '953.67 MiB' 14 | 15 | :param byte_amount: 16 | :return: 17 | """ 18 | coef = 0 19 | float_amount = float(byte_amount) 20 | while float_amount > KIBI and coef < len(MULTIPLES) - 1: 21 | coef += 1 22 | float_amount = float_amount / 1024 23 | return f"{float_amount:0.2f} {MULTIPLES[coef]}B" 24 | -------------------------------------------------------------------------------- /bigquery_frame/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import re 3 | from collections.abc import Iterable 4 | from typing import TYPE_CHECKING, TypeVar, Union 5 | 6 | if TYPE_CHECKING: 7 | from bigquery_frame import Column 8 | from bigquery_frame.column import ColumnOrName, LitOrColumn 9 | 10 | T = TypeVar("T") 11 | K = TypeVar("K") 12 | V = TypeVar("V") 13 | 14 | MAX_JAVA_INT = 2147483647 15 | 16 | 17 | def strip_margin(text: str): 18 | """For every line in this string, strip a leading prefix consisting of whitespace, tabs and carriage returns 19 | followed by | from the line. 20 | 21 | If the first character is a newline, it is also removed. 22 | This method is inspired from Scala's String.stripMargin. 23 | 24 | Args: 25 | text: A multi-line string 26 | 27 | Returns: 28 | A stripped string 29 | 30 | Examples: 31 | >>> print(strip_margin(''' 32 | ... |a 33 | ... |b 34 | ... |c''')) 35 | a 36 | b 37 | c 38 | >>> print(strip_margin('''a 39 | ... |b 40 | ... |c 41 | ... |d''')) 42 | a 43 | b 44 | c 45 | d 46 | """ 47 | s = re.sub(r"\n[ \t\r]*\|", "\n", text) 48 | if s.startswith("\n"): 49 | return s[1:] 50 | else: 51 | return s 52 | 53 | 54 | def indent(str, nb) -> str: 55 | return " " * nb + str.replace("\n", "\n" + " " * nb) 56 | 57 | 58 | def group_by_key(items: Iterable[tuple[K, V]]) -> dict[K, list[V]]: 59 | """Group the values of a list of tuples by their key. 60 | 61 | Args: 62 | items: An iterable of tuples (key, value). 63 | 64 | Returns: 65 | A dictionary where the keys are the keys from the input tuples, 66 | and the values are lists of the corresponding values. 67 | 68 | Examples: 69 | >>> items = [('a', 1), ('b', 2), ('a', 3), ('c', 4), ('b', 5)] 70 | >>> group_by_key(items) 71 | {'a': [1, 3], 'b': [2, 5], 'c': [4]} 72 | >>> group_by_key([]) 73 | {} 74 | """ 75 | result: dict[K, list[V]] = {} 76 | for key, value in items: 77 | if key in result: 78 | result[key].append(value) 79 | else: 80 | result[key] = [value] 81 | return result 82 | 83 | 84 | def quote(string) -> str: 85 | """Add quotes around a column or table names to prevent collision with SQL keywords. 86 | This method is idempotent: it does not add new quotes to an already quoted string. 87 | If the column name is a reference to a nested column (i.e. if it contains dots), each part is quoted separately. 88 | 89 | Examples: 90 | >>> quote("table") 91 | '`table`' 92 | >>> quote("`table`") 93 | '`table`' 94 | >>> quote("column.name") 95 | '`column`.`name`' 96 | >>> quote("*") 97 | '*' 98 | 99 | """ 100 | return ".".join(["`" + s + "`" if s != "*" else "*" for s in string.replace("`", "").split(".")]) 101 | 102 | 103 | def quote_columns(columns: list[str]) -> list[str]: 104 | """Puts every column name of the given list into quotes.""" 105 | return [quote(col) for col in columns] 106 | 107 | 108 | def str_to_col(args: "ColumnOrName") -> "Column": 109 | """Converts string or Column argument to Column types 110 | 111 | Examples: 112 | >>> str_to_col("id") 113 | Column<'`id`'> 114 | >>> from bigquery_frame import functions as f 115 | >>> str_to_col(f.expr("COUNT(1)")) 116 | Column<'COUNT(1)'> 117 | >>> str_to_col("*") 118 | Column<'*'> 119 | 120 | """ 121 | from bigquery_frame import functions as f 122 | 123 | if isinstance(args, str): 124 | return f.col(args) 125 | else: 126 | return args 127 | 128 | 129 | def str_to_cols(args: Iterable["ColumnOrName"]) -> list["Column"]: 130 | """Converts string or Column arguments to Column types 131 | 132 | Examples: 133 | >>> str_to_cols(["c1", "c2"]) 134 | [Column<'`c1`'>, Column<'`c2`'>] 135 | >>> from bigquery_frame import functions as f 136 | >>> str_to_col(f.expr("COUNT(1)")) 137 | Column<'COUNT(1)'> 138 | >>> str_to_col("*") 139 | Column<'*'> 140 | """ 141 | return [str_to_col(arg) for arg in args] 142 | 143 | 144 | def lit_to_col(args: "LitOrColumn") -> "Column": 145 | """Converts literal string or Column argument to Column type 146 | 147 | Examples: 148 | >>> lit_to_col("id") 149 | Column<'r\"\"\"id\"\"\"'> 150 | >>> from bigquery_frame import functions as f 151 | >>> lit_to_col(f.expr("COUNT(1)")) 152 | Column<'COUNT(1)'> 153 | >>> lit_to_col("*") 154 | Column<'r\"\"\"*\"\"\"'> 155 | """ 156 | from bigquery_frame import Column 157 | from bigquery_frame import functions as f 158 | 159 | if isinstance(args, Column): 160 | return args 161 | else: 162 | return f.lit(args) 163 | 164 | 165 | def lit_to_cols(args: Iterable["LitOrColumn"]) -> list["Column"]: 166 | """Converts literal string or Column argument to Column type 167 | 168 | Examples: 169 | >>> lit_to_cols(["id", "c"]) 170 | [Column<'r\"\"\"id\"\"\"'>, Column<'r\"\"\"c\"\"\"'>] 171 | >>> from bigquery_frame import functions as f 172 | >>> lit_to_cols([f.expr("COUNT(1)"), "*"]) 173 | [Column<'COUNT(1)'>, Column<'r\"\"\"*\"\"\"'>] 174 | """ 175 | return [lit_to_col(arg) for arg in args] 176 | 177 | 178 | def number_lines(string: str, starting_index: int = 1) -> str: 179 | """Given a multi-line string, return a new string where each line is prepended with its number 180 | 181 | Example: 182 | >>> print(number_lines('Hello\\nWorld!')) 183 | 1: Hello 184 | 2: World! 185 | """ 186 | lines = string.split("\n") 187 | max_index = starting_index + len(lines) - 1 188 | nb_zeroes = int(math.log10(max_index)) + 1 189 | numbered_lines = [str(index + starting_index).zfill(nb_zeroes) + ": " + line for index, line in enumerate(lines)] 190 | return "\n".join(numbered_lines) 191 | 192 | 193 | def assert_true(assertion: bool, error: Union[str, BaseException] = None) -> None: 194 | """Raise an Exception with the given error_message if the assertion passed is false. 195 | 196 | !!! tip 197 | This method is especially useful to get 100% coverage more easily, without having to write tests for every 198 | single assertion to cover the cases when they fail (which are generally just there to provide a more helpful 199 | error message to users when something that is not supposed to happen does happen) 200 | 201 | Args: 202 | assertion: The boolean result of an assertion 203 | error: An Exception or a message string (in which case an AssertError with this message will be raised) 204 | 205 | >>> assert_true(3==3, "3 <> 4") 206 | >>> assert_true(3==4, "3 <> 4") 207 | Traceback (most recent call last): 208 | ... 209 | AssertionError: 3 <> 4 210 | >>> assert_true(3==4, ValueError("3 <> 4")) 211 | Traceback (most recent call last): 212 | ... 213 | ValueError: 3 <> 4 214 | >>> assert_true(3==4) 215 | Traceback (most recent call last): 216 | ... 217 | AssertionError 218 | """ 219 | if not assertion: 220 | if isinstance(error, BaseException): 221 | raise error 222 | elif isinstance(error, str): 223 | raise AssertionError(error) 224 | else: 225 | raise AssertionError 226 | 227 | 228 | def list_or_tuple_to_list(*columns: Union[list[T], T]) -> list[T]: 229 | """Convert a list or a tuple to a list 230 | 231 | >>> list_or_tuple_to_list() 232 | [] 233 | >>> list_or_tuple_to_list(1, 2) 234 | [1, 2] 235 | >>> list_or_tuple_to_list([1, 2]) 236 | [1, 2] 237 | >>> list_or_tuple_to_list([1, 2], [4, 5]) 238 | Traceback (most recent call last): 239 | ... 240 | TypeError: Wrong argument type: 241 | """ 242 | assert_true(isinstance(columns, (list, tuple)), TypeError(f"Wrong argument type: {type(columns)}")) 243 | if len(columns) == 0: 244 | return [] 245 | if isinstance(columns[0], list): 246 | if len(columns) == 1: 247 | return columns[0] 248 | else: 249 | raise TypeError(f"Wrong argument type: {type(columns)}") 250 | else: 251 | return list(columns) 252 | 253 | 254 | def _ref(_: object) -> None: 255 | """Dummy function used to prevent 'optimize import' from dropping the methods imported""" 256 | -------------------------------------------------------------------------------- /dev/bin/run_linters.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | poetry run black . 5 | poetry run ruff format . 6 | poetry run ruff check bigquery_frame tests 7 | poetry run mypy bigquery_frame 8 | 9 | -------------------------------------------------------------------------------- /dev/bin/run_security_checks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | poetry run bandit . 5 | poetry run safety check 6 | -------------------------------------------------------------------------------- /dev/bin/run_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | export GCP_CREDENTIALS_PATH="gcp-credentials.json" 5 | poetry run pytest --cov -n 32 "$@" 6 | -------------------------------------------------------------------------------- /examples/data_diff/country_code_iso.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder 2 | 3 | bq = BigQueryBuilder() 4 | from bigquery_frame.data_diff import DataframeComparator 5 | 6 | ##################################################################################################################### 7 | # The input tables are snapshots of the table `bigquery-public-data.utility_us.country_code_iso` made at 6 days 8 | # interval, the 2022-09-21 and 2022-09-27. I noticed that the table had been updated with obviously incorrect data: 9 | # as you can see in the diff below, it looks like the continent_code and continent_name columns had been inverted. 10 | # A few hours later, the table was reverted to the same version as it was on the 2022-09-21. 11 | ##################################################################################################################### 12 | 13 | old_df = bq.table("test_us.country_code_iso_snapshot_20220921") 14 | new_df = bq.table("test_us.country_code_iso_snapshot_20220927") 15 | result = DataframeComparator().compare_df(old_df, new_df, join_cols=["country_name"]) 16 | 17 | result.display(show_examples=True) 18 | # Schema: ok (10) 19 | # diff NOT ok 20 | # Summary: 21 | # Row count ok: 278 rows 22 | # 28 (10.07%) rows are identical 23 | # 250 (89.93%) rows have changed 24 | # 0%| | 0/1 [00:00=3.9,<3.13" 19 | data-diff-viewer = "0.3.2" 20 | 21 | google-cloud-bigquery = "^3.14.1" 22 | google-cloud-bigquery-storage = "^2.24.0" 23 | tabulate = "^0.9.0" 24 | tqdm = "^4.64.0" 25 | 26 | [tool.poetry.dev-dependencies] 27 | types-tqdm = "^4.64.6" 28 | types-tabulate = "^0.8.11" 29 | types-setuptools = "^65.6.0.3" 30 | 31 | black = "^24.3.0" 32 | ruff = "^0.1.6" 33 | mypy = "^0.971" 34 | safety = "^2.1.1" 35 | 36 | pytest = "^7.4.4" 37 | pytest-cov = "^4.1.0" 38 | pytest-xdist="^3.5.0" 39 | 40 | pipdeptree = "2.2.1" 41 | bump-my-version = "^0.20.3" 42 | 43 | # Dependencies used by DataFrame.toPandas() 44 | pandas = "~2.2.2" 45 | pyarrow = "^16.0.0" 46 | db-dtypes = "^1.0.3" 47 | 48 | 49 | [build-system] 50 | requires = ["poetry-core==1.1.14"] 51 | build-backend = "poetry.core.masonry.api" 52 | 53 | [tool.black] 54 | line-length = 120 55 | 56 | [tool.isort] 57 | line_length = 120 58 | profile = "black" 59 | known_first_party = "bigquery_frame" 60 | 61 | [tool.coverage.run] 62 | branch = true 63 | omit = ["tests/*"] 64 | 65 | [tool.coverage.html] 66 | directory = "htmlcov" 67 | 68 | [tool.coverage.xml] 69 | output = "test_working_dir/coverage.xml" 70 | 71 | [tool.coverage.report] 72 | exclude_also = [ 73 | "if TYPE_CHECKING:" 74 | ] 75 | 76 | [tool.pytest.ini_options] 77 | addopts = [ 78 | "-ra", 79 | "--doctest-modules", 80 | "--junitxml=test_working_dir/test-results.xml" 81 | ] 82 | 83 | testpaths = [ 84 | "bigquery_frame", 85 | "tests" 86 | ] 87 | 88 | [tool.bumpversion] 89 | current_version = "0.5.0" 90 | commit = true 91 | message = "Bump version: {current_version} → {new_version}" 92 | tag = true 93 | tag_name = "v{new_version}" 94 | tag_message = "Bump version: {current_version} → {new_version}" 95 | parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)(\\.(?P[a-z]+)(?P\\d+))?" 96 | serialize = ["{major}.{minor}.{patch}"] 97 | 98 | [[tool.bumpversion.files]] 99 | filename = "pyproject.toml" 100 | search = "version = \"{current_version}\"" 101 | replace = "version = \"{new_version}\"" 102 | 103 | [[tool.bumpversion.files]] 104 | filename = "bigquery_frame/__init__.py" 105 | search = "__version__ = \"{current_version}\"" 106 | replace = "__version__ = \"{new_version}\"" 107 | 108 | [[tool.bumpversion.files]] 109 | filename = "sonar-project.properties" 110 | search = "sonar.projectVersion={current_version}" 111 | replace = "sonar.projectVersion={new_version}" 112 | 113 | 114 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 120 2 | target-version = "py39" 3 | 4 | select = [ 5 | "C", # mccabe 6 | "D", # pydocstyle 7 | "E", # pycodestyle errors 8 | "W", # pycodestyle warnings 9 | "I", # isort 10 | "F", # Pyflakes 11 | "N", # pep8-naming 12 | "UP", # pyupgrade 13 | "S", # bandit 14 | "YTT", 15 | "ANN", 16 | "ASYNC", 17 | "BLE", 18 | "FBT", 19 | "B", 20 | "A", 21 | "COM", 22 | # "CPY", # copyright notice at top of files 23 | "C4", 24 | "DTZ", 25 | "T10", 26 | "EM", 27 | "EXE", 28 | "ISC", 29 | "ICN", 30 | "G", 31 | "INP", 32 | "PIE", 33 | "T20", 34 | "PYI", 35 | "PT", 36 | "Q", 37 | "RSE", 38 | "RET", 39 | "SLF", 40 | "SLOT", 41 | "SIM", 42 | "TID", 43 | "TCH", 44 | "INT", 45 | "ARG", 46 | "PTH", 47 | "TD", 48 | "FIX", 49 | "ERA", 50 | "PD", # pandas-vet 51 | "PGH", 52 | "PL", 53 | "TRY", 54 | "FLY", 55 | "NPY", 56 | "AIR", 57 | "PERF", 58 | "RUF", 59 | ] 60 | 61 | ignore = [ 62 | "ARG005", # Unused lambda argument 63 | # If a method expect as argument a higher order function of type "A -> B", I find it more confusing to feed it 64 | # with "lambda: B.default_value" than with "lambda a: B.default_value", even if a is not used. 65 | 66 | "RET504", # Unnecessary assignment to `...` before `return` statement 67 | # Naming the return argument before returning it is makes the code more readable and easier to debug. 68 | 69 | "RET505", # Unnecessary `else` after `return` statement 70 | "RET506", # Unnecessary `else` after `raise` statement 71 | # I find the functionnal-programming version less confusing than the imperative version: 72 | # 73 | # # Functionnal-programming version: 74 | # if P: 75 | # return B 76 | # else: 77 | # return C 78 | # 79 | # # Imperative version: 80 | # if P: 81 | # return B 82 | # return C 83 | # 84 | 85 | "SIM108", # Replace multiline if then else with one-liners 86 | # # Personally, I find this: 87 | # if predicate(): 88 | # x = 1 89 | # else: 90 | # x = 2 91 | # 92 | # # More readable than this: 93 | # x = 1 if predicate() else 2 94 | 95 | "FBT001", # Boolean-typed positional argument in function definition 96 | "FBT002", # Boolean default positional argument in function definition 97 | # These rules makes sense but there are several cases where working around it makes the code more confusing than less 98 | # Plus, the spark API does not follow this rule already (for instance, df.show(10, true) works) 99 | 100 | "ANN101", # Missing type annotation for `self` in method 101 | # This rule is not necessary when self is automatically infered by smart type checkers 102 | 103 | "D100", # Missing docstring in public module 104 | "D104", # Missing docstring in public package 105 | 106 | "E221", 107 | "E222", 108 | "E223", 109 | "E224", 110 | "E225", 111 | ] 112 | 113 | fixable = [ 114 | "A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", 115 | "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", 116 | "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", 117 | "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", 118 | "TID", "TRY", "UP", "YTT" 119 | ] 120 | unfixable = [] 121 | 122 | exclude = [ 123 | ".bzr", 124 | ".direnv", 125 | ".eggs", 126 | ".git", 127 | ".git-rewrite", 128 | ".hg", 129 | ".mypy_cache", 130 | ".nox", 131 | ".pants.d", 132 | ".pytype", 133 | ".ruff_cache", 134 | ".svn", 135 | ".tox", 136 | ".venv", 137 | "__pypackages__", 138 | "_build", 139 | "buck-out", 140 | "build", 141 | "dist", 142 | "node_modules", 143 | "venv", 144 | ] 145 | extend-exclude = [ 146 | "conftest.py", 147 | ] 148 | 149 | 150 | [mccabe] 151 | max-complexity = 10 152 | 153 | [lint.pydocstyle] 154 | convention = "google" 155 | 156 | [lint.isort] 157 | known-first-party = ["spark_frame"] 158 | 159 | [pylint] 160 | max-args = 8 # PLR0913: Too many arguments in function definition (8 > 5) 161 | 162 | [per-file-ignores] 163 | "tests/**/*.py" = [ 164 | # at least this three should be fine in tests: 165 | "S101", # asserts allowed in tests... 166 | "PLR2004", # Magic value used in comparison 167 | # "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... 168 | # "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() 169 | # # The below are debateable 170 | # "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes 171 | ] 172 | 173 | 174 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.projectKey=FurcyPin_bigquery-frame 2 | sonar.organization=furcypin 3 | 4 | # This is the name and version displayed in the SonarCloud UI. 5 | sonar.projectName=bigquery-frame 6 | sonar.projectVersion=0.5.0 7 | 8 | # Path is relative to the sonar-project.properties file. Replace "\" by "/" on Windows. 9 | sonar.sources=bigquery_frame 10 | sonar.tests=tests 11 | sonar.python.coverage.reportPaths=test_working_dir/coverage.xml 12 | sonar.python.xunit.reportPath=test_working_dir/test-results.xml 13 | sonar.python.version=3 14 | 15 | # Encoding of the source code. Default is default system encoding 16 | sonar.sourceEncoding=UTF-8 17 | -------------------------------------------------------------------------------- /test_working_dir/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/test_working_dir/.gitkeep -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/tests/__init__.py -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/tests/cli/__init__.py -------------------------------------------------------------------------------- /tests/cli/test_diff.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from bigquery_frame import BigQueryBuilder, DataFrame 4 | from bigquery_frame.cli import diff 5 | from tests.utils import captured_output 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def df(bq: BigQueryBuilder, random_test_dataset: str) -> DataFrame: 10 | df = bq.sql( 11 | """ 12 | SELECT * FROM UNNEST ([ 13 | STRUCT(1 as id, [STRUCT(1 as a, 2 as b, 3 as c)] as my_array), 14 | STRUCT(2 as id, [STRUCT(1 as a, 2 as b, 3 as c)] as my_array), 15 | STRUCT(3 as id, [STRUCT(1 as a, 2 as b, 3 as c)] as my_array) 16 | ]) 17 | """, 18 | ) 19 | return df 20 | 21 | 22 | @pytest.fixture(autouse=True) 23 | def t1(bq: BigQueryBuilder, random_test_dataset: str, df: DataFrame) -> str: 24 | df.write.save(f"{random_test_dataset}.t1") 25 | yield f"{random_test_dataset}.t1" 26 | bq._execute_query(f"DROP TABLE IF EXISTS {random_test_dataset}.t1") 27 | 28 | 29 | @pytest.fixture(autouse=True) 30 | def t2(bq: BigQueryBuilder, random_test_dataset: str, df: DataFrame) -> str: 31 | df.write.save(f"{random_test_dataset}.t2") 32 | yield f"{random_test_dataset}.t2" 33 | bq._execute_query(f"DROP TABLE IF EXISTS {random_test_dataset}.t2") 34 | 35 | 36 | def test_cli_diff(t1: str, t2: str): 37 | with captured_output() as (stdout, stderr): 38 | diff.main(["--tables", f"{t1}", f"{t2}", "--join-cols", "id"]) 39 | assert "Report exported as diff_report.html" in stdout.getvalue() 40 | 41 | 42 | def test_cli_diff_with_output_option(t1: str, t2: str): 43 | with captured_output() as (stdout, stderr): 44 | diff.main( 45 | ["--tables", f"{t1}", f"{t2}", "--join-cols", "id", "--output", "test_working_dir/test_cli_diff.html"], 46 | ) 47 | assert "Report exported as test_working_dir/test_cli_diff.html" in stdout.getvalue() 48 | 49 | 50 | def test_cli_diff_with_no_args(t1: str, t2: str): 51 | """WHEN the command is called with no argument, it should print the help and exit""" 52 | with captured_output() as (stdout, stderr): 53 | with pytest.raises(SystemExit): 54 | diff.main([]) 55 | assert "usage: bq-diff" in stdout.getvalue() 56 | 57 | 58 | def test_cli_diff_with_help_option(t1: str, t2: str): 59 | """WHEN the command is called with the --help option, it should print the help and exit""" 60 | with captured_output() as (stdout, stderr): 61 | with pytest.raises(SystemExit): 62 | diff.main(["--help"]) 63 | assert "usage: bq-diff" in stdout.getvalue() 64 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | import pytest 4 | from google.cloud.bigquery import Client, Dataset 5 | 6 | from bigquery_frame.auth import get_bq_client 7 | from bigquery_frame.bigquery_builder import BigQueryBuilder 8 | 9 | 10 | @pytest.fixture(scope="session") 11 | def random_test_dataset() -> str: 12 | random_id = uuid4() 13 | return "test_dataset_" + str(random_id).replace("-", "_") 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def client(random_test_dataset: str) -> Client: 18 | client = get_bq_client() 19 | dataset = Dataset(f"{client.project}.{random_test_dataset}") 20 | dataset.location = "EU" 21 | client.create_dataset(dataset, exists_ok=True) 22 | yield client 23 | client.delete_dataset(dataset, delete_contents=True) 24 | client.close() 25 | 26 | 27 | @pytest.fixture() 28 | def bq(client: Client) -> BigQueryBuilder: 29 | bq = BigQueryBuilder(client) 30 | return bq 31 | -------------------------------------------------------------------------------- /tests/data_diff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/tests/data_diff/__init__.py -------------------------------------------------------------------------------- /tests/graph_impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/tests/graph_impl/__init__.py -------------------------------------------------------------------------------- /tests/graph_impl/test_connected_components.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder 2 | from bigquery_frame import functions as f 3 | from bigquery_frame.graph import connected_components 4 | 5 | 6 | def test_connected_components(bq: BigQueryBuilder): 7 | df = bq.sql( 8 | """ 9 | SELECT * 10 | FROM UNNEST([ 11 | STRUCT(1 as L, 8 as R), 12 | STRUCT(8 as L, 9 as R), 13 | STRUCT(5 as L, 8 as R), 14 | STRUCT(7 as L, 8 as R), 15 | STRUCT(3 as L, 7 as R), 16 | STRUCT(2 as L, 3 as R), 17 | STRUCT(1 as L, 3 as R), 18 | STRUCT(1 as L, 2 as R), 19 | STRUCT(1 as L, 8 as R), 20 | STRUCT(4 as L, 6 as R), 21 | STRUCT(11 as L, 18 as R), 22 | STRUCT(18 as L, 19 as R), 23 | STRUCT(15 as L, 18 as R), 24 | STRUCT(17 as L, 18 as R), 25 | STRUCT(13 as L, 17 as R), 26 | STRUCT(12 as L, 13 as R), 27 | STRUCT(14 as L, 16 as R), 28 | STRUCT(11 as L, 18 as R), 29 | STRUCT(18 as L, 19 as R), 30 | STRUCT(15 as L, 18 as R), 31 | STRUCT(17 as L, 18 as R), 32 | STRUCT(13 as L, 17 as R), 33 | STRUCT(12 as L, 13 as R), 34 | STRUCT(14 as L, 16 as R) 35 | ]) 36 | """, 37 | ) 38 | expected_df = bq.sql( 39 | """ 40 | SELECT * 41 | FROM UNNEST([ 42 | STRUCT(1 as node_id, 1 as connected_component_id), 43 | STRUCT(2 as node_id, 1 as connected_component_id), 44 | STRUCT(3 as node_id, 1 as connected_component_id), 45 | STRUCT(4 as node_id, 4 as connected_component_id), 46 | STRUCT(5 as node_id, 1 as connected_component_id), 47 | STRUCT(6 as node_id, 4 as connected_component_id), 48 | STRUCT(7 as node_id, 1 as connected_component_id), 49 | STRUCT(8 as node_id, 1 as connected_component_id), 50 | STRUCT(9 as node_id, 1 as connected_component_id), 51 | STRUCT(11 as node_id, 11 as connected_component_id), 52 | STRUCT(12 as node_id, 11 as connected_component_id), 53 | STRUCT(13 as node_id, 11 as connected_component_id), 54 | STRUCT(14 as node_id, 14 as connected_component_id), 55 | STRUCT(15 as node_id, 11 as connected_component_id), 56 | STRUCT(16 as node_id, 14 as connected_component_id), 57 | STRUCT(17 as node_id, 11 as connected_component_id), 58 | STRUCT(18 as node_id, 11 as connected_component_id), 59 | STRUCT(19 as node_id, 11 as connected_component_id) 60 | ]) 61 | """, 62 | ) 63 | df = df.select(f.col("L").alias("l_node"), f.col("R").alias("r_node")).persist() 64 | actual_df = connected_components(df).sort("node_id", "connected_component_id") 65 | assert actual_df.collect() == expected_df.collect() 66 | -------------------------------------------------------------------------------- /tests/nested_impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/tests/nested_impl/__init__.py -------------------------------------------------------------------------------- /tests/nested_impl/test_unnest_all_fields.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder, nested 2 | from bigquery_frame.utils import strip_margin 3 | 4 | 5 | def test_unnest_fields_with_fields_having_same_name_inside_structs(bq: BigQueryBuilder): 6 | """GIVEN a DataFrame with fields in structs having the same name as root-level columns 7 | WHEN we apply unnest_fields on it 8 | THEN the result should be correct 9 | """ 10 | df = bq.sql( 11 | """ 12 | SELECT 13 | 1 as id, 14 | STRUCT(2 as id) as s1 15 | """, 16 | ) 17 | assert df.show_string(simplify_structs=True) == strip_margin( 18 | """ 19 | |+----+-----+ 20 | || id | s1 | 21 | |+----+-----+ 22 | || 1 | {2} | 23 | |+----+-----+""", 24 | ) 25 | assert nested.fields(df) == ["id", "s1.id"] 26 | result_df_list = nested.unnest_all_fields(df, keep_columns=["id"]) 27 | assert list(result_df_list.keys()) == [""] 28 | result_df_list[""].show() 29 | assert result_df_list[""].show_string(simplify_structs=True) == strip_margin( 30 | """ 31 | |+----+----------------+ 32 | || id | s1__STRUCT__id | 33 | |+----+----------------+ 34 | || 1 | 2 | 35 | |+----+----------------+""", 36 | ) 37 | 38 | 39 | def test_unnest_fields_with_fields_having_same_name_inside_array_structs(bq: BigQueryBuilder): 40 | """GIVEN a DataFrame with fields in array of struct having the same name as root-level columns 41 | WHEN we apply unnest_fields on it 42 | THEN the result should be correct 43 | """ 44 | df = bq.sql( 45 | """ 46 | SELECT 47 | 1 as id, 48 | STRUCT(2 as id) as s1, 49 | [STRUCT(3 as id, [STRUCT(5 as id), STRUCT(6 as id)] as s3)] as s2, 50 | """, 51 | ) 52 | assert df.show_string(simplify_structs=True) == strip_margin( 53 | """ 54 | |+----+-----+-------------------+ 55 | || id | s1 | s2 | 56 | |+----+-----+-------------------+ 57 | || 1 | {2} | [{3, [{5}, {6}]}] | 58 | |+----+-----+-------------------+""", 59 | ) 60 | 61 | assert nested.fields(df) == ["id", "s1.id", "s2!.id", "s2!.s3!.id"] 62 | result_df_list = nested.unnest_all_fields(df, keep_columns=["id"]) 63 | assert result_df_list[""].show_string() == strip_margin( 64 | """ 65 | |+----+----------------+ 66 | || id | s1__STRUCT__id | 67 | |+----+----------------+ 68 | || 1 | 2 | 69 | |+----+----------------+""", 70 | ) 71 | assert result_df_list["s2!"].show_string() == strip_margin( 72 | """ 73 | |+----+-------------------------+ 74 | || id | s2__ARRAY____STRUCT__id | 75 | |+----+-------------------------+ 76 | || 1 | 3 | 77 | |+----+-------------------------+""", 78 | ) 79 | assert result_df_list["s2!.s3!"].show_string() == strip_margin( 80 | """ 81 | |+----+----------------------------------------------+ 82 | || id | s2__ARRAY____STRUCT__s3__ARRAY____STRUCT__id | 83 | |+----+----------------------------------------------+ 84 | || 1 | 5 | 85 | || 1 | 6 | 86 | |+----+----------------------------------------------+""", 87 | ) 88 | 89 | 90 | def test_unnest_fields_with_fields_having_same_name_inside_array_structs_and_names_are_keywords(bq: BigQueryBuilder): 91 | """GIVEN a DataFrame with fields in array of struct having the same name as root-level columns 92 | AND if the names are reserved keywords 93 | WHEN we apply unnest_fields on it 94 | THEN the result should be correct 95 | """ 96 | df = bq.sql( 97 | """ 98 | SELECT 99 | 1 as `group`, 100 | STRUCT(2 as `group`) as s1, 101 | [STRUCT(3 as `group`, [STRUCT(5 as `group`), STRUCT(6 as `group`)] as s3)] as s2, 102 | """, 103 | ) 104 | assert df.show_string(simplify_structs=True) == strip_margin( 105 | """ 106 | |+-------+-----+-------------------+ 107 | || group | s1 | s2 | 108 | |+-------+-----+-------------------+ 109 | || 1 | {2} | [{3, [{5}, {6}]}] | 110 | |+-------+-----+-------------------+""", 111 | ) 112 | 113 | assert nested.fields(df) == ["group", "s1.group", "s2!.group", "s2!.s3!.group"] 114 | result_df_list = nested.unnest_all_fields(df, keep_columns=["group"]) 115 | assert result_df_list[""].show_string() == strip_margin( 116 | """ 117 | |+-------+-------------------+ 118 | || group | s1__STRUCT__group | 119 | |+-------+-------------------+ 120 | || 1 | 2 | 121 | |+-------+-------------------+""", 122 | ) 123 | assert result_df_list["s2!"].show_string() == strip_margin( 124 | """ 125 | |+-------+----------------------------+ 126 | || group | s2__ARRAY____STRUCT__group | 127 | |+-------+----------------------------+ 128 | || 1 | 3 | 129 | |+-------+----------------------------+""", 130 | ) 131 | assert result_df_list["s2!.s3!"].show_string() == strip_margin( 132 | """ 133 | |+-------+-------------------------------------------------+ 134 | || group | s2__ARRAY____STRUCT__s3__ARRAY____STRUCT__group | 135 | |+-------+-------------------------------------------------+ 136 | || 1 | 5 | 137 | || 1 | 6 | 138 | |+-------+-------------------------------------------------+""", 139 | ) 140 | -------------------------------------------------------------------------------- /tests/nested_impl/test_with_fields.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder, nested 2 | from bigquery_frame import functions as f 3 | from bigquery_frame.nested_impl.schema_string import schema_string 4 | from bigquery_frame.utils import strip_margin 5 | 6 | 7 | def test_with_fields(bq: BigQueryBuilder): 8 | """GIVEN a DataFrame with nested fields 9 | WHEN we use with_fields to add a new field 10 | THEN the other fields should remain undisturbed 11 | """ 12 | df = bq.sql( 13 | """SELECT 14 | 1 as id, 15 | [STRUCT(2 as a, [STRUCT(3 as c, 4 as d)] as b, [5, 6] as e)] as s1, 16 | STRUCT(7 as f) as s2, 17 | [STRUCT([1, 2] as a), STRUCT([3, 4] as a)] as s3, 18 | [STRUCT([STRUCT(1 as e, 2 as f)] as a), STRUCT([STRUCT(3 as e, 4 as f)] as a)] as s4 19 | """, 20 | ) 21 | assert schema_string(df) == strip_margin( 22 | """ 23 | |root 24 | | |-- id: INTEGER (nullable = true) 25 | | |-- s1!.a: INTEGER (nullable = true) 26 | | |-- s1!.b!.c: INTEGER (nullable = true) 27 | | |-- s1!.b!.d: INTEGER (nullable = true) 28 | | |-- s1!.e!: INTEGER (nullable = false) 29 | | |-- s2.f: INTEGER (nullable = true) 30 | | |-- s3!.a!: INTEGER (nullable = false) 31 | | |-- s4!.a!.e: INTEGER (nullable = true) 32 | | |-- s4!.a!.f: INTEGER (nullable = true) 33 | |""", 34 | ) 35 | new_df = df.transform(nested.with_fields, {"s5.g": f.col("s2.f").cast("FLOAT64")}) 36 | assert schema_string(new_df) == strip_margin( 37 | """ 38 | |root 39 | | |-- id: INTEGER (nullable = true) 40 | | |-- s1!.a: INTEGER (nullable = true) 41 | | |-- s1!.b!.c: INTEGER (nullable = true) 42 | | |-- s1!.b!.d: INTEGER (nullable = true) 43 | | |-- s1!.e!: INTEGER (nullable = false) 44 | | |-- s2.f: INTEGER (nullable = true) 45 | | |-- s3!.a!: INTEGER (nullable = false) 46 | | |-- s4!.a!.e: INTEGER (nullable = true) 47 | | |-- s4!.a!.f: INTEGER (nullable = true) 48 | | |-- s5.g: FLOAT (nullable = true) 49 | |""", 50 | ) 51 | -------------------------------------------------------------------------------- /tests/test_bigquery_builder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from google.api_core.exceptions import BadRequest 3 | from google.cloud.bigquery import Client 4 | 5 | from bigquery_frame import BigQueryBuilder 6 | 7 | 8 | def test_without_debug(client: Client): 9 | """GIVEN a BigQueryBuilder 10 | WHEN we run an incorrect query without the debug mode 11 | THEN it will fail only at the end 12 | """ 13 | bq = BigQueryBuilder(client) 14 | df1 = bq.sql("""SELECT 1 as a""") 15 | df2 = df1.select("b") 16 | with pytest.raises(BadRequest): 17 | df2.show() 18 | 19 | 20 | def test_with_debug(client: Client): 21 | """GIVEN a BigQueryBuilder 22 | WHEN we run an incorrect query without the debug mode 23 | THEN it will fail only at the end 24 | """ 25 | bq = BigQueryBuilder(client, debug=True) 26 | df1 = bq.sql("""SELECT 1 as a""") 27 | with pytest.raises(BadRequest): 28 | df1.select("b") 29 | -------------------------------------------------------------------------------- /tests/test_dataframe_writer.py: -------------------------------------------------------------------------------- 1 | import google 2 | import pytest 3 | from google.cloud.bigquery import Client 4 | 5 | from bigquery_frame import BigQueryBuilder 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def clean(bq: BigQueryBuilder, random_test_dataset: str): 10 | bq._execute_query(f"DROP TABLE IF EXISTS {random_test_dataset}.my_table") 11 | 12 | 13 | def test_write_with_mode_overwrite(bq: BigQueryBuilder, random_test_dataset: str): 14 | bq.sql("SELECT 1 as a").write.save(f"{random_test_dataset}.my_table", mode="OVERWRITE") 15 | bq.sql("SELECT 2 as b").write.mode("overwrite").save(f"{random_test_dataset}.my_table") 16 | df = bq.table(f"{random_test_dataset}.my_table") 17 | assert [r["b"] for r in df.collect()] == [2] 18 | 19 | 20 | def test_write_with_mode_append(bq: BigQueryBuilder, random_test_dataset: str): 21 | bq.sql("SELECT 1 as a").write.save(f"{random_test_dataset}.my_table", mode="OVERWRITE") 22 | bq.sql("SELECT 2 as a").write.mode("append").save(f"{random_test_dataset}.my_table") 23 | df = bq.table(f"{random_test_dataset}.my_table").orderBy("a") 24 | assert [r["a"] for r in df.collect()] == [1, 2] 25 | 26 | 27 | def test_write_with_mode_ignore(bq: BigQueryBuilder, random_test_dataset: str): 28 | bq.sql("SELECT 1 as a").write.save(f"{random_test_dataset}.my_table", mode="OVERWRITE") 29 | bq.sql("SELECT 2 as b").write.mode("ignore").save(f"{random_test_dataset}.my_table") 30 | df = bq.table(f"{random_test_dataset}.my_table") 31 | assert [r["a"] for r in df.collect()] == [1] 32 | 33 | 34 | def test_write_with_mode_error(bq: BigQueryBuilder, random_test_dataset: str): 35 | bq.sql("SELECT 1 as a").write.save(f"{random_test_dataset}.my_table", mode="OVERWRITE") 36 | with pytest.raises(google.api_core.exceptions.Conflict) as e: 37 | bq.sql("SELECT 2 as a").write.mode("error").save(f"{random_test_dataset}.my_table") 38 | assert "409 " in str(e.value) and "Already Exists" in str(e.value) 39 | 40 | 41 | def test_write_with_mode_errorifexists(bq: BigQueryBuilder, random_test_dataset: str): 42 | bq.sql("SELECT 1 as a").write.save(f"{random_test_dataset}.my_table", mode="OVERWRITE") 43 | with pytest.raises(google.api_core.exceptions.Conflict) as e: 44 | bq.sql("SELECT 2 as a").write.mode("errorifexists").save(f"{random_test_dataset}.my_table") 45 | assert "409 " in str(e.value) and "Already Exists" in str(e.value) 46 | 47 | 48 | def test_write_with_options(bq: BigQueryBuilder, random_test_dataset: str, client: Client): 49 | df = bq.sql("SELECT 1 as a") 50 | options = {"description": "this is a test table", "labels": {"org_unit": "development"}} 51 | df.write.save(f"{random_test_dataset}.my_table", mode="OVERWRITE", **options) 52 | table = client.get_table(f"{random_test_dataset}.my_table") 53 | assert table.description == "this is a test table" 54 | assert table.labels == {"org_unit": "development"} 55 | -------------------------------------------------------------------------------- /tests/test_grouped_data.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder 2 | from bigquery_frame.utils import strip_margin 3 | 4 | 5 | def test_groupBy_with_struct_columns(bq: BigQueryBuilder): 6 | """GIVEN a DataFrame with nested fields 7 | WHEN we use a pivot.agg statement 8 | THEN the result should be correct 9 | """ 10 | df = bq.sql( 11 | """ 12 | SELECT * FROM UNNEST([ 13 | STRUCT(STRUCT(2 as age, "Alice" as name, 80 as height) as child), 14 | STRUCT(STRUCT(3 as age, "Alice" as name, 100 as height) as child), 15 | STRUCT(STRUCT(5 as age, "Bob" as name, 120 as height) as child), 16 | STRUCT(STRUCT(10 as age, "Bob" as name, 140 as height) as child) 17 | ]) 18 | """, 19 | ) 20 | assert df.groupBy("child.name").avg("child.age").sort("name").show_string() == strip_margin( 21 | """ 22 | |+-------+---------+ 23 | || name | avg_age | 24 | |+-------+---------+ 25 | || Alice | 2.5 | 26 | || Bob | 7.5 | 27 | |+-------+---------+""", 28 | ) 29 | 30 | assert df.groupBy().avg("child.age", "child.height").show_string() == strip_margin( 31 | """ 32 | |+---------+------------+ 33 | || avg_age | avg_height | 34 | |+---------+------------+ 35 | || 5.0 | 110.0 | 36 | |+---------+------------+""", 37 | ) 38 | 39 | 40 | def test_pivot_with_struct_columns(bq: BigQueryBuilder): 41 | """GIVEN a DataFrame with nested fields 42 | WHEN we use a pivot.agg statement 43 | THEN the result should be correct 44 | """ 45 | df = bq.sql( 46 | """ 47 | SELECT * FROM UNNEST([ 48 | STRUCT("expert" as training, STRUCT("dotNET" as course, 2012 as year, 10000 as earnings) as sales) , 49 | STRUCT("junior" as training, STRUCT("Java" as course, 2012 as year, 20000 as earnings) as sales) , 50 | STRUCT("expert" as training, STRUCT("dotNET" as course, 2012 as year, 5000 as earnings) as sales) , 51 | STRUCT("junior" as training, STRUCT("dotNET" as course, 2013 as year, 48000 as earnings) as sales) , 52 | STRUCT("expert" as training, STRUCT("Java" as course, 2013 as year, 30000 as earnings) as sales) 53 | ]) 54 | """, 55 | ) 56 | assert df.show_string(simplify_structs=True) == strip_margin( 57 | """ 58 | |+----------+-----------------------+ 59 | || training | sales | 60 | |+----------+-----------------------+ 61 | || expert | {dotNET, 2012, 10000} | 62 | || junior | {Java, 2012, 20000} | 63 | || expert | {dotNET, 2012, 5000} | 64 | || junior | {dotNET, 2013, 48000} | 65 | || expert | {Java, 2013, 30000} | 66 | |+----------+-----------------------+""", 67 | ) 68 | df1 = df.groupBy("training").pivot("sales.course", ["dotNET", "Java"]).sum("sales.earnings") 69 | assert df1.show_string(simplify_structs=True) == strip_margin( 70 | """ 71 | |+----------+--------+-------+ 72 | || training | dotNET | Java | 73 | |+----------+--------+-------+ 74 | || expert | 15000 | 30000 | 75 | || junior | 48000 | 20000 | 76 | |+----------+--------+-------+""", 77 | ) 78 | df2 = df.groupBy().pivot("sales.course", ["dotNET", "Java"]).sum("sales.earnings") 79 | assert df2.show_string(simplify_structs=True) == strip_margin( 80 | """ 81 | |+--------+-------+ 82 | || dotNET | Java | 83 | |+--------+-------+ 84 | || 15000 | 30000 | 85 | || 48000 | 20000 | 86 | |+--------+-------+""", 87 | ) 88 | df3 = df.groupBy("training").pivot("sales.course").sum("sales.earnings") 89 | assert df3.show_string(simplify_structs=True) == strip_margin( 90 | """ 91 | |+----------+-------+--------+ 92 | || training | Java | dotNET | 93 | |+----------+-------+--------+ 94 | || expert | 30000 | 15000 | 95 | || junior | 20000 | 48000 | 96 | |+----------+-------+--------+""", 97 | ) 98 | df4 = df.groupBy().pivot("sales.course").sum("sales.earnings") 99 | assert df4.show_string(simplify_structs=True) == strip_margin( 100 | """ 101 | |+-------+--------+ 102 | || Java | dotNET | 103 | |+-------+--------+ 104 | || 30000 | 15000 | 105 | || 20000 | 48000 | 106 | |+-------+--------+""", 107 | ) 108 | -------------------------------------------------------------------------------- /tests/test_has_bigquery_client.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | from google.api_core.exceptions import BadRequest, InternalServerError 5 | from google.cloud.bigquery import Client 6 | 7 | from bigquery_frame.has_bigquery_client import HasBigQueryClient 8 | from bigquery_frame.utils import strip_margin 9 | 10 | 11 | def test_error_handling(client: Client): 12 | """GIVEN a HasBigQueryClient 13 | WHEN we execute a query with an incorrect syntax 14 | THEN a BadRequest exception should be raised 15 | AND it should contain the numbered text of the query 16 | """ 17 | bq_client = HasBigQueryClient(client) 18 | bad_query = """bad query""" 19 | with pytest.raises(BadRequest) as e: 20 | bq_client._execute_query(bad_query) 21 | assert f"Query:\n1: {bad_query}" in e.value.message 22 | 23 | 24 | def test_runtime_error_handling(client: Client): 25 | """GIVEN a HasBigQueryClient 26 | WHEN we execute a query that compiles but fails at runtime 27 | THEN a BadRequest exception should be raised 28 | AND it should contain the numbered text of the query 29 | """ 30 | bq_client = HasBigQueryClient(client) 31 | bad_query = """SELECT (SELECT * FROM UNNEST ([1, 2]))""" 32 | with pytest.raises(BadRequest) as e: 33 | bq_client._execute_query(bad_query) 34 | assert f"Query:\n1: {bad_query}" in e.value.message 35 | 36 | 37 | def test_retry(client: Client): 38 | """GIVEN a HasBigQueryClient 39 | WHEN we execute a query and an InternalServerError happens 40 | THEN we retry the query 3 times 41 | """ 42 | 43 | def result_mock(*args, **kwargs): 44 | raise InternalServerError("This is a test error") 45 | 46 | bq_client = HasBigQueryClient(client) 47 | bad_query = """bad query""" 48 | with mock.patch("google.cloud.bigquery.job.query.QueryJob.result", side_effect=result_mock) as mocked_result: 49 | with pytest.raises(InternalServerError) as e: 50 | bq_client._execute_query(bad_query) 51 | assert mocked_result.call_count == 3 52 | assert "This is a test error" in e.value.message 53 | 54 | 55 | def test_stats_human_readable(client: Client, random_test_dataset: str): 56 | """GIVEN a HasBigQueryClient 57 | WHEN we execute a query that reads from a table 58 | AND we display the query stats in human_readable mode 59 | THEN they should be correctly displayed 60 | """ 61 | bq_client = HasBigQueryClient(client) 62 | bq_client._execute_query(f"CREATE OR REPLACE TABLE {random_test_dataset}.my_table AS SELECT 1 as a") 63 | bq_client._execute_query(f"SELECT * FROM {random_test_dataset}.my_table") 64 | bq_client._execute_query(f"SELECT * FROM {random_test_dataset}.my_table LIMIT 1") 65 | expected = strip_margin( 66 | """ 67 | |Estimated bytes processed : 16.00 B 68 | |Total bytes processed : 16.00 B 69 | |Total bytes billed : 20.00 MiB 70 | |""", 71 | ) 72 | assert bq_client.stats.human_readable() == expected 73 | 74 | 75 | def test_cache_enabled(client: Client, random_test_dataset: str): 76 | """GIVEN a HasBigQueryClient 77 | WHEN we execute a query that reads from a table with the query cache ENABLED 78 | THEN the number of bytes processed/billed SHOULD NOT increase the second time 79 | """ 80 | bq_client = HasBigQueryClient(client) 81 | bq_client._execute_query(f"CREATE OR REPLACE TABLE {random_test_dataset}.my_table AS SELECT 1 as a") 82 | bq_client._execute_query(f"SELECT * FROM {random_test_dataset}.my_table", use_query_cache=True) 83 | assert bq_client.stats.estimated_bytes_processed == 8 84 | assert bq_client.stats.total_bytes_processed == 8 85 | assert bq_client.stats.total_bytes_billed == 10 * 1024 * 1024 86 | bq_client._execute_query(f"SELECT * FROM {random_test_dataset}.my_table", use_query_cache=True) 87 | assert bq_client.stats.estimated_bytes_processed == 8 88 | assert bq_client.stats.total_bytes_processed == 8 89 | assert bq_client.stats.total_bytes_billed == 10 * 1024 * 1024 90 | 91 | 92 | def test_cache_disabled(client: Client, random_test_dataset: str): 93 | """GIVEN a HasBigQueryClient 94 | WHEN we execute a query that reads from a table with the query cache DISABLED 95 | THEN the number of bytes processed/billed SHOULD increase the second time 96 | """ 97 | bq_client = HasBigQueryClient(client) 98 | bq_client._execute_query(f"CREATE OR REPLACE TABLE {random_test_dataset}.my_table AS SELECT 1 as a") 99 | bq_client._execute_query(f"SELECT * FROM {random_test_dataset}.my_table", use_query_cache=True) 100 | assert bq_client.stats.estimated_bytes_processed == 8 101 | assert bq_client.stats.total_bytes_processed == 8 102 | assert bq_client.stats.total_bytes_billed == 10 * 1024 * 1024 103 | bq_client._execute_query(f"SELECT * FROM {random_test_dataset}.my_table", use_query_cache=False) 104 | assert bq_client.stats.estimated_bytes_processed == 8 * 2 105 | assert bq_client.stats.total_bytes_processed == 8 * 2 106 | assert bq_client.stats.total_bytes_billed == 10 * 1024 * 1024 * 2 107 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame.dataframe import strip_margin 2 | from bigquery_frame.utils import number_lines 3 | 4 | 5 | def test_number_lines(): 6 | s = "\n".join([str(i) for i in range(1, 10)]) 7 | expected = strip_margin( 8 | """ 9 | |1: 1 10 | |2: 2 11 | |3: 3 12 | |4: 4 13 | |5: 5 14 | |6: 6 15 | |7: 7 16 | |8: 8 17 | |9: 9""", 18 | ) 19 | assert number_lines(s) == expected 20 | 21 | s = "\n".join([str(i) for i in range(1, 11)]) 22 | expected = strip_margin( 23 | """ 24 | |01: 1 25 | |02: 2 26 | |03: 3 27 | |04: 4 28 | |05: 5 29 | |06: 6 30 | |07: 7 31 | |08: 8 32 | |09: 9 33 | |10: 10""", 34 | ) 35 | assert number_lines(s) == expected 36 | -------------------------------------------------------------------------------- /tests/transformations_impl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FurcyPin/bigquery-frame/5502060f8ce29044a89dc8ec14f4122247bae522/tests/transformations_impl/__init__.py -------------------------------------------------------------------------------- /tests/transformations_impl/test_analyze.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder, Column 2 | from bigquery_frame import functions as f 3 | from bigquery_frame.transformations_impl.analyze import __get_test_df as get_test_df 4 | from bigquery_frame.transformations_impl.analyze import analyze 5 | from bigquery_frame.utils import strip_margin 6 | 7 | 8 | def _get_expected(): 9 | return strip_margin( 10 | """ 11 | |+---------------+------------------------+-------------+-------+----------------+------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------------------------+ 12 | || column_number | column_name | column_type | count | count_distinct | count_null | min | max | approx_top_100 | 13 | |+---------------+------------------------+-------------+-------+----------------+------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------------------------+ 14 | || 0 | id | INTEGER | 9 | 9 | 0 | 1 | 9 | [{1, 1}, {2, 1}, {3, 1}, {4, 1}, {5, 1}, {6, 1}, {7, 1}, {8, 1}, {9, 1}] | 15 | || 1 | name | STRING | 9 | 9 | 0 | Blastoise | Wartortle | [{Bulbasaur, 1}, {Ivysaur, 1}, {Venusaur, 1}, {Charmander, 1}, {Charmeleon, 1}, {Charizard, 1}, {Squirtle, 1}, {Wartortle, 1}, {Blastoise, 1}] | 16 | || 2 | types! | STRING | 13 | 5 | 0 | Fire | Water | [{Grass, 3}, {Poison, 3}, {Fire, 3}, {Water, 3}, {Flying, 1}] | 17 | || 3 | evolution.can_evolve | BOOLEAN | 9 | 2 | 0 | false | true | [{true, 6}, {false, 3}] | 18 | || 4 | evolution.evolves_from | INTEGER | 9 | 6 | 3 | 1 | 8 | [{NULL, 3}, {1, 1}, {2, 1}, {4, 1}, {5, 1}, {7, 1}, {8, 1}] | 19 | |+---------------+------------------------+-------------+-------+----------------+------------+-----------+-----------+------------------------------------------------------------------------------------------------------------------------------------------------+""", # noqa: E501 20 | ) 21 | 22 | 23 | def test_analyze(bq: BigQueryBuilder): 24 | df = get_test_df() 25 | actual = analyze(df) 26 | actual.show(simplify_structs=True) 27 | assert actual.show_string(simplify_structs=True) == _get_expected() 28 | 29 | 30 | def test_analyze_with_keyword_column_names(bq: BigQueryBuilder): 31 | """GIVEN a DataFrame containing field names that are reserved keywords 32 | WHEN we analyze it 33 | THEN no crash should occur 34 | """ 35 | query = """SELECT 1 as `FROM`, STRUCT('a' as `ALL`) as `UNION`""" 36 | df = bq.sql(query) 37 | actual = analyze(df) 38 | assert actual.show_string(simplify_structs=True) == strip_margin( 39 | """ 40 | |+---------------+-------------+-------------+-------+----------------+------------+-----+-----+----------------+ 41 | || column_number | column_name | column_type | count | count_distinct | count_null | min | max | approx_top_100 | 42 | |+---------------+-------------+-------------+-------+----------------+------------+-----+-----+----------------+ 43 | || 0 | FROM | INTEGER | 1 | 1 | 0 | 1 | 1 | [{1, 1}] | 44 | || 1 | UNION.ALL | STRING | 1 | 1 | 0 | a | a | [{a, 1}] | 45 | |+---------------+-------------+-------------+-------+----------------+------------+-----+-----+----------------+""", 46 | ) 47 | 48 | 49 | def test_analyze_with_array_struct_array(bq: BigQueryBuilder): 50 | """GIVEN a DataFrame containing an ARRAY>> 51 | WHEN we analyze it 52 | THEN no crash should occur 53 | """ 54 | query = """SELECT [STRUCT([1, 2, 3] as b)] as a""" 55 | df = bq.sql(query) 56 | actual = analyze(df) 57 | assert actual.show_string(simplify_structs=True) == strip_margin( 58 | """ 59 | |+---------------+-------------+-------------+-------+----------------+------------+-----+-----+--------------------------+ 60 | || column_number | column_name | column_type | count | count_distinct | count_null | min | max | approx_top_100 | 61 | |+---------------+-------------+-------------+-------+----------------+------------+-----+-----+--------------------------+ 62 | || 0 | a!.b! | INTEGER | 3 | 3 | 0 | 1 | 3 | [{1, 1}, {2, 1}, {3, 1}] | 63 | |+---------------+-------------+-------------+-------+----------------+------------+-----+-----+--------------------------+""", # noqa: E501 64 | ) 65 | 66 | 67 | def test_analyze_with_bytes(bq: BigQueryBuilder): 68 | """GIVEN a DataFrame containing a column of type bytes 69 | WHEN we analyze it 70 | THEN no crash should occur 71 | """ 72 | query = r"""SELECT b'\377\340' as s""" 73 | df = bq.sql(query) 74 | actual = analyze(df) 75 | assert actual.show_string(simplify_structs=True) == strip_margin( 76 | """ 77 | |+---------------+-------------+-------------+-------+----------------+------------+------+------+----------------+ 78 | || column_number | column_name | column_type | count | count_distinct | count_null | min | max | approx_top_100 | 79 | |+---------------+-------------+-------------+-------+----------------+------------+------+------+----------------+ 80 | || 0 | s | BYTES | 1 | 1 | 0 | /+A= | /+A= | [{/+A=, 1}] | 81 | |+---------------+-------------+-------------+-------+----------------+------------+------+------+----------------+""", 82 | ) 83 | 84 | 85 | def test_analyze_with_nested_field_in_group_and_array_column(bq: BigQueryBuilder): 86 | """GIVEN a DataFrame containing a STRUCT and an array column 87 | WHEN we analyze it by grouping on a column inside this struct 88 | THEN no crash should occur 89 | """ 90 | query = """SELECT 1 as id, STRUCT(2 as b, 3 as c) as a, [1, 2, 3] as arr""" 91 | df = bq.sql(query) 92 | actual = analyze(df, group_by="a.b") 93 | assert actual.show_string(simplify_structs=True) == strip_margin( 94 | """ 95 | |+-------+---------------+-------------+-------------+-------+----------------+------------+-----+-----+--------------------------+ 96 | || group | column_number | column_name | column_type | count | count_distinct | count_null | min | max | approx_top_100 | 97 | |+-------+---------------+-------------+-------------+-------+----------------+------------+-----+-----+--------------------------+ 98 | || {2} | 0 | id | INTEGER | 1 | 1 | 0 | 1 | 1 | [{1, 1}] | 99 | || {2} | 2 | a.c | INTEGER | 1 | 1 | 0 | 3 | 3 | [{3, 1}] | 100 | || {2} | 3 | arr! | INTEGER | 3 | 3 | 0 | 1 | 3 | [{1, 1}, {2, 1}, {3, 1}] | 101 | |+-------+---------------+-------------+-------------+-------+----------------+------------+-----+-----+--------------------------+""", # noqa: E501 102 | ) 103 | 104 | 105 | def _build_huge_struct(value: Column, depth: int, width: int) -> Column: 106 | if depth == 0: 107 | return value.alias("s") 108 | return f.struct(*[_build_huge_struct(value, depth - 1, width).alias(f"c{i}") for i in range(width)]) 109 | 110 | 111 | def test_compare_df_with_huge_table(bq: BigQueryBuilder): 112 | df = bq.sql( 113 | """ 114 | SELECT * FROM UNNEST([ 115 | STRUCT(1 as id), 116 | STRUCT(2 as id), 117 | STRUCT(3 as id) 118 | ]) 119 | """, 120 | ) 121 | DEPTH = 3 122 | WIDTH = 5 123 | df = df.select("id", _build_huge_struct(f.lit(1), depth=DEPTH, width=WIDTH).alias("s")).persist() 124 | actual = analyze(df) 125 | actual.show(simplify_structs=True) 126 | 127 | 128 | def test_analyze_with_chunks(bq: BigQueryBuilder): 129 | df = get_test_df() 130 | actual = analyze(df, _chunk_size=1) 131 | assert actual.show_string(simplify_structs=True) == _get_expected() 132 | -------------------------------------------------------------------------------- /tests/transformations_impl/test_pivot_unpivot.py: -------------------------------------------------------------------------------- 1 | from google.cloud.bigquery import Row 2 | 3 | from bigquery_frame import BigQueryBuilder 4 | from bigquery_frame import functions as f 5 | from bigquery_frame.transformations_impl.pivot_unpivot import __get_test_pivoted_df as get_test_pivoted_df 6 | from bigquery_frame.transformations_impl.pivot_unpivot import __get_test_unpivoted_df as get_test_unpivoted_df 7 | from bigquery_frame.transformations_impl.pivot_unpivot import pivot, unpivot 8 | 9 | 10 | def test_pivot_v2(bq: BigQueryBuilder): 11 | df = get_test_unpivoted_df(bq) 12 | pivoted = pivot( 13 | df, 14 | pivot_column="country", 15 | aggs=["sum(amount)"], 16 | ) 17 | expected = get_test_pivoted_df(bq) 18 | assert pivoted.collect() == expected.collect() 19 | 20 | 21 | def test_pivot_v2_case_sensitive(bq: BigQueryBuilder): 22 | df = get_test_unpivoted_df(bq) 23 | pivoted = pivot( 24 | df, 25 | pivot_column="COUNTRY", 26 | aggs=["SUM(AMOUNT)"], 27 | ) 28 | expected = get_test_pivoted_df(bq) 29 | assert pivoted.collect() == expected.collect() 30 | 31 | 32 | def test_pivot_v2_with_col_aggs(bq: BigQueryBuilder): 33 | df = get_test_unpivoted_df(bq) 34 | pivoted = pivot( 35 | df, 36 | pivot_column="country", 37 | aggs=[f.sum(f.col("amount"))], 38 | ) 39 | expected = get_test_pivoted_df(bq) 40 | assert pivoted.collect() == expected.collect() 41 | 42 | 43 | def test_pivot_v2_with_multiple_aggs(bq: BigQueryBuilder): 44 | df = get_test_unpivoted_df(bq) 45 | pivoted = pivot( 46 | df.drop("product"), 47 | pivot_column="country", 48 | aggs=["sum(amount) as total_amount", f.count(f.col("year")).alias("nb_years")], 49 | ) 50 | print(pivoted.collect()) 51 | expected = [ 52 | Row( 53 | (9000, 8, 10200, 8, 9400, 8), 54 | { 55 | "total_amount_Canada": 0, 56 | "nb_years_Canada": 1, 57 | "total_amount_China": 2, 58 | "nb_years_China": 3, 59 | "total_amount_Mexico": 4, 60 | "nb_years_Mexico": 5, 61 | }, 62 | ), 63 | ] 64 | assert pivoted.collect() == expected 65 | 66 | 67 | def test_unpivot_v1(bq: BigQueryBuilder): 68 | df = get_test_pivoted_df(bq) 69 | unpivoted = unpivot( 70 | df, 71 | ["year", "product"], 72 | key_alias="country", 73 | value_alias="amount", 74 | implem_version=1, 75 | ) 76 | expected = get_test_unpivoted_df(bq) 77 | assert unpivoted.collect() == expected.collect() 78 | 79 | 80 | def test_unpivot_v2(bq: BigQueryBuilder): 81 | df = get_test_pivoted_df(bq) 82 | unpivoted = unpivot( 83 | df, 84 | ["year", "product"], 85 | key_alias="country", 86 | value_alias="amount", 87 | implem_version=2, 88 | ) 89 | unpivoted = unpivoted.select("year", "product", "country", "amount") 90 | expected = get_test_unpivoted_df(bq) 91 | assert ( 92 | unpivoted.sort("year", "product", "country").collect() == expected.sort("year", "product", "country").collect() 93 | ) 94 | 95 | 96 | def test_unpivot_v1_exclude_nulls(bq: BigQueryBuilder): 97 | df = get_test_pivoted_df(bq) 98 | unpivoted = unpivot( 99 | df, 100 | ["year", "product"], 101 | key_alias="country", 102 | value_alias="amount", 103 | exclude_nulls=True, 104 | implem_version=1, 105 | ) 106 | expected = get_test_unpivoted_df(bq).where("amount IS NOT NULL") 107 | assert unpivoted.collect() == expected.collect() 108 | 109 | 110 | def test_unpivot_v2_exclude_nulls(bq: BigQueryBuilder): 111 | df = get_test_pivoted_df(bq) 112 | unpivoted = unpivot( 113 | df, 114 | ["year", "product"], 115 | key_alias="country", 116 | value_alias="amount", 117 | exclude_nulls=True, 118 | implem_version=2, 119 | ) 120 | unpivoted = unpivoted.select("year", "product", "country", "amount") 121 | expected = get_test_unpivoted_df(bq).where("amount IS NOT NULL") 122 | assert ( 123 | unpivoted.sort("year", "product", "country").collect() == expected.sort("year", "product", "country").collect() 124 | ) 125 | 126 | 127 | def test_unpivot_v1_with_no_pivot_column(bq: BigQueryBuilder): 128 | df = get_test_pivoted_df(bq).drop("year", "product") 129 | unpivoted = unpivot( 130 | df, 131 | pivot_columns=[], 132 | key_alias="country", 133 | value_alias="amount", 134 | implem_version=1, 135 | ) 136 | unpivoted.show() 137 | expected = get_test_unpivoted_df(bq).select("country", "amount") 138 | assert unpivoted.collect() == expected.collect() 139 | 140 | 141 | def test_unpivot_v2_with_no_pivot_column(bq: BigQueryBuilder): 142 | df = get_test_pivoted_df(bq).drop("year", "product") 143 | unpivoted = unpivot( 144 | df, 145 | pivot_columns=[], 146 | key_alias="country", 147 | value_alias="amount", 148 | implem_version=2, 149 | ) 150 | unpivoted = unpivoted.select("country", "amount") 151 | expected = get_test_unpivoted_df(bq).select("country", "amount") 152 | assert unpivoted.collect() == expected.collect() 153 | -------------------------------------------------------------------------------- /tests/transformations_impl/test_transform_all_fields.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from google.cloud.bigquery import SchemaField 4 | 5 | from bigquery_frame import BigQueryBuilder, Column, nested 6 | from bigquery_frame.transformations import transform_all_fields 7 | from bigquery_frame.utils import strip_margin 8 | 9 | WEIRD_CHARS = "_:<ù%µ> &é-+'è_çà=#|" 10 | 11 | 12 | def test_transform_all_fields_with_weird_column_names(bq: BigQueryBuilder): 13 | df = bq.sql( 14 | f"""SELECT 15 | "John" as `name`, 16 | [STRUCT(1 as `a{WEIRD_CHARS}`), STRUCT(2 as `a{WEIRD_CHARS}`)] as s1, 17 | [STRUCT([1, 2] as `a{WEIRD_CHARS}`), STRUCT([3, 4] as `a{WEIRD_CHARS}`)] as s2, 18 | [ 19 | STRUCT([ 20 | STRUCT(STRUCT(1 as `c{WEIRD_CHARS}`) as `b{WEIRD_CHARS}`), 21 | STRUCT(STRUCT(2 as `c{WEIRD_CHARS}`) as `b{WEIRD_CHARS}`) 22 | ] as `a{WEIRD_CHARS}`), 23 | STRUCT([ 24 | STRUCT(STRUCT(3 as `c{WEIRD_CHARS}`) as `b{WEIRD_CHARS}`), 25 | STRUCT(STRUCT(4 as `c{WEIRD_CHARS}`) as `b{WEIRD_CHARS}`) 26 | ] as `a{WEIRD_CHARS}`) 27 | ] as s3 28 | """, 29 | ) 30 | assert nested.schema_string(df) == strip_margin( 31 | f""" 32 | |root 33 | | |-- name: STRING (nullable = true) 34 | | |-- s1!.a{WEIRD_CHARS}: INTEGER (nullable = true) 35 | | |-- s2!.a{WEIRD_CHARS}!: INTEGER (nullable = false) 36 | | |-- s3!.a{WEIRD_CHARS}!.b{WEIRD_CHARS}.c{WEIRD_CHARS}: INTEGER (nullable = true) 37 | |""", 38 | ) 39 | assert df.show_string(simplify_structs=True) == strip_margin( 40 | """ 41 | |+------+------------+----------------------+--------------------------------------+ 42 | || name | s1 | s2 | s3 | 43 | |+------+------------+----------------------+--------------------------------------+ 44 | || John | [{1}, {2}] | [{[1, 2]}, {[3, 4]}] | [{[{{1}}, {{2}}]}, {[{{3}}, {{4}}]}] | 45 | |+------+------------+----------------------+--------------------------------------+""", 46 | ) 47 | 48 | def cast_int_as_double(col: Column, schema_field: SchemaField) -> Optional[Column]: 49 | if schema_field.field_type == "INTEGER" and schema_field.mode != "REPEATED": 50 | return col.cast("FLOAT64") 51 | 52 | actual = transform_all_fields(df, cast_int_as_double) 53 | assert nested.schema_string(actual) == strip_margin( 54 | f""" 55 | |root 56 | | |-- name: STRING (nullable = true) 57 | | |-- s1!.a{WEIRD_CHARS}: FLOAT (nullable = true) 58 | | |-- s2!.a{WEIRD_CHARS}!: FLOAT (nullable = false) 59 | | |-- s3!.a{WEIRD_CHARS}!.b{WEIRD_CHARS}.c{WEIRD_CHARS}: FLOAT (nullable = true) 60 | |""", 61 | ) 62 | assert actual.show_string(simplify_structs=True) == strip_margin( 63 | """ 64 | |+------+----------------+------------------------------+----------------------------------------------+ 65 | || name | s1 | s2 | s3 | 66 | |+------+----------------+------------------------------+----------------------------------------------+ 67 | || John | [{1.0}, {2.0}] | [{[1.0, 2.0]}, {[3.0, 4.0]}] | [{[{{1.0}}, {{2.0}}]}, {[{{3.0}}, {{4.0}}]}] | 68 | |+------+----------------+------------------------------+----------------------------------------------+""", 69 | ) 70 | -------------------------------------------------------------------------------- /tests/transformations_impl/test_union_dataframes.py: -------------------------------------------------------------------------------- 1 | from bigquery_frame import BigQueryBuilder 2 | from bigquery_frame.transformations_impl.union_dataframes import union_dataframes 3 | 4 | 5 | def test_union_dataframes(bq: BigQueryBuilder): 6 | df1 = bq.sql("""SELECT 1 as id""") 7 | df2 = bq.sql("""SELECT 2 as id""") 8 | df3 = bq.sql("""SELECT 3 as id""") 9 | actual = union_dataframes([df1, df2, df3]) 10 | expected = bq.sql("""SELECT id FROM UNNEST([1, 2, 3]) as id""") 11 | assert actual.collect() == expected.collect() 12 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from contextlib import contextmanager 3 | from io import StringIO 4 | 5 | 6 | @contextmanager 7 | def captured_output(): 8 | new_out, new_err = StringIO(), StringIO() 9 | old_out, old_err = sys.stdout, sys.stderr 10 | try: 11 | sys.stdout, sys.stderr = new_out, new_err 12 | yield sys.stdout, sys.stderr 13 | finally: 14 | sys.stdout, sys.stderr = old_out, old_err 15 | --------------------------------------------------------------------------------