├── .github └── workflows │ ├── pr.yml │ └── release.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── bench ├── bench_flow.py └── early_sense.csv ├── dev-requirements.txt ├── docs ├── Makefile ├── api.rst ├── conf.py ├── examples.rst ├── index.rst ├── make.bat └── requirements.txt ├── integration ├── __init__.py ├── conftest.py ├── integration_test_utils.py ├── test_aggregation_integration.py ├── test_azure_filesystem_integration.py ├── test_filesystems_integration.py ├── test_flow_integration.py ├── test_kafka_integration.py ├── test_redis_specific.py ├── test_s3_filesystem_integration.py └── test_tdengine.py ├── requirements.txt ├── set-version.py ├── setup.cfg ├── setup.py ├── storey ├── __init__.py ├── aggregation_utils.py ├── aggregations.py ├── dataframe.py ├── drivers.py ├── dtypes.py ├── flow.py ├── queue.py ├── redis_driver.py ├── sources.py ├── sql_driver.py ├── steps │ ├── __init__.py │ ├── assertion.py │ ├── flatten.py │ ├── foreach.py │ ├── partition.py │ └── sample.py ├── table.py ├── targets.py ├── transformations │ └── __init__.py ├── utils.py └── windowed_store.py ├── tests.coveragerc └── tests ├── __init__.py ├── test-multiple-time-columns.csv ├── test-none-in-keyfield.csv ├── test-with-compact-timestamp.csv ├── test-with-none-values.csv ├── test-with-timestamp-microsecs.csv ├── test-with-timestamp-nanosecs.csv ├── test-with-timestamp.csv ├── test.csv ├── test.parquet ├── test_aggregate_by_key.py ├── test_aggregate_store.py ├── test_concurrent_execution.py ├── test_flow.py ├── test_queue.py ├── test_space_in_header.csv ├── test_space_in_header.parquet ├── test_steps.py ├── test_targets.py ├── test_types.py ├── test_utils.py ├── test_v3io.py └── test_windowed_store.py /.github/workflows/pr.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | name: CI 16 | 17 | on: pull_request_target 18 | 19 | jobs: 20 | lint: 21 | name: Lint code (Python ${{ matrix.python-version }}) 22 | runs-on: ubuntu-latest 23 | strategy: 24 | matrix: 25 | python-version: [3.9, 3.11] 26 | steps: 27 | - uses: actions/checkout@v3 28 | with: 29 | ref: refs/pull/${{ github.event.number }}/merge 30 | - name: Set up Python 31 | uses: actions/setup-python@v4 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | cache: pip 35 | - name: Install dependencies 36 | run: make dev-env 37 | - name: Lint 38 | run: make lint 39 | 40 | test: 41 | name: Unit tests (Python ${{ matrix.python-version }}) 42 | runs-on: ubuntu-latest 43 | strategy: 44 | matrix: 45 | python-version: [3.9, 3.11] 46 | steps: 47 | - uses: actions/checkout@v3 48 | with: 49 | ref: refs/pull/${{ github.event.number }}/merge 50 | - name: Set up Python 51 | uses: actions/setup-python@v4 52 | with: 53 | python-version: ${{ matrix.python-version }} 54 | cache: pip 55 | - name: Install dependencies 56 | run: make dev-env 57 | - name: Run unit tests 58 | run: make test 59 | 60 | integration: 61 | name: Integration tests (Python ${{ matrix.python-version }}) 62 | runs-on: [ self-hosted, Linux ] 63 | strategy: 64 | matrix: 65 | python-version: [3.9, 3.11] 66 | include: 67 | - python-version: 3.9 68 | image: python:3.9.18 69 | - python-version: 3.11 70 | image: python:3.11.8 71 | container: 72 | image: ${{ matrix.image }} 73 | steps: 74 | - uses: actions/checkout@v3 75 | with: 76 | ref: refs/pull/${{ github.event.number }}/merge 77 | - name: Install dependencies 78 | run: make dev-env 79 | - name: Run integration tests 80 | env: 81 | V3IO_API: ${{ secrets.V3IO_API }} 82 | V3IO_ACCESS_KEY: ${{ secrets.V3IO_ACCESS_KEY }} 83 | V3IO_FRAMESD: ${{ secrets.V3IO_FRAMESD }} 84 | run: make integration 85 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | name: Release 16 | 17 | on: 18 | release: 19 | types: 20 | - created 21 | 22 | jobs: 23 | # TODO: remove once 1.8.x is discontinued 24 | release-1-8-x: 25 | runs-on: ubuntu-latest 26 | if: github.ref_name == '1.8.x' 27 | container: 28 | image: python:3.9 29 | steps: 30 | - uses: actions/checkout@v3 31 | - name: Install dependencies 32 | run: make dev-env 33 | - name: lint 34 | run: make lint 35 | - name: test 36 | run: make test 37 | - name: Set version 38 | run: make set-version 39 | - name: Build binary wheel and source tarball 40 | run: make dist 41 | - name: Install publish dependencies 42 | run: python -m pip install twine~=6.1 43 | - name: Push to pypi 44 | run: | 45 | export TWINE_USERNAME=__token__ 46 | export TWINE_PASSWORD=${{ secrets.PYPI_TOKEN }} 47 | python -m twine upload dist/storey-*.whl 48 | 49 | release: 50 | runs-on: ubuntu-latest 51 | if: github.ref_name != '1.8.x' # TODO: remove once 1.8.x is discontinued 52 | container: 53 | image: python:3.11 54 | steps: 55 | - uses: actions/checkout@v3 56 | - name: Install dependencies 57 | run: make dev-env 58 | - name: lint 59 | run: make lint 60 | - name: test 61 | run: make test 62 | - name: Set version 63 | run: make set-version 64 | - name: Build binary wheel and source tarball 65 | run: make dist 66 | - name: Install publish dependencies 67 | run: python -m pip install twine~=6.1 68 | - name: Push to pypi 69 | run: | 70 | export TWINE_USERNAME=__token__ 71 | export TWINE_PASSWORD=${{ secrets.PYPI_TOKEN }} 72 | python -m twine upload dist/storey-*.whl 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | 4 | *.egg-info/ 5 | dist/ 6 | 7 | venv/ 8 | 9 | *.coverage 10 | coverage_reports/ 11 | 12 | .python-version 13 | 14 | docs/_build/ 15 | 16 | test.db 17 | bench-results.json 18 | 19 | .idea 20 | 21 | .vscode/* 22 | !.vscode/settings.json 23 | !.vscode/tasks.json 24 | !.vscode/launch.json 25 | !.vscode/extensions.json 26 | *.code-workspace 27 | 28 | .DS_Store 29 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: 3.9 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | formats: all 13 | 14 | python: 15 | version: 3.9 16 | install: 17 | - requirements: requirements.txt 18 | - requirements: dev-requirements.txt 19 | - requirements: docs/requirements.txt 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt dev-requirements.txt 4 | recursive-include tests *.py 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | .NOTPARALLEL: 16 | 17 | .PHONY: all 18 | all: 19 | $(error please pick a target) 20 | 21 | # We only want to format and lint checked in python files 22 | CHECKED_IN_PYTHON_FILES := $(shell git ls-files | grep '\.py$$') 23 | 24 | # Fallback 25 | ifeq ($(CHECKED_IN_PYTHON_FILES),) 26 | CHECKED_IN_PYTHON_FILES := . 27 | endif 28 | 29 | FLAKE8_OPTIONS := --max-line-length 120 --extend-ignore E203,W503 30 | BLACK_OPTIONS := --line-length 120 31 | ISORT_OPTIONS := --profile black 32 | 33 | .PHONY: fmt 34 | fmt: 35 | @echo "Running black fmt..." 36 | @python -m black $(BLACK_OPTIONS) $(CHECKED_IN_PYTHON_FILES) 37 | @echo "Running isort..." 38 | @python -m isort $(ISORT_OPTIONS) $(CHECKED_IN_PYTHON_FILES) 39 | 40 | .PHONY: lint 41 | lint: flake8 fmt-check 42 | 43 | .PHONY: fmt-check 44 | fmt-check: 45 | @echo "Running black check..." 46 | @python -m black $(BLACK_OPTIONS) --check --diff $(CHECKED_IN_PYTHON_FILES) 47 | @echo "Running isort check..." 48 | @python -m isort --check --diff $(ISORT_OPTIONS) $(CHECKED_IN_PYTHON_FILES) 49 | 50 | .PHONY: flake8 51 | flake8: 52 | @echo "Running flake8 lint..." 53 | @python -m flake8 $(FLAKE8_OPTIONS) $(CHECKED_IN_PYTHON_FILES) 54 | 55 | .PHONY: clean 56 | clean: 57 | find storey tests integration -name '*.pyc' -exec rm {} \; 58 | 59 | .PHONY: test 60 | test: clean 61 | python -m pytest --ignore=integration -rf -v . 62 | 63 | .PHONY: test-coverage 64 | test-coverage: clean 65 | rm -f coverage_reports/unit_tests.coverage 66 | COVERAGE_FILE=coverage_reports/unit_tests.coverage coverage run --rcfile=tests.coveragerc -m pytest --ignore=integration -rf -v . 67 | @echo "Unit test coverage report:" 68 | COVERAGE_FILE=coverage_reports/unit_tests.coverage coverage report --rcfile=tests.coveragerc 69 | 70 | .PHONY: bench 71 | bench: 72 | find bench -name '*.pyc' -exec rm {} \; 73 | python -m pytest --benchmark-json bench-results.json -rf -v bench/*.py 74 | 75 | .PHONY: integration 76 | integration: clean 77 | python -m pytest -rf -v integration 78 | 79 | .PHONY: integration-coverage 80 | integration-coverage: clean 81 | rm -f coverage_reports/integration.coverage 82 | COVERAGE_FILE=coverage_reports/integration.coverage coverage run --rcfile=tests.coveragerc -m pytest -rf -v integration 83 | @echo "Integration test coverage report:" 84 | COVERAGE_FILE=coverage_reports/integration.coverage coverage report --rcfile=tests.coveragerc 85 | 86 | .PHONY: env 87 | env: 88 | python -m pip install -r requirements.txt 89 | 90 | .PHONY: dev-env 91 | dev-env: env 92 | python -m pip install -r dev-requirements.txt 93 | 94 | .PHONY: docs-env 95 | docs-env: 96 | python -m pip install -r docs/requirements.txt 97 | 98 | .PHONY: dist 99 | dist: dev-env 100 | python -m build --sdist --wheel --outdir dist/ . 101 | 102 | .PHONY: set-version 103 | set-version: 104 | python set-version.py 105 | 106 | .PHONY: docs 107 | docs: # Build html docs 108 | rm -f docs/external/*.md 109 | cd docs && make html 110 | 111 | .PHONY: coverage-combine 112 | coverage-combine: 113 | rm -f coverage_reports/combined.coverage 114 | COVERAGE_FILE=coverage_reports/combined.coverage coverage combine --keep coverage_reports/integration.coverage coverage_reports/unit_tests.coverage 115 | @echo "Full coverage report:" 116 | COVERAGE_FILE=coverage_reports/combined.coverage coverage report --rcfile=tests.coveragerc -i 117 | 118 | .PHONY: coverage 119 | coverage: test-coverage integration-coverage coverage-combine 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Storey 2 | 3 | [![CI](https://github.com/mlrun/storey/workflows/CI/badge.svg)](https://github.com/mlrun/storey/actions?query=workflow%3ACI) 4 | 5 | Storey is an asynchronous streaming library, for real time event processing and feature extraction. It's a component 6 | of mlrun. 7 | 8 | ▶ For more information, see [mlrun documentation](https://docs.mlrun.org/en/stable/), and the page on [storey 9 | transformations](https://docs.mlrun.org/en/latest/api/storey.transformations.html) in particular. 10 | -------------------------------------------------------------------------------- /bench/bench_flow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import asyncio 16 | from datetime import datetime, timedelta 17 | 18 | import pandas as pd 19 | import pytest 20 | 21 | from storey import ( 22 | AggregateByKey, 23 | AsyncEmitSource, 24 | Batch, 25 | Complete, 26 | DataframeSource, 27 | Driver, 28 | FieldAggregator, 29 | Map, 30 | Reduce, 31 | SyncEmitSource, 32 | Table, 33 | build_flow, 34 | ) 35 | from storey.dtypes import SlidingWindows 36 | 37 | test_base_time = datetime.fromisoformat("2020-07-21T21:40:00+00:00") 38 | 39 | 40 | @pytest.mark.parametrize("n", [0, 1, 1000, 5000]) 41 | def test_simple_flow_n_events(benchmark, n): 42 | def inner(): 43 | controller = build_flow( 44 | [ 45 | SyncEmitSource(), 46 | Map(lambda x: x + 1), 47 | Reduce(0, lambda acc, x: acc + x), 48 | ] 49 | ).run() 50 | 51 | for i in range(n): 52 | controller.emit(i) 53 | controller.terminate() 54 | controller.await_termination() 55 | 56 | benchmark(inner) 57 | 58 | 59 | @pytest.mark.parametrize("n", [0, 1, 1000, 5000]) 60 | def test_simple_async_flow_n_events(benchmark, n): 61 | async def async_inner(): 62 | controller = build_flow( 63 | [ 64 | AsyncEmitSource(), 65 | Map(lambda x: x + 1), 66 | Reduce(0, lambda acc, x: acc + x), 67 | ] 68 | ).run() 69 | 70 | for i in range(n): 71 | await controller.emit(i) 72 | await controller.terminate() 73 | await controller.await_termination() 74 | 75 | def inner(): 76 | asyncio.run(async_inner()) 77 | 78 | benchmark(inner) 79 | 80 | 81 | @pytest.mark.parametrize("n", [0, 1, 1000, 5000]) 82 | def test_complete_flow_n_events(benchmark, n): 83 | def inner(): 84 | controller = build_flow([SyncEmitSource(), Map(lambda x: x + 1), Complete()]).run() 85 | 86 | for i in range(n): 87 | result = controller.emit(i, return_awaitable_result=True).await_result() 88 | assert result == i + 1 89 | controller.terminate() 90 | controller.await_termination() 91 | 92 | benchmark(inner) 93 | 94 | 95 | @pytest.mark.parametrize("n", [0, 1, 1000, 5000]) 96 | def test_aggregate_by_key_n_events(benchmark, n): 97 | def inner(): 98 | controller = build_flow( 99 | [ 100 | SyncEmitSource(), 101 | AggregateByKey( 102 | [ 103 | FieldAggregator( 104 | "number_of_stuff", 105 | "col1", 106 | ["sum", "avg", "min", "max"], 107 | SlidingWindows(["1h", "2h", "24h"], "10m"), 108 | ) 109 | ], 110 | Table("test", Driver()), 111 | ), 112 | ] 113 | ).run() 114 | 115 | for i in range(n): 116 | data = {"col1": i} 117 | controller.emit(data, "tal", test_base_time + timedelta(minutes=25 * i)) 118 | 119 | controller.terminate() 120 | controller.await_termination() 121 | 122 | benchmark(inner) 123 | 124 | 125 | @pytest.mark.parametrize("n", [0, 1, 1000, 5000]) 126 | def test_batch_n_events(benchmark, n): 127 | def inner(): 128 | controller = build_flow( 129 | [ 130 | SyncEmitSource(), 131 | Batch(4, 100), 132 | ] 133 | ).run() 134 | 135 | for i in range(n): 136 | controller.emit(i) 137 | 138 | controller.terminate() 139 | controller.await_termination() 140 | 141 | benchmark(inner) 142 | 143 | 144 | def test_aggregate_df_86420_events(benchmark): 145 | df = pd.read_csv("bench/early_sense.csv", parse_dates=["timestamp"]) 146 | 147 | def inner(): 148 | driver = Driver() 149 | table = Table("test", driver) 150 | 151 | controller = build_flow( 152 | [ 153 | DataframeSource(df, key_field="patient_id", time_field="timestamp"), 154 | AggregateByKey( 155 | [ 156 | FieldAggregator( 157 | "hr", 158 | "hr", 159 | ["avg", "min", "max"], 160 | SlidingWindows(["1h", "2h"], "10m"), 161 | ), 162 | FieldAggregator( 163 | "rr", 164 | "rr", 165 | ["avg", "min", "max"], 166 | SlidingWindows(["1h", "2h"], "10m"), 167 | ), 168 | FieldAggregator( 169 | "spo2", 170 | "spo2", 171 | ["avg", "min", "max"], 172 | SlidingWindows(["1h", "2h"], "10m"), 173 | ), 174 | ], 175 | table, 176 | ), 177 | ] 178 | ).run() 179 | 180 | controller.await_termination() 181 | 182 | benchmark(inner) 183 | 184 | 185 | def test_aggregate_df_86420_events_basic(benchmark): 186 | df = pd.read_csv("bench/early_sense.csv", parse_dates=["timestamp"]) 187 | 188 | def inner(): 189 | driver = Driver() 190 | table = Table("test", driver) 191 | 192 | controller = build_flow( 193 | [ 194 | DataframeSource(df, key_field="patient_id", time_field="timestamp"), 195 | AggregateByKey( 196 | [ 197 | FieldAggregator( 198 | "hr", 199 | "hr", 200 | ["sum", "count"], 201 | SlidingWindows(["1h", "2h"], "10m"), 202 | ), 203 | FieldAggregator( 204 | "rr", 205 | "rr", 206 | ["sum", "count"], 207 | SlidingWindows(["1h", "2h"], "10m"), 208 | ), 209 | FieldAggregator( 210 | "spo2", 211 | "spo2", 212 | ["sum", "count"], 213 | SlidingWindows(["1h", "2h"], "10m"), 214 | ), 215 | ], 216 | table, 217 | ), 218 | ] 219 | ).run() 220 | 221 | controller.await_termination() 222 | 223 | benchmark(inner) 224 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | build~=0.1.0 2 | black~=24.3 3 | flake8~=5.0 4 | flake8-bugbear~=22.9 5 | isort~=5.7 6 | pytest~=6.2.5 7 | coverage~=7.5 8 | pytest-benchmark~=3.2.3 9 | # Note: you might need a Lua installation to install lupa on Mac. 10 | # brew install lua 11 | lupa~=1.13 12 | fakeredis~=1.9 13 | redis~=4.3 14 | # in sqlalchemy>=2.0 there is breaking changes (such as in Table class autoload argument is removed) 15 | sqlalchemy~=1.4 16 | s3fs~=2023.9.2 17 | adlfs~=2023.9.0 18 | taos-ws-py~=0.3.2 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Storey api 2 | ================ 3 | 4 | 5 | Targets: 6 | ******** 7 | 8 | .. automodule:: storey.targets 9 | :members: 10 | :show-inheritance: 11 | 12 | Sources: 13 | ******** 14 | .. automodule:: storey.sources 15 | :members: 16 | :show-inheritance: 17 | 18 | Transformations: 19 | **************** 20 | 21 | .. automodule:: storey.transformations 22 | :members: 23 | :show-inheritance: 24 | :imported-members: 25 | 26 | 27 | Miscellaneous: 28 | ************** 29 | 30 | .. automodule:: storey.drivers 31 | :members: 32 | :show-inheritance: 33 | 34 | .. automodule:: storey.dtypes 35 | :members: 36 | :show-inheritance: 37 | 38 | .. automodule:: storey.table 39 | :members: 40 | :show-inheritance: 41 | 42 | .. automodule:: storey.aggregations 43 | :members: 44 | :show-inheritance: 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "storey" 22 | copyright = "2023, Iguazio" 23 | author = "Iguazio" 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "sphinx.ext.autodoc", 33 | ] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ["_templates"] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | html_theme = "alabaster" 50 | try: 51 | import sphinx_rtd_theme # noqa 52 | 53 | html_theme = "sphinx_book_theme" 54 | except ImportError: 55 | pass 56 | 57 | # Add any paths that contain custom static files (such as style sheets) here, 58 | # relative to this directory. They are copied after the builtin static files, 59 | # so a file named "default.css" will overwrite the builtin "default.css". 60 | html_static_path = [] 61 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Storey examples 2 | ============================ 3 | 4 | Example showing aggregation by key 5 | 6 | .. code-block:: python 7 | 8 | table = Table(setup_teardown_test, V3ioDriver(), partitioned_by_key=partitioned_by_key) 9 | 10 | controller = build_flow([ 11 | SyncEmitSource(), 12 | AggregateByKey([FieldAggregator("number_of_stuff", "col1", ["sum", "avg", "min", "max", "sqr"], 13 | SlidingWindows(['1h', '2h', '24h'], '10m'))], 14 | table), 15 | NoSqlTarget(table), 16 | Reduce([], lambda acc, x: append_return(acc, x)), 17 | ]).run() 18 | 19 | items_in_ingest_batch = 10 20 | for i in range(items_in_ingest_batch): 21 | data = {'col1': i} 22 | controller.emit(data, 'tal', test_base_time + timedelta(minutes=25 * i)) 23 | 24 | controller.terminate() 25 | result = controller.await_termination() 26 | 27 | 28 | Example showing join with V3IO table: 29 | 30 | .. code-block:: python 31 | 32 | table_path = "path_to_table" 33 | controller = build_flow([ 34 | SyncEmitSource(), 35 | Map(lambda x: x + 1), 36 | Filter(lambda x: x < 8), 37 | JoinWithV3IOTable(V3ioDriver(), lambda x: x, lambda x, y: y['age'], table_path), 38 | Reduce(0, lambda x, y: x + y) 39 | ]).run() 40 | for i in range(10): 41 | controller.emit(i) 42 | 43 | controller.terminate() 44 | result = controller.await_termination() 45 | 46 | 47 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Storey Package Documentation 2 | ================================== 3 | Introduction 4 | ================ 5 | .. _asyncio: https://docs.python.org/3/library/asyncio.html 6 | 7 | Storey is streaming library for real time event processing and feature extraction. It's based on asyncio_ and offers both synchronous and 8 | asynchronous APIs. Storey flows are graphs of steps that perform computational and IO tasks. A basic synchronous flow is created and run as 9 | such: 10 | 11 | .. code-block:: python 12 | 13 | from storey import build_flow, SyncEmitSource, CSVTarget 14 | 15 | controller = build_flow([ 16 | SyncEmitSource(), 17 | CSVTarget('myfile.csv', columns=['n', 'n*10'], header=True) 18 | ]).run() 19 | 20 | for i in range(10): 21 | controller.emit({'n': i, 'n*10': 10 * i}) 22 | 23 | controller.terminate() 24 | controller.await_termination() 25 | 26 | This example constructs a flow that writes events to a CSV file, runs it, then pushes events into that flow. 27 | 28 | The same example can also be run from within an async context: 29 | 30 | .. code-block:: python 31 | 32 | from storey import build_flow, AsyncEmitSource, CSVTarget 33 | 34 | controller = build_flow([ 35 | AsyncEmitSource(), 36 | CSVTarget('myfile.csv', columns=['n', 'n*10'], header=True) 37 | ]).run() 38 | 39 | for i in range(10): 40 | await controller.emit({'n': i, 'n*10': 10 * i}) 41 | 42 | await controller.terminate() 43 | await controller.await_termination() 44 | 45 | The following more interesting example takes a dataframe, aggregates its data using a sliding window, and persists the result to a 46 | V3IO key-value store. 47 | 48 | .. code-block:: python 49 | 50 | from storey import build_flow, DataframeSource, AggregateByKey, FieldAggregator, SlidingWindows, NoSqlTarget, V3ioDriver, Table 51 | 52 | table = Table(f'users/me/destination', V3ioDriver()) 53 | 54 | controller = build_flow([ 55 | DataframeSource(df, key_column='user_id', time_column='timestamp'), 56 | AggregateByKey([ 57 | FieldAggregator("feature1", "field1", ["avg", "min", "max"], 58 | SlidingWindows(['1h', '2h'], '10m')), 59 | FieldAggregator("feature2", "field2", ["avg", "min", "max"], 60 | SlidingWindows(['1h', '2h'], '10m')), 61 | FieldAggregator("feature3", "field3", ["avg", "min", "max"], 62 | SlidingWindows(['1h', '2h'], '10m')) 63 | ], 64 | table), 65 | NoSqlTarget(table, columns=['feature1', 'feature2', 'feature3']), 66 | ]).run() 67 | 68 | controller.await_termination() 69 | 70 | 71 | .. toctree:: 72 | :maxdepth: 2 73 | :caption: Contents: 74 | 75 | examples 76 | api 77 | 78 | Indices and tables 79 | ================== 80 | 81 | * :ref:`genindex` 82 | * :ref:`modindex` 83 | * :ref:`search` 84 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpydoc~=0.9 2 | recommonmark~=0.6 3 | sphinx~=5.3 4 | sphinx_rtd_theme~=1.2 5 | sphinx-copybutton~=0.3 6 | sphinx-togglebutton~=0.2.2 7 | Jinja2~=3.1 8 | myst-nb~=0.10 9 | sphinx-book-theme~=1.0.1 10 | markdown-it-py~=2.2 11 | mdit-py-plugins~=0.3 12 | ipython<9 13 | -------------------------------------------------------------------------------- /integration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /integration/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import asyncio 16 | import os 17 | from datetime import datetime 18 | 19 | import fakeredis 20 | import pytest 21 | 22 | from integration.integration_test_utils import ( 23 | V3ioHeaders, 24 | _generate_table_name, 25 | create_temp_kv, 26 | create_temp_redis_kv, 27 | drivers_list, 28 | get_redis_client, 29 | recursive_delete, 30 | remove_redis_table, 31 | remove_sql_tables, 32 | ) 33 | from storey import V3ioDriver 34 | from storey.redis_driver import RedisDriver 35 | from storey.sql_driver import SQLDriver 36 | 37 | SQLITE_DB = "sqlite:///test.db" 38 | 39 | 40 | @pytest.fixture(params=drivers_list) 41 | def setup_teardown_test(request): 42 | # Setup 43 | if request.param == "SQLDriver" and request.fspath.basename != "test_flow_integration.py": 44 | pytest.skip("SQLDriver test only in test_flow_integration") 45 | test_context = ContextForTests(request.param, table_name=_generate_table_name()) 46 | 47 | # Test runs 48 | yield test_context 49 | 50 | # Teardown 51 | if test_context.driver_name == "V3ioDriver": 52 | asyncio.run(recursive_delete(test_context.table_name, V3ioHeaders())) 53 | elif test_context.driver_name == "RedisDriver": 54 | remove_redis_table(test_context.table_name) 55 | elif test_context.driver_name == "SQLDriver": 56 | remove_sql_tables() 57 | else: 58 | raise ValueError(f'Unsupported driver name "{test_context.driver_name}"') 59 | 60 | 61 | @pytest.fixture(params=drivers_list) 62 | def setup_kv_teardown_test(request): 63 | # Setup 64 | test_context = ContextForTests(request.param, table_name=_generate_table_name()) 65 | 66 | if test_context.driver_name == "V3ioDriver": 67 | asyncio.run(create_temp_kv(test_context.table_name)) 68 | elif test_context.driver_name == "RedisDriver": 69 | create_temp_redis_kv(test_context) 70 | elif test_context.driver_name == "SQLDriver": 71 | pytest.skip(msg="test not relevant for SQLDriver") 72 | else: 73 | raise ValueError(f'Unsupported driver name "{test_context.driver_name}"') 74 | 75 | # Test runs 76 | yield test_context 77 | 78 | # Teardown 79 | if test_context.driver_name == "V3ioDriver": 80 | asyncio.run(recursive_delete(test_context.table_name, V3ioHeaders())) 81 | elif test_context.driver_name == "RedisDriver": 82 | remove_redis_table(test_context.table_name) 83 | else: 84 | raise ValueError(f'Unsupported driver name "{test_context.driver_name}"') 85 | 86 | 87 | @pytest.fixture() 88 | def assign_stream_teardown_test(): 89 | # Setup 90 | stream_path = _generate_table_name("bigdata/storey_ci/stream_test") 91 | 92 | # Test runs 93 | yield stream_path 94 | 95 | # Teardown 96 | asyncio.run(recursive_delete(stream_path, V3ioHeaders())) 97 | 98 | 99 | # Can't call it TestContext because then pytest tries to run it as if it were a test suite 100 | class ContextForTests: 101 | def __init__(self, driver_name: str, table_name: str): 102 | self._driver_name = driver_name 103 | self._table_name = table_name 104 | # sqlite cant save time zone 105 | self.test_base_time = ( 106 | datetime.fromisoformat("2020-07-21T21:40:00+00:00") 107 | if driver_name != "SQLDriver" 108 | else datetime.fromisoformat("2020-07-21T21:40:00") 109 | ) 110 | 111 | self._redis_fake_server = None 112 | if driver_name == "RedisDriver": 113 | redis_url = os.environ.get("MLRUN_REDIS_URL") 114 | if not redis_url: 115 | # if we are using fakeredis, create fake-server to support tests involving multiple clients 116 | self._redis_fake_server = fakeredis.FakeServer() 117 | if driver_name == "SQLDriver": 118 | self._sql_db_path = SQLITE_DB 119 | self._sql_table_name = table_name.split("/")[-2] 120 | self._table_name = f"{SQLITE_DB}/{self._sql_table_name}" 121 | 122 | @property 123 | def table_name(self): 124 | return self._table_name 125 | 126 | @property 127 | def redis_fake_server(self): 128 | return self._redis_fake_server 129 | 130 | @property 131 | def driver_name(self): 132 | return self._driver_name 133 | 134 | @property 135 | def sql_db_path(self): 136 | return self._sql_db_path 137 | 138 | class AggregationlessV3ioDriver(V3ioDriver): 139 | def supports_aggregations(self): 140 | return False 141 | 142 | class AggregationlessRedisDriver(RedisDriver): 143 | def supports_aggregations(self): 144 | return False 145 | 146 | def driver(self, *args, primary_key=None, is_aggregationless_driver=False, time_fields=None, **kwargs): 147 | if self.driver_name == "V3ioDriver": 148 | v3io_driver_class = ContextForTests.AggregationlessV3ioDriver if is_aggregationless_driver else V3ioDriver 149 | return v3io_driver_class(*args, **kwargs) 150 | elif self.driver_name == "RedisDriver": 151 | redis_driver_class = ( 152 | ContextForTests.AggregationlessRedisDriver if is_aggregationless_driver else RedisDriver 153 | ) 154 | return redis_driver_class( 155 | *args, 156 | redis_client=get_redis_client(self.redis_fake_server), 157 | key_prefix="storey-test:", 158 | **kwargs, 159 | ) 160 | elif self.driver_name == "SQLDriver": 161 | if is_aggregationless_driver: 162 | sql_driver_class = SQLDriver 163 | return sql_driver_class(db_path=SQLITE_DB, primary_key=primary_key, time_fields=time_fields) 164 | else: 165 | pytest.skip("SQLDriver does not support aggregation") 166 | else: 167 | driver_name = self.driver_name 168 | raise ValueError(f'Unsupported driver name "{driver_name}"') 169 | -------------------------------------------------------------------------------- /integration/integration_test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import base64 16 | import json 17 | import os 18 | import random 19 | import re 20 | import string 21 | from datetime import datetime, timedelta 22 | 23 | import aiohttp 24 | import fakeredis 25 | import pandas as pd 26 | import redis as r 27 | 28 | import integration.conftest 29 | from storey.drivers import NeedsV3ioAccess 30 | from storey.flow import V3ioError 31 | from storey.redis_driver import RedisDriver 32 | 33 | _non_int_char_pattern = re.compile(r"[^-0-9]") 34 | 35 | 36 | class V3ioHeaders(NeedsV3ioAccess): 37 | def __init__(self, **kwargs): 38 | super().__init__(**kwargs) 39 | self._get_item_headers = { 40 | "X-v3io-function": "GetItem", 41 | "X-v3io-session-key": self._access_key, 42 | } 43 | 44 | self._get_items_headers = { 45 | "X-v3io-function": "GetItems", 46 | "X-v3io-session-key": self._access_key, 47 | } 48 | 49 | self._put_item_headers = { 50 | "X-v3io-function": "PutItem", 51 | "X-v3io-session-key": self._access_key, 52 | } 53 | 54 | self._update_item_headers = { 55 | "X-v3io-function": "UpdateItem", 56 | "X-v3io-session-key": self._access_key, 57 | } 58 | 59 | self._put_records_headers = { 60 | "X-v3io-function": "PutRecords", 61 | "X-v3io-session-key": self._access_key, 62 | } 63 | 64 | self._create_stream_headers = { 65 | "X-v3io-function": "CreateStream", 66 | "X-v3io-session-key": self._access_key, 67 | } 68 | 69 | self._describe_stream_headers = { 70 | "X-v3io-function": "DescribeStream", 71 | "X-v3io-session-key": self._access_key, 72 | } 73 | 74 | self._seek_headers = { 75 | "X-v3io-function": "Seek", 76 | "X-v3io-session-key": self._access_key, 77 | } 78 | 79 | self._get_records_headers = { 80 | "X-v3io-function": "GetRecords", 81 | "X-v3io-session-key": self._access_key, 82 | } 83 | 84 | self._get_put_file_headers = {"X-v3io-session-key": self._access_key} 85 | 86 | 87 | def append_return(lst, x): 88 | lst.append(x) 89 | return lst 90 | 91 | 92 | def _generate_table_name(prefix="bigdata/storey_ci/Aggr_test"): 93 | random_table = "".join([random.choice(string.ascii_letters) for i in range(10)]) 94 | return f"{prefix}/{random_table}/" 95 | 96 | 97 | def get_redis_client(redis_fake_server=None): 98 | redis_url = os.environ.get("MLRUN_REDIS_URL") 99 | if redis_url: 100 | try: 101 | res = r.cluster.RedisCluster.from_url(redis_url) 102 | return res 103 | except r.cluster.RedisClusterException: 104 | return r.Redis.from_url(redis_url) 105 | else: 106 | return fakeredis.FakeRedis(decode_responses=True, server=redis_fake_server) 107 | 108 | 109 | def remove_redis_table(table_name): 110 | redis_client = get_redis_client() 111 | count = 0 112 | for key in redis_client.scan_iter(f"*storey-test:{table_name}*"): 113 | redis_client.delete(key) 114 | count += 1 115 | 116 | 117 | def remove_sql_tables(): 118 | import sqlalchemy as db 119 | 120 | engine = db.create_engine(integration.conftest.SQLITE_DB) 121 | with engine.connect(): 122 | metadata = db.MetaData() 123 | metadata.reflect(bind=engine) 124 | # drop them, if they exist 125 | metadata.drop_all(bind=engine, checkfirst=True) 126 | engine.dispose() 127 | 128 | 129 | drivers_list = ["V3ioDriver", "RedisDriver", "SQLDriver"] 130 | 131 | 132 | async def create_stream(stream_path): 133 | v3io_access = V3ioHeaders() 134 | connector = aiohttp.TCPConnector() 135 | client_session = aiohttp.ClientSession(connector=connector) 136 | request_body = json.dumps({"ShardCount": 2, "RetentionPeriodHours": 1}) 137 | response = await client_session.request( 138 | "POST", 139 | f"{v3io_access._webapi_url}/{stream_path}/", 140 | headers=v3io_access._create_stream_headers, 141 | data=request_body, 142 | ssl=False, 143 | ) 144 | assert response.status == 204, f"Bad response {await response.text()} to request {request_body}" 145 | 146 | 147 | def create_temp_redis_kv(setup_teardown_test): 148 | # Create the data we'll join with in Redis. 149 | table_path = setup_teardown_test.table_name 150 | redis_fake_server = setup_teardown_test.redis_fake_server 151 | redis_client = get_redis_client(redis_fake_server=redis_fake_server) 152 | 153 | for i in range(1, 10): 154 | key = RedisDriver.make_key("storey-test:", table_path, i) 155 | static_key = RedisDriver._static_data_key(key) 156 | redis_client.hset(static_key, mapping={"age": f"{10 - i}", "color": f"blue{i}"}) 157 | 158 | 159 | async def create_temp_kv(table_path): 160 | connector = aiohttp.TCPConnector() 161 | v3io_access = V3ioHeaders() 162 | client_session = aiohttp.ClientSession(connector=connector) 163 | for i in range(1, 10): 164 | request_body = json.dumps({"Item": {"age": {"N": f"{10 - i}"}, "color": {"S": f"blue{i}"}}}) 165 | response = await client_session.request( 166 | "PUT", 167 | f"{v3io_access._webapi_url}/{table_path}/{i}", 168 | headers=v3io_access._put_item_headers, 169 | data=request_body, 170 | ssl=False, 171 | ) 172 | assert response.status == 200, f"Bad response {await response.text()} to request {request_body}" 173 | 174 | 175 | def _v3io_parse_get_items_response(response_body): 176 | response_object = json.loads(response_body) 177 | i = 0 178 | for item in response_object["Items"]: 179 | parsed_item = {} 180 | for name, type_to_value in item.items(): 181 | for typ, value in type_to_value.items(): 182 | val = _convert_nginx_to_python_type(typ, value) 183 | parsed_item[name] = val 184 | response_object["Items"][i] = parsed_item 185 | i = i + 1 186 | return response_object 187 | 188 | 189 | # Deletes the entire table 190 | async def recursive_delete(path, v3io_access): 191 | connector = aiohttp.TCPConnector() 192 | client_session = aiohttp.ClientSession(connector=connector) 193 | 194 | try: 195 | has_more = True 196 | next_marker = "" 197 | while has_more: 198 | get_items_body = {"AttributesToGet": "__name", "Marker": next_marker} 199 | response = await client_session.put( 200 | f"{v3io_access._webapi_url}/{path}/", 201 | headers=v3io_access._get_items_headers, 202 | data=json.dumps(get_items_body), 203 | ssl=False, 204 | ) 205 | body = await response.text() 206 | if response.status == 200: 207 | res = _v3io_parse_get_items_response(body) 208 | for item in res["Items"]: 209 | await _delete_item( 210 | f'{v3io_access._webapi_url}/{path}/{item["__name"]}', 211 | v3io_access, 212 | client_session, 213 | ) 214 | 215 | has_more = "NextMarker" in res 216 | if has_more: 217 | next_marker = res["NextMarker"] 218 | elif response.status == 404: 219 | break 220 | else: 221 | raise V3ioError(f"Failed to delete table {path}. Response status code was {response.status}: {body}") 222 | 223 | await _delete_item(f"{v3io_access._webapi_url}/{path}/", v3io_access, client_session) 224 | finally: 225 | await client_session.close() 226 | 227 | 228 | async def _delete_item(path, v3io_access, client_session): 229 | response = await client_session.delete(path, headers=v3io_access._get_put_file_headers, ssl=False) 230 | if response.status >= 300 and response.status != 404 and response.status != 409: 231 | body = await response.text() 232 | raise V3ioError(f"Failed to delete item at {path}. Response status code was {response.status}: {body}") 233 | 234 | 235 | def _v3io_parse_get_item_response(response_body): 236 | response_object = json.loads(response_body)["Item"] 237 | for name, type_to_value in response_object.items(): 238 | val = None 239 | for typ, value in type_to_value.items(): 240 | val = _convert_nginx_to_python_type(typ, value) 241 | response_object[name] = val 242 | return response_object 243 | 244 | 245 | def _convert_nginx_to_python_type(typ, value): 246 | if typ == "S" or typ == "BOOL": 247 | return value 248 | elif typ == "N": 249 | if _non_int_char_pattern.search(value): 250 | return float(value) 251 | else: 252 | return int(value) 253 | elif typ == "B": 254 | return base64.b64decode(value) 255 | elif typ == "TS": 256 | splits = value.split(":", 1) 257 | secs = int(splits[0]) 258 | nanosecs = int(splits[1]) 259 | return datetime.utcfromtimestamp(secs + nanosecs / 1000000000) 260 | else: 261 | raise V3ioError(f"Type {typ} in get item response is not supported") 262 | 263 | 264 | def create_sql_table(schema, table_name, sql_db_path, key): 265 | import sqlalchemy as db 266 | 267 | engine = db.create_engine(sql_db_path) 268 | with engine.connect(): 269 | metadata = db.MetaData() 270 | columns = [] 271 | for col, col_type in schema.items(): 272 | if col_type == int: 273 | col_type = db.Integer 274 | elif col_type == str: 275 | col_type = db.String 276 | elif col_type == timedelta or col_type == pd.Timedelta: 277 | col_type = db.Interval 278 | elif col_type == datetime or col_type == pd.Timestamp: 279 | col_type = db.DATETIME 280 | elif col_type == bool: 281 | col_type = db.Boolean 282 | elif col_type == float: 283 | col_type = db.Float 284 | else: 285 | raise TypeError(f"Column '{col}' has unsupported type '{col_type}'") 286 | columns.append(db.Column(col, col_type, primary_key=(col in key))) 287 | 288 | db.Table(table_name, metadata, *columns) 289 | metadata.create_all(engine) 290 | -------------------------------------------------------------------------------- /integration/test_azure_filesystem_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import asyncio 16 | import os 17 | import uuid 18 | 19 | import pandas as pd 20 | import pytest 21 | 22 | from storey import ( 23 | AsyncEmitSource, 24 | CSVSource, 25 | CSVTarget, 26 | FlatMap, 27 | Map, 28 | ParquetTarget, 29 | Reduce, 30 | SyncEmitSource, 31 | build_flow, 32 | ) 33 | 34 | from .integration_test_utils import _generate_table_name 35 | 36 | has_azure_credentials = ( 37 | os.getenv("AZURE_ACCOUNT_NAME") and os.getenv("AZURE_ACCOUNT_KEY") and os.getenv("AZURE_BLOB_STORE") 38 | ) 39 | if has_azure_credentials: 40 | storage_options = { 41 | "account_name": os.getenv("AZURE_ACCOUNT_NAME"), 42 | "account_key": os.getenv("AZURE_ACCOUNT_KEY"), 43 | } 44 | from adlfs import AzureBlobFileSystem 45 | 46 | 47 | @pytest.fixture() 48 | def azure_create_csv(): 49 | # Setup 50 | azure_blob = os.getenv("AZURE_BLOB_STORE") 51 | file_path = _generate_table_name(f"{azure_blob}/az_storey") 52 | 53 | _write_test_csv(file_path) 54 | 55 | # Test runs 56 | yield file_path 57 | 58 | # Teardown 59 | _delete_file(file_path) 60 | 61 | 62 | @pytest.fixture() 63 | def azure_teardown_file(): 64 | # Setup 65 | azure_blob = os.getenv("AZURE_BLOB_STORE") 66 | file_path = _generate_table_name(f"{azure_blob}/az_storey") 67 | 68 | # Test runs 69 | yield file_path 70 | 71 | # Teardown 72 | _delete_file(file_path) 73 | 74 | 75 | @pytest.fixture() 76 | def azure_setup_teardown_test(): 77 | # Setup 78 | table_name = _generate_table_name(f'{os.getenv("AZURE_BLOB_STORE")}/test') 79 | 80 | # Test runs 81 | yield table_name 82 | 83 | # Teardown 84 | azure_recursive_delete(table_name) 85 | 86 | 87 | def _write_test_csv(file_path): 88 | az_fs = AzureBlobFileSystem(**storage_options) 89 | data = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 90 | with az_fs.open(file_path, "w") as f: 91 | f.write(data) 92 | 93 | 94 | def _delete_file(path): 95 | az_fs = AzureBlobFileSystem(**storage_options) 96 | az_fs.delete(path) 97 | 98 | 99 | def azure_recursive_delete(path): 100 | az_fs = AzureBlobFileSystem(**storage_options) 101 | az_fs.rm(path, True) 102 | 103 | 104 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 105 | def test_csv_reader_from_azure(azure_create_csv): 106 | controller = build_flow( 107 | [ 108 | CSVSource( 109 | f"az:///{azure_create_csv}", 110 | storage_options=storage_options, 111 | ), 112 | FlatMap(lambda x: x), 113 | Map(lambda x: int(x)), 114 | Reduce(0, lambda acc, x: acc + x), 115 | ] 116 | ).run() 117 | 118 | termination_result = controller.await_termination() 119 | assert termination_result == 495 120 | 121 | 122 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 123 | def test_csv_reader_from_azure_error_on_file_not_found(): 124 | with pytest.raises(FileNotFoundError): 125 | controller = build_flow( 126 | [ 127 | CSVSource( 128 | f'az:///{os.getenv("AZURE_BLOB_STORE")}/idontexist.csv', 129 | storage_options=storage_options, 130 | ), 131 | ] 132 | ).run() 133 | controller.await_termination() 134 | 135 | 136 | async def async_test_write_csv_to_azure(azure_teardown_csv): 137 | controller = build_flow( 138 | [ 139 | AsyncEmitSource(), 140 | CSVTarget( 141 | f"az:///{azure_teardown_csv}", 142 | columns=["n", "n*10"], 143 | header=True, 144 | storage_options=storage_options, 145 | ), 146 | ] 147 | ).run() 148 | 149 | for i in range(10): 150 | await controller.emit([i, 10 * i]) 151 | 152 | await controller.terminate() 153 | await controller.await_termination() 154 | 155 | actual = AzureBlobFileSystem(**storage_options).open(azure_teardown_csv).read() 156 | 157 | expected = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 158 | assert actual.decode("utf-8") == expected 159 | 160 | 161 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 162 | def test_write_csv_to_azure(azure_teardown_file): 163 | asyncio.run(async_test_write_csv_to_azure(azure_teardown_file)) 164 | 165 | 166 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 167 | def test_write_csv_with_dict_to_azure(azure_teardown_file): 168 | file_path = f"az:///{azure_teardown_file}" 169 | controller = build_flow( 170 | [ 171 | SyncEmitSource(), 172 | CSVTarget( 173 | file_path, 174 | columns=["n", "n*10"], 175 | header=True, 176 | storage_options=storage_options, 177 | ), 178 | ] 179 | ).run() 180 | 181 | for i in range(10): 182 | controller.emit({"n": i, "n*10": 10 * i}) 183 | 184 | controller.terminate() 185 | controller.await_termination() 186 | 187 | actual = AzureBlobFileSystem(**storage_options).open(azure_teardown_file).read() 188 | expected = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 189 | assert actual.decode("utf-8") == expected 190 | 191 | 192 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 193 | def test_write_csv_infer_columns_without_header_to_azure(azure_teardown_file): 194 | file_path = f"az:///{azure_teardown_file}" 195 | controller = build_flow([SyncEmitSource(), CSVTarget(file_path, storage_options=storage_options)]).run() 196 | 197 | for i in range(10): 198 | controller.emit({"n": i, "n*10": 10 * i}) 199 | 200 | controller.terminate() 201 | controller.await_termination() 202 | 203 | actual = AzureBlobFileSystem(**storage_options).open(azure_teardown_file).read() 204 | expected = "0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 205 | assert actual.decode("utf-8") == expected 206 | 207 | 208 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 209 | def test_write_csv_from_lists_with_metadata_and_column_pruning_to_azure( 210 | azure_teardown_file, 211 | ): 212 | file_path = f"az:///{azure_teardown_file}" 213 | controller = build_flow( 214 | [ 215 | SyncEmitSource(), 216 | CSVTarget( 217 | file_path, 218 | columns=["event_key=$key", "n*10"], 219 | header=True, 220 | storage_options=storage_options, 221 | ), 222 | ] 223 | ).run() 224 | 225 | for i in range(10): 226 | controller.emit({"n": i, "n*10": 10 * i}, key=f"key{i}") 227 | 228 | controller.terminate() 229 | controller.await_termination() 230 | 231 | actual = AzureBlobFileSystem(**storage_options).open(azure_teardown_file).read() 232 | expected = ( 233 | "event_key,n*10\n" 234 | "key0,0\n" 235 | "key1,10\n" 236 | "key2,20\n" 237 | "key3,30\n" 238 | "key4,40\n" 239 | "key5,50\n" 240 | "key6,60\n" 241 | "key7,70\n" 242 | "key8,80\n" 243 | "key9,90\n" 244 | ) 245 | assert actual.decode("utf-8") == expected 246 | 247 | 248 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 249 | def test_write_to_parquet_to_azure(azure_setup_teardown_test): 250 | out_dir = f"az:///{azure_setup_teardown_test}" 251 | columns = ["my_int", "my_string"] 252 | controller = build_flow( 253 | [ 254 | SyncEmitSource(), 255 | ParquetTarget( 256 | out_dir, 257 | partition_cols="my_int", 258 | columns=columns, 259 | max_events=1, 260 | storage_options=storage_options, 261 | ), 262 | ] 263 | ).run() 264 | 265 | expected = [] 266 | for i in range(10): 267 | controller.emit([i, f"this is {i}"]) 268 | expected.append([i, f"this is {i}"]) 269 | expected = pd.DataFrame(expected, columns=columns) 270 | controller.terminate() 271 | controller.await_termination() 272 | 273 | read_back_df = pd.read_parquet(out_dir, columns=columns, storage_options=storage_options) 274 | read_back_df["my_int"] = read_back_df["my_int"].astype("int64") 275 | pd.testing.assert_frame_equal(read_back_df, expected) 276 | 277 | 278 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 279 | def test_write_to_parquet_to_azure_single_file_on_termination( 280 | azure_setup_teardown_test, 281 | ): 282 | out_file = f"az:///{azure_setup_teardown_test}/out.parquet" 283 | columns = ["my_int", "my_string"] 284 | controller = build_flow( 285 | [ 286 | SyncEmitSource(), 287 | ParquetTarget(out_file, columns=columns, storage_options=storage_options), 288 | ] 289 | ).run() 290 | 291 | expected = [] 292 | for i in range(10): 293 | controller.emit([i, f"this is {i}"]) 294 | expected.append([i, f"this is {i}"]) 295 | expected = pd.DataFrame(expected, columns=columns) 296 | controller.terminate() 297 | controller.await_termination() 298 | 299 | read_back_df = pd.read_parquet(out_file, columns=columns, storage_options=storage_options) 300 | read_back_df["my_int"] = read_back_df["my_int"].astype("int64") 301 | pd.testing.assert_frame_equal(read_back_df, expected) 302 | 303 | 304 | @pytest.mark.skipif(not has_azure_credentials, reason="No azure credentials found") 305 | def test_write_to_parquet_to_azure_with_indices(azure_setup_teardown_test): 306 | out_file = f"az:///{azure_setup_teardown_test}/test_write_to_parquet_with_indices{uuid.uuid4().hex}.parquet" 307 | controller = build_flow( 308 | [ 309 | SyncEmitSource(), 310 | ParquetTarget( 311 | out_file, 312 | index_cols="event_key=$key", 313 | columns=["my_int", "my_string"], 314 | storage_options=storage_options, 315 | ), 316 | ] 317 | ).run() 318 | 319 | expected = [] 320 | for i in range(10): 321 | controller.emit([i, f"this is {i}"], key=f"key{i}") 322 | expected.append([f"key{i}", i, f"this is {i}"]) 323 | columns = ["event_key", "my_int", "my_string"] 324 | expected = pd.DataFrame(expected, columns=columns) 325 | expected.set_index(["event_key"], inplace=True) 326 | controller.terminate() 327 | controller.await_termination() 328 | 329 | read_back_df = pd.read_parquet(out_file, columns=columns, storage_options=storage_options) 330 | read_back_df["my_int"] = read_back_df["my_int"].astype("int64") 331 | pd.testing.assert_frame_equal(read_back_df, expected) 332 | -------------------------------------------------------------------------------- /integration/test_kafka_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import asyncio 16 | import datetime 17 | import json 18 | import os 19 | from time import sleep 20 | 21 | import pytest 22 | 23 | from storey import AsyncEmitSource, Event, Reduce, SyncEmitSource, build_flow 24 | from storey.targets import KafkaTarget 25 | 26 | kafka_brokers = os.getenv("KAFKA_BROKERS") 27 | topic = "test_kafka_integration" 28 | 29 | if kafka_brokers: 30 | import kafka 31 | 32 | 33 | def append_return(lst, x): 34 | lst.append(x) 35 | return lst 36 | 37 | 38 | @pytest.fixture() 39 | def kafka_topic_setup_teardown(): 40 | # Setup 41 | kafka_admin_client = kafka.KafkaAdminClient(bootstrap_servers=kafka_brokers) 42 | kafka_consumer = kafka.KafkaConsumer(topic, bootstrap_servers=kafka_brokers, auto_offset_reset="earliest") 43 | try: 44 | kafka_admin_client.delete_topics([topic]) 45 | sleep(1) 46 | except kafka.errors.UnknownTopicOrPartitionError: 47 | pass 48 | kafka_admin_client.create_topics([kafka.admin.NewTopic(topic, 1, 1)]) 49 | 50 | # Test runs 51 | yield kafka_consumer 52 | 53 | # Teardown 54 | kafka_admin_client.delete_topics([topic]) 55 | kafka_admin_client.close() 56 | kafka_consumer.close() 57 | 58 | 59 | @pytest.mark.skipif( 60 | not kafka_brokers, 61 | reason="KAFKA_BROKERS must be defined to run kafka tests", 62 | ) 63 | def test_kafka_target(kafka_topic_setup_teardown): 64 | kafka_consumer = kafka_topic_setup_teardown 65 | 66 | controller = build_flow( 67 | [ 68 | SyncEmitSource(), 69 | KafkaTarget(kafka_brokers, topic, sharding_func=0, full_event=False), 70 | ] 71 | ).run() 72 | events = [] 73 | for i in range(100): 74 | key = None 75 | if i > 0: 76 | key = f"key{i}" 77 | event = Event({"hello": i, "time": datetime.datetime(2023, 12, 26)}, key) 78 | events.append(event) 79 | controller.emit(event) 80 | 81 | controller.terminate() 82 | controller.await_termination() 83 | 84 | kafka_consumer.subscribe([topic]) 85 | for event in events: 86 | record = next(kafka_consumer) 87 | if event.key is None: 88 | if event.key is None: 89 | assert record.key is None 90 | else: 91 | assert record.key.decode("UTF-8") == event.key 92 | assert record.value.decode("UTF-8") == json.dumps(event.body, default=str) 93 | 94 | 95 | async def async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key): 96 | kafka_consumer = kafka_topic_setup_teardown 97 | 98 | controller = build_flow( 99 | [ 100 | AsyncEmitSource(), 101 | KafkaTarget(kafka_brokers, topic, sharding_func=lambda _: partition_key, full_event=True), 102 | ] 103 | ).run() 104 | events = [] 105 | for i in range(10): 106 | event = Event(i, id=str(i)) 107 | events.append(event) 108 | await controller.emit(event) 109 | 110 | await asyncio.sleep(5) 111 | 112 | readback_records = [] 113 | kafka_consumer.subscribe([topic]) 114 | for event in events: 115 | record = next(kafka_consumer) 116 | if event.key is None: 117 | if event.key is None: 118 | if isinstance(partition_key, int): 119 | assert record.key is None 120 | else: 121 | assert record.key.decode("UTF-8") == partition_key 122 | else: 123 | assert record.key.decode("UTF-8") == event.key 124 | readback_records.append(json.loads(record.value.decode("UTF-8"))) 125 | 126 | controller = build_flow( 127 | [ 128 | AsyncEmitSource(), 129 | Reduce([], lambda acc, x: append_return(acc, x), full_event=True), 130 | ] 131 | ).run() 132 | for record in readback_records: 133 | await controller.emit(Event(record, id="some-new-id")) 134 | 135 | await controller.terminate() 136 | result = await controller.await_termination() 137 | 138 | assert len(result) == 10 139 | 140 | for i, record in enumerate(result): 141 | assert record.body == i 142 | assert record.id == str(i) 143 | 144 | 145 | @pytest.mark.skipif( 146 | not kafka_brokers, 147 | reason="KAFKA_BROKERS must be defined to run kafka tests", 148 | ) 149 | @pytest.mark.parametrize("partition_key", [0, "some_string"]) 150 | def test_async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key): 151 | asyncio.run(async_test_write_to_kafka_full_event_readback(kafka_topic_setup_teardown, partition_key)) 152 | -------------------------------------------------------------------------------- /integration/test_redis_specific.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import pytest 16 | 17 | from storey import JoinWithTable, NoSqlTarget, Reduce, SyncEmitSource, Table, build_flow 18 | from storey.redis_driver import RedisDriver 19 | 20 | from .integration_test_utils import append_return, get_redis_client 21 | 22 | 23 | @pytest.fixture() 24 | def redis(): 25 | return get_redis_client() 26 | 27 | 28 | def test_redis_driver_write(redis): 29 | try: 30 | table_name = "test_redis_driver_write" 31 | 32 | driver = RedisDriver(redis) 33 | controller = build_flow([SyncEmitSource(), NoSqlTarget(Table(table_name, driver))]).run() 34 | controller.emit({"col1": 0}, "key") 35 | controller.terminate() 36 | controller.await_termination() 37 | 38 | table_name = f"{table_name}/" 39 | hash_key = RedisDriver.make_key("storey:", table_name, "key") 40 | redis_key = RedisDriver._static_data_key(hash_key) 41 | 42 | cursor = 0 43 | data = {} 44 | while True: 45 | cursor, v = driver.redis.hscan(redis_key, cursor, match=f"[^{driver.INTERFNAL_FIELD_PREFIX}]*") 46 | data.update(v) 47 | if cursor == 0: 48 | break 49 | data_strings = {} 50 | for key, val in data.items(): 51 | if isinstance(key, bytes): 52 | data_strings[key.decode("utf-8")] = val.decode("utf-8") 53 | else: 54 | data_strings[key] = val 55 | 56 | assert data_strings == {"col1": "0"} 57 | finally: 58 | for key in driver.redis.scan_iter(f"*storey:{table_name}*"): 59 | driver.redis.delete(key) 60 | 61 | 62 | def test_redis_driver_join(redis): 63 | try: 64 | table_name = "test_redis_driver_join" 65 | 66 | driver = RedisDriver(redis) 67 | table = Table(table_name, driver) 68 | table_name = f"{table_name}/" 69 | 70 | # Create the data we'll join with in Redis. 71 | hash_key = RedisDriver.make_key("storey:", table_name, "2") 72 | redis_key = RedisDriver._static_data_key(hash_key) 73 | 74 | driver.redis.hset(redis_key, mapping={"name": "1234"}) 75 | controller = build_flow( 76 | [ 77 | SyncEmitSource(), 78 | JoinWithTable(table, lambda x: x["col2"]), 79 | Reduce([], lambda acc, x: append_return(acc, x)), 80 | ] 81 | ).run() 82 | 83 | controller.emit({"col1": 1, "col2": "2"}, "key") 84 | controller.emit({"col1": 1, "col2": "2"}, "key") 85 | controller.terminate() 86 | termination_result = controller.await_termination() 87 | 88 | expected_result = [{"col1": 1, "col2": "2", "name": 1234}, {"col1": 1, "col2": "2", "name": 1234}] 89 | 90 | assert termination_result == expected_result 91 | finally: 92 | for key in driver.redis.scan_iter(f"*storey:{table_name}*"): 93 | driver.redis.delete(key) 94 | -------------------------------------------------------------------------------- /integration/test_s3_filesystem_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import asyncio 16 | import os 17 | import uuid 18 | 19 | import pandas as pd 20 | import pytest 21 | 22 | from storey import ( 23 | AsyncEmitSource, 24 | CSVSource, 25 | CSVTarget, 26 | FlatMap, 27 | Map, 28 | ParquetTarget, 29 | Reduce, 30 | SyncEmitSource, 31 | build_flow, 32 | ) 33 | 34 | from .integration_test_utils import _generate_table_name 35 | 36 | has_s3_credentials = os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY") and os.getenv("AWS_BUCKET") 37 | if has_s3_credentials: 38 | from s3fs import S3FileSystem 39 | 40 | 41 | @pytest.fixture() 42 | def s3_create_csv(): 43 | # Setup 44 | aws_bucket = os.getenv("AWS_BUCKET") 45 | file_path = _generate_table_name(f"{aws_bucket}/s3_storey") 46 | 47 | _write_test_csv(file_path) 48 | 49 | # Test runs 50 | yield file_path 51 | 52 | # Teardown 53 | _delete_file(file_path) 54 | 55 | 56 | @pytest.fixture() 57 | def s3_teardown_file(): 58 | # Setup 59 | aws_bucket = os.getenv("AWS_BUCKET") 60 | file_path = _generate_table_name(f"{aws_bucket}/s3_storey") 61 | 62 | # Test runs 63 | yield file_path 64 | 65 | # Teardown 66 | _delete_file(file_path) 67 | 68 | 69 | @pytest.fixture() 70 | def s3_setup_teardown_test(): 71 | # Setup 72 | table_name = _generate_table_name(f'{os.getenv("AWS_BUCKET")}/csv_test') 73 | 74 | # Test runs 75 | yield table_name 76 | 77 | # Teardown 78 | s3_recursive_delete(table_name) 79 | 80 | 81 | @pytest.fixture() 82 | def s3_teardown_file_in_bucket(): 83 | full_type_path = "" 84 | 85 | def _create_file_name(file_type): 86 | aws_bucket = os.getenv("AWS_BUCKET") 87 | nonlocal full_type_path 88 | full_type_path = f"{aws_bucket}/file_in_bucket.{file_type}" 89 | return full_type_path 90 | 91 | yield _create_file_name 92 | 93 | # Teardown 94 | _delete_file(full_type_path) 95 | 96 | 97 | def _write_test_csv(file_path): 98 | s3_fs = S3FileSystem() 99 | data = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 100 | with s3_fs.open(file_path, "w") as f: 101 | f.write(data) 102 | 103 | 104 | def _delete_file(path): 105 | s3_fs = S3FileSystem() 106 | s3_fs.delete(path) 107 | 108 | 109 | def s3_recursive_delete(path): 110 | s3_fs = S3FileSystem() 111 | s3_fs.rm(path, True) 112 | 113 | 114 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 115 | def test_csv_reader_from_s3(s3_create_csv): 116 | controller = build_flow( 117 | [ 118 | CSVSource(f"s3://{s3_create_csv}"), 119 | FlatMap(lambda x: x), 120 | Map(lambda x: int(x)), 121 | Reduce(0, lambda acc, x: acc + x), 122 | ] 123 | ).run() 124 | 125 | termination_result = controller.await_termination() 126 | assert termination_result == 495 127 | 128 | 129 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 130 | def test_csv_reader_from_s3_error_on_file_not_found(): 131 | with pytest.raises(FileNotFoundError): 132 | controller = build_flow( 133 | [ 134 | CSVSource(f's3://{os.getenv("AWS_BUCKET")}/idontexist.csv'), 135 | ] 136 | ).run() 137 | controller.await_termination() 138 | 139 | 140 | async def async_test_write_csv_to_s3(s3_teardown_csv): 141 | controller = build_flow( 142 | [ 143 | AsyncEmitSource(), 144 | CSVTarget(f"s3://{s3_teardown_csv}", columns=["n", "n*10"], header=True), 145 | ] 146 | ).run() 147 | 148 | for i in range(10): 149 | await controller.emit([i, 10 * i]) 150 | 151 | await controller.terminate() 152 | await controller.await_termination() 153 | 154 | actual = S3FileSystem().open(s3_teardown_csv).read() 155 | 156 | expected = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 157 | assert actual.decode("utf-8") == expected 158 | 159 | 160 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 161 | def test_write_csv_to_s3(s3_teardown_file): 162 | asyncio.run(async_test_write_csv_to_s3(s3_teardown_file)) 163 | 164 | 165 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 166 | def test_write_csv_with_dict_to_s3(s3_teardown_file): 167 | file_path = f"s3://{s3_teardown_file}" 168 | controller = build_flow([SyncEmitSource(), CSVTarget(file_path, columns=["n", "n*10"], header=True)]).run() 169 | 170 | for i in range(10): 171 | controller.emit({"n": i, "n*10": 10 * i}) 172 | 173 | controller.terminate() 174 | controller.await_termination() 175 | 176 | actual = S3FileSystem().open(s3_teardown_file).read() 177 | expected = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 178 | assert actual.decode("utf-8") == expected 179 | 180 | 181 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 182 | def test_write_csv_infer_columns_without_header_to_s3(s3_teardown_file): 183 | file_path = f"s3://{s3_teardown_file}" 184 | controller = build_flow([SyncEmitSource(), CSVTarget(file_path)]).run() 185 | 186 | for i in range(10): 187 | controller.emit({"n": i, "n*10": 10 * i}) 188 | 189 | controller.terminate() 190 | controller.await_termination() 191 | 192 | actual = S3FileSystem().open(s3_teardown_file).read() 193 | expected = "0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 194 | assert actual.decode("utf-8") == expected 195 | 196 | 197 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 198 | def test_write_csv_from_lists_with_metadata_and_column_pruning_to_s3(s3_teardown_file): 199 | file_path = f"s3://{s3_teardown_file}" 200 | controller = build_flow( 201 | [ 202 | SyncEmitSource(), 203 | CSVTarget(file_path, columns=["event_key=$key", "n*10"], header=True), 204 | ] 205 | ).run() 206 | 207 | for i in range(10): 208 | controller.emit({"n": i, "n*10": 10 * i}, key=f"key{i}") 209 | 210 | controller.terminate() 211 | controller.await_termination() 212 | 213 | actual = S3FileSystem().open(s3_teardown_file).read() 214 | expected = ( 215 | "event_key,n*10\n" 216 | "key0,0\n" 217 | "key1,10\n" 218 | "key2,20\n" 219 | "key3,30\n" 220 | "key4,40\n" 221 | "key5,50\n" 222 | "key6,60\n" 223 | "key7,70\n" 224 | "key8,80\n" 225 | "key9,90\n" 226 | ) 227 | 228 | assert actual.decode("utf-8") == expected 229 | 230 | 231 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 232 | def test_write_to_parquet_to_s3(s3_setup_teardown_test): 233 | out_dir = f"s3://{s3_setup_teardown_test}" 234 | columns = ["my_int", "my_string"] 235 | controller = build_flow( 236 | [ 237 | SyncEmitSource(), 238 | ParquetTarget(out_dir, partition_cols="my_int", columns=columns, max_events=1), 239 | ] 240 | ).run() 241 | 242 | expected = [] 243 | for i in range(10): 244 | controller.emit([i, f"this is {i}"]) 245 | expected.append([i, f"this is {i}"]) 246 | expected = pd.DataFrame(expected, columns=columns) 247 | controller.terminate() 248 | controller.await_termination() 249 | 250 | read_back_df = pd.read_parquet(out_dir, columns=columns) 251 | read_back_df["my_int"] = read_back_df["my_int"].astype("int64") 252 | assert read_back_df.equals(expected), f"{read_back_df}\n!=\n{expected}" 253 | 254 | 255 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 256 | def test_write_to_parquet_to_s3_single_file_on_termination(s3_setup_teardown_test): 257 | out_file = f"s3://{s3_setup_teardown_test}myfile.pq" 258 | columns = ["my_int", "my_string"] 259 | controller = build_flow([SyncEmitSource(), ParquetTarget(out_file, columns=columns)]).run() 260 | 261 | expected = [] 262 | for i in range(10): 263 | controller.emit([i, f"this is {i}"]) 264 | expected.append([i, f"this is {i}"]) 265 | expected = pd.DataFrame(expected, columns=columns) 266 | controller.terminate() 267 | controller.await_termination() 268 | 269 | read_back_df = pd.read_parquet(out_file, columns=columns) 270 | assert read_back_df.equals(expected), f"{read_back_df}\n!=\n{expected}" 271 | 272 | 273 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 274 | def test_write_to_parquet_to_s3_with_indices(s3_setup_teardown_test): 275 | out_file = f"s3://{s3_setup_teardown_test}test_write_to_parquet_with_indices{uuid.uuid4().hex}/" 276 | controller = build_flow( 277 | [ 278 | SyncEmitSource(), 279 | ParquetTarget(out_file, index_cols="event_key=$key", columns=["my_int", "my_string"]), 280 | ] 281 | ).run() 282 | 283 | expected = [] 284 | for i in range(10): 285 | controller.emit([i, f"this is {i}"], key=f"key{i}") 286 | expected.append([f"key{i}", i, f"this is {i}"]) 287 | columns = ["event_key", "my_int", "my_string"] 288 | expected = pd.DataFrame(expected, columns=columns) 289 | expected.set_index(["event_key"], inplace=True) 290 | controller.terminate() 291 | controller.await_termination() 292 | 293 | read_back_df = pd.read_parquet(out_file, columns=columns) 294 | # when reading from buckets lines order can be scrambled 295 | read_back_df.sort_values("event_key", inplace=True) 296 | assert read_back_df.equals(expected), f"{read_back_df}\n!=\n{expected}" 297 | 298 | 299 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 300 | def test_write_csv_to_s3_bucket_directly(s3_teardown_file_in_bucket): 301 | file_path = f's3://{s3_teardown_file_in_bucket("csv")}' 302 | 303 | controller = build_flow([SyncEmitSource(), CSVTarget(file_path, columns=["n", "n*10"], header=True)]).run() 304 | 305 | for i in range(10): 306 | controller.emit({"n": i, "n*10": 10 * i}) 307 | 308 | controller.terminate() 309 | controller.await_termination() 310 | 311 | actual = S3FileSystem().open(s3_teardown_file_in_bucket("csv")).read() 312 | expected = "n,n*10\n0,0\n1,10\n2,20\n3,30\n4,40\n5,50\n6,60\n7,70\n8,80\n9,90\n" 313 | assert actual.decode("utf-8") == expected 314 | 315 | 316 | @pytest.mark.skipif(not has_s3_credentials, reason="No s3 credentials found") 317 | def test_write_parquet_to_s3_bucket_directly(s3_teardown_file_in_bucket): 318 | columns = ["my_int", "my_string"] 319 | 320 | file_path = f's3://{s3_teardown_file_in_bucket("parquet")}' 321 | # file_path = f'/tmp/{s3_teardown_file_in_bucket("parquet")}' 322 | controller = build_flow([SyncEmitSource(), ParquetTarget(file_path, columns=columns)]).run() 323 | 324 | expected = [] 325 | for i in range(10): 326 | controller.emit([i, f"this is {i}"]) 327 | expected.append([i, f"this is {i}"]) 328 | 329 | expected = pd.DataFrame(expected, columns=columns) 330 | controller.terminate() 331 | controller.await_termination() 332 | 333 | read_back_df = pd.read_parquet(file_path, columns=columns) 334 | 335 | assert read_back_df.equals(expected), f"{read_back_df}\n!=\n{expected}" 336 | -------------------------------------------------------------------------------- /integration/test_tdengine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Iterator 3 | from datetime import datetime, timezone 4 | from typing import Optional 5 | 6 | import pytest 7 | import taosws 8 | 9 | from storey import SyncEmitSource, build_flow 10 | from storey.targets import TDEngineTarget 11 | 12 | url = os.getenv("TDENGINE_URL") # e.g.: taosws://root:taosdata@localhost:6041 13 | user = os.getenv("TDENGINE_USER") 14 | password = os.getenv("TDENGINE_PASSWORD") 15 | has_tdengine_credentials = all([url, user, password]) or (url and url.startswith("taosws://")) 16 | 17 | pytestmark = pytest.mark.skipif(not has_tdengine_credentials, reason="Missing TDEngine URL, user, and/or password") 18 | 19 | TDEngineData = tuple[taosws.Connection, str, Optional[str], Optional[str], str, str] 20 | 21 | 22 | @pytest.fixture(params=[("ms", 10), ("us", 10)]) 23 | def tdengine(request: "pytest.FixtureRequest") -> Iterator[TDEngineData]: 24 | timestamp_precision = request.param[0] 25 | nchar_size = request.param[1] 26 | 27 | db_name = "storey" 28 | supertable_name = "test_supertable" 29 | 30 | if url.startswith("taosws://"): 31 | connection = taosws.connect(url) 32 | else: 33 | connection = taosws.connect(url=url, user=user, password=password) 34 | 35 | try: 36 | connection.execute(f"DROP DATABASE {db_name};") 37 | except taosws.QueryError as err: # websocket connection raises QueryError 38 | if "Database not exist" not in str(err): 39 | raise err 40 | 41 | connection.execute(f"CREATE DATABASE {db_name} PRECISION '{timestamp_precision}';") 42 | connection.execute(f"USE {db_name}") 43 | 44 | try: 45 | connection.execute(f"DROP STABLE {supertable_name};") 46 | except taosws.QueryError as err: # websocket connection raises QueryError 47 | if "STable not exist" not in str(err): 48 | raise err 49 | 50 | connection.execute( 51 | f"CREATE STABLE {supertable_name} (time TIMESTAMP, my_string NCHAR({nchar_size})) TAGS (my_int INT);" 52 | ) 53 | 54 | # Test runs 55 | yield connection, url, user, password, db_name, supertable_name 56 | 57 | # Teardown 58 | connection.execute(f"DROP DATABASE {db_name};") 59 | connection.close() 60 | 61 | 62 | @pytest.mark.parametrize("table_col", [None, "$key", "table"]) 63 | def test_tdengine_target(tdengine: TDEngineData, table_col: Optional[str]) -> None: 64 | connection, url, user, password, db_name, supertable_name = tdengine 65 | time_format = "%d/%m/%y %H:%M:%S UTC%z" 66 | 67 | table_name = "test_table" 68 | 69 | # Table is created automatically only when using a supertable 70 | if not table_col: 71 | connection.execute(f"CREATE TABLE {table_name} (time TIMESTAMP, my_string NCHAR(10), my_int INT);") 72 | 73 | controller = build_flow( 74 | [ 75 | SyncEmitSource(), 76 | TDEngineTarget( 77 | url=url, 78 | time_col="time", 79 | columns=["my_string"] if table_col else ["my_string", "my_int"], 80 | user=user, 81 | password=password, 82 | database=db_name, 83 | table=None if table_col else table_name, 84 | table_col=table_col, 85 | supertable=supertable_name if table_col else None, 86 | tag_cols=["my_int"] if table_col else None, 87 | time_format=time_format, 88 | max_events=10, 89 | ), 90 | ] 91 | ).run() 92 | 93 | date_time_str = "18/09/19 01:55:1" 94 | for i in range(5): 95 | timestamp = f"{date_time_str}{i} UTC-0000" 96 | event_body = {"time": timestamp, "my_int": i, "my_string": f"hello{i}"} 97 | event_key = None 98 | subtable_name = f"{table_name}{i}" 99 | if table_col == "$key": 100 | event_key = subtable_name 101 | elif table_col: 102 | event_body[table_col] = subtable_name 103 | controller.emit(event_body, event_key) 104 | 105 | controller.terminate() 106 | controller.await_termination() 107 | 108 | if table_col: 109 | query_table = supertable_name 110 | where_clause = " WHERE my_int > 0 AND my_int < 3" 111 | else: 112 | query_table = table_name 113 | where_clause = "" 114 | result = connection.query(f"SELECT * FROM {query_table} {where_clause} ORDER BY my_int;") 115 | result_list = [] 116 | for row in result: 117 | row = list(row) 118 | for field_index, field in enumerate(result.fields): 119 | typ = field.type() 120 | if typ == "TIMESTAMP": 121 | t = datetime.fromisoformat(row[field_index]) 122 | # websocket returns a timestamp with the local time zone 123 | t = t.astimezone(timezone.utc).replace(tzinfo=None) 124 | row[field_index] = t 125 | result_list.append(row) 126 | if table_col: 127 | expected_result = [ 128 | [datetime(2019, 9, 18, 1, 55, 11), "hello1", 1], 129 | [datetime(2019, 9, 18, 1, 55, 12), "hello2", 2], 130 | ] 131 | else: 132 | expected_result = [ 133 | [datetime(2019, 9, 18, 1, 55, 10), "hello0", 0], 134 | [datetime(2019, 9, 18, 1, 55, 11), "hello1", 1], 135 | [datetime(2019, 9, 18, 1, 55, 12), "hello2", 2], 136 | [datetime(2019, 9, 18, 1, 55, 13), "hello3", 3], 137 | [datetime(2019, 9, 18, 1, 55, 14), "hello4", 4], 138 | ] 139 | assert result_list == expected_result 140 | 141 | 142 | @pytest.mark.parametrize("tdengine", [("ms", 100)], indirect=["tdengine"]) 143 | def test_sql_injection(tdengine: TDEngineData) -> None: 144 | connection, url, user, password, db_name, supertable_name = tdengine 145 | # Create another table to be dropped via SQL injection 146 | tb_name = "dont_drop_me" 147 | connection.execute(f"CREATE TABLE IF NOT EXISTS {tb_name} USING {supertable_name} TAGS (101);") 148 | extra_table_query = f"SHOW TABLES LIKE '{tb_name}';" 149 | assert list(connection.query(extra_table_query)), "The extra table was not created" 150 | 151 | # Try dropping the table 152 | table_name = "test_table" 153 | table_col = "table" 154 | controller = build_flow( 155 | [ 156 | SyncEmitSource(), 157 | TDEngineTarget( 158 | url=url, 159 | time_col="time", 160 | columns=["my_string"], 161 | user=user, 162 | password=password, 163 | database=db_name, 164 | table_col=table_col, 165 | supertable=supertable_name, 166 | tag_cols=["my_int"], 167 | time_format="%d/%m/%y %H:%M:%S UTC%z", 168 | max_events=10, 169 | ), 170 | ] 171 | ).run() 172 | 173 | date_time_str = "18/09/19 01:55:1" 174 | for i in range(5): 175 | timestamp = f"{date_time_str}{i} UTC-0000" 176 | subtable_name = f"{table_name}{i}" 177 | event_body = {"time": timestamp, "my_int": i, "my_string": f"s); DROP TABLE {tb_name};"} 178 | event_body[table_col] = subtable_name 179 | controller.emit(event_body) 180 | 181 | controller.terminate() 182 | controller.await_termination() 183 | 184 | assert list(connection.query(extra_table_query)), "The extra table was dropped" 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp~=3.8 2 | v3io>=0.6.10, <0.8 3 | # exclude pandas 1.5.0 due to https://github.com/pandas-dev/pandas/issues/48767 4 | # and 1.5.* due to https://github.com/pandas-dev/pandas/issues/49203 5 | # pandas 2.2 requires sqlalchemy 2 6 | pandas>=1, !=1.5.*, <2.2 7 | # upper limit is just a safeguard - tested with numpy 1.26.2 8 | numpy>=1.16.5,<1.27 9 | # <18 is just a safeguard - no tests performed with pyarrow higher than 17 10 | pyarrow>=1,<18 11 | v3io-frames>=0.10.14, !=0.11.*, !=0.12.* 12 | fsspec>=0.6.2 13 | v3iofs~=0.1.17 14 | xxhash>=1 15 | nuclio-sdk>=0.5.3 16 | -------------------------------------------------------------------------------- /set-version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from os import environ 16 | 17 | 18 | def set_version(): 19 | version = environ.get("GITHUB_REF") 20 | assert version, "GITHUB_REF is not defined" 21 | 22 | version = version.replace("refs/tags/v", "") 23 | 24 | lines = [] 25 | init_py = "storey/__init__.py" 26 | with open(init_py) as fp: 27 | for line in fp: 28 | if "__version__" in line: 29 | line = f'__version__ = "{version}"\n' 30 | lines.append(line) 31 | 32 | with open(init_py, "w") as out: 33 | out.write("".join(lines)) 34 | 35 | 36 | set_version() 37 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | .git 4 | venv 5 | max-line-length = 140 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import find_packages, setup 16 | 17 | 18 | def version() -> str: 19 | with open("storey/__init__.py") as fp: 20 | for line in fp: 21 | if line.startswith("__version__"): 22 | _, version = line.split("=") 23 | return version.replace('"', "").strip() 24 | raise ValueError("Could not find package version") 25 | 26 | 27 | def load_deps(file_name): 28 | """Load dependencies from requirements file""" 29 | deps = [] 30 | with open(file_name) as fp: 31 | for line in fp: 32 | line = line.strip() 33 | if not line or line[0] == "#": 34 | continue 35 | deps.append(line) 36 | return deps 37 | 38 | 39 | install_requires = load_deps("requirements.txt") 40 | tests_require = load_deps("dev-requirements.txt") 41 | extras_require = { 42 | "kafka": ["kafka-python~=2.0"], 43 | "redis": ["redis~=4.3"], 44 | "sqlalchemy": ["sqlalchemy~=1.3"], 45 | "tdengine": ["taospy[ws]>=2,<3"], 46 | } 47 | 48 | 49 | with open("README.md") as fp: 50 | long_desc = fp.read() 51 | 52 | setup( 53 | name="storey", 54 | version=version(), 55 | description="Async flows", 56 | long_description=long_desc, 57 | long_description_content_type="text/markdown", 58 | author="Iguazio", 59 | author_email="yaronh@iguazio.com", 60 | license="Apache", 61 | url="https://github.com/mlrun/storey", 62 | packages=find_packages(include=["storey*"]), 63 | python_requires=">=3.9", 64 | install_requires=install_requires, 65 | extras_require=extras_require, 66 | classifiers=[ 67 | "Development Status :: 4 - Beta", 68 | "Intended Audience :: Developers", 69 | "License :: OSI Approved :: Apache Software License", 70 | "Operating System :: POSIX :: Linux", 71 | "Operating System :: Microsoft :: Windows", 72 | "Operating System :: MacOS", 73 | "Programming Language :: Python :: 3", 74 | "Programming Language :: Python :: 3.9", 75 | "Programming Language :: Python", 76 | "Topic :: Software Development :: Libraries :: Python Modules", 77 | "Topic :: Software Development :: Libraries", 78 | ], 79 | tests_require=tests_require, 80 | ) 81 | -------------------------------------------------------------------------------- /storey/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | __version__ = "0.0.0+unstable" 16 | 17 | # Importing supported filesystems explicitly so that they will get registered as an fsspec filesystem 18 | import v3iofs # noqa: F401 19 | 20 | from .aggregations import AggregateByKey, QueryByKey # noqa: F401 21 | from .dataframe import ReduceToDataFrame, ToDataFrame # noqa: F401 22 | from .drivers import Driver, NoopDriver, V3ioDriver # noqa: F401 23 | from .dtypes import EmissionType # noqa: F401 24 | from .dtypes import EmitAfterDelay # noqa: F401 25 | from .dtypes import EmitAfterMaxEvent # noqa: F401 26 | from .dtypes import EmitAfterPeriod # noqa: F401 27 | from .dtypes import EmitAfterWindow # noqa: F401 28 | from .dtypes import EmitEveryEvent # noqa: F401 29 | from .dtypes import EmitPolicy # noqa: F401 30 | from .dtypes import Event # noqa: F401 31 | from .dtypes import FieldAggregator # noqa: F401 32 | from .dtypes import FixedWindows # noqa: F401 33 | from .dtypes import FixedWindowType # noqa: F401 34 | from .dtypes import LateDataHandling # noqa: F401 35 | from .dtypes import SlidingWindows # noqa: F401 36 | from .flow import Batch # noqa: F401 37 | from .flow import Choice # noqa: F401 38 | from .flow import Complete # noqa: F401 39 | from .flow import ConcurrentExecution # noqa: F401 40 | from .flow import Context # noqa: F401 41 | from .flow import Extend # noqa: F401 42 | from .flow import Filter # noqa: F401 43 | from .flow import FlatMap # noqa: F401 44 | from .flow import Flow # noqa: F401 45 | from .flow import FlowError # noqa: F401 46 | from .flow import HttpRequest # noqa: F401 47 | from .flow import HttpResponse # noqa: F401 48 | from .flow import JoinWithTable # noqa: F401 49 | from .flow import JoinWithV3IOTable # noqa: F401 50 | from .flow import Map # noqa: F401 51 | from .flow import MapClass # noqa: F401 52 | from .flow import MapWithState # noqa: F401 53 | from .flow import ParallelExecution # noqa: F401 54 | from .flow import ParallelExecutionRunnable # noqa: F401 55 | from .flow import Recover # noqa: F401 56 | from .flow import Reduce # noqa: F401 57 | from .flow import Rename # noqa: F401 58 | from .flow import SendToHttp # noqa: F401 59 | from .flow import build_flow # noqa: F401 60 | from .sources import AsyncEmitSource # noqa: F401 61 | from .sources import CSVSource # noqa: F401 62 | from .sources import DataframeSource # noqa: F401 63 | from .sources import ParquetSource # noqa: F401 64 | from .sources import SQLSource # noqa: F401 65 | from .sources import SyncEmitSource # noqa: F401 66 | from .sql_driver import SQLDriver # noqa: F401 67 | from .table import Table # noqa: F401 68 | from .targets import CSVTarget # noqa: F401 69 | from .targets import KafkaTarget # noqa: F401 70 | from .targets import NoSqlTarget # noqa: F401 71 | from .targets import ParquetTarget # noqa: F401 72 | from .targets import StreamTarget # noqa: F401 73 | from .targets import TDEngineTarget # noqa: F401 74 | from .targets import TSDBTarget # noqa: F401 75 | 76 | # clear module namespace 77 | del v3iofs 78 | -------------------------------------------------------------------------------- /storey/aggregation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import math 16 | 17 | _aggrTypeNone = 0 18 | _aggrTypeCount = 1 19 | _aggrTypeSum = 2 20 | _aggrTypeSqr = 4 21 | _aggrTypeMax = 8 22 | _aggrTypeMin = 16 23 | _aggrTypeLast = 32 24 | _aggrTypeFirst = 64 25 | 26 | # Derived aggregates 27 | _aggrTypeAvg = _aggrTypeCount | _aggrTypeSum 28 | _aggrTypeRate = _aggrTypeLast | 0x8000 29 | _aggrTypeStddev = _aggrTypeCount | _aggrTypeSum | _aggrTypeSqr 30 | _aggrTypeStdvar = _aggrTypeCount | _aggrTypeSum | _aggrTypeSqr | 0x8000 31 | _aggrTypeAll = 0xFFFF 32 | 33 | _raw_aggregates = [ 34 | _aggrTypeCount, 35 | _aggrTypeSum, 36 | _aggrTypeSqr, 37 | _aggrTypeMax, 38 | _aggrTypeMin, 39 | _aggrTypeLast, 40 | ] 41 | _raw_aggregates_by_name = { 42 | "count": _aggrTypeCount, 43 | "sum": _aggrTypeSum, 44 | "sqr": _aggrTypeSqr, 45 | "max": _aggrTypeMax, 46 | "min": _aggrTypeMin, 47 | "first": _aggrTypeFirst, 48 | "last": _aggrTypeLast, 49 | } 50 | _all_aggregates_by_name = { 51 | "count": _aggrTypeCount, 52 | "sum": _aggrTypeSum, 53 | "sqr": _aggrTypeSqr, 54 | "max": _aggrTypeMax, 55 | "min": _aggrTypeMin, 56 | "first": _aggrTypeFirst, 57 | "last": _aggrTypeLast, 58 | "avg": _aggrTypeAvg, 59 | "stdvar": _aggrTypeStdvar, 60 | "stddev": _aggrTypeStddev, 61 | } 62 | _all_aggregates_to_name = { 63 | _aggrTypeCount: "count", 64 | _aggrTypeSum: "sum", 65 | _aggrTypeSqr: "sqr", 66 | _aggrTypeMax: "max", 67 | _aggrTypeMin: "min", 68 | _aggrTypeFirst: "first", 69 | _aggrTypeLast: "last", 70 | _aggrTypeAvg: "avg", 71 | _aggrTypeStdvar: "stdvar", 72 | _aggrTypeStddev: "stddev", 73 | } 74 | 75 | 76 | def is_raw_aggregate(aggregate): 77 | return aggregate in _raw_aggregates_by_name 78 | 79 | 80 | def _avg(args): 81 | count = args[0] 82 | sum = args[1] 83 | if count == 0: 84 | return math.nan 85 | return sum / count 86 | 87 | 88 | def _stddev(args): 89 | count = args[0] 90 | if count == 0 or count == 1: 91 | return math.nan 92 | sum = args[1] 93 | sqr = args[2] 94 | 95 | return math.sqrt((count * sqr - sum * sum) / (count * (count - 1))) 96 | 97 | 98 | def _stdvar(args): 99 | count = args[0] 100 | if count == 0 or count == 1: 101 | return math.nan 102 | sum = args[1] 103 | sqr = args[2] 104 | return (count * sqr - sum * sum) / (count * (count - 1)) 105 | 106 | 107 | def get_virtual_aggregation_func(aggregation): 108 | if aggregation == "avg": 109 | return _avg 110 | if aggregation == "stdvar": 111 | return _stdvar 112 | if aggregation == "stddev": 113 | return _stddev 114 | 115 | raise TypeError(f'"{aggregation}" aggregator is not defined') 116 | 117 | 118 | def get_implied_aggregates(aggregate): 119 | aggrs = [] 120 | aggr_bits = _all_aggregates_by_name[aggregate] 121 | for raw_aggr in _raw_aggregates: 122 | if aggr_bits & raw_aggr == raw_aggr: 123 | aggrs.append(_all_aggregates_to_name[raw_aggr]) 124 | return aggrs 125 | 126 | 127 | def get_all_raw_aggregates_with_hidden(aggregates): 128 | raw_aggregates = {} 129 | 130 | for aggregate in aggregates: 131 | if is_raw_aggregate(aggregate): 132 | raw_aggregates[aggregate] = False 133 | else: 134 | for dependant_aggr in get_implied_aggregates(aggregate): 135 | if dependant_aggr not in raw_aggregates: 136 | raw_aggregates[dependant_aggr] = True 137 | 138 | return raw_aggregates 139 | 140 | 141 | def get_all_raw_aggregates(aggregates): 142 | return set(get_all_raw_aggregates_with_hidden(aggregates).keys()) 143 | 144 | 145 | def is_aggregation_name(name: str): 146 | return name in _all_aggregates_by_name 147 | -------------------------------------------------------------------------------- /storey/dataframe.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from typing import List, Optional 16 | 17 | import pandas as pd 18 | 19 | from .flow import Flow, _termination_obj 20 | 21 | 22 | class ReduceToDataFrame(Flow): 23 | """Builds a pandas DataFrame from events and returns that DataFrame on flow termination. 24 | 25 | :param index: Name of the column to be used as index. Optional. If not set, DataFrame will be range indexed. 26 | :param columns: List of column names to be passed as-is to the DataFrame constructor. Optional. 27 | :param insert_key_column_as: Name of the column to be inserted for event keys. Optional. 28 | If not set, event keys will not be inserted into the DataFrame. 29 | :param insert_id_column_as: Name of the column to be inserted for event IDs. Optional. 30 | If not set, event IDs will not be inserted into the DataFrame. 31 | 32 | for additional params, see documentation of :class:`storey.flow.Flow` 33 | 34 | """ 35 | 36 | def __init__( 37 | self, 38 | index: Optional[str] = None, 39 | columns: Optional[List[str]] = None, 40 | insert_key_column_as: Optional[str] = None, 41 | insert_processing_time_column_as: Optional[str] = None, 42 | insert_id_column_as: Optional[str] = None, 43 | **kwargs, 44 | ): 45 | super().__init__(**kwargs) 46 | self._index = index 47 | self._columns = columns 48 | self._insert_key_column_as = insert_key_column_as 49 | self._insert_processing_time_column_as = insert_processing_time_column_as 50 | self._insert_id_column_as = insert_id_column_as 51 | 52 | def _init(self): 53 | super()._init() 54 | self._key_column = [] 55 | self._processing_time_column = [] 56 | self._id_column = [] 57 | self._data = [] 58 | 59 | def to(self, outlet): 60 | """Pipe this step to next one. Throws exception since illegal""" 61 | raise ValueError("ToDataFrame is a terminal step. It cannot be piped further.") 62 | 63 | async def _do(self, event): 64 | if event is _termination_obj: 65 | df = pd.DataFrame(self._data, columns=self._columns) 66 | if not df.empty: 67 | if self._insert_key_column_as: 68 | df[self._insert_key_column_as] = pd.DataFrame(self._key_column) 69 | if self._insert_processing_time_column_as: 70 | df[self._insert_processing_time_column_as] = self._processing_time_column 71 | if self._insert_id_column_as: 72 | df[self._insert_id_column_as] = self._id_column 73 | if self._index: 74 | df.set_index(self._index, inplace=True) 75 | return df 76 | else: 77 | body = event.body 78 | if isinstance(body, dict) or isinstance(body, list): 79 | self._data.append(body) 80 | if self._insert_key_column_as: 81 | self._key_column.append(event.key) 82 | if self._insert_processing_time_column_as: 83 | self._processing_time_column.append(event.processing_time) 84 | if self._insert_id_column_as: 85 | self._id_column.append(event.id) 86 | else: 87 | raise ValueError(f"ToDataFrame step only supports input of type dictionary or list, not {type(body)}") 88 | 89 | 90 | class ToDataFrame(Flow): 91 | """Create pandas data frame from events. Can appear in the middle of the flow, as opposed to ReduceToDataFrame 92 | 93 | :param index: Name of the column to be used as index. Optional. If not set, DataFrame will be range indexed. 94 | :param columns: List of column names to be passed as-is to the DataFrame constructor. Optional. 95 | 96 | for additional params, see documentation of :class:`storey.flow.Flow` 97 | """ 98 | 99 | def __init__(self, index: Optional[str] = None, columns: Optional[List[str]] = None, **kwargs): 100 | super().__init__(**kwargs) 101 | self._index = index 102 | self._columns = columns 103 | 104 | async def _do(self, event): 105 | if event is _termination_obj: 106 | return await self._do_downstream(_termination_obj) 107 | else: 108 | df = pd.DataFrame(event.body, columns=self._columns) 109 | if self._index: 110 | df.set_index(self._index, inplace=True) 111 | event.body = df 112 | return await self._do_downstream(event) 113 | -------------------------------------------------------------------------------- /storey/dtypes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from datetime import datetime, timezone 16 | from enum import Enum 17 | from typing import Callable, List, Optional, Union 18 | 19 | import numpy 20 | 21 | from .aggregation_utils import get_all_raw_aggregates 22 | from .utils import bucketPerWindow, get_one_unit_of_duration, parse_duration 23 | 24 | _termination_obj = object() 25 | 26 | known_driver_schemes = ["v3io", "redis", "rediss"] 27 | 28 | 29 | class Event: 30 | """The basic unit of data in storey. All steps receive and emit events. 31 | 32 | :param body: the event payload, or data 33 | :param key: Event key. Used by steps that aggregate events by key, such as AggregateByKey. (Optional). Can be list 34 | :param processing_time: Event processing time. Defaults to the time the event was created, UTC. (Optional) 35 | :param id: Event identifier. Usually a unique identifier. (Optional) 36 | :param headers: Request headers (HTTP only) (Optional) 37 | :param method: Request method (HTTP only) (Optional) 38 | :param path: Request path (HTTP only) (Optional) 39 | :param content_type: Request content type (HTTP only) (Optional) 40 | :param awaitable_result: Generally not passed directly. (Optional) 41 | :type awaitable_result: AwaitableResult (Optional) 42 | """ 43 | 44 | def __init__( 45 | self, 46 | body: object, 47 | key: Optional[Union[str, List[str]]] = None, 48 | processing_time: Union[None, datetime, int, float] = None, 49 | id: Optional[str] = None, 50 | headers: Optional[dict] = None, 51 | method: Optional[str] = None, 52 | path: Optional[str] = "/", 53 | content_type=None, 54 | awaitable_result=None, 55 | ): 56 | self.body = body 57 | self.key = key 58 | if processing_time is not None and not isinstance(processing_time, datetime): 59 | if isinstance(processing_time, str): 60 | processing_time = datetime.fromisoformat(processing_time) 61 | elif isinstance(processing_time, (int, float)): 62 | processing_time = datetime.utcfromtimestamp(processing_time) 63 | else: 64 | raise TypeError( 65 | f"Event processing_time parameter must be a datetime, string, or int. " 66 | f"Got {type(processing_time)} instead." 67 | ) 68 | self.processing_time = processing_time or datetime.now(timezone.utc) 69 | self.id = id 70 | self.headers = headers 71 | self.method = method 72 | self.path = path 73 | self.content_type = content_type 74 | self._awaitable_result = awaitable_result 75 | self.error = None 76 | 77 | def __eq__(self, other): 78 | if not isinstance(other, Event): 79 | return False 80 | 81 | return ( 82 | self.body == other.body 83 | and self.id == other.id 84 | and self.headers == other.headers 85 | and self.method == other.method 86 | and self.path == other.path 87 | and self.content_type == other.content_type 88 | ) # noqa: E127 89 | 90 | def __str__(self): 91 | return f"Event(id={self.id}, key={str(self.key)}, body={self.body})" 92 | 93 | 94 | class V3ioError(Exception): 95 | pass 96 | 97 | 98 | class RedisError(Exception): 99 | pass 100 | 101 | 102 | class FlowError(Exception): 103 | pass 104 | 105 | 106 | class TDEngineTypeError(TypeError): 107 | pass 108 | 109 | 110 | class TDEngineValueError(ValueError): 111 | pass 112 | 113 | 114 | class WindowBase: 115 | def __init__(self, window, period, window_str): 116 | self.window_millis = window 117 | self.period_millis = period 118 | self.window_str = window_str 119 | 120 | 121 | class FixedWindow(WindowBase): 122 | """ 123 | Time window representing fixed time interval. The interval will be divided to 10 periods 124 | 125 | :param window: Time window in the format [0-9]+[smhd] 126 | """ 127 | 128 | def __init__(self, window: str): 129 | window_millis = parse_duration(window) 130 | WindowBase.__init__(self, window_millis, window_millis / bucketPerWindow, window) 131 | 132 | def get_total_number_of_buckets(self): 133 | return bucketPerWindow * 2 134 | 135 | def get_window_start_time(self): 136 | return self.get_current_window() 137 | 138 | def get_current_window(self): 139 | return int((datetime.now().timestamp() * 1000) / self.window_millis) * self.window_millis 140 | 141 | def get_current_period(self): 142 | return int((datetime.now().timestamp() * 1000) / self.period_millis) * self.period_millis 143 | 144 | 145 | class SlidingWindow(WindowBase): 146 | """ 147 | Time window representing sliding time interval divided to periods. 148 | 149 | :param window: Time window in the format [0-9]+[smhd] 150 | :param period: Number of buckets to use for the window [0-9]+[smhd] 151 | """ 152 | 153 | def __init__(self, window: str, period: str): 154 | window_millis, period_millis = parse_duration(window), parse_duration(period) 155 | if not window_millis % period_millis == 0: 156 | raise ValueError("period must be a divider of the window") 157 | 158 | WindowBase.__init__(self, window_millis, period_millis, window) 159 | 160 | def get_total_number_of_buckets(self): 161 | return int(self.window_millis / self.period_millis) 162 | 163 | def get_window_start_time(self): 164 | return datetime.now().timestamp() * 1000 165 | 166 | 167 | def get_window_optimal_size_millis(windows_tuples): 168 | windows_list = [] 169 | for window_tuple in windows_tuples: 170 | windows_list.append(window_tuple[0]) 171 | return numpy.lcm.reduce(windows_list) 172 | 173 | 174 | def get_window_optimal_period_millis(windows_tuples): 175 | windows_list = [] 176 | for window_tuple in windows_tuples: 177 | windows_list.append(window_tuple[0]) 178 | return numpy.gcd.reduce(windows_list) 179 | 180 | 181 | class WindowsBase: 182 | def __init__(self, period, windows): 183 | self.max_window_millis = windows[-1][0] 184 | self.smallest_window_millis = windows[0][0] 185 | self.period_millis = period 186 | self.windows = windows # list of tuples of the form (3600000, '1h') 187 | self.window_millis = get_window_optimal_size_millis(windows) 188 | self.total_number_of_buckets = int(self.window_millis / self.period_millis) 189 | 190 | def merge(self, new): 191 | if self.period_millis != new.period_millis: 192 | raise ValueError("Cannot use different periods for same aggregation") 193 | found_new_window = False 194 | for window in new.windows: 195 | if window not in self.windows: 196 | self.windows.append(window) 197 | found_new_window = True 198 | if found_new_window: 199 | if self.max_window_millis < new.max_window_millis: 200 | self.max_window_millis = new.max_window_millis 201 | if self.smallest_window_millis > new.smallest_window_millis: 202 | self.smallest_window_millis = new.smallest_window_millis 203 | if self.total_number_of_buckets < new.total_number_of_buckets: 204 | self.total_number_of_buckets = new.total_number_of_buckets 205 | sorted(set(self.windows), key=lambda tup: tup[0]) 206 | 207 | 208 | def sort_windows_and_convert_to_millis(windows): 209 | if len(windows) == 0: 210 | raise ValueError("Windows list can not be empty") 211 | 212 | if isinstance(windows[0], str): 213 | # Validate windows order 214 | windows_tuples = [(parse_duration(window), window) for window in windows] 215 | windows_tuples.sort(key=lambda tup: tup[0]) 216 | else: 217 | # Internally windows can be passed as tuples 218 | windows_tuples = windows 219 | return windows_tuples 220 | 221 | 222 | class FixedWindows(WindowsBase): 223 | """ 224 | List of time windows representing fixed time intervals. 225 | For example: 1h will represent 1h windows starting every round hour. 226 | 227 | :param windows: List of time windows in the format [0-9]+[smhd] 228 | """ 229 | 230 | def __init__(self, windows: List[str]): 231 | windows_tuples = sort_windows_and_convert_to_millis(windows) 232 | # The period should be a divisor of the unit of the smallest window, 233 | # for example if the smallest request window is 2h, the period will be 1h / `bucketPerWindow` 234 | self.smallest_window_unit_millis = get_one_unit_of_duration(windows_tuples[0][1]) 235 | period = get_window_optimal_period_millis(windows_tuples) / bucketPerWindow 236 | WindowsBase.__init__(self, period, windows_tuples) 237 | 238 | def round_up_time_to_window(self, timestamp): 239 | return ( 240 | int(timestamp / self.smallest_window_unit_millis) * self.smallest_window_unit_millis 241 | + self.smallest_window_unit_millis 242 | ) 243 | 244 | def get_period_by_time(self, timestamp): 245 | return int(timestamp / self.period_millis) * self.period_millis 246 | 247 | def get_window_start_time_by_time(self, timestamp): 248 | return int(timestamp / self.window_millis) * self.window_millis 249 | 250 | def merge(self, new): 251 | if isinstance(new, FixedWindows): 252 | super(FixedWindows, self).merge(new) 253 | else: 254 | self.__init__(new.windows) 255 | 256 | 257 | class SlidingWindows(WindowsBase): 258 | """ 259 | List of time windows representing sliding time intervals. 260 | For example: 1h will represent 1h windows starting from the current time. 261 | 262 | :param windows: List of time windows in the format [0-9]+[smhd] 263 | :param period: Period in the format [0-9]+[smhd] 264 | """ 265 | 266 | def __init__(self, windows: List[str], period: Optional[str] = None): 267 | windows_tuples = sort_windows_and_convert_to_millis(windows) 268 | 269 | if period: 270 | period_millis = parse_duration(period) 271 | 272 | # Verify the given period is a divisor of the windows 273 | for window in windows_tuples: 274 | if not window[0] % period_millis == 0: 275 | raise ValueError( 276 | f"Period must be a divisor of every window, but period {period} does not divide {window}" 277 | ) 278 | else: 279 | # The period should be a divisor of the unit of the smallest window, 280 | # for example if the smallest request window is 2h, the period will be 1h / `bucketPerWindow` 281 | smallest_window_unit_millis = get_one_unit_of_duration(windows_tuples[0][1]) 282 | period_millis = smallest_window_unit_millis / bucketPerWindow 283 | 284 | WindowsBase.__init__(self, period_millis, windows_tuples) 285 | 286 | def get_window_start_time_by_time(self, timestamp): 287 | return int(timestamp / self.period_millis) * self.period_millis 288 | 289 | 290 | class EmissionType(Enum): 291 | All = 1 292 | Incremental = 2 293 | 294 | 295 | class EmitPolicy: 296 | def __init__(self, emission_type=EmissionType.All): 297 | self.emission_type = emission_type 298 | 299 | 300 | class EmitAfterPeriod(EmitPolicy): 301 | """ 302 | Emit event for next step after each period ends 303 | 304 | :param delay_in_seconds: Delay event emission by seconds (Optional) 305 | """ 306 | 307 | def __init__(self, delay_in_seconds: Optional[int] = 0, emission_type=EmissionType.All): 308 | self.delay_in_seconds = delay_in_seconds 309 | EmitPolicy.__init__(self, emission_type) 310 | 311 | @staticmethod 312 | def name(): 313 | return "afterPeriod" 314 | 315 | 316 | class EmitAfterWindow(EmitPolicy): 317 | """ 318 | Emit event for next step after each window ends 319 | 320 | :param delay_in_seconds: Delay event emission by seconds (Optional) 321 | """ 322 | 323 | def __init__(self, delay_in_seconds: Optional[int] = 0, emission_type=EmissionType.All): 324 | self.delay_in_seconds = delay_in_seconds 325 | EmitPolicy.__init__(self, emission_type) 326 | 327 | @staticmethod 328 | def name(): 329 | return "afterWindow" 330 | 331 | 332 | class EmitAfterMaxEvent(EmitPolicy): 333 | """ 334 | Emit the Nth event 335 | 336 | :param max_events: Which number of event to emit 337 | :param timeout_secs: Emit event after timeout expires even if it didn't reach max_events event (Optional) 338 | """ 339 | 340 | def __init__( 341 | self, 342 | max_events: int, 343 | timeout_secs: Optional[int] = None, 344 | emission_type=EmissionType.All, 345 | ): 346 | self.max_events = max_events 347 | self.timeout_secs = timeout_secs 348 | EmitPolicy.__init__(self, emission_type) 349 | 350 | @staticmethod 351 | def name(): 352 | return "maxEvents" 353 | 354 | 355 | class EmitAfterDelay(EmitPolicy): 356 | def __init__(self, delay_in_seconds, emission_type=EmissionType.All): 357 | self.delay_in_seconds = delay_in_seconds 358 | EmitPolicy.__init__(self, emission_type) 359 | 360 | @staticmethod 361 | def name(): 362 | return "afterDelay" 363 | 364 | 365 | class EmitEveryEvent(EmitPolicy): 366 | """ 367 | Emit every event 368 | """ 369 | 370 | @staticmethod 371 | def name(): 372 | return "everyEvent" 373 | 374 | pass 375 | 376 | 377 | def _dict_to_emit_policy(policy_dict): 378 | mode = policy_dict.pop("mode") 379 | if mode == EmitEveryEvent.name(): 380 | policy = EmitEveryEvent() 381 | elif mode == EmitAfterMaxEvent.name(): 382 | if "maxEvents" not in policy_dict: 383 | raise ValueError("maxEvents parameter must be specified for maxEvents emit policy") 384 | policy = EmitAfterMaxEvent(policy_dict.pop("maxEvents")) 385 | elif mode == EmitAfterDelay.name(): 386 | if "delay" not in policy_dict: 387 | raise ValueError("delay parameter must be specified for afterDelay emit policy") 388 | 389 | policy = EmitAfterDelay(policy_dict.pop("delay")) 390 | elif mode == EmitAfterWindow.name(): 391 | policy = EmitAfterWindow(delay_in_seconds=policy_dict.pop("delay", 0)) 392 | elif mode == EmitAfterPeriod.name(): 393 | policy = EmitAfterPeriod(delay_in_seconds=policy_dict.pop("delay", 0)) 394 | else: 395 | raise TypeError(f"unsupported emit policy type: {mode}") 396 | 397 | if policy_dict: 398 | raise ValueError(f"got unexpected arguments for emit policy: {policy_dict}") 399 | 400 | return policy 401 | 402 | 403 | class LateDataHandling(Enum): 404 | Nothing = 1 405 | Sort_before_emit = 2 406 | 407 | 408 | class FieldAggregator: 409 | """ 410 | Field Aggregator represents an set of aggregation features. 411 | 412 | :param name: Name for the feature. 413 | :param field: Field in the event body to aggregate. 414 | :param aggr: List of aggregates to apply. 415 | Valid values are: [count, sum, sqr, avg, max, min, last, first, sttdev, stdvar] 416 | :param windows: Time windows to aggregate the data by. 417 | :param aggr_filter: Filter specifying which events to aggregate. (Optional) 418 | :param max_value: Maximum value for the aggregation (Optional) 419 | """ 420 | 421 | def __init__( 422 | self, 423 | name: str, 424 | field: Union[str, Callable[[Event], object], None], 425 | aggr: List[str], 426 | windows: Union[FixedWindows, SlidingWindows], 427 | aggr_filter: Optional[Callable[[Event], bool]] = None, 428 | max_value: Optional[float] = None, 429 | ): 430 | if aggr_filter is not None and not callable(aggr_filter): 431 | raise TypeError(f"aggr_filter expected to be callable, got {type(aggr_filter)}") 432 | 433 | if callable(field): 434 | self.value_extractor = field 435 | elif isinstance(field, str): 436 | self.value_extractor = lambda element: element.get(field) 437 | 438 | self.name = name 439 | self.aggregations = aggr 440 | self.windows = windows 441 | self.aggr_filter = aggr_filter 442 | self.max_value = max_value 443 | 444 | def get_all_raw_aggregates(self): 445 | return get_all_raw_aggregates(self.aggregations) 446 | 447 | def should_aggregate(self, element): 448 | if not self.aggr_filter: 449 | return True 450 | 451 | return self.aggr_filter(element) 452 | 453 | 454 | class FixedWindowType(Enum): 455 | CurrentOpenWindow = 1 456 | LastClosedWindow = 2 457 | 458 | 459 | class _TDEngineField: 460 | def __init__( 461 | self, 462 | field: str, 463 | field_type: str, 464 | length: int, 465 | note: str, 466 | *args, 467 | ): 468 | self.field = field 469 | self.field_type = field_type 470 | self.length = length 471 | self.note = note 472 | -------------------------------------------------------------------------------- /storey/queue.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import asyncio 16 | import collections 17 | 18 | 19 | class AsyncQueue(asyncio.Queue): 20 | """ 21 | asyncio.Queue with a peek method added. 22 | """ 23 | 24 | async def peek(self): 25 | while self.empty(): 26 | getter = asyncio.get_running_loop().create_future() 27 | self._getters.append(getter) 28 | try: 29 | await getter 30 | except BaseException: 31 | getter.cancel() # Just in case getter is not done yet. 32 | try: 33 | # Clean self._getters from canceled getters. 34 | self._getters.remove(getter) 35 | except ValueError: 36 | # The getter could be removed from self._getters by a 37 | # previous put_nowait call. 38 | pass 39 | if not self.empty() and not getter.cancelled(): 40 | # We were woken up by put_nowait(), but can't take 41 | # the call. Wake up the next in line. 42 | self._wakeup_next(self._getters) 43 | raise 44 | return self.peek_nowait() 45 | 46 | def peek_nowait(self): 47 | if self.empty(): 48 | raise asyncio.QueueEmpty 49 | item = self._peek() 50 | self._wakeup_next(self._putters) 51 | return item 52 | 53 | def _peek(self): 54 | return self._queue[0] 55 | 56 | 57 | def _release_waiter(waiter): 58 | if not waiter.done(): 59 | waiter.set_result(False) 60 | 61 | 62 | class SimpleAsyncQueue: 63 | """ 64 | A simple async queue with built-in timeout. 65 | """ 66 | 67 | def __init__(self, capacity): 68 | self._capacity = capacity 69 | self._deque = collections.deque() 70 | self._not_empty_futures = collections.deque() 71 | self._loop = asyncio.get_running_loop() 72 | 73 | async def get(self, timeout=None): 74 | if not self._deque: 75 | not_empty_future = asyncio.get_running_loop().create_future() 76 | self._not_empty_futures.append(not_empty_future) 77 | if timeout is None: 78 | await not_empty_future 79 | else: 80 | self._loop.call_later(timeout, _release_waiter, not_empty_future) 81 | got_result = await not_empty_future 82 | if not got_result: 83 | raise TimeoutError(f"Queue get() timed out after {timeout} seconds") 84 | 85 | result = self._deque.popleft() 86 | return result 87 | 88 | async def put(self, item): 89 | while self._not_empty_futures: 90 | not_empty_future = self._not_empty_futures.popleft() 91 | if not not_empty_future.done(): 92 | not_empty_future.set_result(True) 93 | break 94 | 95 | return self._deque.append(item) 96 | 97 | def empty(self): 98 | return len(self._deque) == 0 99 | -------------------------------------------------------------------------------- /storey/sql_driver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | 16 | from typing import List, Union 17 | 18 | import pandas as pd 19 | 20 | from storey.drivers import Driver 21 | 22 | 23 | class SQLDriver(Driver): 24 | """ 25 | SQL database connector. 26 | :param primary_key: the primary key of the table, format .... or [, ,...] 27 | :param db_path: database url 28 | :param time_fields: list of all fields that are timestamps 29 | """ 30 | 31 | def __init__(self, primary_key: Union[str, List[str]], db_path: str, time_fields: List[str] = None): 32 | self._db_path = db_path 33 | self._sql_connection = None 34 | self._primary_key = primary_key if isinstance(primary_key, list) else self._extract_list_of_keys(primary_key) 35 | self._time_fields = time_fields 36 | 37 | def _lazy_init(self): 38 | import sqlalchemy as db 39 | 40 | if not self._sql_connection: 41 | self._engine = db.create_engine(self._db_path) 42 | self._sql_connection = self._engine.connect() 43 | 44 | def _table(self, table_path): 45 | import sqlalchemy as db 46 | 47 | metadata = db.MetaData() 48 | 49 | return db.Table( 50 | table_path.split("/")[-1], 51 | metadata, 52 | autoload=True, 53 | autoload_with=self._engine, 54 | ) 55 | 56 | async def _save_key(self, container, table_path, key, aggr_item, partitioned_by_key, additional_data): 57 | import sqlalchemy as db 58 | 59 | self._lazy_init() 60 | key = self._extract_list_of_keys(key) 61 | for i in range(len(self._primary_key)): 62 | additional_data[self._primary_key[i]] = key[i] 63 | table = self._table(table_path) 64 | df = pd.DataFrame(additional_data, index=[0]) 65 | try: 66 | df.to_sql(table.name, con=self._sql_connection, if_exists="append", index=False) 67 | except db.exc.IntegrityError: 68 | self._update_by_key(key, additional_data, table) 69 | 70 | async def _load_aggregates_by_key(self, container, table_path, key): 71 | self._lazy_init() 72 | table = self._table(table_path) 73 | 74 | values = await self._get_all_fields(key, table) 75 | if not values: 76 | values = None 77 | return [None, values] 78 | 79 | async def _load_by_key(self, container, table_path, key, attributes): 80 | self._lazy_init() 81 | table = self._table(table_path) 82 | if attributes == "*": 83 | values = await self._get_all_fields(key, table) 84 | else: 85 | values = await self._get_specific_fields(key, table, attributes) 86 | return values 87 | 88 | async def close(self): 89 | if self._sql_connection: 90 | self._sql_connection.close() 91 | self._sql_connection = None 92 | 93 | async def _get_all_fields(self, key, table): 94 | import sqlalchemy as db 95 | 96 | key = self._extract_list_of_keys(key) 97 | select_object = db.select(table).where( 98 | db.and_(getattr(table.c, self._primary_key[i]) == key[i] for i in range(len(self._primary_key))) 99 | ) 100 | results = pd.read_sql(select_object, con=self._sql_connection, parse_dates=self._time_fields).to_dict( 101 | orient="records" 102 | ) 103 | 104 | return results[0] 105 | 106 | async def _get_specific_fields(self, key: str, table, attributes: List[str]): 107 | import sqlalchemy as db 108 | 109 | key = self._extract_list_of_keys(key) 110 | try: 111 | select_object = db.select(*[getattr(table.c, atr) for atr in attributes]).where( 112 | db.and_(getattr(table.c, self._primary_key[i]) == key[i] for i in range(len(self._primary_key))) 113 | ) 114 | results = pd.read_sql(select_object, con=self._sql_connection, parse_dates=self._time_fields).to_dict( 115 | orient="records" 116 | ) 117 | except Exception as e: 118 | raise RuntimeError(f"Failed to get key '{key}'") from e 119 | 120 | return results[0] 121 | 122 | def supports_aggregations(self): 123 | return False 124 | 125 | def _update_by_key(self, key, data, sql_table): 126 | import sqlalchemy as db 127 | 128 | self._sql_connection.execute( 129 | db.update(sql_table) 130 | .values({getattr(sql_table.c, k): v for k, v in data.items() if k not in self._primary_key}) 131 | .where(db.and_(getattr(sql_table.c, self._primary_key[i]) == key[i] for i in range(len(self._primary_key)))) 132 | ) 133 | 134 | @staticmethod 135 | def _extract_list_of_keys(key): 136 | if isinstance(key, str): 137 | key = key.split(".") 138 | elif not isinstance(key, list): 139 | key = [key] 140 | return key 141 | -------------------------------------------------------------------------------- /storey/steps/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from .assertion import Assert # noqa: F401 16 | from .flatten import Flatten # noqa: F401 17 | from .foreach import ForEach # noqa: F401 18 | from .partition import Partition # noqa: F401 19 | from .sample import EmitPeriod, SampleWindow # noqa: F401 20 | -------------------------------------------------------------------------------- /storey/steps/assertion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from dataclasses import dataclass 16 | from typing import Any, Callable, Collection, List 17 | 18 | from storey.dtypes import _termination_obj 19 | from storey.flow import Flow 20 | 21 | 22 | @dataclass 23 | class _Operator: 24 | str: str 25 | fn: Callable[[Any, Any], bool] 26 | 27 | def __call__(self, x, y): 28 | return self.fn(x, y) 29 | 30 | def __str__(self): 31 | return self.str 32 | 33 | 34 | _EQUALS = _Operator("==", lambda x, y: x == y) 35 | _NOT_EQUAL = _Operator("!=", lambda x, y: x != y) 36 | _GREATER_THAN = _Operator(">", lambda x, y: x > y) 37 | _LESS_THEN = _Operator("<", lambda x, y: x < y) 38 | _GREATER_OR_EQUAL = _Operator(">=", lambda x, y: x >= y) 39 | _LESS_OR_EQUAL = _Operator("<=", lambda x, y: x <= y) 40 | 41 | _IS_INTERSECT = _Operator("any of", lambda col1, col2: any((c in col2 for c in col1))) 42 | _IS_SUBSET = _Operator("all of", lambda col1, col2: all((c in col2 for c in col1))) 43 | _IS_IDENTITY = _Operator( 44 | "exactly", 45 | lambda col1, col2: len(col1) == len(col2) and _IS_SUBSET(col1, col2) and _IS_SUBSET(col2, col1), 46 | ) 47 | _IS_DISJOINT = _Operator("none of", lambda col1, col2: not _IS_INTERSECT(col1, col2)) 48 | 49 | _NOTHING = _Operator("do nothing", lambda x, y: False) 50 | 51 | 52 | class _Assertable: 53 | def __call__(self, event: Any): 54 | raise NotImplementedError 55 | 56 | def check(self): 57 | raise NotImplementedError 58 | 59 | 60 | class _AssertEventCount(_Assertable): 61 | def __init__(self, expected: int = 0, operator: _Operator = _NOTHING): 62 | self.expected: int = expected 63 | self.operator: _Operator = operator 64 | self.actual: int = 0 65 | 66 | def __call__(self, event): 67 | self.actual += 1 68 | 69 | def check(self): 70 | op = self.operator(self.actual, self.expected) 71 | assert op, f"Expected event count {self.operator} {self.expected}, got {self.actual} instead" 72 | 73 | 74 | class _AssertCollection(_Assertable): 75 | def __init__( 76 | self, 77 | expected: Collection[Any], 78 | operator: _Operator = _NOTHING, 79 | ): 80 | self.expected = expected 81 | self.operator: _Operator = operator 82 | self.actual = [] 83 | 84 | def __call__(self, event): 85 | self.actual.append(event) 86 | 87 | def check(self): 88 | op = self.operator(self.expected, self.actual) 89 | assert op, f"Expected {self.operator} {self.actual} in {self.expected}" 90 | 91 | 92 | class _AssertPredicate(_Assertable): 93 | def __init__(self, predicate: Callable[[Any], bool]): 94 | self.predicate = predicate 95 | 96 | def __call__(self, event): 97 | predicate = self.predicate(event) 98 | assert predicate, f"Predicate results in False for Event {event}" 99 | 100 | def check(self): 101 | pass 102 | 103 | 104 | class Assert(Flow): 105 | """Exposes an API for testing the flow between steps.""" 106 | 107 | def __init__(self, **kwargs): 108 | super().__init__(**kwargs) 109 | self.termination_assertions: List[_Assertable] = [] 110 | self.execution_assertions: List[_Assertable] = [] 111 | 112 | def each_event(self, predicate: Callable[[Any], bool]): 113 | self.execution_assertions.append(_AssertPredicate(predicate)) 114 | return self 115 | 116 | def greater_or_equal_to(self, expected: int): 117 | self.termination_assertions.append(_AssertEventCount(expected, _GREATER_OR_EQUAL)) 118 | return self 119 | 120 | def greater_than(self, expected: int): 121 | self.termination_assertions.append(_AssertEventCount(expected, _GREATER_THAN)) 122 | return self 123 | 124 | def less_than(self, expected: int): 125 | self.termination_assertions.append(_AssertEventCount(expected, _LESS_THEN)) 126 | return self 127 | 128 | def less_or_equal_to(self, expected: int): 129 | self.termination_assertions.append(_AssertEventCount(expected, _LESS_OR_EQUAL)) 130 | return self 131 | 132 | def exactly(self, expected: int): 133 | self.termination_assertions.append(_AssertEventCount(expected, _EQUALS)) 134 | return self 135 | 136 | def match_exactly(self, expected: Collection[Any]): 137 | self.termination_assertions.append(_AssertCollection(expected, _IS_IDENTITY)) 138 | return self 139 | 140 | def contains_all_of(self, expected: Collection[Any]): 141 | self.termination_assertions.append(_AssertCollection(expected, _IS_SUBSET)) 142 | return self 143 | 144 | def contains_any_of(self, expected: Collection[Any]): 145 | self.termination_assertions.append(_AssertCollection(expected, _IS_INTERSECT)) 146 | return self 147 | 148 | def contains_none_of(self, expected: Collection[Any]): 149 | self.termination_assertions.append(_AssertCollection(expected, _IS_DISJOINT)) 150 | return self 151 | 152 | async def _do(self, event): 153 | if event is _termination_obj: 154 | for assertion in self.termination_assertions: 155 | assertion.check() 156 | return await self._do_downstream(_termination_obj) 157 | 158 | element = event if self._full_event else event.body 159 | 160 | for assertion in self.execution_assertions: 161 | assertion(element) 162 | 163 | for assertion in self.termination_assertions: 164 | assertion(element) 165 | 166 | await self._do_downstream(event) 167 | -------------------------------------------------------------------------------- /storey/steps/flatten.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from storey import FlatMap 16 | 17 | 18 | def Flatten(**kwargs): 19 | """Flatten is equivalent to FlatMap(lambda x: x).""" 20 | 21 | # Please note that Flatten forces full_event=False, since otherwise we can't iterate the body of the event 22 | if kwargs: 23 | kwargs["full_event"] = False 24 | return FlatMap(lambda x: x, **kwargs) 25 | -------------------------------------------------------------------------------- /storey/steps/foreach.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from storey.flow import _UnaryFunctionFlow 16 | 17 | 18 | class ForEach(_UnaryFunctionFlow): 19 | """Applies given function on each event in the stream, passes original event downstream.""" 20 | 21 | async def _do_internal(self, element, fn_result): 22 | self._user_fn_output_to_event(element, fn_result) 23 | await self._do_downstream(element) 24 | -------------------------------------------------------------------------------- /storey/steps/partition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from collections import namedtuple 16 | from typing import Any, Callable 17 | 18 | from storey import Flow 19 | from storey.dtypes import _termination_obj 20 | 21 | Partitioned = namedtuple("Partitioned", ["left", "right"], defaults=[None, None]) 22 | 23 | 24 | class Partition(Flow): 25 | """ 26 | Partitions events by calling a predicate function on each event. Each processed event results in a `Partitioned` 27 | namedtuple of (left=Optional[Event], right=Optional[Event]). 28 | 29 | For a given event, if the predicate function results in `True`, the event is assigned to `left`. Otherwise, the 30 | event is assigned to `right`. 31 | 32 | :param predicate: A predicate function that results in a boolean. 33 | """ 34 | 35 | def __init__(self, predicate: Callable[[Any], bool], **kwargs): 36 | super().__init__(**kwargs) 37 | self.predicate = predicate 38 | 39 | async def _do(self, event): 40 | if event is _termination_obj: 41 | return await self._do_downstream(_termination_obj) 42 | else: 43 | if self.predicate(event): 44 | event.body = Partitioned(left=event.body, right=None) 45 | else: 46 | event.body = Partitioned(left=None, right=event.body) 47 | await self._do_downstream(event) 48 | -------------------------------------------------------------------------------- /storey/steps/sample.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from enum import Enum 16 | from typing import Callable, Optional, Union 17 | 18 | from storey import Flow 19 | from storey.dtypes import Event, _termination_obj 20 | 21 | 22 | class EmitPeriod(Enum): 23 | FIRST = 1 24 | LAST = 2 25 | 26 | 27 | class SampleWindow(Flow): 28 | """ 29 | Emits a single event in a window of `window_size` events, in accordance with `emit_period` and 30 | `emit_before_termination`. 31 | 32 | :param window_size: The size of the window we want to sample a single event from. 33 | :param emit_period: What event should this step emit for each `window_size` (default: EmitPeriod.First). 34 | Available options: 35 | 1.1) EmitPeriod.FIRST - will emit the first event in a window `window_size` events. 36 | 1.2) EmitPeriod.LAST - will emit the last event in a window of `window_size` events. 37 | :param emit_before_termination: On termination signal, should the step emit the last event it seen (default: False). 38 | Available options: 39 | 2.1) True - The last event seen will be emitted downstream. 40 | 2.2) False - The last event seen will NOT be emitted downstream. 41 | :param key: The key by which events are sampled. By default (None), events are not sampled by key. 42 | Other options may be: 43 | Set to '$key' to sample events by the Event.key property. 44 | set to 'str' key to sample events by Event.body[str]. 45 | set a Callable[[Event], str] to sample events by a custom key extractor. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | window_size: int, 51 | emit_period: EmitPeriod = EmitPeriod.FIRST, 52 | emit_before_termination: bool = False, 53 | key: Optional[Union[str, Callable[[Event], str]]] = None, 54 | **kwargs, 55 | ): 56 | kwargs["full_event"] = True 57 | super().__init__(**kwargs) 58 | 59 | if window_size <= 1: 60 | raise ValueError(f"Expected window_size > 1, found {window_size}") 61 | 62 | if not isinstance(emit_period, EmitPeriod): 63 | raise ValueError(f"Expected emit_period of type `EmitPeriod`, got {type(emit_period)}") 64 | 65 | self._window_size = window_size 66 | self._emit_period = emit_period 67 | self._emit_before_termination = emit_before_termination 68 | self._per_key_count = dict() 69 | self._count = 0 70 | self._last_event = None 71 | self._extract_key: Callable[[Event], str] = self._create_key_extractor(key) 72 | 73 | @staticmethod 74 | def _create_key_extractor(key) -> Callable: 75 | if key is None: 76 | return lambda event: None 77 | elif callable(key): 78 | return key 79 | elif isinstance(key, str): 80 | if key == "$key": 81 | return lambda event: event.key 82 | else: 83 | return lambda event: event.body[key] 84 | else: 85 | raise ValueError(f"Unsupported key type {type(key)}") 86 | 87 | async def _do(self, event): 88 | if event is _termination_obj: 89 | if self._last_event is not None: 90 | await self._do_downstream(self._last_event) 91 | return await self._do_downstream(_termination_obj) 92 | else: 93 | key = self._extract_key(event) 94 | 95 | if key is not None: 96 | if key not in self._per_key_count: 97 | self._per_key_count[key] = 1 98 | else: 99 | self._per_key_count[key] += 1 100 | count = self._per_key_count[key] 101 | else: 102 | self._count += 1 103 | count = self._count 104 | 105 | if self._emit_before_termination: 106 | self._last_event = event 107 | 108 | if count == self._window_size: 109 | if key is not None: 110 | self._per_key_count[key] = 0 111 | else: 112 | self._count = 0 113 | if self._should_emit(count): 114 | self._last_event = None 115 | await self._do_downstream(event) 116 | 117 | def _should_emit(self, count): 118 | if self._emit_period == EmitPeriod.FIRST and count == 1: 119 | return True 120 | elif self._emit_period == EmitPeriod.LAST and count == self._window_size: 121 | return True 122 | return False 123 | -------------------------------------------------------------------------------- /storey/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from ..aggregations import AggregateByKey # noqa: F401 16 | from ..aggregations import QueryByKey # noqa: F401 17 | from ..dataframe import ToDataFrame # noqa: F401 18 | from ..flow import Batch # noqa: F401 19 | from ..flow import Choice # noqa: F401 20 | from ..flow import Extend # noqa: F401 21 | from ..flow import Filter # noqa: F401 22 | from ..flow import FlatMap # noqa: F401 23 | from ..flow import JoinWithTable # noqa: F401 24 | from ..flow import Map # noqa: F401 25 | from ..flow import MapClass # noqa: F401 26 | from ..flow import MapWithState # noqa: F401 27 | from ..flow import ReifyMetadata # noqa: F401 28 | from ..flow import SendToHttp # noqa: F401 29 | from ..flow import _Batching # noqa: F401 30 | from ..flow import _ConcurrentJobExecution # noqa: F401 31 | from ..flow import _FunctionWithStateFlow # noqa: F401 32 | from ..flow import _UnaryFunctionFlow # noqa: F401 33 | from ..steps import Assert # noqa: F401 34 | from ..steps import Flatten # noqa: F401 35 | from ..steps import ForEach # noqa: F401 36 | from ..steps import Partition # noqa: F401 37 | from ..steps import SampleWindow # noqa: F401 38 | -------------------------------------------------------------------------------- /storey/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import base64 16 | import hashlib 17 | import os 18 | import struct 19 | from array import array 20 | from datetime import datetime 21 | from typing import Optional 22 | from urllib.parse import urlparse 23 | 24 | import fsspec 25 | import pytz 26 | 27 | bucketPerWindow = 2 28 | schema_file_name = ".schema" 29 | 30 | serialize_event_marker = "full_event_wrapper" 31 | event_fields_to_serialize = ["key", "id"] 32 | 33 | 34 | def parse_duration(string_time): 35 | unit = string_time[-1] 36 | 37 | if unit == "s": 38 | multiplier = 1000 39 | elif unit == "m": 40 | multiplier = 60 * 1000 41 | elif unit == "h": 42 | multiplier = 60 * 60 * 1000 43 | elif unit == "d": 44 | multiplier = 24 * 60 * 60 * 1000 45 | else: 46 | raise ValueError(f'Failed to parse time "{string_time}"') 47 | 48 | return int(string_time[:-1]) * multiplier 49 | 50 | 51 | def get_one_unit_of_duration(string_time): 52 | unit = string_time[-1] 53 | 54 | if unit == "s": 55 | multiplier = 1000 56 | elif unit == "m": 57 | multiplier = 60 * 1000 58 | elif unit == "h": 59 | multiplier = 60 * 60 * 1000 60 | elif unit == "d": 61 | multiplier = 24 * 60 * 60 * 1000 62 | else: 63 | raise ValueError(f'Failed to parse time "{string_time}"') 64 | 65 | return multiplier 66 | 67 | 68 | def convert_array_tlv(a): 69 | """Gets the array typed array to convert to a blob value of an array, encode it to base64 from base10 with the 70 | following format: 71 | struct vn_object_item_array_md { 72 | uint32_t magic_no; #define MAGIC_NO 11223344 73 | uint16_t version_no; #define ARRAY_VERSION 1 74 | uint32_t array_size_in_bytes; # 8 x element num (8x10 = 80) 75 | enum node_query_filter_operand_type type; # int=11 (260), double=12 (261) 76 | }; 77 | :param a: array type (e.g - array('i', [1, 2, 3, 4, 5]) 78 | :return: blob value of an array 79 | """ 80 | array_type = 259 if a.typecode == "l" else 261 81 | size = len(a) 82 | if a.typecode == "l": 83 | values = struct.pack("l" * size, *a) 84 | else: 85 | values = struct.pack("d" * size, *a) 86 | structure = struct.pack("IhII", 11223344, 1, size * 8, array_type) 87 | converted_blob = base64.b64encode(structure + values) 88 | return converted_blob 89 | 90 | 91 | def extract_array_tlv(b): 92 | """ 93 | get's the blob value of an array, decode it from base64 to base10 and extract the type, length and value based 94 | on the structure - 95 | struct vn_object_item_array_md { 96 | uint32_t magic_no; #define MAGIC_NO 11223344 97 | uint16_t version_no; #define ARRAY_VERSION 1 98 | uint32_t array_size_in_bytes; # 8 x element num (8x10 = 80) 99 | enum node_query_filter_operand_type type; # int=11 (260), double=12 (261) 100 | }; 101 | :param b: blob value 102 | :return: array type array 103 | """ 104 | converted_blob = base64.b64decode(b) 105 | tl = converted_blob[:16] 106 | v = converted_blob[16:] 107 | structure = struct.unpack("IhII", tl) # I=unsigned_int, h=short 108 | size = int(structure[2] / 8) 109 | array_type = "l" if structure[3] == 259 else "d" 110 | if array_type == "l": 111 | values = [v for v in struct.unpack("{}".format("l" * size), v)] 112 | else: 113 | values = [v for v in struct.unpack("{}".format("d" * size), v)] 114 | return array(array_type[0], values) 115 | 116 | 117 | def _split_path(path): 118 | while path.startswith("/"): 119 | path = path[1:] 120 | 121 | parts = path.split("/", 1) 122 | if ":///" in path: 123 | parts = path.split(":///", 1) 124 | if len(parts) == 1: 125 | return parts[0], "/" 126 | else: 127 | return parts[0], f"/{parts[1]}" 128 | 129 | 130 | def get_remaining_path(url): 131 | remaining_path = url 132 | scheme = "" 133 | if "://" in url: 134 | parsed_url = urlparse(url) 135 | scheme = parsed_url.scheme.lower() 136 | if scheme in ("ds", "v3io"): 137 | remaining_path = parsed_url.path 138 | elif scheme in ["wasb", "wasbs"]: 139 | remaining_path = f"{parsed_url.username}{parsed_url.path}" 140 | else: 141 | remaining_path = f"{parsed_url.netloc}{parsed_url.path}" 142 | return scheme, remaining_path 143 | 144 | 145 | def url_to_file_system(url, storage_options): 146 | scheme, remaining_path = get_remaining_path(url) 147 | if url.startswith("ds://"): 148 | parsed_url = urlparse(url) 149 | if parsed_url.password: 150 | scheme = parsed_url.password 151 | else: 152 | raise ValueError("Datastore profile URL is expected to have underlying scheme embedded as password") 153 | 154 | if storage_options: 155 | return fsspec.filesystem(scheme, **storage_options), remaining_path 156 | else: 157 | return fsspec.filesystem(scheme), remaining_path 158 | 159 | 160 | class StoreyMissingDependencyError(Exception): 161 | pass 162 | 163 | 164 | def get_in(obj, keys, default=None): 165 | """ 166 | >>> get_in({'a': {'b': 1}}, 'a.b') 167 | 1 168 | """ 169 | if isinstance(keys, str): 170 | keys = keys.split(".") 171 | 172 | for key in keys: 173 | if not obj or key not in obj: 174 | return default 175 | obj = obj[key] 176 | return obj 177 | 178 | 179 | def update_in(obj, key, value): 180 | parts = key.split(".") if isinstance(key, str) else key 181 | for part in parts[:-1]: 182 | sub = obj.get(part, None) 183 | if sub is None: 184 | sub = obj[part] = {} 185 | obj = sub 186 | 187 | last_key = parts[-1] 188 | obj[last_key] = value 189 | 190 | 191 | def hash_list(list_to_hash): 192 | list_to_hash = [str(element) for element in list_to_hash] 193 | str_concatted = "".join(list_to_hash) 194 | sha1 = hashlib.sha1() 195 | sha1.update(str_concatted.encode("utf8")) 196 | return sha1.hexdigest() 197 | 198 | 199 | def stringify_key(key_list): 200 | if isinstance(key_list, list): 201 | if len(key_list) >= 3: 202 | return str(key_list[0]) + "." + hash_list(key_list[1:]) 203 | if len(key_list) == 2: 204 | return str(key_list[0]) + "." + str(key_list[1]) 205 | return str(key_list[0]) 206 | else: 207 | return str(key_list) 208 | 209 | 210 | def _create_filter_tuple(dtime, attr, sign, list_tuples): 211 | if dtime is not None and attr: 212 | value = getattr(dtime, attr, None) 213 | tuple1 = (attr, sign, value) 214 | list_tuples.append(tuple1) 215 | 216 | 217 | def _find_filter_helper( 218 | list_partitions, 219 | dtime, 220 | sign, 221 | first_sign, 222 | first_uncommon, 223 | filters, 224 | filter_column=None, 225 | ): 226 | single_filter = [] 227 | if dtime is None: 228 | return 229 | if len(list_partitions) == 0 or first_uncommon is None: 230 | return 231 | last_partition = list_partitions[-1] 232 | if len(list_partitions) == 1 or last_partition == first_uncommon: 233 | return 234 | list_partitions_without_last_element = list_partitions[:-1] 235 | for partition in list_partitions_without_last_element: 236 | _create_filter_tuple(dtime, partition, "=", single_filter) 237 | if first_sign: 238 | # only for the first iteration we need to have ">="/"<=" instead of ">"/"<" 239 | _create_filter_tuple(dtime, last_partition, first_sign, single_filter) 240 | # start needs to be > and end needs to be "<=" 241 | if first_sign == "<=": 242 | tuple_last_range = (filter_column, first_sign, dtime) 243 | else: 244 | tuple_last_range = (filter_column, sign, dtime) 245 | single_filter.append(tuple_last_range) 246 | else: 247 | _create_filter_tuple(dtime, last_partition, sign, single_filter) 248 | _find_filter_helper(list_partitions_without_last_element, dtime, sign, None, first_uncommon, filters) 249 | filters.append(single_filter) 250 | 251 | 252 | def _get_filters_for_filter_column(start, end, filter_column, side_range): 253 | if start is not None: 254 | lower_limit_tuple = (filter_column, ">", start) 255 | side_range.append(lower_limit_tuple) 256 | if end is not None: 257 | upper_limit_tuple = (filter_column, "<=", end) 258 | side_range.append(upper_limit_tuple) 259 | 260 | 261 | def find_partitions(url, fs): 262 | # ML-1365. assuming the partitioning is symmetrical (for example both year=2020 and year=2021 directories will have 263 | # inner month partitions). 264 | 265 | partitions = [] 266 | 267 | def _is_private(path): 268 | _, tail = os.path.split(path) 269 | return (tail.startswith("_") or tail.startswith(".")) and "=" not in tail 270 | 271 | def find_partition_helper(url, fs, partitions): 272 | content = fs.ls(url, detail=True) 273 | if len(content) == 0: 274 | return partitions 275 | # https://issues.apache.org/jira/browse/ARROW-1079 there could be some private dirs 276 | filtered_dirs = [x for x in content if not _is_private(x["name"])] 277 | if len(filtered_dirs) == 0: 278 | return partitions 279 | 280 | inner_dir = filtered_dirs[0]["name"] 281 | if fs.isfile(inner_dir): 282 | return partitions 283 | part = inner_dir.split("/")[-1].split("=") 284 | partitions.append(part[0]) 285 | find_partition_helper(inner_dir, fs, partitions) 286 | 287 | if fs.isfile(url): 288 | return partitions 289 | find_partition_helper(url, fs, partitions) 290 | 291 | legal_time_units = ["year", "month", "day", "hour", "minute", "second"] 292 | 293 | partitions_time_attributes = [j for j in legal_time_units if j in partitions] 294 | 295 | return partitions_time_attributes 296 | 297 | 298 | def find_filters(partitions_time_attributes, start, end, filters, filter_column): 299 | # this method build filters to be used by 300 | # https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html 301 | common_partitions = [] 302 | first_uncommon = None 303 | # finding the common attributes. for example for start=1.2.2018 08:53:15, end=5.2.2018 16:24:31, partitioned by 304 | # year, month, day, hour. common_partions=[year, month], first_uncommon=day 305 | for part in partitions_time_attributes: 306 | value_start = getattr(start, part, None) 307 | value_end = getattr(end, part, None) 308 | if value_end == value_start: 309 | common_partitions.append(part) 310 | else: 311 | first_uncommon = part 312 | break 313 | 314 | # for start=1.2.2018 08:53:15, end=5.2.2018 16:24:31, this method will append to filters 315 | # [(year=2018, month=2,day>=1, filter_column>1.2.2018 08:53:15)] 316 | _find_filter_helper( 317 | partitions_time_attributes, 318 | start, 319 | ">", 320 | ">=", 321 | first_uncommon, 322 | filters, 323 | filter_column, 324 | ) 325 | 326 | middle_range_filter = [] 327 | for partition in common_partitions: 328 | _create_filter_tuple(start, partition, "=", middle_range_filter) 329 | 330 | if len(filters) == 0: 331 | # creating only the middle range 332 | _create_filter_tuple(start, first_uncommon, ">=", middle_range_filter) 333 | _create_filter_tuple(end, first_uncommon, "<=", middle_range_filter) 334 | _get_filters_for_filter_column(start, end, filter_column, middle_range_filter) 335 | else: 336 | _create_filter_tuple(start, first_uncommon, ">", middle_range_filter) 337 | _create_filter_tuple(end, first_uncommon, "<", middle_range_filter) 338 | # for start=1.2.2018 08:53:15, end=5.2.2018 16:24:31, this will append to filters 339 | # [(year=2018, month=2, 1 self.max_time: 120 | self.max_time = t 121 | 122 | def __repr__(self): 123 | return str(self) 124 | 125 | def __str__(self): 126 | return f"{self.data} - {self.max_time}" 127 | 128 | 129 | # a class that accepts - window, (data, key, timestamp) 130 | class WindowedStoreElement: 131 | def __init__(self, key, window, late_data_handling): 132 | self.key = key 133 | self.late_data_handling = late_data_handling 134 | self.window = window 135 | self.features = {} 136 | self.first_bucket_start_time = self.window.get_window_start_time() 137 | self.last_bucket_start_time = ( 138 | self.first_bucket_start_time + (window.get_total_number_of_buckets() - 1) * window.period_millis 139 | ) 140 | 141 | def add(self, data, timestamp): 142 | # add a new point and aggregate 143 | for column_name in data: 144 | if column_name not in self.features: 145 | self.initialize_column(column_name) 146 | index = self.get_or_advance_bucket_index_by_timestamp(timestamp) 147 | self.features[column_name][index].add(timestamp, data[column_name]) 148 | 149 | def get_column_name(self, column, aggregation): 150 | return f"{column}_{aggregation}_{self.window.window_str}" 151 | 152 | def initialize_column(self, column): 153 | self.features[column] = [] 154 | for _ in range(self.window.get_total_number_of_buckets()): 155 | self.features[column].append(WindowBucket(self.late_data_handling)) 156 | 157 | def get_or_advance_bucket_index_by_timestamp(self, timestamp): 158 | if timestamp < self.last_bucket_start_time + self.window.period_millis: 159 | bucket_index = int((timestamp - self.first_bucket_start_time) / self.window.period_millis) 160 | return bucket_index 161 | else: 162 | self.advance_window_period(timestamp) 163 | return self.window.get_total_number_of_buckets() - 1 # return last index 164 | 165 | def advance_window_period(self, advance_to=None): 166 | if not advance_to: 167 | advance_to = datetime.now().timestamp() * 1000 168 | desired_bucket_index = int((advance_to - self.first_bucket_start_time) / self.window.period_millis) 169 | buckets_to_advnace = desired_bucket_index - (self.window.get_total_number_of_buckets() - 1) 170 | 171 | if buckets_to_advnace > 0: 172 | if buckets_to_advnace > self.window.get_total_number_of_buckets(): 173 | for column in self.features: 174 | self.initialize_column(column) 175 | else: 176 | for column in self.features: 177 | self.features[column] = self.features[column][buckets_to_advnace:] 178 | for _ in range(buckets_to_advnace): 179 | self.features[column].extend([WindowBucket(self.late_data_handling)]) 180 | 181 | self.first_bucket_start_time = self.first_bucket_start_time + buckets_to_advnace * self.window.period_millis 182 | self.last_bucket_start_time = self.last_bucket_start_time + buckets_to_advnace * self.window.period_millis 183 | 184 | def flush(self): 185 | for column in self.features: 186 | self.initialize_column(column) 187 | 188 | 189 | def aggregate(self, aggregation, old_value, new_value): 190 | if aggregation == "min": 191 | return min(old_value, new_value) 192 | elif aggregation == "max": 193 | return max(old_value, new_value) 194 | elif aggregation == "sum": 195 | return old_value + new_value 196 | elif aggregation == "count": 197 | return old_value + 1 198 | elif aggregation == "last": 199 | return new_value 200 | elif aggregation == "first": 201 | return old_value 202 | 203 | 204 | class WindowedStore: 205 | def __init__(self, window, late_data_handling): 206 | self.window = window 207 | self.late_data_handling = late_data_handling 208 | self.cache = {} 209 | 210 | def __iter__(self): 211 | return iter(self.cache.items()) 212 | 213 | def add(self, key, data, timestamp): 214 | if key not in self.cache: 215 | self.cache[key] = WindowedStoreElement(key, self.window, self.late_data_handling) 216 | 217 | if isinstance(timestamp, datetime): 218 | timestamp = timestamp.timestamp() * 1000 219 | self.cache[key].add(data, timestamp) 220 | 221 | def flush(self): 222 | for key in self.cache: 223 | self.cache[key].flush() 224 | -------------------------------------------------------------------------------- /tests.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | 3 | source=./storey 4 | 5 | omit = 6 | */tests/* 7 | */integration/* 8 | test_* 9 | 10 | relative_files = True 11 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | -------------------------------------------------------------------------------- /tests/test-multiple-time-columns.csv: -------------------------------------------------------------------------------- 1 | k,t1,s,t2 2 | m1,2020-06-27T10:23:08.420581,katya,2020-06-27T12:23:08.420581 3 | m2,2021-06-27T10:23:08.420581,dina,2021-06-27T10:21:08.420581 4 | -------------------------------------------------------------------------------- /tests/test-none-in-keyfield.csv: -------------------------------------------------------------------------------- 1 | k,t,v,b 2 | m1,15/02/2020 02:00:00,8,true 3 | ,16/02/2020 02:00:00,14,False 4 | -------------------------------------------------------------------------------- /tests/test-with-compact-timestamp.csv: -------------------------------------------------------------------------------- 1 | k,t,v,b 2 | m1,2020021502,8,true 3 | m2,2020021602,14,False 4 | -------------------------------------------------------------------------------- /tests/test-with-none-values.csv: -------------------------------------------------------------------------------- 1 | string,bool,bool_with_none,int_with_nan,float_with_nan,date_with_none 2 | a,True,False,1,2.3,2021-04-21 15:56:53.385444 3 | b,True,,,, 4 | -------------------------------------------------------------------------------- /tests/test-with-timestamp-microsecs.csv: -------------------------------------------------------------------------------- 1 | k,t 2 | m1,15/02/2020 02:03:04.123456 3 | m2,16/02/2020 02:03:04.123456 4 | -------------------------------------------------------------------------------- /tests/test-with-timestamp-nanosecs.csv: -------------------------------------------------------------------------------- 1 | k,t 2 | m1,15/02/2020 02:03:04.123456789 3 | m2,16/02/2020 02:03:04.123456789 4 | -------------------------------------------------------------------------------- /tests/test-with-timestamp.csv: -------------------------------------------------------------------------------- 1 | k,t,v,b 2 | m1,15/02/2020 02:00:00,8,true 3 | m2,16/02/2020 02:00:00,14,False 4 | -------------------------------------------------------------------------------- /tests/test.csv: -------------------------------------------------------------------------------- 1 | n1,n2,n3 2 | 1,2,3 3 | 4,5,6 4 | -------------------------------------------------------------------------------- /tests/test.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlrun/storey/44553a4e5acb4144c35a33060da156f7e84c2d37/tests/test.parquet -------------------------------------------------------------------------------- /tests/test_aggregate_store.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from datetime import datetime 16 | 17 | from storey.dtypes import SlidingWindows 18 | from storey.table import ReadOnlyAggregationBuckets 19 | 20 | 21 | def _assert_buckets(window, base_time, initial_data, expected_data): 22 | aggr_buckets = ReadOnlyAggregationBuckets("test", "count", window, None, base_time, None, initial_data=initial_data) 23 | 24 | actual = [aggr_value.value for aggr_value in aggr_buckets.buckets] 25 | assert actual == expected_data 26 | 27 | 28 | def test_load_aggregation_bucket(): 29 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:40:00+00:00").timestamp() * 1000) 30 | window = SlidingWindows(["1h"], "10m") 31 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 32 | initial_data = { 33 | curr_bucket_time - window.max_window_millis: [1, 0, 0, 0, 1, 2], 34 | curr_bucket_time: [1, 0, 2, 1, 1, 0], 35 | } 36 | expected_data = [2, 1, 0, 2, 1, 1] 37 | _assert_buckets(window, test_base_time, initial_data, expected_data) 38 | 39 | 40 | def test_load_aggregation_bucket_data_two_stored_buckets_requested_data_newer_than_both(): 41 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:40:00+00:00").timestamp() * 1000) 42 | window = SlidingWindows(["1h"], "10m") 43 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 44 | initial_data = { 45 | curr_bucket_time - 3 * window.max_window_millis: [1, 0, 0, 0, 1, 2], 46 | curr_bucket_time - 2 * window.max_window_millis: [1, 0, 2, 1, 1, 1], 47 | } 48 | expected_data = [0, 0, 0, 0, 0, 0] 49 | _assert_buckets(window, test_base_time, initial_data, expected_data) 50 | 51 | 52 | def test_load_aggregation_bucket_data_two_stored_buckets_requested_data_newer(): 53 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:25:00+00:00").timestamp() * 1000) 54 | window = SlidingWindows(["1h"], "10m") 55 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 56 | initial_data = { 57 | curr_bucket_time - 2 * window.max_window_millis: [1, 0, 0, 0, 1, 2], 58 | curr_bucket_time - window.max_window_millis: [1, 0, 2, 2, 1, 1], 59 | } 60 | expected_data = [2, 1, 1, 0, 0, 0] 61 | _assert_buckets(window, test_base_time, initial_data, expected_data) 62 | 63 | 64 | def test_load_aggregation_bucket_data_two_stored_buckets_requested_data_older(): 65 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:40:00+00:00").timestamp() * 1000) 66 | window = SlidingWindows(["1h"], "10m") 67 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 68 | initial_data = { 69 | curr_bucket_time + window.max_window_millis: [1, 0, 0, 0, 1, 2], 70 | curr_bucket_time + 2 * window.max_window_millis: [1, 0, 2, 1, 1, 0], 71 | } 72 | expected_data = [0, 0, 0, 0, 0, 0] 73 | _assert_buckets(window, test_base_time, initial_data, expected_data) 74 | 75 | 76 | def test_load_aggregation_bucket_one_stored_bucket(): 77 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:40:00+00:00").timestamp() * 1000) 78 | window = SlidingWindows(["1h"], "10m") 79 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 80 | initial_data = {curr_bucket_time: [1, 0, 2, 1, 1, 0]} 81 | expected_data = [0, 1, 0, 2, 1, 1] 82 | _assert_buckets(window, test_base_time, initial_data, expected_data) 83 | 84 | 85 | def test_load_aggregation_bucket_one_stored_bucket_requested_data_newer(): 86 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:25:00+00:00").timestamp() * 1000) 87 | window = SlidingWindows(["1h"], "10m") 88 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 89 | initial_data = {curr_bucket_time - window.max_window_millis: [1, 0, 2, 1, 1, 3]} 90 | expected_data = [1, 1, 3, 0, 0, 0] 91 | _assert_buckets(window, test_base_time, initial_data, expected_data) 92 | 93 | 94 | def test_load_aggregation_bucket_one_stored_bucket_requested_data_older(): 95 | test_base_time = int(datetime.fromisoformat("2020-07-21T21:40:00+00:00").timestamp() * 1000) 96 | window = SlidingWindows(["1h"], "10m") 97 | curr_bucket_time = int(test_base_time / window.max_window_millis) * window.max_window_millis 98 | initial_data = {curr_bucket_time + window.max_window_millis: [1, 0, 2, 1, 1, 0]} 99 | expected_data = [0, 0, 0, 0, 0, 0] 100 | _assert_buckets(window, test_base_time, initial_data, expected_data) 101 | -------------------------------------------------------------------------------- /tests/test_concurrent_execution.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import time 3 | 4 | import pytest 5 | 6 | from storey import AsyncEmitSource 7 | from storey.flow import Complete, ConcurrentExecution, Reduce, build_flow 8 | from tests.test_flow import append_and_return 9 | 10 | event_processing_duration = 0.5 11 | 12 | 13 | class SomeContext: 14 | def __init__(self): 15 | self.fn = lambda x: x 16 | 17 | 18 | async def process_event_slow_asyncio(event, context): 19 | assert isinstance(context, SomeContext) and callable(context.fn) 20 | await asyncio.sleep(event_processing_duration) 21 | return event 22 | 23 | 24 | def process_event_slow_io(event, context): 25 | assert isinstance(context, SomeContext) and callable(context.fn) 26 | time.sleep(event_processing_duration) 27 | return event 28 | 29 | 30 | def process_event_slow_processing(event): 31 | start = time.monotonic() 32 | while time.monotonic() - start < event_processing_duration: 33 | pass 34 | return event 35 | 36 | 37 | async def async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): 38 | controller = build_flow( 39 | [ 40 | AsyncEmitSource(), 41 | ConcurrentExecution( 42 | event_processor=event_processor, 43 | concurrency_mechanism=concurrency_mechanism, 44 | pass_context=pass_context, 45 | max_in_flight=10, 46 | context=SomeContext(), 47 | ), 48 | Reduce([], append_and_return), 49 | ] 50 | ).run() 51 | 52 | num_events = 8 53 | 54 | start = time.monotonic() 55 | for counter in range(num_events): 56 | await controller.emit(counter) 57 | 58 | await controller.terminate() 59 | result = await controller.await_termination() 60 | end = time.monotonic() 61 | 62 | assert result == list(range(num_events)) 63 | assert end - start > event_processing_duration, "Run time cannot be less than the time to process a single event" 64 | assert ( 65 | end - start < event_processing_duration * num_events 66 | ), "Run time must be less than the time to process all events in serial" 67 | 68 | 69 | @pytest.mark.parametrize( 70 | ["concurrency_mechanism", "event_processor", "pass_context"], 71 | [ 72 | ("asyncio", process_event_slow_asyncio, True), 73 | ("threading", process_event_slow_io, True), 74 | ("multiprocessing", process_event_slow_processing, False), 75 | ], 76 | ) 77 | def test_concurrent_execution(concurrency_mechanism, event_processor, pass_context): 78 | asyncio.run(async_test_concurrent_execution(concurrency_mechanism, event_processor, pass_context)) 79 | 80 | 81 | async def async_test_concurrent_execution_multiprocessing_and_complete(): 82 | controller = build_flow( 83 | [ 84 | AsyncEmitSource(), 85 | ConcurrentExecution( 86 | event_processor=process_event_slow_processing, 87 | concurrency_mechanism="multiprocessing", 88 | max_in_flight=2, 89 | ), 90 | Complete(), 91 | ] 92 | ).run() 93 | 94 | event_body = "hello" 95 | try: 96 | res = await controller.emit(event_body) 97 | assert res == event_body 98 | finally: 99 | await controller.terminate() 100 | await controller.await_termination() 101 | 102 | 103 | def test_concurrent_execution_multiprocessing_and_complete(): 104 | asyncio.run(async_test_concurrent_execution_multiprocessing_and_complete()) 105 | 106 | 107 | def test_concurrent_execution_multiprocessing_and_full_event(): 108 | with pytest.raises( 109 | ValueError, 110 | match='concurrency_mechanism="multiprocessing" may not be used in conjunction with full_event=True', 111 | ): 112 | ConcurrentExecution( 113 | event_processor=process_event_slow_processing, 114 | concurrency_mechanism="multiprocessing", 115 | full_event=True, 116 | ) 117 | -------------------------------------------------------------------------------- /tests/test_queue.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from storey.queue import SimpleAsyncQueue 6 | 7 | 8 | async def async_test_simple_async_queue(): 9 | q = SimpleAsyncQueue(2) 10 | 11 | with pytest.raises(TimeoutError): 12 | await q.get(0) 13 | 14 | get_task = asyncio.create_task(q.get(1)) 15 | await q.put("x") 16 | assert await get_task == "x" 17 | 18 | await q.put("x") 19 | await q.put("y") 20 | put_task = asyncio.create_task(q.put("z")) 21 | assert await q.get() == "x" 22 | await put_task 23 | assert await q.get() == "y" 24 | assert await q.get() == "z" 25 | 26 | 27 | def test_simple_async_queue(): 28 | asyncio.run(async_test_simple_async_queue()) 29 | -------------------------------------------------------------------------------- /tests/test_space_in_header.csv: -------------------------------------------------------------------------------- 1 | header with space,n2,n3 2 | 1,2,3 3 | 4,5,6 4 | -------------------------------------------------------------------------------- /tests/test_space_in_header.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlrun/storey/44553a4e5acb4144c35a33060da156f7e84c2d37/tests/test_space_in_header.parquet -------------------------------------------------------------------------------- /tests/test_steps.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from pytest import fail 16 | 17 | from storey import SyncEmitSource, build_flow 18 | from storey.dtypes import Event 19 | from storey.steps import Assert, EmitPeriod, Flatten, ForEach, Partition, SampleWindow 20 | 21 | 22 | def test_assert_each_event(): 23 | try: 24 | controller = build_flow([SyncEmitSource(), Assert().each_event(lambda event: event > 10)]).run() 25 | controller.emit(1) 26 | controller.terminate() 27 | controller.await_termination() 28 | fail("Assert not failing", False) 29 | except AssertionError: 30 | pass 31 | 32 | 33 | def test_assert_greater_or_equal_to(): 34 | try: 35 | controller = build_flow([SyncEmitSource(), Assert().greater_or_equal_to(2)]).run() 36 | controller.emit(1) 37 | controller.terminate() 38 | controller.await_termination() 39 | fail("Assert not failing", False) 40 | except AssertionError: 41 | pass 42 | 43 | try: 44 | controller = build_flow([SyncEmitSource(), Assert().greater_or_equal_to(2)]).run() 45 | controller.emit(1) 46 | controller.emit(1) 47 | controller.terminate() 48 | controller.await_termination() 49 | except AssertionError: 50 | fail("Assert failed unexpectedly", False) 51 | 52 | 53 | def test_assert_greater_than(): 54 | try: 55 | controller = build_flow([SyncEmitSource(), Assert().greater_than(1)]).run() 56 | controller.emit(1) 57 | controller.terminate() 58 | controller.await_termination() 59 | fail("Assert not failing", False) 60 | except AssertionError: 61 | pass 62 | 63 | try: 64 | controller = build_flow([SyncEmitSource(), Assert().greater_than(1)]).run() 65 | controller.emit(1) 66 | controller.emit(1) 67 | controller.terminate() 68 | controller.await_termination() 69 | except AssertionError: 70 | fail("Assert failed unexpectedly", False) 71 | 72 | 73 | def test_assert_less_or_equal(): 74 | try: 75 | controller = build_flow([SyncEmitSource(), Assert().less_or_equal_to(2)]).run() 76 | controller.emit(1) 77 | controller.emit(2) 78 | controller.emit(3) 79 | controller.terminate() 80 | controller.await_termination() 81 | fail("Assert not failing", False) 82 | except AssertionError: 83 | pass 84 | 85 | try: 86 | controller = build_flow([SyncEmitSource(), Assert().less_or_equal_to(2)]).run() 87 | controller.emit(1) 88 | controller.emit(1) 89 | controller.terminate() 90 | controller.await_termination() 91 | except AssertionError: 92 | fail("Assert failed unexpectedly", False) 93 | 94 | 95 | def test_assert_exactly(): 96 | try: 97 | controller = build_flow([SyncEmitSource(), Assert().exactly(2)]).run() 98 | controller.emit(1) 99 | controller.terminate() 100 | controller.await_termination() 101 | fail("Assert not failing", False) 102 | except AssertionError: 103 | pass 104 | 105 | try: 106 | controller = build_flow([SyncEmitSource(), Assert().exactly(2)]).run() 107 | controller.emit(1) 108 | controller.emit(1) 109 | controller.emit(1) 110 | controller.terminate() 111 | controller.await_termination() 112 | fail("Assert not failing", False) 113 | except AssertionError: 114 | pass 115 | 116 | try: 117 | controller = build_flow([SyncEmitSource(), Assert().exactly(2)]).run() 118 | controller.emit(1) 119 | controller.emit(1) 120 | controller.terminate() 121 | controller.await_termination() 122 | except AssertionError: 123 | fail("Assert failed unexpectedly", False) 124 | 125 | 126 | def test_assert_match_exactly(): 127 | try: 128 | controller = build_flow([SyncEmitSource(), Assert(full_event=False).match_exactly([1, 1, 1])]).run() 129 | controller.emit(1) 130 | controller.emit(1) 131 | controller.terminate() 132 | controller.await_termination() 133 | fail("Assert not failing", False) 134 | except AssertionError: 135 | pass 136 | 137 | try: 138 | controller = build_flow([SyncEmitSource(), Assert().match_exactly([1, 1, 1])]).run() 139 | controller.emit(1) 140 | controller.emit(1) 141 | controller.emit(1) 142 | controller.terminate() 143 | controller.await_termination() 144 | except AssertionError: 145 | fail("Assert failed unexpectedly", False) 146 | 147 | 148 | def test_assert_all_of(): 149 | try: 150 | controller = build_flow( 151 | [ 152 | SyncEmitSource(), 153 | Assert().contains_all_of([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 154 | ] 155 | ).run() 156 | controller.emit([1, 2, 3]) 157 | controller.emit([4, 5, 6]) 158 | controller.terminate() 159 | controller.await_termination() 160 | fail("Assert not failing", False) 161 | except AssertionError: 162 | pass 163 | 164 | try: 165 | controller = build_flow( 166 | [ 167 | SyncEmitSource(), 168 | Assert().contains_all_of([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 169 | ] 170 | ).run() 171 | controller.emit([1, 2, 3]) 172 | controller.emit([4, 5, 6]) 173 | controller.emit([7, 8, 9]) 174 | controller.terminate() 175 | controller.await_termination() 176 | except AssertionError: 177 | fail("Assert failed unexpectedly", False) 178 | 179 | 180 | def test_assert_any_of(): 181 | try: 182 | controller = build_flow( 183 | [ 184 | SyncEmitSource(), 185 | Assert().contains_any_of([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 186 | ] 187 | ).run() 188 | controller.emit([10, 11, 12]) 189 | controller.terminate() 190 | controller.await_termination() 191 | fail("Assert not failing", False) 192 | except AssertionError: 193 | pass 194 | 195 | try: 196 | controller = build_flow( 197 | [ 198 | SyncEmitSource(), 199 | Assert().contains_any_of([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 200 | ] 201 | ).run() 202 | controller.emit([1, 2, 3]) 203 | controller.terminate() 204 | controller.await_termination() 205 | except AssertionError: 206 | fail("Assert failed unexpectedly", False) 207 | 208 | 209 | def test_assert_none_of(): 210 | try: 211 | controller = build_flow( 212 | [ 213 | SyncEmitSource(), 214 | Assert().contains_none_of([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 215 | ] 216 | ).run() 217 | controller.emit([10, 11, 12]) 218 | controller.terminate() 219 | controller.await_termination() 220 | except AssertionError: 221 | fail("Assert failed unexpectedly", False) 222 | 223 | try: 224 | controller = build_flow( 225 | [ 226 | SyncEmitSource(), 227 | Assert().contains_none_of([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 228 | ] 229 | ).run() 230 | controller.emit([1, 2, 3]) 231 | controller.terminate() 232 | controller.await_termination() 233 | fail("Assert not failing", False) 234 | except AssertionError: 235 | pass 236 | 237 | 238 | def test_sample_emit_first(): 239 | controller = build_flow( 240 | [ 241 | SyncEmitSource(), 242 | Assert().exactly(5), 243 | SampleWindow(5), 244 | Assert().exactly(1).match_exactly([0]), 245 | ] 246 | ).run() 247 | 248 | for i in range(0, 5): 249 | controller.emit(i) 250 | controller.terminate() 251 | controller.await_termination() 252 | 253 | 254 | def test_sample_emit_first_with_emit_before_termination(): 255 | controller = build_flow( 256 | [ 257 | SyncEmitSource(), 258 | Assert().exactly(5), 259 | SampleWindow(5, emit_before_termination=True), 260 | Assert().exactly(2).match_exactly([0, 4]), 261 | ] 262 | ).run() 263 | 264 | for i in range(0, 5): 265 | controller.emit(i) 266 | controller.terminate() 267 | controller.await_termination() 268 | 269 | 270 | def test_sample_emit_last(): 271 | controller = build_flow( 272 | [ 273 | SyncEmitSource(), 274 | Assert().exactly(5), 275 | SampleWindow(5, emit_period=EmitPeriod.LAST), 276 | Assert().exactly(1).match_exactly([4]), 277 | ] 278 | ).run() 279 | 280 | for i in range(0, 5): 281 | controller.emit(i) 282 | controller.terminate() 283 | controller.await_termination() 284 | 285 | 286 | def test_sample_emit_last_with_emit_before_termination(): 287 | controller = build_flow( 288 | [ 289 | SyncEmitSource(), 290 | Assert().exactly(5), 291 | SampleWindow(5, emit_period=EmitPeriod.LAST, emit_before_termination=True), 292 | Assert().exactly(1).match_exactly([4]), 293 | ] 294 | ).run() 295 | 296 | for i in range(0, 5): 297 | controller.emit(i) 298 | controller.terminate() 299 | controller.await_termination() 300 | 301 | 302 | def test_sample_emit_event_per_key(): 303 | controller = build_flow( 304 | [ 305 | SyncEmitSource(key_field=str), 306 | Assert().exactly(25), 307 | SampleWindow(5, key="$key"), 308 | Assert().exactly(5).match_exactly([0, 1, 2, 3, 4]), 309 | ] 310 | ).run() 311 | 312 | for i in range(0, 25): 313 | key = f"key_{i % 5}" 314 | controller.emit(i, key=key) 315 | controller.terminate() 316 | controller.await_termination() 317 | 318 | 319 | def test_flatten(): 320 | controller = build_flow([SyncEmitSource(), Flatten(), Assert().contains_all_of([1, 2, 3, 4, 5, 6])]).run() 321 | 322 | controller.emit([1, 2, 3, 4, 5, 6]) 323 | controller.terminate() 324 | controller.await_termination() 325 | 326 | 327 | def test_flatten_forces_full_event_false(): 328 | controller = build_flow( 329 | [ 330 | SyncEmitSource(), 331 | Flatten(full_event=True), 332 | Assert().contains_all_of([1, 2, 3, 4, 5, 6]), 333 | ] 334 | ).run() 335 | 336 | controller.emit([1, 2, 3, 4, 5, 6]) 337 | controller.terminate() 338 | controller.await_termination() 339 | 340 | 341 | def test_foreach(): 342 | event_ids = set() 343 | controller = build_flow( 344 | [ 345 | SyncEmitSource(), 346 | ForEach(lambda e: event_ids.add(e.id), full_event=True), 347 | Assert(full_event=True).each_event(lambda event: event.id in event_ids), 348 | ] 349 | ).run() 350 | 351 | for i in range(0, 5): 352 | controller.emit(i) 353 | 354 | controller.terminate() 355 | controller.await_termination() 356 | 357 | 358 | def test_partition(): 359 | divisible_by_two = {2, 4, 6} 360 | not_divisible_by_two = {1, 3, 5} 361 | 362 | def check_partition(event: Event): 363 | first = event.body.left 364 | second = event.body.right 365 | 366 | if first is not None: 367 | return first in divisible_by_two and first not in not_divisible_by_two and second is None 368 | else: 369 | return second in not_divisible_by_two and second not in divisible_by_two 370 | 371 | controller = build_flow( 372 | [ 373 | SyncEmitSource(), 374 | Assert().exactly(6), 375 | Partition(lambda event: event.body % 2 == 0), 376 | Assert().exactly(6), 377 | Assert(full_event=True).each_event(lambda event: check_partition(event)), 378 | ] 379 | ).run() 380 | 381 | controller.emit(1) 382 | controller.emit(2) 383 | controller.emit(3) 384 | controller.emit(4) 385 | controller.emit(5) 386 | controller.emit(6) 387 | 388 | controller.terminate() 389 | controller.await_termination() 390 | -------------------------------------------------------------------------------- /tests/test_targets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | from unittest.mock import Mock 17 | 18 | import pytest 19 | import taosws 20 | 21 | from storey.dtypes import TDEngineValueError 22 | from storey.targets import TDEngineTarget 23 | 24 | 25 | class TestTDEngineTarget: 26 | @staticmethod 27 | def test_tags_mapping_consistency() -> None: 28 | for type_, func in TDEngineTarget._get_tdengine_type_to_tag_func().items(): 29 | assert func.__name__ == f"{type_.lower()}_to_tag" 30 | 31 | @staticmethod 32 | def test_columns_mapping_consistency() -> None: 33 | for type_, func in TDEngineTarget._get_tdengine_type_to_column_func().items(): 34 | if type_ == "TIMESTAMP": 35 | assert func.__name__.startswith("millis_timestamp") 36 | else: 37 | assert func.__name__.startswith(type_.lower()) 38 | assert func.__name__.endswith("_to_column") 39 | 40 | @staticmethod 41 | @pytest.mark.parametrize( 42 | ("database", "table", "supertable", "table_col", "tag_cols"), 43 | [ 44 | (None, None, "my_super_tb", "pass_this_check", ["also_this_one"]), 45 | ("mydb", None, "my super tb", "pass_this_check", ["also_this_one"]), 46 | ("_db", "9table", None, None, None), 47 | ("_db", " cars", None, None, None), 48 | ], 49 | ) 50 | def test_invalid_names( 51 | database: Optional[str], 52 | table: Optional[str], 53 | supertable: Optional[str], 54 | table_col: Optional[str], 55 | tag_cols: Optional[list[str]], 56 | ) -> None: 57 | with pytest.raises(TDEngineValueError): 58 | TDEngineTarget( 59 | url="taosws://root:taosdata@localhost:6041", 60 | time_col="ts", 61 | columns=["value"], 62 | table_col=table_col, 63 | tag_cols=tag_cols, 64 | database=database, 65 | table=table, 66 | supertable=supertable, 67 | ) 68 | 69 | @staticmethod 70 | @pytest.fixture 71 | def tdengine_target() -> TDEngineTarget: 72 | target = TDEngineTarget( 73 | url="taosws://root:taosdata@localhost:6041", 74 | time_col="ts", 75 | columns=["value"], 76 | database="test", 77 | table="d6241", 78 | ) 79 | 80 | target._connection = Mock() 81 | # The following test schema is obtained from the `taosBenchmark` data: 82 | # https://docs.tdengine.com/get-started/docker/#test-data-insert-performance 83 | # list(conn.query("describe test.d6241;")) 84 | target._connection.query = Mock( 85 | return_value=[ 86 | ("ts", "TIMESTAMP", 8, "", "delta-i", "lz4", "medium"), 87 | ("current", "FLOAT", 4, "", "delta-d", "lz4", "medium"), 88 | ("voltage", "INT", 4, "", "simple8b", "lz4", "medium"), 89 | ("phase", "FLOAT", 4, "", "delta-d", "lz4", "medium"), 90 | ("groupid", "INT", 4, "TAG", "disabled", "disabled", "disabled"), 91 | ("location", "VARCHAR", 24, "TAG", "disabled", "disabled", "disabled"), 92 | ], 93 | ) 94 | return target 95 | 96 | @staticmethod 97 | def test_get_table_schema(tdengine_target: TDEngineTarget) -> None: 98 | """Test that the parsing works""" 99 | tags_schema, reg_cols_schema = tdengine_target._get_table_schema("d6241") 100 | assert tags_schema == [("groupid", taosws.int_to_tag), ("location", taosws.varchar_to_tag)] 101 | assert reg_cols_schema == [ 102 | ("ts", taosws.millis_timestamps_to_column), 103 | ("current", taosws.floats_to_column), 104 | ("voltage", taosws.ints_to_column), 105 | ("phase", taosws.floats_to_column), 106 | ] 107 | -------------------------------------------------------------------------------- /tests/test_types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | 17 | from storey.dtypes import ( 18 | EmitAfterDelay, 19 | EmitAfterMaxEvent, 20 | EmitAfterPeriod, 21 | EmitAfterWindow, 22 | EmitEveryEvent, 23 | _dict_to_emit_policy, 24 | ) 25 | 26 | 27 | @pytest.mark.parametrize("emit_policy", [EmitEveryEvent, EmitAfterPeriod, EmitAfterWindow]) 28 | def test_emit_policy_basic(emit_policy): 29 | policy_dict = {"mode": emit_policy.name()} 30 | policy = _dict_to_emit_policy(policy_dict) 31 | assert type(policy) == emit_policy 32 | 33 | 34 | @pytest.mark.parametrize("emit_policy", [EmitAfterDelay, EmitAfterMaxEvent]) 35 | def test_emit_policy_bad_parameters(emit_policy): 36 | policy_dict = {"mode": emit_policy.name()} 37 | with pytest.raises(ValueError): 38 | _dict_to_emit_policy(policy_dict) 39 | 40 | 41 | def test_emit_policy_wrong_type(): 42 | policy_dict = {"mode": "d-o-g-g"} 43 | with pytest.raises(TypeError): 44 | _dict_to_emit_policy(policy_dict) 45 | 46 | 47 | def test_emit_policy_wrong_args(): 48 | policy_dict = {"mode": EmitAfterWindow.name(), "daily": 8} 49 | with pytest.raises(ValueError): 50 | _dict_to_emit_policy(policy_dict) 51 | 52 | 53 | def test_emit_policy_delay(): 54 | policy_dict = {"mode": EmitAfterDelay.name(), "delay": 8} 55 | policy = _dict_to_emit_policy(policy_dict) 56 | assert type(policy) == EmitAfterDelay 57 | assert policy.delay_in_seconds == 8 58 | 59 | 60 | def test_emit_policy_max_events(): 61 | policy_dict = {"mode": EmitAfterMaxEvent.name(), "maxEvents": 8} 62 | policy = _dict_to_emit_policy(policy_dict) 63 | assert type(policy) == EmitAfterMaxEvent 64 | assert policy.max_events == 8 65 | 66 | 67 | def test_emit_policy_window(): 68 | policy_dict = {"mode": EmitAfterWindow.name(), "delay": 8} 69 | policy = _dict_to_emit_policy(policy_dict) 70 | assert type(policy) == EmitAfterWindow 71 | assert policy.delay_in_seconds == 8 72 | 73 | 74 | def test_emit_policy_period(): 75 | policy_dict = {"mode": EmitAfterPeriod.name(), "delay": 8} 76 | policy = _dict_to_emit_policy(policy_dict) 77 | assert type(policy) == EmitAfterPeriod 78 | assert policy.delay_in_seconds == 8 79 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import datetime 16 | 17 | from fsspec.implementations.local import LocalFileSystem 18 | 19 | from storey.utils import find_filters, get_remaining_path, url_to_file_system 20 | 21 | 22 | def test_get_path_utils(): 23 | url = "wasbs://mycontainer@myaccount.blob.core.windows.net/path/to/object.csv" 24 | schema, path = get_remaining_path(url) 25 | assert path == "mycontainer/path/to/object.csv" 26 | assert schema == "wasbs" 27 | 28 | 29 | def test_ds_get_path_utils(): 30 | url = "ds://:file@profile/path/to/object.csv" 31 | fs, path = url_to_file_system(url, "") 32 | assert path == "/path/to/object.csv" 33 | assert isinstance(fs, LocalFileSystem) 34 | 35 | 36 | def test_find_filters(): 37 | filters = [] 38 | find_filters([], datetime.datetime.min, datetime.datetime.max, filters, "time") 39 | assert filters == [[("time", ">", datetime.datetime.min), ("time", "<=", datetime.datetime.max)]] 40 | filters = [] 41 | find_filters([], None, datetime.datetime.max, filters, "time") 42 | assert filters == [[("time", "<=", datetime.datetime.max)]] 43 | filters = [] 44 | find_filters([], None, None, filters, None) 45 | assert filters == [[]] 46 | -------------------------------------------------------------------------------- /tests/test_v3io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | import base64 16 | import json 17 | from datetime import datetime 18 | 19 | from integration.integration_test_utils import _v3io_parse_get_item_response 20 | 21 | 22 | def test_v3io_parse_get_item_response(): 23 | request = json.dumps( 24 | { 25 | "Item": { 26 | "int": {"N": "55"}, 27 | "float": {"N": "55.4"}, 28 | "string": {"S": "der die das"}, 29 | "boolean": {"BOOL": True}, 30 | "blob": {"B": base64.b64encode(b"message in a bottle").decode("ascii")}, 31 | "timestamp": {"TS": "1594289596:123456"}, 32 | } 33 | } 34 | ) 35 | response = _v3io_parse_get_item_response(request) 36 | expected = { 37 | "int": 55, 38 | "float": 55.4, 39 | "string": "der die das", 40 | "boolean": True, 41 | "blob": b"message in a bottle", 42 | "timestamp": datetime(2020, 7, 9, 10, 13, 16, 124), 43 | } 44 | assert response == expected 45 | -------------------------------------------------------------------------------- /tests/test_windowed_store.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Iguazio 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from datetime import datetime, timedelta 16 | 17 | from storey import Filter, Reduce, SyncEmitSource, build_flow 18 | from storey.dtypes import EmissionType, SlidingWindow 19 | from storey.windowed_store import EmitAfterMaxEvent, Window 20 | 21 | 22 | def append_return(lst, x): 23 | lst.append(x) 24 | return lst 25 | 26 | 27 | def validate_window(expected, window): 28 | for elem in window: 29 | key = elem[0] 30 | data = elem[1] 31 | for column in data.features: 32 | index = 0 33 | for bucket in data.features[column]: 34 | if len(bucket.data) > 0: 35 | assert bucket.data == expected[key][index] 36 | 37 | index = index + 1 38 | 39 | 40 | def to_millis(t): 41 | return t.timestamp() * 1000 42 | 43 | 44 | def test_windowed_flow(): 45 | controller = build_flow( 46 | [ 47 | SyncEmitSource(), 48 | Filter(lambda x: x["col1"] > 3), 49 | Window( 50 | SlidingWindow("30m", "5m"), 51 | "time", 52 | EmitAfterMaxEvent(max_events=3, emission_type=EmissionType.Incremental), 53 | ), 54 | Reduce([], lambda acc, x: append_return(acc, x)), 55 | ] 56 | ).run() 57 | 58 | base_time = datetime.now() 59 | 60 | for i in range(10): 61 | data = {"col1": i, "time": base_time + timedelta(minutes=i)} 62 | controller.emit(data, f"{i % 2}") 63 | 64 | controller.terminate() 65 | window_list = controller.await_termination() 66 | assert len(window_list) == 2 67 | 68 | expected_window_1 = { 69 | "0": { 70 | 0: [(to_millis(base_time + timedelta(minutes=4)), 4)], 71 | 1: [(to_millis(base_time + timedelta(minutes=6)), 6)], 72 | }, 73 | "1": {0: [(to_millis(base_time + timedelta(minutes=5)), 5)]}, 74 | } 75 | 76 | expected_window_2 = { 77 | "0": {1: [(to_millis(base_time + timedelta(minutes=8)), 8)]}, 78 | "1": { 79 | 1: [ 80 | (to_millis(base_time + timedelta(minutes=7)), 7), 81 | (to_millis(base_time + timedelta(minutes=9)), 9), 82 | ] 83 | }, 84 | } 85 | 86 | validate_window(expected_window_1, window_list[0]) 87 | validate_window(expected_window_2, window_list[1]) 88 | --------------------------------------------------------------------------------