├── .dockerignore
├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── Dockerfile
├── LICENSE.txt
├── NOTICE
├── README.md
├── data
└── .gitignore
├── helm-chart
├── .helmignore
├── Chart.yaml
├── install_helm_chart.sh
├── secrets
│ ├── .gitignore
│ ├── README.md
│ ├── opaque
│ │ ├── .gitignore
│ │ └── README.md
│ └── tls
│ │ ├── .gitignore
│ │ └── README.md
├── templates
│ ├── NOTES.txt
│ ├── _helpers.tpl
│ ├── deployment.yaml
│ ├── http-service.yaml
│ ├── secret.yaml
│ ├── service.yaml
│ └── tls-secret.yaml
└── values.yaml
├── logs
└── .gitignore
├── pyproject.toml
├── scripts
├── start_flight_ibis_server.sh
└── test_flight_ibis.sh
├── spark_connector
└── flight-spark-source-1.0-SNAPSHOT-shaded.jar
├── src
└── flight_ibis
│ ├── __init__.py
│ ├── client.py
│ ├── config.py
│ ├── constants.py
│ ├── data_logic_ibis.py
│ ├── examples
│ ├── client_arrow_example.py
│ ├── client_cookbook_example.py
│ ├── server_arrow_example.py
│ └── server_cookbook_example.py
│ ├── server.py
│ ├── setup
│ ├── __init__.py
│ ├── create_local_duckdb_database.py
│ ├── mtls_utilities.py
│ └── tls_utilities.py
│ └── spark_client_test.py
└── tls
└── .gitignore
/.dockerignore:
--------------------------------------------------------------------------------
1 | venv/
2 | build/
3 | src/flight_ibis.egg-info
4 | .idea/
5 | tls/*
6 | data/*
7 | logs/*
8 | Dockerfile
9 | .gitignore
10 | k8s/
11 | helm-chart/
12 | certbot/
13 | .github/
14 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: flight-ibis-ci
2 |
3 | on:
4 | workflow_dispatch:
5 | release:
6 | types:
7 | - published
8 |
9 | env:
10 | DOCKER_IMAGE_NAME: gizmodata/flight-ibis
11 |
12 | jobs:
13 | docker:
14 | name: Build Docker images
15 | strategy:
16 | matrix:
17 | include:
18 | - platform: amd64
19 | runner: buildjet-8vcpu-ubuntu-2204
20 | - platform: arm64
21 | runner: buildjet-8vcpu-ubuntu-2204-arm
22 | runs-on: ${{ matrix.runner }}
23 | steps:
24 | - name: Login to Docker Hub
25 | uses: docker/login-action@v2
26 | with:
27 | username: ${{ secrets.DOCKERHUB_USERNAME }}
28 | password: ${{ secrets.DOCKERHUB_PASSWORD }}
29 |
30 | - name: Build and push
31 | uses: docker/build-push-action@v5
32 | with:
33 | platforms: linux/${{ matrix.platform }}
34 | push: true
35 | tags: |
36 | ${{ env.DOCKER_IMAGE_NAME }}:latest-${{ matrix.platform }}
37 | ${{ env.DOCKER_IMAGE_NAME }}:${{ github.ref_name }}-${{ matrix.platform }}
38 | no-cache: true
39 | provenance: false
40 |
41 | update-image-manifest:
42 | name: Update DockerHub image manifest to include all built platforms
43 | needs: docker
44 | runs-on: ubuntu-latest
45 | steps:
46 | - name: Login to Docker Hub
47 | uses: docker/login-action@v2
48 | with:
49 | username: ${{ secrets.DOCKERHUB_USERNAME }}
50 | password: ${{ secrets.DOCKERHUB_PASSWORD }}
51 |
52 | - name: Create and push manifest images
53 | uses: Noelware/docker-manifest-action@master # or use a pinned version in the Releases tab
54 | with:
55 | inputs: ${{ env.DOCKER_IMAGE_NAME }}:latest,${{ env.DOCKER_IMAGE_NAME }}:${{ github.ref_name }}
56 | images: ${{ env.DOCKER_IMAGE_NAME }}:latest-amd64,${{ env.DOCKER_IMAGE_NAME }}:latest-arm64
57 | push: true
58 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv/
2 | build/
3 | src/flight_ibis.egg-info
4 | .idea/
5 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.12.7
2 |
3 | ARG TARGETPLATFORM
4 | ARG TARGETARCH
5 | ARG TARGETVARIANT
6 | RUN printf "I'm building for TARGETPLATFORM=${TARGETPLATFORM}" \
7 | && printf ", TARGETARCH=${TARGETARCH}" \
8 | && printf ", TARGETVARIANT=${TARGETVARIANT} \n" \
9 | && printf "With uname -s : " && uname -s \
10 | && printf "and uname -m : " && uname -m
11 |
12 | # Update OS and install packages
13 | RUN apt-get update --yes && \
14 | apt-get dist-upgrade --yes && \
15 | apt-get install --yes \
16 | screen \
17 | unzip \
18 | vim \
19 | zip && \
20 | apt-get clean && \
21 | rm -rf /var/lib/apt/lists/*
22 |
23 | # Setup the AWS Client
24 | WORKDIR /tmp
25 |
26 | RUN case ${TARGETPLATFORM} in \
27 | "linux/amd64") AWSCLI_FILE=https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip ;; \
28 | "linux/arm64") AWSCLI_FILE=https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip ;; \
29 | esac && \
30 | curl "${AWSCLI_FILE}" -o "awscliv2.zip" && \
31 | unzip awscliv2.zip && \
32 | ./aws/install && \
33 | rm -f awscliv2.zip
34 |
35 | # Install DuckDB CLI
36 | ARG DUCKDB_VERSION="1.1.2"
37 |
38 | RUN case ${TARGETPLATFORM} in \
39 | "linux/amd64") DUCKDB_FILE=https://github.com/duckdb/duckdb/releases/download/v${DUCKDB_VERSION}/duckdb_cli-linux-amd64.zip ;; \
40 | "linux/arm64") DUCKDB_FILE=https://github.com/duckdb/duckdb/releases/download/v${DUCKDB_VERSION}/duckdb_cli-linux-aarch64.zip ;; \
41 | esac && \
42 | curl --output /tmp/duckdb.zip --location ${DUCKDB_FILE} && \
43 | unzip /tmp/duckdb.zip -d /usr/bin && \
44 | rm /tmp/duckdb.zip
45 |
46 | # Create an application user
47 | RUN useradd app_user --create-home
48 |
49 | ARG APP_DIR="/opt/flight_ibis"
50 | RUN mkdir --parents ${APP_DIR} && \
51 | chown app_user:app_user ${APP_DIR}
52 |
53 | USER app_user
54 |
55 | WORKDIR ${APP_DIR}
56 |
57 | # Setup a Python Virtual environment
58 | ENV VIRTUAL_ENV=${APP_DIR}/venv
59 | RUN python3 -m venv ${VIRTUAL_ENV} && \
60 | echo ". ${VIRTUAL_ENV}/bin/activate" >> ~/.bashrc
61 |
62 | # Set the PATH so that the Python Virtual environment is referenced for subsequent RUN steps (hat tip: https://pythonspeed.com/articles/activate-virtualenv-dockerfile/)
63 | ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
64 |
65 | # Upgrade pip, setuptools, and wheel
66 | RUN pip install --upgrade setuptools pip wheel
67 |
68 | COPY --chown=app_user:app_user . .
69 |
70 | # Install the Flight Ibis package and remove the source files
71 | RUN pip install . && \
72 | rm -rf src build && \
73 | rm -f pyproject.toml
74 |
75 | # Build the seed database
76 | RUN flight-data-bootstrap
77 |
78 | # Run a test to ensure that the server works...
79 | RUN scripts/test_flight_ibis.sh
80 |
81 | # Open Flight server port
82 | EXPOSE 8815
83 |
84 | # Define our entrypoint to start Flight Ibis server
85 | ENTRYPOINT scripts/start_flight_ibis_server.sh
86 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 Gizmo Data LLC
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | This product includes software developed by Voltron Data.
2 | Modifications made by Philip Moore / Gizmo Data LLC in 2024.
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Arrow Flight Ibis Server demo
2 | [
](https://github.com/gizmodata/flight-ibis-demo)
3 | [
](https://hub.docker.com/repository/docker/gizmodata/flight-ibis/general)
4 | [](https://github.com/gizmodata/flight-ibis-demo/actions/workflows/ci.yml)
5 |
6 | ## Setup (to run locally)
7 |
8 | ### Install package
9 |
10 | #### 1. Clone the repo
11 | ```shell
12 | git clone https://github.com/gizmodata/flight-ibis-demo.git
13 |
14 | ```
15 |
16 | #### 2. Setup Python
17 | Create a new Python 3.8+ virtual environment and install the Flight server/client demo with:
18 | ```shell
19 | cd flight-ibis-demo
20 |
21 | # Create the virtual environment
22 | python3 -m venv .venv
23 |
24 | # Activate the virtual environment
25 | . .venv/bin/activate
26 |
27 | # Upgrade pip, setuptools, and wheel
28 | pip install --upgrade pip setuptools wheel
29 |
30 | # Install the flight-ibis demo
31 | pip install --editable .
32 |
33 | ```
34 |
35 | ### Note
36 | For the following commands - if you running from source and using `--editable` mode (for development purposes) - you will need to set the PYTHONPATH environment variable as follows:
37 | ```shell
38 | export PYTHONPATH=$(pwd)/src
39 | ```
40 |
41 | #### 3. Create a sample TPC-H 1GB database (will take about 243MB of disk space due to compression)
42 | ```shell
43 | . .venv/bin/activate
44 | flight-data-bootstrap
45 |
46 | ```
47 | ## Running the Flight Ibis Server and Client
48 | We've provided 4 examples of how to run the Flight Ibis Server and Client - starting from the least to the most secure.
49 |
50 | ### Option 1) Running the Flight Ibis Server / Client demo without TLS (NOT secure)
51 |
52 | #### Run the example
53 | ##### 1. Run the Flight Server
54 | ```shell
55 | . .venv/bin/activate
56 | flight-server
57 |
58 | ```
59 |
60 | ##### 2. Open another terminal (leave the server running) - and run the Flight Client
61 | ```shell
62 | . .venv/bin/activate
63 | flight-client
64 |
65 | ```
66 |
67 | ### Option 2) Running the Flight Ibis Server / Client demo with TLS (somewhat secure)
68 |
69 | #### Run the example
70 | ##### 1. Generate a localhost TLS certificate keypair
71 | ```shell
72 | . .venv/bin/activate
73 | flight-create-tls-keypair
74 |
75 | ```
76 |
77 | ##### 2. Run the Flight Server with TLS enabled (using the keypair created in step #1 above)
78 | ```shell
79 | . .venv/bin/activate
80 | flight-server --tls=tls/server.crt tls/server.key
81 |
82 | ```
83 |
84 | ##### 3. Open another terminal (leave the server running) - and run the Flight Client with TLS enabled (trusting your cert created in step #1)
85 | ```shell
86 | . .venv/bin/activate
87 | flight-client --host=localhost \
88 | --tls \
89 | --tls-roots=tls/server.crt
90 |
91 | ```
92 |
93 | ### Option 3) Running the Flight Ibis Server / Client demo with TLS and MTLS authentication (more secure)
94 |
95 | #### Run the example
96 | ##### 1. Generate a localhost TLS certificate keypair
97 | ```shell
98 | . .venv/bin/activate
99 | flight-create-tls-keypair
100 |
101 | ```
102 |
103 | ##### 2. Generate a Certificate Authority (CA) keypair - used to sign client certificates
104 | ```shell
105 | . .venv/bin/activate
106 | flight-create-mtls-ca-keypair
107 |
108 | ```
109 |
110 | ##### 3. Generate a Client Certificate keypair (signed by the CA you just created in step #2 above)
111 | ```shell
112 | . .venv/bin/activate
113 | flight-create-mtls-client-keypair
114 |
115 | ```
116 |
117 | ##### 4. Run the Flight Server with TLS and MTLS enabled (using the certificates created in the steps above)
118 | ```shell
119 | . .venv/bin/activate
120 | flight-server --tls=tls/server.crt tls/server.key \
121 | --verify-client \
122 | --mtls=tls/ca.crt
123 |
124 | ```
125 |
126 | ##### 5. Open another terminal (leave the server running) - and run the Flight Client with TLS and MTLS enabled
127 | ```shell
128 | . .venv/bin/activate
129 | flight-client --host=localhost \
130 | --tls \
131 | --tls-roots=tls/server.crt \
132 | --mtls=tls/client.crt tls/client.key
133 | ```
134 |
135 | ### Option 4) Running the Flight Ibis Server / Client demo with username/password authentication, TLS, and MTLS authentication (most secure)
136 |
137 | #### Run the example
138 | ##### 1. Generate a localhost TLS certificate keypair
139 | ```shell
140 | . .venv/bin/activate
141 | flight-create-tls-keypair
142 |
143 | ```
144 |
145 | ##### 2. Generate a Certificate Authority (CA) keypair - used to sign client certificates
146 | ```shell
147 | . .venv/bin/activate
148 | flight-create-mtls-ca-keypair
149 |
150 | ```
151 |
152 | ##### 3. Generate a Client Certificate keypair (signed by the CA you just created in step #2 above)
153 | ```shell
154 | . .venv/bin/activate
155 | flight-create-mtls-client-keypair
156 |
157 | ```
158 |
159 | ##### 4. Run the Flight Server requiring a specified username/password - with TLS and MTLS enabled (using the certificates created in the steps above)
160 | ```shell
161 | . .venv/bin/activate
162 | flight-server --tls=tls/server.crt tls/server.key \
163 | --verify-client \
164 | --mtls=tls/ca.crt \
165 | --flight-username=test \
166 | --flight-password=testing123
167 |
168 | ```
169 |
170 | ##### 5. Open another terminal (leave the server running) - and run the Flight Client using the same username/password - with TLS and MTLS enabled
171 | ```shell
172 | . .venv/bin/activate
173 | flight-client --host=localhost \
174 | --tls \
175 | --tls-roots=tls/server.crt \
176 | --mtls=tls/client.crt tls/client.key \
177 | --flight-username=test \
178 | --flight-password=testing123
179 | ```
180 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/helm-chart/.helmignore:
--------------------------------------------------------------------------------
1 | # Patterns to ignore when building packages.
2 | # This supports shell glob matching, relative path matching, and
3 | # negation (prefixed with !). Only one pattern per line.
4 | .DS_Store
5 | # Common VCS dirs
6 | .git/
7 | .gitignore
8 | .bzr/
9 | .bzrignore
10 | .hg/
11 | .hgignore
12 | .svn/
13 | # Common backup files
14 | *.swp
15 | *.bak
16 | *.tmp
17 | *.orig
18 | *~
19 | # Various IDEs
20 | .project
21 | .idea/
22 | *.tmproj
23 | .vscode/
24 |
--------------------------------------------------------------------------------
/helm-chart/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | name: flight-ibis
3 | description: A Helm chart for Kubernetes
4 |
5 | # A chart can be either an 'application' or a 'library' chart.
6 | #
7 | # Application charts are a collection of templates that can be packaged into versioned archives
8 | # to be deployed.
9 | #
10 | # Library charts provide useful utilities or functions for the chart developer. They're included as
11 | # a dependency of application charts to inject those utilities and functions into the rendering
12 | # pipeline. Library charts do not define any templates and therefore cannot be deployed.
13 | type: application
14 |
15 | # This is the chart version. This version number should be incremented each time you make changes
16 | # to the chart and its templates, including the app version.
17 | # Versions are expected to follow Semantic Versioning (https://semver.org/)
18 | version: 0.1.0
19 |
20 | # This is the version number of the application being deployed. This version number should be
21 | # incremented each time you make changes to the application. Versions are not expected to
22 | # follow Semantic Versioning. They should reflect the version the application is using.
23 | # It is recommended to use it with quotes.
24 | appVersion: "1.16.0"
25 |
--------------------------------------------------------------------------------
/helm-chart/install_helm_chart.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | kubectl config set-context --current --namespace=flight
4 |
5 | helm upgrade test --install .
6 |
--------------------------------------------------------------------------------
/helm-chart/secrets/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/helm-chart/secrets/README.md:
--------------------------------------------------------------------------------
1 | # Needed files to support the helm chart
2 | Use the "opaque" folder to create opaque secrets.
3 |
4 | Use the "tls" folder to create tls secrets (server.crt, server.key, ca.crt)
5 |
6 | The files are git ignored for security reasons.
7 |
--------------------------------------------------------------------------------
/helm-chart/secrets/opaque/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/helm-chart/secrets/opaque/README.md:
--------------------------------------------------------------------------------
1 | # Needed files to support the helm chart
2 | Create the following files in this folder with the EXACT names shown (no extensions):
3 | 1. FLIGHT_USERNAME
4 | 2. FLIGHT_PASSWORD
5 |
6 | You may set the contents of the files however you choose.
7 |
8 | The files are git ignored for security reasons.
9 |
--------------------------------------------------------------------------------
/helm-chart/secrets/tls/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/helm-chart/secrets/tls/README.md:
--------------------------------------------------------------------------------
1 | # Needed files to support the helm chart
2 | Create the following files in this folder with the EXACT names shown (no extensions):
3 | 1. server.crt - Full-chain TLS Public Certificate
4 | 2. server.key
5 |
6 | Descriptions:
7 | server.crt - Full-chain TLS Public Certificate
8 | server.key - TLS Private Key
9 |
10 | The files are git ignored for security reasons.
11 |
--------------------------------------------------------------------------------
/helm-chart/templates/NOTES.txt:
--------------------------------------------------------------------------------
1 | 1. Get the application URL by running these commands:
2 | {{- if contains "NodePort" .Values.service.type }}
3 | export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "flight-ibis.fullname" . }})
4 | export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}")
5 | echo http://$NODE_IP:$NODE_PORT
6 | {{- else if contains "LoadBalancer" .Values.service.type }}
7 | NOTE: It may take a few minutes for the LoadBalancer IP to be available.
8 | You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "flight-ibis.fullname" . }}'
9 | export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "flight-ibis.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}")
10 | echo http://$SERVICE_IP:{{ .Values.service.port }}
11 | {{- else if contains "ClusterIP" .Values.service.type }}
12 | export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "flight-ibis.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}")
13 | export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}")
14 | echo "Visit http://127.0.0.1:8080 to use your application"
15 | kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT
16 | {{- end }}
17 |
--------------------------------------------------------------------------------
/helm-chart/templates/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "flight-ibis.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
11 | If release name contains chart name it will be used as a full name.
12 | */}}
13 | {{- define "flight-ibis.fullname" -}}
14 | {{- if .Values.fullnameOverride }}
15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
16 | {{- else }}
17 | {{- $name := default .Chart.Name .Values.nameOverride }}
18 | {{- if contains $name .Release.Name }}
19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
20 | {{- else }}
21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
22 | {{- end }}
23 | {{- end }}
24 | {{- end }}
25 |
26 | {{/*
27 | Create chart name and version as used by the chart label.
28 | */}}
29 | {{- define "flight-ibis.chart" -}}
30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
31 | {{- end }}
32 |
33 | {{/*
34 | Common labels
35 | */}}
36 | {{- define "flight-ibis.labels" -}}
37 | helm.sh/chart: {{ include "flight-ibis.chart" . }}
38 | {{ include "flight-ibis.selectorLabels" . }}
39 | {{- if .Chart.AppVersion }}
40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
41 | {{- end }}
42 | app.kubernetes.io/managed-by: {{ .Release.Service }}
43 | {{- end }}
44 |
45 | {{/*
46 | Selector labels
47 | */}}
48 | {{- define "flight-ibis.selectorLabels" -}}
49 | app.kubernetes.io/name: {{ include "flight-ibis.name" . }}
50 | app.kubernetes.io/instance: {{ .Release.Name }}
51 | {{- end }}
52 |
53 | {{/*
54 | Create the name of the service account to use
55 | */}}
56 | {{- define "flight-ibis.serviceAccountName" -}}
57 | {{- if .Values.serviceAccount.create }}
58 | {{- default (include "flight-ibis.fullname" .) .Values.serviceAccount.name }}
59 | {{- else }}
60 | {{- default "default" .Values.serviceAccount.name }}
61 | {{- end }}
62 | {{- end }}
63 |
--------------------------------------------------------------------------------
/helm-chart/templates/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: {{ include "flight-ibis.fullname" . }}
5 | labels:
6 | {{- include "flight-ibis.labels" . | nindent 4 }}
7 | spec:
8 | replicas: {{ .Values.replicaCount }}
9 | selector:
10 | matchLabels:
11 | {{- include "flight-ibis.selectorLabels" . | nindent 6 }}
12 | template:
13 | metadata:
14 | annotations:
15 | checksum/opaque-secrets: {{ include (print $.Template.BasePath "/secret.yaml") . | sha256sum }}
16 | checksum/tls-secrets: {{ include (print $.Template.BasePath "/tls-secret.yaml") . | sha256sum }}
17 | {{- with .Values.podAnnotations }}
18 | {{- toYaml . | nindent 8 }}
19 | {{- end }}
20 | labels:
21 | {{- include "flight-ibis.selectorLabels" . | nindent 8 }}
22 | spec:
23 | {{- with .Values.imagePullSecrets }}
24 | imagePullSecrets:
25 | {{- toYaml . | nindent 8 }}
26 | {{- end }}
27 | securityContext:
28 | {{- toYaml .Values.podSecurityContext | nindent 8 }}
29 | volumes:
30 | - name: tls-volume
31 | secret:
32 | secretName: {{ include "flight-ibis.fullname" . }}-tls-secret
33 | optional: false
34 | containers:
35 | - name: {{ .Chart.Name }}
36 | command: [ "/bin/bash" ]
37 | # args: [ "-c", "sleep infinity" ]
38 | args: [ "-c", "flight-server" ]
39 | env:
40 | - name: FLIGHT_LOCATION
41 | value: {{ .Values.FlightServerConfig.FLIGHT_LOCATION }}
42 | - name: FLIGHT_PORT
43 | value: {{ .Values.FlightServerConfig.FLIGHT_PORT | quote }}
44 | - name: DATABASE_FILE
45 | value: {{ .Values.FlightServerConfig.DATABASE_FILE }}
46 | - name: DUCKDB_THREADS
47 | value: {{ .Values.FlightServerConfig.DUCKDB_THREADS | quote }}
48 | - name: DUCKDB_MEMORY_LIMIT
49 | value: {{ .Values.FlightServerConfig.DUCKDB_MEMORY_LIMIT }}
50 | - name: FLIGHT_TLS
51 | value: {{ .Values.FlightServerConfig.FLIGHT_TLS }}
52 | - name: FLIGHT_VERIFY_CLIENT
53 | value: {{ .Values.FlightServerConfig.FLIGHT_VERIFY_CLIENT | quote }}
54 | - name: FLIGHT_MTLS
55 | value: {{ .Values.FlightServerConfig.FLIGHT_MTLS }}
56 | - name: FLIGHT_USERNAME
57 | valueFrom:
58 | secretKeyRef:
59 | name: {{ include "flight-ibis.fullname" . }}-secret
60 | key: FLIGHT_USERNAME
61 | optional: false
62 | - name: FLIGHT_PASSWORD
63 | valueFrom:
64 | secretKeyRef:
65 | name: {{ include "flight-ibis.fullname" . }}-secret
66 | key: FLIGHT_PASSWORD
67 | optional: false
68 | - name: LOG_LEVEL
69 | value: {{ .Values.FlightServerConfig.LOG_LEVEL | default "INFO" }}
70 | - name: MAX_FLIGHT_ENDPOINTS
71 | value: {{ .Values.replicaCount | quote }}
72 | volumeMounts:
73 | - mountPath: /opt/flight_ibis/tls
74 | readOnly: true
75 | name: tls-volume
76 | securityContext:
77 | {{- toYaml .Values.securityContext | nindent 12 }}
78 | image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
79 | imagePullPolicy: {{ .Values.image.pullPolicy }}
80 | ports:
81 | - name: flight
82 | containerPort: {{ .Values.service.port }}
83 | protocol: TCP
84 | resources:
85 | {{- toYaml .Values.resources | nindent 12 }}
86 | {{- with .Values.nodeSelector }}
87 | nodeSelector:
88 | {{- toYaml . | nindent 8 }}
89 | {{- end }}
90 | {{- with .Values.affinity }}
91 | affinity:
92 | {{- toYaml . | nindent 8 }}
93 | {{- end }}
94 | {{- with .Values.tolerations }}
95 | tolerations:
96 | {{- toYaml . | nindent 8 }}
97 | {{- end }}
98 |
--------------------------------------------------------------------------------
/helm-chart/templates/http-service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: {{ include "flight-ibis.fullname" . }}-http
5 | labels:
6 | {{- include "flight-ibis.labels" . | nindent 4 }}
7 | spec:
8 | type: ClusterIP
9 | ports:
10 | - port: 80
11 | targetPort: http
12 | protocol: TCP
13 | name: http
14 | selector:
15 | {{- include "flight-ibis.selectorLabels" . | nindent 4 }}
16 |
--------------------------------------------------------------------------------
/helm-chart/templates/secret.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Secret
3 | metadata:
4 | name: {{ include "flight-ibis.fullname" . }}-secret
5 | labels:
6 | {{- include "flight-ibis.labels" . | nindent 4 }}
7 | type: Opaque
8 | data:
9 | {{ (.Files.Glob "secrets/opaque/*").AsSecrets | indent 2 }}
10 |
--------------------------------------------------------------------------------
/helm-chart/templates/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: {{ include "flight-ibis.fullname" . }}
5 | labels:
6 | {{- include "flight-ibis.labels" . | nindent 4 }}
7 | spec:
8 | type: {{ .Values.service.type }}
9 | ports:
10 | - port: {{ .Values.service.port }}
11 | targetPort: flight
12 | protocol: TCP
13 | name: flight
14 | selector:
15 | {{- include "flight-ibis.selectorLabels" . | nindent 4 }}
16 |
--------------------------------------------------------------------------------
/helm-chart/templates/tls-secret.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Secret
3 | metadata:
4 | name: {{ include "flight-ibis.fullname" . }}-tls-secret
5 | labels:
6 | {{- include "flight-ibis.labels" . | nindent 4 }}
7 | type: Opaque
8 | data:
9 | {{ (.Files.Glob "secrets/tls/*").AsSecrets | indent 2 }}
10 |
--------------------------------------------------------------------------------
/helm-chart/values.yaml:
--------------------------------------------------------------------------------
1 | # Default values for flight-ibis.
2 | # This is a YAML-formatted file.
3 | # Declare variables to be passed into your templates.
4 |
5 | replicaCount: 3
6 |
7 | image:
8 | repository: gizmodata/flight-ibis
9 | pullPolicy: Always
10 | # Overrides the image tag whose default is the chart appVersion.
11 | tag: latest
12 |
13 | FlightServerConfig:
14 | FLIGHT_LOCATION: flight-ibis.vdfieldeng.com:8815
15 | FLIGHT_PORT: 8815
16 | DATABASE_FILE: data/tpch.duckdb
17 | DUCKDB_THREADS: 4
18 | DUCKDB_MEMORY_LIMIT: 12GB
19 | FLIGHT_TLS: tls/server.crt tls/server.key
20 | FLIGHT_VERIFY_CLIENT: "TRUE"
21 | FLIGHT_MTLS: tls/ca.crt
22 | LOG_LEVEL: DEBUG
23 |
24 | imagePullSecrets: []
25 | nameOverride: ""
26 | fullnameOverride: ""
27 |
28 | serviceAccount:
29 | # Specifies whether a service account should be created
30 | create: true
31 | # Annotations to add to the service account
32 | annotations: {}
33 | # The name of the service account to use.
34 | # If not set and create is true, a name is generated using the fullname template
35 | name: ""
36 |
37 | podAnnotations: {}
38 |
39 | podSecurityContext: {}
40 | # fsGroup: 2000
41 |
42 | securityContext: {}
43 | # capabilities:
44 | # drop:
45 | # - ALL
46 | # readOnlyRootFilesystem: true
47 | # runAsNonRoot: true
48 | # runAsUser: 1000
49 |
50 | service:
51 | type: LoadBalancer
52 | port: 8815
53 |
54 | resources:
55 | limits:
56 | cpu: 4
57 | memory: 16Gi
58 | requests:
59 | cpu: 4
60 | memory: 16Gi
61 |
62 | nodeSelector:
63 | eks.amazonaws.com/nodegroup: arm64-node-pool-20231002184408630400000022
64 |
65 | tolerations:
66 | - key: sidewinder
67 | operator: Equal
68 | value: "true"
69 | effect: NoSchedule
70 |
71 | affinity: {}
72 |
--------------------------------------------------------------------------------
/logs/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.setuptools]
6 | include-package-data = true
7 |
8 | [tool.setuptools.packages.find]
9 | where = ["./src"] # list of folders that contain the packages (["."] by default)
10 | include = ["*"] # package names should match these glob patterns (["*"] by default)
11 |
12 | [tool.setuptools.package-data]
13 | "*" = ["*.yaml"]
14 |
15 | [project]
16 | name = "flight-ibis"
17 | version = "0.0.34"
18 | description = "An Apache Arrow Flight server/client example powered by Ibis and DuckDB"
19 | readme = "README.md"
20 | authors = [{ name = "Philip Moore", email = "prmoore77@hotmail.com" }]
21 | license = { file = "LICENSE" }
22 | classifiers = [
23 | "License :: OSI Approved :: Apache Software License",
24 | "Programming Language :: Python",
25 | "Programming Language :: Python :: 3",
26 | ]
27 | keywords = ["flight-ibis", "flight", "ibis"]
28 | dependencies = [
29 | "pyarrow==17.0.*",
30 | "click==8.1.*",
31 | "ibis-framework[duckdb]==9.5.*",
32 | "munch==4.0.*",
33 | "pyOpenSSL==24.2.*",
34 | "cryptography==43.0.*",
35 | "pyjwt==2.9.*",
36 | "codetiming==1.4.*",
37 | "pyspark==3.5.*",
38 | "urllib3==2.2.*"
39 | ]
40 | requires-python = ">=3.10"
41 |
42 | [project.optional-dependencies]
43 | dev = ["bumpver", "pip-tools", "pytest"]
44 |
45 | [project.urls]
46 | Homepage = "https://github.com/gizmodata/flight-ibis-demo"
47 |
48 | [project.scripts]
49 | flight-server = "flight_ibis.server:run_flight_server"
50 | flight-client = "flight_ibis.client:run_flight_client"
51 | flight-spark-client = "flight_ibis.spark_client_test:run_spark_flight_client"
52 | flight-data-bootstrap = "flight_ibis.setup.create_local_duckdb_database:create_local_duckdb_database"
53 | flight-create-tls-keypair = "flight_ibis.setup.tls_utilities:click_create_tls_keypair"
54 | flight-create-mtls-ca-keypair = "flight_ibis.setup.mtls_utilities:create_ca_keypair"
55 | flight-create-mtls-client-keypair = "flight_ibis.setup.mtls_utilities:create_client_keypair"
56 |
57 | [tool.bumpver]
58 | current_version = "0.0.34"
59 | version_pattern = "MAJOR.MINOR.PATCH[PYTAGNUM]"
60 | commit_message = "bump version {old_version} -> {new_version}"
61 | commit = true
62 | tag = true
63 |
64 | [tool.bumpver.file_patterns]
65 | "pyproject.toml" = [
66 | '^version = "{version}"$',
67 | '^current_version = "{version}"$',
68 | ]
69 | "src/flight_ibis/__init__.py" = [
70 | '^__version__ = "{version}"$',
71 | ]
72 |
--------------------------------------------------------------------------------
/scripts/start_flight_ibis_server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_DIR=$(dirname ${0})
4 |
5 | # Generate TLS certificates if they are not present...
6 | pushd "${SCRIPT_DIR}/.."
7 | if [ ! -f tls/server.crt ]
8 | then
9 | echo "Generating TLS certs..."
10 | flight-create-tls-keypair
11 | fi
12 | popd
13 |
14 | # Start the server
15 | flight-server
16 |
--------------------------------------------------------------------------------
/scripts/test_flight_ibis.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Start the Flight Ibis Server - in the background...
4 | flight-server &
5 |
6 | # Sleep for a few seconds to allow the server to have time for initialization...
7 | sleep 10
8 |
9 | # Run the client
10 | flight-client
11 |
12 | RC=$?
13 |
14 | if [ ${RC} -eq 0 ]; then
15 | echo "flight-client succeeded with return code ${RC}"
16 | else
17 | echo "flight-client failed with return code ${RC}"
18 | fi
19 |
20 | # Stop the server...
21 | kill %1
22 |
23 | # Exit with the code of the python test...
24 | exit ${RC}
25 |
--------------------------------------------------------------------------------
/spark_connector/flight-spark-source-1.0-SNAPSHOT-shaded.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gizmodata/flight-ibis-demo/e2543826a5212fa380cb0286ce6a500f4c478e82/spark_connector/flight-spark-source-1.0-SNAPSHOT-shaded.jar
--------------------------------------------------------------------------------
/src/flight_ibis/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.34"
2 |
--------------------------------------------------------------------------------
/src/flight_ibis/client.py:
--------------------------------------------------------------------------------
1 | import pyarrow as pa
2 | import pyarrow.flight
3 | from datetime import date, datetime
4 | import json
5 | import click
6 | import logging
7 | import os
8 | from codetiming import Timer
9 | from .config import TIMER_TEXT, get_logger
10 | from .constants import LOCALHOST, DEFAULT_FLIGHT_PORT, GRPC_TCP_SCHEME, GRPC_TLS_SCHEME
11 |
12 |
13 | @click.command()
14 | @click.option(
15 | "--host",
16 | type=str,
17 | default=LOCALHOST,
18 | help="Address or hostname of the Flight server to connect to"
19 | )
20 | @click.option(
21 | "--port",
22 | type=int,
23 | default=DEFAULT_FLIGHT_PORT,
24 | help="Port number of the Flight server to connect to"
25 | )
26 | @click.option(
27 | "--tls/--no-tls",
28 | type=bool,
29 | default=False,
30 | show_default=True,
31 | required=True,
32 | help="Connect to the server with tls"
33 | )
34 | @click.option(
35 | "--tls-verify/--no-tls-verify",
36 | type=bool,
37 | default=(os.getenv("TLS_VERIFY", "TRUE").upper() == "TRUE"),
38 | show_default=True,
39 | help="Verify the server's TLS certificate hostname and signature. Using --no-tls-verify is insecure, only use for development purposes!"
40 | )
41 | @click.option(
42 | "--tls-roots",
43 | type=str,
44 | default=None,
45 | show_default=True,
46 | help="'Path to trusted TLS certificate(s)"
47 | )
48 | @click.option(
49 | "--mtls",
50 | nargs=2,
51 | default=None,
52 | metavar=('CERTFILE', 'KEYFILE'),
53 | help="Enable transport-level security"
54 | )
55 | @click.option(
56 | "--flight-username",
57 | type=str,
58 | default=os.getenv("FLIGHT_USERNAME"),
59 | required=False,
60 | show_default=False,
61 | help="The username used to connect to the Flight server"
62 | )
63 | @click.option(
64 | "--flight-password",
65 | type=str,
66 | default=os.getenv("FLIGHT_PASSWORD"),
67 | required=False,
68 | show_default=False,
69 | help="The password used to connect to the Flight server"
70 | )
71 | @click.option(
72 | "--log-level",
73 | type=click.Choice(["INFO", "DEBUG", "WARNING", "CRITICAL"], case_sensitive=False),
74 | default=os.getenv("LOG_LEVEL", "INFO"),
75 | required=True,
76 | help="The logging level to use"
77 | )
78 | @click.option(
79 | "--log-file",
80 | type=str,
81 | required=False,
82 | help="The log file to write to. If None, will just log to stdout"
83 | )
84 | @click.option(
85 | "--log-file-mode",
86 | type=click.Choice(["a", "w"], case_sensitive=True),
87 | default="w",
88 | help="The log file mode, use value: a for 'append', and value: w to overwrite..."
89 | )
90 | @click.option(
91 | "--from-date",
92 | type=click.DateTime(formats=["%Y-%m-%d"]),
93 | default=date(year=1994, month=1, day=1).isoformat(),
94 | required=True,
95 | help="The from date to use for the data filter - in ISO format (example: 2020-11-01) - for 01-November-2020"
96 | )
97 | @click.option(
98 | "--to-date",
99 | type=click.DateTime(formats=["%Y-%m-%d"]),
100 | default=date(year=1995, month=12, day=31).isoformat(),
101 | required=True,
102 | help="The to date to use for the data filter - in ISO format (example: 2020-11-01) - for 01-November-2020"
103 | )
104 | @click.option(
105 | "--num-endpoints",
106 | type=int,
107 | default=1,
108 | required=True,
109 | help="The number of endpoints to use for downloading the data - more endpoints mean smaller chunks of data (for reducing memory usage, etc.)"
110 | )
111 | @click.option(
112 | "--custkey-filter-value",
113 | type=int,
114 | default=None,
115 | required=False,
116 | help="The value to use for the customer key filter - if None, then no filter will be applied"
117 | )
118 | @click.option(
119 | "--requested-columns",
120 | type=str,
121 | default="*",
122 | required=True,
123 | help="A comma-separated list of columns to request from the server - example value: 'o_orderkey,o_custkey' - if '*', then all columns will be requested"
124 | )
125 | def run_flight_client(host: str,
126 | port: int,
127 | tls: bool,
128 | tls_verify: bool,
129 | tls_roots: str,
130 | mtls: list,
131 | flight_username: str,
132 | flight_password: str,
133 | log_level: str,
134 | log_file: str,
135 | log_file_mode: str,
136 | from_date: datetime,
137 | to_date: datetime,
138 | num_endpoints: int,
139 | custkey_filter_value: int,
140 | requested_columns: str
141 | ):
142 | logger = get_logger(filename=log_file,
143 | filemode=log_file_mode,
144 | logger_name="flight_client",
145 | log_level=getattr(logging, log_level.upper())
146 | )
147 |
148 | with Timer(name="Flight Client test",
149 | text=TIMER_TEXT,
150 | initial_text=True,
151 | logger=logger.info
152 | ):
153 | redacted_locals = {key: value for key, value in locals().items() if key not in ["flight_password"
154 | ]
155 | }
156 | logger.info(msg=f"run_flight_client - was called with args: {redacted_locals}")
157 |
158 | scheme = GRPC_TCP_SCHEME
159 | connection_args = {}
160 | if tls:
161 | scheme = GRPC_TLS_SCHEME
162 | if not tls_verify:
163 | connection_args["disable_server_verification"] = True
164 | logger.warning("TLS verification is disabled, this is insecure, only use for development purposes!")
165 |
166 | if tls_roots:
167 | with open(tls_roots, "rb") as root_certs:
168 | connection_args["tls_root_certs"] = root_certs.read()
169 | if mtls:
170 | if not tls:
171 | raise RuntimeError("TLS must be enabled in order to use MTLS, aborting.")
172 | else:
173 | with open(mtls[0], "rb") as cert_file:
174 | mtls_cert_chain = cert_file.read()
175 | with open(mtls[1], "rb") as key_file:
176 | mtls_private_key = key_file.read()
177 | connection_args["cert_chain"] = mtls_cert_chain
178 | connection_args["private_key"] = mtls_private_key
179 |
180 | flight_server_uri = f"{scheme}://{host}:{port}"
181 | client = pyarrow.flight.FlightClient(location=flight_server_uri,
182 | **connection_args)
183 |
184 | logger.info(msg=f"Connected to Flight Server at location: {flight_server_uri}")
185 |
186 | options = None
187 | if flight_username and flight_password:
188 | if not tls:
189 | raise RuntimeError("TLS must be enabled in order to use authentication, aborting.")
190 | token_pair = client.authenticate_basic_token(username=flight_username.encode(),
191 | password=flight_password.encode(),
192 | )
193 | logger.debug(f"token_pair = {token_pair}")
194 | options = pa.flight.FlightCallOptions(headers=[token_pair])
195 |
196 | # Display session authentication info (if applicable)
197 | if flight_username:
198 | action = pyarrow.flight.Action("who-am-i", b"")
199 | who_am_i_results = list(client.do_action(action=action, options=options))[0]
200 | authenticated_user = who_am_i_results.body.to_pybytes().decode()
201 | if authenticated_user:
202 | logger.info(f"Authenticated to the Flight Server as user: {authenticated_user}")
203 |
204 | arg_dict = dict(num_endpoints=num_endpoints,
205 | min_date=from_date.isoformat(),
206 | max_date=to_date.isoformat()
207 | )
208 | command_dict = dict(command="get_golden_rule_facts",
209 | kwargs=arg_dict
210 | )
211 | if custkey_filter_value:
212 | command_dict.update(filters=[{"column": "o_custkey",
213 | "operator": "=",
214 | "value": custkey_filter_value
215 | }
216 | ])
217 |
218 | if requested_columns != "*":
219 | command_dict.update(columns=requested_columns.split(","))
220 |
221 | command_descriptor = pa.flight.FlightDescriptor.for_command(command=json.dumps(command_dict))
222 | logger.info(msg=f"Command Descriptor: {command_descriptor}")
223 | # Read content of the dataset
224 | flight = client.get_flight_info(command_descriptor, options)
225 |
226 | logger.debug(msg=f"Flight Schema: {flight.schema}")
227 | total_endpoints = 0
228 | total_chunks = 0
229 | total_rows = 0
230 | total_bytes = 0
231 |
232 | for endpoint in flight.endpoints:
233 | with Timer(name=f"Fetch data from Flight Server end-point: {endpoint}",
234 | text=TIMER_TEXT,
235 | initial_text=True,
236 | logger=logger.debug
237 | ):
238 | total_endpoints += 1
239 | reader = client.do_get(endpoint.ticket, options)
240 | for chunk in reader:
241 | total_chunks += 1
242 | data = chunk.data
243 | if data:
244 | logger.debug(msg=f"Chunk size - rows: {data.num_rows} / bytes: {data.nbytes}")
245 | logger.debug(msg=f"Endpoint: {endpoint} / Chunk: {total_chunks}")
246 | logger.info(msg=data.to_pandas().head())
247 | total_rows += data.num_rows
248 | total_bytes += data.nbytes
249 |
250 | logger.info(msg=f"Got {total_rows} rows total ({total_bytes} bytes) - from {total_endpoints} endpoint(s) ({total_chunks} total chunk(s))")
251 |
252 |
253 | if __name__ == '__main__':
254 | run_flight_client()
255 |
--------------------------------------------------------------------------------
/src/flight_ibis/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | from pathlib import Path
5 |
6 | # Constants
7 | SCRIPT_DIR = Path(__file__).parent.resolve()
8 | LOG_DIR = Path("logs").resolve()
9 | DATA_DIR = Path("data").resolve()
10 | DUCKDB_DB_FILE = DATA_DIR / "tpch.duckdb"
11 | DUCKDB_THREADS = 4
12 | DUCKDB_MEMORY_LIMIT = "4GB"
13 | TIMER_TEXT = "{name}: Elapsed time: {:.4f} seconds"
14 | DEFAULT_FLIGHT_ENDPOINTS: int = 1
15 |
16 | # Logging Constants
17 | LOGGING_DATEFMT = '%Y-%m-%d %H:%M:%S %Z'
18 | LOGGING_LEVEL = getattr(logging, os.getenv("LOG_LEVEL", "INFO"))
19 | LOGGING_FORMAT = "%(asctime)s - %(levelname)-8s %(message)s"
20 | EXTENDED_LOGGING_FORMAT = LOGGING_FORMAT + " / Module: '%(module)s' / Function: '%(funcName)s' / LineNo: %(lineno)d / Process: %(process)d - '%(processName)s' / Thread: %(thread)d - '%(threadName)s'"
21 | BASIC_LOGGING_KWARGS = dict(format=LOGGING_FORMAT,
22 | datefmt=LOGGING_DATEFMT,
23 | level=LOGGING_LEVEL
24 | )
25 | STDOUT_LOGGING_KWARGS = dict(stream=sys.stdout)
26 | LOGGING_REDACT_AUTHORIZATION_HEADER = (os.getenv("LOGGING_REDACT_AUTHORIZATION_HEADER", "TRUE").upper() == "TRUE")
27 |
28 |
29 | def get_logger(filename: str = None,
30 | filemode: str = "a",
31 | logger_name: str = None,
32 | log_level: int = LOGGING_LEVEL
33 | ):
34 | logger = logging.getLogger(name=logger_name)
35 | logger.setLevel(log_level)
36 |
37 | # Create a formatter for the log messages
38 | logging_format = LOGGING_FORMAT
39 | if log_level == logging.DEBUG:
40 | logging_format = EXTENDED_LOGGING_FORMAT
41 |
42 | formatter = logging.Formatter(fmt=logging_format)
43 |
44 | # Create a stream handler to log to stdout
45 | console_handler = logging.StreamHandler()
46 | console_handler.setLevel(level=log_level)
47 | console_handler.setFormatter(fmt=formatter)
48 |
49 | logger.addHandler(hdlr=console_handler)
50 |
51 | # Create a file handler to log to a file
52 | if filename:
53 | file_handler = logging.FileHandler(filename=LOG_DIR / filename,
54 | mode=filemode)
55 | file_handler.setLevel(level=log_level)
56 | file_handler.setFormatter(fmt=formatter)
57 | logger.addHandler(hdlr=file_handler)
58 |
59 | return logger
60 |
--------------------------------------------------------------------------------
/src/flight_ibis/constants.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 |
4 | # Constants
5 | LOCALHOST_IP_ADDRESS: str = "0.0.0.0"
6 | LOCALHOST: str = "localhost"
7 | DEFAULT_FLIGHT_PORT: int = 8815
8 | GRPC_TCP_SCHEME: str = "grpc+tcp" # No TLS enabled...
9 | GRPC_TLS_SCHEME: str = "grpc+tls"
10 | BEGINNING_OF_TIME: datetime = datetime(year=1994, month=1, day=1)
11 | END_OF_TIME: datetime = datetime(year=9999, month=12, day=25) # Merry last Christmas!
12 | PYARROW_UNKNOWN: int = -1
13 | JWT_ISS = "Flight Ibis"
14 | JWT_AUD = "Flight Ibis"
15 |
--------------------------------------------------------------------------------
/src/flight_ibis/data_logic_ibis.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import ibis
4 | from ibis import _
5 | from datetime import datetime
6 | import pyarrow
7 | from ibis.expr.types import Scalar
8 |
9 | from .config import DUCKDB_DB_FILE, DUCKDB_THREADS, DUCKDB_MEMORY_LIMIT, TIMER_TEXT, get_logger
10 | from codetiming import Timer
11 |
12 | # Constants
13 | INNER_JOIN = "inner"
14 | SEMI_JOIN = "semi"
15 | MAX_ORDER_TOTALPRICE = 500_000.00
16 | MAX_PERCENT_RANK = 0.98
17 |
18 | # Ibis parameters (global in scope)
19 | p_min_date = ibis.param(type="date")
20 | p_max_date = ibis.param(type="date")
21 | p_total_hash_buckets = ibis.param(type="int")
22 | p_hash_bucket_num = ibis.param(type="int")
23 |
24 |
25 | def build_customer_order_summary_expr(conn: ibis.BaseBackend) -> ibis.Expr:
26 | orders = conn.table("orders")
27 |
28 | # Aggregate global order counts for use in a subsequent filter
29 | customer_order_summary_expr = (orders
30 | .group_by(_.o_custkey)
31 | .aggregate(count_star=_.count())
32 | .mutate(order_count_percent_rank=ibis.percent_rank()
33 | .over(ibis.window(order_by=_.count_star))
34 | )
35 | .filter(_.order_count_percent_rank <= MAX_PERCENT_RANK)
36 | ).cache()
37 |
38 | return customer_order_summary_expr
39 |
40 |
41 | def build_golden_rules_ibis_expression(conn: ibis.BaseBackend,
42 | customer_order_summary_expr: ibis.Expr) -> ibis.Expr:
43 | orders = conn.table("orders").mutate(o_totalprice=_.o_totalprice.cast("double"))
44 | lineitem = conn.table("lineitem").mutate(l_quantity=_.l_quantity.cast("double"),
45 | l_extendedprice=_.l_extendedprice.cast("double"),
46 | l_discount=_.l_discount.cast("double"),
47 | l_tax=_.l_tax.cast("double")
48 | )
49 | region = conn.table("region")
50 | nation = conn.table("nation")
51 | customer = conn.table("customer")
52 | part = conn.table("part")
53 |
54 | # Filter orders to the hash bucket asked for
55 | # Filter out orders larger than MAX_ORDER_TOTALPRICE
56 | orders_prelim = (orders
57 | .filter(_.o_orderdate.between(lower=p_min_date, upper=p_max_date))
58 | .mutate(hash_result=_.o_orderkey.hash())
59 | .mutate(hash_bucket=(_.hash_result % p_total_hash_buckets))
60 | .filter(_.hash_bucket == (p_hash_bucket_num - 1))
61 | .filter(_.o_totalprice <= MAX_ORDER_TOTALPRICE)
62 | .drop("o_comment", "hash_result", "hash_bucket")
63 | )
64 |
65 | # Filter out orders with customers that have more orders than the MAX_PERCENT_RANK
66 | # This simulates filtering out store cards
67 | orders_pre_filtered = orders_prelim.filter((orders_prelim.o_custkey == customer_order_summary_expr.o_custkey).any())
68 |
69 | # Filter out the European region
70 | region_filtered = (region
71 | .filter(_.r_name.notin(["EUROPE"])
72 | )
73 | )
74 |
75 | nation_filtered = (nation
76 | .filter((nation.n_regionkey == region_filtered.r_regionkey).any())
77 | )
78 |
79 | customer_filtered = (customer
80 | .filter((customer.c_nationkey == nation_filtered.n_nationkey).any())
81 | )
82 |
83 | orders_filtered = (orders_pre_filtered
84 | .filter((orders_pre_filtered.o_custkey == customer_filtered.c_custkey).any())
85 | )
86 |
87 | # Filter out Manufacturer#3 from parts
88 | part_filtered = (part
89 | .filter(part.p_mfgr.notin(["Manufacturer#3"]
90 | )
91 | )
92 | )
93 |
94 | lineitem_filtered = (lineitem
95 | .filter((lineitem.l_partkey == part_filtered.p_partkey).any())
96 | )
97 |
98 | # Join lineitem to orders - keep the columns as well...
99 | golden_rule_facts_expr = (orders_filtered.join(right=lineitem_filtered,
100 | predicates=orders_filtered.o_orderkey == lineitem_filtered.l_orderkey,
101 | how=INNER_JOIN
102 | )
103 | )
104 |
105 | return golden_rule_facts_expr
106 |
107 |
108 | def build_param_map(hash_bucket_num: int,
109 | total_hash_buckets: int,
110 | min_date: datetime,
111 | max_date: datetime,
112 | existing_logger=None,
113 | log_file: str = None
114 | ):
115 | return {p_hash_bucket_num: hash_bucket_num,
116 | p_total_hash_buckets: total_hash_buckets,
117 | p_min_date: min_date,
118 | p_max_date: max_date
119 | }
120 |
121 |
122 | def execute_golden_rules(golden_rules_ibis_expression: ibis.Expr,
123 | hash_bucket_num: int,
124 | total_hash_buckets: int,
125 | min_date: datetime,
126 | max_date: datetime,
127 | existing_logger=None,
128 | log_file: str = None
129 | ) -> pyarrow.Table:
130 | try:
131 | if existing_logger:
132 | logger = existing_logger
133 | else:
134 | logger = get_logger(filename=log_file,
135 | filemode="w",
136 | logger_name="data_logic"
137 | )
138 |
139 | logger.debug(f"get_golden_rule_facts - was called with args: {locals()}")
140 |
141 | with Timer(name=f"Run Golden Rules Ibis Expression query against DuckDB back-end",
142 | text=TIMER_TEXT,
143 | initial_text=True,
144 | logger=logger.debug
145 | ):
146 | expr_params = build_param_map(hash_bucket_num=hash_bucket_num,
147 | total_hash_buckets=total_hash_buckets,
148 | min_date=min_date,
149 | max_date=max_date
150 | )
151 | logger.debug(msg=("SQL for Ibis Expression: golden_rules_ibis_expression: \n"
152 | f"{ibis.to_sql(expr=golden_rules_ibis_expression, params=expr_params)}"
153 | )
154 | )
155 |
156 | pyarrow_batches = (golden_rules_ibis_expression
157 | .to_pyarrow(params=expr_params)
158 | )
159 |
160 | logger.debug(f"get_golden_rule_facts - successfully converted Ibis expression to PyArrow.")
161 |
162 | except Exception as e:
163 | logger.exception(msg=f"get_golden_rule_facts - Exception: {str(e)}")
164 | raise
165 | else:
166 | return pyarrow_batches
167 | finally:
168 | logger.debug(msg=f"get_golden_rule_facts - Finally block")
169 | if not existing_logger:
170 | logger.handlers.clear()
171 |
172 |
173 | if __name__ == '__main__':
174 | logger = get_logger()
175 | with Timer(name=f"Run Golden Rules test",
176 | text=TIMER_TEXT,
177 | initial_text=True,
178 | logger=logger.info
179 | ):
180 | TOTAL_HASH_BUCKETS: int = 11
181 |
182 | conn = ibis.duckdb.connect(database=DUCKDB_DB_FILE,
183 | threads=DUCKDB_THREADS,
184 | memory_limit=DUCKDB_MEMORY_LIMIT,
185 | read_only=True
186 | )
187 |
188 | customer_order_summary_expr = build_customer_order_summary_expr(conn=conn)
189 | golden_rules_ibis_expression = build_golden_rules_ibis_expression(conn=conn,
190 | customer_order_summary_expr=customer_order_summary_expr
191 | )
192 | for i in range(1, TOTAL_HASH_BUCKETS + 1):
193 | logger.info(msg=f"Bucket #: {i}")
194 | reader = execute_golden_rules(golden_rules_ibis_expression=golden_rules_ibis_expression,
195 | hash_bucket_num=i,
196 | total_hash_buckets=TOTAL_HASH_BUCKETS,
197 | min_date=datetime(year=1994, month=1, day=1),
198 | max_date=datetime(year=1997, month=12, day=31),
199 | existing_logger=logger
200 | )
201 | for chunk in reader:
202 | logger.info(msg=chunk.to_pandas().head(n=10))
203 |
--------------------------------------------------------------------------------
/src/flight_ibis/examples/client_arrow_example.py:
--------------------------------------------------------------------------------
1 | # Licensed to the Apache Software Foundation (ASF) under one
2 | # or more contributor license agreements. See the NOTICE file
3 | # distributed with this work for additional information
4 | # regarding copyright ownership. The ASF licenses this file
5 | # to you under the Apache License, Version 2.0 (the
6 | # "License"); you may not use this file except in compliance
7 | # with the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing,
12 | # software distributed under the License is distributed on an
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 | # KIND, either express or implied. See the License for the
15 | # specific language governing permissions and limitations
16 | # under the License.
17 |
18 | """An example Flight CLI client."""
19 |
20 | import argparse
21 | import sys
22 |
23 | import pyarrow
24 | import pyarrow.flight
25 | import pyarrow.csv as csv
26 |
27 |
28 | def list_flights(args, client, connection_args={}):
29 | print('Flights\n=======')
30 | for flight in client.list_flights():
31 | descriptor = flight.descriptor
32 | if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH:
33 | print("Path:", descriptor.path)
34 | elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD:
35 | print("Command:", descriptor.command)
36 | else:
37 | print("Unknown descriptor type")
38 |
39 | print("Total records:", end=" ")
40 | if flight.total_records >= 0:
41 | print(flight.total_records)
42 | else:
43 | print("Unknown")
44 |
45 | print("Total bytes:", end=" ")
46 | if flight.total_bytes >= 0:
47 | print(flight.total_bytes)
48 | else:
49 | print("Unknown")
50 |
51 | print("Number of endpoints:", len(flight.endpoints))
52 | print("Schema:")
53 | print(flight.schema)
54 | print('---')
55 |
56 | print('\nActions\n=======')
57 | for action in client.list_actions():
58 | print("Type:", action.type)
59 | print("Description:", action.description)
60 | print('---')
61 |
62 |
63 | def do_action(args, client, connection_args={}):
64 | try:
65 | buf = pyarrow.allocate_buffer(0)
66 | action = pyarrow.flight.Action(args.action_type, buf)
67 | print('Running action', args.action_type)
68 | for result in client.do_action(action):
69 | print("Got result", result.body.to_pybytes())
70 | except pyarrow.lib.ArrowIOError as e:
71 | print("Error calling action:", e)
72 |
73 |
74 | def push_data(args, client, connection_args={}):
75 | print('File Name:', args.file)
76 | my_table = csv.read_csv(args.file)
77 | print('Table rows=', str(len(my_table)))
78 | df = my_table.to_pandas()
79 | print(df.head())
80 | writer, _ = client.do_put(
81 | pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
82 | writer.write_table(my_table)
83 | writer.close()
84 |
85 |
86 | def get_flight(args, client, connection_args={}):
87 | if args.path:
88 | descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path)
89 | else:
90 | descriptor = pyarrow.flight.FlightDescriptor.for_command(args.command)
91 |
92 | info = client.get_flight_info(descriptor)
93 | for endpoint in info.endpoints:
94 | print('Ticket:', endpoint.ticket)
95 | for location in endpoint.locations:
96 | print(location)
97 | get_client = pyarrow.flight.FlightClient(location,
98 | **connection_args)
99 | reader = get_client.do_get(endpoint.ticket)
100 | df = reader.read_pandas()
101 | print(df)
102 |
103 |
104 | def _add_common_arguments(parser):
105 | parser.add_argument('--tls', action='store_true',
106 | help='Enable transport-level security')
107 | parser.add_argument('--tls-roots', default=None,
108 | help='Path to trusted TLS certificate(s)')
109 | parser.add_argument("--mtls", nargs=2, default=None,
110 | metavar=('CERTFILE', 'KEYFILE'),
111 | help="Enable transport-level security")
112 | parser.add_argument('host', type=str,
113 | help="Address or hostname to connect to")
114 |
115 |
116 | def main():
117 | parser = argparse.ArgumentParser()
118 | subcommands = parser.add_subparsers()
119 |
120 | cmd_list = subcommands.add_parser('list')
121 | cmd_list.set_defaults(action='list')
122 | _add_common_arguments(cmd_list)
123 | cmd_list.add_argument('-l', '--list', action='store_true',
124 | help="Print more details.")
125 |
126 | cmd_do = subcommands.add_parser('do')
127 | cmd_do.set_defaults(action='do')
128 | _add_common_arguments(cmd_do)
129 | cmd_do.add_argument('action_type', type=str,
130 | help="The action type to run.")
131 |
132 | cmd_put = subcommands.add_parser('put')
133 | cmd_put.set_defaults(action='put')
134 | _add_common_arguments(cmd_put)
135 | cmd_put.add_argument('file', type=str,
136 | help="CSV file to upload.")
137 |
138 | cmd_get = subcommands.add_parser('get')
139 | cmd_get.set_defaults(action='get')
140 | _add_common_arguments(cmd_get)
141 | cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True)
142 | cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append',
143 | help="The path for the descriptor.")
144 | cmd_get_descriptor.add_argument('-c', '--command', type=str,
145 | help="The command for the descriptor.")
146 |
147 | args = parser.parse_args()
148 | if not hasattr(args, 'action'):
149 | parser.print_help()
150 | sys.exit(1)
151 |
152 | commands = {
153 | 'list': list_flights,
154 | 'do': do_action,
155 | 'get': get_flight,
156 | 'put': push_data,
157 | }
158 | host, port = args.host.split(':')
159 | port = int(port)
160 | scheme = "grpc+tcp"
161 | connection_args = {}
162 | if args.tls:
163 | scheme = "grpc+tls"
164 | if args.tls_roots:
165 | with open(args.tls_roots, "rb") as root_certs:
166 | connection_args["tls_root_certs"] = root_certs.read()
167 | if args.mtls:
168 | with open(args.mtls[0], "rb") as cert_file:
169 | tls_cert_chain = cert_file.read()
170 | with open(args.mtls[1], "rb") as key_file:
171 | tls_private_key = key_file.read()
172 | connection_args["cert_chain"] = tls_cert_chain
173 | connection_args["private_key"] = tls_private_key
174 | client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}",
175 | **connection_args)
176 | while True:
177 | try:
178 | action = pyarrow.flight.Action("healthcheck", b"")
179 | options = pyarrow.flight.FlightCallOptions(timeout=1)
180 | list(client.do_action(action, options=options))
181 | break
182 | except pyarrow.ArrowIOError as e:
183 | if "Deadline" in str(e):
184 | print("Server is not ready, waiting...")
185 | commands[args.action](args, client, connection_args)
186 |
187 |
188 | if __name__ == '__main__':
189 | main()
190 |
--------------------------------------------------------------------------------
/src/flight_ibis/examples/client_cookbook_example.py:
--------------------------------------------------------------------------------
1 | import pyarrow as pa
2 | import pyarrow.flight
3 |
4 |
5 | def main():
6 | client = pa.flight.connect("grpc://0.0.0.0:8815")
7 |
8 | # Upload a new dataset
9 | NUM_BATCHES = 1024
10 | ROWS_PER_BATCH = 4096
11 | upload_descriptor = pa.flight.FlightDescriptor.for_path("streamed.parquet")
12 | batch = pa.record_batch([
13 | pa.array(range(ROWS_PER_BATCH)),
14 | ], names=["ints"])
15 | writer, _ = client.do_put(upload_descriptor, batch.schema)
16 | with writer:
17 | for _ in range(NUM_BATCHES):
18 | writer.write_batch(batch)
19 |
20 | # Read content of the dataset
21 | flight = client.get_flight_info(upload_descriptor)
22 | reader = client.do_get(flight.endpoints[0].ticket)
23 | total_rows = 0
24 | for chunk in reader:
25 | total_rows += chunk.data.num_rows
26 | print("Got", total_rows, "rows total, expected", NUM_BATCHES * ROWS_PER_BATCH)
27 |
28 |
29 | if __name__ == '__main__':
30 | main()
31 |
--------------------------------------------------------------------------------
/src/flight_ibis/examples/server_arrow_example.py:
--------------------------------------------------------------------------------
1 | # Licensed to the Apache Software Foundation (ASF) under one
2 | # or more contributor license agreements. See the NOTICE file
3 | # distributed with this work for additional information
4 | # regarding copyright ownership. The ASF licenses this file
5 | # to you under the Apache License, Version 2.0 (the
6 | # "License"); you may not use this file except in compliance
7 | # with the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing,
12 | # software distributed under the License is distributed on an
13 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 | # KIND, either express or implied. See the License for the
15 | # specific language governing permissions and limitations
16 | # under the License.
17 |
18 | """An example Flight Python server."""
19 |
20 | import argparse
21 | import ast
22 | import threading
23 | import time
24 |
25 | import pyarrow
26 | import pyarrow.flight
27 |
28 |
29 | class FlightServer(pyarrow.flight.FlightServerBase):
30 | def __init__(self, host="localhost", location=None,
31 | tls_certificates=None, verify_client=False,
32 | root_certificates=None, auth_handler=None):
33 | super(FlightServer, self).__init__(
34 | location, auth_handler, tls_certificates, verify_client,
35 | root_certificates)
36 | self.flights = {}
37 | self.host = host
38 | self.tls_certificates = tls_certificates
39 |
40 | @classmethod
41 | def descriptor_to_key(self, descriptor):
42 | return (descriptor.descriptor_type.value, descriptor.command,
43 | tuple(descriptor.path or tuple()))
44 |
45 | def _make_flight_info(self, key, descriptor, table):
46 | if self.tls_certificates:
47 | location = pyarrow.flight.Location.for_grpc_tls(
48 | self.host, self.port)
49 | else:
50 | location = pyarrow.flight.Location.for_grpc_tcp(
51 | self.host, self.port)
52 | endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
53 |
54 | mock_sink = pyarrow.MockOutputStream()
55 | stream_writer = pyarrow.RecordBatchStreamWriter(
56 | mock_sink, table.schema)
57 | stream_writer.write_table(table)
58 | stream_writer.close()
59 | data_size = mock_sink.size()
60 |
61 | return pyarrow.flight.FlightInfo(table.schema,
62 | descriptor, endpoints,
63 | table.num_rows, data_size)
64 |
65 | def list_flights(self, context, criteria):
66 | for key, table in self.flights.items():
67 | if key[1] is not None:
68 | descriptor = \
69 | pyarrow.flight.FlightDescriptor.for_command(key[1])
70 | else:
71 | descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
72 |
73 | yield self._make_flight_info(key, descriptor, table)
74 |
75 | def get_flight_info(self, context, descriptor):
76 | key = FlightServer.descriptor_to_key(descriptor)
77 | if key in self.flights:
78 | table = self.flights[key]
79 | return self._make_flight_info(key, descriptor, table)
80 | raise KeyError('Flight not found.')
81 |
82 | def do_put(self, context, descriptor, reader, writer):
83 | key = FlightServer.descriptor_to_key(descriptor)
84 | print(key)
85 | self.flights[key] = reader.read_all()
86 | print(self.flights[key])
87 |
88 | def do_get(self, context, ticket):
89 | key = ast.literal_eval(ticket.ticket.decode())
90 | if key not in self.flights:
91 | return None
92 | return pyarrow.flight.RecordBatchStream(self.flights[key])
93 |
94 | def list_actions(self, context):
95 | return [
96 | ("clear", "Clear the stored flights."),
97 | ("shutdown", "Shut down this server."),
98 | ]
99 |
100 | def do_action(self, context, action):
101 | if action.type == "clear":
102 | raise NotImplementedError(
103 | "{} is not implemented.".format(action.type))
104 | elif action.type == "healthcheck":
105 | pass
106 | elif action.type == "shutdown":
107 | yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
108 | # Shut down on background thread to avoid blocking current
109 | # request
110 | threading.Thread(target=self._shutdown).start()
111 | else:
112 | raise KeyError("Unknown action {!r}".format(action.type))
113 |
114 | def _shutdown(self):
115 | """Shut down after a delay."""
116 | print("Server is shutting down...")
117 | time.sleep(2)
118 | self.shutdown()
119 |
120 |
121 | def main():
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument("--host", type=str, default="localhost",
124 | help="Address or hostname to listen on")
125 | parser.add_argument("--port", type=int, default=5005,
126 | help="Port number to listen on")
127 | parser.add_argument("--tls", nargs=2, default=None,
128 | metavar=('CERTFILE', 'KEYFILE'),
129 | help="Enable transport-level security")
130 | parser.add_argument("--verify_client", type=bool, default=False,
131 | help="enable mutual TLS and verify the client if True")
132 |
133 | args = parser.parse_args()
134 | tls_certificates = []
135 | scheme = "grpc+tcp"
136 | if args.tls:
137 | scheme = "grpc+tls"
138 | with open(args.tls[0], "rb") as cert_file:
139 | tls_cert_chain = cert_file.read()
140 | with open(args.tls[1], "rb") as key_file:
141 | tls_private_key = key_file.read()
142 | tls_certificates.append((tls_cert_chain, tls_private_key))
143 |
144 | location = "{}://{}:{}".format(scheme, args.host, args.port)
145 |
146 | server = FlightServer(args.host, location,
147 | tls_certificates=tls_certificates,
148 | verify_client=args.verify_client)
149 | print("Serving on", location)
150 | server.serve()
151 |
152 |
153 | if __name__ == '__main__':
154 | main()
155 |
--------------------------------------------------------------------------------
/src/flight_ibis/examples/server_cookbook_example.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import pyarrow as pa
4 | import pyarrow.flight
5 | import pyarrow.parquet
6 | from config import logger
7 |
8 |
9 | class FlightServer(pa.flight.FlightServerBase):
10 |
11 | def __init__(self, location="grpc://0.0.0.0:8815",
12 | repo=pathlib.Path("../datasets"), **kwargs):
13 | super(FlightServer, self).__init__(location, **kwargs)
14 | self._location = location
15 | self._repo = repo
16 |
17 | def _make_flight_info(self, dataset):
18 | dataset_path = self._repo / dataset
19 | schema = pa.parquet.read_schema(dataset_path)
20 | metadata = pa.parquet.read_metadata(dataset_path)
21 | descriptor = pa.flight.FlightDescriptor.for_path(
22 | dataset.encode('utf-8')
23 | )
24 | logger.info(msg=f"Descriptor: {descriptor}")
25 | endpoints = [pa.flight.FlightEndpoint(dataset, [self._location])]
26 | return pyarrow.flight.FlightInfo(schema,
27 | descriptor,
28 | endpoints,
29 | metadata.num_rows,
30 | metadata.serialized_size)
31 |
32 | def list_flights(self, context, criteria):
33 | for dataset in self._repo.iterdir():
34 | yield self._make_flight_info(dataset.name)
35 |
36 | def get_flight_info(self, context, descriptor):
37 | return self._make_flight_info(descriptor.path[0].decode('utf-8'))
38 |
39 | def do_put(self, context, descriptor, reader, writer):
40 | dataset = descriptor.path[0].decode('utf-8')
41 | dataset_path = self._repo / dataset
42 | # Read the uploaded data and write to Parquet incrementally
43 | with dataset_path.open("wb") as sink:
44 | with pa.parquet.ParquetWriter(sink, reader.schema) as writer:
45 | for chunk in reader:
46 | writer.write_table(pa.Table.from_batches([chunk.data]))
47 |
48 | def do_get(self, context, ticket):
49 | dataset = ticket.ticket.decode('utf-8')
50 | # Stream data from a file
51 | dataset_path = self._repo / dataset
52 | reader = pa.parquet.ParquetFile(dataset_path)
53 | return pa.flight.GeneratorStream(
54 | reader.schema_arrow, reader.iter_batches())
55 |
56 | def list_actions(self, context):
57 | return [
58 | ("drop_dataset", "Delete a dataset."),
59 | ]
60 |
61 | def do_action(self, context, action):
62 | if action.type == "drop_dataset":
63 | self.do_drop_dataset(action.body.to_pybytes().decode('utf-8'))
64 | else:
65 | raise NotImplementedError
66 |
67 | def do_drop_dataset(self, dataset):
68 | dataset_path = self._repo / dataset
69 | dataset_path.unlink()
70 |
71 |
72 | if __name__ == '__main__':
73 | server = FlightServer()
74 | server._repo.mkdir(exist_ok=True)
75 | server.serve()
76 |
77 |
--------------------------------------------------------------------------------
/src/flight_ibis/server.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import functools
3 | import json
4 | import os
5 | import sys
6 | import uuid
7 | from copy import deepcopy
8 | from datetime import datetime, timezone, timedelta
9 | from functools import cached_property
10 | from pathlib import Path
11 | from threading import BoundedSemaphore
12 |
13 | import click
14 | import duckdb
15 | import ibis
16 | import jwt
17 | import pyarrow.compute
18 | import pyarrow.flight
19 | import pyarrow.parquet
20 | from OpenSSL import crypto
21 | from munch import Munch, munchify
22 | from pyarrow.flight import SchemaResult
23 |
24 | from . import __version__ as flight_server_version
25 | from .config import get_logger, logging, DUCKDB_DB_FILE, DUCKDB_THREADS, DUCKDB_MEMORY_LIMIT, DEFAULT_FLIGHT_ENDPOINTS, LOGGING_REDACT_AUTHORIZATION_HEADER
26 | from .constants import LOCALHOST_IP_ADDRESS, LOCALHOST, DEFAULT_FLIGHT_PORT, GRPC_TCP_SCHEME, GRPC_TLS_SCHEME, BEGINNING_OF_TIME, PYARROW_UNKNOWN, JWT_ISS, JWT_AUD
27 | from .data_logic_ibis import build_customer_order_summary_expr, build_golden_rules_ibis_expression, execute_golden_rules
28 |
29 | # Define a semaphore pool with 1 max thread to protect against multiple clients using the same ibis connection at the same time
30 | pool_sema = BoundedSemaphore(value=1)
31 |
32 |
33 | class BasicAuthServerMiddlewareFactory(pyarrow.flight.ServerMiddlewareFactory):
34 | """
35 | Middleware that implements username-password authentication.
36 |
37 | Parameters
38 | ----------
39 | creds: Dict[str, str]
40 | A dictionary of username-password values to accept.
41 | """
42 |
43 | @cached_property
44 | def class_name(self):
45 | return self.__class__.__name__
46 |
47 | def __init__(self,
48 | creds: dict,
49 | cert: bytes,
50 | key: bytes,
51 | logger):
52 | super().__init__()
53 | self.creds = creds
54 |
55 | # Extract the public key from the certificate
56 | pub_key = crypto.load_certificate(type=crypto.FILETYPE_PEM, buffer=cert).get_pubkey()
57 | self.public_key = crypto.dump_publickey(type=crypto.FILETYPE_PEM, pkey=pub_key)
58 |
59 | self.private_key = key
60 | self.logger = logger
61 |
62 | def start_call(self, info, headers):
63 | """Validate credentials at the start of every call."""
64 | logging_headers = deepcopy(headers)
65 | if LOGGING_REDACT_AUTHORIZATION_HEADER:
66 | logging_headers.update({"authorization": "<>"})
67 |
68 | self.logger.debug(msg=f"{self.class_name}.start_call - called with args: info={info}, headers={logging_headers}")
69 | # Search for the authentication header (case-insensitive)
70 | auth_header = None
71 | for header in headers:
72 | if header.lower() == "authorization":
73 | auth_header = headers[header][0]
74 | break
75 |
76 | if not auth_header:
77 | raise pyarrow.flight.FlightUnauthenticatedError("No credentials supplied")
78 |
79 | # The header has the structure "AuthType TokenValue", e.g.
80 | # "Basic " or "Bearer ".
81 | auth_type, _, value = auth_header.partition(" ")
82 |
83 | if auth_type == "Basic":
84 | # Initial "login". The user provided a username/password
85 | # combination encoded in the same way as HTTP Basic Auth.
86 | decoded = base64.b64decode(value).decode("utf-8")
87 | username, _, password = decoded.partition(':')
88 | if not password or password != self.creds.get(username):
89 | error_message = f"{self.class_name}.start_call - invalid username/password"
90 | self.logger.error(msg=error_message)
91 | raise pyarrow.flight.FlightUnauthenticatedError(error_message)
92 |
93 | # Create a JWT and sign it with our private key
94 | token = jwt.encode(payload=dict(jti=str(uuid.uuid4()),
95 | iss=JWT_ISS,
96 | sub=username,
97 | aud=JWT_AUD,
98 | iat=datetime.utcnow(),
99 | nbf=datetime.utcnow() - timedelta(minutes=1),
100 | exp=datetime.now(tz=timezone.utc) + timedelta(hours=24),
101 | ),
102 | key=self.private_key,
103 | algorithm="RS256"
104 | )
105 | self.logger.info(msg=f"{self.class_name}.start_call - User: '{username}' successfully authenticated - issued JWT.")
106 | return BasicAuthServerMiddleware(token=token,
107 | username=username
108 | )
109 | elif auth_type == "Bearer":
110 | # An actual call. Validate the bearer token.
111 | try:
112 | decoded_jwt = jwt.decode(jwt=value,
113 | key=self.public_key,
114 | algorithms=["RS256"],
115 | issuer=JWT_ISS,
116 | audience=JWT_AUD
117 | )
118 | except Exception as e:
119 | raise pyarrow.flight.FlightUnauthenticatedError("Invalid token")
120 | else:
121 | subject = decoded_jwt.get("sub")
122 | self.logger.debug(msg=f"{self.class_name}.start_call - JWT with subject: '{subject}' was successfully verified")
123 | return BasicAuthServerMiddleware(token=value,
124 | username=subject
125 | )
126 |
127 | raise pyarrow.flight.FlightUnauthenticatedError("No credentials supplied")
128 |
129 |
130 | class BasicAuthServerMiddleware(pyarrow.flight.ServerMiddleware):
131 | """Middleware that implements username-password authentication."""
132 |
133 | def __init__(self, token: str, username: str):
134 | self.token = token
135 | self.username = username
136 |
137 | def sending_headers(self):
138 | """Return the authentication token to the client."""
139 | return {"authorization": f"Bearer {self.token}"}
140 |
141 | def who_am_i(self):
142 | return self.username
143 |
144 |
145 | class NoOpAuthHandler(pyarrow.flight.ServerAuthHandler):
146 | """
147 | A handler that implements username-password authentication.
148 |
149 | This is required only so that the server will respond to the internal
150 | Handshake RPC call, which the client calls when authenticate_basic_token
151 | is called. Otherwise, it should be a no-op as the actual authentication is
152 | implemented in middleware.
153 | """
154 |
155 | def authenticate(self, outgoing, incoming):
156 | pass
157 |
158 | def is_valid(self, token):
159 | return ""
160 |
161 |
162 | def debuggable(func):
163 | """A decorator to enable GUI (i.e. PyCharm) debugging in the
164 | decorated Arrow Flight RPC Server function.
165 |
166 | See: https://github.com/apache/arrow/issues/36844
167 | for more details...
168 | """
169 |
170 | @functools.wraps(func)
171 | def wrapper_decorator(*args, **kwargs):
172 | try:
173 | import pydevd
174 | pydevd.connected = True
175 | pydevd.settrace(suspend=False)
176 | except ImportError:
177 | # Not running in debugger
178 | pass
179 | value = func(*args, **kwargs)
180 | return value
181 |
182 | return wrapper_decorator
183 |
184 |
185 | class FlightServer(pyarrow.flight.FlightServerBase):
186 | @cached_property
187 | def class_name(self):
188 | return self.__class__.__name__
189 |
190 | def __init__(self,
191 | host_uri: str,
192 | location_uri: str,
193 | max_endpoints: int,
194 | database_file: Path,
195 | duckdb_threads: int,
196 | duckdb_memory_limit: str,
197 | logger,
198 | tls_certificates=None,
199 | verify_client=False,
200 | root_certificates=None,
201 | auth_handler=None,
202 | middleware=None,
203 | log_level: str = None,
204 | log_file: str = None,
205 | log_file_mode: str = None
206 | ):
207 | self.logger = logger
208 |
209 | redacted_locals = {key: value for key, value in locals().items() if key not in ["tls_certificates",
210 | "root_certificates"
211 | ]
212 | }
213 | self.logger.info(msg=f"Flight Server init args: {redacted_locals}")
214 |
215 | if not database_file.exists():
216 | raise RuntimeError(f"The specified database file: '{database_file.as_posix()}' does not exist, aborting.")
217 |
218 | self.flights = {}
219 | self.tls_certificates = tls_certificates
220 | self.host_uri = host_uri
221 | self.location_uri = location_uri
222 | self.max_endpoints = max_endpoints
223 |
224 | # Get an Ibis DuckDB connection
225 | self.ibis_connection = ibis.duckdb.connect(database=database_file,
226 | threads=duckdb_threads,
227 | memory_limit=duckdb_memory_limit,
228 | read_only=True
229 | )
230 | self.customer_order_summary_expr = build_customer_order_summary_expr(conn=self.ibis_connection)
231 |
232 | # Start the Flight RPC server now that the summary expression is built
233 | super(FlightServer, self).__init__(
234 | location=host_uri,
235 | auth_handler=auth_handler,
236 | middleware=middleware,
237 | tls_certificates=tls_certificates,
238 | verify_client=verify_client,
239 | root_certificates=root_certificates
240 | )
241 |
242 | self.golden_rules_ibis_expression = build_golden_rules_ibis_expression(conn=self.ibis_connection,
243 | customer_order_summary_expr=self.customer_order_summary_expr
244 | )
245 | self.logger.info(f"Running Flight-Ibis server - version: {flight_server_version}")
246 | self.logger.info(f"Using Python version: {sys.version}")
247 | self.logger.info(f"Using PyArrow version: {pyarrow.__version__}")
248 | self.logger.info(f"Using Ibis version: {ibis.__version__}")
249 | self.logger.info(f"Using DuckDB version: {duckdb.__version__}")
250 | self.logger.info("Database details:")
251 | self.logger.info(f" Database file: {database_file.as_posix()}")
252 | self.logger.info(f" Threads: {duckdb_threads}")
253 | self.logger.info(f" Memory Limit: {duckdb_memory_limit}")
254 | self.logger.info(f"Serving on {self.host_uri} (generated end-points will refer to location: {self.location_uri})")
255 |
256 | def _get_reference_dataset(self) -> pyarrow.Table:
257 | self.logger.debug(msg="Attempting to acquire semaphore...")
258 | with pool_sema:
259 | self.logger.debug(msg="Semaphore successfully acquired...")
260 | pyarrow_table = execute_golden_rules(golden_rules_ibis_expression=self.golden_rules_ibis_expression,
261 | hash_bucket_num=1,
262 | total_hash_buckets=1,
263 | min_date=BEGINNING_OF_TIME,
264 | max_date=BEGINNING_OF_TIME,
265 | existing_logger=self.logger
266 | )
267 | self.logger.debug(msg="Semaphore released...")
268 |
269 | return pyarrow_table
270 |
271 | def _make_flight_info(self, descriptor: pyarrow.flight.FlightDescriptor) -> pyarrow.flight.FlightInfo:
272 | self.logger.debug(msg=f"{self.class_name}._make_flight_info - was called with args: {locals()}")
273 |
274 | command = self._get_descriptor_command(descriptor=descriptor)
275 | command_munch = self._check_command(command=command)
276 | command_munch.kwargs.total_hash_buckets = min(self.max_endpoints, command_munch.kwargs.get("num_endpoints", self.max_endpoints))
277 |
278 | self.logger.debug(msg=f"{self.class_name}._make_flight_info - descriptor: {descriptor}")
279 |
280 | schema = self._get_schema(descriptor=descriptor)
281 |
282 | endpoints = []
283 | for i in range(1, (command_munch.kwargs.total_hash_buckets + 1)):
284 | command_munch.kwargs.hash_bucket_num = i
285 | endpoints.append(pyarrow.flight.FlightEndpoint(json.dumps(command_munch.toDict()), [self.location_uri]))
286 |
287 | return pyarrow.flight.FlightInfo(schema=schema,
288 | descriptor=descriptor,
289 | endpoints=endpoints,
290 | total_records=PYARROW_UNKNOWN,
291 | total_bytes=PYARROW_UNKNOWN
292 | )
293 |
294 | def _check_command(self, command: dict) -> Munch:
295 | self.logger.debug(msg=f"{self.class_name}._check_command - was called with args: {locals()}")
296 | command_munch: Munch = munchify(x=command)
297 |
298 | if command_munch.command != "get_golden_rule_facts":
299 | error_message = f"{self.class_name}._check_command - Command: {command_munch.command} is not supported."
300 | self.logger.error(msg=error_message)
301 | raise RuntimeError(error_message)
302 | else:
303 | return command_munch
304 |
305 | def _get_descriptor_command(self, descriptor: pyarrow.flight.FlightDescriptor) -> dict:
306 | self.logger.debug(msg=f"{self.class_name}._get_descriptor_command - was called with args: {locals()}")
307 | try:
308 | command = json.loads(descriptor.command.decode('utf-8'))
309 | except Exception as e:
310 | self.logger.exception(msg=f"{self.class_name}._get_descriptor_command - failed with Exception: {str(e)}")
311 | raise
312 | else:
313 | self.logger.debug(msg=f"{self.class_name}._get_descriptor_command - returning: {command}")
314 | return command
315 |
316 | def _get_ticket_command(self, ticket: pyarrow.flight.Ticket) -> dict:
317 | self.logger.debug(msg=f"{self.class_name}._get_ticket_command - was called with args: {locals()}")
318 | command = json.loads(ticket.ticket.decode('utf-8'))
319 | self.logger.debug(msg=f"{self.class_name}._get_ticket_command - returning: {command}")
320 | return command
321 |
322 | @debuggable
323 | def get_flight_info(self, context: pyarrow.flight.ServerCallContext, descriptor: pyarrow.flight.FlightDescriptor) -> pyarrow.flight.FlightInfo:
324 | self.logger.info(msg=f"{self.class_name}.get_flight_info - was called with args: {locals()}")
325 | try:
326 | self.logger.info(msg=f"{self.class_name}.get_flight_info - called with context = {context}, descriptor = {descriptor}")
327 | flight_info = self._make_flight_info(descriptor=descriptor)
328 | except Exception as e:
329 | self.logger.exception(msg=(f"{self.class_name}.get_flight_info - with context = {context}, descriptor = {descriptor}"
330 | f"- failed with exception: {str(e)}"
331 | )
332 | )
333 | raise
334 | else:
335 | self.logger.info(msg=(f"{self.class_name}.get_flight_info - with context = {context}, descriptor = {descriptor}"
336 | f"- returning: FlightInfo ({dict(schema=flight_info.schema, endpoints=flight_info.endpoints)})"
337 | )
338 | )
339 | return flight_info
340 |
341 | def _get_command_munch_from_descriptor(self, descriptor: pyarrow.flight.FlightDescriptor) -> Munch:
342 | command = self._get_descriptor_command(descriptor=descriptor)
343 | command_munch = self._check_command(command=command)
344 |
345 | return command_munch
346 |
347 | def _get_schema(self, descriptor: pyarrow.flight.FlightDescriptor) -> pyarrow.Schema:
348 | self.logger.debug(msg=f"{self.class_name}._get_schema - was called with args: {locals()}")
349 | try:
350 | reference_dataset = self._get_reference_dataset()
351 | command_munch = self._get_command_munch_from_descriptor(descriptor=descriptor)
352 |
353 | if hasattr(command_munch, "columns"):
354 | schema = reference_dataset.select(command_munch.columns).schema
355 | else:
356 | schema = reference_dataset.schema
357 |
358 | except Exception as e:
359 | self.logger.exception(msg=(f"{self.class_name}._get_schema - failed with Exception: {str(e)}"))
360 | raise
361 | else:
362 | self.logger.debug(msg=(f"{self.class_name}._get_schema - with descriptor = {descriptor}"
363 | f"- returning: Schema ({dict(schema=schema)})"
364 | )
365 | )
366 |
367 | return schema
368 |
369 | @debuggable
370 | def get_schema(self, context: pyarrow.flight.ServerCallContext, descriptor: pyarrow.flight.FlightDescriptor) -> SchemaResult:
371 | self.logger.info(msg=f"{self.class_name}.get_schema - was called with args: {locals()}")
372 | schema = self._get_schema(descriptor=descriptor)
373 | self.logger.info(msg=(f"{self.class_name}.get_schema - with context = {context}, descriptor = {descriptor}"
374 | f"- returning: SchemaResult ({dict(schema=schema)})"
375 | )
376 | )
377 | return SchemaResult(schema)
378 |
379 | def _build_filter_expression(self, filter_munch: Munch) -> pyarrow.compute.Expression:
380 | self.logger.debug(msg=f"{self.class_name}._build_filter_expression - was called with args: {locals()}")
381 |
382 | field = pyarrow.compute.field(filter_munch.column)
383 |
384 | if filter_munch.operator == "=":
385 | filter_expr = field == filter_munch.value
386 | elif filter_munch.operator == "<":
387 | filter_expr = field < filter_munch.value
388 | elif filter_munch.operator == "<=":
389 | filter_expr = field <= filter_munch.value
390 | elif filter_munch.operator == ">":
391 | filter_expr = field > filter_munch.value
392 | elif filter_munch.operator == ">=":
393 | filter_expr = field >= filter_munch.value
394 | else:
395 | filter_expr = None
396 |
397 | return filter_expr
398 |
399 | @debuggable
400 | def do_get(self, context: pyarrow.flight.ServerCallContext, ticket: pyarrow.flight.Ticket) -> pyarrow.flight.FlightDataStream:
401 | self.logger.info(msg=f"{self.class_name}.do_get - was called with args: {locals()}")
402 |
403 | try:
404 | command = self._get_ticket_command(ticket=ticket)
405 | command_munch = self._check_command(command=command)
406 |
407 | golden_rule_kwargs = dict(golden_rules_ibis_expression=self.golden_rules_ibis_expression,
408 | hash_bucket_num=command_munch.kwargs.hash_bucket_num,
409 | total_hash_buckets=command_munch.kwargs.total_hash_buckets,
410 | min_date=datetime.fromisoformat(command_munch.kwargs.min_date),
411 | max_date=datetime.fromisoformat(command_munch.kwargs.max_date),
412 | existing_logger=self.logger
413 | )
414 | self.logger.debug(msg=f"{self.class_name}.do_get - calling get_golden_rule_facts with args: {str(golden_rule_kwargs)}")
415 |
416 | self.logger.debug(msg="Attempting to acquire semaphore...")
417 | with pool_sema:
418 | self.logger.debug(msg="Semaphore successfully acquired...")
419 | pyarrow_table = execute_golden_rules(**golden_rule_kwargs)
420 | self.logger.debug(msg="Semaphore released...")
421 |
422 | if hasattr(command_munch, "filters"):
423 | combined_filter_expr = None
424 | for filter in command_munch.filters:
425 | filter_expr = self._build_filter_expression(filter_munch=filter)
426 |
427 | if isinstance(filter_expr, pyarrow.compute.Expression):
428 | if not combined_filter_expr:
429 | combined_filter_expr = filter_expr
430 | else:
431 | combined_filter_expr = combined_filter_expr & filter_expr
432 |
433 | pyarrow_table = pyarrow_table.filter(combined_filter_expr)
434 |
435 | if hasattr(command_munch, "columns"):
436 | pyarrow_table = pyarrow_table.select(command_munch.columns)
437 |
438 | except Exception as e:
439 | error_message = f"{self.class_name}.get_flight_info - Exception: {str(e)}"
440 | self.logger.exception(msg=error_message)
441 | raise
442 | else:
443 | self.logger.info(msg=f"{self.class_name}.do_get - context: {context} - ticket: {ticket} - returning a PyArrow RecordBatchReader with schema: {pyarrow_table.schema}")
444 |
445 | return pyarrow.flight.RecordBatchStream(data_source=pyarrow_table)
446 |
447 | @debuggable
448 | def do_action(self, context: pyarrow.flight.ServerCallContext, action: pyarrow.flight.Action) -> list:
449 | self.logger.info(msg=f"{self.class_name}.do_action - was called with args: {locals()}")
450 | if action.type == "who-am-i":
451 | self.logger.debug(msg=f"{self.class_name}.do_action - returning: {context.peer_identity()}")
452 | return [context.get_middleware('basic').who_am_i().encode()]
453 | raise NotImplementedError
454 |
455 |
456 | @click.command()
457 | @click.option(
458 | "--host",
459 | type=str,
460 | default=os.getenv("FLIGHT_HOST", LOCALHOST_IP_ADDRESS),
461 | required=True,
462 | help="Address (or hostname) to listen on"
463 | )
464 | @click.option(
465 | "--location",
466 | type=str,
467 | default=os.getenv("FLIGHT_LOCATION", f"{LOCALHOST}:{os.getenv('FLIGHT_PORT', DEFAULT_FLIGHT_PORT)}"),
468 | required=True,
469 | help=("Address or hostname for TLS and endpoint generation. This is needed if running the Flight server behind a load balancer and/or "
470 | "a reverse proxy"
471 | )
472 | )
473 | @click.option(
474 | "--port",
475 | type=int,
476 | default=os.getenv("FLIGHT_PORT", DEFAULT_FLIGHT_PORT),
477 | required=True,
478 | help="Port number to listen on"
479 | )
480 | @click.option(
481 | "--max-endpoints",
482 | type=int,
483 | default=os.getenv("MAX_FLIGHT_ENDPOINTS", DEFAULT_FLIGHT_ENDPOINTS),
484 | required=True,
485 | help="The maximum number of Flight end-points to produce for get_flight_info. This is useful if running in Kubernetes with multiple replicas."
486 | )
487 | @click.option(
488 | "--database-file",
489 | type=str,
490 | default=os.getenv("DATABASE_FILE", DUCKDB_DB_FILE.as_posix()),
491 | required=True,
492 | help="The DuckDB database file used for servicing data requests..."
493 | )
494 | @click.option(
495 | "--duckdb-threads",
496 | type=int,
497 | required=True,
498 | default=os.getenv("DUCKDB_THREADS", DUCKDB_THREADS),
499 | help="The number of threads to use for the DuckDB connection."
500 | )
501 | @click.option(
502 | "--duckdb-memory-limit",
503 | type=str,
504 | required=True,
505 | default=os.getenv("DUCKDB_MEMORY_LIMIT", DUCKDB_MEMORY_LIMIT),
506 | help="The amount of memory to use for the DuckDB connection"
507 | )
508 | @click.option(
509 | "--tls",
510 | nargs=2,
511 | default=os.getenv("FLIGHT_TLS").split(" ") if os.getenv("FLIGHT_TLS") else None,
512 | required=False,
513 | metavar=('CERTFILE', 'KEYFILE'),
514 | help="Enable transport-level security"
515 | )
516 | @click.option(
517 | "--verify-client/--no-verify-client",
518 | type=bool,
519 | default=(os.getenv("FLIGHT_VERIFY_CLIENT", "False").upper() == "TRUE"),
520 | show_default=True,
521 | required=True,
522 | help="enable mutual TLS and verify the client if True"
523 | )
524 | @click.option(
525 | "--mtls",
526 | type=str,
527 | default=os.getenv("FLIGHT_MTLS"),
528 | required=False,
529 | help="If you provide verify-client, you must supply an MTLS CA Certificate file (public key only)"
530 | )
531 | @click.option(
532 | "--flight-username",
533 | type=str,
534 | default=os.getenv("FLIGHT_USERNAME"),
535 | required=False,
536 | show_default=False,
537 | help="If supplied, authentication will be required from clients to connect with this username"
538 | )
539 | @click.option(
540 | "--flight-password",
541 | type=str,
542 | default=os.getenv("FLIGHT_PASSWORD"),
543 | required=False,
544 | show_default=False,
545 | help="If supplied, authentication will be required from clients to connect with this password"
546 | )
547 | @click.option(
548 | "--log-level",
549 | type=click.Choice(["INFO", "DEBUG", "WARNING", "CRITICAL"], case_sensitive=False),
550 | default=os.getenv("LOG_LEVEL", "INFO"),
551 | required=True,
552 | help="The logging level to use"
553 | )
554 | @click.option(
555 | "--log-file",
556 | type=str,
557 | default=os.getenv("LOG_FILE"),
558 | required=False,
559 | help="The log file to write to. If None, will just log to stdout"
560 | )
561 | @click.option(
562 | "--log-file-mode",
563 | type=click.Choice(["a", "w"], case_sensitive=True),
564 | default=os.getenv("LOG_FILE_MODE", "w"),
565 | help="The log file mode, use value: a for 'append', and value: w to overwrite..."
566 | )
567 | def run_flight_server(host: str,
568 | location: str,
569 | port: int,
570 | max_endpoints: int,
571 | database_file: str,
572 | duckdb_threads: int,
573 | duckdb_memory_limit: str,
574 | tls: list,
575 | verify_client: bool,
576 | mtls: str,
577 | flight_username: str,
578 | flight_password: str,
579 | log_level: str,
580 | log_file: str,
581 | log_file_mode: str
582 | ):
583 | tls_certificates = []
584 | scheme = GRPC_TCP_SCHEME
585 | if tls:
586 | scheme = GRPC_TLS_SCHEME
587 | with open(tls[0], "rb") as cert_file:
588 | tls_cert_chain = cert_file.read()
589 | with open(tls[1], "rb") as key_file:
590 | tls_private_key = key_file.read()
591 | tls_certificates.append((tls_cert_chain, tls_private_key))
592 |
593 | root_certificates = None
594 | if verify_client:
595 | if not mtls:
596 | raise RuntimeError("You MUST provide a CA certificate public key file path if 'verify_client' is True, aborting.")
597 |
598 | if not tls:
599 | raise RuntimeError("TLS must be enabled in order to use MTLS, aborting.")
600 |
601 | with open(mtls, "rb") as mtls_ca_file:
602 | root_certificates = mtls_ca_file.read()
603 |
604 | logger = get_logger(filename=log_file,
605 | filemode=log_file_mode,
606 | logger_name="flight_server",
607 | log_level=getattr(logging, log_level.upper())
608 | )
609 |
610 | auth_handler = None
611 | middleware = None
612 | if flight_username and flight_password:
613 | if not tls:
614 | raise RuntimeError("TLS must be enabled in order to use authentication, aborting.")
615 | auth_handler = NoOpAuthHandler()
616 | middleware = dict(basic=BasicAuthServerMiddlewareFactory(creds={flight_username: flight_password},
617 | cert=tls_cert_chain,
618 | key=tls_private_key,
619 | logger=logger
620 | ))
621 |
622 | host_uri = f"{scheme}://{host}:{port}"
623 | location_uri = f"{scheme}://{location}"
624 | server = FlightServer(host_uri=host_uri,
625 | location_uri=location_uri,
626 | max_endpoints=max_endpoints,
627 | database_file=Path(database_file),
628 | duckdb_threads=duckdb_threads,
629 | duckdb_memory_limit=duckdb_memory_limit,
630 | logger=logger,
631 | tls_certificates=tls_certificates,
632 | verify_client=verify_client,
633 | root_certificates=root_certificates,
634 | auth_handler=auth_handler,
635 | middleware=middleware,
636 | log_level=log_level,
637 | log_file=log_file,
638 | log_file_mode=log_file_mode
639 | )
640 | try:
641 | server.serve()
642 | except Exception as e:
643 | server.logger.exception(msg=f"Flight server had exception: {str(e)}")
644 | raise
645 | finally:
646 | server.logger.warning(msg="Flight server shutdown")
647 | logging.shutdown()
648 |
649 |
650 | if __name__ == '__main__':
651 | run_flight_server()
652 |
--------------------------------------------------------------------------------
/src/flight_ibis/setup/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gizmodata/flight-ibis-demo/e2543826a5212fa380cb0286ce6a500f4c478e82/src/flight_ibis/setup/__init__.py
--------------------------------------------------------------------------------
/src/flight_ibis/setup/create_local_duckdb_database.py:
--------------------------------------------------------------------------------
1 | import click
2 | import duckdb
3 | from ..config import DUCKDB_DB_FILE, get_logger
4 | from pathlib import Path
5 |
6 |
7 | logger = get_logger()
8 |
9 |
10 | @click.command()
11 | @click.option(
12 | "--database-file",
13 | type=str,
14 | default=DUCKDB_DB_FILE.as_posix(),
15 | help="The database file to create"
16 | )
17 | @click.option(
18 | "--scale-factor",
19 | type=float,
20 | default=1.0,
21 | help="TPC-H Scale Factor used to create the database file."
22 | )
23 | @click.option(
24 | "--overwrite/--no-overwrite",
25 | type=bool,
26 | default=False,
27 | show_default=True,
28 | required=True,
29 | help="Can we overwrite the target --database-file if it already exists..."
30 | )
31 | def create_local_duckdb_database(database_file: str,
32 | scale_factor: float,
33 | overwrite: bool):
34 | logger.info(msg=f"create_local_duckdb_database - was called with args: {locals()}")
35 |
36 | logger.info(msg=f"Using DuckDB version: {duckdb.__version__}")
37 |
38 | database_file_path = Path(database_file)
39 |
40 | if database_file_path.exists():
41 | if overwrite:
42 | logger.warning(msg=f"Deleting existing database file: '{database_file_path}'")
43 | database_file_path.unlink(missing_ok=True)
44 | else:
45 | raise RuntimeError(f"Database file: '{database_file_path}' - already exists, aborting b/c overwrite False")
46 |
47 | # Get a DuckDB database connection
48 | with duckdb.connect(database=database_file_path) as conn:
49 | logger.info(msg=f"Creating DuckDB Database file: '{database_file_path}'")
50 |
51 | # Install the TPCH extension needed to generate the data...
52 | conn.load_extension(extension="tpch")
53 |
54 | # Generate the data
55 | sql_statement = f"CALL dbgen(sf={scale_factor})"
56 | logger.info(f"Running SQL: {sql_statement}")
57 | conn.execute(query=sql_statement,
58 | )
59 |
60 | logger.info(msg=f"Successfully created DuckDB Database file: '{database_file_path}'")
61 |
62 |
63 | if __name__ == '__main__':
64 | create_local_duckdb_database()
65 |
--------------------------------------------------------------------------------
/src/flight_ibis/setup/mtls_utilities.py:
--------------------------------------------------------------------------------
1 | import OpenSSL
2 | import os
3 | from pathlib import Path
4 | import click
5 |
6 | # Constants
7 | CERTIFICATE_VERSION: int = (3 - 1)
8 |
9 | TLS_DIR = Path("tls").resolve()
10 |
11 |
12 | @click.command()
13 | @click.option(
14 | "--ca-common-name",
15 | type=str,
16 | default="Flight Ibis Demo CA",
17 | required=True,
18 | help="The common name to create the Certificate Authority for."
19 | )
20 | @click.option(
21 | "--ca-cert-file",
22 | type=str,
23 | default=(TLS_DIR / "ca.crt").as_posix(),
24 | required=True,
25 | help="The CA certificate file to create."
26 | )
27 | @click.option(
28 | "--ca-key-file",
29 | type=str,
30 | default=(TLS_DIR / "ca.key").as_posix(),
31 | required=True,
32 | help="The CA key file to create."
33 | )
34 | @click.option(
35 | "--overwrite/--no-overwrite",
36 | type=bool,
37 | default=False,
38 | show_default=True,
39 | required=True,
40 | help="Can we overwrite the CA cert/key if they exist?"
41 | )
42 | def create_ca_keypair(ca_common_name: str,
43 | ca_cert_file: str,
44 | ca_key_file: str,
45 | overwrite: bool
46 | ):
47 | ca_cert_file_path = Path(ca_cert_file)
48 | ca_key_file_path = Path(ca_key_file)
49 |
50 | if ca_cert_file_path.exists() or ca_key_file_path.exists():
51 | if not overwrite:
52 | raise RuntimeError(f"The CA Cert file(s): '{ca_cert_file_path.as_posix()}' or '{ca_key_file_path.as_posix()}' - exist - and overwrite is False, aborting.")
53 | else:
54 | ca_cert_file_path.unlink(missing_ok=True)
55 | ca_key_file_path.unlink(missing_ok=True)
56 |
57 | # Generate a new key pair for the CA
58 | key = OpenSSL.crypto.PKey()
59 | key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
60 |
61 | # Generate a self-signed CA certificate
62 | ca_cert = OpenSSL.crypto.X509()
63 | ca_cert.get_subject().CN = ca_common_name
64 | ca_cert.set_version(CERTIFICATE_VERSION)
65 | ca_cert.set_serial_number(1000)
66 | ca_cert.gmtime_adj_notBefore(0)
67 | ca_cert.gmtime_adj_notAfter(365*24*60*60) # 1 year validity
68 | ca_cert.set_issuer(ca_cert.get_subject())
69 | ca_cert.set_pubkey(key)
70 | ca_cert.add_extensions([
71 | OpenSSL.crypto.X509Extension(b"basicConstraints", True, b"CA:TRUE"),
72 | OpenSSL.crypto.X509Extension(b"subjectKeyIdentifier", False, b"hash", subject=ca_cert),
73 | ])
74 | ca_cert.sign(key, "sha256")
75 |
76 | # Write the CA certificate and key to disk
77 | with open(Path(ca_cert_file_path), "wb") as f:
78 | f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca_cert))
79 | with open(Path(ca_key_file_path), "wb") as f:
80 | f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
81 |
82 | print(f"Successfully created Certificate Authority (CA) keypair for Common Name (CN): '{ca_common_name}'")
83 | print(f"CA Cert file: {ca_cert_file_path.as_posix()}")
84 | print(f"CA Key file: {ca_key_file_path.as_posix()}")
85 |
86 |
87 | @click.command()
88 | @click.option(
89 | "--client-common-name",
90 | type=str,
91 | default="Flight Ibis Demo Client",
92 | required=True,
93 | help="The common name to create the Client certificate for."
94 | )
95 | @click.option(
96 | "--ca-cert-file",
97 | type=str,
98 | default=(TLS_DIR / "ca.crt").as_posix(),
99 | required=True,
100 | help="The CA certificate file used to sign the client certificate - is MUST exist."
101 | )
102 | @click.option(
103 | "--ca-key-file",
104 | type=str,
105 | default=(TLS_DIR / "ca.key").as_posix(),
106 | required=True,
107 | help="The CA key file used to sign the client certificate - is MUST exist."
108 | )
109 | @click.option(
110 | "--client-cert-file",
111 | type=str,
112 | default=(TLS_DIR / "client.crt").as_posix(),
113 | required=True,
114 | help="The Client certificate file to create."
115 | )
116 | @click.option(
117 | "--client-key-file",
118 | type=str,
119 | default=(TLS_DIR / "client.key").as_posix(),
120 | required=True,
121 | help="The Client key file to create."
122 | )
123 | @click.option(
124 | "--overwrite/--no-overwrite",
125 | type=bool,
126 | default=False,
127 | show_default=True,
128 | required=True,
129 | help="Can we overwrite the client cert/key if they exist?"
130 | )
131 | def create_client_keypair(client_common_name: str,
132 | ca_cert_file: str,
133 | ca_key_file: str,
134 | client_cert_file: str,
135 | client_key_file: str,
136 | overwrite: bool
137 | ):
138 | ca_cert_file_path = Path(ca_cert_file)
139 | ca_key_file_path = Path(ca_key_file)
140 |
141 | if not ca_cert_file_path.exists():
142 | raise RuntimeError(f"The CA Cert file: '{ca_cert_file_path.as_posix()}' does not exist, aborting")
143 |
144 | if not ca_key_file_path.exists():
145 | raise RuntimeError(f"The CA Key file: '{ca_key_file_path.as_posix()}' does not exist, aborting")
146 |
147 | client_cert_file_path = Path(client_cert_file)
148 | client_key_file_path = Path(client_key_file)
149 |
150 | if client_cert_file_path.exists() or client_key_file_path.exists():
151 | if not overwrite:
152 | raise RuntimeError(f"The Client Cert file(s): '{client_cert_file_path.as_posix()}' or '{client_key_file_path.as_posix()}' - exist - and overwrite is False, aborting.")
153 | else:
154 | client_cert_file_path.unlink(missing_ok=True)
155 | client_key_file_path.unlink(missing_ok=True)
156 |
157 | # Generate a new key pair for the client
158 | key = OpenSSL.crypto.PKey()
159 | key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
160 |
161 | # Generate a certificate signing request (CSR) for the client
162 | req = OpenSSL.crypto.X509Req()
163 | req.get_subject().CN = client_common_name
164 | req.set_pubkey(key)
165 | req.sign(key, "sha256")
166 |
167 | # Load the CA certificate and key from disk
168 | with open(ca_cert_file_path, "rb") as f:
169 | ca_cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, f.read())
170 | with open(ca_key_file_path, "rb") as f:
171 | ca_key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, f.read())
172 |
173 | # Create a new certificate for the client, signed by the CA
174 | client_cert = OpenSSL.crypto.X509()
175 | client_cert.set_version(CERTIFICATE_VERSION)
176 | client_cert.set_subject(req.get_subject())
177 | client_cert.set_serial_number(2000)
178 | client_cert.gmtime_adj_notBefore(0)
179 | client_cert.gmtime_adj_notAfter(365*24*60*60) # 1 year validity
180 | client_cert.set_issuer(ca_cert.get_subject())
181 | client_cert.set_pubkey(req.get_pubkey())
182 | client_cert.add_extensions([
183 | OpenSSL.crypto.X509Extension(b"basicConstraints", True, b"CA:FALSE"),
184 | OpenSSL.crypto.X509Extension(b"subjectKeyIdentifier", False, b"hash", subject=client_cert),
185 | ])
186 | client_cert.sign(ca_key, "sha256")
187 |
188 | # Write the client certificate and key to disk
189 | with open(client_cert_file_path, "wb") as f:
190 | f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, client_cert))
191 | with open(client_key_file_path, "wb") as f:
192 | f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
193 |
194 | print(f"Successfully created Client certificate keypair for Common Name (CN): '{client_common_name}'")
195 | print(f"Signed Client Certificate keypair with CA: '{ca_cert.get_subject()}'''s private key")
196 | print(f"Client Cert file: {client_cert_file_path.as_posix()}")
197 | print(f"Client Key file: {client_key_file_path.as_posix()}")
198 |
--------------------------------------------------------------------------------
/src/flight_ibis/setup/tls_utilities.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 | """Utilities for generating self-signed TLS certificates using pyOpenSSL."""
3 |
4 | import os
5 | import socket
6 | from pathlib import Path
7 |
8 | import click
9 | from OpenSSL import crypto
10 |
11 | from ..config import get_logger
12 |
13 | # Constants
14 | TLS_DIR = Path("tls")
15 | DEFAULT_CERT_FILE = (TLS_DIR / "server.crt")
16 | DEFAULT_KEY_FILE = (TLS_DIR / "server.key")
17 |
18 |
19 | logger = get_logger()
20 |
21 |
22 | def _gen_pyopenssl(common_name: str) -> tuple[bytes, bytes]:
23 | """Generate a self-signed certificate using pyOpenSSL."""
24 |
25 | # Generate RSA private key
26 | private_key = crypto.PKey()
27 | private_key.generate_key(crypto.TYPE_RSA, 2048)
28 |
29 | # Generate X.509 certificate
30 | cert = crypto.X509()
31 |
32 | # Set certificate subject and issuer to be the same (self-signed)
33 | cert.get_subject().CN = common_name
34 | cert.set_issuer(cert.get_subject())
35 |
36 | # Set public key
37 | cert.set_pubkey(private_key)
38 |
39 | # Set certificate serial number
40 | cert.set_serial_number(int.from_bytes(os.urandom(8), "big"))
41 |
42 | # Set validity period (valid from now to 5 years in the future)
43 | cert.gmtime_adj_notBefore(0)
44 | cert.gmtime_adj_notAfter(5 * 365 * 24 * 60 * 60) # 5 years
45 |
46 | # Add subject alternative names
47 | san_list = [
48 | f"DNS:{common_name}",
49 | f"DNS:*.{common_name}",
50 | "DNS:localhost",
51 | "DNS:*.localhost"
52 | ]
53 | hostname = socket.gethostname()
54 | if hostname != common_name:
55 | san_list.append(f"DNS:{hostname}")
56 | san_list.append(f"DNS:*.{hostname}")
57 |
58 | san_extension = crypto.X509Extension(b"subjectAltName", False, ", ".join(san_list).encode())
59 | cert.add_extensions([
60 | san_extension,
61 | crypto.X509Extension(b"basicConstraints", True, b"CA:FALSE")
62 | ])
63 |
64 | # Sign the certificate with the private key
65 | cert.sign(private_key, 'sha256')
66 |
67 | # Convert certificate and private key to PEM format
68 | cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
69 | private_key_pem = crypto.dump_privatekey(crypto.FILETYPE_PEM, private_key)
70 |
71 | return cert_pem, private_key_pem
72 |
73 |
74 | def gen_self_signed_cert(common_name: str) -> tuple[bytes, bytes]:
75 | """Return (cert, key) as ASCII PEM strings using pyOpenSSL."""
76 | return _gen_pyopenssl(common_name=common_name)
77 |
78 |
79 | def create_tls_keypair(
80 | cert_file: str,
81 | key_file: str,
82 | overwrite: bool,
83 | common_name: str
84 | ):
85 | """Create a self-signed TLS key pair and write to disk."""
86 | logger.info(msg=f"create_tls_keypair was called with args: {locals()}")
87 |
88 | cert_file_path = Path(cert_file)
89 | key_file_path = Path(key_file)
90 |
91 | if cert_file_path.exists() or key_file_path.exists():
92 | if not overwrite:
93 | raise RuntimeError(
94 | f"The TLS Cert file(s): '{cert_file_path}' or '{key_file_path}' exist - "
95 | "and overwrite is False, aborting."
96 | )
97 |
98 | cert_file_path.unlink(missing_ok=True)
99 | key_file_path.unlink(missing_ok=True)
100 |
101 | cert, key = gen_self_signed_cert(common_name=common_name)
102 |
103 | cert_file_path.parent.mkdir(parents=True, exist_ok=True)
104 | with cert_file_path.open(mode="wb") as cert_file:
105 | cert_file.write(cert)
106 |
107 | with key_file_path.open(mode="wb") as key_file:
108 | key_file.write(key)
109 |
110 | logger.info("Created TLS Key pair successfully.")
111 | logger.info(f"Cert file path: {cert_file_path}")
112 | logger.info(f"Key file path: {key_file_path}")
113 |
114 |
115 | @click.command()
116 | @click.option(
117 | "--cert-file",
118 | type=str,
119 | default=DEFAULT_CERT_FILE,
120 | required=True,
121 | help="The TLS certificate file to create.",
122 | )
123 | @click.option(
124 | "--key-file",
125 | type=str,
126 | default=DEFAULT_KEY_FILE,
127 | required=True,
128 | help="The TLS key file to create.",
129 | )
130 | @click.option(
131 | "--overwrite/--no-overwrite",
132 | type=bool,
133 | default=False,
134 | show_default=True,
135 | required=True,
136 | help="Can we overwrite the cert/key if they exist?",
137 | )
138 | @click.option(
139 | "--common-name",
140 | type=str,
141 | default=socket.gethostname(),
142 | show_default=True,
143 | required=True,
144 | help="Set the common name (CN) for the certificate.",
145 | )
146 | def click_create_tls_keypair(cert_file: str,
147 | key_file: str,
148 | overwrite: bool,
149 | common_name: str
150 | ):
151 | """Provide a click interface to create a self-signed TLS key pair."""
152 | create_tls_keypair(**locals())
153 |
154 |
155 | if __name__ == "__main__":
156 | click_create_tls_keypair()
157 |
--------------------------------------------------------------------------------
/src/flight_ibis/spark_client_test.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from datetime import date, datetime
4 | from pathlib import Path
5 |
6 | import click
7 | from codetiming import Timer
8 | from pyspark.sql import SparkSession
9 | from pyspark import __version__ as pyspark_version
10 |
11 | from .config import logging, get_logger, TIMER_TEXT
12 | from .constants import LOCALHOST, GRPC_TCP_SCHEME, GRPC_TLS_SCHEME
13 |
14 |
15 | @click.command()
16 | @click.option(
17 | "--host",
18 | type=str,
19 | default=LOCALHOST,
20 | help="Address or hostname of the Flight server to connect to"
21 | )
22 | @click.option(
23 | "--port",
24 | type=int,
25 | default=8815,
26 | help="Port number of the Flight server to connect to"
27 | )
28 | @click.option(
29 | "--tls/--no-tls",
30 | type=bool,
31 | default=False,
32 | show_default=True,
33 | required=True,
34 | help="Connect to the server with tls"
35 | )
36 | @click.option(
37 | "--tls-roots",
38 | type=str,
39 | default=None,
40 | show_default=True,
41 | help="'Path to trusted TLS certificate(s)"
42 | )
43 | @click.option(
44 | "--mtls",
45 | nargs=2,
46 | default=None,
47 | metavar=('CERTFILE', 'KEYFILE'),
48 | help="Enable transport-level security"
49 | )
50 | @click.option(
51 | "--flight-username",
52 | type=str,
53 | default=os.getenv("FLIGHT_USERNAME"),
54 | required=False,
55 | show_default=False,
56 | help="The username used to connect to the Flight server"
57 | )
58 | @click.option(
59 | "--flight-password",
60 | type=str,
61 | default=os.getenv("FLIGHT_PASSWORD"),
62 | required=False,
63 | show_default=False,
64 | help="The password used to connect to the Flight server"
65 | )
66 | @click.option(
67 | "--log-level",
68 | type=click.Choice(["INFO", "DEBUG", "WARNING", "CRITICAL"], case_sensitive=False),
69 | default=os.getenv("LOG_LEVEL", "INFO"),
70 | required=True,
71 | help="The logging level to use"
72 | )
73 | @click.option(
74 | "--log-file",
75 | type=str,
76 | required=False,
77 | help="The log file to write to. If None, will just log to stdout"
78 | )
79 | @click.option(
80 | "--log-file-mode",
81 | type=click.Choice(["a", "w"], case_sensitive=True),
82 | default="w",
83 | help="The log file mode, use value: a for 'append', and value: w to overwrite..."
84 | )
85 | @click.option(
86 | "--from-date",
87 | type=click.DateTime(formats=["%Y-%m-%d"]),
88 | default=date(year=1994, month=1, day=1).isoformat(),
89 | required=True,
90 | help="The from date to use for the data filter - in ISO format (example: 2020-11-01) - for 01-November-2020"
91 | )
92 | @click.option(
93 | "--to-date",
94 | type=click.DateTime(formats=["%Y-%m-%d"]),
95 | default=date(year=1995, month=12, day=31).isoformat(),
96 | required=True,
97 | help="The to date to use for the data filter - in ISO format (example: 2020-11-01) - for 01-November-2020"
98 | )
99 | @click.option(
100 | "--num-endpoints",
101 | type=int,
102 | default=1,
103 | required=True,
104 | help="The number of server threads to use for pulling data"
105 | )
106 | @click.option(
107 | "--custkey-filter-value",
108 | type=int,
109 | default=None,
110 | required=False,
111 | help="The value to use for the customer key filter - if None, then no filter will be applied"
112 | )
113 | @click.option(
114 | "--requested-columns",
115 | type=str,
116 | default="*",
117 | required=True,
118 | help="A comma-separated list of columns to request from the server - example value: 'o_orderkey,o_custkey' - if '*', then all columns will be requested"
119 | )
120 | def run_spark_flight_client(host: str,
121 | port: int,
122 | tls: bool,
123 | tls_roots: str,
124 | mtls: list,
125 | flight_username: str,
126 | flight_password: str,
127 | log_level: str,
128 | log_file: str,
129 | log_file_mode: str,
130 | from_date: datetime,
131 | to_date: datetime,
132 | num_endpoints: int,
133 | custkey_filter_value: int,
134 | requested_columns: str
135 | ):
136 | logger = get_logger(filename=log_file,
137 | filemode=log_file_mode,
138 | logger_name="flight_client",
139 | log_level=getattr(logging, log_level.upper())
140 | )
141 |
142 | with Timer(name="Spark Flight Client test",
143 | text=TIMER_TEXT,
144 | initial_text=True,
145 | logger=logger.info
146 | ):
147 | redacted_locals = {key: value for key, value in locals().items() if key not in ["flight_password"
148 | ]
149 | }
150 | logger.info(msg=f"run_spark_flight_client - was called with args: {redacted_locals}")
151 |
152 | logger.info(msg=f"Using PySpark version: {pyspark_version}")
153 |
154 | spark = (SparkSession
155 | .builder
156 | .appName("flight client")
157 | .config("spark.jars", "spark_connector/flight-spark-source-1.0-SNAPSHOT-shaded.jar")
158 | .getOrCreate())
159 |
160 | scheme = GRPC_TCP_SCHEME
161 | root_ca = ""
162 | if tls:
163 | scheme = GRPC_TLS_SCHEME
164 | # Load the root CA if it is present in the tls directory
165 | if tls_roots:
166 | root_ca_file = Path(tls_roots).resolve()
167 | if root_ca_file.exists():
168 | with open(file=root_ca_file, mode="r") as file:
169 | root_ca = file.read()
170 | logger.info(msg=f"Root CA:\n{root_ca}")
171 |
172 | mtls_cert_chain = ""
173 | mtls_private_key = ""
174 | if mtls:
175 | if not tls:
176 | raise RuntimeError("TLS must be enabled in order to use MTLS, aborting.")
177 | else:
178 | with open(file=mtls[0], mode="r") as cert_file:
179 | mtls_cert_chain = cert_file.read()
180 | with open(file=mtls[1], mode="r") as key_file:
181 | mtls_private_key = key_file.read()
182 |
183 | uri = f'{scheme}://{host}:{port}'
184 | logger.info(msg=f"Using Flight RPC Server URI: '{uri}'")
185 |
186 | arg_dict = dict(num_endpoints=num_endpoints,
187 | min_date=from_date.isoformat(),
188 | max_date=to_date.isoformat()
189 | )
190 | command_dict = dict(command="get_golden_rule_facts",
191 | kwargs=arg_dict
192 | )
193 | logger.info(msg=f"Command dict: {command_dict}")
194 |
195 | df = (spark.read.format('cdap.org.apache.arrow.flight.spark')
196 | .option('trustedCertificates', root_ca)
197 | .option('uri', uri)
198 | .option('username', flight_username)
199 | .option('password', flight_password)
200 | .option('clientCertificate', mtls_cert_chain)
201 | .option('clientKey', mtls_private_key)
202 | .load(
203 | json.dumps(command_dict)
204 | )
205 | )
206 |
207 | # Apply filters if requested
208 | if custkey_filter_value:
209 | df = df.filter(df.o_custkey == custkey_filter_value)
210 |
211 | # Select only the requested columns from the Flight RPC Server
212 | df = df.select(requested_columns.split(","))
213 |
214 | # Show the dataframe
215 | df.show(n=10)
216 |
217 |
218 | if __name__ == '__main__':
219 | run_spark_flight_client()
220 |
--------------------------------------------------------------------------------
/tls/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
--------------------------------------------------------------------------------