├── .flake8 ├── .github └── workflows │ ├── benchmarks.yml │ └── ci.yml ├── .gitignore ├── .licenserc.yaml ├── .style.yapf ├── LICENSE ├── README.md ├── benchmark └── basic.py ├── examples ├── run_tft_iris_example.py ├── tft_iris_example │ ├── __pycache__ │ │ └── preprocessing.cpython-37.pyc │ ├── data │ │ └── input_data │ │ │ └── input_data.csv │ ├── model_training_pipeline.py │ └── preprocessing.py └── word_count_metrics.py ├── ray_beam_runner ├── __init__.py ├── collection.py ├── custom_actor_pool.py ├── overrides.py ├── portability │ ├── __init__.py │ ├── context_management.py │ ├── execution.py │ ├── execution_test.py │ ├── ray_fn_runner.py │ ├── ray_runner_test.py │ └── state.py ├── ray_runner.py ├── serialization.py ├── translator.py └── util.py ├── requirements_dev.txt ├── scripts └── format.sh └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | [flake8] 19 | max-line-length = 88 20 | inline-quotes = " 21 | ignore = 22 | C408 23 | E121 24 | E123 25 | E126 26 | E203 27 | E226 28 | E24 29 | E704 30 | W503 31 | W504 32 | W605 33 | I 34 | N 35 | avoid-escape = no 36 | per-file-ignores = 37 | *ray_runner_test.py: B008 38 | -------------------------------------------------------------------------------- /.github/workflows/benchmarks.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # 18 | name: Benchmarks 19 | 20 | on: 21 | push: 22 | branches: [master] 23 | 24 | jobs: 25 | Benchmarks: 26 | runs-on: ubuntu-latest 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | with: 31 | fetch-depth: 0 32 | - name: Install python 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: 3.8 36 | - name: Install Ray Beam Runner 37 | run: | 38 | pip install -e .[test] 39 | - name: Install dependencies 40 | run: | 41 | python -m pip install --upgrade pip 42 | pip install -r requirements_dev.txt 43 | pip install -U "ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl" 44 | - name: Run Benchmark 45 | run: | 46 | ray start --head 47 | pytest -v benchmark/basic.py --benchmark-only --benchmark-json bm-result.json 48 | cat bm-result.json 49 | ray stop 50 | - name: Store benchmark result 51 | uses: benchmark-action/github-action-benchmark@v1 52 | with: 53 | name: Python Benchmark with pytest-benchmark 54 | tool: 'pytest' 55 | output-file-path: bm-result.json 56 | github-token: ${{ secrets.GITHUB_TOKEN }} 57 | auto-push: true 58 | # Show alert with commit comment on detecting possible performance regression 59 | alert-threshold: '200%' 60 | comment-on-alert: true 61 | fail-on-alert: true 62 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # 18 | name: CI 19 | 20 | on: 21 | pull_request: 22 | branches: [ master ] 23 | 24 | jobs: 25 | 26 | Code-Health: 27 | runs-on: ubuntu-latest 28 | 29 | steps: 30 | - uses: actions/checkout@v3 31 | with: 32 | fetch-depth: 0 33 | - name: Install python 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: 3.8 37 | - name: Install Ray Beam Runner 38 | run: | 39 | pip install -e .[test] 40 | - name: Install dependencies 41 | run: | 42 | python -m pip install --upgrade pip 43 | pip install -r requirements_dev.txt 44 | - name: Format 45 | run: | 46 | bash scripts/format.sh 47 | 48 | Tests: 49 | runs-on: ubuntu-latest 50 | strategy: 51 | matrix: 52 | python-version: [{"v":"3.8", "whl":"38"}, {"v": "3.9", "whl": "39"}, {"v": "3.10", "whl": "310"}] 53 | 54 | steps: 55 | - uses: actions/checkout@v3 56 | with: 57 | fetch-depth: 0 58 | - name: Install python ${{ matrix.python-version.v }} 59 | uses: actions/setup-python@v3 60 | with: 61 | python-version: ${{ matrix.python-version.v }} 62 | - name: Install Ray Beam Runner 63 | run: | 64 | pip install -e .[test] 65 | - name: Install dependencies 66 | run: | 67 | python -m pip install --upgrade pip 68 | pip install -r requirements_dev.txt 69 | pip install -U "ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp${{ matrix.python-version.whl }}-cp${{ matrix.python-version.whl}}-manylinux2014_x86_64.whl" 70 | - name: Run Portability tests 71 | run: | 72 | pytest -r A ray_beam_runner/portability/ray_runner_test.py ray_beam_runner/portability/execution_test.py 73 | 74 | LicenseCheck: 75 | name: License Check 76 | 77 | runs-on: ubuntu-latest 78 | 79 | steps: 80 | - name: Checkout repository 81 | uses: actions/checkout@v3 82 | with: 83 | fetch-depth: 0 84 | - name: Check License Header 85 | uses: apache/skywalking-eyes@985866ce7e324454f61e22eb2db2e998db09d6f3 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # 18 | # Python byte code files 19 | *.pyc 20 | python/.eggs 21 | 22 | # Backup files 23 | *.bak 24 | 25 | # Emacs temporary files 26 | *~ 27 | *# 28 | 29 | # Debug symbols 30 | *.pdb 31 | 32 | # Visual Studio files 33 | /packages 34 | *.suo 35 | *.user 36 | *.VC.db 37 | *.VC.opendb 38 | 39 | # Protobuf-generated files 40 | *_pb2.py 41 | *.pb.h 42 | *.pb.cc 43 | 44 | # Ray cluster configuration 45 | scripts/nodes.txt 46 | 47 | # OS X folder attributes 48 | .DS_Store 49 | 50 | # Debug files 51 | *.dSYM/ 52 | *.su 53 | 54 | # Python setup files 55 | *.egg-info 56 | 57 | # Compressed files 58 | *.gz 59 | 60 | # Datasets from examples 61 | **/MNIST_data/ 62 | **/cifar-10-batches-bin/ 63 | 64 | # Generated documentation files 65 | /doc/_build 66 | /doc/source/_static/thumbs 67 | /doc/source/tune/generated_guides/ 68 | 69 | # User-specific stuff: 70 | .idea/ 71 | 72 | # Pytest Cache 73 | **/.pytest_cache 74 | **/.cache 75 | .benchmarks 76 | python-driver-* 77 | 78 | # Vscode 79 | .vscode/ 80 | 81 | *.iml 82 | 83 | # python virtual env 84 | venv 85 | 86 | # pyenv version file 87 | .python-version 88 | 89 | # Vim 90 | .*.swp 91 | *.swp 92 | tags 93 | 94 | # Emacs 95 | .#* 96 | 97 | # tools 98 | tools/prometheus* 99 | 100 | # ray project files 101 | project-id 102 | .mypy_cache/ 103 | 104 | # Downloaded test data 105 | *.csv 106 | *.csv.gz 107 | *.parquet 108 | -------------------------------------------------------------------------------- /.licenserc.yaml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | header: 18 | license: 19 | spdx-id: Apache-2.0 20 | copyright-owner: Apache Software Foundation 21 | 22 | paths-ignore: 23 | - 'dist' 24 | - 'licenses' 25 | - '**/*.md' 26 | - 'LICENSE' 27 | - 'NOTICE' 28 | - '**/*.csv' 29 | 30 | comment: never 31 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | [style] 19 | based_on_style=pep8 20 | allow_split_before_dict_value=False 21 | join_multiple_lines=False 22 | allow_multiline_lambdas=True 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- 204 | 205 | Code in python/ray/rllib/{evolution_strategies, dqn} adapted from 206 | https://github.com/openai (MIT License) 207 | 208 | Copyright (c) 2016 OpenAI (http://openai.com) 209 | 210 | Permission is hereby granted, free of charge, to any person obtaining a copy 211 | of this software and associated documentation files (the "Software"), to deal 212 | in the Software without restriction, including without limitation the rights 213 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 214 | copies of the Software, and to permit persons to whom the Software is 215 | furnished to do so, subject to the following conditions: 216 | 217 | The above copyright notice and this permission notice shall be included in 218 | all copies or substantial portions of the Software. 219 | 220 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 221 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 222 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 223 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 224 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 225 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 226 | THE SOFTWARE. 227 | 228 | -------------------------------------------------------------------------------- 229 | 230 | Code in python/ray/rllib/impala/vtrace.py from 231 | https://github.com/deepmind/scalable_agent 232 | 233 | Copyright 2018 Google LLC 234 | 235 | Licensed under the Apache License, Version 2.0 (the "License"); 236 | you may not use this file except in compliance with the License. 237 | You may obtain a copy of the License at 238 | 239 | https://www.apache.org/licenses/LICENSE-2.0 240 | 241 | Unless required by applicable law or agreed to in writing, software 242 | distributed under the License is distributed on an "AS IS" BASIS, 243 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 244 | See the License for the specific language governing permissions and 245 | limitations under the License. 246 | 247 | -------------------------------------------------------------------------------- 248 | Code in python/ray/rllib/ars is adapted from https://github.com/modestyachts/ARS 249 | 250 | Copyright (c) 2018, ARS contributors (Horia Mania, Aurelia Guy, Benjamin Recht) 251 | All rights reserved. 252 | 253 | Redistribution and use of ARS in source and binary forms, with or without 254 | modification, are permitted provided that the following conditions are met: 255 | 256 | 1. Redistributions of source code must retain the above copyright notice, this 257 | list of conditions and the following disclaimer. 258 | 259 | 2. Redistributions in binary form must reproduce the above copyright notice, 260 | this list of conditions and the following disclaimer in the documentation and/or 261 | other materials provided with the distribution. 262 | 263 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 264 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 265 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 266 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 267 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 268 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 269 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 270 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 271 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 272 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 273 | 274 | ------------------ 275 | Code in python/ray/prometheus_exporter.py is adapted from https://github.com/census-instrumentation/opencensus-python/blob/master/contrib/opencensus-ext-prometheus/opencensus/ext/prometheus/stats_exporter/__init__.py 276 | 277 | # Copyright 2018, OpenCensus Authors 278 | # 279 | # Licensed under the Apache License, Version 2.0 (the "License"); 280 | # you may not use this file except in compliance with the License. 281 | # You may obtain a copy of the License at 282 | # 283 | # http://www.apache.org/licenses/LICENSE-2.0 284 | # 285 | # Unless required by applicable law or agreed to in writing, software 286 | # distributed under the License is distributed on an "AS IS" BASIS, 287 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 288 | # See the License for the specific language governing permissions and 289 | # limitations under the License. 290 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Ray-based Apache Beam Runner 2 | ========================== 3 | This is a WIP proof of concept implementation undergoing frequent breaking changes and should not be used in production. 4 | 5 | 6 | ## The Portable Ray Beam Runner 7 | 8 | The directory `ray_beam_runner/portability` contains a prototype for an implementation of a [Beam](https://beam.apache.org) 9 | runner for [Ray](https://ray.io) that relies on Beam's portability framework. 10 | 11 | ### Install and build from source 12 | 13 | To install the existing Ray Beam runner from a clone of this repository, you can follow the next steps: 14 | 15 | ```shell 16 | # First create a virtual environment to install and run Python dependencies 17 | virtualenv venv 18 | . venv/bin/activate 19 | 20 | # Install development dependencies for the project 21 | pip install -r requirements_dev.txt 22 | 23 | # Create a local installation to include test dependencies 24 | pip install -e .[test] 25 | # Or if saw error messages like "zsh: no matches found: .[test]", try: 26 | pip install -e '.[test]' 27 | ``` 28 | 29 | ### Testing 30 | 31 | The project has extensive unit tests that can run on a local environment. Tests that verify the basic runner 32 | functionality exist in the file `ray_beam_runner/portability/ray_runner_test.py`. 33 | 34 | **To run the runner functionality test suite** for the Ray Beam Runner, you can run the following command: 35 | 36 | ```shell 37 | pytest ray_beam_runner/portability/ray_runner_test.py 38 | ``` 39 | 40 | To run all local unit tests, you can simply run `pytest` from the root directory. 41 | 42 | ### Found issues? Want to help? 43 | 44 | Please file any problems with the runner in [this repository's issue section](https://github.com/ray-project/ray_beam_runner/issues). 45 | If you would like to help, please **look at the issues as well**. You can pick up one of them and try to implement 46 | it. 47 | 48 | ### Performance testing 49 | 50 | ```shell 51 | # TODO: Write these tests and document how to run them. 52 | ``` -------------------------------------------------------------------------------- /benchmark/basic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Licensed to the Apache Software Foundation (ASF) under one or more 4 | # contributor license agreements. See the NOTICE file distributed with 5 | # this work for additional information regarding copyright ownership. 6 | # The ASF licenses this file to You under the Apache License, Version 2.0 7 | # (the "License"); you may not use this file except in compliance with 8 | # the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | import ray 20 | from apache_beam.tools import fn_api_runner_microbenchmark as bm 21 | 22 | 23 | def run_pipeline(): 24 | bm.run_benchmark(1, 10, 100, True) 25 | return None 26 | 27 | 28 | def test_basic_benchmark(benchmark): 29 | ray.init() 30 | 31 | res = benchmark(run_pipeline) 32 | 33 | ray.shutdown() 34 | assert res is None 35 | -------------------------------------------------------------------------------- /examples/run_tft_iris_example.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import tempfile 19 | 20 | import apache_beam as beam 21 | from apache_beam.options.pipeline_options import PipelineOptions 22 | from apache_beam.runners.direct import DirectRunner 23 | import tensorflow_transform.beam as tft_beam 24 | 25 | from ray_beam_runner.ray_runner import RayRunner 26 | 27 | from tft_iris_example.preprocessing import ( 28 | Split, create_raw_metadata, analyze_and_transform, write_tfrecords, 29 | write_transform_artefacts) 30 | 31 | INPUT_FILENAME = "tft_iris_example/data/input_data/input_data.csv" 32 | OUTPUT_FILENAME = "tft_iris_example/data/output/preprocessed_data" 33 | OUTPUT_TRANSFORM_FUNCTION_FOLDER = "tft_iris_example/" \ 34 | "data/output/transform_artfcts" 35 | 36 | 37 | def run_transformation_pipeline(raw_input_location, transformed_data_location, 38 | transform_artefact_location): 39 | pipeline_options = PipelineOptions(["--parallelism=1"]) 40 | runner_cls = RayRunner 41 | 42 | # pipeline_options = PipelineOptions() 43 | # runner_cls = DirectRunner 44 | 45 | with beam.Pipeline( 46 | runner=runner_cls(), options=pipeline_options) as pipeline: 47 | with tft_beam.Context(temp_dir=tempfile.mkdtemp()): 48 | raw_data = (pipeline | beam.io.ReadFromText( 49 | raw_input_location, skip_header_lines=1) 50 | | beam.ParDo(Split())) 51 | raw_metadata = create_raw_metadata() 52 | raw_dataset = (raw_data, raw_metadata) 53 | transformed_dataset, transform_fn = analyze_and_transform( 54 | raw_dataset) 55 | # transformed_dataset[0] | beam.Map(print) 56 | write_tfrecords(transformed_dataset, transformed_data_location) 57 | write_transform_artefacts(transform_fn, 58 | transform_artefact_location) 59 | 60 | 61 | if __name__ == "__main__": 62 | run_transformation_pipeline( 63 | raw_input_location=INPUT_FILENAME, 64 | transformed_data_location=OUTPUT_FILENAME, 65 | transform_artefact_location=OUTPUT_TRANSFORM_FUNCTION_FOLDER) 66 | -------------------------------------------------------------------------------- /examples/tft_iris_example/__pycache__/preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ray-project/ray_beam_runner/ecc9dba99dc9cbe51c2bbcb1fd472b288a89d1ba/examples/tft_iris_example/__pycache__/preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /examples/tft_iris_example/data/input_data/input_data.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,target 2 | 5.1,3.5,1.4,0.2,Setosa 3 | 4.9,3.0,1.4,0.2,Setosa 4 | 4.7,3.2,1.3,0.2,Setosa 5 | 4.6,3.1,1.5,0.2,Setosa 6 | 5.0,3.6,1.4,0.2,Setosa 7 | 5.4,3.9,1.7,0.4,Setosa 8 | 4.6,3.4,1.4,0.3,Setosa 9 | 5.0,3.4,1.5,0.2,Setosa 10 | 4.4,2.9,1.4,0.2,Setosa 11 | 4.9,3.1,1.5,0.1,Setosa 12 | 5.4,3.7,1.5,0.2,Setosa 13 | 4.8,3.4,1.6,0.2,Setosa 14 | 4.8,3.0,1.4,0.1,Setosa 15 | 4.3,3.0,1.1,0.1,Setosa 16 | 5.8,4.0,1.2,0.2,Setosa 17 | 5.7,4.4,1.5,0.4,Setosa 18 | 5.4,3.9,1.3,0.4,Setosa 19 | 5.1,3.5,1.4,0.3,Setosa 20 | 5.7,3.8,1.7,0.3,Setosa 21 | 5.1,3.8,1.5,0.3,Setosa 22 | 5.4,3.4,1.7,0.2,Setosa 23 | 5.1,3.7,1.5,0.4,Setosa 24 | 4.6,3.6,1.0,0.2,Setosa 25 | 5.1,3.3,1.7,0.5,Setosa 26 | 4.8,3.4,1.9,0.2,Setosa 27 | 5.0,3.0,1.6,0.2,Setosa 28 | 5.0,3.4,1.6,0.4,Setosa 29 | 5.2,3.5,1.5,0.2,Setosa 30 | 5.2,3.4,1.4,0.2,Setosa 31 | 4.7,3.2,1.6,0.2,Setosa 32 | 4.8,3.1,1.6,0.2,Setosa 33 | 5.4,3.4,1.5,0.4,Setosa 34 | 5.2,4.1,1.5,0.1,Setosa 35 | 5.5,4.2,1.4,0.2,Setosa 36 | 4.9,3.1,1.5,0.2,Setosa 37 | 5.0,3.2,1.2,0.2,Setosa 38 | 5.5,3.5,1.3,0.2,Setosa 39 | 4.9,3.6,1.4,0.1,Setosa 40 | 4.4,3.0,1.3,0.2,Setosa 41 | 5.1,3.4,1.5,0.2,Setosa 42 | 5.0,3.5,1.3,0.3,Setosa 43 | 4.5,2.3,1.3,0.3,Setosa 44 | 4.4,3.2,1.3,0.2,Setosa 45 | 5.0,3.5,1.6,0.6,Setosa 46 | 5.1,3.8,1.9,0.4,Setosa 47 | 4.8,3.0,1.4,0.3,Setosa 48 | 5.1,3.8,1.6,0.2,Setosa 49 | 4.6,3.2,1.4,0.2,Setosa 50 | 5.3,3.7,1.5,0.2,Setosa 51 | 5.0,3.3,1.4,0.2,Setosa 52 | 7.0,3.2,4.7,1.4,Versicolour 53 | 6.4,3.2,4.5,1.5,Versicolour 54 | 6.9,3.1,4.9,1.5,Versicolour 55 | 5.5,2.3,4.0,1.3,Versicolour 56 | 6.5,2.8,4.6,1.5,Versicolour 57 | 5.7,2.8,4.5,1.3,Versicolour 58 | 6.3,3.3,4.7,1.6,Versicolour 59 | 4.9,2.4,3.3,1.0,Versicolour 60 | 6.6,2.9,4.6,1.3,Versicolour 61 | 5.2,2.7,3.9,1.4,Versicolour 62 | 5.0,2.0,3.5,1.0,Versicolour 63 | 5.9,3.0,4.2,1.5,Versicolour 64 | 6.0,2.2,4.0,1.0,Versicolour 65 | 6.1,2.9,4.7,1.4,Versicolour 66 | 5.6,2.9,3.6,1.3,Versicolour 67 | 6.7,3.1,4.4,1.4,Versicolour 68 | 5.6,3.0,4.5,1.5,Versicolour 69 | 5.8,2.7,4.1,1.0,Versicolour 70 | 6.2,2.2,4.5,1.5,Versicolour 71 | 5.6,2.5,3.9,1.1,Versicolour 72 | 5.9,3.2,4.8,1.8,Versicolour 73 | 6.1,2.8,4.0,1.3,Versicolour 74 | 6.3,2.5,4.9,1.5,Versicolour 75 | 6.1,2.8,4.7,1.2,Versicolour 76 | 6.4,2.9,4.3,1.3,Versicolour 77 | 6.6,3.0,4.4,1.4,Versicolour 78 | 6.8,2.8,4.8,1.4,Versicolour 79 | 6.7,3.0,5.0,1.7,Versicolour 80 | 6.0,2.9,4.5,1.5,Versicolour 81 | 5.7,2.6,3.5,1.0,Versicolour 82 | 5.5,2.4,3.8,1.1,Versicolour 83 | 5.5,2.4,3.7,1.0,Versicolour 84 | 5.8,2.7,3.9,1.2,Versicolour 85 | 6.0,2.7,5.1,1.6,Versicolour 86 | 5.4,3.0,4.5,1.5,Versicolour 87 | 6.0,3.4,4.5,1.6,Versicolour 88 | 6.7,3.1,4.7,1.5,Versicolour 89 | 6.3,2.3,4.4,1.3,Versicolour 90 | 5.6,3.0,4.1,1.3,Versicolour 91 | 5.5,2.5,4.0,1.3,Versicolour 92 | 5.5,2.6,4.4,1.2,Versicolour 93 | 6.1,3.0,4.6,1.4,Versicolour 94 | 5.8,2.6,4.0,1.2,Versicolour 95 | 5.0,2.3,3.3,1.0,Versicolour 96 | 5.6,2.7,4.2,1.3,Versicolour 97 | 5.7,3.0,4.2,1.2,Versicolour 98 | 5.7,2.9,4.2,1.3,Versicolour 99 | 6.2,2.9,4.3,1.3,Versicolour 100 | 5.1,2.5,3.0,1.1,Versicolour 101 | 5.7,2.8,4.1,1.3,Versicolour 102 | 6.3,3.3,6.0,2.5,Virginica 103 | 5.8,2.7,5.1,1.9,Virginica 104 | 7.1,3.0,5.9,2.1,Virginica 105 | 6.3,2.9,5.6,1.8,Virginica 106 | 6.5,3.0,5.8,2.2,Virginica 107 | 7.6,3.0,6.6,2.1,Virginica 108 | 4.9,2.5,4.5,1.7,Virginica 109 | 7.3,2.9,6.3,1.8,Virginica 110 | 6.7,2.5,5.8,1.8,Virginica 111 | 7.2,3.6,6.1,2.5,Virginica 112 | 6.5,3.2,5.1,2.0,Virginica 113 | 6.4,2.7,5.3,1.9,Virginica 114 | 6.8,3.0,5.5,2.1,Virginica 115 | 5.7,2.5,5.0,2.0,Virginica 116 | 5.8,2.8,5.1,2.4,Virginica 117 | 6.4,3.2,5.3,2.3,Virginica 118 | 6.5,3.0,5.5,1.8,Virginica 119 | 7.7,3.8,6.7,2.2,Virginica 120 | 7.7,2.6,6.9,2.3,Virginica 121 | 6.0,2.2,5.0,1.5,Virginica 122 | 6.9,3.2,5.7,2.3,Virginica 123 | 5.6,2.8,4.9,2.0,Virginica 124 | 7.7,2.8,6.7,2.0,Virginica 125 | 6.3,2.7,4.9,1.8,Virginica 126 | 6.7,3.3,5.7,2.1,Virginica 127 | 7.2,3.2,6.0,1.8,Virginica 128 | 6.2,2.8,4.8,1.8,Virginica 129 | 6.1,3.0,4.9,1.8,Virginica 130 | 6.4,2.8,5.6,2.1,Virginica 131 | 7.2,3.0,5.8,1.6,Virginica 132 | 7.4,2.8,6.1,1.9,Virginica 133 | 7.9,3.8,6.4,2.0,Virginica 134 | 6.4,2.8,5.6,2.2,Virginica 135 | 6.3,2.8,5.1,1.5,Virginica 136 | 6.1,2.6,5.6,1.4,Virginica 137 | 7.7,3.0,6.1,2.3,Virginica 138 | 6.3,3.4,5.6,2.4,Virginica 139 | 6.4,3.1,5.5,1.8,Virginica 140 | 6.0,3.0,4.8,1.8,Virginica 141 | 6.9,3.1,5.4,2.1,Virginica 142 | 6.7,3.1,5.6,2.4,Virginica 143 | 6.9,3.1,5.1,2.3,Virginica 144 | 5.8,2.7,5.1,1.9,Virginica 145 | 6.8,3.2,5.9,2.3,Virginica 146 | 6.7,3.3,5.7,2.5,Virginica 147 | 6.7,3.0,5.2,2.3,Virginica 148 | 6.3,2.5,5.0,1.9,Virginica 149 | 6.5,3.0,5.2,2.0,Virginica 150 | 6.2,3.4,5.4,2.3,Virginica 151 | 5.9,3.0,5.1,1.8,Virginica -------------------------------------------------------------------------------- /examples/tft_iris_example/model_training_pipeline.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # flake8: noqa 18 | import tensorflow as tf 19 | import tensorflow_transform as tft 20 | from functools import partial 21 | import os 22 | import numpy as np 23 | from tensorflow.keras.layers import Input, Dense 24 | from tensorflow.keras import Model 25 | 26 | # VARIANT = "with_actors" 27 | # VARIANT = "with_tasks" 28 | # VARIANT = "with_directrunner" 29 | VARIANT = "." 30 | 31 | TFT_OUTPUT_DIRECTORY = f"data/output/{VARIANT}/transform_artfcts/" 32 | TFRECORD_PATH = f"data/output/{VARIANT}/preprocessed_data/" 33 | 34 | np.random.seed(1234) 35 | tf.random.set_seed(1234) 36 | 37 | 38 | def _parse_function(proto, fs): 39 | parsed_features = tf.io.parse_single_example(proto, fs) 40 | return [ 41 | parsed_features["petal_length_normalized"], 42 | parsed_features["petal_width_normalized"], 43 | parsed_features["sepal_length_normalized"], 44 | parsed_features["sepal_width_normalized"] 45 | ], parsed_features["target"] 46 | 47 | 48 | def get_parse_function(tft_output_dir): 49 | tf_transform_output = tft.TFTransformOutput(tft_output_dir) 50 | feature_spec = tf_transform_output.transformed_feature_spec() 51 | parse_func = lambda proto: _parse_function(proto=proto, fs=feature_spec) 52 | return parse_func 53 | 54 | 55 | def load_dataset(input_path, tft_output_dir, batch_size=10, shuffle_buffer=2): 56 | _parse_fn = get_parse_function(tft_output_dir=tft_output_dir) 57 | filenames = [ 58 | os.path.join(input_path, x) for x in tf.io.gfile.listdir(input_path) 59 | ] 60 | dataset = tf.data.TFRecordDataset(filenames) 61 | dataset = dataset.map(_parse_fn) 62 | dataset = dataset.batch(batch_size) 63 | return dataset 64 | 65 | 66 | dataset = load_dataset( 67 | input_path=TFRECORD_PATH, tft_output_dir=TFT_OUTPUT_DIRECTORY) 68 | 69 | 70 | def build_model(input_shape=4, validation_split=0.1): 71 | tf.keras.backend.one_hot 72 | input_layer = Input(shape=(4, )) 73 | hidden_layer = Dense(64, activation="relu") 74 | hidden_output = hidden_layer(input_layer) 75 | output_layer = Dense(3, activation="softmax") 76 | output = output_layer(input_layer) 77 | model = Model(inputs=input_layer, outputs=output) 78 | model.compile( 79 | loss="categorical_crossentropy", optimizer="adam", metrics="accuracy") 80 | return model 81 | 82 | 83 | model = build_model() 84 | _ = model.fit(dataset, epochs=100, verbose=False) 85 | 86 | input_args = { 87 | "sepal_length": np.array([[5.1]]), 88 | "sepal_width": np.array([[3.5]]), 89 | "petal_length": np.array([[1.4]]), 90 | "petal_width": np.array([[0.2]]), 91 | } 92 | 93 | signature_dict = { 94 | "petal_width": [ 95 | tf.TensorSpec(shape=[], dtype=tf.float32, name="petal_width") 96 | ], 97 | "petal_length": [ 98 | tf.TensorSpec(shape=[], dtype=tf.float32, name="petal_length") 99 | ], 100 | "sepal_width": [ 101 | tf.TensorSpec(shape=[], dtype=tf.float32, name="sepal_width") 102 | ], 103 | "sepal_length": [ 104 | tf.TensorSpec(shape=[], dtype=tf.float32, name="sepal_length") 105 | ], 106 | } 107 | 108 | 109 | class MyModel(tf.keras.Model): 110 | def __init__(self, model, tft_output_dir): 111 | super(MyModel, self).__init__() 112 | self.model = model 113 | self.tft_layer = tft.TFTransformOutput( 114 | tft_output_dir).transform_features_layer() 115 | 116 | def call(self, inputs): 117 | transformed = self.tft_layer(inputs) 118 | 119 | modelInput = tf.stack( 120 | [ 121 | transformed["petal_length_normalized"], 122 | transformed["petal_width_normalized"], 123 | transformed["sepal_length_normalized"], 124 | transformed["sepal_width_normalized"] 125 | ], 126 | axis=1) 127 | pred = self.model(modelInput) 128 | return {"prediction": pred} 129 | 130 | 131 | myModel = MyModel(model, TFT_OUTPUT_DIRECTORY) 132 | 133 | print(myModel.predict(input_args)) 134 | 135 | myModel.save(f"data/output/{VARIANT}/saved_model") 136 | -------------------------------------------------------------------------------- /examples/tft_iris_example/preprocessing.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # flake8: noqa 18 | import pprint 19 | import tempfile 20 | import unicodedata 21 | import os 22 | 23 | import tensorflow as tf 24 | import tensorflow_transform as tft 25 | # import numpy as np 26 | import apache_beam as beam 27 | from apache_beam.options.pipeline_options import PipelineOptions 28 | 29 | import tensorflow_transform.beam as tft_beam 30 | from tensorflow_transform.tf_metadata import dataset_metadata 31 | # from tensorflow_transform.tf_metadata import dataset_schema 32 | from tensorflow_transform.tf_metadata import schema_utils 33 | 34 | INPUT_FILENAME = 'data/input/input_data.csv' 35 | OUTPUT_FILENAME = 'data/output/preprocessed_data' 36 | OUTPUT_TRANSFORM_FUNCTION_FOLDER = 'data/output/transform_artfcts' 37 | 38 | NUMERIC_FEATURE_KEYS = [ 39 | 'sepal_length', 'sepal_width', 'petal_length', 'petal_width' 40 | ] 41 | 42 | LABEL_KEY = 'target' 43 | 44 | 45 | def create_raw_metadata(): 46 | listFeatures = [(name, tf.io.FixedLenFeature([1], tf.float32)) 47 | for name in NUMERIC_FEATURE_KEYS 48 | ] + [(LABEL_KEY, tf.io.VarLenFeature(tf.string))] 49 | RAW_DATA_FEATURE_SPEC = dict(listFeatures) 50 | 51 | RAW_DATA_METADATA = tft.tf_metadata.dataset_metadata.DatasetMetadata( 52 | schema_utils.schema_from_feature_spec(RAW_DATA_FEATURE_SPEC)) 53 | return RAW_DATA_METADATA 54 | 55 | 56 | class Split(beam.DoFn): 57 | def process(self, element): 58 | import numpy as np 59 | sepal_length, sepal_width, petal_length, petal_width, target = element.split( 60 | ",") 61 | return [{ 62 | "sepal_length": np.array([float(sepal_length)]), 63 | "sepal_width": np.array([float(sepal_width)]), 64 | "petal_length": np.array([float(petal_length)]), 65 | "petal_width": np.array([float(petal_width)]), 66 | "target": target, 67 | }] 68 | 69 | 70 | def preprocess_fn(input_features): 71 | output_features = {} 72 | 73 | # Target feature 74 | # This is a SparseTensor because it is optional. Here we fill in a default 75 | # value when it is missing. This is useful when this column is missing during 76 | # inference 77 | sparse = tf.sparse.SparseTensor( 78 | indices=input_features[LABEL_KEY].indices, 79 | values=input_features[LABEL_KEY].values, 80 | dense_shape=[input_features[LABEL_KEY].dense_shape[0], 1]) 81 | dense = tf.sparse.to_dense(sp_input=sparse, default_value='') 82 | # # Reshaping from a batch of vectors of size 1 to a batch to scalars. 83 | dense = tf.squeeze(dense, axis=1) 84 | dense_integerized = tft.compute_and_apply_vocabulary( 85 | dense, vocab_filename="label_index_map") 86 | output_features['target'] = tf.one_hot(dense_integerized, depth=3) 87 | 88 | # normalization of continuous variables 89 | output_features['sepal_length_normalized'] = tft.scale_to_z_score( 90 | input_features['sepal_length']) 91 | output_features['sepal_width_normalized'] = tft.scale_to_z_score( 92 | input_features['sepal_width']) 93 | output_features['petal_length_normalized'] = tft.scale_to_z_score( 94 | input_features['petal_length']) 95 | output_features['petal_width_normalized'] = tft.scale_to_z_score( 96 | input_features['petal_width']) 97 | return output_features 98 | 99 | 100 | def analyze_and_transform(raw_dataset, step="Default"): 101 | transformed_dataset, transform_fn = raw_dataset | "{} - Analyze & Transform".format( 102 | step) >> tft_beam.AnalyzeAndTransformDataset(preprocess_fn) 103 | 104 | return transformed_dataset, transform_fn 105 | 106 | 107 | def write_tfrecords(dataset, location, step="Default"): 108 | transformed_data, transformed_metadata = dataset 109 | (transformed_data 110 | | "{} - Write Transformed Data".format(step) >> 111 | beam.io.tfrecordio.WriteToTFRecord( 112 | file_path_prefix=os.path.join(location, "{}-".format(step)), 113 | file_name_suffix=".tfrecords", 114 | coder=tft.coders.example_proto_coder.ExampleProtoCoder( 115 | transformed_metadata.schema), 116 | )) 117 | 118 | 119 | def write_transform_artefacts(transform_fn, location): 120 | (transform_fn | "Write Transform Artifacts" >> 121 | tft_beam.tft_beam_io.transform_fn_io.WriteTransformFn(location)) 122 | 123 | 124 | def run_transformation_pipeline(raw_input_location, transformed_data_location, 125 | transform_artefact_location): 126 | pipeline_options = PipelineOptions() 127 | 128 | with beam.Pipeline(options=pipeline_options) as pipeline: 129 | with tft_beam.Context(temp_dir=tempfile.mkdtemp()): 130 | raw_data = (pipeline | beam.io.ReadFromText( 131 | raw_input_location, skip_header_lines=True) 132 | | beam.ParDo(Split())) 133 | raw_metadata = create_raw_metadata() 134 | raw_dataset = (raw_data, raw_metadata) 135 | transformed_dataset, transform_fn = analyze_and_transform( 136 | raw_dataset) 137 | # transformed_dataset[0] | beam.Map(print) 138 | write_tfrecords(transformed_dataset, transformed_data_location) 139 | write_transform_artefacts(transform_fn, 140 | transform_artefact_location) 141 | 142 | 143 | if __name__ == "__main__": 144 | run_transformation_pipeline( 145 | raw_input_location=INPUT_FILENAME, 146 | transformed_data_location=OUTPUT_FILENAME, 147 | transform_artefact_location=OUTPUT_TRANSFORM_FUNCTION_FOLDER) 148 | -------------------------------------------------------------------------------- /examples/word_count_metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | """A word-counting workflow.""" 19 | 20 | # pytype: skip-file 21 | 22 | # beam-playground: 23 | # name: WordCountWithMetrics 24 | # description: A word-counting workflow with metrics. 25 | # multifile: false 26 | # default_example: true 27 | # pipeline_options: --output output.txt 28 | # context_line: 48 29 | # categories: 30 | # - Combiners 31 | # - Options 32 | # - Metrics 33 | # - Quickstart 34 | # complexity: MEDIUM 35 | # tags: 36 | # - count 37 | # - metrics 38 | # - strings 39 | 40 | import argparse 41 | import apache_beam as beam 42 | from apache_beam.io import ReadFromText 43 | from apache_beam.io import WriteToText 44 | from apache_beam.metrics import Metrics 45 | from apache_beam.metrics.metric import MetricsFilter 46 | from apache_beam.options.pipeline_options import PipelineOptions 47 | from apache_beam.options.pipeline_options import SetupOptions 48 | 49 | 50 | class WordExtractingDoFn(beam.DoFn): 51 | """Parse each line of input text into words.""" 52 | 53 | def __init__(self): 54 | # TODO(BEAM-6158): Revert the workaround once we can pickle super() on py3. 55 | # super().__init__() 56 | beam.DoFn.__init__(self) 57 | self.words_counter = Metrics.counter(self.__class__, "words") 58 | self.word_lengths_counter = Metrics.counter(self.__class__, "word_lengths") 59 | self.word_lengths_dist = Metrics.distribution(self.__class__, "word_len_dist") 60 | self.empty_line_counter = Metrics.counter(self.__class__, "empty_lines") 61 | 62 | def process(self, element): 63 | """Returns an iterator over the words of this element. 64 | 65 | The element is a line of text. If the line is blank, note that, too. 66 | 67 | Args: 68 | element: the element being processed 69 | 70 | Returns: 71 | The processed element. 72 | """ 73 | import re 74 | 75 | text_line = element.strip() 76 | if not text_line: 77 | self.empty_line_counter.inc(1) 78 | words = re.findall(r"[\w\']+", text_line, re.UNICODE) 79 | for w in words: 80 | self.words_counter.inc() 81 | self.word_lengths_counter.inc(len(w)) 82 | self.word_lengths_dist.update(len(w)) 83 | return words 84 | 85 | 86 | def main(argv=None, save_main_session=True, runner=None): 87 | """Main entry point; defines and runs the wordcount pipeline.""" 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument( 90 | "--input", dest="input", default="input.txt", help="Input file to process." 91 | ) 92 | parser.add_argument( 93 | "--output", 94 | dest="output", 95 | default="output.txt", 96 | help="Output file to write results to.", 97 | ) 98 | known_args, pipeline_args = parser.parse_known_args(argv) 99 | 100 | # We use the save_main_session option because one or more DoFn's in this 101 | # workflow rely on global context (e.g., a module imported at module level). 102 | pipeline_options = PipelineOptions(pipeline_args) 103 | pipeline_options.view_as(SetupOptions).save_main_session = save_main_session 104 | if runner == "ray": 105 | p = beam.Pipeline(runner=ray_runner.RayFnApiRunner(), options=pipeline_options) 106 | else: 107 | p = beam.Pipeline(options=pipeline_options) 108 | 109 | # Read the text file[pattern] into a PCollection. 110 | lines = p | "read" >> ReadFromText(known_args.input) 111 | 112 | # Count the occurrences of each word. 113 | def count_ones(word_ones): 114 | (word, ones) = word_ones 115 | return (word, sum(ones)) 116 | 117 | counts = ( 118 | lines 119 | | "split" >> (beam.ParDo(WordExtractingDoFn()).with_output_types(str)) 120 | | "pair_with_one" >> beam.Map(lambda x: (x, 1)) 121 | | "group" >> beam.GroupByKey() 122 | | "count" >> beam.Map(count_ones) 123 | ) 124 | 125 | # Format the counts into a PCollection of strings. 126 | def format_result(word_count): 127 | (word, count) = word_count 128 | return "%s: %d" % (word, count) 129 | 130 | output = counts | "format" >> beam.Map(format_result) 131 | 132 | # Write the output using a "Write" transform that has side effects. 133 | # pylint: disable=expression-not-assigned 134 | output | "write" >> WriteToText(known_args.output) 135 | 136 | result = p.run() 137 | result.wait_until_finish() 138 | 139 | # Do not query metrics when creating a template which doesn't run 140 | if ( 141 | not hasattr(result, "has_job") or result.has_job # direct runner 142 | ): # not just a template creation 143 | # Query element-wise metrics, e.g., counter, distribution 144 | empty_lines_filter = MetricsFilter().with_name("empty_lines") 145 | query_result = result.metrics().query(empty_lines_filter) 146 | if query_result["counters"]: 147 | empty_lines_counter = query_result["counters"][0] 148 | print(f"number of empty lines:{empty_lines_counter.result}") 149 | 150 | word_lengths_filter = MetricsFilter().with_name("word_len_dist") 151 | query_result = result.metrics().query(word_lengths_filter) 152 | if query_result["distributions"]: 153 | word_lengths_dist = query_result["distributions"][0] 154 | print("average word length:%.2f" % word_lengths_dist.result.mean) 155 | print(f"min word length: {word_lengths_dist.result.min}") 156 | print(f"max word length: {word_lengths_dist.result.max}") 157 | 158 | # #Query non-user metrics, e.g., start_bundle_msecs, process_bundle_msecs 159 | # result_metrics = result.monitoring_metrics() 160 | # #import pytest 161 | # #pytest.set_trace() 162 | # all_metrics_via_monitoring_infos = result_metrics.query() 163 | # print(all_metrics_via_monitoring_infos['counters'][0].result) 164 | # print(all_metrics_via_monitoring_infos['counters'][1].result) 165 | 166 | 167 | if __name__ == "__main__": 168 | # logging.getLogger().setLevel(logging.DEBUG) 169 | # Test 1: Ray Runner 170 | print("Ray Runner--------------->") 171 | import ray_beam_runner.portability.ray_fn_runner as ray_runner 172 | import ray 173 | 174 | ray.init(num_cpus=1, include_dashboard=False, ignore_reinit_error=True) 175 | main(runner="ray") 176 | ray.shutdown() 177 | print("Direct Runner--------------->") 178 | # Test 2: Direct Runner 179 | main() 180 | -------------------------------------------------------------------------------- /ray_beam_runner/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | -------------------------------------------------------------------------------- /ray_beam_runner/collection.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import ray 19 | 20 | from apache_beam.pvalue import PValue 21 | from apache_beam.typehints import Dict 22 | 23 | 24 | class CollectionMap: 25 | def __init__(self): 26 | self.ray_datasets: Dict[PValue, ray.data.Dataset] = {} 27 | 28 | def get(self, pvalue: PValue): 29 | return self.ray_datasets.get(pvalue, None) 30 | 31 | def set(self, pvalue: PValue, ray_dataset: ray.data.Dataset): 32 | self.ray_datasets[pvalue] = ray_dataset 33 | 34 | def has(self, pvalue: PValue): 35 | return pvalue in self.ray_datasets 36 | -------------------------------------------------------------------------------- /ray_beam_runner/custom_actor_pool.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import TypeVar, Iterable, Any 19 | 20 | import ray 21 | from ray.data.impl.compute import ComputeStrategy 22 | from ray.types import ObjectRef 23 | from ray.data.block import Block 24 | from ray.data.impl.block_list import BlockList 25 | from ray.data.impl.progress_bar import ProgressBar 26 | 27 | T = TypeVar("T") 28 | U = TypeVar("U") 29 | 30 | # A class type that implements __call__. 31 | CallableClass = type 32 | 33 | 34 | class CustomActorPool(ComputeStrategy): 35 | def __init__(self, worker_cls): 36 | self.workers = [] 37 | self.worker_cls = worker_cls 38 | 39 | def __del__(self): 40 | for w in self.workers: 41 | w.__ray_terminate__.remote() 42 | 43 | def apply(self, fn: Any, remote_args: dict, 44 | blocks: Iterable[Block]) -> Iterable[ObjectRef[Block]]: 45 | 46 | map_bar = ProgressBar("Map Progress", total=len(blocks)) 47 | 48 | if not remote_args: 49 | remote_args["num_cpus"] = 1 50 | BlockWorker = ray.remote(**remote_args)(self.worker_cls) 51 | 52 | self.workers = [BlockWorker.remote()] 53 | metadata_mapping = {} 54 | tasks = {w.ready.remote(): w for w in self.workers} 55 | ready_workers = set() 56 | blocks_in = [(b, m) for (b, m) in zip(blocks, blocks.get_metadata())] 57 | blocks_out = [] 58 | 59 | while len(blocks_out) < len(blocks): 60 | ready, _ = ray.wait( 61 | list(tasks), timeout=0.01, num_returns=1, fetch_local=False) 62 | if not ready: 63 | if len(ready_workers) / len(self.workers) > 0.8: 64 | w = BlockWorker.remote() 65 | self.workers.append(w) 66 | tasks[w.ready.remote()] = w 67 | map_bar.set_description( 68 | "Map Progress ({} actors {} pending)".format( 69 | len(ready_workers), 70 | len(self.workers) - len(ready_workers))) 71 | continue 72 | 73 | [obj_id] = ready 74 | worker = tasks[obj_id] 75 | del tasks[obj_id] 76 | 77 | # Process task result. 78 | if worker in ready_workers: 79 | blocks_out.append(obj_id) 80 | map_bar.update(1) 81 | else: 82 | ready_workers.add(worker) 83 | 84 | # Schedule a new task. 85 | if blocks_in: 86 | block_ref, meta_ref = worker.process_block.remote( 87 | *blocks_in.pop()) 88 | metadata_mapping[block_ref] = meta_ref 89 | tasks[block_ref] = worker 90 | 91 | new_metadata = ray.get([metadata_mapping[b] for b in blocks_out]) 92 | map_bar.close() 93 | return BlockList(blocks_out, new_metadata) 94 | -------------------------------------------------------------------------------- /ray_beam_runner/overrides.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | import typing 18 | 19 | from apache_beam import (pvalue, PTransform, Create, Reshuffle, Windowing, 20 | GroupByKey, ParDo) 21 | from apache_beam.io import Read 22 | from apache_beam.pipeline import PTransformOverride 23 | from apache_beam.runners.direct.direct_runner import _GroupAlsoByWindowDoFn 24 | from apache_beam.transforms.window import GlobalWindows 25 | from apache_beam import typehints 26 | from apache_beam.typehints import List, trivial_inference 27 | 28 | K = typing.TypeVar("K") 29 | V = typing.TypeVar("V") 30 | 31 | 32 | class _Create(PTransform): 33 | def __init__(self, values): 34 | self.values = values 35 | 36 | def expand(self, input_or_inputs): 37 | return pvalue.PCollection.from_(input_or_inputs) 38 | 39 | def get_windowing(self, inputs): 40 | # type: (typing.Any) -> Windowing 41 | return Windowing(GlobalWindows()) 42 | 43 | 44 | @typehints.with_input_types(K) 45 | @typehints.with_output_types(K) 46 | class _Reshuffle(PTransform): 47 | def expand(self, input_or_inputs): 48 | return pvalue.PCollection.from_(input_or_inputs) 49 | 50 | 51 | class _Read(PTransform): 52 | def __init__(self, source): 53 | self.source = source 54 | 55 | def expand(self, input_or_inputs): 56 | return pvalue.PCollection.from_(input_or_inputs) 57 | 58 | 59 | @typehints.with_input_types(typing.Tuple[K, V]) 60 | @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) 61 | class _GroupByKeyOnly(PTransform): 62 | def expand(self, input_or_inputs): 63 | return pvalue.PCollection.from_(input_or_inputs) 64 | 65 | def infer_output_type(self, input_type): 66 | key_type, value_type = trivial_inference.key_value_types(input_type) 67 | return typehints.KV[key_type, typehints.Iterable[value_type]] 68 | 69 | 70 | @typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]]) 71 | @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) 72 | class _GroupAlsoByWindow(ParDo): 73 | def __init__(self, windowing): 74 | super(_GroupAlsoByWindow, self).__init__( 75 | _GroupAlsoByWindowDoFn(windowing)) 76 | self.windowing = windowing 77 | 78 | def expand(self, input_or_inputs): 79 | return pvalue.PCollection.from_(input_or_inputs) 80 | 81 | 82 | @typehints.with_input_types(typing.Tuple[K, V]) 83 | @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) 84 | class _GroupByKey(PTransform): 85 | def expand(self, input_or_inputs): 86 | return ( 87 | input_or_inputs 88 | | "ReifyWindows" >> ParDo(GroupByKey.ReifyWindows()) 89 | | "GroupByKey" >> _GroupByKeyOnly() 90 | | "GroupByWindow" >> _GroupAlsoByWindow(input_or_inputs.windowing)) 91 | 92 | 93 | def _get_overrides() -> List[PTransformOverride]: 94 | class CreateOverride(PTransformOverride): 95 | def matches(self, applied_ptransform): 96 | return applied_ptransform.transform.__class__ == Create 97 | 98 | def get_replacement_transform_for_applied_ptransform( 99 | self, applied_ptransform): 100 | # Use specialized streaming implementation. 101 | transform = _Create(applied_ptransform.transform.values) 102 | return transform 103 | 104 | class ReshuffleOverride(PTransformOverride): 105 | def matches(self, applied_ptransform): 106 | return applied_ptransform.transform.__class__ == Reshuffle 107 | 108 | def get_replacement_transform_for_applied_ptransform( 109 | self, applied_ptransform): 110 | # Use specialized streaming implementation. 111 | transform = _Reshuffle() 112 | return transform 113 | 114 | class ReadOverride(PTransformOverride): 115 | def matches(self, applied_ptransform): 116 | return applied_ptransform.transform.__class__ == Read 117 | 118 | def get_replacement_transform_for_applied_ptransform( 119 | self, applied_ptransform): 120 | # Use specialized streaming implementation. 121 | transform = _Read(applied_ptransform.transform.source) 122 | return transform 123 | 124 | class GroupByKeyOverride(PTransformOverride): 125 | def matches(self, applied_ptransform): 126 | return applied_ptransform.transform.__class__ == GroupByKey 127 | 128 | def get_replacement_transform_for_applied_ptransform( 129 | self, applied_ptransform): 130 | # Use specialized streaming implementation. 131 | transform = _GroupByKey() 132 | return transform 133 | 134 | return [ 135 | CreateOverride(), 136 | ReshuffleOverride(), 137 | ReadOverride(), 138 | GroupByKeyOverride(), 139 | ] 140 | -------------------------------------------------------------------------------- /ray_beam_runner/portability/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | -------------------------------------------------------------------------------- /ray_beam_runner/portability/context_management.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | import typing 18 | from typing import Dict 19 | from typing import List 20 | from typing import Optional 21 | from typing import Tuple 22 | 23 | from apache_beam.portability.api import beam_fn_api_pb2 24 | from apache_beam.portability.api import beam_runner_api_pb2 25 | from apache_beam.portability.api import endpoints_pb2 26 | from apache_beam.runners.portability.fn_api_runner import execution as fn_execution 27 | from apache_beam.runners.portability.fn_api_runner import translations 28 | from apache_beam.runners.portability.fn_api_runner import worker_handlers 29 | from apache_beam.runners.portability.fn_api_runner.execution import PartitionableBuffer 30 | from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers 31 | from apache_beam.runners.portability.fn_api_runner.translations import DataOutput 32 | from apache_beam.runners.portability.fn_api_runner.translations import TimerFamilyId 33 | from apache_beam.runners.worker import bundle_processor 34 | from apache_beam.utils import proto_utils 35 | 36 | from ray_beam_runner.portability.execution import RayRunnerExecutionContext 37 | 38 | 39 | class RayBundleContextManager: 40 | def __init__( 41 | self, 42 | execution_context: RayRunnerExecutionContext, 43 | stage: translations.Stage, 44 | ) -> None: 45 | self.execution_context = execution_context 46 | self.stage = stage 47 | # self.extract_bundle_inputs_and_outputs() 48 | self.bundle_uid = self.execution_context.next_uid() 49 | 50 | # Properties that are lazily initialized 51 | self._process_bundle_descriptor = ( 52 | None 53 | ) # type: Optional[beam_fn_api_pb2.ProcessBundleDescriptor] 54 | self._worker_handlers = ( 55 | None 56 | ) # type: Optional[List[worker_handlers.WorkerHandler]] 57 | # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map 58 | # is built after self._process_bundle_descriptor is initialized. 59 | # This field can be used to tell whether current bundle has timers. 60 | self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]] 61 | 62 | def __reduce__(self): 63 | data = (self.execution_context, self.stage) 64 | 65 | def deserializer(args): 66 | RayBundleContextManager(args[0], args[1]) 67 | 68 | return (deserializer, data) 69 | 70 | @property 71 | def worker_handlers(self) -> List[worker_handlers.WorkerHandler]: 72 | return [] 73 | 74 | def data_api_service_descriptor( 75 | self, 76 | ) -> Optional[endpoints_pb2.ApiServiceDescriptor]: 77 | return endpoints_pb2.ApiServiceDescriptor(url="fake") 78 | 79 | def state_api_service_descriptor( 80 | self, 81 | ) -> Optional[endpoints_pb2.ApiServiceDescriptor]: 82 | return None 83 | 84 | @property 85 | def process_bundle_descriptor(self) -> beam_fn_api_pb2.ProcessBundleDescriptor: 86 | if self._process_bundle_descriptor is None: 87 | self._process_bundle_descriptor = ( 88 | beam_fn_api_pb2.ProcessBundleDescriptor.FromString( 89 | self._build_process_bundle_descriptor() 90 | ) 91 | ) 92 | self._timer_coder_ids = ( 93 | fn_execution.BundleContextManager._build_timer_coders_id_map(self) 94 | ) 95 | return self._process_bundle_descriptor 96 | 97 | def _build_process_bundle_descriptor(self): 98 | # Cannot be invoked until *after* _extract_endpoints is called. 99 | # Always populate the timer_api_service_descriptor. 100 | pbd = beam_fn_api_pb2.ProcessBundleDescriptor( 101 | id=self.bundle_uid, 102 | transforms={ 103 | transform.unique_name: transform for transform in self.stage.transforms 104 | }, 105 | pcollections=dict( 106 | self.execution_context.pipeline_components.pcollections.items() 107 | ), 108 | coders=dict(self.execution_context.pipeline_components.coders.items()), 109 | windowing_strategies=dict( 110 | self.execution_context.pipeline_components.windowing_strategies.items() 111 | ), 112 | environments=dict( 113 | self.execution_context.pipeline_components.environments.items() 114 | ), 115 | state_api_service_descriptor=self.state_api_service_descriptor(), 116 | timer_api_service_descriptor=self.data_api_service_descriptor(), 117 | ) 118 | 119 | return pbd.SerializeToString() 120 | 121 | def get_bundle_inputs_and_outputs( 122 | self, 123 | ) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[TimerFamilyId, bytes]]: 124 | """Returns maps of transform names to PCollection identifiers. 125 | 126 | Also mutates IO stages to point to the data ApiServiceDescriptor. 127 | 128 | Returns: 129 | A tuple of (data_input, data_output, expected_timer_output) dictionaries. 130 | `data_input` is a dictionary mapping (transform_name, output_name) to a 131 | PCollection buffer; `data_output` is a dictionary mapping 132 | (transform_name, output_name) to a PCollection ID. 133 | `expected_timer_output` is a dictionary mapping transform_id and 134 | timer family ID to a buffer id for timers. 135 | """ 136 | return self.transform_to_buffer_coder, self.data_output, self.stage_timers 137 | 138 | def setup(self): 139 | transform_to_buffer_coder: typing.Dict[str, typing.Tuple[bytes, str]] = {} 140 | data_output = {} # type: DataOutput 141 | expected_timer_output = {} # type: OutputTimers 142 | for transform in self.stage.transforms: 143 | if transform.spec.urn in ( 144 | bundle_processor.DATA_INPUT_URN, 145 | bundle_processor.DATA_OUTPUT_URN, 146 | ): 147 | pcoll_id = transform.spec.payload 148 | if transform.spec.urn == bundle_processor.DATA_INPUT_URN: 149 | coder_id = self.execution_context.data_channel_coders[ 150 | translations.only_element(transform.outputs.values()) 151 | ] 152 | if pcoll_id == translations.IMPULSE_BUFFER: 153 | pcoll_id = transform.unique_name.encode("utf8") 154 | self.execution_context.pcollection_buffers.put( 155 | pcoll_id, [self.execution_context.encoded_impulse_ref] 156 | ) 157 | else: 158 | pass 159 | transform_to_buffer_coder[transform.unique_name] = ( 160 | pcoll_id, 161 | self.execution_context.safe_coders.get(coder_id, coder_id), 162 | ) 163 | elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: 164 | data_output[transform.unique_name] = pcoll_id 165 | coder_id = self.execution_context.data_channel_coders[ 166 | translations.only_element(transform.inputs.values()) 167 | ] 168 | else: 169 | raise NotImplementedError 170 | data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) 171 | transform.spec.payload = data_spec.SerializeToString() 172 | elif transform.spec.urn in translations.PAR_DO_URNS: 173 | payload = proto_utils.parse_Bytes( 174 | transform.spec.payload, beam_runner_api_pb2.ParDoPayload 175 | ) 176 | for timer_family_id in payload.timer_family_specs.keys(): 177 | expected_timer_output[ 178 | (transform.unique_name, timer_family_id) 179 | ] = translations.create_buffer_id(timer_family_id, "timers") 180 | self.transform_to_buffer_coder, self.data_output, self.stage_timers = ( 181 | transform_to_buffer_coder, 182 | data_output, 183 | expected_timer_output, 184 | ) 185 | -------------------------------------------------------------------------------- /ray_beam_runner/portability/execution.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | """Set of utilities for execution of a pipeline by the RayRunner.""" 19 | 20 | # mypy: disallow-untyped-defs 21 | 22 | import collections 23 | import dataclasses 24 | import itertools 25 | import logging 26 | import random 27 | import typing 28 | from typing import List 29 | from typing import Mapping 30 | from typing import Optional 31 | from typing import Generator 32 | 33 | import ray 34 | 35 | import apache_beam 36 | from apache_beam import coders 37 | from apache_beam.metrics import monitoring_infos 38 | from apache_beam.portability import common_urns 39 | from apache_beam.portability.api import beam_fn_api_pb2 40 | from apache_beam.portability.api import beam_runner_api_pb2 41 | from apache_beam.runners import pipeline_context 42 | from apache_beam.runners.portability.fn_api_runner import execution as fn_execution 43 | from apache_beam.runners.portability.fn_api_runner import translations 44 | from apache_beam.runners.portability.fn_api_runner import watermark_manager 45 | from apache_beam.runners.portability.fn_api_runner import worker_handlers 46 | from apache_beam.runners.worker import bundle_processor 47 | 48 | from ray_beam_runner.portability.state import RayStateManager 49 | from ray_beam_runner.serialization import register_protobuf_serializers 50 | 51 | _LOGGER = logging.getLogger(__name__) 52 | 53 | 54 | @ray.remote(num_returns="dynamic") 55 | def ray_execute_bundle( 56 | runner_context: "RayRunnerExecutionContext", 57 | input_bundle: "Bundle", 58 | transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]], 59 | expected_outputs: translations.DataOutput, 60 | stage_timers: Mapping[translations.TimerFamilyId, bytes], 61 | instruction_request_repr: Mapping[str, typing.Any], 62 | dry_run=False, 63 | ) -> Generator: 64 | # generator returns: 65 | # (serialized InstructionResponse, ouputs, 66 | # repeat of pcoll, data, 67 | # delayed applications, repeat of pcoll, data) 68 | 69 | register_protobuf_serializers() 70 | instruction_request = beam_fn_api_pb2.InstructionRequest( 71 | instruction_id=instruction_request_repr["instruction_id"], 72 | process_bundle=beam_fn_api_pb2.ProcessBundleRequest( 73 | process_bundle_descriptor_id=instruction_request_repr[ 74 | "process_descriptor_id" 75 | ], 76 | cache_tokens=[instruction_request_repr["cache_token"]], 77 | ), 78 | ) 79 | output_buffers: Mapping[ 80 | typing.Union[str, translations.TimerFamilyId], list 81 | ] = collections.defaultdict(list) 82 | process_bundle_id = instruction_request.instruction_id 83 | 84 | worker_handler = _get_worker_handler( 85 | runner_context, instruction_request_repr["process_descriptor_id"] 86 | ) 87 | 88 | _send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id) 89 | 90 | input_data = { 91 | k: _fetch_decode_data( 92 | runner_context, 93 | _get_input_id(transform_buffer_coder[k][0], k), 94 | transform_buffer_coder[k][1], 95 | objrefs, 96 | ) 97 | for k, objrefs in input_bundle.input_data.items() 98 | } 99 | 100 | for transform_id, elements in input_data.items(): 101 | data_out = worker_handler.data_conn.output_stream( 102 | process_bundle_id, transform_id 103 | ) 104 | for byte_stream in elements: 105 | data_out.write(byte_stream) 106 | data_out.close() 107 | 108 | expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list( 109 | expected_outputs.keys() 110 | ) 111 | expect_reads.extend(list(stage_timers.keys())) 112 | 113 | result_future = worker_handler.control_conn.push(instruction_request) 114 | 115 | for output in worker_handler.data_conn.input_elements( 116 | process_bundle_id, 117 | expect_reads, 118 | abort_callback=lambda: ( 119 | result_future.is_done() and bool(result_future.get().error) 120 | ), 121 | ): 122 | if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run: 123 | output_buffers[ 124 | stage_timers[(output.transform_id, output.timer_family_id)] 125 | ].append(output.timers) 126 | if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run: 127 | output_buffers[expected_outputs[output.transform_id]].append(output.data) 128 | 129 | result: beam_fn_api_pb2.InstructionResponse = result_future.get() 130 | 131 | if result.process_bundle.requires_finalization: 132 | finalize_request = beam_fn_api_pb2.InstructionRequest( 133 | finalize_bundle=beam_fn_api_pb2.FinalizeBundleRequest( 134 | instruction_id=process_bundle_id 135 | ) 136 | ) 137 | finalize_response = worker_handler.control_conn.push(finalize_request).get() 138 | if finalize_response.error: 139 | raise RuntimeError(finalize_response.error) 140 | 141 | returns = [result] 142 | 143 | returns.append(len(output_buffers)) 144 | for pcoll, buffer in output_buffers.items(): 145 | returns.append(pcoll) 146 | returns.append(buffer) 147 | 148 | # Now we collect all the deferred inputs remaining from bundle execution. 149 | # Deferred inputs can be: 150 | # - timers 151 | # - SDK-initiated deferred applications of root elements 152 | # - # TODO: Runner-initiated deferred applications of root elements 153 | process_bundle_descriptor = runner_context.worker_manager.process_bundle_descriptor( 154 | instruction_request_repr["process_descriptor_id"] 155 | ) 156 | delayed_applications = _retrieve_delayed_applications( 157 | result, 158 | process_bundle_descriptor, 159 | runner_context, 160 | ) 161 | 162 | returns.append(len(delayed_applications)) 163 | for pcoll, buffer in delayed_applications.items(): 164 | returns.append(pcoll) 165 | returns.append(buffer) 166 | 167 | for ret in returns: 168 | yield ret 169 | 170 | 171 | def _get_source_transform_name( 172 | process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, 173 | transform_id: str, 174 | input_id: str, 175 | ) -> str: 176 | """Find the name of the source PTransform that feeds into the given 177 | (transform_id, input_id).""" 178 | input_pcoll = process_bundle_descriptor.transforms[transform_id].inputs[input_id] 179 | for ptransform_id, ptransform in process_bundle_descriptor.transforms.items(): 180 | # The GrpcRead is directly followed by the SDF/Process. 181 | if ( 182 | ptransform.spec.urn == bundle_processor.DATA_INPUT_URN 183 | and input_pcoll in ptransform.outputs.values() 184 | ): 185 | return ptransform_id 186 | 187 | # The GrpcRead is followed by SDF/Truncate -> SDF/Process. 188 | # We need to traverse the TRUNCATE_SIZED_RESTRICTION node in order 189 | # to find the original source PTransform. 190 | if ( 191 | ptransform.spec.urn 192 | == common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn 193 | and input_pcoll in ptransform.outputs.values() 194 | ): 195 | input_pcoll_ = translations.only_element( 196 | process_bundle_descriptor.transforms[ptransform_id].inputs.values() 197 | ) 198 | for ( 199 | ptransform_id_2, 200 | ptransform_2, 201 | ) in process_bundle_descriptor.transforms.items(): 202 | if ( 203 | ptransform_2.spec.urn == bundle_processor.DATA_INPUT_URN 204 | and input_pcoll_ in ptransform_2.outputs.values() 205 | ): 206 | return ptransform_id_2 207 | 208 | raise RuntimeError("No IO transform feeds %s" % transform_id) 209 | 210 | 211 | def _retrieve_delayed_applications( 212 | bundle_result: beam_fn_api_pb2.InstructionResponse, 213 | process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, 214 | runner_context: "RayRunnerExecutionContext", 215 | ): 216 | """Extract delayed applications from a bundle run. 217 | 218 | A delayed application represents a user-initiated checkpoint, where user code 219 | delays the consumption of a data element to checkpoint the previous elements 220 | in a bundle. 221 | """ 222 | delayed_bundles = {} 223 | for delayed_application in bundle_result.process_bundle.residual_roots: 224 | # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. 225 | # time_delay = delayed_application.requested_time_delay 226 | source_transform = _get_source_transform_name( 227 | process_bundle_descriptor, 228 | delayed_application.application.transform_id, 229 | delayed_application.application.input_id, 230 | ) 231 | 232 | if source_transform not in delayed_bundles: 233 | delayed_bundles[source_transform] = [] 234 | delayed_bundles[source_transform].append( 235 | delayed_application.application.element 236 | ) 237 | 238 | for consumer, data in delayed_bundles.items(): 239 | delayed_bundles[consumer] = [data] 240 | 241 | return delayed_bundles 242 | 243 | 244 | def _get_input_id(buffer_id, transform_name): 245 | """Get the 'buffer_id' for the input data we're retrieving. 246 | 247 | For most data, the buffer ID is as expected, but for IMPULSE readers, the 248 | buffer ID is the consumer name. 249 | """ 250 | if isinstance(buffer_id, bytes) and ( 251 | buffer_id.startswith(b"materialize") 252 | or buffer_id.startswith(b"timer") 253 | or buffer_id.startswith(b"group") 254 | ): 255 | buffer_id = buffer_id 256 | else: 257 | buffer_id = transform_name.encode("ascii") 258 | return buffer_id 259 | 260 | 261 | def _fetch_decode_data( 262 | runner_context: "RayRunnerExecutionContext", 263 | buffer_id: bytes, 264 | coder_id: str, 265 | data_references: List[ray.ObjectRef], 266 | ): 267 | """Fetch a PCollection's data and decode it.""" 268 | if buffer_id.startswith(b"group"): 269 | _, pcoll_id = translations.split_buffer_id(buffer_id) 270 | transform = runner_context.pipeline_components.transforms[pcoll_id] 271 | out_pcoll = runner_context.pipeline_components.pcollections[ 272 | translations.only_element(transform.outputs.values()) 273 | ] 274 | windowing_strategy = runner_context.pipeline_components.windowing_strategies[ 275 | out_pcoll.windowing_strategy_id 276 | ] 277 | postcoder = runner_context.pipeline_context.coders[coder_id] 278 | precoder = coders.WindowedValueCoder( 279 | coders.TupleCoder( 280 | ( 281 | postcoder.wrapped_value_coder._coders[0], 282 | postcoder.wrapped_value_coder._coders[1]._elem_coder, 283 | ) 284 | ), 285 | postcoder.window_coder, 286 | ) 287 | buffer = fn_execution.GroupingBuffer( 288 | pre_grouped_coder=precoder, 289 | post_grouped_coder=postcoder, 290 | windowing=apache_beam.Windowing.from_runner_api(windowing_strategy, None), 291 | ) 292 | else: 293 | buffer = fn_execution.ListBuffer( 294 | coder_impl=runner_context.pipeline_context.coders[coder_id].get_impl() 295 | ) 296 | 297 | for block in ray.get(data_references): 298 | # TODO(pabloem): Stop using ListBuffer, and use different 299 | # buffers to pass data to Beam. 300 | for elm in block: 301 | buffer.append(elm) 302 | return buffer 303 | 304 | 305 | def _send_timers( 306 | worker_handler: worker_handlers.WorkerHandler, 307 | input_bundle: "Bundle", 308 | stage_timers: Mapping[translations.TimerFamilyId, bytes], 309 | process_bundle_id, 310 | ) -> None: 311 | """Pass timers to the worker for processing.""" 312 | for transform_id, timer_family_id in stage_timers.keys(): 313 | timer_out = worker_handler.data_conn.output_timer_stream( 314 | process_bundle_id, transform_id, timer_family_id 315 | ) 316 | for timer in input_bundle.input_timers.get((transform_id, timer_family_id), []): 317 | timer_out.write(timer) 318 | timer_out.close() 319 | 320 | 321 | @ray.remote 322 | class _RayRunnerStats: 323 | def __init__(self): 324 | self._bundle_uid = 0 325 | 326 | def next_bundle(self): 327 | self._bundle_uid += 1 328 | return self._bundle_uid 329 | 330 | 331 | class RayWorkerHandlerManager: 332 | def __init__(self): 333 | self._process_bundle_descriptors = {} 334 | 335 | def register_process_bundle_descriptor(self, process_bundle_descriptor): 336 | ray_process_bundle_descriptor = process_bundle_descriptor 337 | self._process_bundle_descriptors[ 338 | ray_process_bundle_descriptor.id 339 | ] = ray_process_bundle_descriptor 340 | 341 | def process_bundle_descriptor(self, id): 342 | return self._process_bundle_descriptors[id] 343 | 344 | 345 | class RayStage(translations.Stage): 346 | @staticmethod 347 | def from_Stage(stage: translations.Stage): 348 | return RayStage( 349 | stage.name, 350 | stage.transforms, 351 | stage.downstream_side_inputs, 352 | # stage.must_follow, 353 | [], 354 | stage.parent, 355 | stage.environment, 356 | stage.forced_root, 357 | ) 358 | 359 | 360 | class PcollectionBufferManager: 361 | def __init__(self): 362 | self.buffers = collections.defaultdict(list) 363 | 364 | def put(self, pcoll, data_refs: List[ray.ObjectRef]): 365 | self.buffers[pcoll].extend(data_refs) 366 | 367 | def get(self, pcoll) -> List[ray.ObjectRef]: 368 | return self.buffers[pcoll] 369 | 370 | def clear(self, pcoll): 371 | self.buffers[pcoll].clear() 372 | 373 | 374 | @ray.remote 375 | class RayWatermarkManager(watermark_manager.WatermarkManager): 376 | def __init__(self): 377 | # the original WatermarkManager performs a lot of computation 378 | # in its __init__ method. Because Ray calls __init__ whenever 379 | # it deserializes an object, we'll move its setup elsewhere. 380 | self._initialized = False 381 | self._pcollections_by_name = {} 382 | self._stages_by_name = {} 383 | 384 | def setup(self, stages): 385 | if self._initialized: 386 | return 387 | logging.debug("initialized the RayWatermarkManager") 388 | self._initialized = True 389 | watermark_manager.WatermarkManager.setup(self, stages) 390 | 391 | 392 | class RayRunnerExecutionContext(object): 393 | def __init__( 394 | self, 395 | stages: List[translations.Stage], 396 | pipeline_components: beam_runner_api_pb2.Components, 397 | safe_coders: translations.SafeCoderMapping, 398 | data_channel_coders: Mapping[str, str], 399 | state_servicer: Optional[RayStateManager] = None, 400 | worker_manager: Optional[RayWorkerHandlerManager] = None, 401 | pcollection_buffers: PcollectionBufferManager = None, 402 | ) -> None: 403 | self.pcollection_buffers = pcollection_buffers or PcollectionBufferManager() 404 | self.state_servicer = state_servicer or RayStateManager() 405 | self.stages = [ 406 | RayStage.from_Stage(s) if not isinstance(s, RayStage) else s for s in stages 407 | ] 408 | self.side_input_descriptors_by_stage = ( 409 | fn_execution.FnApiRunnerExecutionContext._build_data_side_inputs_map(stages) 410 | ) 411 | self.pipeline_components = pipeline_components 412 | self.safe_coders = safe_coders 413 | self.data_channel_coders = data_channel_coders 414 | 415 | self.input_transform_to_buffer_id = { 416 | t.unique_name: bytes(t.spec.payload) 417 | for s in stages 418 | for t in s.transforms 419 | if t.spec.urn == bundle_processor.DATA_INPUT_URN 420 | } 421 | self._watermark_manager = RayWatermarkManager.remote() 422 | self.pipeline_context = pipeline_context.PipelineContext(pipeline_components) 423 | self.safe_windowing_strategies = { 424 | # TODO: Enable safe_windowing_strategy after 425 | # figuring out how to pickle the function. 426 | # id: self._make_safe_windowing_strategy(id) 427 | id: id 428 | for id in pipeline_components.windowing_strategies.keys() 429 | } 430 | self.stats = _RayRunnerStats.remote() 431 | self._uid = 0 432 | self.worker_manager = worker_manager or RayWorkerHandlerManager() 433 | self.timer_coder_ids = self._build_timer_coders_id_map() 434 | self.encoded_impulse_ref = ray.put([fn_execution.ENCODED_IMPULSE_VALUE]) 435 | 436 | @property 437 | def watermark_manager(self): 438 | # We don't need to wait for this line to execute with ray.get, 439 | # because any further calls to the watermark manager actor will 440 | # have to wait for it. 441 | self._watermark_manager.setup.remote(self.stages) 442 | return self._watermark_manager 443 | 444 | @staticmethod 445 | def next_uid(): 446 | # TODO(pabloem): Use stats actor for UIDs. 447 | # return str(ray.get(self.stats.next_bundle.remote())) 448 | # self._uid += 1 449 | return str(random.randint(0, 11111111)) 450 | 451 | def _build_timer_coders_id_map(self): 452 | from apache_beam.utils import proto_utils 453 | 454 | timer_coder_ids = {} 455 | for ( 456 | transform_id, 457 | transform_proto, 458 | ) in self.pipeline_components.transforms.items(): 459 | if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn: 460 | pardo_payload = proto_utils.parse_Bytes( 461 | transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload 462 | ) 463 | for id, timer_family_spec in pardo_payload.timer_family_specs.items(): 464 | timer_coder_ids[ 465 | (transform_id, id) 466 | ] = timer_family_spec.timer_family_coder_id 467 | return timer_coder_ids 468 | 469 | def commit_side_inputs_to_state(self, data_side_input: translations.DataSideInput): 470 | """ 471 | Store side inputs in the state manager so that they can be accessed by workers. 472 | """ 473 | for (consuming_transform_id, tag), ( 474 | buffer_id, 475 | func_spec, 476 | ) in data_side_input.items(): 477 | _, pcoll_id = translations.split_buffer_id(buffer_id) 478 | value_coder = self.pipeline_context.coders[ 479 | self.safe_coders[self.data_channel_coders[pcoll_id]] 480 | ] 481 | 482 | elements_by_window = fn_execution.WindowGroupingBuffer( 483 | func_spec, value_coder 484 | ) 485 | 486 | # TODO: Fix this 487 | pcoll_buffer = ray.get(self.pcollection_buffers.get(buffer_id)) 488 | for bundle_items in pcoll_buffer: 489 | for bundle_item in bundle_items: 490 | elements_by_window.append(bundle_item) 491 | 492 | futures = [] 493 | if func_spec.urn == common_urns.side_inputs.ITERABLE.urn: 494 | for _, window, elements_data in elements_by_window.encoded_items(): 495 | state_key = beam_fn_api_pb2.StateKey( 496 | iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( 497 | transform_id=consuming_transform_id, 498 | side_input_id=tag, 499 | window=window, 500 | ) 501 | ) 502 | futures.append( 503 | self.state_servicer.append_raw( 504 | state_key, elements_data 505 | )._object_ref 506 | ) 507 | elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn: 508 | for key, window, elements_data in elements_by_window.encoded_items(): 509 | state_key = beam_fn_api_pb2.StateKey( 510 | multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( 511 | transform_id=consuming_transform_id, 512 | side_input_id=tag, 513 | window=window, 514 | key=key, 515 | ) 516 | ) 517 | futures.append( 518 | self.state_servicer.append_raw( 519 | state_key, elements_data 520 | )._object_ref 521 | ) 522 | else: 523 | raise ValueError("Unknown access pattern: '%s'" % func_spec.urn) 524 | 525 | ray.wait(futures, num_returns=len(futures)) 526 | 527 | 528 | def merge_stage_results( 529 | previous_result: beam_fn_api_pb2.InstructionResponse, 530 | last_result: beam_fn_api_pb2.InstructionResponse, 531 | ) -> beam_fn_api_pb2.InstructionResponse: 532 | """Merge InstructionResponse objects from executions of same stage bundles. 533 | 534 | This method is used to produce a global per-stage result object with 535 | aggregated metrics and results. 536 | """ 537 | return ( 538 | last_result 539 | if previous_result is None 540 | else beam_fn_api_pb2.InstructionResponse( 541 | process_bundle=beam_fn_api_pb2.ProcessBundleResponse( 542 | monitoring_infos=monitoring_infos.consolidate( 543 | itertools.chain( 544 | previous_result.process_bundle.monitoring_infos, 545 | last_result.process_bundle.monitoring_infos, 546 | ) 547 | ) 548 | ), 549 | error=previous_result.error or last_result.error, 550 | ) 551 | ) 552 | 553 | 554 | def _get_worker_handler( 555 | runner_context: RayRunnerExecutionContext, bundle_descriptor_id 556 | ) -> worker_handlers.WorkerHandler: 557 | worker_handler = worker_handlers.EmbeddedWorkerHandler( 558 | None, # Unnecessary payload. 559 | runner_context.state_servicer, 560 | None, # Unnecessary provision info. 561 | runner_context.worker_manager, 562 | ) 563 | worker_handler.worker.bundle_processor_cache.register( 564 | runner_context.worker_manager.process_bundle_descriptor(bundle_descriptor_id) 565 | ) 566 | return worker_handler 567 | 568 | 569 | @dataclasses.dataclass 570 | class Bundle: 571 | input_timers: Mapping[translations.TimerFamilyId, fn_execution.PartitionableBuffer] 572 | input_data: Mapping[str, List[ray.ObjectRef]] 573 | -------------------------------------------------------------------------------- /ray_beam_runner/portability/execution_test.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import hamcrest as hc 19 | import unittest 20 | 21 | import ray 22 | 23 | from apache_beam.portability.api import beam_fn_api_pb2 24 | from ray_beam_runner.portability.state import RayStateManager 25 | 26 | 27 | class StateHandlerTest(unittest.TestCase): 28 | SAMPLE_STATE_KEY = beam_fn_api_pb2.StateKey() 29 | SAMPLE_INPUT_DATA = [b"bobby" b"tables", b"drop table", b"where table_name > 12345"] 30 | 31 | @classmethod 32 | def setUpClass(cls) -> None: 33 | if not ray.is_initialized(): 34 | ray.init() 35 | 36 | @classmethod 37 | def tearDownClass(cls) -> None: 38 | ray.shutdown() 39 | 40 | def test_data_stored_properly(self): 41 | sh = RayStateManager() 42 | with sh.process_instruction_id("anyinstruction"): 43 | for data in StateHandlerTest.SAMPLE_INPUT_DATA: 44 | sh.append_raw(StateHandlerTest.SAMPLE_STATE_KEY, data) 45 | 46 | with sh.process_instruction_id("otherinstruction"): 47 | continuation_token = None 48 | all_data = [] 49 | while True: 50 | data, continuation_token = sh.get_raw( 51 | StateHandlerTest.SAMPLE_STATE_KEY, continuation_token 52 | ) 53 | all_data.append(data) 54 | if continuation_token is None: 55 | break 56 | 57 | hc.assert_that( 58 | all_data, hc.contains_exactly(*StateHandlerTest.SAMPLE_INPUT_DATA) 59 | ) 60 | 61 | def test_fresh_key(self): 62 | sh = RayStateManager() 63 | with sh.process_instruction_id("anyinstruction"): 64 | data, continuation_token = sh.get_raw(StateHandlerTest.SAMPLE_STATE_KEY) 65 | hc.assert_that(continuation_token, hc.equal_to(None)) 66 | hc.assert_that(data, hc.equal_to(b"")) 67 | -------------------------------------------------------------------------------- /ray_beam_runner/portability/ray_fn_runner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | """A PipelineRunner using the SDK harness.""" 19 | # pytype: skip-file 20 | # mypy: check-untyped-defs 21 | import collections 22 | import copy 23 | import logging 24 | import typing 25 | from typing import Dict 26 | from typing import List 27 | from typing import Mapping 28 | from typing import Optional 29 | from typing import Tuple 30 | from typing import Union 31 | from typing import MutableMapping 32 | from typing import Iterable 33 | 34 | from apache_beam.coders.coder_impl import create_OutputStream 35 | from apache_beam.options import pipeline_options 36 | from apache_beam.options.value_provider import RuntimeValueProvider 37 | from apache_beam.pipeline import Pipeline 38 | from apache_beam.portability import common_urns 39 | from apache_beam.portability.api import beam_fn_api_pb2 40 | from apache_beam.portability.api import beam_runner_api_pb2 41 | from apache_beam.runners import runner 42 | from apache_beam.runners.common import group_by_key_input_visitor 43 | from apache_beam.runners.portability.fn_api_runner import execution 44 | from apache_beam.runners.portability.fn_api_runner import fn_runner 45 | from apache_beam.runners.portability.fn_api_runner import translations 46 | from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer 47 | from apache_beam.transforms import environments 48 | from apache_beam.utils import proto_utils, timestamp 49 | from apache_beam.metrics import metric 50 | from apache_beam.metrics.execution import MetricResult 51 | from apache_beam.runners.portability import portable_metrics 52 | from apache_beam.portability.api import metrics_pb2 53 | 54 | import ray 55 | from ray_beam_runner.portability.context_management import RayBundleContextManager 56 | from ray_beam_runner.portability.execution import Bundle, _get_input_id 57 | from ray_beam_runner.portability.execution import ( 58 | ray_execute_bundle, 59 | merge_stage_results, 60 | ) 61 | from ray_beam_runner.portability.execution import RayRunnerExecutionContext 62 | from ray_beam_runner.serialization import register_protobuf_serializers 63 | 64 | _LOGGER = logging.getLogger(__name__) 65 | 66 | # This module is experimental. No backwards-compatibility guarantees. 67 | 68 | 69 | def _setup_options(options: pipeline_options.PipelineOptions): 70 | """Perform any necessary checkups and updates to input pipeline options""" 71 | 72 | # TODO(pabloem): Add input pipeline options 73 | RuntimeValueProvider.set_runtime_options({}) 74 | 75 | experiments = options.view_as(pipeline_options.DebugOptions).experiments or [] 76 | if "beam_fn_api" not in experiments: 77 | experiments.append("beam_fn_api") 78 | options.view_as(pipeline_options.DebugOptions).experiments = experiments 79 | 80 | 81 | def _check_supported_requirements( 82 | pipeline_proto: beam_runner_api_pb2.Pipeline, 83 | supported_requirements: typing.Iterable[str], 84 | ): 85 | """Check that the input pipeline does not have unsuported requirements.""" 86 | for requirement in pipeline_proto.requirements: 87 | if requirement not in supported_requirements: 88 | raise ValueError( 89 | "Unable to run pipeline with requirement: %s" % requirement 90 | ) 91 | for transform in pipeline_proto.components.transforms.values(): 92 | if transform.spec.urn == common_urns.primitives.TEST_STREAM.urn: 93 | raise NotImplementedError(transform.spec.urn) 94 | elif transform.spec.urn in translations.PAR_DO_URNS: 95 | payload = proto_utils.parse_Bytes( 96 | transform.spec.payload, beam_runner_api_pb2.ParDoPayload 97 | ) 98 | for timer in payload.timer_family_specs.values(): 99 | if timer.time_domain != beam_runner_api_pb2.TimeDomain.EVENT_TIME: 100 | raise NotImplementedError(timer.time_domain) 101 | 102 | 103 | def _pipeline_checks( 104 | pipeline: Pipeline, 105 | options: pipeline_options.PipelineOptions, 106 | supported_requirements: typing.Iterable[str], 107 | ): 108 | # This is sometimes needed if type checking is disabled 109 | # to enforce that the inputs (and outputs) of GroupByKey operations 110 | # are known to be KVs. 111 | pipeline.visit( 112 | group_by_key_input_visitor( 113 | not options.view_as( 114 | pipeline_options.TypeOptions 115 | ).allow_non_deterministic_key_coders 116 | ) 117 | ) 118 | 119 | pipeline_proto = pipeline.to_runner_api( 120 | default_environment=environments.EmbeddedPythonEnvironment.default() 121 | ) 122 | fn_runner.FnApiRunner._validate_requirements(None, pipeline_proto) 123 | 124 | _check_supported_requirements(pipeline_proto, supported_requirements) 125 | return pipeline_proto 126 | 127 | 128 | class RayFnApiRunner(runner.PipelineRunner): 129 | def __init__( 130 | self, 131 | is_drain=False, 132 | ) -> None: 133 | 134 | """Creates a new Ray Runner instance. 135 | 136 | Args: 137 | progress_request_frequency: The frequency (in seconds) that the runner 138 | waits before requesting progress from the SDK. 139 | is_drain: identify whether expand the sdf graph in the drain mode. 140 | """ 141 | super().__init__() 142 | # TODO: figure out if this is necessary (probably, later) 143 | self._progress_frequency = None 144 | self._cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator() 145 | self._is_drain = is_drain 146 | 147 | @staticmethod 148 | def supported_requirements(): 149 | # type: () -> Tuple[str, ...] 150 | return ( 151 | common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn, 152 | common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn, 153 | common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn, 154 | ) 155 | 156 | def run_pipeline( 157 | self, pipeline: Pipeline, options: pipeline_options.PipelineOptions 158 | ) -> "RayRunnerResult": 159 | 160 | # Checkup and set up input pipeline options 161 | _setup_options(options) 162 | 163 | # Check pipeline and convert into protocol buffer representation 164 | pipeline_proto = _pipeline_checks( 165 | pipeline, options, self.supported_requirements() 166 | ) 167 | 168 | # Take the protocol buffer representation of the user's pipeline, and 169 | # apply optimizations. 170 | stage_context, stages = translations.create_and_optimize_stages( 171 | copy.deepcopy(pipeline_proto), 172 | phases=[ 173 | # This is a list of transformations and optimizations to apply 174 | # to a pipeline. 175 | translations.annotate_downstream_side_inputs, 176 | translations.fix_side_input_pcoll_coders, 177 | translations.pack_combiners, 178 | translations.lift_combiners, 179 | translations.expand_sdf, 180 | translations.expand_gbk, 181 | translations.sink_flattens, 182 | translations.greedily_fuse, 183 | translations.read_to_impulse, 184 | translations.impulse_to_input, 185 | translations.sort_stages, 186 | translations.setup_timer_mapping, 187 | translations.populate_data_channel_coders, 188 | ], 189 | known_runner_urns=frozenset( 190 | [ 191 | common_urns.primitives.FLATTEN.urn, 192 | common_urns.primitives.GROUP_BY_KEY.urn, 193 | ] 194 | ), 195 | use_state_iterables=False, 196 | is_drain=self._is_drain, 197 | ) 198 | return self.execute_pipeline(stage_context, stages) 199 | 200 | def execute_pipeline( 201 | self, 202 | stage_context: translations.TransformContext, 203 | stages: List[translations.Stage], 204 | ) -> "RayRunnerResult": 205 | """Execute pipeline represented by a list of stages and a context.""" 206 | logging.info("Starting pipeline of %d stages." % len(stages)) 207 | 208 | register_protobuf_serializers() 209 | runner_execution_context = RayRunnerExecutionContext( 210 | stages, 211 | stage_context.components, 212 | stage_context.safe_coders, 213 | stage_context.data_channel_coders, 214 | ) 215 | 216 | # Using this queue to hold 'bundles' that are ready to be processed 217 | queue = collections.deque() 218 | 219 | # stage metrics 220 | monitoring_infos_by_stage: MutableMapping[ 221 | str, Iterable["metrics_pb2.MonitoringInfo"] 222 | ] = {} 223 | 224 | try: 225 | for stage in stages: 226 | bundle_ctx = RayBundleContextManager(runner_execution_context, stage) 227 | result = self._run_stage(runner_execution_context, bundle_ctx, queue) 228 | monitoring_infos_by_stage[ 229 | bundle_ctx.stage.name 230 | ] = result.process_bundle.monitoring_infos 231 | 232 | finally: 233 | pass 234 | return RayRunnerResult(runner.PipelineState.DONE, monitoring_infos_by_stage) 235 | 236 | def _run_stage( 237 | self, 238 | runner_execution_context: RayRunnerExecutionContext, 239 | bundle_context_manager: RayBundleContextManager, 240 | ready_bundles: collections.deque, 241 | ) -> beam_fn_api_pb2.InstructionResponse: 242 | 243 | """Run an individual stage. 244 | 245 | Args: 246 | runner_execution_context: An object containing execution information for 247 | the pipeline. 248 | bundle_context_manager (execution.BundleContextManager): A description of 249 | the stage to execute, and its context. 250 | """ 251 | bundle_context_manager.setup() 252 | runner_execution_context.worker_manager.register_process_bundle_descriptor( 253 | bundle_context_manager.process_bundle_descriptor 254 | ) 255 | input_timers: Mapping[ 256 | translations.TimerFamilyId, execution.PartitionableBuffer 257 | ] = {} 258 | 259 | input_data = { 260 | k: runner_execution_context.pcollection_buffers.get( 261 | _get_input_id(bundle_context_manager.transform_to_buffer_coder[k][0], k) 262 | ) 263 | for k in bundle_context_manager.transform_to_buffer_coder 264 | } 265 | 266 | final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] 267 | 268 | while True: 269 | ( 270 | last_result, 271 | fired_timers, 272 | delayed_applications, 273 | bundle_outputs, 274 | ) = self._run_bundle( 275 | runner_execution_context, 276 | bundle_context_manager, 277 | Bundle(input_timers=input_timers, input_data=input_data), 278 | ) 279 | 280 | final_result = merge_stage_results(final_result, last_result) 281 | if not delayed_applications and not fired_timers: 282 | break 283 | else: 284 | # TODO: Enable following assertion after watermarking is implemented 285 | # assert (ray.get( 286 | # runner_execution_context.watermark_manager 287 | # .get_stage_node.remote( 288 | # bundle_context_manager.stage.name)).output_watermark() 289 | # < timestamp.MAX_TIMESTAMP), ( 290 | # 'wrong timestamp for %s. ' 291 | # % ray.get( 292 | # runner_execution_context.watermark_manager 293 | # .get_stage_node.remote( 294 | # bundle_context_manager.stage.name))) 295 | input_data = delayed_applications 296 | input_timers = fired_timers 297 | 298 | # Store the required downstream side inputs into state so it is accessible 299 | # for the worker when it runs bundles that consume this stage's output. 300 | data_side_input = runner_execution_context.side_input_descriptors_by_stage.get( 301 | bundle_context_manager.stage.name, {} 302 | ) 303 | runner_execution_context.commit_side_inputs_to_state(data_side_input) 304 | 305 | return final_result 306 | 307 | def _run_bundle( 308 | self, 309 | runner_execution_context: RayRunnerExecutionContext, 310 | bundle_context_manager: RayBundleContextManager, 311 | input_bundle: Bundle, 312 | ) -> Tuple[ 313 | beam_fn_api_pb2.InstructionResponse, 314 | Dict[translations.TimerFamilyId, ListBuffer], 315 | Mapping[str, ray.ObjectRef], 316 | List[Union[str, translations.TimerFamilyId]], 317 | ]: 318 | """Execute a bundle, and return a result object, and deferred inputs.""" 319 | ( 320 | transform_to_buffer_coder, 321 | data_output, 322 | stage_timers, 323 | ) = bundle_context_manager.get_bundle_inputs_and_outputs() 324 | 325 | cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator( 326 | static=False 327 | ) 328 | 329 | process_bundle_descriptor = bundle_context_manager.process_bundle_descriptor 330 | 331 | # TODO(pabloem): Are there two different IDs? the Bundle ID and PBD ID? 332 | process_bundle_id = "bundle_%s" % process_bundle_descriptor.id 333 | 334 | pbd_id = process_bundle_descriptor.id 335 | result_generator_ref = ray_execute_bundle.remote( 336 | runner_execution_context, 337 | input_bundle, 338 | transform_to_buffer_coder, 339 | data_output, 340 | stage_timers, 341 | instruction_request_repr={ 342 | "instruction_id": process_bundle_id, 343 | "process_descriptor_id": pbd_id, 344 | "cache_token": next(cache_token_generator), 345 | }, 346 | ) 347 | result_generator = iter(ray.get(result_generator_ref)) 348 | result = ray.get(next(result_generator)) 349 | 350 | output = [] 351 | num_outputs = ray.get(next(result_generator)) 352 | for _ in range(num_outputs): 353 | pcoll = ray.get(next(result_generator)) 354 | data_ref = next(result_generator) 355 | output.append(pcoll) 356 | runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) 357 | 358 | delayed_applications = {} 359 | num_delayed_applications = ray.get(next(result_generator)) 360 | for _ in range(num_delayed_applications): 361 | pcoll = ray.get(next(result_generator)) 362 | data_ref = next(result_generator) 363 | delayed_applications[pcoll] = data_ref 364 | runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) 365 | 366 | ( 367 | watermarks_by_transform_and_timer_family, 368 | newly_set_timers, 369 | ) = self._collect_written_timers(bundle_context_manager) 370 | 371 | # TODO(pabloem): Add support for splitting of results. 372 | 373 | # After collecting deferred inputs, we 'pad' the structure with empty 374 | # buffers for other expected inputs. 375 | # if deferred_inputs or newly_set_timers: 376 | # # The worker will be waiting on these inputs as well. 377 | # for other_input in data_input: 378 | # if other_input not in deferred_inputs: 379 | # deferred_inputs[other_input] = ListBuffer( 380 | # coder_impl=bundle_context_manager.get_input_coder_impl( 381 | # other_input)) 382 | 383 | return result, newly_set_timers, delayed_applications, output 384 | 385 | @staticmethod 386 | def _collect_written_timers( 387 | bundle_context_manager: RayBundleContextManager, 388 | ) -> Tuple[ 389 | Dict[translations.TimerFamilyId, timestamp.Timestamp], 390 | Mapping[translations.TimerFamilyId, execution.PartitionableBuffer], 391 | ]: 392 | """Review output buffers, and collect written timers. 393 | This function reviews a stage that has just been run. The stage will have 394 | written timers to its output buffers. The function then takes the timers, 395 | and adds them to the `newly_set_timers` dictionary, and the 396 | timer_watermark_data dictionary. 397 | The function then returns the following two elements in a tuple: 398 | - timer_watermark_data: A dictionary mapping timer family to upcoming 399 | timestamp to fire. 400 | - newly_set_timers: A dictionary mapping timer family to timer buffers 401 | to be passed to the SDK upon firing. 402 | """ 403 | timer_watermark_data = {} 404 | newly_set_timers = {} 405 | 406 | execution_context = bundle_context_manager.execution_context 407 | buffer_manager = execution_context.pcollection_buffers 408 | 409 | for ( 410 | transform_id, 411 | timer_family_id, 412 | ), buffer_id in bundle_context_manager.stage_timers.items(): 413 | timer_buffer = buffer_manager.get(buffer_id) 414 | 415 | coder_id = bundle_context_manager._timer_coder_ids[ 416 | (transform_id, timer_family_id) 417 | ] 418 | 419 | coder = execution_context.pipeline_context.coders[coder_id] 420 | timer_coder_impl = coder.get_impl() 421 | 422 | timers_by_key_tag_and_window = {} 423 | if len(timer_buffer) >= 1: 424 | written_timers = ray.get(timer_buffer[0]) 425 | # clear the timer buffer 426 | buffer_manager.clear(buffer_id) 427 | 428 | # deduplicate updates to the same timer 429 | for elements_timers in written_timers: 430 | for decoded_timer in timer_coder_impl.decode_all(elements_timers): 431 | key_tag_win = ( 432 | decoded_timer.user_key, 433 | decoded_timer.dynamic_timer_tag, 434 | decoded_timer.windows[0], 435 | ) 436 | if not decoded_timer.clear_bit: 437 | timers_by_key_tag_and_window[key_tag_win] = decoded_timer 438 | elif ( 439 | decoded_timer.clear_bit 440 | and key_tag_win in timers_by_key_tag_and_window 441 | ): 442 | del timers_by_key_tag_and_window[key_tag_win] 443 | if not timers_by_key_tag_and_window: 444 | continue 445 | 446 | out = create_OutputStream() 447 | for decoded_timer in timers_by_key_tag_and_window.values(): 448 | timer_coder_impl.encode_to_stream(decoded_timer, out, True) 449 | timer_watermark_data[(transform_id, timer_family_id)] = min( 450 | timer_watermark_data.get( 451 | (transform_id, timer_family_id), timestamp.MAX_TIMESTAMP 452 | ), 453 | decoded_timer.hold_timestamp, 454 | ) 455 | 456 | buf = ListBuffer(coder_impl=timer_coder_impl) 457 | buf.append(out.get()) 458 | newly_set_timers[(transform_id, timer_family_id)] = buf 459 | return timer_watermark_data, newly_set_timers 460 | 461 | 462 | class FnApiMetrics(metric.MetricResults): 463 | def __init__(self, step_monitoring_infos, user_metrics_only=True): 464 | """Used for querying metrics from the PipelineResult object. 465 | step_monitoring_infos: Per step metrics specified as MonitoringInfos. 466 | user_metrics_only: If true, includes user metrics only. 467 | """ 468 | self._counters = {} 469 | self._distributions = {} 470 | self._gauges = {} 471 | self._user_metrics_only = user_metrics_only 472 | self._monitoring_infos = step_monitoring_infos 473 | 474 | for smi in step_monitoring_infos.values(): 475 | counters, distributions, gauges = portable_metrics.from_monitoring_infos( 476 | smi, user_metrics_only 477 | ) 478 | self._counters.update(counters) 479 | self._distributions.update(distributions) 480 | self._gauges.update(gauges) 481 | 482 | def query(self, filter=None): 483 | counters = [ 484 | MetricResult(k, v, v) 485 | for k, v in self._counters.items() 486 | if self.matches(filter, k) 487 | ] 488 | distributions = [ 489 | MetricResult(k, v, v) 490 | for k, v in self._distributions.items() 491 | if self.matches(filter, k) 492 | ] 493 | gauges = [ 494 | MetricResult(k, v, v) 495 | for k, v in self._gauges.items() 496 | if self.matches(filter, k) 497 | ] 498 | 499 | return { 500 | self.COUNTERS: counters, 501 | self.DISTRIBUTIONS: distributions, 502 | self.GAUGES: gauges, 503 | } 504 | 505 | def monitoring_infos(self): 506 | # type: () -> List[metrics_pb2.MonitoringInfo] 507 | return [item for sublist in self._monitoring_infos.values() for item in sublist] 508 | 509 | 510 | class RayRunnerResult(runner.PipelineResult): 511 | def __init__(self, state, monitoring_infos_by_stage): 512 | super().__init__(state) 513 | self._monitoring_infos_by_stage = monitoring_infos_by_stage 514 | self._metrics = None 515 | self._monitoring_metrics = None 516 | 517 | def wait_until_finish(self, duration=None): 518 | return None 519 | 520 | def metrics(self): 521 | """Returns a queryable object including user metrics only.""" 522 | if self._metrics is None: 523 | self._metrics = FnApiMetrics( 524 | self._monitoring_infos_by_stage, user_metrics_only=True 525 | ) 526 | return self._metrics 527 | 528 | def monitoring_metrics(self): 529 | """Returns a queryable object including all metrics.""" 530 | if self._monitoring_metrics is None: 531 | self._monitoring_metrics = FnApiMetrics( 532 | self._monitoring_infos_by_stage, user_metrics_only=False 533 | ) 534 | return self._monitoring_metrics 535 | -------------------------------------------------------------------------------- /ray_beam_runner/portability/state.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | """Library for streaming state management for Ray Beam Runner.""" 19 | 20 | import collections 21 | import contextlib 22 | from typing import Optional, Tuple, Iterator, TypeVar 23 | 24 | import ray 25 | from ray import ObjectRef 26 | from apache_beam.portability.api import beam_fn_api_pb2 27 | from apache_beam.runners.worker import sdk_worker 28 | 29 | T = TypeVar("T") 30 | 31 | 32 | class RayFuture(sdk_worker._Future[T]): 33 | """Wraps a ray ObjectRef in a beam sdk_worker._Future""" 34 | 35 | def __init__(self, object_ref): 36 | # type: (ObjectRef[T]) -> None 37 | self._object_ref: ObjectRef[T] = object_ref 38 | 39 | def wait(self, timeout=None): 40 | # type: (Optional[float]) -> bool 41 | try: 42 | # TODO: Is ray.get slower than ray.wait if we don't need the return value? 43 | ray.get(self._object_ref, timeout=timeout) 44 | # 45 | return True 46 | except ray.GetTimeoutError: 47 | return False 48 | 49 | def get(self, timeout=None): 50 | # type: (Optional[float]) -> T 51 | return ray.get(self._object_ref, timeout=timeout) 52 | 53 | def set(self, _value): 54 | # type: (T) -> sdk_worker._Future[T] 55 | raise NotImplementedError() 56 | 57 | 58 | @ray.remote 59 | class _ActorStateManager: 60 | def __init__(self): 61 | self._data = collections.defaultdict(lambda: []) 62 | 63 | def get_raw( 64 | self, 65 | state_key: str, 66 | continuation_token: Optional[bytes] = None, 67 | ) -> Tuple[bytes, Optional[bytes]]: 68 | if continuation_token: 69 | continuation_token = int(continuation_token) 70 | else: 71 | continuation_token = 0 72 | 73 | full_state = self._data[state_key] 74 | if len(full_state) == continuation_token: 75 | return b"", None 76 | 77 | if continuation_token + 1 == len(full_state): 78 | next_cont_token = None 79 | else: 80 | next_cont_token = str(continuation_token + 1).encode("utf8") 81 | 82 | return full_state[continuation_token], next_cont_token 83 | 84 | def append_raw(self, state_key: str, data: bytes): 85 | self._data[state_key].append(data) 86 | 87 | def clear(self, state_key: str): 88 | self._data[state_key] = [] 89 | 90 | 91 | class RayStateManager(sdk_worker.StateHandler): 92 | def __init__(self, state_actor: Optional[_ActorStateManager] = None): 93 | self._state_actor = state_actor or _ActorStateManager.remote() 94 | self._instruction_id: Optional[str] = None 95 | 96 | @staticmethod 97 | def _to_key(state_key: beam_fn_api_pb2.StateKey): 98 | return state_key.SerializeToString() 99 | 100 | def get_raw( 101 | self, 102 | state_key, # type: beam_fn_api_pb2.StateKey 103 | continuation_token=None, # type: Optional[bytes] 104 | ) -> Tuple[bytes, Optional[bytes]]: 105 | return ray.get( 106 | self._state_actor.get_raw.remote( 107 | RayStateManager._to_key(state_key), 108 | continuation_token, 109 | ) 110 | ) 111 | 112 | def append_raw(self, state_key: beam_fn_api_pb2.StateKey, data: bytes) -> RayFuture: 113 | return RayFuture( 114 | self._state_actor.append_raw.remote( 115 | RayStateManager._to_key(state_key), data 116 | ) 117 | ) 118 | 119 | def clear(self, state_key: beam_fn_api_pb2.StateKey) -> RayFuture: 120 | assert self._instruction_id is not None 121 | return RayFuture( 122 | self._state_actor.clear.remote(RayStateManager._to_key(state_key)) 123 | ) 124 | 125 | @contextlib.contextmanager 126 | def process_instruction_id(self, bundle_id: str) -> Iterator[None]: 127 | # Instruction id is not being used right now, 128 | # we only assert that it has been set before accessing state. 129 | self._instruction_id = bundle_id 130 | yield 131 | self._instruction_id = None 132 | 133 | def done(self): 134 | pass 135 | -------------------------------------------------------------------------------- /ray_beam_runner/ray_runner.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | """RayRunner, executing on a Ray cluster. 18 | 19 | """ 20 | 21 | # pytype: skip-file 22 | import logging 23 | 24 | import ray 25 | from apache_beam.options.pipeline_options import PipelineOptions 26 | 27 | from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner 28 | from ray_beam_runner.collection import CollectionMap 29 | from ray_beam_runner.overrides import _get_overrides 30 | from ray_beam_runner.translator import TranslationExecutor 31 | from apache_beam.runners.runner import PipelineState, PipelineResult 32 | 33 | __all__ = ["RayRunner"] 34 | 35 | from apache_beam.typehints import Dict 36 | 37 | _LOGGER = logging.getLogger(__name__) 38 | 39 | 40 | class RayRunnerOptions(PipelineOptions): 41 | """DirectRunner-specific execution options.""" 42 | 43 | @classmethod 44 | def _add_argparse_args(cls, parser): 45 | parser.add_argument( 46 | "--parallelism", 47 | type=int, 48 | default=1, 49 | help="Parallelism for Read/Create operations") 50 | 51 | 52 | class RayRunner(BundleBasedDirectRunner): 53 | """Executes a single pipeline on the local machine.""" 54 | 55 | @staticmethod 56 | def is_fnapi_compatible(): 57 | return False 58 | 59 | def run_pipeline(self, pipeline, options): 60 | """Execute the entire pipeline and returns a RayPipelineResult.""" 61 | runner_options = options.view_as(RayRunnerOptions) 62 | 63 | collection_map = CollectionMap() 64 | 65 | # Override some transforms with custom transforms 66 | overrides = _get_overrides() 67 | pipeline.replace_all(overrides) 68 | 69 | # Execute transforms using Ray datasets 70 | translation_executor = TranslationExecutor( 71 | collection_map, parallelism=runner_options.parallelism) 72 | pipeline.visit(translation_executor) 73 | 74 | named_graphs = [ 75 | transform.named_outputs() 76 | for transform in pipeline.transforms_stack 77 | ] 78 | 79 | outputs = {} 80 | for named_outputs in named_graphs: 81 | outputs.update(named_outputs) 82 | 83 | _LOGGER.info("Running pipeline with RayRunner.") 84 | 85 | result = RayPipelineResult(outputs) 86 | 87 | return result 88 | 89 | 90 | class RayPipelineResult(PipelineResult): 91 | def __init__(self, named_outputs: Dict[str, ray.data.Dataset]): 92 | super(RayPipelineResult, self).__init__(PipelineState.RUNNING) 93 | self.named_outputs = named_outputs 94 | 95 | def __del__(self): 96 | if self._state == PipelineState.RUNNING: 97 | _LOGGER.warning( 98 | "The RayPipelineResult is being garbage-collected while the " 99 | "RayRunner is still running the corresponding pipeline. " 100 | "This may lead to incomplete execution of the pipeline if the " 101 | "main thread exits before pipeline completion. Consider using " 102 | "result.wait_until_finish() to wait for completion of " 103 | "pipeline execution.") 104 | 105 | def wait_until_finish(self, duration=None): 106 | if not PipelineState.is_terminal(self.state): 107 | if duration: 108 | raise NotImplementedError( 109 | "RayRunner does not support duration argument.") 110 | try: 111 | objs = list(self.named_outputs.values()) 112 | ray.wait(objs, num_returns=objs) 113 | self._state = PipelineState.DONE 114 | except Exception: # pylint: disable=broad-except 115 | self._state = PipelineState.FAILED 116 | raise 117 | return self._state 118 | -------------------------------------------------------------------------------- /ray_beam_runner/serialization.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | import ray 18 | 19 | from apache_beam.portability.api import beam_runner_api_pb2, beam_fn_api_pb2 20 | 21 | 22 | def register_protobuf_serializers(): 23 | """ 24 | Register serializers for protobuf messages. 25 | Note: Serializers are managed locally for each Ray worker. 26 | """ 27 | # TODO(rkenmi): Figure out how to not repeat this call on workers? 28 | pb_msg_map = { 29 | msg_name: pb_module 30 | for pb_module in [beam_fn_api_pb2, beam_runner_api_pb2] 31 | for msg_name in pb_module.DESCRIPTOR.message_types_by_name.keys() 32 | } 33 | 34 | def _serializer(message): 35 | return message.SerializeToString() 36 | 37 | def _deserializer(pb_module, msg_name): 38 | return lambda s: getattr(pb_module, msg_name).FromString(s) 39 | 40 | for msg_name, pb_module in pb_msg_map.items(): 41 | ray.util.register_serializer( 42 | getattr(pb_module, msg_name), 43 | serializer=_serializer, 44 | deserializer=_deserializer(pb_module, msg_name), 45 | ) 46 | -------------------------------------------------------------------------------- /ray_beam_runner/translator.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one or more 2 | # contributor license agreements. See the NOTICE file distributed with 3 | # this work for additional information regarding copyright ownership. 4 | # The ASF licenses this file to You under the Apache License, Version 2.0 5 | # (the "License"); you may not use this file except in compliance with 6 | # the License. You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | from typing import Mapping, Sequence 17 | 18 | import ray.data 19 | 20 | from apache_beam import ( 21 | Create, 22 | Union, 23 | ParDo, 24 | Impulse, 25 | PTransform, 26 | WindowInto, 27 | Flatten, 28 | io, 29 | DoFn, 30 | ) 31 | from apache_beam.pipeline import AppliedPTransform, PipelineVisitor 32 | from apache_beam.pvalue import PBegin, TaggedOutput 33 | from apache_beam.runners.common import ( 34 | DoFnInvoker, 35 | DoFnSignature, 36 | DoFnContext, 37 | Receiver, 38 | _OutputProcessor, 39 | ) 40 | from apache_beam.transforms.sideinputs import SideInputMap 41 | from ray.data.block import Block, BlockMetadata, BlockAccessor 42 | 43 | from ray_beam_runner.collection import CollectionMap 44 | from ray_beam_runner.custom_actor_pool import CustomActorPool 45 | from ray_beam_runner.overrides import ( 46 | _Create, 47 | _Read, 48 | _Reshuffle, 49 | _GroupByKeyOnly, 50 | _GroupAlsoByWindow, 51 | ) 52 | from apache_beam.transforms.window import WindowFn, TimestampedValue, GlobalWindow 53 | from apache_beam.typehints import Optional 54 | from apache_beam.utils.windowed_value import WindowedValue 55 | 56 | 57 | def get_windowed_value(input_item, window_fn: WindowFn): 58 | if isinstance(input_item, TaggedOutput): 59 | input_item = input_item.value 60 | if isinstance(input_item, WindowedValue): 61 | windowed_value = input_item 62 | elif isinstance(input_item, TimestampedValue): 63 | assign_context = WindowFn.AssignContext(input_item.timestamp, input_item.value) 64 | windowed_value = WindowedValue( 65 | input_item.value, input_item.timestamp, window_fn.assign(assign_context) 66 | ) 67 | else: 68 | windowed_value = WindowedValue(input_item, 0, (GlobalWindow(),)) 69 | 70 | return windowed_value 71 | 72 | 73 | class RayDataTranslation(object): 74 | def __init__(self, applied_ptransform: AppliedPTransform, parallelism: int = 1): 75 | self.applied_ptransform = applied_ptransform 76 | self.parallelism = parallelism 77 | 78 | def apply( 79 | self, 80 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 81 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 82 | ): 83 | raise NotImplementedError 84 | 85 | 86 | class RayNoop(RayDataTranslation): 87 | def apply( 88 | self, 89 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 90 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 91 | ): 92 | return ray_ds 93 | 94 | 95 | class RayImpulse(RayDataTranslation): 96 | def apply( 97 | self, 98 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 99 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 100 | ): 101 | assert ray_ds is None 102 | return ray.data.from_items([0], parallelism=self.parallelism) 103 | 104 | 105 | class RayCreate(RayDataTranslation): 106 | def apply( 107 | self, 108 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 109 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 110 | ): 111 | assert ray_ds is None 112 | 113 | original_transform: Create = self.applied_ptransform.transform 114 | 115 | items = original_transform.values 116 | 117 | # Todo: parallelism should be configurable 118 | # Setting this to < 1 leads to errors for assert_that checks 119 | return ray.data.from_items(items, parallelism=self.parallelism) 120 | 121 | 122 | class RayRead(RayDataTranslation): 123 | def apply( 124 | self, 125 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 126 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 127 | ): 128 | assert ray_ds is None 129 | 130 | original_transform: _Read = self.applied_ptransform.transform 131 | 132 | source = original_transform.source 133 | 134 | if isinstance(source, io.textio._TextSource): 135 | filename = source._pattern.value 136 | ray_ds = ray.data.read_text(filename, parallelism=self.parallelism) 137 | 138 | skip_lines = int(source._skip_header_lines) 139 | if skip_lines > 0: 140 | _, ray_ds = ray_ds.split_at_indices([skip_lines]) 141 | 142 | return ray_ds 143 | 144 | raise NotImplementedError("Could not read from source:", source) 145 | 146 | 147 | class RayReshuffle(RayDataTranslation): 148 | def apply( 149 | self, 150 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 151 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 152 | ): 153 | assert ray_ds is not None 154 | return ray_ds.random_shuffle() 155 | 156 | 157 | class RayParDo(RayDataTranslation): 158 | def apply( 159 | self, 160 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 161 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 162 | ): 163 | assert ray_ds is not None 164 | assert isinstance(ray_ds, ray.data.Dataset) 165 | assert self.applied_ptransform.transform is not None 166 | assert isinstance(self.applied_ptransform.transform, ParDo) 167 | 168 | # Get original function and side inputs 169 | transform = self.applied_ptransform.transform 170 | label = transform.label 171 | map_fn = transform.fn 172 | args = transform.args or [] 173 | kwargs = transform.kwargs or {} 174 | 175 | main_input = list(self.applied_ptransform.main_inputs.values())[0] 176 | window_fn = ( 177 | main_input.windowing.windowfn if hasattr(main_input, "windowing") else None 178 | ) 179 | 180 | class TaggingReceiver(Receiver): 181 | def __init__(self, tag, values): 182 | self.tag = tag 183 | self.values = values 184 | 185 | def receive(self, windowed_value): 186 | if self.tag: 187 | output = TaggedOutput(self.tag, windowed_value) 188 | else: 189 | output = windowed_value 190 | self.values.append(output) 191 | 192 | # We might want to have separate receivers at some point 193 | # For now, collect everything in one list and filter afterwards 194 | # class TaggedReceivers(dict): 195 | # def __missing__(self, tag): 196 | # self[tag] = receiver = SimpleReceiver() 197 | # return receiver 198 | 199 | class OneReceiver(dict): 200 | def __init__(self, values): 201 | self.values = values 202 | 203 | def __missing__(self, key): 204 | if key not in self: 205 | self[key] = TaggingReceiver(key, self.values) 206 | return self[key] 207 | 208 | class RayDoFnWorker(object): 209 | def __init__(self): 210 | self._is_setup = False 211 | 212 | self.context = DoFnContext(label, state=None) 213 | self.bundle_finalizer_param = DoFn.BundleFinalizerParam() 214 | do_fn_signature = DoFnSignature(map_fn) 215 | 216 | self.values = [] 217 | 218 | self.tagged_receivers = OneReceiver(self.values) 219 | self.window_fn = window_fn 220 | 221 | output_processor = _OutputProcessor( 222 | window_fn=self.window_fn, 223 | main_receivers=self.tagged_receivers[None], 224 | tagged_receivers=self.tagged_receivers, 225 | per_element_output_counter=None, 226 | ) 227 | 228 | self.do_fn_invoker = DoFnInvoker.create_invoker( 229 | do_fn_signature, 230 | output_processor, 231 | self.context, 232 | side_inputs, 233 | args, 234 | kwargs, 235 | user_state_context=None, 236 | bundle_finalizer_param=self.bundle_finalizer_param, 237 | ) 238 | 239 | def __del__(self): 240 | self.do_fn_invoker.invoke_teardown() 241 | 242 | def ready(self): 243 | return "ok" 244 | 245 | def process_batch(self, batch): 246 | if not self._is_setup: 247 | self.do_fn_invoker.invoke_setup() 248 | self._is_setup = True 249 | 250 | self.do_fn_invoker.invoke_start_bundle() 251 | 252 | # Clear return list 253 | self.values.clear() 254 | 255 | for input_item in batch: 256 | windowed_value = get_windowed_value(input_item, self.window_fn) 257 | self.do_fn_invoker.invoke_process(windowed_value) 258 | 259 | self.do_fn_invoker.invoke_finish_bundle() 260 | 261 | # map_fn.process may return multiple items 262 | ret = list(self.values) 263 | return ret 264 | 265 | @ray.method(num_returns=2) 266 | def process_block( 267 | self, block: Block, meta: BlockMetadata 268 | ) -> (Block, BlockMetadata): 269 | if not self._is_setup: 270 | map_fn.setup() 271 | self._is_setup = True 272 | 273 | new_block = self.process_batch(block) 274 | accessor = BlockAccessor.for_block(new_block) 275 | new_metadata = BlockMetadata( 276 | num_rows=accessor.num_rows(), 277 | size_bytes=accessor.size_bytes(), 278 | schema=accessor.schema(), 279 | input_files=meta.input_files, 280 | ) 281 | return new_block, new_metadata 282 | 283 | def simple_batch_dofn(batch): 284 | context = DoFnContext(label, state=None) 285 | bundle_finalizer_param = DoFn.BundleFinalizerParam() 286 | do_fn_signature = DoFnSignature(map_fn) 287 | 288 | values = [] 289 | 290 | tagged_receivers = OneReceiver(values) 291 | 292 | output_processor = _OutputProcessor( 293 | window_fn=window_fn, 294 | main_receivers=tagged_receivers[None], 295 | tagged_receivers=tagged_receivers, 296 | per_element_output_counter=None, 297 | ) 298 | 299 | do_fn_invoker = DoFnInvoker.create_invoker( 300 | do_fn_signature, 301 | output_processor, 302 | context, 303 | side_inputs, 304 | args, 305 | kwargs, 306 | user_state_context=None, 307 | bundle_finalizer_param=bundle_finalizer_param, 308 | ) 309 | 310 | # Invoke setup just in case 311 | do_fn_invoker.invoke_setup() 312 | do_fn_invoker.invoke_start_bundle() 313 | 314 | for input_item in batch: 315 | windowed_value = get_windowed_value(input_item, window_fn) 316 | do_fn_invoker.invoke_process(windowed_value) 317 | 318 | do_fn_invoker.invoke_finish_bundle() 319 | # Invoke teardown just in case 320 | do_fn_invoker.invoke_teardown() 321 | 322 | # This has to happen last as we might receive results 323 | # in invoke_finish_bundle() or invoke_teardown() 324 | ret = list(values) 325 | 326 | return ret 327 | 328 | # Todo: implement 329 | dofn_has_no_setup_or_teardown = True 330 | 331 | if dofn_has_no_setup_or_teardown: 332 | return ray_ds.map_batches(simple_batch_dofn) 333 | 334 | # The lambda fn is ignored as the RayDoFnWorker encapsulates the 335 | # actual logic in self.process_batch 336 | return ray_ds.map_batches( 337 | lambda batch: batch, compute=CustomActorPool(worker_cls=RayDoFnWorker) 338 | ) 339 | 340 | 341 | class RayGroupByKey(RayDataTranslation): 342 | def apply( 343 | self, 344 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 345 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 346 | ): 347 | assert ray_ds is not None 348 | assert isinstance(ray_ds, ray.data.Dataset) 349 | 350 | # TODO(jjyao) Currently dataset doesn't handle 351 | # tuple groupby key so we wrap it in an object as a workaround. 352 | # This hack can be removed once dataset supports tuple groupby key. 353 | class KeyWrapper(object): 354 | def __init__(self, key): 355 | self.key = key 356 | 357 | def __lt__(self, other): 358 | return self.key < other.key 359 | 360 | def __eq__(self, other): 361 | return self.key == other.key 362 | 363 | def key(windowed_value): 364 | if not isinstance(windowed_value, WindowedValue): 365 | windowed_value = WindowedValue(windowed_value, 0, (GlobalWindow(),)) 366 | 367 | # Extract key from windowed value 368 | key, _ = windowed_value.value 369 | # We convert to strings here to support void keys 370 | return KeyWrapper(str(key) if key is None else key) 371 | 372 | def value(windowed_value): 373 | if not isinstance(windowed_value, WindowedValue): 374 | windowed_value = WindowedValue(windowed_value, 0, (GlobalWindow(),)) 375 | 376 | # Extract value from windowed value 377 | _, value = windowed_value.value 378 | return value 379 | 380 | return ( 381 | ray_ds.groupby(key) 382 | .aggregate( 383 | ray.data.aggregate.AggregateFn( 384 | init=lambda k: [], 385 | accumulate=lambda a, r: a + [value(r)], 386 | merge=lambda a1, a2: a1 + a2, 387 | ) 388 | ) 389 | .map(lambda r: (r[0].key, r[1])) 390 | ) 391 | 392 | 393 | class RayWindowInto(RayDataTranslation): 394 | def apply( 395 | self, 396 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 397 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 398 | ): 399 | window_fn = self.applied_ptransform.transform.windowing.windowfn 400 | 401 | def to_windowed_value(item): 402 | if isinstance(item, WindowedValue): 403 | return item 404 | 405 | if isinstance(item, TimestampedValue): 406 | return WindowedValue( 407 | item.value, 408 | item.timestamp, 409 | window_fn.assign( 410 | WindowFn.AssignContext(item.timestamp, element=item.value) 411 | ), 412 | ) 413 | 414 | return item 415 | 416 | return ray_ds.map(to_windowed_value) 417 | 418 | 419 | class RayFlatten(RayDataTranslation): 420 | def apply( 421 | self, 422 | ray_ds: Union[None, ray.data.Dataset, Mapping[str, ray.data.Dataset]] = None, 423 | side_inputs: Optional[Sequence[ray.data.Dataset]] = None, 424 | ): 425 | assert ray_ds is not None 426 | assert isinstance(ray_ds, Mapping) 427 | assert len(ray_ds) >= 1 428 | 429 | keys = sorted(ray_ds.keys()) 430 | primary_key = keys.pop(0) 431 | 432 | primary_ds = ray_ds[primary_key] 433 | return primary_ds.union(*[ray_ds[key] for key in keys]).repartition(1) 434 | 435 | 436 | translations = { 437 | _Create: RayCreate, # Composite transform 438 | _Read: RayRead, 439 | Impulse: RayImpulse, 440 | _Reshuffle: RayReshuffle, 441 | ParDo: RayParDo, 442 | Flatten: RayFlatten, 443 | WindowInto: RayParDo, # RayWindowInto, 444 | _GroupByKeyOnly: RayGroupByKey, 445 | _GroupAlsoByWindow: RayParDo, 446 | # CoGroupByKey: RayCoGroupByKey, 447 | PTransform: RayNoop, # Todo: How to handle generic ptransforms? Map? 448 | } 449 | 450 | 451 | class TranslationExecutor(PipelineVisitor): 452 | def __init__(self, collection_map: CollectionMap, parallelism: int = 1): 453 | self._collection_map = collection_map 454 | self._parallelism = parallelism 455 | 456 | def enter_composite_transform(self, transform_node: AppliedPTransform) -> None: 457 | pass 458 | 459 | def leave_composite_transform(self, transform_node: AppliedPTransform) -> None: 460 | pass 461 | 462 | def visit_transform(self, transform_node: AppliedPTransform) -> None: 463 | self.execute(transform_node) 464 | 465 | def get_translation( 466 | self, applied_ptransform: AppliedPTransform 467 | ) -> Optional[RayDataTranslation]: 468 | # Sanity check 469 | type_ = type(applied_ptransform.transform) 470 | if type_ not in translations: 471 | return None 472 | 473 | translation_factory = translations[type_] 474 | translation = translation_factory( 475 | applied_ptransform, parallelism=self._parallelism 476 | ) 477 | 478 | return translation 479 | 480 | def execute(self, applied_ptransform: AppliedPTransform) -> bool: 481 | translation = self.get_translation(applied_ptransform) 482 | 483 | if not translation: 484 | # Warn? Debug output? 485 | return False 486 | 487 | named_inputs = {} 488 | for name, element in applied_ptransform.named_inputs().items(): 489 | if isinstance(element, PBegin): 490 | ray_ds = None 491 | else: 492 | ray_ds = self._collection_map.get(element) 493 | 494 | named_inputs[name] = ray_ds 495 | 496 | if len(named_inputs) == 0: 497 | ray_ds = None 498 | else: 499 | ray_ds = {} 500 | for name in applied_ptransform.main_inputs.keys(): 501 | ray_ds[name] = named_inputs.pop(name) 502 | 503 | if len(ray_ds) == 1: 504 | ray_ds = list(ray_ds.values())[0] 505 | 506 | class RayDatasetAccessor(object): 507 | def __init__(self, ray_ds: ray.data.Dataset, window_fn: WindowFn): 508 | self.ray_ds = ray_ds 509 | self.window_fn = window_fn 510 | 511 | def __iter__(self): 512 | for row in self.ray_ds.iter_rows(): 513 | yield get_windowed_value(row, self.window_fn) 514 | 515 | side_inputs = [] 516 | for side_input in applied_ptransform.side_inputs: 517 | side_ds = self._collection_map.get(side_input.pvalue) 518 | side_inputs.append( 519 | SideInputMap( 520 | type(side_input), 521 | side_input._view_options(), 522 | RayDatasetAccessor(side_ds, side_input._window_mapping_fn), 523 | ) 524 | ) 525 | 526 | def _visualize(ray_ds_dict): 527 | for name, ray_ds in ray_ds_dict.items(): 528 | if not ray_ds: 529 | out = ray_ds 530 | elif not isinstance(ray_ds, ray.data.Dataset): 531 | out = ray_ds 532 | else: 533 | out = ray.get(ray_ds.to_numpy_refs()) 534 | print(("DATA", name, out)) 535 | continue 536 | 537 | def _visualize_all(ray_ds): 538 | if isinstance(ray_ds, ray.data.Dataset) or not ray_ds: 539 | _visualize({"_main": ray_ds}) 540 | elif isinstance(ray_ds, list): 541 | _visualize((dict(enumerate(ray_ds)))) 542 | else: 543 | _visualize(ray_ds) 544 | 545 | print("-" * 80) 546 | print("APPLYING", applied_ptransform.full_label) 547 | print("-" * 80) 548 | print("MAIN INPUT") 549 | _visualize_all(ray_ds) 550 | print("SIDE INPUTS") 551 | _visualize_all([list(si._iterable) for si in side_inputs]) 552 | print("." * 40) 553 | result = translation.apply(ray_ds, side_inputs=side_inputs) 554 | print("." * 40) 555 | print("RESULT", applied_ptransform.full_label) 556 | _visualize_all(result) 557 | print("-" * 80) 558 | 559 | for name, element in applied_ptransform.named_outputs().items(): 560 | if isinstance(result, dict): 561 | out = result.get(name) 562 | else: 563 | out = result 564 | 565 | if out: 566 | if name != "None": 567 | # Side output 568 | out = out.filter( 569 | lambda x: isinstance(x, TaggedOutput) and x.tag == name 570 | ) 571 | out = out.map(lambda x: x.value) 572 | else: 573 | # Main output 574 | out = out.filter(lambda x: not isinstance(x, TaggedOutput)) 575 | 576 | self._collection_map.set(element, out) 577 | 578 | return True 579 | -------------------------------------------------------------------------------- /ray_beam_runner/util.py: -------------------------------------------------------------------------------- 1 | # Licensed to the Apache Software Foundation (ASF) under one or more 2 | # contributor license agreements. See the NOTICE file distributed with 3 | # this work for additional information regarding copyright ownership. 4 | # The ASF licenses this file to You under the Apache License, Version 2.0 5 | # (the "License"); you may not use this file except in compliance with 6 | # the License. You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | from apache_beam.pipeline import PipelineVisitor 18 | 19 | 20 | class PipelinePrinter(PipelineVisitor): 21 | def visit_value(self, value, producer_node): 22 | print(f"visit_value(value, {producer_node.full_label})") 23 | 24 | def visit_transform(self, transform_node): 25 | print(f"visit_transform({type(transform_node.transform)})") 26 | 27 | def enter_composite_transform(self, transform_node): 28 | print(f"enter_composite_transform({transform_node.full_label})") 29 | 30 | def leave_composite_transform(self, transform_node): 31 | print(f"leave_composite_transform({transform_node.full_label})") 32 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | click==8.0.1 2 | black==21.12b0 3 | flake8==3.9.1 4 | flake8-comprehensions 5 | flake8-quotes==2.0.0 6 | flake8-bugbear==21.9.2 7 | pylint==2.15.4 8 | pytest==7.1.2 9 | pyhamcrest==2.0.3 10 | pytest-benchmark 11 | apache-beam==2.42.0 12 | -------------------------------------------------------------------------------- /scripts/format.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Licensed to the Apache Software Foundation (ASF) under one or more 4 | # contributor license agreements. See the NOTICE file distributed with 5 | # this work for additional information regarding copyright ownership. 6 | # The ASF licenses this file to You under the Apache License, Version 2.0 7 | # (the "License"); you may not use this file except in compliance with 8 | # the License. You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | # Black formatter (if installed). This script formats all changed files from the last mergebase. 19 | # You are encouraged to run this locally before pushing changes for review. 20 | 21 | # Cause the script to exit if a single command fails 22 | set -eo pipefail 23 | 24 | FLAKE8_VERSION_REQUIRED="3.9.1" 25 | BLACK_VERSION_REQUIRED="21.12b0" 26 | 27 | check_command_exist() { 28 | VERSION="" 29 | case "$1" in 30 | black) 31 | VERSION=$BLACK_VERSION_REQUIRED 32 | ;; 33 | flake8) 34 | VERSION=$FLAKE8_VERSION_REQUIRED 35 | ;; 36 | *) 37 | echo "$1 is not a required dependency" 38 | exit 1 39 | esac 40 | if ! [ -x "$(command -v $1)" ]; then 41 | echo "$1 not installed. pip install $1==$VERSION" 42 | exit 1 43 | fi 44 | } 45 | 46 | check_command_exist black 47 | check_command_exist flake8 48 | 49 | # this stops git rev-parse from failing if we run this from the .git directory 50 | builtin cd "$(dirname "${BASH_SOURCE:-$0}")" 51 | 52 | ROOT="$(git rev-parse --show-toplevel)" 53 | builtin cd "$ROOT" || exit 1 54 | 55 | FLAKE8_VERSION=$(flake8 --version | head -n 1 | awk '{print $1}') 56 | BLACK_VERSION=$(black --version | awk '{print $2}') 57 | 58 | # params: tool name, tool version, required version 59 | tool_version_check() { 60 | if [[ $2 != $3 ]]; then 61 | echo "WARNING: Ray Beam Runner uses $1 $3, You are currently using $2. This might generate different results." 62 | fi 63 | } 64 | 65 | tool_version_check "flake8" "$FLAKE8_VERSION" "$FLAKE8_VERSION_REQUIRED" 66 | tool_version_check "black" "$BLACK_VERSION" "$BLACK_VERSION_REQUIRED" 67 | 68 | 69 | # Format specified files 70 | format_files() { 71 | black "$@" 72 | } 73 | 74 | # Format files that differ from main branch. Ignores dirs that are not slated 75 | # for autoformat yet. 76 | format_changed() { 77 | # The `if` guard ensures that the list of filenames is not empty, which 78 | # could cause the formatter to receive 0 positional arguments, making 79 | # Black error. 80 | # 81 | # `diff-filter=ACRM` and $MERGEBASE is to ensure we only format files that 82 | # exist on both branches. 83 | MERGEBASE="$(git merge-base upstream/master HEAD)" 84 | 85 | if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then 86 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ 87 | black 88 | git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ 89 | flake8 --config=.flake8 90 | fi 91 | } 92 | 93 | # Format all files, and print the diff to stdout for travis. 94 | format_all() { 95 | black ray_beam_runner/ 96 | flake8 --config=.flake8 ray_beam_runner 97 | pylint ray_beam_runner/ 98 | } 99 | 100 | # This flag formats individual files. --files *must* be the first command line 101 | # arg to use this option. 102 | if [[ "$1" == '--files' ]]; then 103 | format_files "${@:2}" 104 | # If `--all` is passed, then any further arguments are ignored and the 105 | # entire python directory is formatted. 106 | elif [[ "$1" == '--all' ]]; then 107 | format_all 108 | else 109 | # Add the upstream remote if it doesn't exist 110 | if ! git remote -v | grep -q upstream; then 111 | git remote add 'upstream' 'https://github.com/ray-project/ray_beam_runner.git' 112 | fi 113 | 114 | # Only fetch master since that's the branch we're diffing against. 115 | git fetch upstream master || true 116 | 117 | # Format only the files that changed in last commit. 118 | format_changed 119 | fi 120 | 121 | if ! git diff --quiet &>/dev/null; then 122 | echo 'Reformatted changed files. Please review and stage the changes.' 123 | echo 'Files updated:' 124 | echo 125 | 126 | git --no-pager diff --name-only 127 | 128 | exit 1 129 | fi 130 | 131 | echo 'Linting check finished successfully.' 132 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | from setuptools import find_packages, setup 18 | 19 | TEST_REQUIREMENTS = [ 20 | 'apache_beam[test]', 21 | 'pyhamcrest', 22 | 'pytest', 23 | ] 24 | 25 | setup( 26 | name="ray_beam", 27 | packages=find_packages(where=".", include="ray_beam_runner*"), 28 | version="0.0.1", 29 | author="Ray Team", 30 | description="An Apache Beam Runner using Ray.", 31 | long_description="An Apache Beam Runner based on the Ray " 32 | "distributed computing framework.", 33 | url="https://github.com/ray-project/ray_beam_runner", 34 | classifiers=[ 35 | "Programming Language :: Python :: 3.7", 36 | "Programming Language :: Python :: 3.8", 37 | "Programming Language :: Python :: 3.9", 38 | ], 39 | install_requires=[ 40 | "ray[data]", "apache_beam" 41 | ], 42 | extras_require={ 43 | 'test': TEST_REQUIREMENTS, 44 | } 45 | ) 46 | --------------------------------------------------------------------------------