├── .github
└── workflows
│ ├── pypi.yml
│ ├── ray_nightly_test.yml
│ ├── raydp.yml
│ └── raydp_nightly.yml
├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── bin
└── raydp-submit
├── build.sh
├── core
├── agent
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ └── java
│ │ └── org
│ │ ├── apache
│ │ └── spark
│ │ │ └── raydp
│ │ │ └── Agent.java
│ │ └── slf4j
│ │ └── impl
│ │ └── StaticLoggerBinder.java
├── javastyle-suppressions.xml
├── javastyle.xml
├── pom.xml
├── raydp-main
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ ├── java
│ │ └── org
│ │ │ └── apache
│ │ │ └── spark
│ │ │ ├── deploy
│ │ │ └── raydp
│ │ │ │ ├── ExternalShuffleServiceUtils.java
│ │ │ │ └── RayAppMasterUtils.java
│ │ │ └── raydp
│ │ │ ├── RayDPUtils.java
│ │ │ ├── RayExecutorUtils.java
│ │ │ └── SparkOnRayConfigs.java
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ └── org.apache.spark.scheduler.ExternalClusterManager
│ │ ├── scala
│ │ └── org
│ │ │ └── apache
│ │ │ └── spark
│ │ │ ├── RayDPException.scala
│ │ │ ├── deploy
│ │ │ ├── SparkSubmit.scala
│ │ │ └── raydp
│ │ │ │ ├── AppMasterEntryPoint.scala
│ │ │ │ ├── AppMasterJavaBridge.scala
│ │ │ │ ├── ApplicationDescription.scala
│ │ │ │ ├── ApplicationInfo.scala
│ │ │ │ ├── ApplicationState.scala
│ │ │ │ ├── Messages.scala
│ │ │ │ ├── RayAppMaster.scala
│ │ │ │ ├── RayDPDriverAgent.scala
│ │ │ │ └── RayExternalShuffleService.scala
│ │ │ ├── executor
│ │ │ └── RayDPExecutor.scala
│ │ │ ├── rdd
│ │ │ ├── RayDatasetRDD.scala
│ │ │ └── RayObjectRefRDD.scala
│ │ │ ├── scheduler
│ │ │ └── cluster
│ │ │ │ └── raydp
│ │ │ │ ├── RayClusterManager.scala
│ │ │ │ └── RayCoarseGrainedSchedulerBackend.scala
│ │ │ ├── sql
│ │ │ └── raydp
│ │ │ │ ├── ObjectStoreReader.scala
│ │ │ │ └── ObjectStoreWriter.scala
│ │ │ └── util
│ │ │ └── DependencyUtils.scala
│ │ └── test
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── scheduler
│ │ └── cluster
│ │ └── raydp
│ │ └── TestRayCoarseGrainedSchedulerBackend.java
├── scalastyle.xml
└── shims
│ ├── common
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ └── scala
│ │ ├── com
│ │ └── intel
│ │ │ └── raydp
│ │ │ └── shims
│ │ │ ├── SparkShimLoader.scala
│ │ │ ├── SparkShimProvider.scala
│ │ │ └── SparkShims.scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── executor
│ │ └── RayDPExecutorBackendFactory.scala
│ ├── pom.xml
│ ├── spark322
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ └── scala
│ │ ├── com
│ │ └── intel
│ │ │ └── raydp
│ │ │ └── shims
│ │ │ ├── SparkShimProvider.scala
│ │ │ └── SparkShims.scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ ├── TaskContextUtils.scala
│ │ ├── executor
│ │ └── RayDPSpark322ExecutorBackendFactory.scala
│ │ └── sql
│ │ └── SparkSqlUtils.scala
│ ├── spark330
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ └── scala
│ │ ├── com
│ │ └── intel
│ │ │ └── raydp
│ │ │ └── shims
│ │ │ ├── SparkShimProvider.scala
│ │ │ └── SparkShims.scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ ├── TaskContextUtils.scala
│ │ ├── executor
│ │ ├── RayCoarseGrainedExecutorBackend.scala
│ │ └── RayDPSpark330ExecutorBackendFactory.scala
│ │ └── sql
│ │ └── SparkSqlUtils.scala
│ ├── spark340
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ └── scala
│ │ ├── com
│ │ └── intel
│ │ │ └── raydp
│ │ │ └── shims
│ │ │ ├── SparkShimProvider.scala
│ │ │ └── SparkShims.scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ ├── TaskContextUtils.scala
│ │ ├── executor
│ │ ├── RayCoarseGrainedExecutorBackend.scala
│ │ └── RayDPSpark340ExecutorBackendFactory.scala
│ │ └── sql
│ │ └── SparkSqlUtils.scala
│ └── spark350
│ ├── pom.xml
│ └── src
│ └── main
│ ├── resources
│ └── META-INF
│ │ └── services
│ │ └── com.intel.raydp.shims.SparkShimProvider
│ └── scala
│ ├── com
│ └── intel
│ │ └── raydp
│ │ └── shims
│ │ ├── SparkShimProvider.scala
│ │ └── SparkShims.scala
│ └── org
│ └── apache
│ └── spark
│ ├── TaskContextUtils.scala
│ ├── executor
│ ├── RayCoarseGrainedExecutorBackend.scala
│ └── RayDPSpark350ExecutorBackendFactory.scala
│ └── sql
│ └── SparkSqlUtils.scala
├── doc
├── mpi.md
└── spark_on_ray.md
├── docker
├── Dockerfile
├── README.md
├── build-docker.sh
└── legacy.yaml
├── examples
├── README.md
├── data_process.py
├── fake_nyctaxi.csv
├── horovod_nyctaxi.py
├── pytorch_dlrm.ipynb
├── pytorch_nyctaxi.py
├── random_nyctaxi.py
├── raydp-submit.py
├── raytrain_nyctaxi.py
├── tensorflow_nyctaxi.py
├── tensorflow_titanic.ipynb
└── xgboost_ray_nyctaxi.py
├── python
├── MANIFEST.in
├── pylintrc
├── raydp
│ ├── __init__.py
│ ├── context.py
│ ├── estimator.py
│ ├── mpi
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── mpi_job.py
│ │ ├── mpi_worker.py
│ │ ├── network
│ │ │ ├── __init__.py
│ │ │ ├── network.proto
│ │ │ ├── network_pb2.py
│ │ │ └── network_pb2_grpc.py
│ │ └── utils.py
│ ├── ray_cluster_resources.py
│ ├── services.py
│ ├── spark
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ ├── interfaces.py
│ │ ├── parallel_iterator_worker.py
│ │ ├── ray_cluster.py
│ │ └── ray_cluster_master.py
│ ├── tests
│ │ ├── conftest.py
│ │ ├── test_data_owner_transfer.py
│ │ ├── test_mpi.py
│ │ ├── test_spark_cluster.py
│ │ ├── test_spark_utils.py
│ │ ├── test_tf.py
│ │ ├── test_torch.py
│ │ ├── test_torch_sequential.py
│ │ └── test_xgboost.py
│ ├── tf
│ │ ├── __init__.py
│ │ └── estimator.py
│ ├── torch
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── estimator.py
│ │ ├── torch_metrics.py
│ │ └── torch_ml_dataset.py
│ ├── utils.py
│ ├── versions.py
│ └── xgboost
│ │ ├── __init__.py
│ │ └── estimator.py
└── setup.py
└── tutorials
├── dataset
└── healthcare-dataset-stroke-data.csv
├── pytorch_example.ipynb
└── raytrain_example.ipynb
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | name: RayDP PyPi
19 |
20 | on:
21 | schedule:
22 | - cron: '0 0 * * *'
23 | # can manually trigger the workflow
24 | workflow_dispatch:
25 |
26 | permissions: # added using https://github.com/step-security/secure-repo
27 | contents: read
28 |
29 | jobs:
30 | build-and-publish:
31 | # do not run in forks
32 | if: ${{ github.repository_owner == 'oap-project' }}
33 | name: build wheel and upload
34 | runs-on: ubuntu-latest
35 | steps:
36 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
37 | - name: Set up Python 3.9
38 | uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4
39 | with:
40 | python-version: 3.9
41 | - name: Set up JDK 1.8
42 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
43 | with:
44 | java-version: 1.8
45 | - name: days since the commit date
46 | run: |
47 | :
48 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA)
49 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 ))
50 | if [ $days -eq 0 ]; then
51 | echo COMMIT_TODAY=true >> $GITHUB_ENV
52 | fi
53 | - name: Build wheel
54 | if: env.COMMIT_TODAY == 'true'
55 | env:
56 | RAYDP_BUILD_MODE: nightly
57 | run: pip install wheel grpcio-tools && ./build.sh
58 | - name: Upload
59 | if: env.COMMIT_TODAY == 'true'
60 | uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 # release/v1
61 | with:
62 | password: ${{ secrets.PYPI_API_TOKEN }}
63 |
--------------------------------------------------------------------------------
/.github/workflows/raydp.yml:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | name: RayDP CI
19 |
20 | on:
21 | push:
22 | branches: [main, master]
23 | pull_request:
24 | branches: [main, master]
25 | workflow_dispatch:
26 |
27 | permissions: # added using https://github.com/step-security/secure-repo
28 | contents: read
29 |
30 | jobs:
31 | build-and-test:
32 | strategy:
33 | matrix:
34 | os: [ubuntu-latest]
35 | python-version: [3.9, 3.10.14]
36 | spark-version: [3.2.4, 3.3.2, 3.4.0, 3.5.0]
37 | ray-version: [2.34.0, 2.40.0]
38 |
39 | runs-on: ${{ matrix.os }}
40 |
41 | steps:
42 | - uses: actions/checkout@ee0669bd1cc54295c223e0bb666b733df41de1c5 # v2.7.0
43 | - name: Set up Python ${{ matrix.python-version }}
44 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4
45 | with:
46 | python-version: ${{ matrix.python-version }}
47 | - name: Set up JDK 1.8
48 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
49 | with:
50 | java-version: 1.8
51 | - name: Install extra dependencies for macOS
52 | if: matrix.os == 'macos-latest'
53 | run: |
54 | brew install pkg-config
55 | brew install libuv libomp mpich
56 | - name: Install extra dependencies for Ubuntu
57 | if: matrix.os == 'ubuntu-latest'
58 | run: |
59 | sudo apt-get install -y mpich
60 | - name: Cache pip - Ubuntu
61 | if: matrix.os == 'ubuntu-latest'
62 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
63 | with:
64 | path: ~/.cache/pip
65 | key: ${{ matrix.os }}-${{ matrix.python-version }}-pip
66 | - name: Cache pip - MacOS
67 | if: matrix.os == 'macos-latest'
68 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
69 | with:
70 | path: ~/Library/Caches/pip
71 | key: ${{ matrix.os }}-${{ matrix.python-version }}-pip
72 | - name: Install dependencies
73 | run: |
74 | python -m pip install --upgrade pip
75 | pip install wheel
76 | pip install "numpy<1.24"
77 | pip install "pydantic<2.0"
78 | SUBVERSION=$(python -c 'import sys; print(sys.version_info[1])')
79 | if [ "$(uname -s)" == "Linux" ]
80 | then
81 | pip install torch --index-url https://download.pytorch.org/whl/cpu
82 | else
83 | pip install torch
84 | fi
85 | pip install pyarrow "ray[train]==${{ matrix.ray-version }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget
86 | pip install "xgboost_ray[default]<=0.1.13"
87 | pip install "xgboost<=2.0.3"
88 | pip install torchmetrics
89 | - name: Cache Maven
90 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
91 | with:
92 | path: ~/.m2
93 | key: ${{ matrix.os }}-m2-${{ hashFiles('core/pom.xml') }}
94 | - name: Build and install
95 | env:
96 | GITHUB_CI: 1
97 | run: |
98 | pip install pyspark==${{ matrix.spark-version }}
99 | ./build.sh
100 | pip install dist/raydp-*.whl
101 | - name: Lint
102 | run: |
103 | pip install pylint==2.8.3
104 | pylint --rcfile=python/pylintrc python/raydp
105 | pylint --rcfile=python/pylintrc examples/*.py
106 | - name: Test with pytest
107 | run: |
108 | ray start --head --num-cpus 6
109 | PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest python/raydp/tests/ -v
110 | ray stop --force
111 | - name: Test Examples
112 | run: |
113 | ray start --head
114 | python examples/raydp-submit.py
115 | ray stop
116 | python examples/pytorch_nyctaxi.py
117 | python examples/tensorflow_nyctaxi.py
118 | python examples/xgboost_ray_nyctaxi.py
119 | # python examples/raytrain_nyctaxi.py
120 | python examples/data_process.py
121 |
--------------------------------------------------------------------------------
/.github/workflows/raydp_nightly.yml:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | name: Legacy raydp_nightly PyPi
19 |
20 | on:
21 | schedule:
22 | - cron: '0 0 * * *'
23 | # can manually trigger the workflow
24 | workflow_dispatch:
25 |
26 | permissions: # added using https://github.com/step-security/secure-repo
27 | contents: read
28 |
29 | jobs:
30 | build-and-publish:
31 | # do not run in forks
32 | if: ${{ github.repository_owner == 'oap-project' }}
33 | name: build wheel and upload
34 | runs-on: ubuntu-latest
35 | steps:
36 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
37 | - name: Set up Python 3.9
38 | uses: actions/setup-python@0f07f7f756721ebd886c2462646a35f78a8bc4de # v1.2.4
39 | with:
40 | python-version: 3.9
41 | - name: Set up JDK 1.8
42 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
43 | with:
44 | java-version: 1.8
45 | - name: days since the commit date
46 | run: |
47 | :
48 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA)
49 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 ))
50 | if [ $days -eq 0 ]; then
51 | echo COMMIT_TODAY=true >> $GITHUB_ENV
52 | fi
53 | - name: Build wheel
54 | if: env.COMMIT_TODAY == 'true'
55 | env:
56 | RAYDP_PACKAGE_NAME: raydp_nightly
57 | run: pip install wheel grpcio-tools && ./build.sh
58 | - name: Upload
59 | if: env.COMMIT_TODAY == 'true'
60 | uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 # release/v1
61 | with:
62 | password: ${{ secrets.PYPI_API_TOKEN }}
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.iml
3 |
4 | __pycache__/
5 | build/
6 | dist/
7 | *.egg-info/
8 | *.eggs/
9 |
10 | dev/.tmp_dir/
11 | target/
12 | *.jar
13 |
14 | .DS_Store
15 |
16 | .vscode
17 | examples/.ipynb_checkpoints/
18 | .python-version
19 |
20 | # Vim temp files
21 | *.swp
22 | *.swo
23 | *.parquet
24 | *.crc
25 | _SUCCESS
26 |
27 | .metals/
28 | .bloop/
29 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Report a Vulnerability
4 |
5 | Please report security issues or vulnerabilities to the [Intel® Security Center].
6 |
7 | For more information on how Intel® works to resolve security issues, see
8 | [Vulnerability Handling Guidelines].
9 |
10 | [Intel® Security Center]:https://www.intel.com/security
11 |
12 | [Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html
13 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | set -ex
21 |
22 | if ! command -v mvn &> /dev/null
23 | then
24 | echo "mvn could not be found, please install maven first"
25 | exit
26 | else
27 | mvn_path=`which mvn`
28 | echo "Using ${mvn_path} for build core module"
29 | fi
30 |
31 | CURRENT_DIR="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
32 | DIST_PATH=${CURRENT_DIR}/dist/
33 |
34 | if [[ ! -d ${DIST_PATH} ]];
35 | then
36 | mkdir ${DIST_PATH}
37 | fi
38 |
39 | # build core part
40 | CORE_DIR="${CURRENT_DIR}/core"
41 | pushd ${CORE_DIR}
42 | if [[ -z $GITHUB_CI ]];
43 | then
44 | mvn clean package -q -DskipTests
45 | else
46 | mvn verify -q
47 | fi
48 | popd # core dir
49 |
50 | # build python part
51 | RAYDP_PACKAGE_NAME=${RAYDP_PACKAGE_NAME:-raydp}
52 | PYTHON_DIR="${CURRENT_DIR}/python"
53 |
54 | if [[ -d "${PYTHON_DIR}/build" ]];
55 | then
56 | rm -rf "${PYTHON_DIR}/build"
57 | fi
58 |
59 | pushd ${PYTHON_DIR}
60 | python setup.py bdist_wheel
61 | cp ${PYTHON_DIR}/dist/${RAYDP_PACKAGE_NAME}-* ${DIST_PATH}
62 | popd # python dir
63 |
64 | set +ex
65 |
--------------------------------------------------------------------------------
/core/agent/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-parent
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-agent
15 | RayDP Java Agent
16 | jar
17 |
18 |
19 |
20 |
21 | org.apache.maven.plugins
22 | maven-compiler-plugin
23 | 3.8.0
24 |
25 | 1.8
26 | 1.8
27 |
28 |
29 |
30 | org.apache.maven.plugins
31 | maven-jar-plugin
32 | 3.3.0
33 |
34 |
35 |
36 | org.apache.spark.raydp.Agent
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | org.apache.logging.log4j
47 | log4j-core
48 | 2.17.1
49 |
50 |
51 | org.apache.logging.log4j
52 | log4j-slf4j-impl
53 | 2.17.1
54 |
55 |
56 | org.slf4j
57 | slf4j-api
58 | 1.7.32
59 |
60 |
61 | com.intel
62 | raydp
63 | ${project.version}
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/core/agent/src/main/java/org/apache/spark/raydp/Agent.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.raydp;
19 |
20 | import org.slf4j.LoggerFactory;
21 |
22 | import java.io.File;
23 | import java.io.FileOutputStream;
24 | import java.io.IOException;
25 | import java.io.OutputStream;
26 | import java.io.OutputStreamWriter;
27 | import java.io.PrintStream;
28 | import java.io.Writer;
29 | import java.lang.instrument.Instrumentation;
30 | import java.lang.management.ManagementFactory;
31 | import java.nio.charset.Charset;
32 | import java.nio.charset.StandardCharsets;
33 |
34 |
35 | public class Agent {
36 |
37 | public static final PrintStream DEFAULT_ERR_PS = System.err;
38 |
39 | public static final PrintStream DEFAULT_OUT_PS = System.out;
40 |
41 | public static void premain(String agentArgs, Instrumentation inst)
42 | throws IOException {
43 | // redirect system output/error stream so that annoying SLF4J warnings
44 | // and other logs during binding
45 | // SLF4J factory don't show in spark-shell
46 | // Instead, the warnings and logs are kept in
47 | // /logs/slf4j-.log
48 |
49 | String pid = ManagementFactory.getRuntimeMXBean().getName()
50 | .split("@")[0];
51 | String logDir = System.getProperty("ray.logging.dir");
52 | if (logDir == null) {
53 | logDir = "/tmp/ray/session_latest/logs";
54 | System.getProperties().put("ray.logging.dir", logDir);
55 | }
56 |
57 | File parentDir = new File(logDir);
58 | if (!parentDir.exists()) {
59 | boolean flag = parentDir.mkdirs();
60 | if (!flag) {
61 | throw new RuntimeException("Error create log dir.");
62 | }
63 | }
64 |
65 | File logFile = new File(parentDir, "/slf4j-" + pid + ".log");
66 | try (PrintStream ps = new PrintStream(logFile, "UTF-8")) {
67 | System.setOut(ps);
68 | System.setErr(ps);
69 | // slf4j binding
70 | LoggerFactory.getLogger(Agent.class);
71 | } catch (Exception e) {
72 | e.printStackTrace();
73 | } finally {
74 | System.out.flush();
75 | System.err.flush();
76 | // restore system output/error stream
77 | System.setErr(DEFAULT_ERR_PS);
78 | System.setOut(DEFAULT_OUT_PS);
79 | }
80 | // below is to write ':job_id:' to first line of log file prefixed with 'java-worker' as required by
81 | // PR, https://github.com/ray-project/ray/pull/31772.
82 | // It's a workaround of the ray 2.3.[0-1] issue going to be fixed by https://github.com/ray-project/ray/pull/33665.
83 | String jobId = System.getenv("RAY_JOB_ID");
84 | String rayAddress = System.getProperty("ray.address");
85 | if (jobId != null && rayAddress != null) {
86 | String prefix = "java-worker";
87 | // TODO: uncomment after the ray PR #33665 released
88 | // String prefix = System.getProperty("ray.logging.file-prefix", "java-worker");
89 | // if ("java-worker".equals(prefix)) {
90 | File file = new File(new String((logDir + "/" + prefix + "-" + jobId + "-" + pid + ".log")
91 | .getBytes(Charset.forName("UTF-8")), "UTF-8"));
92 | try (OutputStream out = new FileOutputStream(file);
93 | Writer writer = new OutputStreamWriter(out, StandardCharsets.UTF_8)) {
94 | writer.write(":job_id:" + jobId + "\n");
95 | }
96 | // }
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/core/agent/src/main/java/org/slf4j/impl/StaticLoggerBinder.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.slf4j.impl;
19 |
20 | import org.apache.spark.raydp.Agent;
21 | import org.apache.spark.raydp.SparkOnRayConfigs;
22 | import org.slf4j.ILoggerFactory;
23 | import org.slf4j.spi.LoggerFactoryBinder;
24 |
25 | import java.io.PrintStream;
26 | import java.util.HashMap;
27 | import java.util.Map;
28 |
29 | /**
30 | * A delegation class to bind to slf4j so that we can have a chance to choose
31 | * which underlying log4j framework to use.
32 | */
33 | public class StaticLoggerBinder implements LoggerFactoryBinder {
34 |
35 | // for compatibility check
36 | public static final String REQUESTED_API_VERSION;
37 |
38 | private static final Class> LOGFACTORY_CLASS;
39 |
40 | private static final ILoggerFactory FACTORY;
41 |
42 | private static final StaticLoggerBinder _INSTANCE = new StaticLoggerBinder();
43 |
44 | private static final Map LOG_FACTORY_CLASSES
45 | = new HashMap<>();
46 |
47 | private static PrintStream subSystemErr;
48 |
49 | private static PrintStream subSystemOut;
50 |
51 | static {
52 | subSystemErr = System.err;
53 | subSystemOut = System.out;
54 |
55 | LOG_FACTORY_CLASSES.put("log4j",
56 | "org.slf4j.impl.Log4jLoggerFactory"); // log4j 1
57 | LOG_FACTORY_CLASSES.put("log4j2",
58 | "org.apache.logging.slf4j.Log4jLoggerFactory"); // log4j 2
59 |
60 | String factoryClzStr = System
61 | .getProperty(SparkOnRayConfigs.LOG4J_FACTORY_CLASS_KEY, "");
62 | if (factoryClzStr.length() == 0) {
63 | System.err.println("ERROR: system property '"
64 | + SparkOnRayConfigs.LOG4J_FACTORY_CLASS_KEY
65 | + "' needs to be specified for slf4j binding");
66 | LOGFACTORY_CLASS = null;
67 | FACTORY = null;
68 | } else {
69 | String mappedClsStr = LOG_FACTORY_CLASSES.get(factoryClzStr);
70 | if (mappedClsStr == null) {
71 | mappedClsStr = factoryClzStr;
72 | }
73 | // restore to system default stream so that log4j console appender
74 | // can be correctly set
75 | System.setErr(Agent.DEFAULT_ERR_PS);
76 | System.setOut(Agent.DEFAULT_OUT_PS);
77 | Class> tempClass = null;
78 | try {
79 | tempClass = Class.forName(mappedClsStr);
80 | } catch (Exception e) {
81 | e.printStackTrace();
82 | } finally {
83 | LOGFACTORY_CLASS = tempClass;
84 | }
85 | StringBuilder sb = new StringBuilder();
86 | sb.append("mapped factory class: ").append(mappedClsStr)
87 | .append(". load ");
88 | if (LOGFACTORY_CLASS != null) {
89 | sb.append(LOGFACTORY_CLASS.getName());
90 | try {
91 | String loc = LOGFACTORY_CLASS.getProtectionDomain().getCodeSource()
92 | .getLocation().toURI().toString();
93 | sb.append(" from ").append(loc);
94 | } catch (Exception e) {
95 | e.printStackTrace();
96 | }
97 | } else {
98 | sb.append("failed");
99 | }
100 |
101 | ILoggerFactory tmpFactory = null;
102 | try {
103 | tmpFactory = (ILoggerFactory) tempClass.newInstance();
104 | } catch (Exception e) {
105 | e.printStackTrace();
106 | } finally {
107 | FACTORY = tmpFactory;
108 | }
109 | // set to substitute stream for capturing remaining logs
110 | System.setErr(subSystemErr);
111 | System.setOut(subSystemOut);
112 | System.out.println(sb);
113 | }
114 | REQUESTED_API_VERSION = "1.6.66";
115 | }
116 |
117 | public static final StaticLoggerBinder getSingleton() {
118 | return _INSTANCE;
119 | }
120 |
121 | @Override
122 | public ILoggerFactory getLoggerFactory() {
123 | // restore to system default stream so that log4j console appender
124 | // can be correctly set
125 | if (System.out != Agent.DEFAULT_OUT_PS) {
126 | System.setOut(Agent.DEFAULT_OUT_PS);
127 | }
128 | if (System.err != Agent.DEFAULT_ERR_PS) {
129 | System.setErr(Agent.DEFAULT_ERR_PS);
130 | }
131 | return FACTORY;
132 | }
133 |
134 | @Override
135 | public String getLoggerFactoryClassStr() {
136 | return LOGFACTORY_CLASS.getName();
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/core/javastyle-suppressions.xml:
--------------------------------------------------------------------------------
1 |
17 |
18 |
21 |
22 |
29 |
30 |
31 |
33 |
35 |
37 |
39 |
41 |
43 |
45 |
47 |
49 |
51 |
52 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/ExternalShuffleServiceUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp;
19 |
20 | import java.util.List;
21 |
22 | import io.ray.api.ActorHandle;
23 | import io.ray.api.Ray;
24 |
25 | public class ExternalShuffleServiceUtils {
26 | public static ActorHandle createShuffleService(
27 | String node, List options) {
28 | return Ray.actor(RayExternalShuffleService::new)
29 | .setResource("node:" + node, 0.01)
30 | .setJvmOptions(options).remote();
31 | }
32 |
33 | public static void startShuffleService(
34 | ActorHandle handle) {
35 | handle.task(RayExternalShuffleService::start).remote();
36 | }
37 |
38 | public static void stopShuffleService(
39 | ActorHandle handle) {
40 | handle.task(RayExternalShuffleService::stop).remote();
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp;
19 |
20 | import java.util.List;
21 | import java.util.Map;
22 |
23 | import io.ray.api.ActorHandle;
24 | import io.ray.api.Ray;
25 | import io.ray.api.call.ActorCreator;
26 | import org.apache.spark.raydp.SparkOnRayConfigs;
27 |
28 | public class RayAppMasterUtils {
29 | public static ActorHandle createAppMaster(
30 | String cp,
31 | String name,
32 | List jvmOptions,
33 | Map appMasterResource) {
34 | ActorCreator creator = Ray.actor(RayAppMaster::new, cp);
35 | if (name != null) {
36 | creator.setName(name);
37 | }
38 | jvmOptions.add("-cp");
39 | jvmOptions.add(cp);
40 | creator.setJvmOptions(jvmOptions);
41 | for(Map.Entry resource : appMasterResource.entrySet()) {
42 | String resourceName = resource.getKey()
43 | .substring(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX.length() + 1);
44 | creator.setResource(resourceName, resource.getValue());
45 | }
46 |
47 | return creator.remote();
48 | }
49 |
50 | public static String getMasterUrl(
51 | ActorHandle handle) {
52 | return handle.task(RayAppMaster::getMasterUrl).remote().get();
53 | }
54 |
55 | public static Map getRestartedExecutors(
56 | ActorHandle handle) {
57 | return handle.task(RayAppMaster::getRestartedExecutors).remote().get();
58 | }
59 |
60 | public static void stopAppMaster(
61 | ActorHandle handle) {
62 | handle.task(RayAppMaster::stop).remote().get();
63 | handle.kill();
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/raydp/RayDPUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.raydp;
19 |
20 | import io.ray.api.ObjectRef;
21 | import io.ray.api.Ray;
22 | import io.ray.api.id.ObjectId;
23 | import io.ray.runtime.AbstractRayRuntime;
24 | import io.ray.runtime.object.ObjectRefImpl;
25 |
26 | public class RayDPUtils {
27 |
28 | /**
29 | * Convert ObjectRef to subclass ObjectRefImpl. Throw RuntimeException if it is not instance
30 | * of ObjectRefImpl. We can't import the ObjectRefImpl in scala code, so we do the
31 | * conversion at here.
32 | */
33 | public static ObjectRefImpl convert(ObjectRef obj) {
34 | if (obj instanceof ObjectRefImpl) {
35 | return (ObjectRefImpl)obj;
36 | } else {
37 | throw new RuntimeException(obj.getClass() + " is not ObjectRefImpl");
38 | }
39 | }
40 |
41 | /**
42 | * Create ObjectRef from Array[Byte] and register ownership.
43 | * We can't import the ObjectRefImpl in scala code, so we do the conversion at here.
44 | */
45 | public static ObjectRef readBinary(byte[] obj, Class clazz, byte[] ownerAddress) {
46 | ObjectId id = new ObjectId(obj);
47 | ObjectRefImpl ref = new ObjectRefImpl<>(id, clazz, false);
48 | AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
49 | runtime.getObjectStore().registerOwnershipInfoAndResolveFuture(
50 | id, null, ownerAddress
51 | );
52 | return ref;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.raydp;
19 |
20 | import io.ray.api.ActorHandle;
21 | import io.ray.api.ObjectRef;
22 | import io.ray.api.Ray;
23 | import io.ray.api.call.ActorCreator;
24 | import java.util.Map;
25 | import java.util.List;
26 |
27 | import io.ray.api.placementgroup.PlacementGroup;
28 | import io.ray.runtime.object.ObjectRefImpl;
29 | import org.apache.spark.executor.RayDPExecutor;
30 |
31 | public class RayExecutorUtils {
32 | /**
33 | * Convert from mbs -> memory units. The memory units in ray is byte
34 | */
35 |
36 | private static double toMemoryUnits(int memoryInMB) {
37 | double result = 1.0 * memoryInMB * 1024 * 1024;
38 | return Math.round(result);
39 | }
40 |
41 | public static ActorHandle createExecutorActor(
42 | String executorId,
43 | String appMasterURL,
44 | double cores,
45 | int memoryInMB,
46 | Map resources,
47 | PlacementGroup placementGroup,
48 | int bundleIndex,
49 | List javaOpts) {
50 | ActorCreator creator = Ray.actor(
51 | RayDPExecutor::new, executorId, appMasterURL);
52 | creator.setName("raydp-executor-" + executorId);
53 | creator.setJvmOptions(javaOpts);
54 | creator.setResource("CPU", cores);
55 | creator.setResource("memory", toMemoryUnits(memoryInMB));
56 |
57 | for (Map.Entry entry: resources.entrySet()) {
58 | creator.setResource(entry.getKey(), entry.getValue());
59 | }
60 | if (placementGroup != null) {
61 | creator.setPlacementGroup(placementGroup, bundleIndex);
62 | }
63 | creator.setMaxRestarts(-1);
64 | creator.setMaxTaskRetries(-1);
65 | creator.setMaxConcurrency(2);
66 | return creator.remote();
67 | }
68 |
69 | public static void setUpExecutor(
70 | ActorHandle handler,
71 | String appId,
72 | String driverUrl,
73 | int cores,
74 | String classPathEntries) {
75 | handler.task(RayDPExecutor::startUp,
76 | appId, driverUrl, cores, classPathEntries).remote();
77 | }
78 |
79 | public static String[] getBlockLocations(
80 | ActorHandle handler,
81 | int rddId,
82 | int numPartitions) {
83 | return handler.task(RayDPExecutor::getBlockLocations,
84 | rddId, numPartitions).remote().get();
85 | }
86 |
87 | public static ObjectRef getRDDPartition(
88 | ActorHandle handle,
89 | int rddId,
90 | int partitionId,
91 | String schema,
92 | String driverAgentUrl) {
93 | return (ObjectRefImpl) handle.task(
94 | RayDPExecutor::getRDDPartition,
95 | rddId, partitionId, schema, driverAgentUrl).remote();
96 | }
97 |
98 | public static void exitExecutor(
99 | ActorHandle handle
100 | ) {
101 | handle.task(RayDPExecutor::stop).remote();
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager:
--------------------------------------------------------------------------------
1 | org.apache.spark.scheduler.cluster.raydp.RayClusterManager
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/RayDPException.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark
19 |
20 | class RayDPException(message: String, cause: Throwable)
21 | extends SparkException(message, cause) {
22 | def this(message: String) = this(message, null)
23 | }
24 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterEntryPoint.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import java.io.{DataOutputStream, File, FileOutputStream}
21 | import java.net.InetAddress
22 | import java.nio.file.Files
23 |
24 | import scala.util.Try
25 |
26 | import py4j.GatewayServer
27 |
28 | import org.apache.spark.internal.Logging
29 |
30 |
31 | class AppMasterEntryPoint {
32 | private val appMaster: AppMasterJavaBridge = new AppMasterJavaBridge()
33 |
34 | def getAppMasterBridge(): AppMasterJavaBridge = {
35 | appMaster
36 | }
37 | }
38 |
39 | object AppMasterEntryPoint extends Logging {
40 | private val localhost = InetAddress.getLoopbackAddress()
41 |
42 | def getGatewayServer(): GatewayServer = {
43 | new GatewayServer.GatewayServerBuilder()
44 | .javaPort(0)
45 | .javaAddress(localhost)
46 | .entryPoint(new AppMasterEntryPoint())
47 | .build()
48 | }
49 |
50 | def main(args: Array[String]): Unit = {
51 |
52 | var server = getGatewayServer()
53 |
54 | while(true) {
55 | if (!Try(server.start()).isFailure) {
56 | val boundPort: Int = server.getListeningPort()
57 | if (boundPort == -1) {
58 | logError(s"${server.getClass} failed to bind; exiting")
59 | System.exit(1)
60 | } else {
61 | logDebug(s"Started PythonGatewayServer on port $boundPort")
62 | }
63 |
64 |
65 | val connectionInfoPath = new File(sys.env("_RAYDP_APPMASTER_CONN_INFO_PATH"))
66 | val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
67 | "connection", ".info").toFile()
68 |
69 | val dos = new DataOutputStream(new FileOutputStream(tmpPath))
70 | dos.writeInt(boundPort)
71 | dos.close()
72 |
73 | if (!tmpPath.renameTo(connectionInfoPath)) {
74 | logError(s"Unable to write connection information to $connectionInfoPath.")
75 | System.exit(1)
76 | }
77 |
78 | // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
79 | while (System.in.read() != -1) {
80 | // Do nothing
81 | }
82 | logDebug("Exiting due to broken pipe from Python driver")
83 | System.exit(0)
84 | } else {
85 | server.shutdown()
86 | logError(s"${server.getClass} failed to bind; retrying...")
87 | Thread.sleep(1000)
88 | server = getGatewayServer()
89 | }
90 | }
91 |
92 |
93 |
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterJavaBridge.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import java.util.Map
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import io.ray.api.{ActorHandle, Ray}
25 |
26 | import org.apache.spark.raydp.SparkOnRayConfigs
27 |
28 | class AppMasterJavaBridge {
29 | private var handle: ActorHandle[RayAppMaster] = null
30 |
31 | def startUpAppMaster(extra_cp: String, sparkProps: Map[String, Any]): Unit = {
32 | if (handle == null) {
33 | // init ray, we should set the config by java properties
34 | Ray.init()
35 | val name = RayAppMaster.ACTOR_NAME
36 | val sparkJvmOptions = sparkProps.asScala.toMap.filter(
37 | e => !SparkOnRayConfigs.SPARK_DRIVER_EXTRA_JAVA_OPTIONS.equals(e._1))
38 | .map {
39 | case (k, v) =>
40 | if (!SparkOnRayConfigs.SPARK_JAVAAGENT.equals(k)) {
41 | "-D" + k + "=" + v
42 | } else {
43 | "-javaagent:" + v
44 | }
45 | }.toBuffer
46 |
47 | val appMasterResources = sparkProps.asScala.filter {
48 | case (k, v) => k.startsWith(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX)
49 | }.map{ case (k, v) => k->double2Double(v.toString.toDouble) }.asJava
50 |
51 | handle = RayAppMasterUtils.createAppMaster(
52 | extra_cp, name,
53 | (sparkJvmOptions ++ Seq(SparkOnRayConfigs.RAYDP_LOGFILE_PREFIX_CFG)).asJava,
54 | appMasterResources)
55 | }
56 | }
57 |
58 | def getMasterUrl(): String = {
59 | if (handle == null) {
60 | throw new RuntimeException("You should create the RayAppMaster handle first")
61 | }
62 | RayAppMasterUtils.getMasterUrl(handle)
63 | }
64 |
65 | def stop(): Unit = {
66 | if (handle != null) {
67 | RayAppMasterUtils.stopAppMaster(handle)
68 | Ray.shutdown()
69 | handle = null
70 | }
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationDescription.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import scala.collection.Map
21 |
22 | private[spark] case class Command(
23 | driverUrl: String,
24 | environment: Map[String, String],
25 | classPathEntries: Seq[String],
26 | libraryPathEntries: Seq[String],
27 | javaOpts: Seq[String]) {
28 |
29 | def withNewJavaOpts(newJavaOptions: Seq[String]): Command = {
30 | Command(driverUrl, environment, classPathEntries, libraryPathEntries, newJavaOptions)
31 | }
32 | }
33 |
34 | private[spark] case class ApplicationDescription(
35 | name: String,
36 | numExecutors: Int,
37 | coresPerExecutor: Option[Int],
38 | memoryPerExecutorMB: Int,
39 | rayActorCPU: Double,
40 | command: Command,
41 | user: String = System.getProperty("user.name", ""),
42 | resourceReqsPerExecutor: Map[String, Double] = Map.empty) {
43 |
44 | def withNewCommand(newCommand: Command): ApplicationDescription = {
45 | ApplicationDescription(name = name,
46 | numExecutors = numExecutors, coresPerExecutor = coresPerExecutor,
47 | memoryPerExecutorMB = memoryPerExecutorMB, command = newCommand, user = user,
48 | resourceReqsPerExecutor = resourceReqsPerExecutor,
49 | rayActorCPU = rayActorCPU)
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationState.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | object ApplicationState extends Enumeration {
21 |
22 | type ApplicationState = Value
23 |
24 | val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value
25 | }
26 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import org.apache.spark.rpc.RpcEndpointRef
21 |
22 | private[deploy] sealed trait RayDPDeployMessage extends Serializable
23 |
24 | case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef)
25 | extends RayDPDeployMessage
26 |
27 | case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends RayDPDeployMessage
28 |
29 | case class UnregisterApplication(appId: String) extends RayDPDeployMessage
30 |
31 | case class RegisterExecutor(executorId: String, nodeIp: String) extends RayDPDeployMessage
32 |
33 | case class ExecutorStarted(executorId: String) extends RayDPDeployMessage
34 |
35 | case class RequestExecutors(appId: String, requestedTotal: Int) extends RayDPDeployMessage
36 |
37 | case class KillExecutors(appId: String, executorIds: Seq[String]) extends RayDPDeployMessage
38 |
39 | case class RequestAddPendingRestartedExecutor(executorId: String)
40 | extends RayDPDeployMessage
41 |
42 | case class AddPendingRestartedExecutorReply(newExecutorId: Option[String])
43 | extends RayDPDeployMessage
44 |
45 | case class RecacheRDD(rddId: Int) extends RayDPDeployMessage
46 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayDPDriverAgent.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import io.ray.runtime.config.RayConfig
21 |
22 | import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
23 | import org.apache.spark.internal.Logging
24 | import org.apache.spark.rpc._
25 |
26 |
27 | class RayDPDriverAgent() {
28 | private val spark = SparkContext.getOrCreate()
29 | private var endpoint: RpcEndpointRef = _
30 | private var rpcEnv: RpcEnv = _
31 | private val conf: SparkConf = new SparkConf()
32 |
33 | init
34 |
35 | def init(): Unit = {
36 | val securityMgr = new SecurityManager(conf)
37 | val host = RayConfig.create().nodeIp
38 | rpcEnv = RpcEnv.create(
39 | RayAppMaster.ENV_NAME,
40 | host,
41 | host,
42 | 0,
43 | conf,
44 | securityMgr,
45 | // limit to single-thread
46 | numUsableCores = 1,
47 | clientMode = false)
48 | // register endpoint
49 | endpoint = rpcEnv.setupEndpoint(RayDPDriverAgent.ENDPOINT_NAME,
50 | new RayDPDriverAgentEndpoint(rpcEnv))
51 | }
52 |
53 | def getDriverAgentEndpointUrl(): String = {
54 | RpcEndpointAddress(rpcEnv.address, RayDPDriverAgent.ENDPOINT_NAME).toString
55 | }
56 |
57 | class RayDPDriverAgentEndpoint(override val rpcEnv: RpcEnv)
58 | extends ThreadSafeRpcEndpoint with Logging {
59 | override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
60 | case RecacheRDD(rddId) =>
61 | // TODO if multiple blocks get lost, should call this only once
62 | // SparkEnv.get.blockManagerMaster.getLocationsAndStatus()
63 | spark.getPersistentRDDs.map {
64 | case (id, rdd) =>
65 | if (id == rddId) {
66 | rdd.count
67 | }
68 | }
69 | context.reply(true)
70 | }
71 | }
72 |
73 | }
74 |
75 | object RayDPDriverAgent {
76 | val ENDPOINT_NAME = "RAYDP_DRIVER_AGENT"
77 | }
78 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayExternalShuffleService.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import io.ray.api.Ray;
21 |
22 | import org.apache.spark.{SecurityManager, SparkConf}
23 | import org.apache.spark.deploy.ExternalShuffleService
24 | import org.apache.spark.internal.Logging
25 |
26 | class RayExternalShuffleService() extends Logging {
27 | val conf = new SparkConf()
28 | val mgr = new SecurityManager(conf)
29 | val instance = new ExternalShuffleService(conf, mgr)
30 |
31 | def start(): Unit = {
32 | instance.start()
33 | }
34 |
35 | def stop(): Unit = {
36 | instance.stop()
37 | Ray.exitActor()
38 | }
39 | }
40 |
41 | object RayExternalShuffleService {
42 | def getShuffleConf(conf: SparkConf): Array[String] = {
43 | // all conf needed by external shuffle service
44 | var shuffleConf = conf.getAll.filter {
45 | case (k, v) => k.startsWith("spark.shuffle")
46 | }.map {
47 | case (k, v) =>
48 | "-D" + k + "=" + v
49 | }
50 | val localDirKey = "spark.local.dir"
51 | if (conf.contains(localDirKey)) {
52 | shuffleConf = shuffleConf :+
53 | "-D" + localDirKey + "=" + conf.get(localDirKey)
54 | }
55 | shuffleConf
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.rdd
19 |
20 | import java.util.List;
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import io.ray.runtime.generated.Common.Address
25 |
26 | import org.apache.spark.{Partition, SparkContext, TaskContext}
27 | import org.apache.spark.api.java.JavaSparkContext
28 | import org.apache.spark.raydp.RayDPUtils
29 | import org.apache.spark.sql.raydp.ObjectStoreReader
30 |
31 | private[spark] class RayDatasetRDDPartition(val ref: Array[Byte], idx: Int) extends Partition {
32 | val index = idx
33 | }
34 |
35 | private[spark]
36 | class RayDatasetRDD(
37 | jsc: JavaSparkContext,
38 | @transient val objectIds: List[Array[Byte]],
39 | locations: List[Array[Byte]])
40 | extends RDD[Array[Byte]](jsc.sc, Nil) {
41 |
42 | override def getPartitions: Array[Partition] = {
43 | objectIds.asScala.zipWithIndex.map { case (k, i) =>
44 | new RayDatasetRDDPartition(k, i).asInstanceOf[Partition]
45 | }.toArray
46 | }
47 |
48 | override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
49 | val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
50 | ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index))
51 | }
52 |
53 | override def getPreferredLocations(split: Partition): Seq[String] = {
54 | val address = Address.parseFrom(locations.get(split.index))
55 | Seq(address.getIpAddress())
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.rdd
19 |
20 | import java.util.List;
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import io.ray.runtime.generated.Common.Address
25 |
26 | import org.apache.spark.{Partition, SparkContext, TaskContext}
27 | import org.apache.spark.raydp.RayDPUtils
28 | import org.apache.spark.sql.Row
29 |
30 | private[spark] class RayObjectRefRDDPartition(idx: Int) extends Partition {
31 | val index = idx
32 | }
33 |
34 | private[spark]
35 | class RayObjectRefRDD(
36 | sc: SparkContext,
37 | locations: List[Array[Byte]])
38 | extends RDD[Row](sc, Nil) {
39 |
40 | override def getPartitions: Array[Partition] = {
41 | (0 until locations.size()).map { i =>
42 | new RayObjectRefRDDPartition(i).asInstanceOf[Partition]
43 | }.toArray
44 | }
45 |
46 | override def compute(split: Partition, context: TaskContext): Iterator[Row] = {
47 | (Row(split.index) :: Nil).iterator
48 | }
49 |
50 | override def getPreferredLocations(split: Partition): Seq[String] = {
51 | Seq(Address.parseFrom(locations.get(split.index)).getIpAddress())
52 | }
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayClusterManager.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.scheduler.cluster.raydp
19 |
20 | import org.apache.spark.SparkContext
21 | import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
22 |
23 | private[spark] class RayClusterManager extends ExternalClusterManager {
24 |
25 | override def canCreate(masterURL: String): Boolean = {
26 | masterURL.startsWith("ray")
27 | }
28 |
29 | override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = {
30 | new TaskSchedulerImpl(sc)
31 | }
32 |
33 | override def createSchedulerBackend(
34 | sc: SparkContext,
35 | masterURL: String,
36 | scheduler: TaskScheduler): SchedulerBackend = {
37 | new RayCoarseGrainedSchedulerBackend(
38 | sc,
39 | scheduler.asInstanceOf[TaskSchedulerImpl],
40 | masterURL)
41 | }
42 |
43 | override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
44 | scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.raydp
19 |
20 | import java.io.ByteArrayInputStream
21 | import java.nio.channels.{Channels, ReadableByteChannel}
22 | import java.util.List
23 |
24 | import com.intel.raydp.shims.SparkShimLoader
25 |
26 | import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
27 | import org.apache.spark.raydp.RayDPUtils
28 | import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD}
29 | import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
30 | import org.apache.spark.sql.catalyst.expressions.GenericRow
31 | import org.apache.spark.sql.execution.arrow.ArrowConverters
32 | import org.apache.spark.sql.types.{IntegerType, StructType}
33 |
34 | object ObjectStoreReader {
35 | def createRayObjectRefDF(
36 | spark: SparkSession,
37 | locations: List[Array[Byte]]): DataFrame = {
38 | val rdd = new RayObjectRefRDD(spark.sparkContext, locations)
39 | val schema = new StructType().add("idx", IntegerType)
40 | spark.createDataFrame(rdd, schema)
41 | }
42 |
43 | def RayDatasetToDataFrame(
44 | sparkSession: SparkSession,
45 | rdd: RayDatasetRDD,
46 | schema: String): DataFrame = {
47 | SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession)
48 | }
49 |
50 | def getBatchesFromStream(
51 | ref: Array[Byte],
52 | ownerAddress: Array[Byte]): Iterator[Array[Byte]] = {
53 | val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress)
54 | ArrowConverters.getBatchesFromStream(
55 | Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/test/org/apache/spark/scheduler/cluster/raydp/TestRayCoarseGrainedSchedulerBackend.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 | package org.apache.spark.scheduler.cluster.raydp;
18 |
19 | import org.apache.spark.SparkConf;
20 | import org.junit.jupiter.api.Test;
21 |
22 | import org.apache.spark.scheduler.cluster.SchedulerBackendUtils;
23 |
24 | import static org.junit.jupiter.api.Assertions.assertEquals;
25 |
26 | /**
27 | * This class performs unit testing on some methods in `RayCoarseGrainedSchedulerBackend`.
28 | */
29 | public class TestRayCoarseGrainedSchedulerBackend {
30 |
31 | // Test using the default value.
32 | @Test
33 | public void testExecutorNumberWithDefaultConfig() {
34 | SparkConf conf = new SparkConf();
35 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
36 | assertEquals(2, executorNumber);
37 | }
38 |
39 | // Test using a negative value.
40 | @Test
41 | public void testExecutorNumberWithNegativeConfig() {
42 | SparkConf conf = new SparkConf();
43 | conf.set("spark.dynamicAllocation.initialExecutors", "-1");
44 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
45 | assertEquals(2, executorNumber);
46 | }
47 |
48 | // Test using reasonable values.
49 | @Test
50 | public void testExecutorNumberWithValidConfig() {
51 | SparkConf conf = new SparkConf();
52 | conf.set("spark.executor.instances", "5");
53 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
54 | assertEquals(5, executorNumber);
55 | }
56 |
57 | // Test using dynamic values.
58 | @Test
59 | public void testExecutorNumberWithDynamicConfig() {
60 | SparkConf conf = new SparkConf();
61 | conf.set("spark.dynamicAllocation.enabled", "true");
62 | conf.set("spark.dynamicAllocation.minExecutors", "3");
63 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
64 | assertEquals(3, executorNumber);
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/core/shims/common/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-common
15 | RayDP Shims Common
16 | 1.7.0-SNAPSHOT
17 | jar
18 |
19 |
20 |
21 |
22 | org.scalastyle
23 | scalastyle-maven-plugin
24 |
25 |
26 | net.alchim31.maven
27 | scala-maven-plugin
28 | 3.2.2
29 |
30 |
31 | scala-compile-first
32 | process-resources
33 |
34 | compile
35 |
36 |
37 |
38 | scala-test-compile-first
39 | process-test-resources
40 |
41 | testCompile
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | org.apache.spark
52 | spark-sql_${scala.binary.version}
53 | ${spark.version}
54 | provided
55 |
56 |
57 | com.google.protobuf
58 | protobuf-java
59 |
60 |
61 |
62 |
63 | org.apache.spark
64 | spark-core_${scala.binary.version}
65 | ${spark.version}
66 | provided
67 |
68 |
69 | org.xerial.snappy
70 | snappy-java
71 |
72 |
73 | org.apache.commons
74 | commons-compress
75 |
76 |
77 | org.apache.commons
78 | commons-text
79 |
80 |
81 | org.apache.ivy
82 | ivy
83 |
84 |
85 | log4j
86 | log4j
87 |
88 |
89 |
90 |
91 | org.xerial.snappy
92 | snappy-java
93 | ${snappy.version}
94 |
95 |
96 | org.apache.commons
97 | commons-text
98 | ${commons.text.version}
99 |
100 |
101 | com.google.protobuf
102 | protobuf-java
103 | ${protobuf.version}
104 |
105 |
106 | org.apache.ivy
107 | ivy
108 | ${ivy.version}
109 |
110 |
111 | org.apache.commons
112 | commons-compress
113 | ${commons.compress.version}
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimLoader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims
19 |
20 | import java.util.ServiceLoader
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import org.apache.spark.SPARK_VERSION_SHORT
25 | import org.apache.spark.internal.Logging
26 |
27 | object SparkShimLoader extends Logging {
28 | private var sparkShims: SparkShims = null
29 | private var sparkShimProviderClass: String = null
30 |
31 | def getSparkShims: SparkShims = {
32 | if (sparkShims == null) {
33 | val provider = getSparkShimProvider()
34 | sparkShims = provider.createShim
35 | }
36 | sparkShims
37 | }
38 |
39 | def getSparkVersion: String = {
40 | SPARK_VERSION_SHORT
41 | }
42 |
43 | def setSparkShimProviderClass(providerClass: String): Unit = {
44 | sparkShimProviderClass = providerClass
45 | }
46 |
47 | private def loadSparkShimProvider(): SparkShimProvider = {
48 | // Match and load Shim provider for current Spark version.
49 | val sparkVersion = getSparkVersion
50 | logInfo(s"Loading Spark Shims for version: $sparkVersion")
51 |
52 | // Load and filter the providers based on version
53 | val shimProviders =
54 | ServiceLoader.load(classOf[SparkShimProvider]).asScala.filter(_.matches(sparkVersion))
55 | if (shimProviders.size > 1) {
56 | throw new IllegalStateException(s"More than one SparkShimProvider found: $shimProviders")
57 | }
58 |
59 | val shimProvider = shimProviders.headOption match {
60 | case Some(shimProvider) => shimProvider
61 | case None =>
62 | throw new IllegalStateException(s"No Spark Shim Provider found for $sparkVersion")
63 | }
64 | logInfo(s"Using Shim provider: $shimProviders")
65 | shimProvider
66 | }
67 |
68 | private def getSparkShimProvider(): SparkShimProvider = {
69 | if (sparkShimProviderClass != null) {
70 | logInfo(s"Using Spark Shim Provider specified by $sparkShimProviderClass. ")
71 | val providerClass = Class.forName(sparkShimProviderClass)
72 | val providerConstructor = providerClass.getConstructor()
73 | providerConstructor.newInstance().asInstanceOf[SparkShimProvider]
74 | } else {
75 | loadSparkShimProvider()
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims
19 |
20 | /**
21 | * Provider interface for matching and retrieving the Shims of a specific Spark version
22 | */
23 | trait SparkShimProvider {
24 | def matches(version:String): Boolean
25 | def createShim: SparkShims
26 | }
27 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.{SparkEnv, TaskContext}
22 | import org.apache.spark.api.java.JavaRDD
23 | import org.apache.spark.executor.RayDPExecutorBackendFactory
24 | import org.apache.spark.sql.types.StructType
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 |
27 | sealed abstract class ShimDescriptor
28 |
29 | case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimDescriptor {
30 | override def toString(): String = s"$major.$minor.$patch"
31 | }
32 |
33 | trait SparkShims {
34 | def getShimDescriptor: ShimDescriptor
35 |
36 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame
37 |
38 | def getExecutorBackendFactory(): RayDPExecutorBackendFactory
39 |
40 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext
41 |
42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema
43 | }
44 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/org/apache/spark/executor/RayDPExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.rpc.RpcEnv
24 | import org.apache.spark.resource.ResourceProfile
25 |
26 | trait RayDPExecutorBackendFactory {
27 | def createExecutorBackend(
28 | rpcEnv: RpcEnv,
29 | driverUrl: String,
30 | executorId: String,
31 | bindAddress: String,
32 | hostname: String,
33 | cores: Int,
34 | userClassPath: Seq[URL],
35 | env: SparkEnv,
36 | resourcesFileOpt: Option[String],
37 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend
38 | }
39 |
--------------------------------------------------------------------------------
/core/shims/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-parent
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims
15 | RayDP Shims
16 | pom
17 |
18 |
19 | common
20 | spark322
21 | spark330
22 | spark340
23 | spark350
24 |
25 |
26 |
27 | 2.12
28 | 4.3.0
29 | 3.2.2
30 |
31 |
32 |
33 |
34 |
35 | net.alchim31.maven
36 | scala-maven-plugin
37 | ${scala.plugin.version}
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/core/shims/spark322/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark322
15 | RayDP Shims for Spark 3.2.2
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark322.version}
70 | provided
71 |
72 |
73 | com.google.protobuf
74 | protobuf-java
75 |
76 |
77 |
78 |
79 | org.apache.spark
80 | spark-core_${scala.binary.version}
81 | ${spark322.version}
82 | provided
83 |
84 |
85 | org.xerial.snappy
86 | snappy-java
87 |
88 |
89 | org.apache.commons
90 | commons-compress
91 |
92 |
93 | org.apache.commons
94 | commons-text
95 |
96 |
97 | org.apache.ivy
98 | ivy
99 |
100 |
101 | log4j
102 | log4j
103 |
104 |
105 |
106 |
107 | org.xerial.snappy
108 | snappy-java
109 | ${snappy.version}
110 |
111 |
112 | org.apache.commons
113 | commons-compress
114 | ${commons.compress.version}
115 |
116 |
117 | org.apache.commons
118 | commons-text
119 | ${commons.text.version}
120 |
121 |
122 | org.apache.ivy
123 | ivy
124 | ${ivy.version}
125 |
126 |
127 | com.google.protobuf
128 | protobuf-java
129 | ${protobuf.version}
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark322.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark322
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK311_DESCRIPTOR = SparkShimDescriptor(3, 1, 1)
24 | val SPARK312_DESCRIPTOR = SparkShimDescriptor(3, 1, 2)
25 | val SPARK313_DESCRIPTOR = SparkShimDescriptor(3, 1, 3)
26 | val SPARK320_DESCRIPTOR = SparkShimDescriptor(3, 2, 0)
27 | val SPARK321_DESCRIPTOR = SparkShimDescriptor(3, 2, 1)
28 | val SPARK322_DESCRIPTOR = SparkShimDescriptor(3, 2, 2)
29 | val SPARK323_DESCRIPTOR = SparkShimDescriptor(3, 2, 3)
30 | val SPARK324_DESCRIPTOR = SparkShimDescriptor(3, 2, 4)
31 | val DESCRIPTOR_STRINGS =
32 | Seq(s"$SPARK311_DESCRIPTOR", s"$SPARK312_DESCRIPTOR" ,s"$SPARK313_DESCRIPTOR",
33 | s"$SPARK320_DESCRIPTOR", s"$SPARK321_DESCRIPTOR", s"$SPARK322_DESCRIPTOR",
34 | s"$SPARK323_DESCRIPTOR", s"$SPARK324_DESCRIPTOR")
35 | val DESCRIPTOR = SPARK323_DESCRIPTOR
36 | }
37 |
38 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
39 | def createShim: SparkShims = {
40 | new Spark322Shims()
41 | }
42 |
43 | def matches(version: String): Boolean = {
44 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark322
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark322._
24 | import org.apache.spark.spark322.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark322.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark322Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark322ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark322
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark322
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor.CoarseGrainedExecutorBackend
24 | import org.apache.spark.executor.RayDPExecutorBackendFactory
25 | import org.apache.spark.resource.ResourceProfile
26 | import org.apache.spark.rpc.RpcEnv
27 |
28 | class RayDPSpark322ExecutorBackendFactory
29 | extends RayDPExecutorBackendFactory {
30 | override def createExecutorBackend(
31 | rpcEnv: RpcEnv,
32 | driverUrl: String,
33 | executorId: String,
34 | bindAddress: String,
35 | hostname: String,
36 | cores: Int,
37 | userClassPath: Seq[URL],
38 | env: SparkEnv,
39 | resourcesFileOpt: Option[String],
40 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
41 | new CoarseGrainedExecutorBackend(
42 | rpcEnv,
43 | driverUrl,
44 | executorId,
45 | bindAddress,
46 | hostname,
47 | cores,
48 | userClassPath,
49 | env,
50 | resourcesFileOpt,
51 | resourceProfile)
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark322
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
23 | import org.apache.spark.sql.execution.arrow.ArrowConverters
24 | import org.apache.spark.sql.types.StructType
25 | import org.apache.spark.sql.util.ArrowUtils
26 |
27 | object SparkSqlUtils {
28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
29 | ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session))
30 | }
31 |
32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/core/shims/spark330/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark330
15 | RayDP Shims for Spark 3.3.0
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark330.version}
70 | provided
71 |
72 |
73 | com.google.protobuf
74 | protobuf-java
75 |
76 |
77 |
78 |
79 | org.apache.spark
80 | spark-core_${scala.binary.version}
81 | ${spark330.version}
82 | provided
83 |
84 |
85 | org.xerial.snappy
86 | snappy-java
87 |
88 |
89 | io.netty
90 | netty-handler
91 |
92 |
93 | org.apache.commons
94 | commons-text
95 |
96 |
97 | org.apache.ivy
98 | ivy
99 |
100 |
101 |
102 |
103 | org.xerial.snappy
104 | snappy-java
105 | ${snappy.version}
106 |
107 |
108 | io.netty
109 | netty-handler
110 | ${netty.version}
111 |
112 |
113 | org.apache.commons
114 | commons-text
115 | ${commons.text.version}
116 |
117 |
118 | org.apache.ivy
119 | ivy
120 | ${ivy.version}
121 |
122 |
123 | com.google.protobuf
124 | protobuf-java
125 | ${protobuf.version}
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark330.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark330
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK330_DESCRIPTOR = SparkShimDescriptor(3, 3, 0)
24 | val SPARK331_DESCRIPTOR = SparkShimDescriptor(3, 3, 1)
25 | val SPARK332_DESCRIPTOR = SparkShimDescriptor(3, 3, 2)
26 | val SPARK333_DESCRIPTOR = SparkShimDescriptor(3, 3, 3)
27 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK330_DESCRIPTOR", s"$SPARK331_DESCRIPTOR",
28 | s"$SPARK332_DESCRIPTOR", s"$SPARK333_DESCRIPTOR")
29 | val DESCRIPTOR = SPARK332_DESCRIPTOR
30 | }
31 |
32 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
33 | def createShim: SparkShims = {
34 | new Spark330Shims()
35 | }
36 |
37 | def matches(version: String): Boolean = {
38 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark330
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark330._
24 | import org.apache.spark.spark330.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark330.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark330Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark330ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark330
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.resource.ResourceProfile
24 | import org.apache.spark.rpc.RpcEnv
25 |
26 | class RayCoarseGrainedExecutorBackend(
27 | rpcEnv: RpcEnv,
28 | driverUrl: String,
29 | executorId: String,
30 | bindAddress: String,
31 | hostname: String,
32 | cores: Int,
33 | userClassPath: Seq[URL],
34 | env: SparkEnv,
35 | resourcesFileOpt: Option[String],
36 | resourceProfile: ResourceProfile)
37 | extends CoarseGrainedExecutorBackend(
38 | rpcEnv,
39 | driverUrl,
40 | executorId,
41 | bindAddress,
42 | hostname,
43 | cores,
44 | env,
45 | resourcesFileOpt,
46 | resourceProfile) {
47 |
48 | override def getUserClassPath: Seq[URL] = userClassPath
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark330
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor._
24 | import org.apache.spark.resource.ResourceProfile
25 | import org.apache.spark.rpc.RpcEnv
26 |
27 | class RayDPSpark330ExecutorBackendFactory
28 | extends RayDPExecutorBackendFactory {
29 | override def createExecutorBackend(
30 | rpcEnv: RpcEnv,
31 | driverUrl: String,
32 | executorId: String,
33 | bindAddress: String,
34 | hostname: String,
35 | cores: Int,
36 | userClassPath: Seq[URL],
37 | env: SparkEnv,
38 | resourcesFileOpt: Option[String],
39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
40 | new RayCoarseGrainedExecutorBackend(
41 | rpcEnv,
42 | driverUrl,
43 | executorId,
44 | bindAddress,
45 | hostname,
46 | cores,
47 | userClassPath,
48 | env,
49 | resourcesFileOpt,
50 | resourceProfile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark330
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
23 | import org.apache.spark.sql.execution.arrow.ArrowConverters
24 | import org.apache.spark.sql.types.StructType
25 | import org.apache.spark.sql.util.ArrowUtils
26 |
27 | object SparkSqlUtils {
28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
29 | ArrowConverters.toDataFrame(rdd, schema, session)
30 | }
31 |
32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/core/shims/spark340/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark340
15 | RayDP Shims for Spark 3.4.0
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark340.version}
70 | provided
71 |
72 |
73 | org.apache.spark
74 | spark-core_${scala.binary.version}
75 | ${spark340.version}
76 | provided
77 |
78 |
79 | org.xerial.snappy
80 | snappy-java
81 |
82 |
83 | io.netty
84 | netty-handler
85 |
86 |
87 |
88 |
89 | org.xerial.snappy
90 | snappy-java
91 | ${snappy.version}
92 |
93 |
94 | io.netty
95 | netty-handler
96 | ${netty.version}
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark340.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark340
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0)
24 | val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1)
25 | val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2)
26 | val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3)
27 | val SPARK344_DESCRIPTOR = SparkShimDescriptor(3, 4, 4)
28 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR",
29 | s"$SPARK343_DESCRIPTOR", s"$SPARK344_DESCRIPTOR")
30 | val DESCRIPTOR = SPARK341_DESCRIPTOR
31 | }
32 |
33 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
34 | def createShim: SparkShims = {
35 | new Spark340Shims()
36 | }
37 |
38 | def matches(version: String): Boolean = {
39 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark340
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark340._
24 | import org.apache.spark.spark340.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark340.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark340Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark340ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark340
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.resource.ResourceProfile
24 | import org.apache.spark.rpc.RpcEnv
25 |
26 | class RayCoarseGrainedExecutorBackend(
27 | rpcEnv: RpcEnv,
28 | driverUrl: String,
29 | executorId: String,
30 | bindAddress: String,
31 | hostname: String,
32 | cores: Int,
33 | userClassPath: Seq[URL],
34 | env: SparkEnv,
35 | resourcesFileOpt: Option[String],
36 | resourceProfile: ResourceProfile)
37 | extends CoarseGrainedExecutorBackend(
38 | rpcEnv,
39 | driverUrl,
40 | executorId,
41 | bindAddress,
42 | hostname,
43 | cores,
44 | env,
45 | resourcesFileOpt,
46 | resourceProfile) {
47 |
48 | override def getUserClassPath: Seq[URL] = userClassPath
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark340
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor._
24 | import org.apache.spark.resource.ResourceProfile
25 | import org.apache.spark.rpc.RpcEnv
26 |
27 | class RayDPSpark340ExecutorBackendFactory
28 | extends RayDPExecutorBackendFactory {
29 | override def createExecutorBackend(
30 | rpcEnv: RpcEnv,
31 | driverUrl: String,
32 | executorId: String,
33 | bindAddress: String,
34 | hostname: String,
35 | cores: Int,
36 | userClassPath: Seq[URL],
37 | env: SparkEnv,
38 | resourcesFileOpt: Option[String],
39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
40 | new RayCoarseGrainedExecutorBackend(
41 | rpcEnv,
42 | driverUrl,
43 | executorId,
44 | bindAddress,
45 | hostname,
46 | cores,
47 | userClassPath,
48 | env,
49 | resourcesFileOpt,
50 | resourceProfile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark340
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.TaskContext
22 | import org.apache.spark.api.java.JavaRDD
23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
24 | import org.apache.spark.sql.execution.arrow.ArrowConverters
25 | import org.apache.spark.sql.types._
26 | import org.apache.spark.sql.util.ArrowUtils
27 |
28 | object SparkSqlUtils {
29 | def toDataFrame(
30 | arrowBatchRDD: JavaRDD[Array[Byte]],
31 | schemaString: String,
32 | session: SparkSession): DataFrame = {
33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
36 | val context = TaskContext.get()
37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
38 | }
39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema)
40 | }
41 |
42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/core/shims/spark350/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark350
15 | RayDP Shims for Spark 3.5.0
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark350.version}
70 | provided
71 |
72 |
73 | org.apache.spark
74 | spark-core_${scala.binary.version}
75 | ${spark350.version}
76 | provided
77 |
78 |
79 | org.xerial.snappy
80 | snappy-java
81 |
82 |
83 | io.netty
84 | netty-handler
85 |
86 |
87 |
88 |
89 | org.xerial.snappy
90 | snappy-java
91 | ${snappy.version}
92 |
93 |
94 | io.netty
95 | netty-handler
96 | ${netty.version}
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark350.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark350
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0)
24 | val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1)
25 | val SPARK352_DESCRIPTOR = SparkShimDescriptor(3, 5, 2)
26 | val SPARK353_DESCRIPTOR = SparkShimDescriptor(3, 5, 3)
27 | val SPARK354_DESCRIPTOR = SparkShimDescriptor(3, 5, 4)
28 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR", s"$SPARK352_DESCRIPTOR",
29 | s"$SPARK353_DESCRIPTOR", s"$SPARK354_DESCRIPTOR")
30 | val DESCRIPTOR = SPARK350_DESCRIPTOR
31 | }
32 |
33 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
34 | def createShim: SparkShims = {
35 | new Spark350Shims()
36 | }
37 |
38 | def matches(version: String): Boolean = {
39 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark350
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark350._
24 | import org.apache.spark.spark350.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark350.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark350Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark350ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark350
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.resource.ResourceProfile
24 | import org.apache.spark.rpc.RpcEnv
25 |
26 | class RayCoarseGrainedExecutorBackend(
27 | rpcEnv: RpcEnv,
28 | driverUrl: String,
29 | executorId: String,
30 | bindAddress: String,
31 | hostname: String,
32 | cores: Int,
33 | userClassPath: Seq[URL],
34 | env: SparkEnv,
35 | resourcesFileOpt: Option[String],
36 | resourceProfile: ResourceProfile)
37 | extends CoarseGrainedExecutorBackend(
38 | rpcEnv,
39 | driverUrl,
40 | executorId,
41 | bindAddress,
42 | hostname,
43 | cores,
44 | env,
45 | resourcesFileOpt,
46 | resourceProfile) {
47 |
48 | override def getUserClassPath: Seq[URL] = userClassPath
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark350
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor._
24 | import org.apache.spark.resource.ResourceProfile
25 | import org.apache.spark.rpc.RpcEnv
26 |
27 | class RayDPSpark350ExecutorBackendFactory
28 | extends RayDPExecutorBackendFactory {
29 | override def createExecutorBackend(
30 | rpcEnv: RpcEnv,
31 | driverUrl: String,
32 | executorId: String,
33 | bindAddress: String,
34 | hostname: String,
35 | cores: Int,
36 | userClassPath: Seq[URL],
37 | env: SparkEnv,
38 | resourcesFileOpt: Option[String],
39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
40 | new RayCoarseGrainedExecutorBackend(
41 | rpcEnv,
42 | driverUrl,
43 | executorId,
44 | bindAddress,
45 | hostname,
46 | cores,
47 | userClassPath,
48 | env,
49 | resourcesFileOpt,
50 | resourceProfile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark350
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.TaskContext
22 | import org.apache.spark.api.java.JavaRDD
23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
24 | import org.apache.spark.sql.execution.arrow.ArrowConverters
25 | import org.apache.spark.sql.types._
26 | import org.apache.spark.sql.util.ArrowUtils
27 |
28 | object SparkSqlUtils {
29 | def toDataFrame(
30 | arrowBatchRDD: JavaRDD[Array[Byte]],
31 | schemaString: String,
32 | session: SparkSession): DataFrame = {
33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
36 | val context = TaskContext.get()
37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId,false, context)
38 | }
39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema)
40 | }
41 |
42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/doc/mpi.md:
--------------------------------------------------------------------------------
1 | # MPI on Ray
2 |
3 | RayDP also provides a simple API to running MPI job on top of Ray. Currently, we support three types of MPI: `intel_mpi`, `openmpi` and `MPICH`. To use the following API, make sure you have installed the given type of MPI on each of Ray worker node.
4 |
5 | ### API
6 |
7 | ```python
8 | def create_mpi_job(job_name: str,
9 | world_size: int,
10 | num_cpus_per_process: int,
11 | num_processes_per_node: int,
12 | mpi_script_prepare_fn: Callable = None,
13 | timeout: int = 1,
14 | mpi_type: str = "intel_mpi",
15 | placement_group=None,
16 | placement_group_bundle_indexes: List[int] = None) -> MPIJob:
17 | """ Create a MPI Job
18 |
19 | :param job_name: the job name
20 | :param world_size: the world size
21 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray
22 | :param num_processes_per_node: num processes per node
23 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a
24 | MPIJobcontext instance. It will use the default script if not provides.
25 | :param timeout: the timeout used to wait for job creation
26 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and mpich
27 | :param placement_group: the placement_group for request mpi resources
28 | :param placement_group_bundle_indexes: this should be equal with
29 | world_size / num_processes_per_node if provides.
30 | """
31 | ```
32 |
33 | ### Create a simple MPI Job
34 |
35 | ```python
36 | from raydp.mpi import create_mpi_job, MPIJobContext, WorkerContext
37 |
38 | # Define the MPI JOb. We want to create a 4 world_size MPIJob, and each process requires 2 cpus.
39 | # We have set the num_processes_per_node to 2, so the processes will be strictly spread into two nodes.
40 |
41 | # You could also to specify the placement group to reserve the resources for MPI job. The num_cpus_per_process
42 | # will be ignored if the placement group is provided. And the size of
43 | # placement_group_bundle_indexes should be equal with world_size // num_processes_per_node.
44 | job = create_mpi_job(job_name="example",
45 | world_size=4,
46 | num_cpus_per_process=2,
47 | num_processes_per_node=2,
48 | timeout=5,
49 | mpi_type="intel_mpi",
50 | placement_group=None,
51 | placement_group_bundle_indexes: List[int] = None)
52 |
53 | # Start the MPI Job, this will start up the MPI processes and connect to the ray cluster
54 | job.start()
55 |
56 | # define the MPI task function
57 | def func(context: WorkerContext):
58 | return context.job_id
59 |
60 | # run the MPI task, this is a blocking operation. And the results is a world_size array.
61 | results = job.run(func)
62 |
63 | # stop the MPI job
64 | job.stop()
65 | ```
66 |
67 | ### Use `with` auto start/stop MPIJob
68 | ```python
69 | with create_mpi_job(job_name="example",
70 | world_size=4,
71 | num_cpus_per_process=2,
72 | num_processes_per_node=2,
73 | timeout=5,
74 | mpi_type="intel_mpi") as job:
75 | def f(context: WorkerContext):
76 | return context.job_id
77 | results = job.run(f)
78 | ```
79 |
80 | ### Specify the MPI script and environments
81 |
82 | You could customize the MPI job environments and MPI scripts with `mpi_script_prepare_fn` argument.
83 |
84 | ```python
85 | def script_prepare_fn(context: MPIJobContext):
86 | context.add_env("OMP_NUM_THREADS", "2")
87 | default_script = ["mpirun", "--allow-run-as-root", "--tag-output", "-H",
88 | ",".join(context.hosts), "-N", f"{context.num_procs_per_node}"]
89 | return default_script
90 |
91 | job = create_mpi_job(job_name="example",
92 | world_size=4,
93 | num_cpus_per_process=2,
94 | num_processes_per_node=2,
95 | timeout=5,
96 | mpi_type="intel_mpi",
97 | mpi_script_prepare_fn=script_prepare_fn)
98 | ```
99 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM rayproject/ray:latest@sha256:c864e37f4ce516ff49425f69cac5503a51e84c333d30928416714a2c3da55b43
2 |
3 | ARG HTTP_PROXY
4 | ARG HTTPS_PROXY
5 |
6 | # set http_proxy & https_proxy
7 | ENV http_proxy=${HTTP_PROXY}
8 | ENV https_proxy=${HTTPS_PROXY}
9 |
10 | # install java, create workdir and install raydp
11 | # You could change the raydp to raydp-nightly if you want to try the master branch code
12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \
13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \
14 | && sudo mkdir /raydp \
15 | && sudo chown -R ray /raydp \
16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp
17 |
18 | WORKDIR /raydp
19 |
20 | # unset http_proxy & https_proxy
21 | ENV http_proxy=
22 | ENV https_proxy=
23 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Running RayDP on k8s cluster
2 |
3 | ## Build docker image
4 | Build the docker image to use in K8S with the following command, and this will create an image tag with `oap-project/raydp:latest`
5 | ```shell
6 | # under ${RAYDP_HOME}/docker
7 | ./build-docker.sh
8 | ```
9 |
10 | You can install our nightly build with `pip install raydp --pre` or `pip install raydp-nightly`.To install raydp-nightly in the image, modify the following code in `Dockerfile`:
11 | ```Dockerfile
12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \
13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \
14 | && sudo mkdir /raydp \
15 | && sudo chown -R ray /raydp \
16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp-nightly
17 | ```
18 |
19 | Meanwhile, you should install all dependencies of your application in the `Dockerfile`. If suitable, you can change the base image to `ray-ml`:
20 | ```Dockerfile
21 | FROM rayproject/ray-ml:latest
22 | ```
23 |
24 | Then, you can push the built image to repository or spread to the k8s worker nodes.
25 |
26 | ## Deploy ray cluster with Helm
27 | You need to create a Helm chart first. To start with, check out this [example ray cluster Helm chart](https://github.com/ray-project/kuberay/tree/master/helm-chart/ray-cluster). You can clone this repo and copy this directory, then modify `values.yaml` to use the previously built image.
28 |
29 | ```yaml
30 | image:
31 | repository: oap-project/raydp
32 | tag: latest
33 | pullPolicy: IfNotPresent
34 | ```
35 |
36 | You can also change other fields in this file to specify number of workers, etc.
37 |
38 | Then, you need to deploy the KubeRay operator first, please refer to [here](https://docs.ray.io/en/latest/cluster/kubernetes/getting-started.html#kuberay-quickstart) for instructions. You can now deploy a Ray cluster with RayDP installed via `helm install ray-cluster PATH_to_CHART`.
39 |
40 | ## Access the cluster
41 | Check here [here](https://docs.ray.io/en/master/cluster/kubernetes/getting-started.html#running-applications-on-a-ray-cluster) to see how to run applications on the cluster you just deployed.
42 |
43 | ## Legacy
44 | If you are using Ray versions before 2.0, you can try this command.
45 | ```shell
46 | ray up ${RAYDP_HOME}/docker/legacy.yaml
47 | ```
--------------------------------------------------------------------------------
/docker/build-docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | docker build --build-arg HTTP_PROXY=${http_proxy} \
21 | --build-arg HTTPS_PROXY=${https_proxy} \
22 | -t oap-project/raydp:latest .
23 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # RayDP Examples
2 | Here are a few examples showing how RayDP works together with other libraries, such as PyTorch, Tensorflow, XGBoost and Horovod.
3 |
4 | In order to run these examples, you may need to install corresponding dependencies. For installation guides, please refer to their homepages. Notice that we need to install [xgboost_ray](https://github.com/ray-project/xgboost_ray) to run the xgboost example. In addition, if you are running the examples in a ray cluster, all nodes should have the dependencies installed.
5 |
6 | ## NYC Taxi Fare Prediction Dataset
7 | We have a few examples which use this dataset.
8 | You can run our examples right away after you clone our repo, because we include a small example dataset generated randomly using `examples/random_nyctaxi.py`. Generated datasets just demonstrates that our examples can work, but the trained models might not be meaningful.
9 |
10 | The original dataset can be downloaded [here](https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data). After you download it, please modify the variable `NYC_TRAIN_CSV` in `data_process.py` and point it to where `train.csv` is saved.
11 |
12 | ## Horovod
13 | To run the example, please install horovod via `pip install horovod[pytorch, ray]`. In addition, `HOROVOD_WITH_PYTORCH` and `HOROVOD_WITH_GLOO` should be set to `1` before pip. Notice that macOS users need to first install `libuv` via `brew install libuv`. Please refer to [here](https://horovod.readthedocs.io/en/stable/install_include.html) for details.
14 |
15 | When running `horovod_nyctaxi.py`, do not use `horovodrun`. Check [here](https://horovod.readthedocs.io/en/stable/ray_include.html) for more information.
16 |
17 | ## RaySGD Example
18 | In the RaySGD example, we demonstrate how to use our `MLDataset` API. After we use Spark to transform the dataset, we call `RayMLDataset.from_spark` to write the Spark DataFrames into Ray object store, using Apache Arrow format. We then convert the data to `pandas` DataFrame, hopefully zero-copy. Finally, they can be consumed by any framework supports `numpy` format, such as PyTorch or Tensorflow. `MLDataset` is partitioned, or sharded, just like Spark DataFrames. Their numbers of partitions are not required to be the same. However, the number of shards of `MLDataset` should be the same as the number of workers of `TorchTrainer` or `TFTrainer`, so that each worker is mapped to a shard.
19 |
--------------------------------------------------------------------------------
/examples/pytorch_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import ray
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import raydp
7 | from raydp.torch import TorchEstimator
8 | from raydp.utils import random_split
9 |
10 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV
11 | from typing import List, Dict
12 |
13 | # Firstly, You need to init or connect to a ray cluster.
14 | # Note that you should set include_java to True.
15 | # For more config info in ray, please refer the ray doc:
16 | # https://docs.ray.io/en/latest/package-ref.html
17 | # ray.init(address="auto")
18 | ray.init(address="local", num_cpus=4)
19 |
20 | # After initialize ray cluster, you can use the raydp api to get a spark session
21 | app_name = "NYC Taxi Fare Prediction with RayDP"
22 | num_executors = 1
23 | cores_per_executor = 1
24 | memory_per_executor = "500M"
25 | spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)
26 |
27 | # Then you can code as you are using spark
28 | # The dataset can be downloaded from:
29 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data
30 | # Here we just use a subset of the training data
31 | data = spark.read.format("csv").option("header", "true") \
32 | .option("inferSchema", "true") \
33 | .load(NYC_TRAIN_CSV)
34 | # Set spark timezone for processing datetime
35 | spark.conf.set("spark.sql.session.timeZone", "UTC")
36 | # Transform the dataset
37 | data = nyc_taxi_preprocess(data)
38 | # Split data into train_dataset and test_dataset
39 | train_df, test_df = random_split(data, [0.9, 0.1], 0)
40 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"]
41 | # Define a neural network model
42 | class NYC_Model(nn.Module):
43 | def __init__(self, cols):
44 | super().__init__()
45 | self.fc1 = nn.Linear(cols, 256)
46 | self.fc2 = nn.Linear(256, 128)
47 | self.fc3 = nn.Linear(128, 64)
48 | self.fc4 = nn.Linear(64, 16)
49 | self.fc5 = nn.Linear(16, 1)
50 | self.bn1 = nn.BatchNorm1d(256)
51 | self.bn2 = nn.BatchNorm1d(128)
52 | self.bn3 = nn.BatchNorm1d(64)
53 | self.bn4 = nn.BatchNorm1d(16)
54 |
55 | def forward(self, x):
56 | x = F.relu(self.fc1(x))
57 | x = self.bn1(x)
58 | x = F.relu(self.fc2(x))
59 | x = self.bn2(x)
60 | x = F.relu(self.fc3(x))
61 | x = self.bn3(x)
62 | x = F.relu(self.fc4(x))
63 | x = self.bn4(x)
64 | x = self.fc5(x)
65 | return x
66 |
67 | nyc_model = NYC_Model(len(features))
68 | criterion = nn.SmoothL1Loss()
69 | optimizer = torch.optim.Adam(nyc_model.parameters(), lr=0.001)
70 | # Create a distributed estimator based on the raydp api
71 | estimator = TorchEstimator(num_workers=1, model=nyc_model, optimizer=optimizer, loss=criterion,
72 | feature_columns=features, feature_types=torch.float,
73 | label_column="fare_amount", label_type=torch.float,
74 | batch_size=64, num_epochs=30,
75 | metrics_name = ["MeanAbsoluteError", "MeanSquaredError"],
76 | use_ccl=False)
77 | # Train the model
78 | estimator.fit_on_spark(train_df, test_df)
79 | # Get the trained model
80 | model = estimator.get_model()
81 | # shutdown raydp and ray
82 | raydp.stop_spark()
83 | ray.shutdown()
84 |
--------------------------------------------------------------------------------
/examples/random_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | base_date = np.datetime64("2010-01-01 00:00:00")
8 |
9 | parser = argparse.ArgumentParser(description="Rabdin NYC taxi Generator")
10 | parser.add_argument(
11 | "--num-records",
12 | type=int,
13 | default=2000,
14 | metavar="N",
15 | help="number of records to generate (default: 2000)")
16 |
17 | args = parser.parse_args()
18 |
19 | N = args.num_records
20 |
21 | fare_amount = np.random.uniform(3.0, 50.0, size=N)
22 | pick_long = np.random.uniform(-74.2, -73.8, size=N)
23 | pick_lat = np.random.uniform(40.7, 40.8, size=N)
24 | drop_long = np.random.uniform(-74.2, -73.8, size=N)
25 | drop_lat = np.random.uniform(40.7, 40.8, size=N)
26 | passenger_count = np.random.randint(1, 5, size=N)
27 | date = np.random.randint(0, 157680000, size=N) + base_date
28 | date = np.array([t.item().strftime("%Y-%m-%d %H:%m:%S UTC") for t in date])
29 | key = ["fake_key"] * N
30 | df = pd.DataFrame({
31 | "key": key,
32 | "fare_amount":fare_amount,
33 | "pickup_datetime": date,
34 | "pickup_longitude": pick_long,
35 | "pickup_latitude": pick_lat,
36 | "dropoff_longitude": drop_long,
37 | "dropoff_latitude": drop_lat,
38 | "passenger_count": passenger_count
39 | })
40 | csv_path = os.path.dirname(os.path.realpath(__file__)) + "/fake_nyctaxi.csv"
41 | df.to_csv(csv_path, index=False)
42 |
--------------------------------------------------------------------------------
/examples/raydp-submit.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname
2 | import sys
3 | import json
4 | import subprocess
5 | import ray
6 | import pyspark
7 |
8 | ray.init(address="auto")
9 | node = ray.worker.global_worker.node
10 | options = {}
11 | options["ray"] = {}
12 | options["ray"]["run-mode"] = "CLUSTER"
13 | options["ray"]["node-ip"] = node.node_ip_address
14 | options["ray"]["address"] = node.address
15 | options["ray"]["session-dir"] = node.get_session_dir_path()
16 |
17 | ray.shutdown()
18 | conf_path = dirname(__file__) + "/ray.conf"
19 | with open(conf_path, "w") as f:
20 | json.dump(options, f)
21 | command = ["bin/raydp-submit", "--ray-conf", conf_path]
22 | command += ["--conf", "spark.executor.cores=1"]
23 | command += ["--conf", "spark.executor.instances=1"]
24 | command += ["--conf", "spark.executor.memory=500m"]
25 | example_path = dirname(pyspark.__file__)
26 | # run SparkPi as example
27 | command.append(example_path + "/examples/src/main/python/pi.py")
28 | sys.exit(subprocess.run(command, check=True).returncode)
29 |
--------------------------------------------------------------------------------
/examples/tensorflow_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import ray
2 | from tensorflow import keras
3 | from tensorflow.keras.callbacks import Callback
4 |
5 | import raydp
6 | from raydp.tf import TFEstimator
7 | from raydp.utils import random_split
8 |
9 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV
10 | from typing import List, Dict
11 | # Firstly, You need to init or connect to a ray cluster.
12 | # Note that you should set include_java to True.
13 | # For more config info in ray, please refer the ray doc:
14 | # https://docs.ray.io/en/latest/package-ref.html
15 | # ray.init(address="auto")
16 | ray.init(address="local", num_cpus=6)
17 |
18 | # After initialize ray cluster, you can use the raydp api to get a spark session
19 | app_name = "NYC Taxi Fare Prediction with RayDP"
20 | num_executors = 1
21 | cores_per_executor = 1
22 | memory_per_executor = "500M"
23 | spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)
24 |
25 | # Then you can code as you are using spark
26 | # The dataset can be downloaded from:
27 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data
28 | # Here we just use a subset of the training data
29 | data = spark.read.format("csv").option("header", "true") \
30 | .option("inferSchema", "true") \
31 | .load(NYC_TRAIN_CSV)
32 | # Set spark timezone for processing datetime
33 | spark.conf.set("spark.sql.session.timeZone", "UTC")
34 | # Transform the dataset
35 | data = nyc_taxi_preprocess(data)
36 | data = data.cache()
37 | # Split data into train_dataset and test_dataset
38 | train_df, test_df = random_split(data, [0.9, 0.1], 0)
39 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"]
40 |
41 | # Define the keras model
42 | model = keras.Sequential(
43 | [
44 | keras.layers.InputLayer(input_shape=(len(features),)),
45 | keras.layers.Flatten(),
46 | keras.layers.Dense(256, activation="relu"),
47 | keras.layers.BatchNormalization(),
48 | keras.layers.Dense(128, activation="relu"),
49 | keras.layers.BatchNormalization(),
50 | keras.layers.Dense(64, activation="relu"),
51 | keras.layers.BatchNormalization(),
52 | keras.layers.Dense(32, activation="relu"),
53 | keras.layers.BatchNormalization(),
54 | keras.layers.Dense(16, activation="relu"),
55 | keras.layers.BatchNormalization(),
56 | keras.layers.Dense(1),
57 | ]
58 | )
59 |
60 | class PrintingCallback(Callback):
61 | def handle_result(self, results: List[Dict], **info):
62 | print(results)
63 |
64 | # Define the optimizer and loss function
65 | # Then create the tensorflow estimator provided by Raydp
66 | adam = keras.optimizers.Adam(learning_rate=0.001)
67 | loss = keras.losses.MeanSquaredError()
68 | estimator = TFEstimator(num_workers=2, model=model, optimizer=adam, loss=loss,
69 | merge_feature_columns=True, metrics=["mae"],
70 | feature_columns=features, label_columns="fare_amount",
71 | batch_size=256, num_epochs=10, callbacks=[PrintingCallback()])
72 |
73 | # Train the model
74 | estimator.fit_on_spark(train_df, test_df)
75 | # Get the model
76 | model = estimator.get_model()
77 | # shudown raydp and ray
78 | raydp.stop_spark()
79 | ray.shutdown()
80 |
--------------------------------------------------------------------------------
/examples/xgboost_ray_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import ray
2 | import numpy as np
3 | # XGBoost on ray is needed to run this example.
4 | # Please refer to https://docs.ray.io/en/latest/xgboost-ray.html to install it.
5 | from xgboost_ray import RayDMatrix, train, RayParams
6 | import raydp
7 | from raydp.utils import random_split
8 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV
9 |
10 | # connect to ray cluster
11 | # ray.init(address="auto")
12 | ray.init(address="local", num_cpus=4)
13 | # After ray.init, you can use the raydp api to get a spark session
14 | app_name = "NYC Taxi Fare Prediction with RayDP"
15 | num_executors = 1
16 | cores_per_executor = 1
17 | memory_per_executor = "500M"
18 | spark = raydp.init_spark(app_name, num_executors, cores_per_executor, memory_per_executor)
19 | data = spark.read.format("csv").option("header", "true") \
20 | .option("inferSchema", "true") \
21 | .load(NYC_TRAIN_CSV)
22 | # Set spark timezone for processing datetime
23 | spark.conf.set("spark.sql.session.timeZone", "UTC")
24 | # Transform the dataset
25 | data = nyc_taxi_preprocess(data)
26 | # Split data into train_dataset and test_dataset
27 | train_df, test_df = random_split(data, [0.9, 0.1], 0)
28 | # Convert spark dataframe into ray dataset
29 | train_dataset = ray.data.from_spark(train_df)
30 | test_dataset = ray.data.from_spark(test_df)
31 | # Then convert them into DMatrix used by xgboost
32 | dtrain = RayDMatrix(train_dataset, label="fare_amount")
33 | dtest = RayDMatrix(test_dataset, label="fare_amount")
34 | # Configure the XGBoost model
35 | config = {
36 | "tree_method": "hist",
37 | "eval_metric": ["logloss", "error"],
38 | }
39 | evals_result = {}
40 | # Train the model
41 | bst = train(
42 | config,
43 | dtrain,
44 | evals=[(dtest, "eval")],
45 | evals_result=evals_result,
46 | ray_params=RayParams(max_actor_restarts=1, num_actors=1, cpus_per_actor=1),
47 | num_boost_round=10)
48 | # print evaluation stats
49 | print("Final validation error: {:.4f}".format(
50 | evals_result["eval"]["error"][-1]))
51 | raydp.stop_spark()
52 | ray.shutdown()
53 |
--------------------------------------------------------------------------------
/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | include README.md
19 | recursive-include raydp/jars *.jar
20 | global-exclude *.py[cod] __pycache__ .DS_Store
--------------------------------------------------------------------------------
/python/raydp/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from raydp.context import init_spark, stop_spark
19 |
20 | __version__ = "1.7.0.dev0"
21 |
22 | __all__ = ["init_spark", "stop_spark"]
23 |
--------------------------------------------------------------------------------
/python/raydp/estimator.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, NoReturn, Optional
20 |
21 |
22 |
23 | class EstimatorInterface(ABC):
24 | """
25 | A scikit-learn like API.
26 | """
27 |
28 | @abstractmethod
29 | def fit(self,
30 | train_ds,
31 | evaluate_ds = None) -> NoReturn:
32 | """Train or evaluate the model.
33 |
34 | :param train_ds: the model will train on the MLDataset
35 | :param evaluate_ds: if this is provided, the model will evaluate on the MLDataset
36 | """
37 |
38 | @abstractmethod
39 | def get_model(self) -> Any:
40 | """Get the trained model
41 |
42 | :return the model
43 | """
44 |
--------------------------------------------------------------------------------
/python/raydp/mpi/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 |
19 | from typing import Callable, List
20 |
21 | from .mpi_job import MPIJob, MPIType, IntelMPIJob, OpenMPIJob, MPICHJob, MPIJobContext
22 | from .mpi_worker import WorkerContext
23 |
24 |
25 | def _get_mpi_type(mpi_type: str) -> MPIType:
26 | if mpi_type.strip().lower() == "openmpi":
27 | return MPIType.OPEN_MPI
28 | elif mpi_type.strip().lower() == "intel_mpi":
29 | return MPIType.INTEL_MPI
30 | elif mpi_type.strip().lower() == "mpich":
31 | return MPIType.MPICH
32 | else:
33 | return None
34 |
35 |
36 | def create_mpi_job(job_name: str,
37 | world_size: int,
38 | num_cpus_per_process: int,
39 | num_processes_per_node: int,
40 | mpi_script_prepare_fn: Callable = None,
41 | timeout: int = 1,
42 | mpi_type: str = "intel_mpi",
43 | placement_group=None,
44 | placement_group_bundle_indexes: List[int] = None) -> MPIJob:
45 | """Create a MPI Job
46 |
47 | :param job_name: the job name
48 | :param world_size: the world size
49 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray
50 | :param num_processes_per_node: num processes per node
51 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a
52 | MPIJobContext instance. It will use the default script if not provides.
53 | :param timeout: the timeout used to wait for job creation
54 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and MPICH
55 | :param placement_group: the placement_group for request mpi resources
56 | :param placement_group_bundle_indexes: this should be equal with
57 | world_size / num_processes_per_node if provides.
58 | """
59 | mpi_type = _get_mpi_type(mpi_type)
60 | if mpi_type == MPIType.OPEN_MPI:
61 | return OpenMPIJob(mpi_type=MPIType.OPEN_MPI,
62 | job_name=job_name,
63 | world_size=world_size,
64 | num_cpus_per_process=num_cpus_per_process,
65 | num_processes_per_node=num_processes_per_node,
66 | mpi_script_prepare_fn=mpi_script_prepare_fn,
67 | timeout=timeout,
68 | placement_group=placement_group,
69 | placement_group_bundle_indexes=placement_group_bundle_indexes)
70 | elif mpi_type == MPIType.INTEL_MPI:
71 | return IntelMPIJob(mpi_type=MPIType.INTEL_MPI,
72 | job_name=job_name,
73 | world_size=world_size,
74 | num_cpus_per_process=num_cpus_per_process,
75 | num_processes_per_node=num_processes_per_node,
76 | mpi_script_prepare_fn=mpi_script_prepare_fn,
77 | timeout=timeout,
78 | placement_group=placement_group,
79 | placement_group_bundle_indexes=placement_group_bundle_indexes)
80 | elif mpi_type == MPIType.MPICH:
81 | return MPICHJob(mpi_type=MPIType.MPICH,
82 | job_name=job_name,
83 | world_size=world_size,
84 | num_cpus_per_process=num_cpus_per_process,
85 | num_processes_per_node=num_processes_per_node,
86 | mpi_script_prepare_fn=mpi_script_prepare_fn,
87 | timeout=timeout,
88 | placement_group=placement_group,
89 | placement_group_bundle_indexes=placement_group_bundle_indexes)
90 | else:
91 | raise Exception(f"MPI type: {mpi_type} not supported now")
92 |
93 |
94 | __all__ = ["create_mpi_job", "MPIJobContext", "WorkerContext"]
95 |
--------------------------------------------------------------------------------
/python/raydp/mpi/constants.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from os import path
19 |
20 | MPI_TYPE = "raydp_mpi_type"
21 | MPI_JOB_ID = "raydp_mpi_job_id"
22 | MPI_DRIVER_HOST = "raydp_mpi_driver_host"
23 | MPI_DRIVER_PORT = "raydp_mpi_driver_port"
24 |
25 | MAXIMUM_WAIT_TIME_OUT = "raydp_maximum_wait_time_out"
26 |
27 | _current_dir = path.dirname(path.realpath(__file__))
28 | MPI_MAIN_CLASS_PATH = path.join(_current_dir, "mpi_worker.py")
29 |
--------------------------------------------------------------------------------
/python/raydp/mpi/network/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import sys
19 | import os
20 |
21 | dir_path = os.path.dirname(os.path.realpath(__file__))
22 | sys.path.append(str(dir_path))
23 |
--------------------------------------------------------------------------------
/python/raydp/mpi/network/network.proto:
--------------------------------------------------------------------------------
1 | //
2 | // Licensed to the Apache Software Foundation (ASF) under one or more
3 | // contributor license agreements. See the NOTICE file distributed with
4 | // this work for additional information regarding copyright ownership.
5 | // The ASF licenses this file to You under the Apache License, Version 2.0
6 | // (the "License"); you may not use this file except in compliance with
7 | // the License. You may obtain a copy of the License at
8 | //
9 | // http://www.apache.org/licenses/LICENSE-2.0
10 | //
11 | // Unless required by applicable law or agreed to in writing, software
12 | // distributed under the License is distributed on an "AS IS" BASIS,
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | // See the License for the specific language governing permissions and
15 | // limitations under the License.
16 | //
17 |
18 | // Syntax version
19 | syntax = "proto3";
20 |
21 | // Driver Service definition
22 | service DriverService {
23 | // register the worker process to driver which used to tell the worker has started up
24 | rpc RegisterWorker (RegisterWorkerRequest) returns (RegisterWorkerReply);
25 | // register the worker service host and port
26 | rpc RegisterWorkerService (RegisterWorkerServiceRequest) returns (RegisterWorkerServiceReply);
27 | // register the function result
28 | rpc RegisterFuncResult (FunctionResult) returns (Empty);
29 | }
30 |
31 | // Worker Service
32 | service WorkerService {
33 | // run the given function
34 | rpc RunFunction (Function) returns (Empty);
35 | // stop the worker service
36 | rpc Stop (Empty) returns (Empty);
37 | }
38 |
39 | message RegisterWorkerRequest {
40 | // the job id
41 | string job_id = 1;
42 | // the world rank id
43 | int32 world_rank = 2;
44 | }
45 |
46 | message RegisterWorkerReply {
47 | // the all node addresses and used to determine the current node ip adddress
48 | repeated string node_addresses = 3;
49 | }
50 |
51 | message RegisterWorkerServiceRequest {
52 | // the world rank
53 | int32 world_rank = 1;
54 | // the worker service listening ip
55 | string worker_ip = 2;
56 | // the worker service listening port
57 | int32 worker_port = 3;
58 | }
59 |
60 | message RegisterWorkerServiceReply {
61 | // the ray redis address
62 | string ray_address = 1;
63 | // the ray redis password
64 | string redis_password = 2;
65 | }
66 |
67 | message Function {
68 | // the function id
69 | int32 func_id = 1;
70 | // the serialized python function
71 | bytes func = 2;
72 | }
73 |
74 | message FunctionResult {
75 | int32 world_rank = 1;
76 | // the function id
77 | int32 func_id = 2;
78 | // the function results
79 | bytes result = 3;
80 | }
81 |
82 | message Empty {
83 | }
84 |
--------------------------------------------------------------------------------
/python/raydp/mpi/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import os
19 | import select
20 | import subprocess
21 | import threading
22 | import time
23 | from typing import List
24 |
25 | import grpc
26 | import netifaces
27 |
28 |
29 | class StoppableThread(threading.Thread):
30 |
31 | def __init__(self, group=None, target=None, name=None,
32 | args=(), kwargs=None, *, daemon=None):
33 | super().__init__(group, target, name, args, kwargs, daemon=daemon)
34 | self._stop_event = threading.Event()
35 |
36 | def stop(self):
37 | self._stop_event.set()
38 |
39 | def stopped(self):
40 | return self._stop_event.is_set()
41 |
42 |
43 | def run_cmd(cmd: str, env, failed_callback):
44 | # pylint: disable=R1732
45 | proc = subprocess.Popen(cmd,
46 | shell=True,
47 | stdin=subprocess.DEVNULL,
48 | stdout=subprocess.PIPE,
49 | stderr=subprocess.PIPE,
50 | env=env,
51 | start_new_session=True)
52 |
53 | def check_failed():
54 | # check whether the process has finished
55 | while not threading.current_thread().stopped():
56 | ret_code = proc.poll()
57 | if ret_code:
58 | failed_callback()
59 | raise Exception(f"mpirun failed: {ret_code}")
60 |
61 | if ret_code == 0:
62 | break
63 |
64 | time.sleep(1)
65 |
66 | check_thread = StoppableThread(target=check_failed)
67 |
68 | def redirect_stream(streams):
69 | while not threading.current_thread().stopped() and streams:
70 | readable, _, _ = select.select(streams, [], [], 0.5)
71 | for stream in readable:
72 | if not stream:
73 | continue
74 | line = stream.readline()
75 | if not line:
76 | streams.remove(stream)
77 | else:
78 | print(line.decode().strip("\n"))
79 |
80 | redirect_thread = StoppableThread(target=redirect_stream, args=([proc.stdout, proc.stderr],))
81 | check_thread.start()
82 | redirect_thread.start()
83 | return proc, check_thread, redirect_thread
84 |
85 |
86 | def create_insecure_channel(address,
87 | options=None,
88 | compression=None):
89 | """Disable the http proxy when create channel"""
90 | # disable http proxy
91 | if options is not None:
92 | need_add = True
93 | for k, v in options:
94 | if k == "grpc.enable_http_proxy":
95 | need_add = False
96 | break
97 | if need_add:
98 | options = (*options, ("grpc.enable_http_proxy", 0))
99 | else:
100 | options = (("grpc.enable_http_proxy", 0),)
101 |
102 | return grpc.insecure_channel(
103 | address, options, compression)
104 |
105 |
106 | def get_environ_value(key: str) -> str:
107 | """Get value from environ, raise exception if the key not existed"""
108 | assert key in os.environ, f"{key} should be set in the environ"
109 | return os.environ[key]
110 |
111 |
112 | def get_node_ip_address(node_addresses: List[str]) -> str:
113 | found = None
114 | for interface in netifaces.interfaces():
115 | addrs = netifaces.ifaddresses(interface)
116 | addresses = addrs.get(netifaces.AF_INET, None)
117 | if not addresses:
118 | continue
119 | for inet_addr in addresses:
120 | address = inet_addr.get("addr", None)
121 | if address in node_addresses:
122 | found = address
123 | return found
124 |
--------------------------------------------------------------------------------
/python/raydp/ray_cluster_resources.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from typing import Dict, List
19 |
20 | import ray
21 | import time
22 | from ray.ray_constants import MEMORY_RESOURCE_UNIT_BYTES
23 |
24 |
25 | class ClusterResources:
26 | # TODO: make this configurable
27 | refresh_interval = 0.1
28 | latest_refresh_time = time.time() - refresh_interval
29 | node_to_resources = {}
30 | item_keys_mapping = {"num_cpus": "CPU"}
31 | label_name = "__ray_spark_node_label"
32 |
33 | @classmethod
34 | def total_alive_nodes(cls):
35 | cls._refresh()
36 | return len(cls.node_to_resources)
37 |
38 | @classmethod
39 | def satisfy(cls, request: Dict[str, float]) -> List[str]:
40 | cls._refresh()
41 | satisfied = []
42 | for host_name, resources in cls.node_to_resources.items():
43 | if cls._compare_two_dict(resources, request):
44 | satisfied.append(resources[cls.label_name])
45 |
46 | return satisfied
47 |
48 | @classmethod
49 | def _refresh(cls):
50 | if (time.time() - cls.latest_refresh_time) < cls.refresh_interval:
51 | return
52 |
53 | for node in ray.nodes():
54 | if node["Alive"]:
55 | host_name = node["NodeManagerHostname"]
56 | resources = node["Resources"]
57 | for key in resources:
58 | if key.startswith("node:"):
59 | resources[cls.label_name] = key
60 | break
61 | assert cls.label_name in resources,\
62 | f"{resources} should contain a resource likes: 'node:10.0.0.131': 1.0"
63 | cls.node_to_resources[host_name] = resources
64 | cls.latest_refresh_time = time.time()
65 |
66 | @classmethod
67 | def _compare_two_dict(cls, available: Dict[str, float], request: Dict[str, float]) -> bool:
68 | for k, v in request.items():
69 | k = cls.item_keys_mapping.get(k, k)
70 | if k not in available:
71 | return False
72 |
73 | if k == "memory":
74 | v = int(v / MEMORY_RESOURCE_UNIT_BYTES)
75 |
76 | if available[k] < v:
77 | return False
78 |
79 | return True
80 |
--------------------------------------------------------------------------------
/python/raydp/services.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict, NoReturn
20 |
21 |
22 | class Cluster(ABC):
23 | """
24 | This is the base class for all specified cluster, such as SparkCluster, FlinkCluster.
25 | :param master_resources_requirement: The resources requirement for the master service.
26 | """
27 | def __init__(self, master_resources_requirement):
28 | # the master node is live as same as ray driver node. And we can specify the resources
29 | # limitation for master node. So we don't count it.
30 | self._num_nodes = 0
31 |
32 | @abstractmethod
33 | def _set_up_master(self,
34 | resources: Dict[str, float],
35 | kwargs: Dict[Any, Any]):
36 | """
37 | Subcluster should implement this to set up master node.
38 | """
39 |
40 | def add_worker(self,
41 | resources_requirement: Dict[str, float],
42 | **kwargs: Dict[Any, Any]):
43 | """
44 | Add one worker to the cluster.
45 | :param resources_requirement: The resource requirements for the worker service.
46 | """
47 | try:
48 | self._set_up_worker(resources_requirement, kwargs)
49 | except:
50 | self.stop()
51 | raise
52 |
53 | @abstractmethod
54 | def _set_up_worker(self,
55 | resources: Dict[str, float],
56 | kwargs: Dict[str, str]):
57 | """
58 | Subcluster should implement this to set up worker node.
59 | """
60 |
61 | @abstractmethod
62 | def get_cluster_url(self) -> str:
63 | """
64 | Return the cluster url, eg: spark://master-node:7077
65 | """
66 |
67 | @abstractmethod
68 | def stop(self):
69 | """
70 | Stop cluster
71 | """
72 |
73 |
74 | class ClusterMaster(ABC):
75 |
76 | @abstractmethod
77 | def start_up(self) -> NoReturn:
78 | pass
79 |
80 | @abstractmethod
81 | def get_master_url(self) -> str:
82 | pass
83 |
84 | @abstractmethod
85 | def get_host(self) -> str:
86 | pass
87 |
88 | @abstractmethod
89 | def stop(self):
90 | pass
91 |
92 |
93 | class ClusterWorker(ABC):
94 |
95 | @abstractmethod
96 | def start_up(self) -> str:
97 | """
98 | :return: error message, return None if succeeded
99 | """
100 |
101 | @abstractmethod
102 | def get_host(self) -> str:
103 | pass
104 |
105 | @abstractmethod
106 | def stop(self):
107 | pass
108 |
--------------------------------------------------------------------------------
/python/raydp/spark/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .dataset import PartitionObjectsOwner, \
19 | get_raydp_master_owner, \
20 | spark_dataframe_to_ray_dataset, \
21 | ray_dataset_to_spark_dataframe, \
22 | from_spark_recoverable
23 | from .interfaces import SparkEstimatorInterface
24 | from .ray_cluster import SparkCluster
25 |
26 | __all__ = [
27 | "SparkCluster",
28 | "SparkEstimatorInterface",
29 | "PartitionObjectsOwner",
30 | "get_raydp_master_owner",
31 | "spark_dataframe_to_ray_dataset",
32 | "ray_dataset_to_spark_dataframe",
33 | "from_spark_recoverable"
34 | ]
35 |
--------------------------------------------------------------------------------
/python/raydp/spark/interfaces.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from typing import NoReturn
19 | from typing import Optional, Union
20 |
21 | from raydp.utils import convert_to_spark
22 |
23 | DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"]
24 | OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]]
25 |
26 |
27 | class SparkEstimatorInterface:
28 | def _check_and_convert(self, df):
29 | train_df, _ = convert_to_spark(df)
30 | return train_df
31 |
32 | def fit_on_spark(self,
33 | train_df: DF,
34 | evaluate_df: OPTIONAL_DF = None) -> NoReturn:
35 | """Fit and evaluate the model on the Spark or koalas DataFrame.
36 |
37 | :param train_df the DataFrame which the model will train on.
38 | :param evaluate_df the optional DataFrame which the model evaluate on it
39 | """
40 |
--------------------------------------------------------------------------------
/python/raydp/spark/parallel_iterator_worker.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | from typing import Any
19 |
20 | from ray.util.iter import ParallelIteratorWorker
21 |
22 |
23 | class ParallelIteratorWorkerWithLen(ParallelIteratorWorker):
24 | def __init__(self, item_generator: Any, repeat: bool, num_records: int):
25 | super().__init__(item_generator, repeat)
26 | self.num_records = num_records
27 |
28 | def __len__(self):
29 | return self.num_records
30 |
--------------------------------------------------------------------------------
/python/raydp/tests/conftest.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import logging
19 | import subprocess
20 | import time
21 |
22 | import pyspark
23 | import pytest
24 | import ray
25 | import raydp
26 | from pyspark.sql import SparkSession
27 |
28 |
29 | def quiet_logger():
30 | py4j_logger = logging.getLogger("py4j")
31 | py4j_logger.setLevel(logging.WARNING)
32 |
33 | koalas_logger = logging.getLogger("koalas")
34 | koalas_logger.setLevel(logging.WARNING)
35 |
36 |
37 | @pytest.fixture(scope="function")
38 | def spark_session(request):
39 | spark = SparkSession.builder.master("local[2]").appName("RayDP test").getOrCreate()
40 | request.addfinalizer(lambda: spark.stop())
41 | quiet_logger()
42 | return spark
43 |
44 |
45 | @pytest.fixture(scope="function", params=["local", "ray://localhost:10001"])
46 | def ray_cluster(request):
47 | ray.shutdown()
48 | if request.param == "local":
49 | ray.init(address="local", num_cpus=6, include_dashboard=False)
50 | else:
51 | ray.init(address=request.param)
52 | request.addfinalizer(lambda: ray.shutdown())
53 |
54 |
55 | @pytest.fixture(scope="function", params=["local", "ray://localhost:10001"])
56 | def spark_on_ray_small(request):
57 | ray.shutdown()
58 | if request.param == "local":
59 | ray.init(address="local", num_cpus=6, include_dashboard=False)
60 | else:
61 | ray.init(address=request.param)
62 | node_ip = ray.util.get_node_ip_address()
63 | spark = raydp.init_spark("test", 1, 1, "500M", configs={
64 | "spark.driver.host": node_ip,
65 | "spark.driver.bindAddress": node_ip
66 | })
67 |
68 | def stop_all():
69 | spark.stop()
70 | raydp.stop_spark()
71 | time.sleep(5)
72 | ray.shutdown()
73 |
74 | request.addfinalizer(stop_all)
75 | return spark
76 |
77 |
78 | @pytest.fixture(scope="function", params=["local", "ray://localhost:10001"])
79 | def spark_on_ray_2_executors(request):
80 | ray.shutdown()
81 | if request.param == "local":
82 | ray.init(address="local", num_cpus=6, include_dashboard=False)
83 | else:
84 | ray.init(address=request.param)
85 | node_ip = ray.util.get_node_ip_address()
86 | spark = raydp.init_spark("test", 2, 1, "500M", configs={
87 | "spark.driver.host": node_ip,
88 | "spark.driver.bindAddress": node_ip
89 | })
90 |
91 | def stop_all():
92 | spark.stop()
93 | raydp.stop_spark()
94 | time.sleep(5)
95 | ray.shutdown()
96 |
97 | request.addfinalizer(stop_all)
98 | return spark
99 |
100 | @pytest.fixture(scope='session')
101 | def custom_spark_dir(tmp_path_factory) -> str:
102 | working_dir = tmp_path_factory.mktemp("spark").as_posix()
103 |
104 | # Leave the if more verbose just in case the distribution name changed in the future.
105 | # Please make sure the version here is not the most recent release, so the file is available
106 | # in the archive download. Latest release's download URL (https://dlcdn.apache.org/spark/*)
107 | # will be changed to archive when the next release come out and break the test.
108 | if pyspark.__version__ == "3.2.1":
109 | spark_distribution = 'spark-3.2.1-bin-hadoop3.2'
110 | elif pyspark.__version__ == "3.1.3":
111 | spark_distribution = 'spark-3.1.3-bin-hadoop3.2'
112 | else:
113 | raise Exception(f"Unsupported Spark version {pyspark.__version__}.")
114 |
115 | file_extension = 'tgz'
116 | spark_distribution_file = f"{working_dir}/{spark_distribution}.{file_extension}"
117 |
118 | import wget
119 |
120 | wget.download(
121 | f"https://archive.apache.org/dist/spark/spark-{pyspark.__version__}/{spark_distribution}.{file_extension}",
122 | spark_distribution_file)
123 | subprocess.check_output(['tar', 'xzvf', spark_distribution_file, '--directory', working_dir])
124 | return f"{working_dir}/{spark_distribution}"
125 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_tf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import pyspark
19 | import pytest
20 | import os
21 | import sys
22 | import shutil
23 |
24 | import tensorflow as tf
25 | import tensorflow.keras as keras
26 |
27 | from pyspark.sql.functions import rand
28 |
29 | from raydp.tf import TFEstimator
30 | from raydp.utils import random_split
31 |
32 | @pytest.mark.parametrize("use_fs_directory", [True, False])
33 | def test_tf_estimator(spark_on_ray_small, use_fs_directory):
34 | spark = spark_on_ray_small
35 |
36 | # ---------------- data process with Spark ------------
37 | # calculate y = 3 * x + 4
38 | df: pyspark.sql.DataFrame = spark.range(0, 100000)
39 | df = df.withColumn("x", rand() * 100) # add x column
40 | df = df.withColumn("y", df.x * 3 + rand() + 4) # add y column
41 | df = df.select(df.x, df.y)
42 |
43 | train_df, test_df = random_split(df, [0.7, 0.3])
44 |
45 | # create model
46 | model = keras.Sequential(
47 | [
48 | keras.layers.InputLayer(input_shape=()),
49 | # Add feature dimension, expanding (batch_size,) to (batch_size, 1).
50 | keras.layers.Flatten(),
51 | keras.layers.Dense(1),
52 | ]
53 | )
54 |
55 | optimizer = keras.optimizers.Adam(0.01)
56 | loss = keras.losses.MeanSquaredError()
57 |
58 | estimator = TFEstimator(num_workers=2,
59 | model=model,
60 | optimizer=optimizer,
61 | loss=loss,
62 | metrics=["accuracy", "mse"],
63 | feature_columns="x",
64 | label_columns="y",
65 | batch_size=1000,
66 | num_epochs=2,
67 | use_gpu=False)
68 |
69 | if use_fs_directory:
70 | dir = os.path.dirname(__file__) + "/test_tf"
71 | uri = "file://" + dir
72 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
73 | else:
74 | estimator.fit_on_spark(train_df, test_df)
75 | model = estimator.get_model()
76 | result = model(tf.constant([0, 0]))
77 | assert result.shape == (2, 1)
78 | if use_fs_directory:
79 | shutil.rmtree(dir)
80 |
81 | if __name__ == "__main__":
82 | # sys.exit(pytest.main(["-v", __file__]))
83 | import ray, raydp
84 | ray.init()
85 | spark = raydp.init_spark('a', 6, 1, '500m')
86 | test_tf_estimator(spark, False)
87 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_torch.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import pytest
19 | import os
20 | import sys
21 | import shutil
22 | import torch
23 |
24 | # https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html
25 | # import databricks.koalas as ks
26 | import pyspark.pandas as ps
27 |
28 | from raydp.torch import TorchEstimator
29 | from raydp.utils import random_split
30 |
31 | @pytest.mark.parametrize("use_fs_directory", [True, False])
32 | def test_torch_estimator(spark_on_ray_small, use_fs_directory):
33 | # ---------------- data process with koalas ------------
34 | spark = spark_on_ray_small
35 |
36 | # calculate z = 3 * x + 4 * y + 5
37 | df: ps.DataFrame = ps.range(0, 100000)
38 | df["x"] = df["id"] + 100
39 | df["y"] = df["id"] + 1000
40 | df["z"] = df["x"] * 3 + df["y"] * 4 + 5
41 | df = df.astype("float")
42 |
43 | train_df, test_df = random_split(df, [0.7, 0.3])
44 |
45 | # ---------------- ray sgd -------------------------
46 | # create the model
47 | class LinearModel(torch.nn.Module):
48 | def __init__(self):
49 | super(LinearModel, self).__init__()
50 | self.linear = torch.nn.Linear(2, 1)
51 |
52 | def forward(self, x):
53 | return self.linear(x)
54 |
55 | model = LinearModel()
56 | # create the optimizer
57 | optimizer = torch.optim.Adam(model.parameters())
58 | # create the loss
59 | loss = torch.nn.MSELoss()
60 | # create lr_scheduler
61 |
62 | def lr_scheduler_creator(optimizer, config):
63 | return torch.optim.lr_scheduler.MultiStepLR(
64 | optimizer, milestones=[150, 250, 350], gamma=0.1)
65 |
66 | # create the estimator
67 | estimator = TorchEstimator(num_workers=2,
68 | model=model,
69 | optimizer=optimizer,
70 | loss=loss,
71 | lr_scheduler_creator=lr_scheduler_creator,
72 | feature_columns=["x", "y"],
73 | feature_types=torch.float,
74 | label_column="z",
75 | label_type=torch.float,
76 | batch_size=1000,
77 | num_epochs=2,
78 | use_gpu=False)
79 |
80 | # train the model
81 | if use_fs_directory:
82 | dir = os.path.dirname(__file__) + "/test_torch"
83 | uri = "file://" + dir
84 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
85 | else:
86 | estimator.fit_on_spark(train_df, test_df)
87 | model = estimator.get_model()
88 | result = model(torch.Tensor([[0, 0], [1, 1]]))
89 | assert result.shape == (2, 1)
90 | if use_fs_directory:
91 | shutil.rmtree(dir)
92 |
93 |
94 | if __name__ == "__main__":
95 | sys.exit(pytest.main(["-v", __file__]))
96 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_torch_sequential.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import pytest
19 | import sys
20 | import torch
21 | import raydp
22 | from raydp.torch import TorchEstimator
23 |
24 | def test_torch_estimator(spark_on_ray_small):
25 | ##prepare the data
26 | customers = [
27 | (1,'James', 21, 6),
28 | (2, "Liz", 25, 8),
29 | (3, "John", 31, 6),
30 | (4, "Jennifer", 45, 7),
31 | (5, "Robert", 41, 5),
32 | (6, "Sandra", 45, 8)
33 | ]
34 | df = spark_on_ray_small.createDataFrame(customers, ["cID", "name", "age", "grade"])
35 |
36 | ##create model
37 | model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2,1))
38 | optimizer = torch.optim.Adam(model.parameters())
39 | loss = torch.nn.MSELoss()
40 |
41 | #config
42 | estimator = TorchEstimator(
43 | model = model,
44 | optimizer = optimizer,
45 | loss = loss,
46 | num_workers = 3,
47 | num_epochs = 5,
48 | feature_columns = ["age"],
49 | feature_types = torch.float,
50 | label_column = "grade",
51 | label_type = torch.float,
52 | batch_size = 1
53 | )
54 | estimator.fit_on_spark(df)
55 |
56 | if __name__ == "__main__":
57 | sys.exit(pytest.main(["-v", __file__]))
--------------------------------------------------------------------------------
/python/raydp/tests/test_xgboost.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import os
19 | import sys
20 | import shutil
21 | import platform
22 | import pytest
23 | import pyspark
24 | import numpy as np
25 | from pyspark.sql.functions import rand
26 |
27 | from raydp.xgboost import XGBoostEstimator
28 | from raydp.utils import random_split
29 |
30 | @pytest.mark.parametrize("use_fs_directory", [True, False])
31 | def test_xgb_estimator(spark_on_ray_small, use_fs_directory):
32 | if platform.system() == "Darwin":
33 | pytest.skip("Skip xgboost test on MacOS")
34 | spark = spark_on_ray_small
35 |
36 | # calculate z = 3 * x + 4 * y + 5
37 | df: pyspark.sql.DataFrame = spark.range(0, 100000)
38 | df = df.withColumn("x", rand() * 100) # add x column
39 | df = df.withColumn("y", rand() * 1000) # ad y column
40 | df = df.withColumn("z", df.x * 3 + df.y * 4 + rand() + 5) # ad z column
41 | df = df.select(df.x, df.y, df.z)
42 |
43 | train_df, test_df = random_split(df, [0.7, 0.3])
44 | params = {}
45 | estimator = XGBoostEstimator(params, "z", resources_per_worker={"CPU": 1})
46 | if use_fs_directory:
47 | dir = os.path.dirname(os.path.realpath(__file__)) + "/test_xgboost"
48 | uri = "file://" + dir
49 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
50 | else:
51 | estimator.fit_on_spark(train_df, test_df)
52 | print(estimator.get_model().inplace_predict(np.asarray([[1,2]])))
53 | if use_fs_directory:
54 | shutil.rmtree(dir)
55 |
56 | if __name__ == '__main__':
57 | import ray, raydp
58 | ray.init(address="auto")
59 | spark = raydp.init_spark('test_xgboost', 1, 1, '500m')
60 | test_xgb_estimator(spark, True)
--------------------------------------------------------------------------------
/python/raydp/tf/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .estimator import TFEstimator
19 |
20 | __all__ = ["TFEstimator"]
21 |
--------------------------------------------------------------------------------
/python/raydp/torch/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .estimator import TorchEstimator
19 |
20 | __all__ = ["TorchEstimator"]
21 |
--------------------------------------------------------------------------------
/python/raydp/torch/config.py:
--------------------------------------------------------------------------------
1 | from ray.train.torch.config import _TorchBackend
2 | from ray.train.torch.config import TorchConfig as RayTorchConfig
3 | from ray.train._internal.worker_group import WorkerGroup
4 | from dataclasses import dataclass
5 | import sys
6 | # The package importlib_metadata is in a different place, depending on the Python version.
7 | if sys.version_info < (3, 8):
8 | import importlib_metadata
9 | else:
10 | import importlib.metadata as importlib_metadata
11 |
12 | @dataclass
13 | class TorchConfig(RayTorchConfig):
14 |
15 | @property
16 | def backend_cls(self):
17 | return EnableCCLBackend
18 |
19 | def libs_import():
20 | """try to import IPEX and oneCCL.
21 | """
22 | try:
23 | import intel_extension_for_pytorch
24 | except ImportError:
25 | raise ImportError(
26 | "Please install intel_extension_for_pytorch"
27 | )
28 | try:
29 | ccl_version = importlib_metadata.version("oneccl_bind_pt")
30 | if ccl_version >= "1.12":
31 | # pylint: disable-all
32 | import oneccl_bindings_for_pytorch
33 | else:
34 | import torch_ccl
35 | except ImportError as ccl_not_exist:
36 | raise ImportError(
37 | "Please install torch-ccl"
38 | ) from ccl_not_exist
39 |
40 | class EnableCCLBackend(_TorchBackend):
41 |
42 | def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
43 | for i in range(len(worker_group)):
44 | worker_group.execute_single_async(i, libs_import)
45 | super().on_start(worker_group, backend_config)
46 |
--------------------------------------------------------------------------------
/python/raydp/torch/torch_metrics.py:
--------------------------------------------------------------------------------
1 | import sys
2 | module = sys.modules[__name__]
3 |
4 | def try_import_torchmetrics():
5 | """Tries importing torchmetrics and returns the module (or None).
6 | Returns:
7 | torchmetrics modules.
8 | """
9 | try:
10 | # pylint: disable=import-outside-toplevel
11 | import torchmetrics
12 |
13 | return torchmetrics
14 | except ImportError as torchmetrics_not_exist:
15 | raise ImportError(
16 | "Could not import torchmetrics! Raydp TorchEstimator requires "
17 | "you to install torchmetrics: "
18 | "`pip install torchmetrics`."
19 | ) from torchmetrics_not_exist
20 |
21 | class TorchMetric():
22 | def __init__(self, metrics_name, metrics_config):
23 | torchmetrics = try_import_torchmetrics()
24 | self._metrics_name = metrics_name
25 | self._metrics_func = {}
26 | if self._metrics_name is not None:
27 | assert isinstance(metrics_name, list), "metrics_name must be a list"
28 | for metric in self._metrics_name:
29 | if isinstance(metric, torchmetrics.Metric):
30 | self._metrics_func[metric.__class__.__name__] = metric
31 | elif isinstance(metric, str) and hasattr(torchmetrics, metric):
32 | if metrics_config is not None and metrics_config[metric] is not None:
33 | self._metrics_func[metric] = getattr(torchmetrics, metric)(
34 | **metrics_config[metric])
35 | else:
36 | self._metrics_func[metric] = getattr(torchmetrics, metric)()
37 | else:
38 | raise Exception(
39 | "Unsupported parameter, we only support list of "
40 | "torchmetrics.Metric instances or arr of torchmetrics.")
41 |
42 | def update(self, preds, targets):
43 | for metric in self._metrics_func:
44 | self._metrics_func[metric].update(preds, targets)
45 |
46 | def compute(self):
47 | epoch_res = {}
48 | for metric in self._metrics_func:
49 | epoch_res[metric] = self._metrics_func[metric].compute().item()
50 |
51 | return epoch_res
52 |
53 | def reset(self):
54 | for metric in self._metrics_func:
55 | self._metrics_func[metric].reset()
56 |
--------------------------------------------------------------------------------
/python/raydp/torch/torch_ml_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import logging
19 | import queue
20 | import threading
21 | from typing import Callable
22 |
23 | import ray
24 | from ray.util.data import MLDataset
25 | from torch.utils.data import IterableDataset
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | class TorchMLDataset(IterableDataset):
31 | def __init__(self,
32 | ds: MLDataset,
33 | collate_fn: Callable,
34 | shuffle: bool = False,
35 | shuffle_seed: int = None):
36 | super().__init__()
37 | self.ds = ds
38 | self.collate_fn = collate_fn
39 | self.shuffle = shuffle
40 | self.shuffle_seed = shuffle_seed or 1
41 |
42 | def __iter__(self):
43 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards())
44 | it = iter(it)
45 | for pdf in it:
46 | if self.shuffle:
47 | pdf = pdf.sample(frac=1.0, random_state=self.shuffle_seed)
48 | yield self.collate_fn(pdf)
49 |
50 | def __len__(self):
51 | all_actors = []
52 | for actor_set in self.ds.actor_sets:
53 | all_actors.extend(actor_set.actors)
54 | assert len(all_actors) > 0
55 | if "__len__" in dir(all_actors[0]):
56 | # This is a very hack method to get the length of the iterator
57 | num_records = sum([ray.get(actor.__len__.remote()) for actor in all_actors])
58 | else:
59 | logger.warning("The MLDataset has not provide the __len__ method, we will iter all "
60 | "data to count the number of rows. This should be pretty slowly.")
61 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards())
62 | it = iter(it)
63 | num_records = 0
64 | for pdf in it:
65 | num_records += pdf.shape[0]
66 | return num_records
67 |
68 |
69 | class PrefetchedDataLoader:
70 | def __init__(self, base_loader, max_size: int = 5):
71 | self.base_loader = base_loader
72 | self.max_size = max_size
73 | self.queue = queue.Queue(maxsize=max_size)
74 | self.fetcher = None
75 | self.fetcher_stop = threading.Event()
76 |
77 | def _setup(self):
78 | if self.fetcher is not None:
79 | self.fetcher_stop.set()
80 | if self.queue is not None and not self.queue.empty():
81 | self.queue.get()
82 | self.queue = queue.Queue(maxsize=self.max_size)
83 | self.fetcher = None
84 | self.fetcher_stop.clear()
85 |
86 | it = iter(self.base_loader)
87 |
88 | def fetch_task():
89 | while not self.fetcher_stop.is_set():
90 | try:
91 | got_data = next(it)
92 | self.queue.put(got_data)
93 | except StopIteration:
94 | self.queue.put(None)
95 | break
96 | except: # pylint: disable=W0707, W0706
97 | raise
98 | self.fetcher = threading.Thread(target=fetch_task)
99 | self.fetcher.start()
100 |
101 | def __iter__(self):
102 | self._setup()
103 | while True:
104 | fetched_data = self.queue.get()
105 | if fetched_data is not None:
106 | yield fetched_data
107 | else:
108 | break
109 |
110 | def __len__(self):
111 | return len(self.base_loader)
112 |
--------------------------------------------------------------------------------
/python/raydp/versions.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import re
19 | import pyspark
20 |
21 |
22 | # log4j1 if spark version <= 3.2, otherwise, log4j2
23 | SPARK_LOG4J_VERSION = "log4j"
24 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j.configurationFile"
25 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j-default.properties"
26 | _spark_ver = re.search("\\d+\\.\\d+", pyspark.version.__version__)
27 | if _spark_ver.group(0) > "3.2":
28 | SPARK_LOG4J_VERSION = "log4j2"
29 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile"
30 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2-default.properties"
31 |
32 | # support ray >= 2.1, they all use log4j2
33 | RAY_LOG4J_VERSION = "log4j2"
34 | RAY_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile"
35 | RAY_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2.xml"
36 |
--------------------------------------------------------------------------------
/python/raydp/xgboost/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .estimator import XGBoostEstimator
19 |
20 | __all__ = ["XGBoostEstimator"]
21 |
--------------------------------------------------------------------------------