├── .github └── workflows │ ├── pypi.yml │ ├── ray_nightly_test.yml │ ├── raydp.yml │ └── raydp_nightly.yml ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── bin └── raydp-submit ├── build.sh ├── core ├── agent │ ├── pom.xml │ └── src │ │ └── main │ │ └── java │ │ └── org │ │ ├── apache │ │ └── spark │ │ │ └── raydp │ │ │ └── Agent.java │ │ └── slf4j │ │ └── impl │ │ └── StaticLoggerBinder.java ├── javastyle-suppressions.xml ├── javastyle.xml ├── pom.xml ├── raydp-main │ ├── pom.xml │ └── src │ │ └── main │ │ ├── java │ │ └── org │ │ │ └── apache │ │ │ └── spark │ │ │ ├── deploy │ │ │ └── raydp │ │ │ │ ├── ExternalShuffleServiceUtils.java │ │ │ │ └── RayAppMasterUtils.java │ │ │ └── raydp │ │ │ ├── RayDPUtils.java │ │ │ ├── RayExecutorUtils.java │ │ │ └── SparkOnRayConfigs.java │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── org.apache.spark.scheduler.ExternalClusterManager │ │ ├── scala │ │ └── org │ │ │ └── apache │ │ │ └── spark │ │ │ ├── RayDPException.scala │ │ │ ├── deploy │ │ │ ├── SparkSubmit.scala │ │ │ └── raydp │ │ │ │ ├── AppMasterEntryPoint.scala │ │ │ │ ├── AppMasterJavaBridge.scala │ │ │ │ ├── ApplicationDescription.scala │ │ │ │ ├── ApplicationInfo.scala │ │ │ │ ├── ApplicationState.scala │ │ │ │ ├── Messages.scala │ │ │ │ ├── RayAppMaster.scala │ │ │ │ ├── RayDPDriverAgent.scala │ │ │ │ └── RayExternalShuffleService.scala │ │ │ ├── executor │ │ │ └── RayDPExecutor.scala │ │ │ ├── rdd │ │ │ ├── RayDatasetRDD.scala │ │ │ └── RayObjectRefRDD.scala │ │ │ ├── scheduler │ │ │ └── cluster │ │ │ │ └── raydp │ │ │ │ ├── RayClusterManager.scala │ │ │ │ └── RayCoarseGrainedSchedulerBackend.scala │ │ │ ├── sql │ │ │ └── raydp │ │ │ │ ├── ObjectStoreReader.scala │ │ │ │ └── ObjectStoreWriter.scala │ │ │ └── util │ │ │ └── DependencyUtils.scala │ │ └── test │ │ └── org │ │ └── apache │ │ └── spark │ │ └── scheduler │ │ └── cluster │ │ └── raydp │ │ └── TestRayCoarseGrainedSchedulerBackend.java ├── scalastyle.xml └── shims │ ├── common │ ├── pom.xml │ └── src │ │ └── main │ │ └── scala │ │ ├── com │ │ └── intel │ │ │ └── raydp │ │ │ └── shims │ │ │ ├── SparkShimLoader.scala │ │ │ ├── SparkShimProvider.scala │ │ │ └── SparkShims.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── executor │ │ └── RayDPExecutorBackendFactory.scala │ ├── pom.xml │ ├── spark322 │ ├── pom.xml │ └── src │ │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ └── scala │ │ ├── com │ │ └── intel │ │ │ └── raydp │ │ │ └── shims │ │ │ ├── SparkShimProvider.scala │ │ │ └── SparkShims.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ ├── TaskContextUtils.scala │ │ ├── executor │ │ └── RayDPSpark322ExecutorBackendFactory.scala │ │ └── sql │ │ └── SparkSqlUtils.scala │ ├── spark330 │ ├── pom.xml │ └── src │ │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ └── scala │ │ ├── com │ │ └── intel │ │ │ └── raydp │ │ │ └── shims │ │ │ ├── SparkShimProvider.scala │ │ │ └── SparkShims.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ ├── TaskContextUtils.scala │ │ ├── executor │ │ ├── RayCoarseGrainedExecutorBackend.scala │ │ └── RayDPSpark330ExecutorBackendFactory.scala │ │ └── sql │ │ └── SparkSqlUtils.scala │ ├── spark340 │ ├── pom.xml │ └── src │ │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ └── scala │ │ ├── com │ │ └── intel │ │ │ └── raydp │ │ │ └── shims │ │ │ ├── SparkShimProvider.scala │ │ │ └── SparkShims.scala │ │ └── org │ │ └── apache │ │ └── spark │ │ ├── TaskContextUtils.scala │ │ ├── executor │ │ ├── RayCoarseGrainedExecutorBackend.scala │ │ └── RayDPSpark340ExecutorBackendFactory.scala │ │ └── sql │ │ └── SparkSqlUtils.scala │ └── spark350 │ ├── pom.xml │ └── src │ └── main │ ├── resources │ └── META-INF │ │ └── services │ │ └── com.intel.raydp.shims.SparkShimProvider │ └── scala │ ├── com │ └── intel │ │ └── raydp │ │ └── shims │ │ ├── SparkShimProvider.scala │ │ └── SparkShims.scala │ └── org │ └── apache │ └── spark │ ├── TaskContextUtils.scala │ ├── executor │ ├── RayCoarseGrainedExecutorBackend.scala │ └── RayDPSpark350ExecutorBackendFactory.scala │ └── sql │ └── SparkSqlUtils.scala ├── doc ├── mpi.md └── spark_on_ray.md ├── docker ├── Dockerfile ├── README.md ├── build-docker.sh └── legacy.yaml ├── examples ├── README.md ├── data_process.py ├── fake_nyctaxi.csv ├── horovod_nyctaxi.py ├── pytorch_dlrm.ipynb ├── pytorch_nyctaxi.py ├── random_nyctaxi.py ├── raydp-submit.py ├── raytrain_nyctaxi.py ├── tensorflow_nyctaxi.py ├── tensorflow_titanic.ipynb └── xgboost_ray_nyctaxi.py ├── python ├── MANIFEST.in ├── pylintrc ├── raydp │ ├── __init__.py │ ├── context.py │ ├── estimator.py │ ├── mpi │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── mpi_job.py │ │ ├── mpi_worker.py │ │ ├── network │ │ │ ├── __init__.py │ │ │ ├── network.proto │ │ │ ├── network_pb2.py │ │ │ └── network_pb2_grpc.py │ │ └── utils.py │ ├── ray_cluster_resources.py │ ├── services.py │ ├── spark │ │ ├── __init__.py │ │ ├── dataset.py │ │ ├── interfaces.py │ │ ├── parallel_iterator_worker.py │ │ ├── ray_cluster.py │ │ └── ray_cluster_master.py │ ├── tests │ │ ├── conftest.py │ │ ├── test_data_owner_transfer.py │ │ ├── test_mpi.py │ │ ├── test_spark_cluster.py │ │ ├── test_spark_utils.py │ │ ├── test_tf.py │ │ ├── test_torch.py │ │ ├── test_torch_sequential.py │ │ └── test_xgboost.py │ ├── tf │ │ ├── __init__.py │ │ └── estimator.py │ ├── torch │ │ ├── __init__.py │ │ ├── config.py │ │ ├── estimator.py │ │ ├── torch_metrics.py │ │ └── torch_ml_dataset.py │ ├── utils.py │ ├── versions.py │ └── xgboost │ │ ├── __init__.py │ │ └── estimator.py └── setup.py └── tutorials ├── dataset └── healthcare-dataset-stroke-data.csv ├── pytorch_example.ipynb └── raytrain_example.ipynb /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | name: RayDP PyPi 19 | 20 | on: 21 | schedule: 22 | - cron: '0 0 * * *' 23 | # can manually trigger the workflow 24 | workflow_dispatch: 25 | 26 | permissions: # added using https://github.com/step-security/secure-repo 27 | contents: read 28 | 29 | jobs: 30 | build-and-publish: 31 | # do not run in forks 32 | if: ${{ github.repository_owner == 'oap-project' }} 33 | name: build wheel and upload 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master 37 | - name: Set up Python 3.9 38 | uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4 39 | with: 40 | python-version: 3.9 41 | - name: Set up JDK 1.8 42 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 43 | with: 44 | java-version: 1.8 45 | - name: days since the commit date 46 | run: | 47 | : 48 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA) 49 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 )) 50 | if [ $days -eq 0 ]; then 51 | echo COMMIT_TODAY=true >> $GITHUB_ENV 52 | fi 53 | - name: Build wheel 54 | if: env.COMMIT_TODAY == 'true' 55 | env: 56 | RAYDP_BUILD_MODE: nightly 57 | run: pip install wheel grpcio-tools && ./build.sh 58 | - name: Upload 59 | if: env.COMMIT_TODAY == 'true' 60 | uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 # release/v1 61 | with: 62 | password: ${{ secrets.PYPI_API_TOKEN }} 63 | -------------------------------------------------------------------------------- /.github/workflows/raydp.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | name: RayDP CI 19 | 20 | on: 21 | push: 22 | branches: [main, master] 23 | pull_request: 24 | branches: [main, master] 25 | workflow_dispatch: 26 | 27 | permissions: # added using https://github.com/step-security/secure-repo 28 | contents: read 29 | 30 | jobs: 31 | build-and-test: 32 | strategy: 33 | matrix: 34 | os: [ubuntu-latest] 35 | python-version: [3.9, 3.10.14] 36 | spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0] 37 | ray-version: [2.34.0, 2.40.0] 38 | 39 | runs-on: ${{ matrix.os }} 40 | 41 | steps: 42 | - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0 43 | - name: Set up Python ${{ matrix.python-version }} 44 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 45 | with: 46 | python-version: ${{ matrix.python-version }} 47 | - name: Set up JDK 1.8 48 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 49 | with: 50 | java-version: 1.8 51 | - name: Install extra dependencies for macOS 52 | if: matrix.os == 'macos-latest' 53 | run: | 54 | brew install pkg-config 55 | brew install libuv libomp mpich 56 | - name: Install extra dependencies for Ubuntu 57 | if: matrix.os == 'ubuntu-latest' 58 | run: | 59 | sudo apt-get install -y mpich 60 | - name: Cache pip - Ubuntu 61 | if: matrix.os == 'ubuntu-latest' 62 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 63 | with: 64 | path: ~/.cache/pip 65 | key: ${{ matrix.os }}-${{ matrix.python-version }}-pip 66 | - name: Cache pip - MacOS 67 | if: matrix.os == 'macos-latest' 68 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 69 | with: 70 | path: ~/Library/Caches/pip 71 | key: ${{ matrix.os }}-${{ matrix.python-version }}-pip 72 | - name: Install dependencies 73 | run: | 74 | python -m pip install --upgrade pip 75 | pip install wheel 76 | pip install "numpy<1.24" 77 | pip install "pydantic<2.0" 78 | SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])') 79 | if [ "$(uname -s)" == "Linux" ] 80 | then 81 | pip install torch --index-url https://download.pytorch.org/whl/cpu 82 | else 83 | pip install torch 84 | fi 85 | pip install pyarrow "ray[train]==${{ matrix.ray-version }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget 86 | pip install "xgboost_ray[default]<=0.1.13" 87 | pip install "xgboost<=2.0.3" 88 | pip install torchmetrics 89 | - name: Cache Maven 90 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 91 | with: 92 | path: ~/.m2 93 | key: ${{ matrix.os }}-m2-${{ hashFiles('core/pom.xml') }} 94 | - name: Build and install 95 | env: 96 | GITHUB_CI: 1 97 | run: | 98 | pip install pyspark==${{ matrix.spark-version }} 99 | ./build.sh 100 | pip install dist/raydp-*.whl 101 | - name: Lint 102 | run: | 103 | pip install pylint==2.8.3 104 | pylint --rcfile=python/pylintrc python/raydp 105 | pylint --rcfile=python/pylintrc examples/*.py 106 | - name: Test with pytest 107 | run: | 108 | ray start --head --num-cpus 6 109 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest python/raydp/tests/ -v 110 | ray stop --force 111 | - name: Test Examples 112 | run: | 113 | ray start --head 114 | python examples/raydp-submit.py 115 | ray stop 116 | python examples/pytorch_nyctaxi.py 117 | python examples/tensorflow_nyctaxi.py 118 | python examples/xgboost_ray_nyctaxi.py 119 | # python examples/raytrain_nyctaxi.py 120 | python examples/data_process.py 121 | -------------------------------------------------------------------------------- /.github/workflows/raydp_nightly.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | name: Legacy raydp_nightly PyPi 19 | 20 | on: 21 | schedule: 22 | - cron: '0 0 * * *' 23 | # can manually trigger the workflow 24 | workflow_dispatch: 25 | 26 | permissions: # added using https://github.com/step-security/secure-repo 27 | contents: read 28 | 29 | jobs: 30 | build-and-publish: 31 | # do not run in forks 32 | if: ${{ github.repository_owner == 'oap-project' }} 33 | name: build wheel and upload 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master 37 | - name: Set up Python 3.9 38 | uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4 39 | with: 40 | python-version: 3.9 41 | - name: Set up JDK 1.8 42 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 43 | with: 44 | java-version: 1.8 45 | - name: days since the commit date 46 | run: | 47 | : 48 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA) 49 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 )) 50 | if [ $days -eq 0 ]; then 51 | echo COMMIT_TODAY=true >> $GITHUB_ENV 52 | fi 53 | - name: Build wheel 54 | if: env.COMMIT_TODAY == 'true' 55 | env: 56 | RAYDP_PACKAGE_NAME: raydp_nightly 57 | run: pip install wheel grpcio-tools && ./build.sh 58 | - name: Upload 59 | if: env.COMMIT_TODAY == 'true' 60 | uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 # release/v1 61 | with: 62 | password: ${{ secrets.PYPI_API_TOKEN }} 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | 4 | __pycache__/ 5 | build/ 6 | dist/ 7 | *.egg-info/ 8 | *.eggs/ 9 | 10 | dev/.tmp_dir/ 11 | target/ 12 | *.jar 13 | 14 | .DS_Store 15 | 16 | .vscode 17 | examples/.ipynb_checkpoints/ 18 | .python-version 19 | 20 | # Vim temp files 21 | *.swp 22 | *.swo 23 | *.parquet 24 | *.crc 25 | _SUCCESS 26 | 27 | .metals/ 28 | .bloop/ 29 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Report a Vulnerability 4 | 5 | Please report security issues or vulnerabilities to the [Intel® Security Center]. 6 | 7 | For more information on how Intel® works to resolve security issues, see 8 | [Vulnerability Handling Guidelines]. 9 | 10 | [Intel® Security Center]:https://www.intel.com/security 11 | 12 | [Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html 13 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | set -ex 21 | 22 | if ! command -v mvn &> /dev/null 23 | then 24 | echo "mvn could not be found, please install maven first" 25 | exit 26 | else 27 | mvn_path=`which mvn` 28 | echo "Using ${mvn_path} for build core module" 29 | fi 30 | 31 | CURRENT_DIR="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" 32 | DIST_PATH=${CURRENT_DIR}/dist/ 33 | 34 | if [[ ! -d ${DIST_PATH} ]]; 35 | then 36 | mkdir ${DIST_PATH} 37 | fi 38 | 39 | # build core part 40 | CORE_DIR="${CURRENT_DIR}/core" 41 | pushd ${CORE_DIR} 42 | if [[ -z $GITHUB_CI ]]; 43 | then 44 | mvn clean package -q -DskipTests 45 | else 46 | mvn verify -q 47 | fi 48 | popd # core dir 49 | 50 | # build python part 51 | RAYDP_PACKAGE_NAME=${RAYDP_PACKAGE_NAME:-raydp} 52 | PYTHON_DIR="${CURRENT_DIR}/python" 53 | 54 | if [[ -d "${PYTHON_DIR}/build" ]]; 55 | then 56 | rm -rf "${PYTHON_DIR}/build" 57 | fi 58 | 59 | pushd ${PYTHON_DIR} 60 | python setup.py bdist_wheel 61 | cp ${PYTHON_DIR}/dist/${RAYDP_PACKAGE_NAME}-* ${DIST_PATH} 62 | popd # python dir 63 | 64 | set +ex 65 | -------------------------------------------------------------------------------- /core/agent/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-parent 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-agent 15 | RayDP Java Agent 16 | jar 17 | 18 | 19 | 20 | 21 | org.apache.maven.plugins 22 | maven-compiler-plugin 23 | 3.8.0 24 | 25 | 1.8 26 | 1.8 27 | 28 | 29 | 30 | org.apache.maven.plugins 31 | maven-jar-plugin 32 | 3.3.0 33 | 34 | 35 | 36 | org.apache.spark.raydp.Agent 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | org.apache.logging.log4j 47 | log4j-core 48 | 2.17.1 49 | 50 | 51 | org.apache.logging.log4j 52 | log4j-slf4j-impl 53 | 2.17.1 54 | 55 | 56 | org.slf4j 57 | slf4j-api 58 | 1.7.32 59 | 60 | 61 | com.intel 62 | raydp 63 | ${project.version} 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /core/agent/src/main/java/org/apache/spark/raydp/Agent.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.raydp; 19 | 20 | import org.slf4j.LoggerFactory; 21 | 22 | import java.io.File; 23 | import java.io.FileOutputStream; 24 | import java.io.IOException; 25 | import java.io.OutputStream; 26 | import java.io.OutputStreamWriter; 27 | import java.io.PrintStream; 28 | import java.io.Writer; 29 | import java.lang.instrument.Instrumentation; 30 | import java.lang.management.ManagementFactory; 31 | import java.nio.charset.Charset; 32 | import java.nio.charset.StandardCharsets; 33 | 34 | 35 | public class Agent { 36 | 37 | public static final PrintStream DEFAULT_ERR_PS = System.err; 38 | 39 | public static final PrintStream DEFAULT_OUT_PS = System.out; 40 | 41 | public static void premain(String agentArgs, Instrumentation inst) 42 | throws IOException { 43 | // redirect system output/error stream so that annoying SLF4J warnings 44 | // and other logs during binding 45 | // SLF4J factory don't show in spark-shell 46 | // Instead, the warnings and logs are kept in 47 | // /logs/slf4j-.log 48 | 49 | String pid = ManagementFactory.getRuntimeMXBean().getName() 50 | .split("@")[0]; 51 | String logDir = System.getProperty("ray.logging.dir"); 52 | if (logDir == null) { 53 | logDir = "/tmp/ray/session_latest/logs"; 54 | System.getProperties().put("ray.logging.dir", logDir); 55 | } 56 | 57 | File parentDir = new File(logDir); 58 | if (!parentDir.exists()) { 59 | boolean flag = parentDir.mkdirs(); 60 | if (!flag) { 61 | throw new RuntimeException("Error create log dir."); 62 | } 63 | } 64 | 65 | File logFile = new File(parentDir, "/slf4j-" + pid + ".log"); 66 | try (PrintStream ps = new PrintStream(logFile, "UTF-8")) { 67 | System.setOut(ps); 68 | System.setErr(ps); 69 | // slf4j binding 70 | LoggerFactory.getLogger(Agent.class); 71 | } catch (Exception e) { 72 | e.printStackTrace(); 73 | } finally { 74 | System.out.flush(); 75 | System.err.flush(); 76 | // restore system output/error stream 77 | System.setErr(DEFAULT_ERR_PS); 78 | System.setOut(DEFAULT_OUT_PS); 79 | } 80 | // below is to write ':job_id:' to first line of log file prefixed with 'java-worker' as required by 81 | // PR, https://github.com/ray-project/ray/pull/31772. 82 | // It's a workaround of the ray 2.3.[0-1] issue going to be fixed by https://github.com/ray-project/ray/pull/33665. 83 | String jobId = System.getenv("RAY_JOB_ID"); 84 | String rayAddress = System.getProperty("ray.address"); 85 | if (jobId != null && rayAddress != null) { 86 | String prefix = "java-worker"; 87 | // TODO: uncomment after the ray PR #33665 released 88 | // String prefix = System.getProperty("ray.logging.file-prefix", "java-worker"); 89 | // if ("java-worker".equals(prefix)) { 90 | File file = new File(new String((logDir + "/" + prefix + "-" + jobId + "-" + pid + ".log") 91 | .getBytes(Charset.forName("UTF-8")), "UTF-8")); 92 | try (OutputStream out = new FileOutputStream(file); 93 | Writer writer = new OutputStreamWriter(out, StandardCharsets.UTF_8)) { 94 | writer.write(":job_id:" + jobId + "\n"); 95 | } 96 | // } 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /core/agent/src/main/java/org/slf4j/impl/StaticLoggerBinder.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.slf4j.impl; 19 | 20 | import org.apache.spark.raydp.Agent; 21 | import org.apache.spark.raydp.SparkOnRayConfigs; 22 | import org.slf4j.ILoggerFactory; 23 | import org.slf4j.spi.LoggerFactoryBinder; 24 | 25 | import java.io.PrintStream; 26 | import java.util.HashMap; 27 | import java.util.Map; 28 | 29 | /** 30 | * A delegation class to bind to slf4j so that we can have a chance to choose 31 | * which underlying log4j framework to use. 32 | */ 33 | public class StaticLoggerBinder implements LoggerFactoryBinder { 34 | 35 | // for compatibility check 36 | public static final String REQUESTED_API_VERSION; 37 | 38 | private static final Class LOGFACTORY_CLASS; 39 | 40 | private static final ILoggerFactory FACTORY; 41 | 42 | private static final StaticLoggerBinder _INSTANCE = new StaticLoggerBinder(); 43 | 44 | private static final Map LOG_FACTORY_CLASSES 45 | = new HashMap<>(); 46 | 47 | private static PrintStream subSystemErr; 48 | 49 | private static PrintStream subSystemOut; 50 | 51 | static { 52 | subSystemErr = System.err; 53 | subSystemOut = System.out; 54 | 55 | LOG_FACTORY_CLASSES.put("log4j", 56 | "org.slf4j.impl.Log4jLoggerFactory"); // log4j 1 57 | LOG_FACTORY_CLASSES.put("log4j2", 58 | "org.apache.logging.slf4j.Log4jLoggerFactory"); // log4j 2 59 | 60 | String factoryClzStr = System 61 | .getProperty(SparkOnRayConfigs.LOG4J_FACTORY_CLASS_KEY, ""); 62 | if (factoryClzStr.length() == 0) { 63 | System.err.println("ERROR: system property '" 64 | + SparkOnRayConfigs.LOG4J_FACTORY_CLASS_KEY 65 | + "' needs to be specified for slf4j binding"); 66 | LOGFACTORY_CLASS = null; 67 | FACTORY = null; 68 | } else { 69 | String mappedClsStr = LOG_FACTORY_CLASSES.get(factoryClzStr); 70 | if (mappedClsStr == null) { 71 | mappedClsStr = factoryClzStr; 72 | } 73 | // restore to system default stream so that log4j console appender 74 | // can be correctly set 75 | System.setErr(Agent.DEFAULT_ERR_PS); 76 | System.setOut(Agent.DEFAULT_OUT_PS); 77 | Class tempClass = null; 78 | try { 79 | tempClass = Class.forName(mappedClsStr); 80 | } catch (Exception e) { 81 | e.printStackTrace(); 82 | } finally { 83 | LOGFACTORY_CLASS = tempClass; 84 | } 85 | StringBuilder sb = new StringBuilder(); 86 | sb.append("mapped factory class: ").append(mappedClsStr) 87 | .append(". load "); 88 | if (LOGFACTORY_CLASS != null) { 89 | sb.append(LOGFACTORY_CLASS.getName()); 90 | try { 91 | String loc = LOGFACTORY_CLASS.getProtectionDomain().getCodeSource() 92 | .getLocation().toURI().toString(); 93 | sb.append(" from ").append(loc); 94 | } catch (Exception e) { 95 | e.printStackTrace(); 96 | } 97 | } else { 98 | sb.append("failed"); 99 | } 100 | 101 | ILoggerFactory tmpFactory = null; 102 | try { 103 | tmpFactory = (ILoggerFactory) tempClass.newInstance(); 104 | } catch (Exception e) { 105 | e.printStackTrace(); 106 | } finally { 107 | FACTORY = tmpFactory; 108 | } 109 | // set to substitute stream for capturing remaining logs 110 | System.setErr(subSystemErr); 111 | System.setOut(subSystemOut); 112 | System.out.println(sb); 113 | } 114 | REQUESTED_API_VERSION = "1.6.66"; 115 | } 116 | 117 | public static final StaticLoggerBinder getSingleton() { 118 | return _INSTANCE; 119 | } 120 | 121 | @Override 122 | public ILoggerFactory getLoggerFactory() { 123 | // restore to system default stream so that log4j console appender 124 | // can be correctly set 125 | if (System.out != Agent.DEFAULT_OUT_PS) { 126 | System.setOut(Agent.DEFAULT_OUT_PS); 127 | } 128 | if (System.err != Agent.DEFAULT_ERR_PS) { 129 | System.setErr(Agent.DEFAULT_ERR_PS); 130 | } 131 | return FACTORY; 132 | } 133 | 134 | @Override 135 | public String getLoggerFactoryClassStr() { 136 | return LOGFACTORY_CLASS.getName(); 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /core/javastyle-suppressions.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 21 | 22 | 29 | 30 | 31 | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | 49 | 51 | 52 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/ExternalShuffleServiceUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp; 19 | 20 | import java.util.List; 21 | 22 | import io.ray.api.ActorHandle; 23 | import io.ray.api.Ray; 24 | 25 | public class ExternalShuffleServiceUtils { 26 | public static ActorHandle createShuffleService( 27 | String node, List options) { 28 | return Ray.actor(RayExternalShuffleService::new) 29 | .setResource("node:" + node, 0.01) 30 | .setJvmOptions(options).remote(); 31 | } 32 | 33 | public static void startShuffleService( 34 | ActorHandle handle) { 35 | handle.task(RayExternalShuffleService::start).remote(); 36 | } 37 | 38 | public static void stopShuffleService( 39 | ActorHandle handle) { 40 | handle.task(RayExternalShuffleService::stop).remote(); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp; 19 | 20 | import java.util.List; 21 | import java.util.Map; 22 | 23 | import io.ray.api.ActorHandle; 24 | import io.ray.api.Ray; 25 | import io.ray.api.call.ActorCreator; 26 | import org.apache.spark.raydp.SparkOnRayConfigs; 27 | 28 | public class RayAppMasterUtils { 29 | public static ActorHandle createAppMaster( 30 | String cp, 31 | String name, 32 | List jvmOptions, 33 | Map appMasterResource) { 34 | ActorCreator creator = Ray.actor(RayAppMaster::new, cp); 35 | if (name != null) { 36 | creator.setName(name); 37 | } 38 | jvmOptions.add("-cp"); 39 | jvmOptions.add(cp); 40 | creator.setJvmOptions(jvmOptions); 41 | for(Map.Entry resource : appMasterResource.entrySet()) { 42 | String resourceName = resource.getKey() 43 | .substring(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX.length() + 1); 44 | creator.setResource(resourceName, resource.getValue()); 45 | } 46 | 47 | return creator.remote(); 48 | } 49 | 50 | public static String getMasterUrl( 51 | ActorHandle handle) { 52 | return handle.task(RayAppMaster::getMasterUrl).remote().get(); 53 | } 54 | 55 | public static Map getRestartedExecutors( 56 | ActorHandle handle) { 57 | return handle.task(RayAppMaster::getRestartedExecutors).remote().get(); 58 | } 59 | 60 | public static void stopAppMaster( 61 | ActorHandle handle) { 62 | handle.task(RayAppMaster::stop).remote().get(); 63 | handle.kill(); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/raydp/RayDPUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.raydp; 19 | 20 | import io.ray.api.ObjectRef; 21 | import io.ray.api.Ray; 22 | import io.ray.api.id.ObjectId; 23 | import io.ray.runtime.AbstractRayRuntime; 24 | import io.ray.runtime.object.ObjectRefImpl; 25 | 26 | public class RayDPUtils { 27 | 28 | /** 29 | * Convert ObjectRef to subclass ObjectRefImpl. Throw RuntimeException if it is not instance 30 | * of ObjectRefImpl. We can't import the ObjectRefImpl in scala code, so we do the 31 | * conversion at here. 32 | */ 33 | public static ObjectRefImpl convert(ObjectRef obj) { 34 | if (obj instanceof ObjectRefImpl) { 35 | return (ObjectRefImpl)obj; 36 | } else { 37 | throw new RuntimeException(obj.getClass() + " is not ObjectRefImpl"); 38 | } 39 | } 40 | 41 | /** 42 | * Create ObjectRef from Array[Byte] and register ownership. 43 | * We can't import the ObjectRefImpl in scala code, so we do the conversion at here. 44 | */ 45 | public static ObjectRef readBinary(byte[] obj, Class clazz, byte[] ownerAddress) { 46 | ObjectId id = new ObjectId(obj); 47 | ObjectRefImpl ref = new ObjectRefImpl<>(id, clazz, false); 48 | AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); 49 | runtime.getObjectStore().registerOwnershipInfoAndResolveFuture( 50 | id, null, ownerAddress 51 | ); 52 | return ref; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.raydp; 19 | 20 | import io.ray.api.ActorHandle; 21 | import io.ray.api.ObjectRef; 22 | import io.ray.api.Ray; 23 | import io.ray.api.call.ActorCreator; 24 | import java.util.Map; 25 | import java.util.List; 26 | 27 | import io.ray.api.placementgroup.PlacementGroup; 28 | import io.ray.runtime.object.ObjectRefImpl; 29 | import org.apache.spark.executor.RayDPExecutor; 30 | 31 | public class RayExecutorUtils { 32 | /** 33 | * Convert from mbs -> memory units. The memory units in ray is byte 34 | */ 35 | 36 | private static double toMemoryUnits(int memoryInMB) { 37 | double result = 1.0 * memoryInMB * 1024 * 1024; 38 | return Math.round(result); 39 | } 40 | 41 | public static ActorHandle createExecutorActor( 42 | String executorId, 43 | String appMasterURL, 44 | double cores, 45 | int memoryInMB, 46 | Map resources, 47 | PlacementGroup placementGroup, 48 | int bundleIndex, 49 | List javaOpts) { 50 | ActorCreator creator = Ray.actor( 51 | RayDPExecutor::new, executorId, appMasterURL); 52 | creator.setName("raydp-executor-" + executorId); 53 | creator.setJvmOptions(javaOpts); 54 | creator.setResource("CPU", cores); 55 | creator.setResource("memory", toMemoryUnits(memoryInMB)); 56 | 57 | for (Map.Entry entry: resources.entrySet()) { 58 | creator.setResource(entry.getKey(), entry.getValue()); 59 | } 60 | if (placementGroup != null) { 61 | creator.setPlacementGroup(placementGroup, bundleIndex); 62 | } 63 | creator.setMaxRestarts(-1); 64 | creator.setMaxTaskRetries(-1); 65 | creator.setMaxConcurrency(2); 66 | return creator.remote(); 67 | } 68 | 69 | public static void setUpExecutor( 70 | ActorHandle handler, 71 | String appId, 72 | String driverUrl, 73 | int cores, 74 | String classPathEntries) { 75 | handler.task(RayDPExecutor::startUp, 76 | appId, driverUrl, cores, classPathEntries).remote(); 77 | } 78 | 79 | public static String[] getBlockLocations( 80 | ActorHandle handler, 81 | int rddId, 82 | int numPartitions) { 83 | return handler.task(RayDPExecutor::getBlockLocations, 84 | rddId, numPartitions).remote().get(); 85 | } 86 | 87 | public static ObjectRef getRDDPartition( 88 | ActorHandle handle, 89 | int rddId, 90 | int partitionId, 91 | String schema, 92 | String driverAgentUrl) { 93 | return (ObjectRefImpl) handle.task( 94 | RayDPExecutor::getRDDPartition, 95 | rddId, partitionId, schema, driverAgentUrl).remote(); 96 | } 97 | 98 | public static void exitExecutor( 99 | ActorHandle handle 100 | ) { 101 | handle.task(RayDPExecutor::stop).remote(); 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager: -------------------------------------------------------------------------------- 1 | org.apache.spark.scheduler.cluster.raydp.RayClusterManager -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/RayDPException.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | class RayDPException(message: String, cause: Throwable) 21 | extends SparkException(message, cause) { 22 | def this(message: String) = this(message, null) 23 | } 24 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterEntryPoint.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import java.io.{DataOutputStream, File, FileOutputStream} 21 | import java.net.InetAddress 22 | import java.nio.file.Files 23 | 24 | import scala.util.Try 25 | 26 | import py4j.GatewayServer 27 | 28 | import org.apache.spark.internal.Logging 29 | 30 | 31 | class AppMasterEntryPoint { 32 | private val appMaster: AppMasterJavaBridge = new AppMasterJavaBridge() 33 | 34 | def getAppMasterBridge(): AppMasterJavaBridge = { 35 | appMaster 36 | } 37 | } 38 | 39 | object AppMasterEntryPoint extends Logging { 40 | private val localhost = InetAddress.getLoopbackAddress() 41 | 42 | def getGatewayServer(): GatewayServer = { 43 | new GatewayServer.GatewayServerBuilder() 44 | .javaPort(0) 45 | .javaAddress(localhost) 46 | .entryPoint(new AppMasterEntryPoint()) 47 | .build() 48 | } 49 | 50 | def main(args: Array[String]): Unit = { 51 | 52 | var server = getGatewayServer() 53 | 54 | while(true) { 55 | if (!Try(server.start()).isFailure) { 56 | val boundPort: Int = server.getListeningPort() 57 | if (boundPort == -1) { 58 | logError(s"${server.getClass} failed to bind; exiting") 59 | System.exit(1) 60 | } else { 61 | logDebug(s"Started PythonGatewayServer on port $boundPort") 62 | } 63 | 64 | 65 | val connectionInfoPath = new File(sys.env("_RAYDP_APPMASTER_CONN_INFO_PATH")) 66 | val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), 67 | "connection", ".info").toFile() 68 | 69 | val dos = new DataOutputStream(new FileOutputStream(tmpPath)) 70 | dos.writeInt(boundPort) 71 | dos.close() 72 | 73 | if (!tmpPath.renameTo(connectionInfoPath)) { 74 | logError(s"Unable to write connection information to $connectionInfoPath.") 75 | System.exit(1) 76 | } 77 | 78 | // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: 79 | while (System.in.read() != -1) { 80 | // Do nothing 81 | } 82 | logDebug("Exiting due to broken pipe from Python driver") 83 | System.exit(0) 84 | } else { 85 | server.shutdown() 86 | logError(s"${server.getClass} failed to bind; retrying...") 87 | Thread.sleep(1000) 88 | server = getGatewayServer() 89 | } 90 | } 91 | 92 | 93 | 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterJavaBridge.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import java.util.Map 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import io.ray.api.{ActorHandle, Ray} 25 | 26 | import org.apache.spark.raydp.SparkOnRayConfigs 27 | 28 | class AppMasterJavaBridge { 29 | private var handle: ActorHandle[RayAppMaster] = null 30 | 31 | def startUpAppMaster(extra_cp: String, sparkProps: Map[String, Any]): Unit = { 32 | if (handle == null) { 33 | // init ray, we should set the config by java properties 34 | Ray.init() 35 | val name = RayAppMaster.ACTOR_NAME 36 | val sparkJvmOptions = sparkProps.asScala.toMap.filter( 37 | e => !SparkOnRayConfigs.SPARK_DRIVER_EXTRA_JAVA_OPTIONS.equals(e._1)) 38 | .map { 39 | case (k, v) => 40 | if (!SparkOnRayConfigs.SPARK_JAVAAGENT.equals(k)) { 41 | "-D" + k + "=" + v 42 | } else { 43 | "-javaagent:" + v 44 | } 45 | }.toBuffer 46 | 47 | val appMasterResources = sparkProps.asScala.filter { 48 | case (k, v) => k.startsWith(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX) 49 | }.map{ case (k, v) => k->double2Double(v.toString.toDouble) }.asJava 50 | 51 | handle = RayAppMasterUtils.createAppMaster( 52 | extra_cp, name, 53 | (sparkJvmOptions ++ Seq(SparkOnRayConfigs.RAYDP_LOGFILE_PREFIX_CFG)).asJava, 54 | appMasterResources) 55 | } 56 | } 57 | 58 | def getMasterUrl(): String = { 59 | if (handle == null) { 60 | throw new RuntimeException("You should create the RayAppMaster handle first") 61 | } 62 | RayAppMasterUtils.getMasterUrl(handle) 63 | } 64 | 65 | def stop(): Unit = { 66 | if (handle != null) { 67 | RayAppMasterUtils.stopAppMaster(handle) 68 | Ray.shutdown() 69 | handle = null 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationDescription.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import scala.collection.Map 21 | 22 | private[spark] case class Command( 23 | driverUrl: String, 24 | environment: Map[String, String], 25 | classPathEntries: Seq[String], 26 | libraryPathEntries: Seq[String], 27 | javaOpts: Seq[String]) { 28 | 29 | def withNewJavaOpts(newJavaOptions: Seq[String]): Command = { 30 | Command(driverUrl, environment, classPathEntries, libraryPathEntries, newJavaOptions) 31 | } 32 | } 33 | 34 | private[spark] case class ApplicationDescription( 35 | name: String, 36 | numExecutors: Int, 37 | coresPerExecutor: Option[Int], 38 | memoryPerExecutorMB: Int, 39 | rayActorCPU: Double, 40 | command: Command, 41 | user: String = System.getProperty("user.name", ""), 42 | resourceReqsPerExecutor: Map[String, Double] = Map.empty) { 43 | 44 | def withNewCommand(newCommand: Command): ApplicationDescription = { 45 | ApplicationDescription(name = name, 46 | numExecutors = numExecutors, coresPerExecutor = coresPerExecutor, 47 | memoryPerExecutorMB = memoryPerExecutorMB, command = newCommand, user = user, 48 | resourceReqsPerExecutor = resourceReqsPerExecutor, 49 | rayActorCPU = rayActorCPU) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationState.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | object ApplicationState extends Enumeration { 21 | 22 | type ApplicationState = Value 23 | 24 | val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value 25 | } 26 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import org.apache.spark.rpc.RpcEndpointRef 21 | 22 | private[deploy] sealed trait RayDPDeployMessage extends Serializable 23 | 24 | case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) 25 | extends RayDPDeployMessage 26 | 27 | case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends RayDPDeployMessage 28 | 29 | case class UnregisterApplication(appId: String) extends RayDPDeployMessage 30 | 31 | case class RegisterExecutor(executorId: String, nodeIp: String) extends RayDPDeployMessage 32 | 33 | case class ExecutorStarted(executorId: String) extends RayDPDeployMessage 34 | 35 | case class RequestExecutors(appId: String, requestedTotal: Int) extends RayDPDeployMessage 36 | 37 | case class KillExecutors(appId: String, executorIds: Seq[String]) extends RayDPDeployMessage 38 | 39 | case class RequestAddPendingRestartedExecutor(executorId: String) 40 | extends RayDPDeployMessage 41 | 42 | case class AddPendingRestartedExecutorReply(newExecutorId: Option[String]) 43 | extends RayDPDeployMessage 44 | 45 | case class RecacheRDD(rddId: Int) extends RayDPDeployMessage 46 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayDPDriverAgent.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import io.ray.runtime.config.RayConfig 21 | 22 | import org.apache.spark.{SecurityManager, SparkConf, SparkContext} 23 | import org.apache.spark.internal.Logging 24 | import org.apache.spark.rpc._ 25 | 26 | 27 | class RayDPDriverAgent() { 28 | private val spark = SparkContext.getOrCreate() 29 | private var endpoint: RpcEndpointRef = _ 30 | private var rpcEnv: RpcEnv = _ 31 | private val conf: SparkConf = new SparkConf() 32 | 33 | init 34 | 35 | def init(): Unit = { 36 | val securityMgr = new SecurityManager(conf) 37 | val host = RayConfig.create().nodeIp 38 | rpcEnv = RpcEnv.create( 39 | RayAppMaster.ENV_NAME, 40 | host, 41 | host, 42 | 0, 43 | conf, 44 | securityMgr, 45 | // limit to single-thread 46 | numUsableCores = 1, 47 | clientMode = false) 48 | // register endpoint 49 | endpoint = rpcEnv.setupEndpoint(RayDPDriverAgent.ENDPOINT_NAME, 50 | new RayDPDriverAgentEndpoint(rpcEnv)) 51 | } 52 | 53 | def getDriverAgentEndpointUrl(): String = { 54 | RpcEndpointAddress(rpcEnv.address, RayDPDriverAgent.ENDPOINT_NAME).toString 55 | } 56 | 57 | class RayDPDriverAgentEndpoint(override val rpcEnv: RpcEnv) 58 | extends ThreadSafeRpcEndpoint with Logging { 59 | override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { 60 | case RecacheRDD(rddId) => 61 | // TODO if multiple blocks get lost, should call this only once 62 | // SparkEnv.get.blockManagerMaster.getLocationsAndStatus() 63 | spark.getPersistentRDDs.map { 64 | case (id, rdd) => 65 | if (id == rddId) { 66 | rdd.count 67 | } 68 | } 69 | context.reply(true) 70 | } 71 | } 72 | 73 | } 74 | 75 | object RayDPDriverAgent { 76 | val ENDPOINT_NAME = "RAYDP_DRIVER_AGENT" 77 | } 78 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayExternalShuffleService.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import io.ray.api.Ray; 21 | 22 | import org.apache.spark.{SecurityManager, SparkConf} 23 | import org.apache.spark.deploy.ExternalShuffleService 24 | import org.apache.spark.internal.Logging 25 | 26 | class RayExternalShuffleService() extends Logging { 27 | val conf = new SparkConf() 28 | val mgr = new SecurityManager(conf) 29 | val instance = new ExternalShuffleService(conf, mgr) 30 | 31 | def start(): Unit = { 32 | instance.start() 33 | } 34 | 35 | def stop(): Unit = { 36 | instance.stop() 37 | Ray.exitActor() 38 | } 39 | } 40 | 41 | object RayExternalShuffleService { 42 | def getShuffleConf(conf: SparkConf): Array[String] = { 43 | // all conf needed by external shuffle service 44 | var shuffleConf = conf.getAll.filter { 45 | case (k, v) => k.startsWith("spark.shuffle") 46 | }.map { 47 | case (k, v) => 48 | "-D" + k + "=" + v 49 | } 50 | val localDirKey = "spark.local.dir" 51 | if (conf.contains(localDirKey)) { 52 | shuffleConf = shuffleConf :+ 53 | "-D" + localDirKey + "=" + conf.get(localDirKey) 54 | } 55 | shuffleConf 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.rdd 19 | 20 | import java.util.List; 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import io.ray.runtime.generated.Common.Address 25 | 26 | import org.apache.spark.{Partition, SparkContext, TaskContext} 27 | import org.apache.spark.api.java.JavaSparkContext 28 | import org.apache.spark.raydp.RayDPUtils 29 | import org.apache.spark.sql.raydp.ObjectStoreReader 30 | 31 | private[spark] class RayDatasetRDDPartition(val ref: Array[Byte], idx: Int) extends Partition { 32 | val index = idx 33 | } 34 | 35 | private[spark] 36 | class RayDatasetRDD( 37 | jsc: JavaSparkContext, 38 | @transient val objectIds: List[Array[Byte]], 39 | locations: List[Array[Byte]]) 40 | extends RDD[Array[Byte]](jsc.sc, Nil) { 41 | 42 | override def getPartitions: Array[Partition] = { 43 | objectIds.asScala.zipWithIndex.map { case (k, i) => 44 | new RayDatasetRDDPartition(k, i).asInstanceOf[Partition] 45 | }.toArray 46 | } 47 | 48 | override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { 49 | val ref = split.asInstanceOf[RayDatasetRDDPartition].ref 50 | ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index)) 51 | } 52 | 53 | override def getPreferredLocations(split: Partition): Seq[String] = { 54 | val address = Address.parseFrom(locations.get(split.index)) 55 | Seq(address.getIpAddress()) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.rdd 19 | 20 | import java.util.List; 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import io.ray.runtime.generated.Common.Address 25 | 26 | import org.apache.spark.{Partition, SparkContext, TaskContext} 27 | import org.apache.spark.raydp.RayDPUtils 28 | import org.apache.spark.sql.Row 29 | 30 | private[spark] class RayObjectRefRDDPartition(idx: Int) extends Partition { 31 | val index = idx 32 | } 33 | 34 | private[spark] 35 | class RayObjectRefRDD( 36 | sc: SparkContext, 37 | locations: List[Array[Byte]]) 38 | extends RDD[Row](sc, Nil) { 39 | 40 | override def getPartitions: Array[Partition] = { 41 | (0 until locations.size()).map { i => 42 | new RayObjectRefRDDPartition(i).asInstanceOf[Partition] 43 | }.toArray 44 | } 45 | 46 | override def compute(split: Partition, context: TaskContext): Iterator[Row] = { 47 | (Row(split.index) :: Nil).iterator 48 | } 49 | 50 | override def getPreferredLocations(split: Partition): Seq[String] = { 51 | Seq(Address.parseFrom(locations.get(split.index)).getIpAddress()) 52 | } 53 | } 54 | 55 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayClusterManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.scheduler.cluster.raydp 19 | 20 | import org.apache.spark.SparkContext 21 | import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} 22 | 23 | private[spark] class RayClusterManager extends ExternalClusterManager { 24 | 25 | override def canCreate(masterURL: String): Boolean = { 26 | masterURL.startsWith("ray") 27 | } 28 | 29 | override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { 30 | new TaskSchedulerImpl(sc) 31 | } 32 | 33 | override def createSchedulerBackend( 34 | sc: SparkContext, 35 | masterURL: String, 36 | scheduler: TaskScheduler): SchedulerBackend = { 37 | new RayCoarseGrainedSchedulerBackend( 38 | sc, 39 | scheduler.asInstanceOf[TaskSchedulerImpl], 40 | masterURL) 41 | } 42 | 43 | override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { 44 | scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.raydp 19 | 20 | import java.io.ByteArrayInputStream 21 | import java.nio.channels.{Channels, ReadableByteChannel} 22 | import java.util.List 23 | 24 | import com.intel.raydp.shims.SparkShimLoader 25 | 26 | import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} 27 | import org.apache.spark.raydp.RayDPUtils 28 | import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD} 29 | import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} 30 | import org.apache.spark.sql.catalyst.expressions.GenericRow 31 | import org.apache.spark.sql.execution.arrow.ArrowConverters 32 | import org.apache.spark.sql.types.{IntegerType, StructType} 33 | 34 | object ObjectStoreReader { 35 | def createRayObjectRefDF( 36 | spark: SparkSession, 37 | locations: List[Array[Byte]]): DataFrame = { 38 | val rdd = new RayObjectRefRDD(spark.sparkContext, locations) 39 | val schema = new StructType().add("idx", IntegerType) 40 | spark.createDataFrame(rdd, schema) 41 | } 42 | 43 | def RayDatasetToDataFrame( 44 | sparkSession: SparkSession, 45 | rdd: RayDatasetRDD, 46 | schema: String): DataFrame = { 47 | SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession) 48 | } 49 | 50 | def getBatchesFromStream( 51 | ref: Array[Byte], 52 | ownerAddress: Array[Byte]): Iterator[Array[Byte]] = { 53 | val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress) 54 | ArrowConverters.getBatchesFromStream( 55 | Channels.newChannel(new ByteArrayInputStream(objectRef.get))) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/test/org/apache/spark/scheduler/cluster/raydp/TestRayCoarseGrainedSchedulerBackend.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.scheduler.cluster.raydp; 18 | 19 | import org.apache.spark.SparkConf; 20 | import org.junit.jupiter.api.Test; 21 | 22 | import org.apache.spark.scheduler.cluster.SchedulerBackendUtils; 23 | 24 | import static org.junit.jupiter.api.Assertions.assertEquals; 25 | 26 | /** 27 | * This class performs unit testing on some methods in `RayCoarseGrainedSchedulerBackend`. 28 | */ 29 | public class TestRayCoarseGrainedSchedulerBackend { 30 | 31 | // Test using the default value. 32 | @Test 33 | public void testExecutorNumberWithDefaultConfig() { 34 | SparkConf conf = new SparkConf(); 35 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 36 | assertEquals(2, executorNumber); 37 | } 38 | 39 | // Test using a negative value. 40 | @Test 41 | public void testExecutorNumberWithNegativeConfig() { 42 | SparkConf conf = new SparkConf(); 43 | conf.set("spark.dynamicAllocation.initialExecutors", "-1"); 44 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 45 | assertEquals(2, executorNumber); 46 | } 47 | 48 | // Test using reasonable values. 49 | @Test 50 | public void testExecutorNumberWithValidConfig() { 51 | SparkConf conf = new SparkConf(); 52 | conf.set("spark.executor.instances", "5"); 53 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 54 | assertEquals(5, executorNumber); 55 | } 56 | 57 | // Test using dynamic values. 58 | @Test 59 | public void testExecutorNumberWithDynamicConfig() { 60 | SparkConf conf = new SparkConf(); 61 | conf.set("spark.dynamicAllocation.enabled", "true"); 62 | conf.set("spark.dynamicAllocation.minExecutors", "3"); 63 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 64 | assertEquals(3, executorNumber); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /core/shims/common/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-common 15 | RayDP Shims Common 16 | 1.7.0-SNAPSHOT 17 | jar 18 | 19 | 20 | 21 | 22 | org.scalastyle 23 | scalastyle-maven-plugin 24 | 25 | 26 | net.alchim31.maven 27 | scala-maven-plugin 28 | 3.2.2 29 | 30 | 31 | scala-compile-first 32 | process-resources 33 | 34 | compile 35 | 36 | 37 | 38 | scala-test-compile-first 39 | process-test-resources 40 | 41 | testCompile 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | org.apache.spark 52 | spark-sql_${scala.binary.version} 53 | ${spark.version} 54 | provided 55 | 56 | 57 | com.google.protobuf 58 | protobuf-java 59 | 60 | 61 | 62 | 63 | org.apache.spark 64 | spark-core_${scala.binary.version} 65 | ${spark.version} 66 | provided 67 | 68 | 69 | org.xerial.snappy 70 | snappy-java 71 | 72 | 73 | org.apache.commons 74 | commons-compress 75 | 76 | 77 | org.apache.commons 78 | commons-text 79 | 80 | 81 | org.apache.ivy 82 | ivy 83 | 84 | 85 | log4j 86 | log4j 87 | 88 | 89 | 90 | 91 | org.xerial.snappy 92 | snappy-java 93 | ${snappy.version} 94 | 95 | 96 | org.apache.commons 97 | commons-text 98 | ${commons.text.version} 99 | 100 | 101 | com.google.protobuf 102 | protobuf-java 103 | ${protobuf.version} 104 | 105 | 106 | org.apache.ivy 107 | ivy 108 | ${ivy.version} 109 | 110 | 111 | org.apache.commons 112 | commons-compress 113 | ${commons.compress.version} 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimLoader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims 19 | 20 | import java.util.ServiceLoader 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import org.apache.spark.SPARK_VERSION_SHORT 25 | import org.apache.spark.internal.Logging 26 | 27 | object SparkShimLoader extends Logging { 28 | private var sparkShims: SparkShims = null 29 | private var sparkShimProviderClass: String = null 30 | 31 | def getSparkShims: SparkShims = { 32 | if (sparkShims == null) { 33 | val provider = getSparkShimProvider() 34 | sparkShims = provider.createShim 35 | } 36 | sparkShims 37 | } 38 | 39 | def getSparkVersion: String = { 40 | SPARK_VERSION_SHORT 41 | } 42 | 43 | def setSparkShimProviderClass(providerClass: String): Unit = { 44 | sparkShimProviderClass = providerClass 45 | } 46 | 47 | private def loadSparkShimProvider(): SparkShimProvider = { 48 | // Match and load Shim provider for current Spark version. 49 | val sparkVersion = getSparkVersion 50 | logInfo(s"Loading Spark Shims for version: $sparkVersion") 51 | 52 | // Load and filter the providers based on version 53 | val shimProviders = 54 | ServiceLoader.load(classOf[SparkShimProvider]).asScala.filter(_.matches(sparkVersion)) 55 | if (shimProviders.size > 1) { 56 | throw new IllegalStateException(s"More than one SparkShimProvider found: $shimProviders") 57 | } 58 | 59 | val shimProvider = shimProviders.headOption match { 60 | case Some(shimProvider) => shimProvider 61 | case None => 62 | throw new IllegalStateException(s"No Spark Shim Provider found for $sparkVersion") 63 | } 64 | logInfo(s"Using Shim provider: $shimProviders") 65 | shimProvider 66 | } 67 | 68 | private def getSparkShimProvider(): SparkShimProvider = { 69 | if (sparkShimProviderClass != null) { 70 | logInfo(s"Using Spark Shim Provider specified by $sparkShimProviderClass. ") 71 | val providerClass = Class.forName(sparkShimProviderClass) 72 | val providerConstructor = providerClass.getConstructor() 73 | providerConstructor.newInstance().asInstanceOf[SparkShimProvider] 74 | } else { 75 | loadSparkShimProvider() 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims 19 | 20 | /** 21 | * Provider interface for matching and retrieving the Shims of a specific Spark version 22 | */ 23 | trait SparkShimProvider { 24 | def matches(version:String): Boolean 25 | def createShim: SparkShims 26 | } 27 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.{SparkEnv, TaskContext} 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.executor.RayDPExecutorBackendFactory 24 | import org.apache.spark.sql.types.StructType 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | 27 | sealed abstract class ShimDescriptor 28 | 29 | case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimDescriptor { 30 | override def toString(): String = s"$major.$minor.$patch" 31 | } 32 | 33 | trait SparkShims { 34 | def getShimDescriptor: ShimDescriptor 35 | 36 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame 37 | 38 | def getExecutorBackendFactory(): RayDPExecutorBackendFactory 39 | 40 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext 41 | 42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema 43 | } 44 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/org/apache/spark/executor/RayDPExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.rpc.RpcEnv 24 | import org.apache.spark.resource.ResourceProfile 25 | 26 | trait RayDPExecutorBackendFactory { 27 | def createExecutorBackend( 28 | rpcEnv: RpcEnv, 29 | driverUrl: String, 30 | executorId: String, 31 | bindAddress: String, 32 | hostname: String, 33 | cores: Int, 34 | userClassPath: Seq[URL], 35 | env: SparkEnv, 36 | resourcesFileOpt: Option[String], 37 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend 38 | } 39 | -------------------------------------------------------------------------------- /core/shims/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-parent 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims 15 | RayDP Shims 16 | pom 17 | 18 | 19 | common 20 | spark322 21 | spark330 22 | spark340 23 | spark350 24 | 25 | 26 | 27 | 2.12 28 | 4.3.0 29 | 3.2.2 30 | 31 | 32 | 33 | 34 | 35 | net.alchim31.maven 36 | scala-maven-plugin 37 | ${scala.plugin.version} 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /core/shims/spark322/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark322 15 | RayDP Shims for Spark 3.2.2 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark322.version} 70 | provided 71 | 72 | 73 | com.google.protobuf 74 | protobuf-java 75 | 76 | 77 | 78 | 79 | org.apache.spark 80 | spark-core_${scala.binary.version} 81 | ${spark322.version} 82 | provided 83 | 84 | 85 | org.xerial.snappy 86 | snappy-java 87 | 88 | 89 | org.apache.commons 90 | commons-compress 91 | 92 | 93 | org.apache.commons 94 | commons-text 95 | 96 | 97 | org.apache.ivy 98 | ivy 99 | 100 | 101 | log4j 102 | log4j 103 | 104 | 105 | 106 | 107 | org.xerial.snappy 108 | snappy-java 109 | ${snappy.version} 110 | 111 | 112 | org.apache.commons 113 | commons-compress 114 | ${commons.compress.version} 115 | 116 | 117 | org.apache.commons 118 | commons-text 119 | ${commons.text.version} 120 | 121 | 122 | org.apache.ivy 123 | ivy 124 | ${ivy.version} 125 | 126 | 127 | com.google.protobuf 128 | protobuf-java 129 | ${protobuf.version} 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark322.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark322 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK311_DESCRIPTOR = SparkShimDescriptor(3, 1, 1) 24 | val SPARK312_DESCRIPTOR = SparkShimDescriptor(3, 1, 2) 25 | val SPARK313_DESCRIPTOR = SparkShimDescriptor(3, 1, 3) 26 | val SPARK320_DESCRIPTOR = SparkShimDescriptor(3, 2, 0) 27 | val SPARK321_DESCRIPTOR = SparkShimDescriptor(3, 2, 1) 28 | val SPARK322_DESCRIPTOR = SparkShimDescriptor(3, 2, 2) 29 | val SPARK323_DESCRIPTOR = SparkShimDescriptor(3, 2, 3) 30 | val SPARK324_DESCRIPTOR = SparkShimDescriptor(3, 2, 4) 31 | val DESCRIPTOR_STRINGS = 32 | Seq(s"$SPARK311_DESCRIPTOR", s"$SPARK312_DESCRIPTOR" ,s"$SPARK313_DESCRIPTOR", 33 | s"$SPARK320_DESCRIPTOR", s"$SPARK321_DESCRIPTOR", s"$SPARK322_DESCRIPTOR", 34 | s"$SPARK323_DESCRIPTOR", s"$SPARK324_DESCRIPTOR") 35 | val DESCRIPTOR = SPARK323_DESCRIPTOR 36 | } 37 | 38 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 39 | def createShim: SparkShims = { 40 | new Spark322Shims() 41 | } 42 | 43 | def matches(version: String): Boolean = { 44 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark322 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark322._ 24 | import org.apache.spark.spark322.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark322.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark322Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark322ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark322 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark322 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor.CoarseGrainedExecutorBackend 24 | import org.apache.spark.executor.RayDPExecutorBackendFactory 25 | import org.apache.spark.resource.ResourceProfile 26 | import org.apache.spark.rpc.RpcEnv 27 | 28 | class RayDPSpark322ExecutorBackendFactory 29 | extends RayDPExecutorBackendFactory { 30 | override def createExecutorBackend( 31 | rpcEnv: RpcEnv, 32 | driverUrl: String, 33 | executorId: String, 34 | bindAddress: String, 35 | hostname: String, 36 | cores: Int, 37 | userClassPath: Seq[URL], 38 | env: SparkEnv, 39 | resourcesFileOpt: Option[String], 40 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 41 | new CoarseGrainedExecutorBackend( 42 | rpcEnv, 43 | driverUrl, 44 | executorId, 45 | bindAddress, 46 | hostname, 47 | cores, 48 | userClassPath, 49 | env, 50 | resourcesFileOpt, 51 | resourceProfile) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark322 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 23 | import org.apache.spark.sql.execution.arrow.ArrowConverters 24 | import org.apache.spark.sql.types.StructType 25 | import org.apache.spark.sql.util.ArrowUtils 26 | 27 | object SparkSqlUtils { 28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { 29 | ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session)) 30 | } 31 | 32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /core/shims/spark330/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark330 15 | RayDP Shims for Spark 3.3.0 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark330.version} 70 | provided 71 | 72 | 73 | com.google.protobuf 74 | protobuf-java 75 | 76 | 77 | 78 | 79 | org.apache.spark 80 | spark-core_${scala.binary.version} 81 | ${spark330.version} 82 | provided 83 | 84 | 85 | org.xerial.snappy 86 | snappy-java 87 | 88 | 89 | io.netty 90 | netty-handler 91 | 92 | 93 | org.apache.commons 94 | commons-text 95 | 96 | 97 | org.apache.ivy 98 | ivy 99 | 100 | 101 | 102 | 103 | org.xerial.snappy 104 | snappy-java 105 | ${snappy.version} 106 | 107 | 108 | io.netty 109 | netty-handler 110 | ${netty.version} 111 | 112 | 113 | org.apache.commons 114 | commons-text 115 | ${commons.text.version} 116 | 117 | 118 | org.apache.ivy 119 | ivy 120 | ${ivy.version} 121 | 122 | 123 | com.google.protobuf 124 | protobuf-java 125 | ${protobuf.version} 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark330.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark330 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK330_DESCRIPTOR = SparkShimDescriptor(3, 3, 0) 24 | val SPARK331_DESCRIPTOR = SparkShimDescriptor(3, 3, 1) 25 | val SPARK332_DESCRIPTOR = SparkShimDescriptor(3, 3, 2) 26 | val SPARK333_DESCRIPTOR = SparkShimDescriptor(3, 3, 3) 27 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK330_DESCRIPTOR", s"$SPARK331_DESCRIPTOR", 28 | s"$SPARK332_DESCRIPTOR", s"$SPARK333_DESCRIPTOR") 29 | val DESCRIPTOR = SPARK332_DESCRIPTOR 30 | } 31 | 32 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 33 | def createShim: SparkShims = { 34 | new Spark330Shims() 35 | } 36 | 37 | def matches(version: String): Boolean = { 38 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark330 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark330._ 24 | import org.apache.spark.spark330.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark330.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark330Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark330ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark330 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.resource.ResourceProfile 24 | import org.apache.spark.rpc.RpcEnv 25 | 26 | class RayCoarseGrainedExecutorBackend( 27 | rpcEnv: RpcEnv, 28 | driverUrl: String, 29 | executorId: String, 30 | bindAddress: String, 31 | hostname: String, 32 | cores: Int, 33 | userClassPath: Seq[URL], 34 | env: SparkEnv, 35 | resourcesFileOpt: Option[String], 36 | resourceProfile: ResourceProfile) 37 | extends CoarseGrainedExecutorBackend( 38 | rpcEnv, 39 | driverUrl, 40 | executorId, 41 | bindAddress, 42 | hostname, 43 | cores, 44 | env, 45 | resourcesFileOpt, 46 | resourceProfile) { 47 | 48 | override def getUserClassPath: Seq[URL] = userClassPath 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark330 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor._ 24 | import org.apache.spark.resource.ResourceProfile 25 | import org.apache.spark.rpc.RpcEnv 26 | 27 | class RayDPSpark330ExecutorBackendFactory 28 | extends RayDPExecutorBackendFactory { 29 | override def createExecutorBackend( 30 | rpcEnv: RpcEnv, 31 | driverUrl: String, 32 | executorId: String, 33 | bindAddress: String, 34 | hostname: String, 35 | cores: Int, 36 | userClassPath: Seq[URL], 37 | env: SparkEnv, 38 | resourcesFileOpt: Option[String], 39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 40 | new RayCoarseGrainedExecutorBackend( 41 | rpcEnv, 42 | driverUrl, 43 | executorId, 44 | bindAddress, 45 | hostname, 46 | cores, 47 | userClassPath, 48 | env, 49 | resourcesFileOpt, 50 | resourceProfile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark330 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 23 | import org.apache.spark.sql.execution.arrow.ArrowConverters 24 | import org.apache.spark.sql.types.StructType 25 | import org.apache.spark.sql.util.ArrowUtils 26 | 27 | object SparkSqlUtils { 28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { 29 | ArrowConverters.toDataFrame(rdd, schema, session) 30 | } 31 | 32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /core/shims/spark340/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark340 15 | RayDP Shims for Spark 3.4.0 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark340.version} 70 | provided 71 | 72 | 73 | org.apache.spark 74 | spark-core_${scala.binary.version} 75 | ${spark340.version} 76 | provided 77 | 78 | 79 | org.xerial.snappy 80 | snappy-java 81 | 82 | 83 | io.netty 84 | netty-handler 85 | 86 | 87 | 88 | 89 | org.xerial.snappy 90 | snappy-java 91 | ${snappy.version} 92 | 93 | 94 | io.netty 95 | netty-handler 96 | ${netty.version} 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark340.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark340 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0) 24 | val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1) 25 | val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2) 26 | val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3) 27 | val SPARK344_DESCRIPTOR = SparkShimDescriptor(3, 4, 4) 28 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR", 29 | s"$SPARK343_DESCRIPTOR", s"$SPARK344_DESCRIPTOR") 30 | val DESCRIPTOR = SPARK341_DESCRIPTOR 31 | } 32 | 33 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 34 | def createShim: SparkShims = { 35 | new Spark340Shims() 36 | } 37 | 38 | def matches(version: String): Boolean = { 39 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark340 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark340._ 24 | import org.apache.spark.spark340.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark340.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark340Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark340ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark340 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.resource.ResourceProfile 24 | import org.apache.spark.rpc.RpcEnv 25 | 26 | class RayCoarseGrainedExecutorBackend( 27 | rpcEnv: RpcEnv, 28 | driverUrl: String, 29 | executorId: String, 30 | bindAddress: String, 31 | hostname: String, 32 | cores: Int, 33 | userClassPath: Seq[URL], 34 | env: SparkEnv, 35 | resourcesFileOpt: Option[String], 36 | resourceProfile: ResourceProfile) 37 | extends CoarseGrainedExecutorBackend( 38 | rpcEnv, 39 | driverUrl, 40 | executorId, 41 | bindAddress, 42 | hostname, 43 | cores, 44 | env, 45 | resourcesFileOpt, 46 | resourceProfile) { 47 | 48 | override def getUserClassPath: Seq[URL] = userClassPath 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark340 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor._ 24 | import org.apache.spark.resource.ResourceProfile 25 | import org.apache.spark.rpc.RpcEnv 26 | 27 | class RayDPSpark340ExecutorBackendFactory 28 | extends RayDPExecutorBackendFactory { 29 | override def createExecutorBackend( 30 | rpcEnv: RpcEnv, 31 | driverUrl: String, 32 | executorId: String, 33 | bindAddress: String, 34 | hostname: String, 35 | cores: Int, 36 | userClassPath: Seq[URL], 37 | env: SparkEnv, 38 | resourcesFileOpt: Option[String], 39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 40 | new RayCoarseGrainedExecutorBackend( 41 | rpcEnv, 42 | driverUrl, 43 | executorId, 44 | bindAddress, 45 | hostname, 46 | cores, 47 | userClassPath, 48 | env, 49 | resourcesFileOpt, 50 | resourceProfile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark340 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.TaskContext 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 24 | import org.apache.spark.sql.execution.arrow.ArrowConverters 25 | import org.apache.spark.sql.types._ 26 | import org.apache.spark.sql.util.ArrowUtils 27 | 28 | object SparkSqlUtils { 29 | def toDataFrame( 30 | arrowBatchRDD: JavaRDD[Array[Byte]], 31 | schemaString: String, 32 | session: SparkSession): DataFrame = { 33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] 34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone 35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter => 36 | val context = TaskContext.get() 37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) 38 | } 39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema) 40 | } 41 | 42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /core/shims/spark350/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark350 15 | RayDP Shims for Spark 3.5.0 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark350.version} 70 | provided 71 | 72 | 73 | org.apache.spark 74 | spark-core_${scala.binary.version} 75 | ${spark350.version} 76 | provided 77 | 78 | 79 | org.xerial.snappy 80 | snappy-java 81 | 82 | 83 | io.netty 84 | netty-handler 85 | 86 | 87 | 88 | 89 | org.xerial.snappy 90 | snappy-java 91 | ${snappy.version} 92 | 93 | 94 | io.netty 95 | netty-handler 96 | ${netty.version} 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark350.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark350 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0) 24 | val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1) 25 | val SPARK352_DESCRIPTOR = SparkShimDescriptor(3, 5, 2) 26 | val SPARK353_DESCRIPTOR = SparkShimDescriptor(3, 5, 3) 27 | val SPARK354_DESCRIPTOR = SparkShimDescriptor(3, 5, 4) 28 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR", s"$SPARK352_DESCRIPTOR", 29 | s"$SPARK353_DESCRIPTOR", s"$SPARK354_DESCRIPTOR") 30 | val DESCRIPTOR = SPARK350_DESCRIPTOR 31 | } 32 | 33 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 34 | def createShim: SparkShims = { 35 | new Spark350Shims() 36 | } 37 | 38 | def matches(version: String): Boolean = { 39 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark350 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark350._ 24 | import org.apache.spark.spark350.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark350.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark350Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark350ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark350 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.resource.ResourceProfile 24 | import org.apache.spark.rpc.RpcEnv 25 | 26 | class RayCoarseGrainedExecutorBackend( 27 | rpcEnv: RpcEnv, 28 | driverUrl: String, 29 | executorId: String, 30 | bindAddress: String, 31 | hostname: String, 32 | cores: Int, 33 | userClassPath: Seq[URL], 34 | env: SparkEnv, 35 | resourcesFileOpt: Option[String], 36 | resourceProfile: ResourceProfile) 37 | extends CoarseGrainedExecutorBackend( 38 | rpcEnv, 39 | driverUrl, 40 | executorId, 41 | bindAddress, 42 | hostname, 43 | cores, 44 | env, 45 | resourcesFileOpt, 46 | resourceProfile) { 47 | 48 | override def getUserClassPath: Seq[URL] = userClassPath 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark350 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor._ 24 | import org.apache.spark.resource.ResourceProfile 25 | import org.apache.spark.rpc.RpcEnv 26 | 27 | class RayDPSpark350ExecutorBackendFactory 28 | extends RayDPExecutorBackendFactory { 29 | override def createExecutorBackend( 30 | rpcEnv: RpcEnv, 31 | driverUrl: String, 32 | executorId: String, 33 | bindAddress: String, 34 | hostname: String, 35 | cores: Int, 36 | userClassPath: Seq[URL], 37 | env: SparkEnv, 38 | resourcesFileOpt: Option[String], 39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 40 | new RayCoarseGrainedExecutorBackend( 41 | rpcEnv, 42 | driverUrl, 43 | executorId, 44 | bindAddress, 45 | hostname, 46 | cores, 47 | userClassPath, 48 | env, 49 | resourcesFileOpt, 50 | resourceProfile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark350 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.TaskContext 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 24 | import org.apache.spark.sql.execution.arrow.ArrowConverters 25 | import org.apache.spark.sql.types._ 26 | import org.apache.spark.sql.util.ArrowUtils 27 | 28 | object SparkSqlUtils { 29 | def toDataFrame( 30 | arrowBatchRDD: JavaRDD[Array[Byte]], 31 | schemaString: String, 32 | session: SparkSession): DataFrame = { 33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] 34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone 35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter => 36 | val context = TaskContext.get() 37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId,false, context) 38 | } 39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema) 40 | } 41 | 42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /doc/mpi.md: -------------------------------------------------------------------------------- 1 | # MPI on Ray 2 | 3 | RayDP also provides a simple API to running MPI job on top of Ray. Currently, we support three types of MPI: `intel_mpi`, `openmpi` and `MPICH`. To use the following API, make sure you have installed the given type of MPI on each of Ray worker node. 4 | 5 | ### API 6 | 7 | ```python 8 | def create_mpi_job(job_name: str, 9 | world_size: int, 10 | num_cpus_per_process: int, 11 | num_processes_per_node: int, 12 | mpi_script_prepare_fn: Callable = None, 13 | timeout: int = 1, 14 | mpi_type: str = "intel_mpi", 15 | placement_group=None, 16 | placement_group_bundle_indexes: List[int] = None) -> MPIJob: 17 | """ Create a MPI Job 18 | 19 | :param job_name: the job name 20 | :param world_size: the world size 21 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray 22 | :param num_processes_per_node: num processes per node 23 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a 24 | MPIJobcontext instance. It will use the default script if not provides. 25 | :param timeout: the timeout used to wait for job creation 26 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and mpich 27 | :param placement_group: the placement_group for request mpi resources 28 | :param placement_group_bundle_indexes: this should be equal with 29 | world_size / num_processes_per_node if provides. 30 | """ 31 | ``` 32 | 33 | ### Create a simple MPI Job 34 | 35 | ```python 36 | from raydp.mpi import create_mpi_job, MPIJobContext, WorkerContext 37 | 38 | # Define the MPI JOb. We want to create a 4 world_size MPIJob, and each process requires 2 cpus. 39 | # We have set the num_processes_per_node to 2, so the processes will be strictly spread into two nodes. 40 | 41 | # You could also to specify the placement group to reserve the resources for MPI job. The num_cpus_per_process 42 | # will be ignored if the placement group is provided. And the size of 43 | # placement_group_bundle_indexes should be equal with world_size // num_processes_per_node. 44 | job = create_mpi_job(job_name="example", 45 | world_size=4, 46 | num_cpus_per_process=2, 47 | num_processes_per_node=2, 48 | timeout=5, 49 | mpi_type="intel_mpi", 50 | placement_group=None, 51 | placement_group_bundle_indexes: List[int] = None) 52 | 53 | # Start the MPI Job, this will start up the MPI processes and connect to the ray cluster 54 | job.start() 55 | 56 | # define the MPI task function 57 | def func(context: WorkerContext): 58 | return context.job_id 59 | 60 | # run the MPI task, this is a blocking operation. And the results is a world_size array. 61 | results = job.run(func) 62 | 63 | # stop the MPI job 64 | job.stop() 65 | ``` 66 | 67 | ### Use `with` auto start/stop MPIJob 68 | ```python 69 | with create_mpi_job(job_name="example", 70 | world_size=4, 71 | num_cpus_per_process=2, 72 | num_processes_per_node=2, 73 | timeout=5, 74 | mpi_type="intel_mpi") as job: 75 | def f(context: WorkerContext): 76 | return context.job_id 77 | results = job.run(f) 78 | ``` 79 | 80 | ### Specify the MPI script and environments 81 | 82 | You could customize the MPI job environments and MPI scripts with `mpi_script_prepare_fn` argument. 83 | 84 | ```python 85 | def script_prepare_fn(context: MPIJobContext): 86 | context.add_env("OMP_NUM_THREADS", "2") 87 | default_script = ["mpirun", "--allow-run-as-root", "--tag-output", "-H", 88 | ",".join(context.hosts), "-N", f"{context.num_procs_per_node}"] 89 | return default_script 90 | 91 | job = create_mpi_job(job_name="example", 92 | world_size=4, 93 | num_cpus_per_process=2, 94 | num_processes_per_node=2, 95 | timeout=5, 96 | mpi_type="intel_mpi", 97 | mpi_script_prepare_fn=script_prepare_fn) 98 | ``` 99 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray:latest@sha256:c864e37f4ce516ff49425f69cac5503a51e84c333d30928416714a2c3da55b43 2 | 3 | ARG HTTP_PROXY 4 | ARG HTTPS_PROXY 5 | 6 | # set http_proxy & https_proxy 7 | ENV http_proxy=${HTTP_PROXY} 8 | ENV https_proxy=${HTTPS_PROXY} 9 | 10 | # install java, create workdir and install raydp 11 | # You could change the raydp to raydp-nightly if you want to try the master branch code 12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \ 13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \ 14 | && sudo mkdir /raydp \ 15 | && sudo chown -R ray /raydp \ 16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp 17 | 18 | WORKDIR /raydp 19 | 20 | # unset http_proxy & https_proxy 21 | ENV http_proxy= 22 | ENV https_proxy= 23 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Running RayDP on k8s cluster 2 | 3 | ## Build docker image 4 | Build the docker image to use in K8S with the following command, and this will create an image tag with `oap-project/raydp:latest` 5 | ```shell 6 | # under ${RAYDP_HOME}/docker 7 | ./build-docker.sh 8 | ``` 9 | 10 | You can install our nightly build with `pip install raydp --pre` or `pip install raydp-nightly`.To install raydp-nightly in the image, modify the following code in `Dockerfile`: 11 | ```Dockerfile 12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \ 13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \ 14 | && sudo mkdir /raydp \ 15 | && sudo chown -R ray /raydp \ 16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp-nightly 17 | ``` 18 | 19 | Meanwhile, you should install all dependencies of your application in the `Dockerfile`. If suitable, you can change the base image to `ray-ml`: 20 | ```Dockerfile 21 | FROM rayproject/ray-ml:latest 22 | ``` 23 | 24 | Then, you can push the built image to repository or spread to the k8s worker nodes. 25 | 26 | ## Deploy ray cluster with Helm 27 | You need to create a Helm chart first. To start with, check out this [example ray cluster Helm chart](https://github.com/ray-project/kuberay/tree/master/helm-chart/ray-cluster). You can clone this repo and copy this directory, then modify `values.yaml` to use the previously built image. 28 | 29 | ```yaml 30 | image: 31 | repository: oap-project/raydp 32 | tag: latest 33 | pullPolicy: IfNotPresent 34 | ``` 35 | 36 | You can also change other fields in this file to specify number of workers, etc. 37 | 38 | Then, you need to deploy the KubeRay operator first, please refer to [here](https://docs.ray.io/en/latest/cluster/kubernetes/getting-started.html#kuberay-quickstart) for instructions. You can now deploy a Ray cluster with RayDP installed via `helm install ray-cluster PATH_to_CHART`. 39 | 40 | ## Access the cluster 41 | Check here [here](https://docs.ray.io/en/master/cluster/kubernetes/getting-started.html#running-applications-on-a-ray-cluster) to see how to run applications on the cluster you just deployed. 42 | 43 | ## Legacy 44 | If you are using Ray versions before 2.0, you can try this command. 45 | ```shell 46 | ray up ${RAYDP_HOME}/docker/legacy.yaml 47 | ``` -------------------------------------------------------------------------------- /docker/build-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | docker build --build-arg HTTP_PROXY=${http_proxy} \ 21 | --build-arg HTTPS_PROXY=${https_proxy} \ 22 | -t oap-project/raydp:latest . 23 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # RayDP Examples 2 | Here are a few examples showing how RayDP works together with other libraries, such as PyTorch, Tensorflow, XGBoost and Horovod. 3 | 4 | In order to run these examples, you may need to install corresponding dependencies. For installation guides, please refer to their homepages. Notice that we need to install [xgboost_ray](https://github.com/ray-project/xgboost_ray) to run the xgboost example. In addition, if you are running the examples in a ray cluster, all nodes should have the dependencies installed. 5 | 6 | ## NYC Taxi Fare Prediction Dataset 7 | We have a few examples which use this dataset. 8 | You can run our examples right away after you clone our repo, because we include a small example dataset generated randomly using `examples/random_nyctaxi.py`. Generated datasets just demonstrates that our examples can work, but the trained models might not be meaningful. 9 | 10 | The original dataset can be downloaded [here](https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data). After you download it, please modify the variable `NYC_TRAIN_CSV` in `data_process.py` and point it to where `train.csv` is saved. 11 | 12 | ## Horovod 13 | To run the example, please install horovod via `pip install horovod[pytorch, ray]`. In addition, `HOROVOD_WITH_PYTORCH` and `HOROVOD_WITH_GLOO` should be set to `1` before pip. Notice that macOS users need to first install `libuv` via `brew install libuv`. Please refer to [here](https://horovod.readthedocs.io/en/stable/install_include.html) for details. 14 | 15 | When running `horovod_nyctaxi.py`, do not use `horovodrun`. Check [here](https://horovod.readthedocs.io/en/stable/ray_include.html) for more information. 16 | 17 | ## RaySGD Example 18 | In the RaySGD example, we demonstrate how to use our `MLDataset` API. After we use Spark to transform the dataset, we call `RayMLDataset.from_spark` to write the Spark DataFrames into Ray object store, using Apache Arrow format. We then convert the data to `pandas` DataFrame, hopefully zero-copy. Finally, they can be consumed by any framework supports `numpy` format, such as PyTorch or Tensorflow. `MLDataset` is partitioned, or sharded, just like Spark DataFrames. Their numbers of partitions are not required to be the same. However, the number of shards of `MLDataset` should be the same as the number of workers of `TorchTrainer` or `TFTrainer`, so that each worker is mapped to a shard. 19 | -------------------------------------------------------------------------------- /examples/pytorch_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import raydp 7 | from raydp.torch import TorchEstimator 8 | from raydp.utils import random_split 9 | 10 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV 11 | from typing import List, Dict 12 | 13 | # Firstly, You need to init or connect to a ray cluster. 14 | # Note that you should set include_java to True. 15 | # For more config info in ray, please refer the ray doc: 16 | # https://docs.ray.io/en/latest/package-ref.html 17 | # ray.init(address="auto") 18 | ray.init(address="local", num_cpus=4) 19 | 20 | # After initialize ray cluster, you can use the raydp api to get a spark session 21 | app_name = "NYC Taxi Fare Prediction with RayDP" 22 | num_executors = 1 23 | cores_per_executor = 1 24 | memory_per_executor = "500M" 25 | spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor) 26 | 27 | # Then you can code as you are using spark 28 | # The dataset can be downloaded from: 29 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data 30 | # Here we just use a subset of the training data 31 | data = spark.read.format("csv").option("header", "true") \ 32 | .option("inferSchema", "true") \ 33 | .load(NYC_TRAIN_CSV) 34 | # Set spark timezone for processing datetime 35 | spark.conf.set("spark.sql.session.timeZone", "UTC") 36 | # Transform the dataset 37 | data = nyc_taxi_preprocess(data) 38 | # Split data into train_dataset and test_dataset 39 | train_df, test_df = random_split(data, [0.9, 0.1], 0) 40 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"] 41 | # Define a neural network model 42 | class NYC_Model(nn.Module): 43 | def __init__(self, cols): 44 | super().__init__() 45 | self.fc1 = nn.Linear(cols, 256) 46 | self.fc2 = nn.Linear(256, 128) 47 | self.fc3 = nn.Linear(128, 64) 48 | self.fc4 = nn.Linear(64, 16) 49 | self.fc5 = nn.Linear(16, 1) 50 | self.bn1 = nn.BatchNorm1d(256) 51 | self.bn2 = nn.BatchNorm1d(128) 52 | self.bn3 = nn.BatchNorm1d(64) 53 | self.bn4 = nn.BatchNorm1d(16) 54 | 55 | def forward(self, x): 56 | x = F.relu(self.fc1(x)) 57 | x = self.bn1(x) 58 | x = F.relu(self.fc2(x)) 59 | x = self.bn2(x) 60 | x = F.relu(self.fc3(x)) 61 | x = self.bn3(x) 62 | x = F.relu(self.fc4(x)) 63 | x = self.bn4(x) 64 | x = self.fc5(x) 65 | return x 66 | 67 | nyc_model = NYC_Model(len(features)) 68 | criterion = nn.SmoothL1Loss() 69 | optimizer = torch.optim.Adam(nyc_model.parameters(), lr=0.001) 70 | # Create a distributed estimator based on the raydp api 71 | estimator = TorchEstimator(num_workers=1, model=nyc_model, optimizer=optimizer, loss=criterion, 72 | feature_columns=features, feature_types=torch.float, 73 | label_column="fare_amount", label_type=torch.float, 74 | batch_size=64, num_epochs=30, 75 | metrics_name = ["MeanAbsoluteError", "MeanSquaredError"], 76 | use_ccl=False) 77 | # Train the model 78 | estimator.fit_on_spark(train_df, test_df) 79 | # Get the trained model 80 | model = estimator.get_model() 81 | # shutdown raydp and ray 82 | raydp.stop_spark() 83 | ray.shutdown() 84 | -------------------------------------------------------------------------------- /examples/random_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | base_date = np.datetime64("2010-01-01 00:00:00") 8 | 9 | parser = argparse.ArgumentParser(description="Rabdin NYC taxi Generator") 10 | parser.add_argument( 11 | "--num-records", 12 | type=int, 13 | default=2000, 14 | metavar="N", 15 | help="number of records to generate (default: 2000)") 16 | 17 | args = parser.parse_args() 18 | 19 | N = args.num_records 20 | 21 | fare_amount = np.random.uniform(3.0, 50.0, size=N) 22 | pick_long = np.random.uniform(-74.2, -73.8, size=N) 23 | pick_lat = np.random.uniform(40.7, 40.8, size=N) 24 | drop_long = np.random.uniform(-74.2, -73.8, size=N) 25 | drop_lat = np.random.uniform(40.7, 40.8, size=N) 26 | passenger_count = np.random.randint(1, 5, size=N) 27 | date = np.random.randint(0, 157680000, size=N) + base_date 28 | date = np.array([t.item().strftime("%Y-%m-%d %H:%m:%S UTC") for t in date]) 29 | key = ["fake_key"] * N 30 | df = pd.DataFrame({ 31 | "key": key, 32 | "fare_amount":fare_amount, 33 | "pickup_datetime": date, 34 | "pickup_longitude": pick_long, 35 | "pickup_latitude": pick_lat, 36 | "dropoff_longitude": drop_long, 37 | "dropoff_latitude": drop_lat, 38 | "passenger_count": passenger_count 39 | }) 40 | csv_path = os.path.dirname(os.path.realpath(__file__)) + "/fake_nyctaxi.csv" 41 | df.to_csv(csv_path, index=False) 42 | -------------------------------------------------------------------------------- /examples/raydp-submit.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname 2 | import sys 3 | import json 4 | import subprocess 5 | import ray 6 | import pyspark 7 | 8 | ray.init(address="auto") 9 | node = ray.worker.global_worker.node 10 | options = {} 11 | options["ray"] = {} 12 | options["ray"]["run-mode"] = "CLUSTER" 13 | options["ray"]["node-ip"] = node.node_ip_address 14 | options["ray"]["address"] = node.address 15 | options["ray"]["session-dir"] = node.get_session_dir_path() 16 | 17 | ray.shutdown() 18 | conf_path = dirname(__file__) + "/ray.conf" 19 | with open(conf_path, "w") as f: 20 | json.dump(options, f) 21 | command = ["bin/raydp-submit", "--ray-conf", conf_path] 22 | command += ["--conf", "spark.executor.cores=1"] 23 | command += ["--conf", "spark.executor.instances=1"] 24 | command += ["--conf", "spark.executor.memory=500m"] 25 | example_path = dirname(pyspark.__file__) 26 | # run SparkPi as example 27 | command.append(example_path + "/examples/src/main/python/pi.py") 28 | sys.exit(subprocess.run(command, check=True).returncode) 29 | -------------------------------------------------------------------------------- /examples/tensorflow_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from tensorflow import keras 3 | from tensorflow.keras.callbacks import Callback 4 | 5 | import raydp 6 | from raydp.tf import TFEstimator 7 | from raydp.utils import random_split 8 | 9 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV 10 | from typing import List, Dict 11 | # Firstly, You need to init or connect to a ray cluster. 12 | # Note that you should set include_java to True. 13 | # For more config info in ray, please refer the ray doc: 14 | # https://docs.ray.io/en/latest/package-ref.html 15 | # ray.init(address="auto") 16 | ray.init(address="local", num_cpus=6) 17 | 18 | # After initialize ray cluster, you can use the raydp api to get a spark session 19 | app_name = "NYC Taxi Fare Prediction with RayDP" 20 | num_executors = 1 21 | cores_per_executor = 1 22 | memory_per_executor = "500M" 23 | spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor) 24 | 25 | # Then you can code as you are using spark 26 | # The dataset can be downloaded from: 27 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data 28 | # Here we just use a subset of the training data 29 | data = spark.read.format("csv").option("header", "true") \ 30 | .option("inferSchema", "true") \ 31 | .load(NYC_TRAIN_CSV) 32 | # Set spark timezone for processing datetime 33 | spark.conf.set("spark.sql.session.timeZone", "UTC") 34 | # Transform the dataset 35 | data = nyc_taxi_preprocess(data) 36 | data = data.cache() 37 | # Split data into train_dataset and test_dataset 38 | train_df, test_df = random_split(data, [0.9, 0.1], 0) 39 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"] 40 | 41 | # Define the keras model 42 | model = keras.Sequential( 43 | [ 44 | keras.layers.InputLayer(input_shape=(len(features),)), 45 | keras.layers.Flatten(), 46 | keras.layers.Dense(256, activation="relu"), 47 | keras.layers.BatchNormalization(), 48 | keras.layers.Dense(128, activation="relu"), 49 | keras.layers.BatchNormalization(), 50 | keras.layers.Dense(64, activation="relu"), 51 | keras.layers.BatchNormalization(), 52 | keras.layers.Dense(32, activation="relu"), 53 | keras.layers.BatchNormalization(), 54 | keras.layers.Dense(16, activation="relu"), 55 | keras.layers.BatchNormalization(), 56 | keras.layers.Dense(1), 57 | ] 58 | ) 59 | 60 | class PrintingCallback(Callback): 61 | def handle_result(self, results: List[Dict], **info): 62 | print(results) 63 | 64 | # Define the optimizer and loss function 65 | # Then create the tensorflow estimator provided by Raydp 66 | adam = keras.optimizers.Adam(learning_rate=0.001) 67 | loss = keras.losses.MeanSquaredError() 68 | estimator = TFEstimator(num_workers=2, model=model, optimizer=adam, loss=loss, 69 | merge_feature_columns=True, metrics=["mae"], 70 | feature_columns=features, label_columns="fare_amount", 71 | batch_size=256, num_epochs=10, callbacks=[PrintingCallback()]) 72 | 73 | # Train the model 74 | estimator.fit_on_spark(train_df, test_df) 75 | # Get the model 76 | model = estimator.get_model() 77 | # shudown raydp and ray 78 | raydp.stop_spark() 79 | ray.shutdown() 80 | -------------------------------------------------------------------------------- /examples/xgboost_ray_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import numpy as np 3 | # XGBoost on ray is needed to run this example. 4 | # Please refer to https://docs.ray.io/en/latest/xgboost-ray.html to install it. 5 | from xgboost_ray import RayDMatrix, train, RayParams 6 | import raydp 7 | from raydp.utils import random_split 8 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV 9 | 10 | # connect to ray cluster 11 | # ray.init(address="auto") 12 | ray.init(address="local", num_cpus=4) 13 | # After ray.init, you can use the raydp api to get a spark session 14 | app_name = "NYC Taxi Fare Prediction with RayDP" 15 | num_executors = 1 16 | cores_per_executor = 1 17 | memory_per_executor = "500M" 18 | spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor) 19 | data = spark.read.format("csv").option("header", "true") \ 20 | .option("inferSchema", "true") \ 21 | .load(NYC_TRAIN_CSV) 22 | # Set spark timezone for processing datetime 23 | spark.conf.set("spark.sql.session.timeZone", "UTC") 24 | # Transform the dataset 25 | data = nyc_taxi_preprocess(data) 26 | # Split data into train_dataset and test_dataset 27 | train_df, test_df = random_split(data, [0.9, 0.1], 0) 28 | # Convert spark dataframe into ray dataset 29 | train_dataset = ray.data.from_spark(train_df) 30 | test_dataset = ray.data.from_spark(test_df) 31 | # Then convert them into DMatrix used by xgboost 32 | dtrain = RayDMatrix(train_dataset, label="fare_amount") 33 | dtest = RayDMatrix(test_dataset, label="fare_amount") 34 | # Configure the XGBoost model 35 | config = { 36 | "tree_method": "hist", 37 | "eval_metric": ["logloss", "error"], 38 | } 39 | evals_result = {} 40 | # Train the model 41 | bst = train( 42 | config, 43 | dtrain, 44 | evals=[(dtest, "eval")], 45 | evals_result=evals_result, 46 | ray_params=RayParams(max_actor_restarts=1, num_actors=1, cpus_per_actor=1), 47 | num_boost_round=10) 48 | # print evaluation stats 49 | print("Final validation error: {:.4f}".format( 50 | evals_result["eval"]["error"][-1])) 51 | raydp.stop_spark() 52 | ray.shutdown() 53 | -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | include README.md 19 | recursive-include raydp/jars *.jar 20 | global-exclude *.py[cod] __pycache__ .DS_Store -------------------------------------------------------------------------------- /python/raydp/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from raydp.context import init_spark, stop_spark 19 | 20 | __version__ = "1.7.0.dev0" 21 | 22 | __all__ = ["init_spark", "stop_spark"] 23 | -------------------------------------------------------------------------------- /python/raydp/estimator.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, NoReturn, Optional 20 | 21 | 22 | 23 | class EstimatorInterface(ABC): 24 | """ 25 | A scikit-learn like API. 26 | """ 27 | 28 | @abstractmethod 29 | def fit(self, 30 | train_ds, 31 | evaluate_ds = None) -> NoReturn: 32 | """Train or evaluate the model. 33 | 34 | :param train_ds: the model will train on the MLDataset 35 | :param evaluate_ds: if this is provided, the model will evaluate on the MLDataset 36 | """ 37 | 38 | @abstractmethod 39 | def get_model(self) -> Any: 40 | """Get the trained model 41 | 42 | :return the model 43 | """ 44 | -------------------------------------------------------------------------------- /python/raydp/mpi/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | 19 | from typing import Callable, List 20 | 21 | from .mpi_job import MPIJob, MPIType, IntelMPIJob, OpenMPIJob, MPICHJob, MPIJobContext 22 | from .mpi_worker import WorkerContext 23 | 24 | 25 | def _get_mpi_type(mpi_type: str) -> MPIType: 26 | if mpi_type.strip().lower() == "openmpi": 27 | return MPIType.OPEN_MPI 28 | elif mpi_type.strip().lower() == "intel_mpi": 29 | return MPIType.INTEL_MPI 30 | elif mpi_type.strip().lower() == "mpich": 31 | return MPIType.MPICH 32 | else: 33 | return None 34 | 35 | 36 | def create_mpi_job(job_name: str, 37 | world_size: int, 38 | num_cpus_per_process: int, 39 | num_processes_per_node: int, 40 | mpi_script_prepare_fn: Callable = None, 41 | timeout: int = 1, 42 | mpi_type: str = "intel_mpi", 43 | placement_group=None, 44 | placement_group_bundle_indexes: List[int] = None) -> MPIJob: 45 | """Create a MPI Job 46 | 47 | :param job_name: the job name 48 | :param world_size: the world size 49 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray 50 | :param num_processes_per_node: num processes per node 51 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a 52 | MPIJobContext instance. It will use the default script if not provides. 53 | :param timeout: the timeout used to wait for job creation 54 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and MPICH 55 | :param placement_group: the placement_group for request mpi resources 56 | :param placement_group_bundle_indexes: this should be equal with 57 | world_size / num_processes_per_node if provides. 58 | """ 59 | mpi_type = _get_mpi_type(mpi_type) 60 | if mpi_type == MPIType.OPEN_MPI: 61 | return OpenMPIJob(mpi_type=MPIType.OPEN_MPI, 62 | job_name=job_name, 63 | world_size=world_size, 64 | num_cpus_per_process=num_cpus_per_process, 65 | num_processes_per_node=num_processes_per_node, 66 | mpi_script_prepare_fn=mpi_script_prepare_fn, 67 | timeout=timeout, 68 | placement_group=placement_group, 69 | placement_group_bundle_indexes=placement_group_bundle_indexes) 70 | elif mpi_type == MPIType.INTEL_MPI: 71 | return IntelMPIJob(mpi_type=MPIType.INTEL_MPI, 72 | job_name=job_name, 73 | world_size=world_size, 74 | num_cpus_per_process=num_cpus_per_process, 75 | num_processes_per_node=num_processes_per_node, 76 | mpi_script_prepare_fn=mpi_script_prepare_fn, 77 | timeout=timeout, 78 | placement_group=placement_group, 79 | placement_group_bundle_indexes=placement_group_bundle_indexes) 80 | elif mpi_type == MPIType.MPICH: 81 | return MPICHJob(mpi_type=MPIType.MPICH, 82 | job_name=job_name, 83 | world_size=world_size, 84 | num_cpus_per_process=num_cpus_per_process, 85 | num_processes_per_node=num_processes_per_node, 86 | mpi_script_prepare_fn=mpi_script_prepare_fn, 87 | timeout=timeout, 88 | placement_group=placement_group, 89 | placement_group_bundle_indexes=placement_group_bundle_indexes) 90 | else: 91 | raise Exception(f"MPI type: {mpi_type} not supported now") 92 | 93 | 94 | __all__ = ["create_mpi_job", "MPIJobContext", "WorkerContext"] 95 | -------------------------------------------------------------------------------- /python/raydp/mpi/constants.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from os import path 19 | 20 | MPI_TYPE = "raydp_mpi_type" 21 | MPI_JOB_ID = "raydp_mpi_job_id" 22 | MPI_DRIVER_HOST = "raydp_mpi_driver_host" 23 | MPI_DRIVER_PORT = "raydp_mpi_driver_port" 24 | 25 | MAXIMUM_WAIT_TIME_OUT = "raydp_maximum_wait_time_out" 26 | 27 | _current_dir = path.dirname(path.realpath(__file__)) 28 | MPI_MAIN_CLASS_PATH = path.join(_current_dir, "mpi_worker.py") 29 | -------------------------------------------------------------------------------- /python/raydp/mpi/network/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import sys 19 | import os 20 | 21 | dir_path = os.path.dirname(os.path.realpath(__file__)) 22 | sys.path.append(str(dir_path)) 23 | -------------------------------------------------------------------------------- /python/raydp/mpi/network/network.proto: -------------------------------------------------------------------------------- 1 | // 2 | // Licensed to the Apache Software Foundation (ASF) under one or more 3 | // contributor license agreements. See the NOTICE file distributed with 4 | // this work for additional information regarding copyright ownership. 5 | // The ASF licenses this file to You under the Apache License, Version 2.0 6 | // (the "License"); you may not use this file except in compliance with 7 | // the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | // 17 | 18 | // Syntax version 19 | syntax = "proto3"; 20 | 21 | // Driver Service definition 22 | service DriverService { 23 | // register the worker process to driver which used to tell the worker has started up 24 | rpc RegisterWorker (RegisterWorkerRequest) returns (RegisterWorkerReply); 25 | // register the worker service host and port 26 | rpc RegisterWorkerService (RegisterWorkerServiceRequest) returns (RegisterWorkerServiceReply); 27 | // register the function result 28 | rpc RegisterFuncResult (FunctionResult) returns (Empty); 29 | } 30 | 31 | // Worker Service 32 | service WorkerService { 33 | // run the given function 34 | rpc RunFunction (Function) returns (Empty); 35 | // stop the worker service 36 | rpc Stop (Empty) returns (Empty); 37 | } 38 | 39 | message RegisterWorkerRequest { 40 | // the job id 41 | string job_id = 1; 42 | // the world rank id 43 | int32 world_rank = 2; 44 | } 45 | 46 | message RegisterWorkerReply { 47 | // the all node addresses and used to determine the current node ip adddress 48 | repeated string node_addresses = 3; 49 | } 50 | 51 | message RegisterWorkerServiceRequest { 52 | // the world rank 53 | int32 world_rank = 1; 54 | // the worker service listening ip 55 | string worker_ip = 2; 56 | // the worker service listening port 57 | int32 worker_port = 3; 58 | } 59 | 60 | message RegisterWorkerServiceReply { 61 | // the ray redis address 62 | string ray_address = 1; 63 | // the ray redis password 64 | string redis_password = 2; 65 | } 66 | 67 | message Function { 68 | // the function id 69 | int32 func_id = 1; 70 | // the serialized python function 71 | bytes func = 2; 72 | } 73 | 74 | message FunctionResult { 75 | int32 world_rank = 1; 76 | // the function id 77 | int32 func_id = 2; 78 | // the function results 79 | bytes result = 3; 80 | } 81 | 82 | message Empty { 83 | } 84 | -------------------------------------------------------------------------------- /python/raydp/mpi/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import select 20 | import subprocess 21 | import threading 22 | import time 23 | from typing import List 24 | 25 | import grpc 26 | import netifaces 27 | 28 | 29 | class StoppableThread(threading.Thread): 30 | 31 | def __init__(self, group=None, target=None, name=None, 32 | args=(), kwargs=None, *, daemon=None): 33 | super().__init__(group, target, name, args, kwargs, daemon=daemon) 34 | self._stop_event = threading.Event() 35 | 36 | def stop(self): 37 | self._stop_event.set() 38 | 39 | def stopped(self): 40 | return self._stop_event.is_set() 41 | 42 | 43 | def run_cmd(cmd: str, env, failed_callback): 44 | # pylint: disable=R1732 45 | proc = subprocess.Popen(cmd, 46 | shell=True, 47 | stdin=subprocess.DEVNULL, 48 | stdout=subprocess.PIPE, 49 | stderr=subprocess.PIPE, 50 | env=env, 51 | start_new_session=True) 52 | 53 | def check_failed(): 54 | # check whether the process has finished 55 | while not threading.current_thread().stopped(): 56 | ret_code = proc.poll() 57 | if ret_code: 58 | failed_callback() 59 | raise Exception(f"mpirun failed: {ret_code}") 60 | 61 | if ret_code == 0: 62 | break 63 | 64 | time.sleep(1) 65 | 66 | check_thread = StoppableThread(target=check_failed) 67 | 68 | def redirect_stream(streams): 69 | while not threading.current_thread().stopped() and streams: 70 | readable, _, _ = select.select(streams, [], [], 0.5) 71 | for stream in readable: 72 | if not stream: 73 | continue 74 | line = stream.readline() 75 | if not line: 76 | streams.remove(stream) 77 | else: 78 | print(line.decode().strip("\n")) 79 | 80 | redirect_thread = StoppableThread(target=redirect_stream, args=([proc.stdout, proc.stderr],)) 81 | check_thread.start() 82 | redirect_thread.start() 83 | return proc, check_thread, redirect_thread 84 | 85 | 86 | def create_insecure_channel(address, 87 | options=None, 88 | compression=None): 89 | """Disable the http proxy when create channel""" 90 | # disable http proxy 91 | if options is not None: 92 | need_add = True 93 | for k, v in options: 94 | if k == "grpc.enable_http_proxy": 95 | need_add = False 96 | break 97 | if need_add: 98 | options = (*options, ("grpc.enable_http_proxy", 0)) 99 | else: 100 | options = (("grpc.enable_http_proxy", 0),) 101 | 102 | return grpc.insecure_channel( 103 | address, options, compression) 104 | 105 | 106 | def get_environ_value(key: str) -> str: 107 | """Get value from environ, raise exception if the key not existed""" 108 | assert key in os.environ, f"{key} should be set in the environ" 109 | return os.environ[key] 110 | 111 | 112 | def get_node_ip_address(node_addresses: List[str]) -> str: 113 | found = None 114 | for interface in netifaces.interfaces(): 115 | addrs = netifaces.ifaddresses(interface) 116 | addresses = addrs.get(netifaces.AF_INET, None) 117 | if not addresses: 118 | continue 119 | for inet_addr in addresses: 120 | address = inet_addr.get("addr", None) 121 | if address in node_addresses: 122 | found = address 123 | return found 124 | -------------------------------------------------------------------------------- /python/raydp/ray_cluster_resources.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import Dict, List 19 | 20 | import ray 21 | import time 22 | from ray.ray_constants import MEMORY_RESOURCE_UNIT_BYTES 23 | 24 | 25 | class ClusterResources: 26 | # TODO: make this configurable 27 | refresh_interval = 0.1 28 | latest_refresh_time = time.time() - refresh_interval 29 | node_to_resources = {} 30 | item_keys_mapping = {"num_cpus": "CPU"} 31 | label_name = "__ray_spark_node_label" 32 | 33 | @classmethod 34 | def total_alive_nodes(cls): 35 | cls._refresh() 36 | return len(cls.node_to_resources) 37 | 38 | @classmethod 39 | def satisfy(cls, request: Dict[str, float]) -> List[str]: 40 | cls._refresh() 41 | satisfied = [] 42 | for host_name, resources in cls.node_to_resources.items(): 43 | if cls._compare_two_dict(resources, request): 44 | satisfied.append(resources[cls.label_name]) 45 | 46 | return satisfied 47 | 48 | @classmethod 49 | def _refresh(cls): 50 | if (time.time() - cls.latest_refresh_time) < cls.refresh_interval: 51 | return 52 | 53 | for node in ray.nodes(): 54 | if node["Alive"]: 55 | host_name = node["NodeManagerHostname"] 56 | resources = node["Resources"] 57 | for key in resources: 58 | if key.startswith("node:"): 59 | resources[cls.label_name] = key 60 | break 61 | assert cls.label_name in resources,\ 62 | f"{resources} should contain a resource likes: 'node:10.0.0.131': 1.0" 63 | cls.node_to_resources[host_name] = resources 64 | cls.latest_refresh_time = time.time() 65 | 66 | @classmethod 67 | def _compare_two_dict(cls, available: Dict[str, float], request: Dict[str, float]) -> bool: 68 | for k, v in request.items(): 69 | k = cls.item_keys_mapping.get(k, k) 70 | if k not in available: 71 | return False 72 | 73 | if k == "memory": 74 | v = int(v / MEMORY_RESOURCE_UNIT_BYTES) 75 | 76 | if available[k] < v: 77 | return False 78 | 79 | return True 80 | -------------------------------------------------------------------------------- /python/raydp/services.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict, NoReturn 20 | 21 | 22 | class Cluster(ABC): 23 | """ 24 | This is the base class for all specified cluster, such as SparkCluster, FlinkCluster. 25 | :param master_resources_requirement: The resources requirement for the master service. 26 | """ 27 | def __init__(self, master_resources_requirement): 28 | # the master node is live as same as ray driver node. And we can specify the resources 29 | # limitation for master node. So we don't count it. 30 | self._num_nodes = 0 31 | 32 | @abstractmethod 33 | def _set_up_master(self, 34 | resources: Dict[str, float], 35 | kwargs: Dict[Any, Any]): 36 | """ 37 | Subcluster should implement this to set up master node. 38 | """ 39 | 40 | def add_worker(self, 41 | resources_requirement: Dict[str, float], 42 | **kwargs: Dict[Any, Any]): 43 | """ 44 | Add one worker to the cluster. 45 | :param resources_requirement: The resource requirements for the worker service. 46 | """ 47 | try: 48 | self._set_up_worker(resources_requirement, kwargs) 49 | except: 50 | self.stop() 51 | raise 52 | 53 | @abstractmethod 54 | def _set_up_worker(self, 55 | resources: Dict[str, float], 56 | kwargs: Dict[str, str]): 57 | """ 58 | Subcluster should implement this to set up worker node. 59 | """ 60 | 61 | @abstractmethod 62 | def get_cluster_url(self) -> str: 63 | """ 64 | Return the cluster url, eg: spark://master-node:7077 65 | """ 66 | 67 | @abstractmethod 68 | def stop(self): 69 | """ 70 | Stop cluster 71 | """ 72 | 73 | 74 | class ClusterMaster(ABC): 75 | 76 | @abstractmethod 77 | def start_up(self) -> NoReturn: 78 | pass 79 | 80 | @abstractmethod 81 | def get_master_url(self) -> str: 82 | pass 83 | 84 | @abstractmethod 85 | def get_host(self) -> str: 86 | pass 87 | 88 | @abstractmethod 89 | def stop(self): 90 | pass 91 | 92 | 93 | class ClusterWorker(ABC): 94 | 95 | @abstractmethod 96 | def start_up(self) -> str: 97 | """ 98 | :return: error message, return None if succeeded 99 | """ 100 | 101 | @abstractmethod 102 | def get_host(self) -> str: 103 | pass 104 | 105 | @abstractmethod 106 | def stop(self): 107 | pass 108 | -------------------------------------------------------------------------------- /python/raydp/spark/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .dataset import PartitionObjectsOwner, \ 19 | get_raydp_master_owner, \ 20 | spark_dataframe_to_ray_dataset, \ 21 | ray_dataset_to_spark_dataframe, \ 22 | from_spark_recoverable 23 | from .interfaces import SparkEstimatorInterface 24 | from .ray_cluster import SparkCluster 25 | 26 | __all__ = [ 27 | "SparkCluster", 28 | "SparkEstimatorInterface", 29 | "PartitionObjectsOwner", 30 | "get_raydp_master_owner", 31 | "spark_dataframe_to_ray_dataset", 32 | "ray_dataset_to_spark_dataframe", 33 | "from_spark_recoverable" 34 | ] 35 | -------------------------------------------------------------------------------- /python/raydp/spark/interfaces.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import NoReturn 19 | from typing import Optional, Union 20 | 21 | from raydp.utils import convert_to_spark 22 | 23 | DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"] 24 | OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]] 25 | 26 | 27 | class SparkEstimatorInterface: 28 | def _check_and_convert(self, df): 29 | train_df, _ = convert_to_spark(df) 30 | return train_df 31 | 32 | def fit_on_spark(self, 33 | train_df: DF, 34 | evaluate_df: OPTIONAL_DF = None) -> NoReturn: 35 | """Fit and evaluate the model on the Spark or koalas DataFrame. 36 | 37 | :param train_df the DataFrame which the model will train on. 38 | :param evaluate_df the optional DataFrame which the model evaluate on it 39 | """ 40 | -------------------------------------------------------------------------------- /python/raydp/spark/parallel_iterator_worker.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from typing import Any 19 | 20 | from ray.util.iter import ParallelIteratorWorker 21 | 22 | 23 | class ParallelIteratorWorkerWithLen(ParallelIteratorWorker): 24 | def __init__(self, item_generator: Any, repeat: bool, num_records: int): 25 | super().__init__(item_generator, repeat) 26 | self.num_records = num_records 27 | 28 | def __len__(self): 29 | return self.num_records 30 | -------------------------------------------------------------------------------- /python/raydp/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import logging 19 | import subprocess 20 | import time 21 | 22 | import pyspark 23 | import pytest 24 | import ray 25 | import raydp 26 | from pyspark.sql import SparkSession 27 | 28 | 29 | def quiet_logger(): 30 | py4j_logger = logging.getLogger("py4j") 31 | py4j_logger.setLevel(logging.WARNING) 32 | 33 | koalas_logger = logging.getLogger("koalas") 34 | koalas_logger.setLevel(logging.WARNING) 35 | 36 | 37 | @pytest.fixture(scope="function") 38 | def spark_session(request): 39 | spark = SparkSession.builder.master("local[2]").appName("RayDP test").getOrCreate() 40 | request.addfinalizer(lambda: spark.stop()) 41 | quiet_logger() 42 | return spark 43 | 44 | 45 | @pytest.fixture(scope="function", params=["local", "ray://localhost:10001"]) 46 | def ray_cluster(request): 47 | ray.shutdown() 48 | if request.param == "local": 49 | ray.init(address="local", num_cpus=6, include_dashboard=False) 50 | else: 51 | ray.init(address=request.param) 52 | request.addfinalizer(lambda: ray.shutdown()) 53 | 54 | 55 | @pytest.fixture(scope="function", params=["local", "ray://localhost:10001"]) 56 | def spark_on_ray_small(request): 57 | ray.shutdown() 58 | if request.param == "local": 59 | ray.init(address="local", num_cpus=6, include_dashboard=False) 60 | else: 61 | ray.init(address=request.param) 62 | node_ip = ray.util.get_node_ip_address() 63 | spark = raydp.init_spark("test", 1, 1, "500M", configs={ 64 | "spark.driver.host": node_ip, 65 | "spark.driver.bindAddress": node_ip 66 | }) 67 | 68 | def stop_all(): 69 | spark.stop() 70 | raydp.stop_spark() 71 | time.sleep(5) 72 | ray.shutdown() 73 | 74 | request.addfinalizer(stop_all) 75 | return spark 76 | 77 | 78 | @pytest.fixture(scope="function", params=["local", "ray://localhost:10001"]) 79 | def spark_on_ray_2_executors(request): 80 | ray.shutdown() 81 | if request.param == "local": 82 | ray.init(address="local", num_cpus=6, include_dashboard=False) 83 | else: 84 | ray.init(address=request.param) 85 | node_ip = ray.util.get_node_ip_address() 86 | spark = raydp.init_spark("test", 2, 1, "500M", configs={ 87 | "spark.driver.host": node_ip, 88 | "spark.driver.bindAddress": node_ip 89 | }) 90 | 91 | def stop_all(): 92 | spark.stop() 93 | raydp.stop_spark() 94 | time.sleep(5) 95 | ray.shutdown() 96 | 97 | request.addfinalizer(stop_all) 98 | return spark 99 | 100 | @pytest.fixture(scope='session') 101 | def custom_spark_dir(tmp_path_factory) -> str: 102 | working_dir = tmp_path_factory.mktemp("spark").as_posix() 103 | 104 | # Leave the if more verbose just in case the distribution name changed in the future. 105 | # Please make sure the version here is not the most recent release, so the file is available 106 | # in the archive download. Latest release's download URL (https://dlcdn.apache.org/spark/*) 107 | # will be changed to archive when the next release come out and break the test. 108 | if pyspark.__version__ == "3.2.1": 109 | spark_distribution = 'spark-3.2.1-bin-hadoop3.2' 110 | elif pyspark.__version__ == "3.1.3": 111 | spark_distribution = 'spark-3.1.3-bin-hadoop3.2' 112 | else: 113 | raise Exception(f"Unsupported Spark version {pyspark.__version__}.") 114 | 115 | file_extension = 'tgz' 116 | spark_distribution_file = f"{working_dir}/{spark_distribution}.{file_extension}" 117 | 118 | import wget 119 | 120 | wget.download( 121 | f"https://archive.apache.org/dist/spark/spark-{pyspark.__version__}/{spark_distribution}.{file_extension}", 122 | spark_distribution_file) 123 | subprocess.check_output(['tar', 'xzvf', spark_distribution_file, '--directory', working_dir]) 124 | return f"{working_dir}/{spark_distribution}" 125 | -------------------------------------------------------------------------------- /python/raydp/tests/test_tf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import pyspark 19 | import pytest 20 | import os 21 | import sys 22 | import shutil 23 | 24 | import tensorflow as tf 25 | import tensorflow.keras as keras 26 | 27 | from pyspark.sql.functions import rand 28 | 29 | from raydp.tf import TFEstimator 30 | from raydp.utils import random_split 31 | 32 | @pytest.mark.parametrize("use_fs_directory", [True, False]) 33 | def test_tf_estimator(spark_on_ray_small, use_fs_directory): 34 | spark = spark_on_ray_small 35 | 36 | # ---------------- data process with Spark ------------ 37 | # calculate y = 3 * x + 4 38 | df: pyspark.sql.DataFrame = spark.range(0, 100000) 39 | df = df.withColumn("x", rand() * 100) # add x column 40 | df = df.withColumn("y", df.x * 3 + rand() + 4) # add y column 41 | df = df.select(df.x, df.y) 42 | 43 | train_df, test_df = random_split(df, [0.7, 0.3]) 44 | 45 | # create model 46 | model = keras.Sequential( 47 | [ 48 | keras.layers.InputLayer(input_shape=()), 49 | # Add feature dimension, expanding (batch_size,) to (batch_size, 1). 50 | keras.layers.Flatten(), 51 | keras.layers.Dense(1), 52 | ] 53 | ) 54 | 55 | optimizer = keras.optimizers.Adam(0.01) 56 | loss = keras.losses.MeanSquaredError() 57 | 58 | estimator = TFEstimator(num_workers=2, 59 | model=model, 60 | optimizer=optimizer, 61 | loss=loss, 62 | metrics=["accuracy", "mse"], 63 | feature_columns="x", 64 | label_columns="y", 65 | batch_size=1000, 66 | num_epochs=2, 67 | use_gpu=False) 68 | 69 | if use_fs_directory: 70 | dir = os.path.dirname(__file__) + "/test_tf" 71 | uri = "file://" + dir 72 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri) 73 | else: 74 | estimator.fit_on_spark(train_df, test_df) 75 | model = estimator.get_model() 76 | result = model(tf.constant([0, 0])) 77 | assert result.shape == (2, 1) 78 | if use_fs_directory: 79 | shutil.rmtree(dir) 80 | 81 | if __name__ == "__main__": 82 | # sys.exit(pytest.main(["-v", __file__])) 83 | import ray, raydp 84 | ray.init() 85 | spark = raydp.init_spark('a', 6, 1, '500m') 86 | test_tf_estimator(spark, False) 87 | -------------------------------------------------------------------------------- /python/raydp/tests/test_torch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import pytest 19 | import os 20 | import sys 21 | import shutil 22 | import torch 23 | 24 | # https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html 25 | # import databricks.koalas as ks 26 | import pyspark.pandas as ps 27 | 28 | from raydp.torch import TorchEstimator 29 | from raydp.utils import random_split 30 | 31 | @pytest.mark.parametrize("use_fs_directory", [True, False]) 32 | def test_torch_estimator(spark_on_ray_small, use_fs_directory): 33 | # ---------------- data process with koalas ------------ 34 | spark = spark_on_ray_small 35 | 36 | # calculate z = 3 * x + 4 * y + 5 37 | df: ps.DataFrame = ps.range(0, 100000) 38 | df["x"] = df["id"] + 100 39 | df["y"] = df["id"] + 1000 40 | df["z"] = df["x"] * 3 + df["y"] * 4 + 5 41 | df = df.astype("float") 42 | 43 | train_df, test_df = random_split(df, [0.7, 0.3]) 44 | 45 | # ---------------- ray sgd ------------------------- 46 | # create the model 47 | class LinearModel(torch.nn.Module): 48 | def __init__(self): 49 | super(LinearModel, self).__init__() 50 | self.linear = torch.nn.Linear(2, 1) 51 | 52 | def forward(self, x): 53 | return self.linear(x) 54 | 55 | model = LinearModel() 56 | # create the optimizer 57 | optimizer = torch.optim.Adam(model.parameters()) 58 | # create the loss 59 | loss = torch.nn.MSELoss() 60 | # create lr_scheduler 61 | 62 | def lr_scheduler_creator(optimizer, config): 63 | return torch.optim.lr_scheduler.MultiStepLR( 64 | optimizer, milestones=[150, 250, 350], gamma=0.1) 65 | 66 | # create the estimator 67 | estimator = TorchEstimator(num_workers=2, 68 | model=model, 69 | optimizer=optimizer, 70 | loss=loss, 71 | lr_scheduler_creator=lr_scheduler_creator, 72 | feature_columns=["x", "y"], 73 | feature_types=torch.float, 74 | label_column="z", 75 | label_type=torch.float, 76 | batch_size=1000, 77 | num_epochs=2, 78 | use_gpu=False) 79 | 80 | # train the model 81 | if use_fs_directory: 82 | dir = os.path.dirname(__file__) + "/test_torch" 83 | uri = "file://" + dir 84 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri) 85 | else: 86 | estimator.fit_on_spark(train_df, test_df) 87 | model = estimator.get_model() 88 | result = model(torch.Tensor([[0, 0], [1, 1]])) 89 | assert result.shape == (2, 1) 90 | if use_fs_directory: 91 | shutil.rmtree(dir) 92 | 93 | 94 | if __name__ == "__main__": 95 | sys.exit(pytest.main(["-v", __file__])) 96 | -------------------------------------------------------------------------------- /python/raydp/tests/test_torch_sequential.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import pytest 19 | import sys 20 | import torch 21 | import raydp 22 | from raydp.torch import TorchEstimator 23 | 24 | def test_torch_estimator(spark_on_ray_small): 25 | ##prepare the data 26 | customers = [ 27 | (1,'James', 21, 6), 28 | (2, "Liz", 25, 8), 29 | (3, "John", 31, 6), 30 | (4, "Jennifer", 45, 7), 31 | (5, "Robert", 41, 5), 32 | (6, "Sandra", 45, 8) 33 | ] 34 | df = spark_on_ray_small.createDataFrame(customers, ["cID", "name", "age", "grade"]) 35 | 36 | ##create model 37 | model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2,1)) 38 | optimizer = torch.optim.Adam(model.parameters()) 39 | loss = torch.nn.MSELoss() 40 | 41 | #config 42 | estimator = TorchEstimator( 43 | model = model, 44 | optimizer = optimizer, 45 | loss = loss, 46 | num_workers = 3, 47 | num_epochs = 5, 48 | feature_columns = ["age"], 49 | feature_types = torch.float, 50 | label_column = "grade", 51 | label_type = torch.float, 52 | batch_size = 1 53 | ) 54 | estimator.fit_on_spark(df) 55 | 56 | if __name__ == "__main__": 57 | sys.exit(pytest.main(["-v", __file__])) -------------------------------------------------------------------------------- /python/raydp/tests/test_xgboost.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import sys 20 | import shutil 21 | import platform 22 | import pytest 23 | import pyspark 24 | import numpy as np 25 | from pyspark.sql.functions import rand 26 | 27 | from raydp.xgboost import XGBoostEstimator 28 | from raydp.utils import random_split 29 | 30 | @pytest.mark.parametrize("use_fs_directory", [True, False]) 31 | def test_xgb_estimator(spark_on_ray_small, use_fs_directory): 32 | if platform.system() == "Darwin": 33 | pytest.skip("Skip xgboost test on MacOS") 34 | spark = spark_on_ray_small 35 | 36 | # calculate z = 3 * x + 4 * y + 5 37 | df: pyspark.sql.DataFrame = spark.range(0, 100000) 38 | df = df.withColumn("x", rand() * 100) # add x column 39 | df = df.withColumn("y", rand() * 1000) # ad y column 40 | df = df.withColumn("z", df.x * 3 + df.y * 4 + rand() + 5) # ad z column 41 | df = df.select(df.x, df.y, df.z) 42 | 43 | train_df, test_df = random_split(df, [0.7, 0.3]) 44 | params = {} 45 | estimator = XGBoostEstimator(params, "z", resources_per_worker={"CPU": 1}) 46 | if use_fs_directory: 47 | dir = os.path.dirname(os.path.realpath(__file__)) + "/test_xgboost" 48 | uri = "file://" + dir 49 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri) 50 | else: 51 | estimator.fit_on_spark(train_df, test_df) 52 | print(estimator.get_model().inplace_predict(np.asarray([[1,2]]))) 53 | if use_fs_directory: 54 | shutil.rmtree(dir) 55 | 56 | if __name__ == '__main__': 57 | import ray, raydp 58 | ray.init(address="auto") 59 | spark = raydp.init_spark('test_xgboost', 1, 1, '500m') 60 | test_xgb_estimator(spark, True) -------------------------------------------------------------------------------- /python/raydp/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .estimator import TFEstimator 19 | 20 | __all__ = ["TFEstimator"] 21 | -------------------------------------------------------------------------------- /python/raydp/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .estimator import TorchEstimator 19 | 20 | __all__ = ["TorchEstimator"] 21 | -------------------------------------------------------------------------------- /python/raydp/torch/config.py: -------------------------------------------------------------------------------- 1 | from ray.train.torch.config import _TorchBackend 2 | from ray.train.torch.config import TorchConfig as RayTorchConfig 3 | from ray.train._internal.worker_group import WorkerGroup 4 | from dataclasses import dataclass 5 | import sys 6 | # The package importlib_metadata is in a different place, depending on the Python version. 7 | if sys.version_info < (3, 8): 8 | import importlib_metadata 9 | else: 10 | import importlib.metadata as importlib_metadata 11 | 12 | @dataclass 13 | class TorchConfig(RayTorchConfig): 14 | 15 | @property 16 | def backend_cls(self): 17 | return EnableCCLBackend 18 | 19 | def libs_import(): 20 | """try to import IPEX and oneCCL. 21 | """ 22 | try: 23 | import intel_extension_for_pytorch 24 | except ImportError: 25 | raise ImportError( 26 | "Please install intel_extension_for_pytorch" 27 | ) 28 | try: 29 | ccl_version = importlib_metadata.version("oneccl_bind_pt") 30 | if ccl_version >= "1.12": 31 | # pylint: disable-all 32 | import oneccl_bindings_for_pytorch 33 | else: 34 | import torch_ccl 35 | except ImportError as ccl_not_exist: 36 | raise ImportError( 37 | "Please install torch-ccl" 38 | ) from ccl_not_exist 39 | 40 | class EnableCCLBackend(_TorchBackend): 41 | 42 | def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): 43 | for i in range(len(worker_group)): 44 | worker_group.execute_single_async(i, libs_import) 45 | super().on_start(worker_group, backend_config) 46 | -------------------------------------------------------------------------------- /python/raydp/torch/torch_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | module = sys.modules[__name__] 3 | 4 | def try_import_torchmetrics(): 5 | """Tries importing torchmetrics and returns the module (or None). 6 | Returns: 7 | torchmetrics modules. 8 | """ 9 | try: 10 | # pylint: disable=import-outside-toplevel 11 | import torchmetrics 12 | 13 | return torchmetrics 14 | except ImportError as torchmetrics_not_exist: 15 | raise ImportError( 16 | "Could not import torchmetrics! Raydp TorchEstimator requires " 17 | "you to install torchmetrics: " 18 | "`pip install torchmetrics`." 19 | ) from torchmetrics_not_exist 20 | 21 | class TorchMetric(): 22 | def __init__(self, metrics_name, metrics_config): 23 | torchmetrics = try_import_torchmetrics() 24 | self._metrics_name = metrics_name 25 | self._metrics_func = {} 26 | if self._metrics_name is not None: 27 | assert isinstance(metrics_name, list), "metrics_name must be a list" 28 | for metric in self._metrics_name: 29 | if isinstance(metric, torchmetrics.Metric): 30 | self._metrics_func[metric.__class__.__name__] = metric 31 | elif isinstance(metric, str) and hasattr(torchmetrics, metric): 32 | if metrics_config is not None and metrics_config[metric] is not None: 33 | self._metrics_func[metric] = getattr(torchmetrics, metric)( 34 | **metrics_config[metric]) 35 | else: 36 | self._metrics_func[metric] = getattr(torchmetrics, metric)() 37 | else: 38 | raise Exception( 39 | "Unsupported parameter, we only support list of " 40 | "torchmetrics.Metric instances or arr of torchmetrics.") 41 | 42 | def update(self, preds, targets): 43 | for metric in self._metrics_func: 44 | self._metrics_func[metric].update(preds, targets) 45 | 46 | def compute(self): 47 | epoch_res = {} 48 | for metric in self._metrics_func: 49 | epoch_res[metric] = self._metrics_func[metric].compute().item() 50 | 51 | return epoch_res 52 | 53 | def reset(self): 54 | for metric in self._metrics_func: 55 | self._metrics_func[metric].reset() 56 | -------------------------------------------------------------------------------- /python/raydp/torch/torch_ml_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import logging 19 | import queue 20 | import threading 21 | from typing import Callable 22 | 23 | import ray 24 | from ray.util.data import MLDataset 25 | from torch.utils.data import IterableDataset 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class TorchMLDataset(IterableDataset): 31 | def __init__(self, 32 | ds: MLDataset, 33 | collate_fn: Callable, 34 | shuffle: bool = False, 35 | shuffle_seed: int = None): 36 | super().__init__() 37 | self.ds = ds 38 | self.collate_fn = collate_fn 39 | self.shuffle = shuffle 40 | self.shuffle_seed = shuffle_seed or 1 41 | 42 | def __iter__(self): 43 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards()) 44 | it = iter(it) 45 | for pdf in it: 46 | if self.shuffle: 47 | pdf = pdf.sample(frac=1.0, random_state=self.shuffle_seed) 48 | yield self.collate_fn(pdf) 49 | 50 | def __len__(self): 51 | all_actors = [] 52 | for actor_set in self.ds.actor_sets: 53 | all_actors.extend(actor_set.actors) 54 | assert len(all_actors) > 0 55 | if "__len__" in dir(all_actors[0]): 56 | # This is a very hack method to get the length of the iterator 57 | num_records = sum([ray.get(actor.__len__.remote()) for actor in all_actors]) 58 | else: 59 | logger.warning("The MLDataset has not provide the __len__ method, we will iter all " 60 | "data to count the number of rows. This should be pretty slowly.") 61 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards()) 62 | it = iter(it) 63 | num_records = 0 64 | for pdf in it: 65 | num_records += pdf.shape[0] 66 | return num_records 67 | 68 | 69 | class PrefetchedDataLoader: 70 | def __init__(self, base_loader, max_size: int = 5): 71 | self.base_loader = base_loader 72 | self.max_size = max_size 73 | self.queue = queue.Queue(maxsize=max_size) 74 | self.fetcher = None 75 | self.fetcher_stop = threading.Event() 76 | 77 | def _setup(self): 78 | if self.fetcher is not None: 79 | self.fetcher_stop.set() 80 | if self.queue is not None and not self.queue.empty(): 81 | self.queue.get() 82 | self.queue = queue.Queue(maxsize=self.max_size) 83 | self.fetcher = None 84 | self.fetcher_stop.clear() 85 | 86 | it = iter(self.base_loader) 87 | 88 | def fetch_task(): 89 | while not self.fetcher_stop.is_set(): 90 | try: 91 | got_data = next(it) 92 | self.queue.put(got_data) 93 | except StopIteration: 94 | self.queue.put(None) 95 | break 96 | except: # pylint: disable=W0707, W0706 97 | raise 98 | self.fetcher = threading.Thread(target=fetch_task) 99 | self.fetcher.start() 100 | 101 | def __iter__(self): 102 | self._setup() 103 | while True: 104 | fetched_data = self.queue.get() 105 | if fetched_data is not None: 106 | yield fetched_data 107 | else: 108 | break 109 | 110 | def __len__(self): 111 | return len(self.base_loader) 112 | -------------------------------------------------------------------------------- /python/raydp/versions.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import re 19 | import pyspark 20 | 21 | 22 | # log4j1 if spark version <= 3.2, otherwise, log4j2 23 | SPARK_LOG4J_VERSION = "log4j" 24 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j.configurationFile" 25 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j-default.properties" 26 | _spark_ver = re.search("\\d+\\.\\d+", pyspark.version.__version__) 27 | if _spark_ver.group(0) > "3.2": 28 | SPARK_LOG4J_VERSION = "log4j2" 29 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile" 30 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2-default.properties" 31 | 32 | # support ray >= 2.1, they all use log4j2 33 | RAY_LOG4J_VERSION = "log4j2" 34 | RAY_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile" 35 | RAY_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2.xml" 36 | -------------------------------------------------------------------------------- /python/raydp/xgboost/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .estimator import XGBoostEstimator 19 | 20 | __all__ = ["XGBoostEstimator"] 21 | --------------------------------------------------------------------------------