├── core
├── shims
│ ├── spark322
│ │ ├── src
│ │ │ └── main
│ │ │ │ ├── resources
│ │ │ │ └── META-INF
│ │ │ │ │ └── services
│ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ │ │ └── scala
│ │ │ │ ├── org
│ │ │ │ └── apache
│ │ │ │ │ └── spark
│ │ │ │ │ ├── TaskContextUtils.scala
│ │ │ │ │ ├── sql
│ │ │ │ │ └── SparkSqlUtils.scala
│ │ │ │ │ └── executor
│ │ │ │ │ └── RayDPSpark322ExecutorBackendFactory.scala
│ │ │ │ └── com
│ │ │ │ └── intel
│ │ │ │ └── raydp
│ │ │ │ └── shims
│ │ │ │ ├── SparkShimProvider.scala
│ │ │ │ └── SparkShims.scala
│ │ └── pom.xml
│ ├── spark330
│ │ ├── src
│ │ │ └── main
│ │ │ │ ├── resources
│ │ │ │ └── META-INF
│ │ │ │ │ └── services
│ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ │ │ └── scala
│ │ │ │ ├── org
│ │ │ │ └── apache
│ │ │ │ │ └── spark
│ │ │ │ │ ├── TaskContextUtils.scala
│ │ │ │ │ ├── sql
│ │ │ │ │ └── SparkSqlUtils.scala
│ │ │ │ │ └── executor
│ │ │ │ │ ├── RayCoarseGrainedExecutorBackend.scala
│ │ │ │ │ └── RayDPSpark330ExecutorBackendFactory.scala
│ │ │ │ └── com
│ │ │ │ └── intel
│ │ │ │ └── raydp
│ │ │ │ └── shims
│ │ │ │ ├── SparkShimProvider.scala
│ │ │ │ └── SparkShims.scala
│ │ └── pom.xml
│ ├── spark340
│ │ ├── src
│ │ │ └── main
│ │ │ │ ├── resources
│ │ │ │ └── META-INF
│ │ │ │ │ └── services
│ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ │ │ └── scala
│ │ │ │ ├── org
│ │ │ │ └── apache
│ │ │ │ │ └── spark
│ │ │ │ │ ├── TaskContextUtils.scala
│ │ │ │ │ ├── executor
│ │ │ │ │ ├── RayCoarseGrainedExecutorBackend.scala
│ │ │ │ │ └── RayDPSpark340ExecutorBackendFactory.scala
│ │ │ │ │ └── sql
│ │ │ │ │ └── SparkSqlUtils.scala
│ │ │ │ └── com
│ │ │ │ └── intel
│ │ │ │ └── raydp
│ │ │ │ └── shims
│ │ │ │ ├── SparkShimProvider.scala
│ │ │ │ └── SparkShims.scala
│ │ └── pom.xml
│ ├── spark350
│ │ ├── src
│ │ │ └── main
│ │ │ │ ├── resources
│ │ │ │ └── META-INF
│ │ │ │ │ └── services
│ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider
│ │ │ │ └── scala
│ │ │ │ ├── org
│ │ │ │ └── apache
│ │ │ │ │ └── spark
│ │ │ │ │ ├── TaskContextUtils.scala
│ │ │ │ │ ├── executor
│ │ │ │ │ ├── RayCoarseGrainedExecutorBackend.scala
│ │ │ │ │ └── RayDPSpark350ExecutorBackendFactory.scala
│ │ │ │ │ └── sql
│ │ │ │ │ └── SparkSqlUtils.scala
│ │ │ │ └── com
│ │ │ │ └── intel
│ │ │ │ └── raydp
│ │ │ │ └── shims
│ │ │ │ ├── SparkShimProvider.scala
│ │ │ │ └── SparkShims.scala
│ │ └── pom.xml
│ ├── common
│ │ └── src
│ │ │ └── main
│ │ │ └── scala
│ │ │ ├── com
│ │ │ └── intel
│ │ │ │ └── raydp
│ │ │ │ └── shims
│ │ │ │ ├── SparkShimProvider.scala
│ │ │ │ ├── SparkShims.scala
│ │ │ │ └── SparkShimLoader.scala
│ │ │ └── org
│ │ │ └── apache
│ │ │ └── spark
│ │ │ └── executor
│ │ │ └── RayDPExecutorBackendFactory.scala
│ └── pom.xml
├── raydp-main
│ └── src
│ │ └── main
│ │ ├── resources
│ │ └── META-INF
│ │ │ └── services
│ │ │ └── org.apache.spark.scheduler.ExternalClusterManager
│ │ ├── scala
│ │ └── org
│ │ │ └── apache
│ │ │ └── spark
│ │ │ ├── RayDPException.scala
│ │ │ ├── deploy
│ │ │ └── raydp
│ │ │ │ ├── ApplicationState.scala
│ │ │ │ ├── Messages.scala
│ │ │ │ ├── RayExternalShuffleService.scala
│ │ │ │ ├── ApplicationDescription.scala
│ │ │ │ ├── RayDPDriverAgent.scala
│ │ │ │ ├── AppMasterEntryPoint.scala
│ │ │ │ └── AppMasterJavaBridge.scala
│ │ │ ├── scheduler
│ │ │ └── cluster
│ │ │ │ └── raydp
│ │ │ │ └── RayClusterManager.scala
│ │ │ ├── rdd
│ │ │ ├── RayObjectRefRDD.scala
│ │ │ └── RayDatasetRDD.scala
│ │ │ └── sql
│ │ │ └── raydp
│ │ │ └── ObjectStoreReader.scala
│ │ ├── java
│ │ └── org
│ │ │ └── apache
│ │ │ └── spark
│ │ │ ├── deploy
│ │ │ └── raydp
│ │ │ │ ├── ExternalShuffleServiceUtils.java
│ │ │ │ └── RayAppMasterUtils.java
│ │ │ └── raydp
│ │ │ ├── RayDPUtils.java
│ │ │ └── RayExecutorUtils.java
│ │ └── test
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── scheduler
│ │ └── cluster
│ │ └── raydp
│ │ └── TestRayCoarseGrainedSchedulerBackend.java
├── agent
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ └── java
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── raydp
│ │ └── Agent.java
└── javastyle-suppressions.xml
├── examples
├── test_pyfile.py
├── random_nyctaxi.py
├── README.md
├── raydp-submit.py
├── test_pyfiles_main.py
├── test_raydp_submit_pyfiles.py
├── xgboost_ray_nyctaxi.py
├── tensorflow_nyctaxi.py
└── pytorch_nyctaxi.py
├── .gitignore
├── SECURITY.md
├── docker
├── Dockerfile
├── build-docker.sh
└── README.md
├── python
├── raydp
│ ├── tf
│ │ └── __init__.py
│ ├── torch
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── torch_metrics.py
│ │ └── torch_ml_dataset.py
│ ├── xgboost
│ │ └── __init__.py
│ ├── __init__.py
│ ├── mpi
│ │ ├── network
│ │ │ ├── __init__.py
│ │ │ ├── network.proto
│ │ │ └── network_pb2.py
│ │ ├── constants.py
│ │ ├── utils.py
│ │ └── __init__.py
│ ├── spark
│ │ ├── parallel_iterator_worker.py
│ │ ├── __init__.py
│ │ └── interfaces.py
│ ├── estimator.py
│ ├── versions.py
│ ├── tests
│ │ ├── test_torch_sequential.py
│ │ ├── test_spark_master_memory.py
│ │ ├── test_xgboost.py
│ │ ├── test_tf.py
│ │ └── test_torch.py
│ ├── ray_cluster_resources.py
│ └── services.py
└── MANIFEST.in
├── .github
└── workflows
│ ├── pypi.yml
│ └── pypi_release.yml
└── doc
└── mpi.md
/core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark322.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark330.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark340.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider:
--------------------------------------------------------------------------------
1 | com.intel.raydp.shims.spark350.SparkShimProvider
2 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager:
--------------------------------------------------------------------------------
1 | org.apache.spark.scheduler.cluster.raydp.RayClusterManager
--------------------------------------------------------------------------------
/examples/test_pyfile.py:
--------------------------------------------------------------------------------
1 | """
2 | Helper module for testing raydp-submit --py-files functionality.
3 | This module defines functions that will be used by the main.py script.
4 | """
5 |
6 | def compute_sum(numbers):
7 | """Compute sum to verify module functionality in Spark."""
8 | return sum(numbers)
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.iml
3 |
4 | __pycache__/
5 | build/
6 | dist/
7 | *.egg-info/
8 | *.eggs/
9 |
10 | dev/.tmp_dir/
11 | target/
12 | *.jar
13 |
14 | .DS_Store
15 |
16 | .vscode
17 | examples/.ipynb_checkpoints/
18 | .python-version
19 |
20 | # Vim temp files
21 | *.swp
22 | *.swo
23 | *.parquet
24 | *.crc
25 | _SUCCESS
26 |
27 | .metals/
28 | .bloop/
29 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Report a Vulnerability
4 |
5 | Please report security issues or vulnerabilities to the [Intel® Security Center].
6 |
7 | For more information on how Intel® works to resolve security issues, see
8 | [Vulnerability Handling Guidelines].
9 |
10 | [Intel® Security Center]:https://www.intel.com/security
11 |
12 | [Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html
13 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM rayproject/ray:latest@sha256:c864e37f4ce516ff49425f69cac5503a51e84c333d30928416714a2c3da55b43
2 |
3 | ARG HTTP_PROXY
4 | ARG HTTPS_PROXY
5 |
6 | # set http_proxy & https_proxy
7 | ENV http_proxy=${HTTP_PROXY}
8 | ENV https_proxy=${HTTPS_PROXY}
9 |
10 | # install java, create workdir and install raydp
11 | # You could change the raydp to raydp-nightly if you want to try the master branch code
12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \
13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \
14 | && sudo mkdir /raydp \
15 | && sudo chown -R ray /raydp \
16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp
17 |
18 | WORKDIR /raydp
19 |
20 | # unset http_proxy & https_proxy
21 | ENV http_proxy=
22 | ENV https_proxy=
23 |
--------------------------------------------------------------------------------
/python/raydp/tf/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .estimator import TFEstimator
19 |
20 | __all__ = ["TFEstimator"]
21 |
--------------------------------------------------------------------------------
/python/raydp/torch/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .estimator import TorchEstimator
19 |
20 | __all__ = ["TorchEstimator"]
21 |
--------------------------------------------------------------------------------
/python/raydp/xgboost/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .estimator import XGBoostEstimator
19 |
20 | __all__ = ["XGBoostEstimator"]
21 |
--------------------------------------------------------------------------------
/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | include README.md
19 | recursive-include raydp/jars *.jar
20 | global-exclude *.py[cod] __pycache__ .DS_Store
--------------------------------------------------------------------------------
/python/raydp/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from raydp.context import init_spark, stop_spark
19 |
20 | __version__ = "1.7.0.dev0"
21 |
22 | __all__ = ["init_spark", "stop_spark"]
23 |
--------------------------------------------------------------------------------
/python/raydp/mpi/network/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import sys
19 | import os
20 |
21 | dir_path = os.path.dirname(os.path.realpath(__file__))
22 | sys.path.append(str(dir_path))
23 |
--------------------------------------------------------------------------------
/docker/build-docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | docker build --build-arg HTTP_PROXY=${http_proxy} \
21 | --build-arg HTTPS_PROXY=${https_proxy} \
22 | -t oap-project/raydp:latest .
23 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/RayDPException.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark
19 |
20 | class RayDPException(message: String, cause: Throwable)
21 | extends SparkException(message, cause) {
22 | def this(message: String) = this(message, null)
23 | }
24 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationState.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | object ApplicationState extends Enumeration {
21 |
22 | type ApplicationState = Value
23 |
24 | val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value
25 | }
26 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims
19 |
20 | /**
21 | * Provider interface for matching and retrieving the Shims of a specific Spark version
22 | */
23 | trait SparkShimProvider {
24 | def matches(version:String): Boolean
25 | def createShim: SparkShims
26 | }
27 |
--------------------------------------------------------------------------------
/python/raydp/mpi/constants.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from os import path
19 |
20 | MPI_TYPE = "raydp_mpi_type"
21 | MPI_JOB_ID = "raydp_mpi_job_id"
22 | MPI_DRIVER_HOST = "raydp_mpi_driver_host"
23 | MPI_DRIVER_PORT = "raydp_mpi_driver_port"
24 |
25 | MAXIMUM_WAIT_TIME_OUT = "raydp_maximum_wait_time_out"
26 |
27 | _current_dir = path.dirname(path.realpath(__file__))
28 | MPI_MAIN_CLASS_PATH = path.join(_current_dir, "mpi_worker.py")
29 |
--------------------------------------------------------------------------------
/python/raydp/spark/parallel_iterator_worker.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | from typing import Any
19 |
20 | from ray.util.iter import ParallelIteratorWorker
21 |
22 |
23 | class ParallelIteratorWorkerWithLen(ParallelIteratorWorker):
24 | def __init__(self, item_generator: Any, repeat: bool, num_records: int):
25 | super().__init__(item_generator, repeat)
26 | self.num_records = num_records
27 |
28 | def __len__(self):
29 | return self.num_records
30 |
--------------------------------------------------------------------------------
/core/shims/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-parent
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims
15 | RayDP Shims
16 | pom
17 |
18 |
19 | common
20 | spark322
21 | spark330
22 | spark340
23 | spark350
24 |
25 |
26 |
27 | 2.12
28 | 4.3.0
29 | 3.2.2
30 |
31 |
32 |
33 |
34 |
35 | net.alchim31.maven
36 | scala-maven-plugin
37 | ${scala.plugin.version}
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark322
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark330
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark340
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.spark350
19 |
20 | import java.util.Properties
21 |
22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
23 | import org.apache.spark.memory.TaskMemoryManager
24 |
25 | object TaskContextUtils {
26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0,
28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem)
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/examples/random_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | base_date = np.datetime64("2010-01-01 00:00:00")
8 |
9 | parser = argparse.ArgumentParser(description="Rabdin NYC taxi Generator")
10 | parser.add_argument(
11 | "--num-records",
12 | type=int,
13 | default=2000,
14 | metavar="N",
15 | help="number of records to generate (default: 2000)")
16 |
17 | args = parser.parse_args()
18 |
19 | N = args.num_records
20 |
21 | fare_amount = np.random.uniform(3.0, 50.0, size=N)
22 | pick_long = np.random.uniform(-74.2, -73.8, size=N)
23 | pick_lat = np.random.uniform(40.7, 40.8, size=N)
24 | drop_long = np.random.uniform(-74.2, -73.8, size=N)
25 | drop_lat = np.random.uniform(40.7, 40.8, size=N)
26 | passenger_count = np.random.randint(1, 5, size=N)
27 | date = np.random.randint(0, 157680000, size=N) + base_date
28 | date = np.array([t.item().strftime("%Y-%m-%d %H:%m:%S UTC") for t in date])
29 | key = ["fake_key"] * N
30 | df = pd.DataFrame({
31 | "key": key,
32 | "fare_amount":fare_amount,
33 | "pickup_datetime": date,
34 | "pickup_longitude": pick_long,
35 | "pickup_latitude": pick_lat,
36 | "dropoff_longitude": drop_long,
37 | "dropoff_latitude": drop_lat,
38 | "passenger_count": passenger_count
39 | })
40 | csv_path = os.path.dirname(os.path.realpath(__file__)) + "/fake_nyctaxi.csv"
41 | df.to_csv(csv_path, index=False)
42 |
--------------------------------------------------------------------------------
/python/raydp/spark/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from .dataset import PartitionObjectsOwner, \
19 | get_raydp_master_owner, \
20 | spark_dataframe_to_ray_dataset, \
21 | ray_dataset_to_spark_dataframe, \
22 | from_spark_recoverable
23 | from .interfaces import SparkEstimatorInterface
24 | from .ray_cluster import SparkCluster
25 |
26 | __all__ = [
27 | "SparkCluster",
28 | "SparkEstimatorInterface",
29 | "PartitionObjectsOwner",
30 | "get_raydp_master_owner",
31 | "spark_dataframe_to_ray_dataset",
32 | "ray_dataset_to_spark_dataframe",
33 | "from_spark_recoverable"
34 | ]
35 |
--------------------------------------------------------------------------------
/python/raydp/estimator.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, NoReturn, Optional
20 |
21 |
22 |
23 | class EstimatorInterface(ABC):
24 | """
25 | A scikit-learn like API.
26 | """
27 |
28 | @abstractmethod
29 | def fit(self,
30 | train_ds,
31 | evaluate_ds = None) -> NoReturn:
32 | """Train or evaluate the model.
33 |
34 | :param train_ds: the model will train on the MLDataset
35 | :param evaluate_ds: if this is provided, the model will evaluate on the MLDataset
36 | """
37 |
38 | @abstractmethod
39 | def get_model(self) -> Any:
40 | """Get the trained model
41 |
42 | :return the model
43 | """
44 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/org/apache/spark/executor/RayDPExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.rpc.RpcEnv
24 | import org.apache.spark.resource.ResourceProfile
25 |
26 | trait RayDPExecutorBackendFactory {
27 | def createExecutorBackend(
28 | rpcEnv: RpcEnv,
29 | driverUrl: String,
30 | executorId: String,
31 | bindAddress: String,
32 | hostname: String,
33 | cores: Int,
34 | userClassPath: Seq[URL],
35 | env: SparkEnv,
36 | resourcesFileOpt: Option[String],
37 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend
38 | }
39 |
--------------------------------------------------------------------------------
/python/raydp/versions.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import re
19 | import pyspark
20 |
21 |
22 | # log4j1 if spark version <= 3.2, otherwise, log4j2
23 | SPARK_LOG4J_VERSION = "log4j"
24 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j.configurationFile"
25 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j-default.properties"
26 | _spark_ver = re.search("\\d+\\.\\d+", pyspark.version.__version__)
27 | if _spark_ver.group(0) > "3.2":
28 | SPARK_LOG4J_VERSION = "log4j2"
29 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile"
30 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2-default.properties"
31 |
32 | # support ray >= 2.1, they all use log4j2
33 | RAY_LOG4J_VERSION = "log4j2"
34 | RAY_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile"
35 | RAY_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2.xml"
36 |
--------------------------------------------------------------------------------
/python/raydp/torch/config.py:
--------------------------------------------------------------------------------
1 | from ray.train.torch.config import _TorchBackend
2 | from ray.train.torch.config import TorchConfig as RayTorchConfig
3 | from ray.train._internal.worker_group import WorkerGroup
4 | from dataclasses import dataclass
5 | import sys
6 | # The package importlib_metadata is in a different place, depending on the Python version.
7 | if sys.version_info < (3, 8):
8 | import importlib_metadata
9 | else:
10 | import importlib.metadata as importlib_metadata
11 |
12 | @dataclass
13 | class TorchConfig(RayTorchConfig):
14 |
15 | @property
16 | def backend_cls(self):
17 | return EnableCCLBackend
18 |
19 | def libs_import():
20 | """try to import IPEX and oneCCL.
21 | """
22 | try:
23 | import intel_extension_for_pytorch
24 | except ImportError:
25 | raise ImportError(
26 | "Please install intel_extension_for_pytorch"
27 | )
28 | try:
29 | ccl_version = importlib_metadata.version("oneccl_bind_pt")
30 | if ccl_version >= "1.12":
31 | # pylint: disable-all
32 | import oneccl_bindings_for_pytorch
33 | else:
34 | import torch_ccl
35 | except ImportError as ccl_not_exist:
36 | raise ImportError(
37 | "Please install torch-ccl"
38 | ) from ccl_not_exist
39 |
40 | class EnableCCLBackend(_TorchBackend):
41 |
42 | def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
43 | for i in range(len(worker_group)):
44 | worker_group.execute_single_async(i, libs_import)
45 | super().on_start(worker_group, backend_config)
46 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark330
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
23 | import org.apache.spark.sql.execution.arrow.ArrowConverters
24 | import org.apache.spark.sql.types.StructType
25 | import org.apache.spark.sql.util.ArrowUtils
26 |
27 | object SparkSqlUtils {
28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
29 | ArrowConverters.toDataFrame(rdd, schema, session)
30 | }
31 |
32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark322
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
23 | import org.apache.spark.sql.execution.arrow.ArrowConverters
24 | import org.apache.spark.sql.types.StructType
25 | import org.apache.spark.sql.util.ArrowUtils
26 |
27 | object SparkSqlUtils {
28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = {
29 | ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session))
30 | }
31 |
32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/python/raydp/spark/interfaces.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from typing import NoReturn
19 | from typing import Optional, Union
20 |
21 | from raydp.utils import convert_to_spark
22 |
23 | DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"]
24 | OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]]
25 |
26 |
27 | class SparkEstimatorInterface:
28 | def _check_and_convert(self, df):
29 | train_df, _ = convert_to_spark(df)
30 | return train_df
31 |
32 | def fit_on_spark(self,
33 | train_df: DF,
34 | evaluate_df: OPTIONAL_DF = None) -> NoReturn:
35 | """Fit and evaluate the model on the Spark or koalas DataFrame.
36 |
37 | :param train_df the DataFrame which the model will train on.
38 | :param evaluate_df the optional DataFrame which the model evaluate on it
39 | """
40 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.resource.ResourceProfile
24 | import org.apache.spark.rpc.RpcEnv
25 |
26 | class RayCoarseGrainedExecutorBackend(
27 | rpcEnv: RpcEnv,
28 | driverUrl: String,
29 | executorId: String,
30 | bindAddress: String,
31 | hostname: String,
32 | cores: Int,
33 | userClassPath: Seq[URL],
34 | env: SparkEnv,
35 | resourcesFileOpt: Option[String],
36 | resourceProfile: ResourceProfile)
37 | extends CoarseGrainedExecutorBackend(
38 | rpcEnv,
39 | driverUrl,
40 | executorId,
41 | bindAddress,
42 | hostname,
43 | cores,
44 | env,
45 | resourcesFileOpt,
46 | resourceProfile) {
47 |
48 | override def getUserClassPath: Seq[URL] = userClassPath
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.resource.ResourceProfile
24 | import org.apache.spark.rpc.RpcEnv
25 |
26 | class RayCoarseGrainedExecutorBackend(
27 | rpcEnv: RpcEnv,
28 | driverUrl: String,
29 | executorId: String,
30 | bindAddress: String,
31 | hostname: String,
32 | cores: Int,
33 | userClassPath: Seq[URL],
34 | env: SparkEnv,
35 | resourcesFileOpt: Option[String],
36 | resourceProfile: ResourceProfile)
37 | extends CoarseGrainedExecutorBackend(
38 | rpcEnv,
39 | driverUrl,
40 | executorId,
41 | bindAddress,
42 | hostname,
43 | cores,
44 | env,
45 | resourcesFileOpt,
46 | resourceProfile) {
47 |
48 | override def getUserClassPath: Seq[URL] = userClassPath
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.resource.ResourceProfile
24 | import org.apache.spark.rpc.RpcEnv
25 |
26 | class RayCoarseGrainedExecutorBackend(
27 | rpcEnv: RpcEnv,
28 | driverUrl: String,
29 | executorId: String,
30 | bindAddress: String,
31 | hostname: String,
32 | cores: Int,
33 | userClassPath: Seq[URL],
34 | env: SparkEnv,
35 | resourcesFileOpt: Option[String],
36 | resourceProfile: ResourceProfile)
37 | extends CoarseGrainedExecutorBackend(
38 | rpcEnv,
39 | driverUrl,
40 | executorId,
41 | bindAddress,
42 | hostname,
43 | cores,
44 | env,
45 | resourcesFileOpt,
46 | resourceProfile) {
47 |
48 | override def getUserClassPath: Seq[URL] = userClassPath
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/ExternalShuffleServiceUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp;
19 |
20 | import java.util.List;
21 |
22 | import io.ray.api.ActorHandle;
23 | import io.ray.api.Ray;
24 |
25 | public class ExternalShuffleServiceUtils {
26 | public static ActorHandle createShuffleService(
27 | String node, List options) {
28 | return Ray.actor(RayExternalShuffleService::new)
29 | .setResource("node:" + node, 0.01)
30 | .setJvmOptions(options).remote();
31 | }
32 |
33 | public static void startShuffleService(
34 | ActorHandle handle) {
35 | handle.task(RayExternalShuffleService::start).remote();
36 | }
37 |
38 | public static void stopShuffleService(
39 | ActorHandle handle) {
40 | handle.task(RayExternalShuffleService::stop).remote();
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark330
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK330_DESCRIPTOR = SparkShimDescriptor(3, 3, 0)
24 | val SPARK331_DESCRIPTOR = SparkShimDescriptor(3, 3, 1)
25 | val SPARK332_DESCRIPTOR = SparkShimDescriptor(3, 3, 2)
26 | val SPARK333_DESCRIPTOR = SparkShimDescriptor(3, 3, 3)
27 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK330_DESCRIPTOR", s"$SPARK331_DESCRIPTOR",
28 | s"$SPARK332_DESCRIPTOR", s"$SPARK333_DESCRIPTOR")
29 | val DESCRIPTOR = SPARK332_DESCRIPTOR
30 | }
31 |
32 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
33 | def createShim: SparkShims = {
34 | new Spark330Shims()
35 | }
36 |
37 | def matches(version: String): Boolean = {
38 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark340
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0)
24 | val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1)
25 | val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2)
26 | val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3)
27 | val SPARK344_DESCRIPTOR = SparkShimDescriptor(3, 4, 4)
28 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR",
29 | s"$SPARK343_DESCRIPTOR", s"$SPARK344_DESCRIPTOR")
30 | val DESCRIPTOR = SPARK341_DESCRIPTOR
31 | }
32 |
33 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
34 | def createShim: SparkShims = {
35 | new Spark340Shims()
36 | }
37 |
38 | def matches(version: String): Boolean = {
39 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.{SparkEnv, TaskContext}
22 | import org.apache.spark.api.java.JavaRDD
23 | import org.apache.spark.executor.RayDPExecutorBackendFactory
24 | import org.apache.spark.sql.types.StructType
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 |
27 | sealed abstract class ShimDescriptor
28 |
29 | case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimDescriptor {
30 | override def toString(): String = s"$major.$minor.$patch"
31 | }
32 |
33 | trait SparkShims {
34 | def getShimDescriptor: ShimDescriptor
35 |
36 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame
37 |
38 | def getExecutorBackendFactory(): RayDPExecutorBackendFactory
39 |
40 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext
41 |
42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema
43 | }
44 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark330
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor._
24 | import org.apache.spark.resource.ResourceProfile
25 | import org.apache.spark.rpc.RpcEnv
26 |
27 | class RayDPSpark330ExecutorBackendFactory
28 | extends RayDPExecutorBackendFactory {
29 | override def createExecutorBackend(
30 | rpcEnv: RpcEnv,
31 | driverUrl: String,
32 | executorId: String,
33 | bindAddress: String,
34 | hostname: String,
35 | cores: Int,
36 | userClassPath: Seq[URL],
37 | env: SparkEnv,
38 | resourcesFileOpt: Option[String],
39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
40 | new RayCoarseGrainedExecutorBackend(
41 | rpcEnv,
42 | driverUrl,
43 | executorId,
44 | bindAddress,
45 | hostname,
46 | cores,
47 | userClassPath,
48 | env,
49 | resourcesFileOpt,
50 | resourceProfile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark340
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor._
24 | import org.apache.spark.resource.ResourceProfile
25 | import org.apache.spark.rpc.RpcEnv
26 |
27 | class RayDPSpark340ExecutorBackendFactory
28 | extends RayDPExecutorBackendFactory {
29 | override def createExecutorBackend(
30 | rpcEnv: RpcEnv,
31 | driverUrl: String,
32 | executorId: String,
33 | bindAddress: String,
34 | hostname: String,
35 | cores: Int,
36 | userClassPath: Seq[URL],
37 | env: SparkEnv,
38 | resourcesFileOpt: Option[String],
39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
40 | new RayCoarseGrainedExecutorBackend(
41 | rpcEnv,
42 | driverUrl,
43 | executorId,
44 | bindAddress,
45 | hostname,
46 | cores,
47 | userClassPath,
48 | env,
49 | resourcesFileOpt,
50 | resourceProfile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark350
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor._
24 | import org.apache.spark.resource.ResourceProfile
25 | import org.apache.spark.rpc.RpcEnv
26 |
27 | class RayDPSpark350ExecutorBackendFactory
28 | extends RayDPExecutorBackendFactory {
29 | override def createExecutorBackend(
30 | rpcEnv: RpcEnv,
31 | driverUrl: String,
32 | executorId: String,
33 | bindAddress: String,
34 | hostname: String,
35 | cores: Int,
36 | userClassPath: Seq[URL],
37 | env: SparkEnv,
38 | resourcesFileOpt: Option[String],
39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
40 | new RayCoarseGrainedExecutorBackend(
41 | rpcEnv,
42 | driverUrl,
43 | executorId,
44 | bindAddress,
45 | hostname,
46 | cores,
47 | userClassPath,
48 | env,
49 | resourcesFileOpt,
50 | resourceProfile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayClusterManager.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.scheduler.cluster.raydp
19 |
20 | import org.apache.spark.SparkContext
21 | import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
22 |
23 | private[spark] class RayClusterManager extends ExternalClusterManager {
24 |
25 | override def canCreate(masterURL: String): Boolean = {
26 | masterURL.startsWith("ray")
27 | }
28 |
29 | override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = {
30 | new TaskSchedulerImpl(sc)
31 | }
32 |
33 | override def createSchedulerBackend(
34 | sc: SparkContext,
35 | masterURL: String,
36 | scheduler: TaskScheduler): SchedulerBackend = {
37 | new RayCoarseGrainedSchedulerBackend(
38 | sc,
39 | scheduler.asInstanceOf[TaskSchedulerImpl],
40 | masterURL)
41 | }
42 |
43 | override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {
44 | scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.rdd
19 |
20 | import java.util.List;
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import io.ray.runtime.generated.Common.Address
25 |
26 | import org.apache.spark.{Partition, SparkContext, TaskContext}
27 | import org.apache.spark.raydp.RayDPUtils
28 | import org.apache.spark.sql.Row
29 |
30 | private[spark] class RayObjectRefRDDPartition(idx: Int) extends Partition {
31 | val index = idx
32 | }
33 |
34 | private[spark]
35 | class RayObjectRefRDD(
36 | sc: SparkContext,
37 | locations: List[Array[Byte]])
38 | extends RDD[Row](sc, Nil) {
39 |
40 | override def getPartitions: Array[Partition] = {
41 | (0 until locations.size()).map { i =>
42 | new RayObjectRefRDDPartition(i).asInstanceOf[Partition]
43 | }.toArray
44 | }
45 |
46 | override def compute(split: Partition, context: TaskContext): Iterator[Row] = {
47 | (Row(split.index) :: Nil).iterator
48 | }
49 |
50 | override def getPreferredLocations(split: Partition): Seq[String] = {
51 | Seq(Address.parseFrom(locations.get(split.index)).getIpAddress())
52 | }
53 | }
54 |
55 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.executor.spark322
19 |
20 | import java.net.URL
21 |
22 | import org.apache.spark.SparkEnv
23 | import org.apache.spark.executor.CoarseGrainedExecutorBackend
24 | import org.apache.spark.executor.RayDPExecutorBackendFactory
25 | import org.apache.spark.resource.ResourceProfile
26 | import org.apache.spark.rpc.RpcEnv
27 |
28 | class RayDPSpark322ExecutorBackendFactory
29 | extends RayDPExecutorBackendFactory {
30 | override def createExecutorBackend(
31 | rpcEnv: RpcEnv,
32 | driverUrl: String,
33 | executorId: String,
34 | bindAddress: String,
35 | hostname: String,
36 | cores: Int,
37 | userClassPath: Seq[URL],
38 | env: SparkEnv,
39 | resourcesFileOpt: Option[String],
40 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = {
41 | new CoarseGrainedExecutorBackend(
42 | rpcEnv,
43 | driverUrl,
44 | executorId,
45 | bindAddress,
46 | hostname,
47 | cores,
48 | userClassPath,
49 | env,
50 | resourcesFileOpt,
51 | resourceProfile)
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import org.apache.spark.rpc.RpcEndpointRef
21 |
22 | private[deploy] sealed trait RayDPDeployMessage extends Serializable
23 |
24 | case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef)
25 | extends RayDPDeployMessage
26 |
27 | case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends RayDPDeployMessage
28 |
29 | case class UnregisterApplication(appId: String) extends RayDPDeployMessage
30 |
31 | case class RegisterExecutor(executorId: String, nodeIp: String) extends RayDPDeployMessage
32 |
33 | case class ExecutorStarted(executorId: String) extends RayDPDeployMessage
34 |
35 | case class RequestExecutors(appId: String, requestedTotal: Int) extends RayDPDeployMessage
36 |
37 | case class KillExecutors(appId: String, executorIds: Seq[String]) extends RayDPDeployMessage
38 |
39 | case class RequestAddPendingRestartedExecutor(executorId: String)
40 | extends RayDPDeployMessage
41 |
42 | case class AddPendingRestartedExecutorReply(newExecutorId: Option[String])
43 | extends RayDPDeployMessage
44 |
45 | case class RecacheRDD(rddId: Int) extends RayDPDeployMessage
46 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark340
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.TaskContext
22 | import org.apache.spark.api.java.JavaRDD
23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
24 | import org.apache.spark.sql.execution.arrow.ArrowConverters
25 | import org.apache.spark.sql.types._
26 | import org.apache.spark.sql.util.ArrowUtils
27 |
28 | object SparkSqlUtils {
29 | def toDataFrame(
30 | arrowBatchRDD: JavaRDD[Array[Byte]],
31 | schemaString: String,
32 | session: SparkSession): DataFrame = {
33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
36 | val context = TaskContext.get()
37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
38 | }
39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema)
40 | }
41 |
42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_torch_sequential.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import pytest
19 | import sys
20 | import torch
21 | import raydp
22 | from raydp.torch import TorchEstimator
23 |
24 | def test_torch_estimator(spark_on_ray_small):
25 | ##prepare the data
26 | customers = [
27 | (1,'James', 21, 6),
28 | (2, "Liz", 25, 8),
29 | (3, "John", 31, 6),
30 | (4, "Jennifer", 45, 7),
31 | (5, "Robert", 41, 5),
32 | (6, "Sandra", 45, 8)
33 | ]
34 | df = spark_on_ray_small.createDataFrame(customers, ["cID", "name", "age", "grade"])
35 |
36 | ##create model
37 | model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2,1))
38 | optimizer = torch.optim.Adam(model.parameters())
39 | loss = torch.nn.MSELoss()
40 |
41 | #config
42 | estimator = TorchEstimator(
43 | model = model,
44 | optimizer = optimizer,
45 | loss = loss,
46 | num_workers = 3,
47 | num_epochs = 5,
48 | feature_columns = ["age"],
49 | feature_types = torch.float,
50 | label_column = "grade",
51 | label_type = torch.float,
52 | batch_size = 1
53 | )
54 | estimator.fit_on_spark(df)
55 |
56 | if __name__ == "__main__":
57 | sys.exit(pytest.main(["-v", __file__]))
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayExternalShuffleService.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import io.ray.api.Ray;
21 |
22 | import org.apache.spark.{SecurityManager, SparkConf}
23 | import org.apache.spark.deploy.ExternalShuffleService
24 | import org.apache.spark.internal.Logging
25 |
26 | class RayExternalShuffleService() extends Logging {
27 | val conf = new SparkConf()
28 | val mgr = new SecurityManager(conf)
29 | val instance = new ExternalShuffleService(conf, mgr)
30 |
31 | def start(): Unit = {
32 | instance.start()
33 | }
34 |
35 | def stop(): Unit = {
36 | instance.stop()
37 | Ray.exitActor()
38 | }
39 | }
40 |
41 | object RayExternalShuffleService {
42 | def getShuffleConf(conf: SparkConf): Array[String] = {
43 | // all conf needed by external shuffle service
44 | var shuffleConf = conf.getAll.filter {
45 | case (k, v) => k.startsWith("spark.shuffle")
46 | }.map {
47 | case (k, v) =>
48 | "-D" + k + "=" + v
49 | }
50 | val localDirKey = "spark.local.dir"
51 | if (conf.contains(localDirKey)) {
52 | shuffleConf = shuffleConf :+
53 | "-D" + localDirKey + "=" + conf.get(localDirKey)
54 | }
55 | shuffleConf
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.spark350
19 |
20 | import org.apache.arrow.vector.types.pojo.Schema
21 | import org.apache.spark.TaskContext
22 | import org.apache.spark.api.java.JavaRDD
23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession}
24 | import org.apache.spark.sql.execution.arrow.ArrowConverters
25 | import org.apache.spark.sql.types._
26 | import org.apache.spark.sql.util.ArrowUtils
27 |
28 | object SparkSqlUtils {
29 | def toDataFrame(
30 | arrowBatchRDD: JavaRDD[Array[Byte]],
31 | schemaString: String,
32 | session: SparkSession): DataFrame = {
33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
36 | val context = TaskContext.get()
37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId,false, context)
38 | }
39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema)
40 | }
41 |
42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false)
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark322
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK311_DESCRIPTOR = SparkShimDescriptor(3, 1, 1)
24 | val SPARK312_DESCRIPTOR = SparkShimDescriptor(3, 1, 2)
25 | val SPARK313_DESCRIPTOR = SparkShimDescriptor(3, 1, 3)
26 | val SPARK320_DESCRIPTOR = SparkShimDescriptor(3, 2, 0)
27 | val SPARK321_DESCRIPTOR = SparkShimDescriptor(3, 2, 1)
28 | val SPARK322_DESCRIPTOR = SparkShimDescriptor(3, 2, 2)
29 | val SPARK323_DESCRIPTOR = SparkShimDescriptor(3, 2, 3)
30 | val SPARK324_DESCRIPTOR = SparkShimDescriptor(3, 2, 4)
31 | val DESCRIPTOR_STRINGS =
32 | Seq(s"$SPARK311_DESCRIPTOR", s"$SPARK312_DESCRIPTOR" ,s"$SPARK313_DESCRIPTOR",
33 | s"$SPARK320_DESCRIPTOR", s"$SPARK321_DESCRIPTOR", s"$SPARK322_DESCRIPTOR",
34 | s"$SPARK323_DESCRIPTOR", s"$SPARK324_DESCRIPTOR")
35 | val DESCRIPTOR = SPARK323_DESCRIPTOR
36 | }
37 |
38 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
39 | def createShim: SparkShims = {
40 | new Spark322Shims()
41 | }
42 |
43 | def matches(version: String): Boolean = {
44 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark350
19 |
20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor}
21 |
22 | object SparkShimProvider {
23 | val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0)
24 | val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1)
25 | val SPARK352_DESCRIPTOR = SparkShimDescriptor(3, 5, 2)
26 | val SPARK353_DESCRIPTOR = SparkShimDescriptor(3, 5, 3)
27 | val SPARK354_DESCRIPTOR = SparkShimDescriptor(3, 5, 4)
28 | val SPARK355_DESCRIPTOR = SparkShimDescriptor(3, 5, 5)
29 | val SPARK356_DESCRIPTOR = SparkShimDescriptor(3, 5, 6)
30 | val SPARK357_DESCRIPTOR = SparkShimDescriptor(3, 5, 7)
31 | val DESCRIPTOR_STRINGS = Seq(
32 | s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR", s"$SPARK352_DESCRIPTOR",
33 | s"$SPARK353_DESCRIPTOR", s"$SPARK354_DESCRIPTOR", s"$SPARK355_DESCRIPTOR",
34 | s"$SPARK356_DESCRIPTOR", s"$SPARK357_DESCRIPTOR"
35 | )
36 | val DESCRIPTOR = SPARK350_DESCRIPTOR
37 | }
38 |
39 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider {
40 | def createShim: SparkShims = {
41 | new Spark350Shims()
42 | }
43 |
44 | def matches(version: String): Boolean = {
45 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version)
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationDescription.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import scala.collection.Map
21 |
22 | private[spark] case class Command(
23 | driverUrl: String,
24 | environment: Map[String, String],
25 | classPathEntries: Seq[String],
26 | libraryPathEntries: Seq[String],
27 | javaOpts: Seq[String]) {
28 |
29 | def withNewJavaOpts(newJavaOptions: Seq[String]): Command = {
30 | Command(driverUrl, environment, classPathEntries, libraryPathEntries, newJavaOptions)
31 | }
32 | }
33 |
34 | private[spark] case class ApplicationDescription(
35 | name: String,
36 | numExecutors: Int,
37 | coresPerExecutor: Option[Int],
38 | memoryPerExecutorMB: Int,
39 | rayActorCPU: Double,
40 | command: Command,
41 | user: String = System.getProperty("user.name", ""),
42 | resourceReqsPerExecutor: Map[String, Double] = Map.empty) {
43 |
44 | def withNewCommand(newCommand: Command): ApplicationDescription = {
45 | ApplicationDescription(name = name,
46 | numExecutors = numExecutors, coresPerExecutor = coresPerExecutor,
47 | memoryPerExecutorMB = memoryPerExecutorMB, command = newCommand, user = user,
48 | resourceReqsPerExecutor = resourceReqsPerExecutor,
49 | rayActorCPU = rayActorCPU)
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_spark_master_memory.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pytest
3 | import ray
4 | import raydp
5 | from ray.cluster_utils import Cluster
6 | from ray.util.state import list_actors
7 |
8 |
9 | def test_spark_master_memory_custom(jdk17_extra_spark_configs):
10 | cluster = Cluster(
11 | initialize_head=True,
12 | head_node_args={
13 | "num_cpus": 2,
14 | "resources": {"master": 10},
15 | "include_dashboard": True,
16 | "dashboard_port": 8270,
17 | },
18 | )
19 | ray.init(address=cluster.address,
20 | dashboard_port=cluster.head_node.dashboard_grpc_port,
21 | include_dashboard=True)
22 |
23 | custom_memory = 100 * 1024 * 1024 # 100MB in bytes
24 | configs = jdk17_extra_spark_configs.copy()
25 | # Config under test: set Spark Master actor memory via RayDP config
26 | configs["spark.ray.raydp_spark_master.actor.resource.memory"] = str(custom_memory)
27 | # Also require the master custom resource so the actor is scheduled on the head
28 | configs["spark.ray.raydp_spark_master.actor.resource.master"] = "1"
29 |
30 | app_name = "test_spark_master_memory_custom"
31 |
32 | spark = raydp.init_spark(
33 | app_name=app_name,
34 | num_executors=1,
35 | executor_cores=1,
36 | executor_memory="500M",
37 | configs=configs,
38 | )
39 |
40 | # Trigger the Spark master / RayDPSparkMaster startup
41 | spark.createDataFrame([(1, 2)], ["a", "b"]).count()
42 |
43 | # RayDPSparkMaster name is app_name + RAYDP_SPARK_MASTER_SUFFIX
44 | master_actor_name = f"{app_name}_SPARK_MASTER"
45 |
46 | actor = ray.get_actor(master_actor_name)
47 | assert actor is not None
48 |
49 | # Query Ray state for this actor
50 | actor_state = list_actors(filters=[("actor_id", "=", actor._actor_id.hex())], detail=True)[0]
51 | resources = actor_state.required_resources
52 |
53 | assert resources["memory"] == custom_memory
54 | assert resources["master"] == 1
55 |
56 | spark.stop()
57 | raydp.stop_spark()
58 | ray.shutdown()
59 | cluster.shutdown()
60 |
61 |
62 | if __name__ == "__main__":
63 | sys.exit(pytest.main(["-v", __file__]))
64 |
65 |
66 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/raydp/RayDPUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.raydp;
19 |
20 | import io.ray.api.ObjectRef;
21 | import io.ray.api.Ray;
22 | import io.ray.api.id.ObjectId;
23 | import io.ray.runtime.AbstractRayRuntime;
24 | import io.ray.runtime.object.ObjectRefImpl;
25 |
26 | public class RayDPUtils {
27 |
28 | /**
29 | * Convert ObjectRef to subclass ObjectRefImpl. Throw RuntimeException if it is not instance
30 | * of ObjectRefImpl. We can't import the ObjectRefImpl in scala code, so we do the
31 | * conversion at here.
32 | */
33 | public static ObjectRefImpl convert(ObjectRef obj) {
34 | if (obj instanceof ObjectRefImpl) {
35 | return (ObjectRefImpl)obj;
36 | } else {
37 | throw new RuntimeException(obj.getClass() + " is not ObjectRefImpl");
38 | }
39 | }
40 |
41 | /**
42 | * Create ObjectRef from Array[Byte] and register ownership.
43 | * We can't import the ObjectRefImpl in scala code, so we do the conversion at here.
44 | */
45 | public static ObjectRef readBinary(byte[] obj, Class clazz, byte[] ownerAddress) {
46 | ObjectId id = new ObjectId(obj);
47 | ObjectRefImpl ref = new ObjectRefImpl<>(id, clazz, false);
48 | AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
49 | runtime.getObjectStore().registerOwnershipInfoAndResolveFuture(
50 | id, null, ownerAddress
51 | );
52 | return ref;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 | # Running RayDP on k8s cluster
2 |
3 | ## Build docker image
4 | Build the docker image to use in K8S with the following command, and this will create an image tag with `oap-project/raydp:latest`
5 | ```shell
6 | # under ${RAYDP_HOME}/docker
7 | ./build-docker.sh
8 | ```
9 |
10 | You can install our nightly build with `pip install raydp --pre` or `pip install raydp-nightly`.To install raydp-nightly in the image, modify the following code in `Dockerfile`:
11 | ```Dockerfile
12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \
13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \
14 | && sudo mkdir /raydp \
15 | && sudo chown -R ray /raydp \
16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp-nightly
17 | ```
18 |
19 | Meanwhile, you should install all dependencies of your application in the `Dockerfile`. If suitable, you can change the base image to `ray-ml`:
20 | ```Dockerfile
21 | FROM rayproject/ray-ml:latest
22 | ```
23 |
24 | Then, you can push the built image to repository or spread to the k8s worker nodes.
25 |
26 | ## Deploy ray cluster with Helm
27 | You need to create a Helm chart first. To start with, check out this [example ray cluster Helm chart](https://github.com/ray-project/kuberay/tree/master/helm-chart/ray-cluster). You can clone this repo and copy this directory, then modify `values.yaml` to use the previously built image.
28 |
29 | ```yaml
30 | image:
31 | repository: oap-project/raydp
32 | tag: latest
33 | pullPolicy: IfNotPresent
34 | ```
35 |
36 | You can also change other fields in this file to specify number of workers, etc.
37 |
38 | Then, you need to deploy the KubeRay operator first, please refer to [here](https://docs.ray.io/en/latest/cluster/kubernetes/getting-started.html#kuberay-quickstart) for instructions. You can now deploy a Ray cluster with RayDP installed via `helm install ray-cluster PATH_to_CHART`.
39 |
40 | ## Access the cluster
41 | Check here [here](https://docs.ray.io/en/master/cluster/kubernetes/getting-started.html#running-applications-on-a-ray-cluster) to see how to run applications on the cluster you just deployed.
42 |
43 | ## Legacy
44 | If you are using Ray versions before 2.0, you can try this command.
45 | ```shell
46 | ray up ${RAYDP_HOME}/docker/legacy.yaml
47 | ```
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.rdd
19 |
20 | import java.util.List;
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import io.ray.runtime.generated.Common.Address
25 |
26 | import org.apache.spark.{Partition, SparkContext, TaskContext}
27 | import org.apache.spark.api.java.JavaSparkContext
28 | import org.apache.spark.raydp.RayDPUtils
29 | import org.apache.spark.sql.raydp.ObjectStoreReader
30 |
31 | private[spark] class RayDatasetRDDPartition(val ref: Array[Byte], idx: Int) extends Partition {
32 | val index = idx
33 | }
34 |
35 | private[spark]
36 | class RayDatasetRDD(
37 | jsc: JavaSparkContext,
38 | @transient val objectIds: List[Array[Byte]],
39 | locations: List[Array[Byte]])
40 | extends RDD[Array[Byte]](jsc.sc, Nil) {
41 |
42 | override def getPartitions: Array[Partition] = {
43 | objectIds.asScala.zipWithIndex.map { case (k, i) =>
44 | new RayDatasetRDDPartition(k, i).asInstanceOf[Partition]
45 | }.toArray
46 | }
47 |
48 | override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
49 | val ref = split.asInstanceOf[RayDatasetRDDPartition].ref
50 | ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index))
51 | }
52 |
53 | override def getPreferredLocations(split: Partition): Seq[String] = {
54 | val address = Address.parseFrom(locations.get(split.index))
55 | Seq(address.getIpAddress())
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark322
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark322._
24 | import org.apache.spark.spark322.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark322.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark322Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark322ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark330
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark330._
24 | import org.apache.spark.spark330.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark330.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark330Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark330ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark340
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark340._
24 | import org.apache.spark.spark340.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark340.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark340Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark340ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims.spark350
19 |
20 | import org.apache.spark.{SparkEnv, TaskContext}
21 | import org.apache.spark.api.java.JavaRDD
22 | import org.apache.spark.executor.RayDPExecutorBackendFactory
23 | import org.apache.spark.executor.spark350._
24 | import org.apache.spark.spark350.TaskContextUtils
25 | import org.apache.spark.sql.{DataFrame, SparkSession}
26 | import org.apache.spark.sql.spark350.SparkSqlUtils
27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims}
28 | import org.apache.arrow.vector.types.pojo.Schema
29 | import org.apache.spark.sql.types.StructType
30 |
31 | class Spark350Shims extends SparkShims {
32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR
33 |
34 | override def toDataFrame(
35 | rdd: JavaRDD[Array[Byte]],
36 | schema: String,
37 | session: SparkSession): DataFrame = {
38 | SparkSqlUtils.toDataFrame(rdd, schema, session)
39 | }
40 |
41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = {
42 | new RayDPSpark350ExecutorBackendFactory()
43 | }
44 |
45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = {
46 | TaskContextUtils.getDummyTaskContext(partitionId, env)
47 | }
48 |
49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = {
50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # RayDP Examples
2 | Here are a few examples showing how RayDP works together with other libraries, such as PyTorch, Tensorflow, XGBoost and Horovod.
3 |
4 | In order to run these examples, you may need to install corresponding dependencies. For installation guides, please refer to their homepages. Notice that we need to install [xgboost_ray](https://github.com/ray-project/xgboost_ray) to run the xgboost example. In addition, if you are running the examples in a ray cluster, all nodes should have the dependencies installed.
5 |
6 | ## NYC Taxi Fare Prediction Dataset
7 | We have a few examples which use this dataset.
8 | You can run our examples right away after you clone our repo, because we include a small example dataset generated randomly using `examples/random_nyctaxi.py`. Generated datasets just demonstrates that our examples can work, but the trained models might not be meaningful.
9 |
10 | The original dataset can be downloaded [here](https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data). After you download it, please modify the variable `NYC_TRAIN_CSV` in `data_process.py` and point it to where `train.csv` is saved.
11 |
12 | ## Horovod
13 | To run the example, please install horovod via `pip install horovod[pytorch, ray]`. In addition, `HOROVOD_WITH_PYTORCH` and `HOROVOD_WITH_GLOO` should be set to `1` before pip. Notice that macOS users need to first install `libuv` via `brew install libuv`. Please refer to [here](https://horovod.readthedocs.io/en/stable/install_include.html) for details.
14 |
15 | When running `horovod_nyctaxi.py`, do not use `horovodrun`. Check [here](https://horovod.readthedocs.io/en/stable/ray_include.html) for more information.
16 |
17 | ## RaySGD Example
18 | In the RaySGD example, we demonstrate how to use our `MLDataset` API. After we use Spark to transform the dataset, we call `RayMLDataset.from_spark` to write the Spark DataFrames into Ray object store, using Apache Arrow format. We then convert the data to `pandas` DataFrame, hopefully zero-copy. Finally, they can be consumed by any framework supports `numpy` format, such as PyTorch or Tensorflow. `MLDataset` is partitioned, or sharded, just like Spark DataFrames. Their numbers of partitions are not required to be the same. However, the number of shards of `MLDataset` should be the same as the number of workers of `TorchTrainer` or `TFTrainer`, so that each worker is mapped to a shard.
19 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_xgboost.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import os
19 | import sys
20 | import shutil
21 | import platform
22 | import pytest
23 | import pyspark
24 | import numpy as np
25 | from pyspark.sql.functions import rand
26 |
27 | from raydp.xgboost import XGBoostEstimator
28 | from raydp.utils import random_split
29 |
30 | @pytest.mark.parametrize("use_fs_directory", [True, False])
31 | def test_xgb_estimator(spark_on_ray_small, use_fs_directory):
32 | if platform.system() == "Darwin":
33 | pytest.skip("Skip xgboost test on MacOS")
34 | spark = spark_on_ray_small
35 |
36 | # calculate z = 3 * x + 4 * y + 5
37 | df: pyspark.sql.DataFrame = spark.range(0, 100000)
38 | df = df.withColumn("x", rand() * 100) # add x column
39 | df = df.withColumn("y", rand() * 1000) # ad y column
40 | df = df.withColumn("z", df.x * 3 + df.y * 4 + rand() + 5) # ad z column
41 | df = df.select(df.x, df.y, df.z)
42 |
43 | train_df, test_df = random_split(df, [0.7, 0.3])
44 | params = {}
45 | estimator = XGBoostEstimator(params, "z", resources_per_worker={"CPU": 1})
46 | if use_fs_directory:
47 | dir = os.path.dirname(os.path.realpath(__file__)) + "/test_xgboost"
48 | uri = "file://" + dir
49 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
50 | else:
51 | estimator.fit_on_spark(train_df, test_df)
52 | print(estimator.get_model().inplace_predict(np.asarray([[1,2]])))
53 | if use_fs_directory:
54 | shutil.rmtree(dir)
55 |
56 | if __name__ == '__main__':
57 | sys.exit(pytest.main(["-v", __file__]))
--------------------------------------------------------------------------------
/examples/raydp-submit.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname
2 | import sys
3 | import json
4 | import subprocess
5 | import shlex
6 | import ray
7 | import pyspark
8 |
9 | ray.init(address="auto")
10 | node = ray.worker.global_worker.node
11 | options = {}
12 | options["ray"] = {}
13 | options["ray"]["run-mode"] = "CLUSTER"
14 | options["ray"]["node-ip"] = node.node_ip_address
15 | options["ray"]["address"] = node.address
16 | options["ray"]["session-dir"] = node.get_session_dir_path()
17 |
18 | ray.shutdown()
19 | conf_path = dirname(__file__) + "/ray.conf"
20 | with open(conf_path, "w") as f:
21 | json.dump(options, f)
22 |
23 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access
24 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP.
25 | java_opts = " ".join([
26 | "-XX:+IgnoreUnrecognizedVMOptions",
27 | "--add-opens=java.base/java.lang=ALL-UNNAMED",
28 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
29 | "--add-opens=java.base/java.io=ALL-UNNAMED",
30 | "--add-opens=java.base/java.net=ALL-UNNAMED",
31 | "--add-opens=java.base/java.nio=ALL-UNNAMED",
32 | "--add-opens=java.base/java.math=ALL-UNNAMED",
33 | "--add-opens=java.base/java.text=ALL-UNNAMED",
34 | "--add-opens=java.base/java.util=ALL-UNNAMED",
35 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
36 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
37 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
38 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
39 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED",
40 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
41 | ])
42 |
43 | command = ["bin/raydp-submit", "--ray-conf", conf_path]
44 | command += ["--conf", "spark.executor.cores=1"]
45 | command += ["--conf", "spark.executor.instances=1"]
46 | command += ["--conf", "spark.executor.memory=500m"]
47 | command += ["--conf", f"spark.executor.extraJavaOptions={java_opts}"]
48 | command += ["--conf", f"spark.driver.extraJavaOptions={java_opts}"]
49 | command += ["--conf", f"spark.ray.raydp_app_master.extraJavaOptions={java_opts}"]
50 | example_path = dirname(pyspark.__file__)
51 | # run SparkPi as example
52 | command.append(example_path + "/examples/src/main/python/pi.py")
53 | cmd_str = " ".join(shlex.quote(arg) for arg in command)
54 | sys.exit(subprocess.run(cmd_str, check=True, shell=True).returncode)
55 |
--------------------------------------------------------------------------------
/python/raydp/torch/torch_metrics.py:
--------------------------------------------------------------------------------
1 | import sys
2 | module = sys.modules[__name__]
3 |
4 | def try_import_torchmetrics():
5 | """Tries importing torchmetrics and returns the module (or None).
6 | Returns:
7 | torchmetrics modules.
8 | """
9 | try:
10 | # pylint: disable=import-outside-toplevel
11 | import torchmetrics
12 |
13 | return torchmetrics
14 | except ImportError as torchmetrics_not_exist:
15 | raise ImportError(
16 | "Could not import torchmetrics! Raydp TorchEstimator requires "
17 | "you to install torchmetrics: "
18 | "`pip install torchmetrics`."
19 | ) from torchmetrics_not_exist
20 |
21 | class TorchMetric():
22 | def __init__(self, metrics_name, metrics_config):
23 | torchmetrics = try_import_torchmetrics()
24 | self._metrics_name = metrics_name
25 | self._metrics_func = {}
26 | if self._metrics_name is not None:
27 | assert isinstance(metrics_name, list), "metrics_name must be a list"
28 | for metric in self._metrics_name:
29 | if isinstance(metric, torchmetrics.Metric):
30 | self._metrics_func[metric.__class__.__name__] = metric
31 | elif isinstance(metric, str) and hasattr(torchmetrics, metric):
32 | if metrics_config is not None and metrics_config[metric] is not None:
33 | self._metrics_func[metric] = getattr(torchmetrics, metric)(
34 | **metrics_config[metric])
35 | else:
36 | self._metrics_func[metric] = getattr(torchmetrics, metric)()
37 | else:
38 | raise Exception(
39 | "Unsupported parameter, we only support list of "
40 | "torchmetrics.Metric instances or arr of torchmetrics.")
41 |
42 | def update(self, preds, targets):
43 | for metric in self._metrics_func:
44 | self._metrics_func[metric].update(preds, targets)
45 |
46 | def compute(self):
47 | epoch_res = {}
48 | for metric in self._metrics_func:
49 | epoch_res[metric] = self._metrics_func[metric].compute().item()
50 |
51 | return epoch_res
52 |
53 | def reset(self):
54 | for metric in self._metrics_func:
55 | self._metrics_func[metric].reset()
56 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | name: RayDP PyPi
19 |
20 | on:
21 | schedule:
22 | - cron: '0 0 * * *'
23 | # can manually trigger the workflow
24 | workflow_dispatch:
25 |
26 | permissions: # added using https://github.com/step-security/secure-repo
27 | contents: read
28 |
29 | jobs:
30 | build-and-publish:
31 | # do not run in forks
32 | if: ${{ github.repository_owner == 'ray-project' }}
33 | name: build wheel and upload
34 | runs-on: ubuntu-latest
35 | steps:
36 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
37 | - name: Set up Python 3.10
38 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4
39 | with:
40 | python-version: 3.10.14
41 | - name: Set up JDK 1.8
42 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
43 | with:
44 | java-version: 1.8
45 | - name: days since the commit date
46 | run: |
47 | :
48 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA)
49 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 ))
50 | if [ $days -eq 0 ]; then
51 | echo COMMIT_TODAY=true >> $GITHUB_ENV
52 | fi
53 | - name: Build wheel
54 | if: env.COMMIT_TODAY == 'true'
55 | env:
56 | RAYDP_BUILD_MODE: nightly
57 | run: pip install wheel grpcio-tools && ./build.sh
58 | - name: Upload
59 | if: env.COMMIT_TODAY == 'true'
60 | uses: pypa/gh-action-pypi-publish@v1.13.0
61 | with:
62 | password: ${{ secrets.PYPI_API_TOKEN }}
63 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.sql.raydp
19 |
20 | import java.io.ByteArrayInputStream
21 | import java.nio.channels.{Channels, ReadableByteChannel}
22 | import java.util.List
23 |
24 | import com.intel.raydp.shims.SparkShimLoader
25 |
26 | import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
27 | import org.apache.spark.raydp.RayDPUtils
28 | import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD}
29 | import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext}
30 | import org.apache.spark.sql.catalyst.expressions.GenericRow
31 | import org.apache.spark.sql.execution.arrow.ArrowConverters
32 | import org.apache.spark.sql.types.{IntegerType, StructType}
33 |
34 | object ObjectStoreReader {
35 | def createRayObjectRefDF(
36 | spark: SparkSession,
37 | locations: List[Array[Byte]]): DataFrame = {
38 | val rdd = new RayObjectRefRDD(spark.sparkContext, locations)
39 | val schema = new StructType().add("idx", IntegerType)
40 | spark.createDataFrame(rdd, schema)
41 | }
42 |
43 | def RayDatasetToDataFrame(
44 | sparkSession: SparkSession,
45 | rdd: RayDatasetRDD,
46 | schema: String): DataFrame = {
47 | SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession)
48 | }
49 |
50 | def getBatchesFromStream(
51 | ref: Array[Byte],
52 | ownerAddress: Array[Byte]): Iterator[Array[Byte]] = {
53 | val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress)
54 | ArrowConverters.getBatchesFromStream(
55 | Channels.newChannel(new ByteArrayInputStream(objectRef.get)))
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/core/agent/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-parent
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-agent
15 | RayDP Java Agent
16 | jar
17 |
18 |
19 |
20 |
21 | org.apache.maven.plugins
22 | maven-compiler-plugin
23 | 3.8.0
24 |
25 | 1.8
26 | 1.8
27 |
28 |
29 |
30 | org.apache.maven.plugins
31 | maven-jar-plugin
32 | 3.3.0
33 |
34 |
35 |
36 | org.apache.spark.raydp.Agent
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | org.apache.logging.log4j
47 | log4j-core
48 | 2.25.3
49 |
50 |
51 | org.apache.logging.log4j
52 | log4j-slf4j-impl
53 | 2.17.1
54 |
55 |
56 | org.slf4j
57 | slf4j-api
58 | 1.7.32
59 |
60 |
61 | com.intel
62 | raydp
63 | ${project.version}
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp;
19 |
20 | import java.util.List;
21 | import java.util.Map;
22 |
23 | import io.ray.api.ActorHandle;
24 | import io.ray.api.Ray;
25 | import io.ray.api.call.ActorCreator;
26 | import org.apache.spark.raydp.SparkOnRayConfigs;
27 |
28 | public class RayAppMasterUtils {
29 | public static ActorHandle createAppMaster(
30 | String cp,
31 | String name,
32 | List jvmOptions,
33 | Map appMasterResource) {
34 | ActorCreator creator = Ray.actor(RayAppMaster::new, cp);
35 | if (name != null) {
36 | creator.setName(name);
37 | }
38 | jvmOptions.add("-cp");
39 | jvmOptions.add(cp);
40 | creator.setJvmOptions(jvmOptions);
41 | for(Map.Entry resource : appMasterResource.entrySet()) {
42 | String resourceName = resource.getKey()
43 | .substring(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX.length() + 1);
44 | creator.setResource(resourceName, resource.getValue());
45 | }
46 |
47 | return creator.remote();
48 | }
49 |
50 | public static String getMasterUrl(
51 | ActorHandle handle) {
52 | return handle.task(RayAppMaster::getMasterUrl).remote().get();
53 | }
54 |
55 | public static Map getRestartedExecutors(
56 | ActorHandle handle) {
57 | return handle.task(RayAppMaster::getRestartedExecutors).remote().get();
58 | }
59 |
60 | public static void stopAppMaster(
61 | ActorHandle handle) {
62 | handle.task(RayAppMaster::stop).remote().get();
63 | handle.kill();
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/core/javastyle-suppressions.xml:
--------------------------------------------------------------------------------
1 |
17 |
18 |
21 |
22 |
29 |
30 |
31 |
33 |
35 |
37 |
39 |
41 |
43 |
45 |
47 |
49 |
51 |
52 |
--------------------------------------------------------------------------------
/python/raydp/mpi/network/network.proto:
--------------------------------------------------------------------------------
1 | //
2 | // Licensed to the Apache Software Foundation (ASF) under one or more
3 | // contributor license agreements. See the NOTICE file distributed with
4 | // this work for additional information regarding copyright ownership.
5 | // The ASF licenses this file to You under the Apache License, Version 2.0
6 | // (the "License"); you may not use this file except in compliance with
7 | // the License. You may obtain a copy of the License at
8 | //
9 | // http://www.apache.org/licenses/LICENSE-2.0
10 | //
11 | // Unless required by applicable law or agreed to in writing, software
12 | // distributed under the License is distributed on an "AS IS" BASIS,
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | // See the License for the specific language governing permissions and
15 | // limitations under the License.
16 | //
17 |
18 | // Syntax version
19 | syntax = "proto3";
20 |
21 | // Driver Service definition
22 | service DriverService {
23 | // register the worker process to driver which used to tell the worker has started up
24 | rpc RegisterWorker (RegisterWorkerRequest) returns (RegisterWorkerReply);
25 | // register the worker service host and port
26 | rpc RegisterWorkerService (RegisterWorkerServiceRequest) returns (RegisterWorkerServiceReply);
27 | // register the function result
28 | rpc RegisterFuncResult (FunctionResult) returns (Empty);
29 | }
30 |
31 | // Worker Service
32 | service WorkerService {
33 | // run the given function
34 | rpc RunFunction (Function) returns (Empty);
35 | // stop the worker service
36 | rpc Stop (Empty) returns (Empty);
37 | }
38 |
39 | message RegisterWorkerRequest {
40 | // the job id
41 | string job_id = 1;
42 | // the world rank id
43 | int32 world_rank = 2;
44 | }
45 |
46 | message RegisterWorkerReply {
47 | // the all node addresses and used to determine the current node ip adddress
48 | repeated string node_addresses = 3;
49 | }
50 |
51 | message RegisterWorkerServiceRequest {
52 | // the world rank
53 | int32 world_rank = 1;
54 | // the worker service listening ip
55 | string worker_ip = 2;
56 | // the worker service listening port
57 | int32 worker_port = 3;
58 | }
59 |
60 | message RegisterWorkerServiceReply {
61 | // the ray redis address
62 | string ray_address = 1;
63 | // the ray redis password
64 | string redis_password = 2;
65 | }
66 |
67 | message Function {
68 | // the function id
69 | int32 func_id = 1;
70 | // the serialized python function
71 | bytes func = 2;
72 | }
73 |
74 | message FunctionResult {
75 | int32 world_rank = 1;
76 | // the function id
77 | int32 func_id = 2;
78 | // the function results
79 | bytes result = 3;
80 | }
81 |
82 | message Empty {
83 | }
84 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayDPDriverAgent.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import io.ray.runtime.config.RayConfig
21 |
22 | import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
23 | import org.apache.spark.internal.Logging
24 | import org.apache.spark.rpc._
25 |
26 |
27 | class RayDPDriverAgent() {
28 | private val spark = SparkContext.getOrCreate()
29 | private var endpoint: RpcEndpointRef = _
30 | private var rpcEnv: RpcEnv = _
31 | private val conf: SparkConf = new SparkConf()
32 |
33 | init
34 |
35 | def init(): Unit = {
36 | val securityMgr = new SecurityManager(conf)
37 | val host = RayConfig.create().nodeIp
38 | rpcEnv = RpcEnv.create(
39 | RayAppMaster.ENV_NAME,
40 | host,
41 | host,
42 | 0,
43 | conf,
44 | securityMgr,
45 | // limit to single-thread
46 | numUsableCores = 1,
47 | clientMode = false)
48 | // register endpoint
49 | endpoint = rpcEnv.setupEndpoint(RayDPDriverAgent.ENDPOINT_NAME,
50 | new RayDPDriverAgentEndpoint(rpcEnv))
51 | }
52 |
53 | def getDriverAgentEndpointUrl(): String = {
54 | RpcEndpointAddress(rpcEnv.address, RayDPDriverAgent.ENDPOINT_NAME).toString
55 | }
56 |
57 | class RayDPDriverAgentEndpoint(override val rpcEnv: RpcEnv)
58 | extends ThreadSafeRpcEndpoint with Logging {
59 | override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
60 | case RecacheRDD(rddId) =>
61 | // TODO if multiple blocks get lost, should call this only once
62 | // SparkEnv.get.blockManagerMaster.getLocationsAndStatus()
63 | spark.getPersistentRDDs.map {
64 | case (id, rdd) =>
65 | if (id == rddId) {
66 | rdd.count
67 | }
68 | }
69 | context.reply(true)
70 | }
71 | }
72 |
73 | }
74 |
75 | object RayDPDriverAgent {
76 | val ENDPOINT_NAME = "RAYDP_DRIVER_AGENT"
77 | }
78 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/test/org/apache/spark/scheduler/cluster/raydp/TestRayCoarseGrainedSchedulerBackend.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 | package org.apache.spark.scheduler.cluster.raydp;
18 |
19 | import org.apache.spark.SparkConf;
20 | import org.junit.jupiter.api.Test;
21 |
22 | import org.apache.spark.scheduler.cluster.SchedulerBackendUtils;
23 |
24 | import static org.junit.jupiter.api.Assertions.assertEquals;
25 |
26 | /**
27 | * This class performs unit testing on some methods in `RayCoarseGrainedSchedulerBackend`.
28 | */
29 | public class TestRayCoarseGrainedSchedulerBackend {
30 |
31 | // Test using the default value.
32 | @Test
33 | public void testExecutorNumberWithDefaultConfig() {
34 | SparkConf conf = new SparkConf();
35 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
36 | assertEquals(2, executorNumber);
37 | }
38 |
39 | // Test using a negative value.
40 | @Test
41 | public void testExecutorNumberWithNegativeConfig() {
42 | SparkConf conf = new SparkConf();
43 | conf.set("spark.dynamicAllocation.initialExecutors", "-1");
44 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
45 | assertEquals(2, executorNumber);
46 | }
47 |
48 | // Test using reasonable values.
49 | @Test
50 | public void testExecutorNumberWithValidConfig() {
51 | SparkConf conf = new SparkConf();
52 | conf.set("spark.executor.instances", "5");
53 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
54 | assertEquals(5, executorNumber);
55 | }
56 |
57 | // Test using dynamic values.
58 | @Test
59 | public void testExecutorNumberWithDynamicConfig() {
60 | SparkConf conf = new SparkConf();
61 | conf.set("spark.dynamicAllocation.enabled", "true");
62 | conf.set("spark.dynamicAllocation.minExecutors", "3");
63 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2);
64 | assertEquals(3, executorNumber);
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/examples/test_pyfiles_main.py:
--------------------------------------------------------------------------------
1 | """
2 | Main script for testing raydp-submit --py-files functionality.
3 | This script imports and uses functions from test_pyfile.py.
4 | """
5 | import sys
6 | import raydp
7 |
8 |
9 | def main():
10 | """Test that py-files are properly distributed and accessible."""
11 | print("=" * 60)
12 | print("Testing raydp-submit --py-files functionality")
13 | print("=" * 60)
14 |
15 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access
16 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP.
17 | java_opts = " ".join([
18 | "-XX:+IgnoreUnrecognizedVMOptions",
19 | "--add-opens=java.base/java.lang=ALL-UNNAMED",
20 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
21 | "--add-opens=java.base/java.io=ALL-UNNAMED",
22 | "--add-opens=java.base/java.net=ALL-UNNAMED",
23 | "--add-opens=java.base/java.nio=ALL-UNNAMED",
24 | "--add-opens=java.base/java.math=ALL-UNNAMED",
25 | "--add-opens=java.base/java.text=ALL-UNNAMED",
26 | "--add-opens=java.base/java.util=ALL-UNNAMED",
27 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
28 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
29 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
30 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
31 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED",
32 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
33 | ])
34 | extra_configs = {
35 | "spark.executor.extraJavaOptions": java_opts,
36 | "spark.driver.extraJavaOptions": java_opts,
37 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts,
38 | }
39 | spark = raydp.init_spark("Test PyFiles", 1, 1, "500M",
40 | configs=extra_configs)
41 |
42 | # Test: Use compute_sum in Spark executor context
43 | print("\nTesting py-files in Spark executor context...")
44 | # Create RDD and use function from test_pyfile in executors
45 | numbers_rdd = spark.sparkContext.parallelize([1, 2, 3, 4, 5], 2)
46 |
47 | # Map function that uses imported function
48 | def process_partition(partition):
49 | from test_pyfile import compute_sum # pylint: disable=import-outside-toplevel
50 | nums = list(partition)
51 | return [compute_sum(nums)]
52 |
53 | results = numbers_rdd.mapPartitions(process_partition).collect()
54 | total = sum(results)
55 | print(f"Sum from executors: {total}")
56 | assert total == 15, f"Expected 15, got {total}"
57 | print("✓ Functions from py-files work in Spark executors!")
58 |
59 | raydp.stop_spark()
60 |
61 | print("\n" + "=" * 60)
62 | print("Test passed! --py-files is working correctly!")
63 | print("=" * 60)
64 | return 0
65 |
66 |
67 | if __name__ == "__main__":
68 | sys.exit(main())
69 |
--------------------------------------------------------------------------------
/python/raydp/mpi/network/network_pb2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by the protocol buffer compiler. DO NOT EDIT!
3 | # source: network.proto
4 | """Generated protocol buffer code."""
5 | from google.protobuf import descriptor as _descriptor
6 | from google.protobuf import descriptor_pool as _descriptor_pool
7 | from google.protobuf import symbol_database as _symbol_database
8 | from google.protobuf.internal import builder as _builder
9 | # @@protoc_insertion_point(imports)
10 |
11 | _sym_db = _symbol_database.Default()
12 |
13 |
14 |
15 |
16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rnetwork.proto\";\n\x15RegisterWorkerRequest\x12\x0e\n\x06job_id\x18\x01 \x01(\t\x12\x12\n\nworld_rank\x18\x02 \x01(\x05\"-\n\x13RegisterWorkerReply\x12\x16\n\x0enode_addresses\x18\x03 \x03(\t\"Z\n\x1cRegisterWorkerServiceRequest\x12\x12\n\nworld_rank\x18\x01 \x01(\x05\x12\x11\n\tworker_ip\x18\x02 \x01(\t\x12\x13\n\x0bworker_port\x18\x03 \x01(\x05\"I\n\x1aRegisterWorkerServiceReply\x12\x13\n\x0bray_address\x18\x01 \x01(\t\x12\x16\n\x0eredis_password\x18\x02 \x01(\t\")\n\x08\x46unction\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"E\n\x0e\x46unctionResult\x12\x12\n\nworld_rank\x18\x01 \x01(\x05\x12\x0f\n\x07\x66unc_id\x18\x02 \x01(\x05\x12\x0e\n\x06result\x18\x03 \x01(\x0c\"\x07\n\x05\x45mpty2\xd3\x01\n\rDriverService\x12>\n\x0eRegisterWorker\x12\x16.RegisterWorkerRequest\x1a\x14.RegisterWorkerReply\x12S\n\x15RegisterWorkerService\x12\x1d.RegisterWorkerServiceRequest\x1a\x1b.RegisterWorkerServiceReply\x12-\n\x12RegisterFuncResult\x12\x0f.FunctionResult\x1a\x06.Empty2I\n\rWorkerService\x12 \n\x0bRunFunction\x12\t.Function\x1a\x06.Empty\x12\x16\n\x04Stop\x12\x06.Empty\x1a\x06.Emptyb\x06proto3')
17 |
18 | _globals = globals()
19 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
20 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'network_pb2', _globals)
21 | if _descriptor._USE_C_DESCRIPTORS == False:
22 |
23 | DESCRIPTOR._options = None
24 | _globals['_REGISTERWORKERREQUEST']._serialized_start=17
25 | _globals['_REGISTERWORKERREQUEST']._serialized_end=76
26 | _globals['_REGISTERWORKERREPLY']._serialized_start=78
27 | _globals['_REGISTERWORKERREPLY']._serialized_end=123
28 | _globals['_REGISTERWORKERSERVICEREQUEST']._serialized_start=125
29 | _globals['_REGISTERWORKERSERVICEREQUEST']._serialized_end=215
30 | _globals['_REGISTERWORKERSERVICEREPLY']._serialized_start=217
31 | _globals['_REGISTERWORKERSERVICEREPLY']._serialized_end=290
32 | _globals['_FUNCTION']._serialized_start=292
33 | _globals['_FUNCTION']._serialized_end=333
34 | _globals['_FUNCTIONRESULT']._serialized_start=335
35 | _globals['_FUNCTIONRESULT']._serialized_end=404
36 | _globals['_EMPTY']._serialized_start=406
37 | _globals['_EMPTY']._serialized_end=413
38 | _globals['_DRIVERSERVICE']._serialized_start=416
39 | _globals['_DRIVERSERVICE']._serialized_end=627
40 | _globals['_WORKERSERVICE']._serialized_start=629
41 | _globals['_WORKERSERVICE']._serialized_end=702
42 | # @@protoc_insertion_point(module_scope)
43 |
--------------------------------------------------------------------------------
/python/raydp/ray_cluster_resources.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from typing import Dict, List
19 |
20 | import ray
21 | import time
22 | from ray.ray_constants import MEMORY_RESOURCE_UNIT_BYTES
23 |
24 |
25 | class ClusterResources:
26 | # TODO: make this configurable
27 | refresh_interval = 0.1
28 | latest_refresh_time = time.time() - refresh_interval
29 | node_to_resources = {}
30 | item_keys_mapping = {"num_cpus": "CPU"}
31 | label_name = "__ray_spark_node_label"
32 |
33 | @classmethod
34 | def total_alive_nodes(cls):
35 | cls._refresh()
36 | return len(cls.node_to_resources)
37 |
38 | @classmethod
39 | def satisfy(cls, request: Dict[str, float]) -> List[str]:
40 | cls._refresh()
41 | satisfied = []
42 | for host_name, resources in cls.node_to_resources.items():
43 | if cls._compare_two_dict(resources, request):
44 | satisfied.append(resources[cls.label_name])
45 |
46 | return satisfied
47 |
48 | @classmethod
49 | def _refresh(cls):
50 | if (time.time() - cls.latest_refresh_time) < cls.refresh_interval:
51 | return
52 |
53 | for node in ray.nodes():
54 | if node["Alive"]:
55 | host_name = node["NodeManagerHostname"]
56 | resources = node["Resources"]
57 | for key in resources:
58 | if key.startswith("node:"):
59 | resources[cls.label_name] = key
60 | break
61 | assert cls.label_name in resources,\
62 | f"{resources} should contain a resource likes: 'node:10.0.0.131': 1.0"
63 | cls.node_to_resources[host_name] = resources
64 | cls.latest_refresh_time = time.time()
65 |
66 | @classmethod
67 | def _compare_two_dict(cls, available: Dict[str, float], request: Dict[str, float]) -> bool:
68 | for k, v in request.items():
69 | k = cls.item_keys_mapping.get(k, k)
70 | if k not in available:
71 | return False
72 |
73 | if k == "memory":
74 | v = int(v / MEMORY_RESOURCE_UNIT_BYTES)
75 |
76 | if available[k] < v:
77 | return False
78 |
79 | return True
80 |
--------------------------------------------------------------------------------
/core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimLoader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package com.intel.raydp.shims
19 |
20 | import java.util.ServiceLoader
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import org.apache.spark.SPARK_VERSION_SHORT
25 | import org.apache.spark.internal.Logging
26 |
27 | object SparkShimLoader extends Logging {
28 | private var sparkShims: SparkShims = null
29 | private var sparkShimProviderClass: String = null
30 |
31 | def getSparkShims: SparkShims = {
32 | if (sparkShims == null) {
33 | val provider = getSparkShimProvider()
34 | sparkShims = provider.createShim
35 | }
36 | sparkShims
37 | }
38 |
39 | def getSparkVersion: String = {
40 | SPARK_VERSION_SHORT
41 | }
42 |
43 | def setSparkShimProviderClass(providerClass: String): Unit = {
44 | sparkShimProviderClass = providerClass
45 | }
46 |
47 | private def loadSparkShimProvider(): SparkShimProvider = {
48 | // Match and load Shim provider for current Spark version.
49 | val sparkVersion = getSparkVersion
50 | logInfo(s"Loading Spark Shims for version: $sparkVersion")
51 |
52 | // Load and filter the providers based on version
53 | val shimProviders =
54 | ServiceLoader.load(classOf[SparkShimProvider]).asScala.filter(_.matches(sparkVersion))
55 | if (shimProviders.size > 1) {
56 | throw new IllegalStateException(s"More than one SparkShimProvider found: $shimProviders")
57 | }
58 |
59 | val shimProvider = shimProviders.headOption match {
60 | case Some(shimProvider) => shimProvider
61 | case None =>
62 | throw new IllegalStateException(s"No Spark Shim Provider found for $sparkVersion")
63 | }
64 | logInfo(s"Using Shim provider: $shimProviders")
65 | shimProvider
66 | }
67 |
68 | private def getSparkShimProvider(): SparkShimProvider = {
69 | if (sparkShimProviderClass != null) {
70 | logInfo(s"Using Spark Shim Provider specified by $sparkShimProviderClass. ")
71 | val providerClass = Class.forName(sparkShimProviderClass)
72 | val providerConstructor = providerClass.getConstructor()
73 | providerConstructor.newInstance().asInstanceOf[SparkShimProvider]
74 | } else {
75 | loadSparkShimProvider()
76 | }
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_tf.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import pyspark
19 | import pytest
20 | import os
21 | import sys
22 | import shutil
23 |
24 | import tensorflow as tf
25 | import tensorflow.keras as keras
26 |
27 | from pyspark.sql.functions import rand
28 |
29 | from raydp.tf import TFEstimator
30 | from raydp.utils import random_split
31 |
32 | @pytest.mark.parametrize("use_fs_directory", [True, False])
33 | def test_tf_estimator(spark_on_ray_small, use_fs_directory):
34 | spark = spark_on_ray_small
35 |
36 | # ---------------- data process with Spark ------------
37 | # calculate y = 3 * x + 4
38 | df: pyspark.sql.DataFrame = spark.range(0, 100000)
39 | df = df.withColumn("x", rand() * 100) # add x column
40 | df = df.withColumn("y", df.x * 3 + rand() + 4) # add y column
41 | df = df.select(df.x, df.y)
42 |
43 | train_df, test_df = random_split(df, [0.7, 0.3])
44 |
45 | # create model
46 | model = keras.Sequential(
47 | [
48 | keras.layers.InputLayer(input_shape=()),
49 | # Add feature dimension, expanding (batch_size,) to (batch_size, 1).
50 | keras.layers.Flatten(),
51 | keras.layers.Dense(1),
52 | ]
53 | )
54 |
55 | optimizer = keras.optimizers.Adam(0.01)
56 | loss = keras.losses.MeanSquaredError()
57 |
58 | estimator = TFEstimator(num_workers=2,
59 | model=model,
60 | optimizer=optimizer,
61 | loss=loss,
62 | metrics=["accuracy", "mse"],
63 | feature_columns="x",
64 | label_columns="y",
65 | batch_size=1000,
66 | num_epochs=2,
67 | use_gpu=False)
68 |
69 | if use_fs_directory:
70 | dir = os.path.dirname(__file__) + "/test_tf"
71 | uri = "file://" + dir
72 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
73 | else:
74 | estimator.fit_on_spark(train_df, test_df)
75 | model = estimator.get_model()
76 | result = model(tf.constant([0, 0]))
77 | assert result.shape == (2, 1)
78 | if use_fs_directory:
79 | shutil.rmtree(dir)
80 |
81 | if __name__ == "__main__":
82 | sys.exit(pytest.main(["-v", __file__]))
83 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterEntryPoint.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import java.io.{DataOutputStream, File, FileOutputStream}
21 | import java.net.InetAddress
22 | import java.nio.file.Files
23 |
24 | import scala.util.Try
25 |
26 | import py4j.GatewayServer
27 |
28 | import org.apache.spark.internal.Logging
29 |
30 |
31 | class AppMasterEntryPoint {
32 | private val appMaster: AppMasterJavaBridge = new AppMasterJavaBridge()
33 |
34 | def getAppMasterBridge(): AppMasterJavaBridge = {
35 | appMaster
36 | }
37 | }
38 |
39 | object AppMasterEntryPoint extends Logging {
40 | private val localhost = InetAddress.getLoopbackAddress()
41 |
42 | def getGatewayServer(): GatewayServer = {
43 | new GatewayServer.GatewayServerBuilder()
44 | .javaPort(0)
45 | .javaAddress(localhost)
46 | .entryPoint(new AppMasterEntryPoint())
47 | .build()
48 | }
49 |
50 | def main(args: Array[String]): Unit = {
51 |
52 | var server = getGatewayServer()
53 |
54 | while(true) {
55 | if (!Try(server.start()).isFailure) {
56 | val boundPort: Int = server.getListeningPort()
57 | if (boundPort == -1) {
58 | logError(s"${server.getClass} failed to bind; exiting")
59 | System.exit(1)
60 | } else {
61 | logDebug(s"Started PythonGatewayServer on port $boundPort")
62 | }
63 |
64 |
65 | val connectionInfoPath = new File(sys.env("_RAYDP_APPMASTER_CONN_INFO_PATH"))
66 | val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
67 | "connection", ".info").toFile()
68 |
69 | val dos = new DataOutputStream(new FileOutputStream(tmpPath))
70 | dos.writeInt(boundPort)
71 | dos.close()
72 |
73 | if (!tmpPath.renameTo(connectionInfoPath)) {
74 | logError(s"Unable to write connection information to $connectionInfoPath.")
75 | System.exit(1)
76 | }
77 |
78 | // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
79 | while (System.in.read() != -1) {
80 | // Do nothing
81 | }
82 | logDebug("Exiting due to broken pipe from Python driver")
83 | System.exit(0)
84 | } else {
85 | server.shutdown()
86 | logError(s"${server.getClass} failed to bind; retrying...")
87 | Thread.sleep(1000)
88 | server = getGatewayServer()
89 | }
90 | }
91 |
92 |
93 |
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/examples/test_raydp_submit_pyfiles.py:
--------------------------------------------------------------------------------
1 | """
2 | Test script for raydp-submit --py-files functionality.
3 | This script mimics raydp-submit.py but tests that --py-files works correctly.
4 | """
5 | from os.path import dirname, abspath, join
6 | import sys
7 | import json
8 | import subprocess
9 | import shlex
10 | import ray
11 |
12 | def main():
13 | print("Starting raydp-submit --py-files test...")
14 |
15 | # Initialize Ray and get cluster info
16 | ray.init(address="auto")
17 | node = ray.worker.global_worker.node
18 | options = {}
19 | options["ray"] = {}
20 | options["ray"]["run-mode"] = "CLUSTER"
21 | options["ray"]["node-ip"] = node.node_ip_address
22 | options["ray"]["address"] = node.address
23 | options["ray"]["session-dir"] = node.get_session_dir_path()
24 |
25 | ray.shutdown()
26 |
27 | # Write Ray configuration
28 | examples_dir = dirname(abspath(__file__))
29 | conf_path = join(examples_dir, "ray.conf")
30 | with open(conf_path, "w") as f:
31 | json.dump(options, f)
32 |
33 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access
34 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP.
35 | java_opts = " ".join([
36 | "-XX:+IgnoreUnrecognizedVMOptions",
37 | "--add-opens=java.base/java.lang=ALL-UNNAMED",
38 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
39 | "--add-opens=java.base/java.io=ALL-UNNAMED",
40 | "--add-opens=java.base/java.net=ALL-UNNAMED",
41 | "--add-opens=java.base/java.nio=ALL-UNNAMED",
42 | "--add-opens=java.base/java.math=ALL-UNNAMED",
43 | "--add-opens=java.base/java.text=ALL-UNNAMED",
44 | "--add-opens=java.base/java.util=ALL-UNNAMED",
45 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
46 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
47 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
48 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
49 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED",
50 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
51 | ])
52 |
53 | # Build raydp-submit command
54 | command = ["bin/raydp-submit", "--ray-conf", conf_path]
55 | command += ["--conf", "spark.executor.cores=1"]
56 | command += ["--conf", "spark.executor.instances=1"]
57 | command += ["--conf", "spark.executor.memory=500m"]
58 | command += ["--conf", f"spark.executor.extraJavaOptions={java_opts}"]
59 | command += ["--conf", f"spark.driver.extraJavaOptions={java_opts}"]
60 | command += ["--conf", f"spark.ray.raydp_app_master.extraJavaOptions={java_opts}"]
61 |
62 | # Add --py-files with test_pyfile.py
63 | test_pyfile_path = join(examples_dir, "test_pyfile.py")
64 | command += ["--py-files", test_pyfile_path]
65 |
66 | # Add the main script
67 | main_script_path = join(examples_dir, "test_pyfiles_main.py")
68 | command.append(main_script_path)
69 |
70 | # Execute the command
71 | print("\nExecuting command:")
72 | cmd_str = " ".join(shlex.quote(arg) for arg in command)
73 | print(cmd_str)
74 | print("\n" + "=" * 60)
75 |
76 | result = subprocess.run(cmd_str, check=True, shell=True)
77 |
78 | return result.returncode
79 |
80 |
81 | if __name__ == "__main__":
82 | sys.exit(main())
83 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterJavaBridge.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.deploy.raydp
19 |
20 | import java.util.Map
21 |
22 | import scala.collection.JavaConverters._
23 |
24 | import io.ray.api.{ActorHandle, Ray}
25 |
26 | import org.apache.spark.raydp.SparkOnRayConfigs
27 |
28 | class AppMasterJavaBridge {
29 | private var handle: ActorHandle[RayAppMaster] = null
30 |
31 | def startUpAppMaster(extra_cp: String, sparkProps: Map[String, Any]): Unit = {
32 | if (handle == null) {
33 | // init ray, we should set the config by java properties
34 | Ray.init()
35 | val name = RayAppMaster.ACTOR_NAME
36 | val sparkJvmOptions = sparkProps.asScala.toMap.filter(
37 | e => {
38 | !SparkOnRayConfigs.SPARK_DRIVER_EXTRA_JAVA_OPTIONS.equals(e._1) &&
39 | !SparkOnRayConfigs.SPARK_APP_MASTER_EXTRA_JAVA_OPTIONS.equals(e._1)
40 | })
41 | .map {
42 | case (k, v) =>
43 | if (!SparkOnRayConfigs.SPARK_JAVAAGENT.equals(k)) {
44 | "-D" + k + "=" + v
45 | } else {
46 | "-javaagent:" + v
47 | }
48 | }.toBuffer
49 |
50 | // Add raw JVM options from spark.ray.raydp_app_master.extraJavaOptions
51 | // (e.g., --add-opens for JDK 17)
52 | val appMasterExtraJavaOptions =
53 | sparkProps.get(SparkOnRayConfigs.SPARK_APP_MASTER_EXTRA_JAVA_OPTIONS)
54 | if (appMasterExtraJavaOptions != null) {
55 | val opts = appMasterExtraJavaOptions.toString.trim
56 | if (opts.nonEmpty) {
57 | sparkJvmOptions ++= opts.split("\\s+").toSeq
58 | }
59 | }
60 |
61 | val appMasterResources = sparkProps.asScala.filter {
62 | case (k, v) => k.startsWith(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX)
63 | }.map{ case (k, v) => k->double2Double(v.toString.toDouble) }.asJava
64 |
65 | handle = RayAppMasterUtils.createAppMaster(
66 | extra_cp, name,
67 | (sparkJvmOptions ++ Seq(SparkOnRayConfigs.RAYDP_LOGFILE_PREFIX_CFG)).asJava,
68 | appMasterResources)
69 | }
70 | }
71 |
72 | def getMasterUrl(): String = {
73 | if (handle == null) {
74 | throw new RuntimeException("You should create the RayAppMaster handle first")
75 | }
76 | RayAppMasterUtils.getMasterUrl(handle)
77 | }
78 |
79 | def stop(): Unit = {
80 | if (handle != null) {
81 | RayAppMasterUtils.stopAppMaster(handle)
82 | Ray.shutdown()
83 | handle = null
84 | }
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/core/shims/spark340/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark340
15 | RayDP Shims for Spark 3.4.0
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark340.version}
70 | provided
71 |
72 |
73 | org.apache.spark
74 | spark-core_${scala.binary.version}
75 | ${spark340.version}
76 | provided
77 |
78 |
79 | org.xerial.snappy
80 | snappy-java
81 |
82 |
83 | io.netty
84 | netty-handler
85 |
86 |
87 |
88 |
89 | org.xerial.snappy
90 | snappy-java
91 | ${snappy.version}
92 |
93 |
94 | io.netty
95 | netty-handler
96 | ${netty.version}
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/core/shims/spark350/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark350
15 | RayDP Shims for Spark 3.5.0
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark350.version}
70 | provided
71 |
72 |
73 | org.apache.spark
74 | spark-core_${scala.binary.version}
75 | ${spark350.version}
76 | provided
77 |
78 |
79 | org.xerial.snappy
80 | snappy-java
81 |
82 |
83 | io.netty
84 | netty-handler
85 |
86 |
87 |
88 |
89 | org.xerial.snappy
90 | snappy-java
91 | ${snappy.version}
92 |
93 |
94 | io.netty
95 | netty-handler
96 | ${netty.version}
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/examples/xgboost_ray_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import ray
2 | import numpy as np
3 | # XGBoost on ray is needed to run this example.
4 | # Please refer to https://docs.ray.io/en/latest/xgboost-ray.html to install it.
5 | from xgboost_ray import RayDMatrix, train, RayParams
6 | import raydp
7 | from raydp.utils import random_split
8 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV
9 |
10 | # connect to ray cluster
11 | # ray.init(address="auto")
12 | ray.init(address="local", num_cpus=4)
13 | # After ray.init, you can use the raydp api to get a spark session
14 | app_name = "NYC Taxi Fare Prediction with RayDP"
15 | num_executors = 1
16 | cores_per_executor = 1
17 | memory_per_executor = "500M"
18 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access
19 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP.
20 | java_opts = " ".join([
21 | "-XX:+IgnoreUnrecognizedVMOptions",
22 | "--add-opens=java.base/java.lang=ALL-UNNAMED",
23 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
24 | "--add-opens=java.base/java.io=ALL-UNNAMED",
25 | "--add-opens=java.base/java.net=ALL-UNNAMED",
26 | "--add-opens=java.base/java.nio=ALL-UNNAMED",
27 | "--add-opens=java.base/java.math=ALL-UNNAMED",
28 | "--add-opens=java.base/java.text=ALL-UNNAMED",
29 | "--add-opens=java.base/java.util=ALL-UNNAMED",
30 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
31 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
32 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
33 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
34 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED",
35 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
36 | ])
37 | extra_configs = {
38 | "spark.executor.extraJavaOptions": java_opts,
39 | "spark.driver.extraJavaOptions": java_opts,
40 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts,
41 | }
42 | spark = raydp.init_spark(app_name, num_executors,
43 | cores_per_executor, memory_per_executor,
44 | configs=extra_configs)
45 | data = spark.read.format("csv").option("header", "true") \
46 | .option("inferSchema", "true") \
47 | .load(NYC_TRAIN_CSV)
48 | # Set spark timezone for processing datetime
49 | spark.conf.set("spark.sql.session.timeZone", "UTC")
50 | # Transform the dataset
51 | data = nyc_taxi_preprocess(data)
52 | # Split data into train_dataset and test_dataset
53 | train_df, test_df = random_split(data, [0.9, 0.1], 0)
54 | # Convert spark dataframe into ray dataset
55 | train_dataset = ray.data.from_spark(train_df)
56 | test_dataset = ray.data.from_spark(test_df)
57 | # Then convert them into DMatrix used by xgboost
58 | dtrain = RayDMatrix(train_dataset, label="fare_amount")
59 | dtest = RayDMatrix(test_dataset, label="fare_amount")
60 | # Configure the XGBoost model
61 | config = {
62 | "tree_method": "hist",
63 | "eval_metric": ["logloss", "error"],
64 | }
65 | evals_result = {}
66 | # Train the model
67 | bst = train(
68 | config,
69 | dtrain,
70 | evals=[(dtest, "eval")],
71 | evals_result=evals_result,
72 | ray_params=RayParams(max_actor_restarts=1, num_actors=1, cpus_per_actor=1),
73 | num_boost_round=10)
74 | # print evaluation stats
75 | print("Final validation error: {:.4f}".format(
76 | evals_result["eval"]["error"][-1]))
77 | raydp.stop_spark()
78 | ray.shutdown()
79 |
--------------------------------------------------------------------------------
/python/raydp/services.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from abc import ABC, abstractmethod
19 | from typing import Any, Dict, NoReturn
20 |
21 |
22 | class Cluster(ABC):
23 | """
24 | This is the base class for all specified cluster, such as SparkCluster, FlinkCluster.
25 | :param master_resources_requirement: The resources requirement for the master service.
26 | """
27 | def __init__(self, master_resources_requirement):
28 | # the master node is live as same as ray driver node. And we can specify the resources
29 | # limitation for master node. So we don't count it.
30 | self._num_nodes = 0
31 |
32 | @abstractmethod
33 | def _set_up_master(self,
34 | resources: Dict[str, float],
35 | kwargs: Dict[Any, Any]):
36 | """
37 | Subcluster should implement this to set up master node.
38 | """
39 |
40 | def add_worker(self,
41 | resources_requirement: Dict[str, float],
42 | **kwargs: Dict[Any, Any]):
43 | """
44 | Add one worker to the cluster.
45 | :param resources_requirement: The resource requirements for the worker service.
46 | """
47 | try:
48 | self._set_up_worker(resources_requirement, kwargs)
49 | except:
50 | self.stop()
51 | raise
52 |
53 | @abstractmethod
54 | def _set_up_worker(self,
55 | resources: Dict[str, float],
56 | kwargs: Dict[str, str]):
57 | """
58 | Subcluster should implement this to set up worker node.
59 | """
60 |
61 | @abstractmethod
62 | def get_cluster_url(self) -> str:
63 | """
64 | Return the cluster url, eg: spark://master-node:7077
65 | """
66 |
67 | @abstractmethod
68 | def stop(self):
69 | """
70 | Stop cluster
71 | """
72 |
73 |
74 | class ClusterMaster(ABC):
75 |
76 | @abstractmethod
77 | def start_up(self) -> NoReturn:
78 | pass
79 |
80 | @abstractmethod
81 | def get_master_url(self) -> str:
82 | pass
83 |
84 | @abstractmethod
85 | def get_host(self) -> str:
86 | pass
87 |
88 | @abstractmethod
89 | def stop(self):
90 | pass
91 |
92 |
93 | class ClusterWorker(ABC):
94 |
95 | @abstractmethod
96 | def start_up(self) -> str:
97 | """
98 | :return: error message, return None if succeeded
99 | """
100 |
101 | @abstractmethod
102 | def get_host(self) -> str:
103 | pass
104 |
105 | @abstractmethod
106 | def stop(self):
107 | pass
108 |
--------------------------------------------------------------------------------
/python/raydp/tests/test_torch.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from pyspark.sql import SparkSession
19 | import pytest
20 | import os
21 | import sys
22 | import shutil
23 | import torch
24 |
25 | # https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html
26 | # import databricks.koalas as ks
27 | import pyspark.pandas as ps
28 |
29 | from raydp.torch import TorchEstimator
30 | from raydp.utils import random_split
31 |
32 | @pytest.mark.parametrize("use_fs_directory", [True, False])
33 | def test_torch_estimator(spark_on_ray_small, use_fs_directory):
34 | # ---------------- data process with koalas ------------
35 | spark: SparkSession = spark_on_ray_small
36 |
37 | # calculate z = 3 * x + 4 * y + 5
38 | df = ps.range(0, 100000)
39 | df["x"] = df["id"] + 100
40 | df["y"] = df["id"] + 1000
41 | df["z"] = df["x"] * 3 + df["y"] * 4 + 5
42 | df = df.astype("float")
43 |
44 | train_df, test_df = random_split(df, [0.7, 0.3])
45 |
46 | # ---------------- ray sgd -------------------------
47 | # create the model
48 | class LinearModel(torch.nn.Module):
49 | def __init__(self):
50 | super(LinearModel, self).__init__()
51 | self.linear = torch.nn.Linear(2, 1)
52 |
53 | def forward(self, x):
54 | return self.linear(x)
55 |
56 | model = LinearModel()
57 | # create the optimizer
58 | optimizer = torch.optim.Adam(model.parameters())
59 | # create the loss
60 | loss = torch.nn.MSELoss()
61 | # create lr_scheduler
62 |
63 | def lr_scheduler_creator(optimizer, config):
64 | return torch.optim.lr_scheduler.MultiStepLR(
65 | optimizer, milestones=[150, 250, 350], gamma=0.1)
66 |
67 | # create the estimator
68 | estimator = TorchEstimator(num_workers=2,
69 | model=model,
70 | optimizer=optimizer,
71 | loss=loss,
72 | lr_scheduler_creator=lr_scheduler_creator,
73 | feature_columns=["x", "y"],
74 | feature_types=torch.float,
75 | label_column="z",
76 | label_type=torch.float,
77 | batch_size=1000,
78 | num_epochs=2,
79 | use_gpu=False)
80 |
81 | # train the model
82 | if use_fs_directory:
83 | dir = os.path.dirname(__file__) + "/test_torch"
84 | uri = "file://" + dir
85 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri)
86 | else:
87 | estimator.fit_on_spark(train_df, test_df)
88 | model = estimator.get_model()
89 | result = model(torch.Tensor([[0, 0], [1, 1]]))
90 | assert result.shape == (2, 1)
91 | if use_fs_directory:
92 | shutil.rmtree(dir)
93 |
94 |
95 | if __name__ == "__main__":
96 | sys.exit(pytest.main(["-v", __file__]))
97 |
--------------------------------------------------------------------------------
/core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.raydp;
19 |
20 | import io.ray.api.ActorHandle;
21 | import io.ray.api.ObjectRef;
22 | import io.ray.api.Ray;
23 | import io.ray.api.call.ActorCreator;
24 | import java.util.Map;
25 | import java.util.List;
26 |
27 | import io.ray.api.placementgroup.PlacementGroup;
28 | import io.ray.runtime.object.ObjectRefImpl;
29 | import org.apache.spark.executor.RayDPExecutor;
30 |
31 | public class RayExecutorUtils {
32 | /**
33 | * Convert from mbs -> memory units. The memory units in ray is byte
34 | */
35 |
36 | private static double toMemoryUnits(int memoryInMB) {
37 | double result = 1.0 * memoryInMB * 1024 * 1024;
38 | return Math.round(result);
39 | }
40 |
41 | public static ActorHandle createExecutorActor(
42 | String executorId,
43 | String appMasterURL,
44 | double cores,
45 | int memoryInMB,
46 | Map resources,
47 | PlacementGroup placementGroup,
48 | int bundleIndex,
49 | List javaOpts) {
50 | ActorCreator creator = Ray.actor(
51 | RayDPExecutor::new, executorId, appMasterURL);
52 | creator.setName("raydp-executor-" + executorId);
53 | creator.setJvmOptions(javaOpts);
54 | creator.setResource("CPU", cores);
55 | creator.setResource("memory", toMemoryUnits(memoryInMB));
56 |
57 | for (Map.Entry entry: resources.entrySet()) {
58 | creator.setResource(entry.getKey(), entry.getValue());
59 | }
60 | if (placementGroup != null) {
61 | creator.setPlacementGroup(placementGroup, bundleIndex);
62 | }
63 | creator.setMaxRestarts(-1);
64 | creator.setMaxTaskRetries(-1);
65 | creator.setMaxConcurrency(2);
66 | return creator.remote();
67 | }
68 |
69 | public static void setUpExecutor(
70 | ActorHandle handler,
71 | String appId,
72 | String driverUrl,
73 | int cores,
74 | String classPathEntries) {
75 | handler.task(RayDPExecutor::startUp,
76 | appId, driverUrl, cores, classPathEntries).remote();
77 | }
78 |
79 | public static String[] getBlockLocations(
80 | ActorHandle handler,
81 | int rddId,
82 | int numPartitions) {
83 | return handler.task(RayDPExecutor::getBlockLocations,
84 | rddId, numPartitions).remote().get();
85 | }
86 |
87 | public static ObjectRef getRDDPartition(
88 | ActorHandle handle,
89 | int rddId,
90 | int partitionId,
91 | String schema,
92 | String driverAgentUrl) {
93 | return (ObjectRefImpl) handle.task(
94 | RayDPExecutor::getRDDPartition,
95 | rddId, partitionId, schema, driverAgentUrl).remote();
96 | }
97 |
98 | public static void exitExecutor(
99 | ActorHandle handle
100 | ) {
101 | handle.task(RayDPExecutor::stop).remote();
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/doc/mpi.md:
--------------------------------------------------------------------------------
1 | # MPI on Ray
2 |
3 | RayDP also provides a simple API to running MPI job on top of Ray. Currently, we support three types of MPI: `intel_mpi`, `openmpi` and `MPICH`. To use the following API, make sure you have installed the given type of MPI on each of Ray worker node.
4 |
5 | ### API
6 |
7 | ```python
8 | def create_mpi_job(job_name: str,
9 | world_size: int,
10 | num_cpus_per_process: int,
11 | num_processes_per_node: int,
12 | mpi_script_prepare_fn: Callable = None,
13 | timeout: int = 1,
14 | mpi_type: str = "intel_mpi",
15 | placement_group=None,
16 | placement_group_bundle_indexes: List[int] = None) -> MPIJob:
17 | """ Create a MPI Job
18 |
19 | :param job_name: the job name
20 | :param world_size: the world size
21 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray
22 | :param num_processes_per_node: num processes per node
23 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a
24 | MPIJobcontext instance. It will use the default script if not provides.
25 | :param timeout: the timeout used to wait for job creation
26 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and mpich
27 | :param placement_group: the placement_group for request mpi resources
28 | :param placement_group_bundle_indexes: this should be equal with
29 | world_size / num_processes_per_node if provides.
30 | """
31 | ```
32 |
33 | ### Create a simple MPI Job
34 |
35 | ```python
36 | from raydp.mpi import create_mpi_job, MPIJobContext, WorkerContext
37 |
38 | # Define the MPI JOb. We want to create a 4 world_size MPIJob, and each process requires 2 cpus.
39 | # We have set the num_processes_per_node to 2, so the processes will be strictly spread into two nodes.
40 |
41 | # You could also to specify the placement group to reserve the resources for MPI job. The num_cpus_per_process
42 | # will be ignored if the placement group is provided. And the size of
43 | # placement_group_bundle_indexes should be equal with world_size // num_processes_per_node.
44 | job = create_mpi_job(job_name="example",
45 | world_size=4,
46 | num_cpus_per_process=2,
47 | num_processes_per_node=2,
48 | timeout=5,
49 | mpi_type="intel_mpi",
50 | placement_group=None,
51 | placement_group_bundle_indexes: List[int] = None)
52 |
53 | # Start the MPI Job, this will start up the MPI processes and connect to the ray cluster
54 | job.start()
55 |
56 | # define the MPI task function
57 | def func(context: WorkerContext):
58 | return context.job_id
59 |
60 | # run the MPI task, this is a blocking operation. And the results is a world_size array.
61 | results = job.run(func)
62 |
63 | # stop the MPI job
64 | job.stop()
65 | ```
66 |
67 | ### Use `with` auto start/stop MPIJob
68 | ```python
69 | with create_mpi_job(job_name="example",
70 | world_size=4,
71 | num_cpus_per_process=2,
72 | num_processes_per_node=2,
73 | timeout=5,
74 | mpi_type="intel_mpi") as job:
75 | def f(context: WorkerContext):
76 | return context.job_id
77 | results = job.run(f)
78 | ```
79 |
80 | ### Specify the MPI script and environments
81 |
82 | You could customize the MPI job environments and MPI scripts with `mpi_script_prepare_fn` argument.
83 |
84 | ```python
85 | def script_prepare_fn(context: MPIJobContext):
86 | context.add_env("OMP_NUM_THREADS", "2")
87 | default_script = ["mpirun", "--allow-run-as-root", "--tag-output", "-H",
88 | ",".join(context.hosts), "-N", f"{context.num_procs_per_node}"]
89 | return default_script
90 |
91 | job = create_mpi_job(job_name="example",
92 | world_size=4,
93 | num_cpus_per_process=2,
94 | num_processes_per_node=2,
95 | timeout=5,
96 | mpi_type="intel_mpi",
97 | mpi_script_prepare_fn=script_prepare_fn)
98 | ```
99 |
--------------------------------------------------------------------------------
/core/agent/src/main/java/org/apache/spark/raydp/Agent.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.raydp;
19 |
20 | import org.slf4j.LoggerFactory;
21 |
22 | import java.io.File;
23 | import java.io.FileOutputStream;
24 | import java.io.IOException;
25 | import java.io.OutputStream;
26 | import java.io.OutputStreamWriter;
27 | import java.io.PrintStream;
28 | import java.io.Writer;
29 | import java.lang.instrument.Instrumentation;
30 | import java.lang.management.ManagementFactory;
31 | import java.nio.charset.Charset;
32 | import java.nio.charset.StandardCharsets;
33 |
34 |
35 | public class Agent {
36 |
37 | public static final PrintStream DEFAULT_ERR_PS = System.err;
38 |
39 | public static final PrintStream DEFAULT_OUT_PS = System.out;
40 |
41 | public static void premain(String agentArgs, Instrumentation inst)
42 | throws IOException {
43 | // redirect system output/error stream so that annoying SLF4J warnings
44 | // and other logs during binding
45 | // SLF4J factory don't show in spark-shell
46 | // Instead, the warnings and logs are kept in
47 | // /logs/slf4j-.log
48 |
49 | String pid = ManagementFactory.getRuntimeMXBean().getName()
50 | .split("@")[0];
51 | String logDir = System.getProperty("ray.logging.dir");
52 | if (logDir == null) {
53 | logDir = "/tmp/ray/session_latest/logs";
54 | System.getProperties().put("ray.logging.dir", logDir);
55 | }
56 |
57 | File parentDir = new File(logDir);
58 | if (!parentDir.exists()) {
59 | boolean flag = parentDir.mkdirs();
60 | if (!flag) {
61 | throw new RuntimeException("Error create log dir.");
62 | }
63 | }
64 |
65 | File logFile = new File(parentDir, "/slf4j-" + pid + ".log");
66 | try (PrintStream ps = new PrintStream(logFile, "UTF-8")) {
67 | System.setOut(ps);
68 | System.setErr(ps);
69 | // slf4j binding
70 | LoggerFactory.getLogger(Agent.class);
71 | } catch (Exception e) {
72 | e.printStackTrace();
73 | } finally {
74 | System.out.flush();
75 | System.err.flush();
76 | // restore system output/error stream
77 | System.setErr(DEFAULT_ERR_PS);
78 | System.setOut(DEFAULT_OUT_PS);
79 | }
80 | // below is to write ':job_id:' to first line of log file prefixed with 'java-worker' as required by
81 | // PR, https://github.com/ray-project/ray/pull/31772.
82 | // It's a workaround of the ray 2.3.[0-1] issue going to be fixed by https://github.com/ray-project/ray/pull/33665.
83 | String jobId = System.getenv("RAY_JOB_ID");
84 | String rayAddress = System.getProperty("ray.address");
85 | if (jobId != null && rayAddress != null) {
86 | String prefix = "java-worker";
87 | // TODO: uncomment after the ray PR #33665 released
88 | // String prefix = System.getProperty("ray.logging.file-prefix", "java-worker");
89 | // if ("java-worker".equals(prefix)) {
90 | File file = new File(new String((logDir + "/" + prefix + "-" + jobId + "-" + pid + ".log")
91 | .getBytes(Charset.forName("UTF-8")), "UTF-8"));
92 | try (OutputStream out = new FileOutputStream(file);
93 | Writer writer = new OutputStreamWriter(out, StandardCharsets.UTF_8)) {
94 | writer.write(":job_id:" + jobId + "\n");
95 | }
96 | // }
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/.github/workflows/pypi_release.yml:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | name: RayDP PyPI Release
19 |
20 | on:
21 | workflow_dispatch:
22 | inputs:
23 | tag:
24 | description: 'Git tag to publish (e.g., v1.7.0 or 1.7.0)'
25 | required: true
26 | type: string
27 |
28 | permissions: # added using https://github.com/step-security/secure-repo
29 | contents: read
30 |
31 | jobs:
32 | build-and-publish:
33 | # do not run in forks
34 | if: ${{ github.repository_owner == 'ray-project' }}
35 | name: build wheel and upload release
36 | runs-on: ubuntu-latest
37 | env:
38 | PYSPARK_VERSION: "3.5.7"
39 | RAY_VERSION: "2.40.0"
40 | steps:
41 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master
42 | with:
43 | ref: ${{ inputs.tag }}
44 | fetch-depth: 0
45 | - name: Set up Python 3.10
46 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4
47 | with:
48 | python-version: 3.10.14
49 | - name: Set up JDK 1.8
50 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4
51 | with:
52 | java-version: 1.8
53 | - name: Install extra dependencies for Ubuntu
54 | run: |
55 | sudo apt-get install -y mpich
56 | - name: Cache pip
57 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
58 | with:
59 | path: ~/.cache/pip
60 | key: ubuntu-latest-3.10.14-pip
61 | - name: Install dependencies
62 | run: |
63 | python -m pip install --upgrade pip
64 | pip install wheel
65 | pip install "numpy<1.24" "click<8.3.0"
66 | pip install "pydantic<2.0"
67 | pip install torch --index-url https://download.pytorch.org/whl/cpu
68 | pip install pyarrow "ray[train,default]==${{ env.RAY_VERSION }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget
69 | pip install "xgboost_ray[default]<=0.1.13"
70 | pip install "xgboost<=2.0.3"
71 | pip install torchmetrics
72 | - name: Cache Maven
73 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2
74 | with:
75 | path: ~/.m2
76 | key: ubuntu-latest-m2-${{ hashFiles('core/pom.xml') }}
77 | - name: Build and install
78 | env:
79 | GITHUB_CI: 1
80 | run: |
81 | pip install pyspark==${{ env.PYSPARK_VERSION }}
82 | ./build.sh
83 | pip install dist/raydp-*.whl
84 | - name: Lint
85 | run: |
86 | pip install pylint==2.8.3
87 | pylint --rcfile=python/pylintrc python/raydp
88 | pylint --rcfile=python/pylintrc examples/*.py
89 | - name: Test with pytest
90 | run: |
91 | ray start --head --num-cpus 6
92 | pytest python/raydp/tests/ -v
93 | ray stop --force
94 | - name: Test Examples
95 | run: |
96 | ray start --head
97 | python examples/raydp-submit.py
98 | python examples/test_raydp_submit_pyfiles.py
99 | ray stop
100 | python examples/pytorch_nyctaxi.py
101 | python examples/tensorflow_nyctaxi.py
102 | python examples/xgboost_ray_nyctaxi.py
103 | # python examples/raytrain_nyctaxi.py
104 | python examples/data_process.py
105 | - name: Upload to PyPI
106 | uses: pypa/gh-action-pypi-publish@v1.13.0
107 | with:
108 | password: ${{ secrets.PYPI_API_TOKEN }}
109 |
--------------------------------------------------------------------------------
/python/raydp/torch/torch_ml_dataset.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import logging
19 | import queue
20 | import threading
21 | from typing import Callable
22 |
23 | import ray
24 | from ray.util.data import MLDataset
25 | from torch.utils.data import IterableDataset
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | class TorchMLDataset(IterableDataset):
31 | def __init__(self,
32 | ds: MLDataset,
33 | collate_fn: Callable,
34 | shuffle: bool = False,
35 | shuffle_seed: int = None):
36 | super().__init__()
37 | self.ds = ds
38 | self.collate_fn = collate_fn
39 | self.shuffle = shuffle
40 | self.shuffle_seed = shuffle_seed or 1
41 |
42 | def __iter__(self):
43 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards())
44 | it = iter(it)
45 | for pdf in it:
46 | if self.shuffle:
47 | pdf = pdf.sample(frac=1.0, random_state=self.shuffle_seed)
48 | yield self.collate_fn(pdf)
49 |
50 | def __len__(self):
51 | all_actors = []
52 | for actor_set in self.ds.actor_sets:
53 | all_actors.extend(actor_set.actors)
54 | assert len(all_actors) > 0
55 | if "__len__" in dir(all_actors[0]):
56 | # This is a very hack method to get the length of the iterator
57 | num_records = sum([ray.get(actor.__len__.remote()) for actor in all_actors])
58 | else:
59 | logger.warning("The MLDataset has not provide the __len__ method, we will iter all "
60 | "data to count the number of rows. This should be pretty slowly.")
61 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards())
62 | it = iter(it)
63 | num_records = 0
64 | for pdf in it:
65 | num_records += pdf.shape[0]
66 | return num_records
67 |
68 |
69 | class PrefetchedDataLoader:
70 | def __init__(self, base_loader, max_size: int = 5):
71 | self.base_loader = base_loader
72 | self.max_size = max_size
73 | self.queue = queue.Queue(maxsize=max_size)
74 | self.fetcher = None
75 | self.fetcher_stop = threading.Event()
76 |
77 | def _setup(self):
78 | if self.fetcher is not None:
79 | self.fetcher_stop.set()
80 | if self.queue is not None and not self.queue.empty():
81 | self.queue.get()
82 | self.queue = queue.Queue(maxsize=self.max_size)
83 | self.fetcher = None
84 | self.fetcher_stop.clear()
85 |
86 | it = iter(self.base_loader)
87 |
88 | def fetch_task():
89 | while not self.fetcher_stop.is_set():
90 | try:
91 | got_data = next(it)
92 | self.queue.put(got_data)
93 | except StopIteration:
94 | self.queue.put(None)
95 | break
96 | except: # pylint: disable=W0707, W0706
97 | raise
98 | self.fetcher = threading.Thread(target=fetch_task)
99 | self.fetcher.start()
100 |
101 | def __iter__(self):
102 | self._setup()
103 | while True:
104 | fetched_data = self.queue.get()
105 | if fetched_data is not None:
106 | yield fetched_data
107 | else:
108 | break
109 |
110 | def __len__(self):
111 | return len(self.base_loader)
112 |
--------------------------------------------------------------------------------
/core/shims/spark330/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark330
15 | RayDP Shims for Spark 3.3.0
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark330.version}
70 | provided
71 |
72 |
73 | com.google.protobuf
74 | protobuf-java
75 |
76 |
77 |
78 |
79 | org.apache.spark
80 | spark-core_${scala.binary.version}
81 | ${spark330.version}
82 | provided
83 |
84 |
85 | org.xerial.snappy
86 | snappy-java
87 |
88 |
89 | io.netty
90 | netty-handler
91 |
92 |
93 | org.apache.commons
94 | commons-text
95 |
96 |
97 | org.apache.ivy
98 | ivy
99 |
100 |
101 |
102 |
103 | org.xerial.snappy
104 | snappy-java
105 | ${snappy.version}
106 |
107 |
108 | io.netty
109 | netty-handler
110 | ${netty.version}
111 |
112 |
113 | org.apache.commons
114 | commons-text
115 | ${commons.text.version}
116 |
117 |
118 | org.apache.ivy
119 | ivy
120 | ${ivy.version}
121 |
122 |
123 | com.google.protobuf
124 | protobuf-java
125 | ${protobuf.version}
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/python/raydp/mpi/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import os
19 | import select
20 | import subprocess
21 | import threading
22 | import time
23 | from typing import List
24 |
25 | import grpc
26 | import netifaces
27 |
28 |
29 | class StoppableThread(threading.Thread):
30 |
31 | def __init__(self, group=None, target=None, name=None,
32 | args=(), kwargs=None, *, daemon=None):
33 | super().__init__(group, target, name, args, kwargs, daemon=daemon)
34 | self._stop_event = threading.Event()
35 |
36 | def stop(self):
37 | self._stop_event.set()
38 |
39 | def stopped(self):
40 | return self._stop_event.is_set()
41 |
42 |
43 | def run_cmd(cmd: str, env, failed_callback):
44 | # pylint: disable=R1732
45 | proc = subprocess.Popen(cmd,
46 | shell=True,
47 | stdin=subprocess.DEVNULL,
48 | stdout=subprocess.PIPE,
49 | stderr=subprocess.PIPE,
50 | env=env,
51 | start_new_session=True)
52 |
53 | def check_failed():
54 | # check whether the process has finished
55 | while not threading.current_thread().stopped():
56 | ret_code = proc.poll()
57 | if ret_code:
58 | failed_callback()
59 | raise Exception(f"mpirun failed: {ret_code}")
60 |
61 | if ret_code == 0:
62 | break
63 |
64 | time.sleep(1)
65 |
66 | check_thread = StoppableThread(target=check_failed)
67 |
68 | def redirect_stream(streams):
69 | while not threading.current_thread().stopped() and streams:
70 | readable, _, _ = select.select(streams, [], [], 0.5)
71 | for stream in readable:
72 | if not stream:
73 | continue
74 | line = stream.readline()
75 | if not line:
76 | streams.remove(stream)
77 | else:
78 | print(line.decode().strip("\n"))
79 |
80 | redirect_thread = StoppableThread(target=redirect_stream, args=([proc.stdout, proc.stderr],))
81 | check_thread.start()
82 | redirect_thread.start()
83 | return proc, check_thread, redirect_thread
84 |
85 |
86 | def create_insecure_channel(address,
87 | options=None,
88 | compression=None):
89 | """Disable the http proxy when create channel"""
90 | # disable http proxy
91 | if options is not None:
92 | need_add = True
93 | for k, v in options:
94 | if k == "grpc.enable_http_proxy":
95 | need_add = False
96 | break
97 | if need_add:
98 | options = (*options, ("grpc.enable_http_proxy", 0))
99 | else:
100 | options = (("grpc.enable_http_proxy", 0),)
101 |
102 | return grpc.insecure_channel(
103 | address, options, compression)
104 |
105 |
106 | def get_environ_value(key: str) -> str:
107 | """Get value from environ, raise exception if the key not existed"""
108 | assert key in os.environ, f"{key} should be set in the environ"
109 | return os.environ[key]
110 |
111 |
112 | def get_node_ip_address(node_addresses: List[str]) -> str:
113 | found = None
114 | for interface in netifaces.interfaces():
115 | addrs = netifaces.ifaddresses(interface)
116 | addresses = addrs.get(netifaces.AF_INET, None)
117 | if not addresses:
118 | continue
119 | for inet_addr in addresses:
120 | address = inet_addr.get("addr", None)
121 | if address in node_addresses:
122 | found = address
123 | return found
124 |
--------------------------------------------------------------------------------
/core/shims/spark322/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 | 4.0.0
6 |
7 |
8 | com.intel
9 | raydp-shims
10 | 1.7.0-SNAPSHOT
11 | ../pom.xml
12 |
13 |
14 | raydp-shims-spark322
15 | RayDP Shims for Spark 3.2.2
16 | jar
17 |
18 |
19 | 2.12.15
20 | 2.13.5
21 |
22 |
23 |
24 |
25 |
26 | org.scalastyle
27 | scalastyle-maven-plugin
28 |
29 |
30 | net.alchim31.maven
31 | scala-maven-plugin
32 | 3.2.2
33 |
34 |
35 | scala-compile-first
36 | process-resources
37 |
38 | compile
39 |
40 |
41 |
42 | scala-test-compile-first
43 | process-test-resources
44 |
45 | testCompile
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | src/main/resources
55 |
56 |
57 |
58 |
59 |
60 |
61 | com.intel
62 | raydp-shims-common
63 | ${project.version}
64 | compile
65 |
66 |
67 | org.apache.spark
68 | spark-sql_${scala.binary.version}
69 | ${spark322.version}
70 | provided
71 |
72 |
73 | com.google.protobuf
74 | protobuf-java
75 |
76 |
77 |
78 |
79 | org.apache.spark
80 | spark-core_${scala.binary.version}
81 | ${spark322.version}
82 | provided
83 |
84 |
85 | org.xerial.snappy
86 | snappy-java
87 |
88 |
89 | org.apache.commons
90 | commons-compress
91 |
92 |
93 | org.apache.commons
94 | commons-text
95 |
96 |
97 | org.apache.ivy
98 | ivy
99 |
100 |
101 | log4j
102 | log4j
103 |
104 |
105 |
106 |
107 | org.xerial.snappy
108 | snappy-java
109 | ${snappy.version}
110 |
111 |
112 | org.apache.commons
113 | commons-compress
114 | ${commons.compress.version}
115 |
116 |
117 | org.apache.commons
118 | commons-text
119 | ${commons.text.version}
120 |
121 |
122 | org.apache.ivy
123 | ivy
124 | ${ivy.version}
125 |
126 |
127 | com.google.protobuf
128 | protobuf-java
129 | ${protobuf.version}
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/examples/tensorflow_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import ray
2 | from tensorflow import keras
3 | from tensorflow.keras.callbacks import Callback
4 |
5 | import raydp
6 | from raydp.tf import TFEstimator
7 | from raydp.utils import random_split
8 |
9 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV
10 | from typing import List, Dict
11 | # Firstly, You need to init or connect to a ray cluster.
12 | # Note that you should set include_java to True.
13 | # For more config info in ray, please refer the ray doc:
14 | # https://docs.ray.io/en/latest/package-ref.html
15 | # ray.init(address="auto")
16 | ray.init(address="local", num_cpus=6)
17 |
18 | # After initialize ray cluster, you can use the raydp api to get a spark session
19 | app_name = "NYC Taxi Fare Prediction with RayDP"
20 | num_executors = 1
21 | cores_per_executor = 1
22 | memory_per_executor = "500M"
23 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access
24 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP.
25 | java_opts = " ".join([
26 | "-XX:+IgnoreUnrecognizedVMOptions",
27 | "--add-opens=java.base/java.lang=ALL-UNNAMED",
28 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
29 | "--add-opens=java.base/java.io=ALL-UNNAMED",
30 | "--add-opens=java.base/java.net=ALL-UNNAMED",
31 | "--add-opens=java.base/java.nio=ALL-UNNAMED",
32 | "--add-opens=java.base/java.math=ALL-UNNAMED",
33 | "--add-opens=java.base/java.text=ALL-UNNAMED",
34 | "--add-opens=java.base/java.util=ALL-UNNAMED",
35 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
36 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
37 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
38 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
39 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED",
40 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
41 | ])
42 | extra_configs = {
43 | "spark.executor.extraJavaOptions": java_opts,
44 | "spark.driver.extraJavaOptions": java_opts,
45 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts,
46 | }
47 | spark = raydp.init_spark(app_name, num_executors,
48 | cores_per_executor, memory_per_executor,
49 | configs=extra_configs)
50 |
51 | # Then you can code as you are using spark
52 | # The dataset can be downloaded from:
53 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data
54 | # Here we just use a subset of the training data
55 | data = spark.read.format("csv").option("header", "true") \
56 | .option("inferSchema", "true") \
57 | .load(NYC_TRAIN_CSV)
58 | # Set spark timezone for processing datetime
59 | spark.conf.set("spark.sql.session.timeZone", "UTC")
60 | # Transform the dataset
61 | data = nyc_taxi_preprocess(data)
62 | data = data.cache()
63 | # Split data into train_dataset and test_dataset
64 | train_df, test_df = random_split(data, [0.9, 0.1], 0)
65 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"]
66 |
67 | # Define the keras model
68 | model = keras.Sequential(
69 | [
70 | keras.layers.InputLayer(input_shape=(len(features),)),
71 | keras.layers.Flatten(),
72 | keras.layers.Dense(256, activation="relu"),
73 | keras.layers.BatchNormalization(),
74 | keras.layers.Dense(128, activation="relu"),
75 | keras.layers.BatchNormalization(),
76 | keras.layers.Dense(64, activation="relu"),
77 | keras.layers.BatchNormalization(),
78 | keras.layers.Dense(32, activation="relu"),
79 | keras.layers.BatchNormalization(),
80 | keras.layers.Dense(16, activation="relu"),
81 | keras.layers.BatchNormalization(),
82 | keras.layers.Dense(1),
83 | ]
84 | )
85 |
86 | class PrintingCallback(Callback):
87 | def handle_result(self, results: List[Dict], **info):
88 | print(results)
89 |
90 | # Define the optimizer and loss function
91 | # Then create the tensorflow estimator provided by Raydp
92 | adam = keras.optimizers.Adam(learning_rate=0.001)
93 | loss = keras.losses.MeanSquaredError()
94 | estimator = TFEstimator(num_workers=2, model=model, optimizer=adam, loss=loss,
95 | merge_feature_columns=True, metrics=["mae"],
96 | feature_columns=features, label_columns="fare_amount",
97 | batch_size=256, num_epochs=10, callbacks=[PrintingCallback()])
98 |
99 | # Train the model
100 | estimator.fit_on_spark(train_df, test_df)
101 | # Get the model
102 | model = estimator.get_model()
103 | # shudown raydp and ray
104 | raydp.stop_spark()
105 | ray.shutdown()
106 |
--------------------------------------------------------------------------------
/python/raydp/mpi/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 |
19 | from typing import Callable, List
20 |
21 | from .mpi_job import MPIJob, MPIType, IntelMPIJob, OpenMPIJob, MPICHJob, MPIJobContext
22 | from .mpi_worker import WorkerContext
23 |
24 |
25 | def _get_mpi_type(mpi_type: str) -> MPIType:
26 | if mpi_type.strip().lower() == "openmpi":
27 | return MPIType.OPEN_MPI
28 | elif mpi_type.strip().lower() == "intel_mpi":
29 | return MPIType.INTEL_MPI
30 | elif mpi_type.strip().lower() == "mpich":
31 | return MPIType.MPICH
32 | else:
33 | return None
34 |
35 |
36 | def create_mpi_job(job_name: str,
37 | world_size: int,
38 | num_cpus_per_process: int,
39 | num_processes_per_node: int,
40 | mpi_script_prepare_fn: Callable = None,
41 | timeout: int = 1,
42 | mpi_type: str = "intel_mpi",
43 | placement_group=None,
44 | placement_group_bundle_indexes: List[int] = None) -> MPIJob:
45 | """Create a MPI Job
46 |
47 | :param job_name: the job name
48 | :param world_size: the world size
49 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray
50 | :param num_processes_per_node: num processes per node
51 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a
52 | MPIJobContext instance. It will use the default script if not provides.
53 | :param timeout: the timeout used to wait for job creation
54 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and MPICH
55 | :param placement_group: the placement_group for request mpi resources
56 | :param placement_group_bundle_indexes: this should be equal with
57 | world_size / num_processes_per_node if provides.
58 | """
59 | mpi_type = _get_mpi_type(mpi_type)
60 | if mpi_type == MPIType.OPEN_MPI:
61 | return OpenMPIJob(mpi_type=MPIType.OPEN_MPI,
62 | job_name=job_name,
63 | world_size=world_size,
64 | num_cpus_per_process=num_cpus_per_process,
65 | num_processes_per_node=num_processes_per_node,
66 | mpi_script_prepare_fn=mpi_script_prepare_fn,
67 | timeout=timeout,
68 | placement_group=placement_group,
69 | placement_group_bundle_indexes=placement_group_bundle_indexes)
70 | elif mpi_type == MPIType.INTEL_MPI:
71 | return IntelMPIJob(mpi_type=MPIType.INTEL_MPI,
72 | job_name=job_name,
73 | world_size=world_size,
74 | num_cpus_per_process=num_cpus_per_process,
75 | num_processes_per_node=num_processes_per_node,
76 | mpi_script_prepare_fn=mpi_script_prepare_fn,
77 | timeout=timeout,
78 | placement_group=placement_group,
79 | placement_group_bundle_indexes=placement_group_bundle_indexes)
80 | elif mpi_type == MPIType.MPICH:
81 | return MPICHJob(mpi_type=MPIType.MPICH,
82 | job_name=job_name,
83 | world_size=world_size,
84 | num_cpus_per_process=num_cpus_per_process,
85 | num_processes_per_node=num_processes_per_node,
86 | mpi_script_prepare_fn=mpi_script_prepare_fn,
87 | timeout=timeout,
88 | placement_group=placement_group,
89 | placement_group_bundle_indexes=placement_group_bundle_indexes)
90 | else:
91 | raise Exception(f"MPI type: {mpi_type} not supported now")
92 |
93 |
94 | __all__ = ["create_mpi_job", "MPIJobContext", "WorkerContext"]
95 |
--------------------------------------------------------------------------------
/examples/pytorch_nyctaxi.py:
--------------------------------------------------------------------------------
1 | import ray
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import raydp
7 | from raydp.torch import TorchEstimator
8 | from raydp.utils import random_split
9 |
10 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV
11 | from typing import List, Dict
12 |
13 | # Firstly, You need to init or connect to a ray cluster.
14 | # Note that you should set include_java to True.
15 | # For more config info in ray, please refer the ray doc:
16 | # https://docs.ray.io/en/latest/package-ref.html
17 | # ray.init(address="auto")
18 | ray.init(address="local", num_cpus=4)
19 |
20 | # After initialize ray cluster, you can use the raydp api to get a spark session
21 | app_name = "NYC Taxi Fare Prediction with RayDP"
22 | num_executors = 1
23 | cores_per_executor = 1
24 | memory_per_executor = "500M"
25 |
26 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access
27 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP.
28 | java_opts = " ".join([
29 | "-XX:+IgnoreUnrecognizedVMOptions",
30 | "--add-opens=java.base/java.lang=ALL-UNNAMED",
31 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
32 | "--add-opens=java.base/java.io=ALL-UNNAMED",
33 | "--add-opens=java.base/java.net=ALL-UNNAMED",
34 | "--add-opens=java.base/java.nio=ALL-UNNAMED",
35 | "--add-opens=java.base/java.math=ALL-UNNAMED",
36 | "--add-opens=java.base/java.text=ALL-UNNAMED",
37 | "--add-opens=java.base/java.util=ALL-UNNAMED",
38 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED",
39 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED",
40 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED",
41 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED",
42 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED",
43 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED",
44 | ])
45 | extra_configs = {
46 | "spark.executor.extraJavaOptions": java_opts,
47 | "spark.driver.extraJavaOptions": java_opts,
48 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts,
49 | }
50 | spark = raydp.init_spark(app_name, num_executors,
51 | cores_per_executor, memory_per_executor,
52 | configs=extra_configs)
53 |
54 | # Then you can code as you are using spark
55 | # The dataset can be downloaded from:
56 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data
57 | # Here we just use a subset of the training data
58 | data = spark.read.format("csv").option("header", "true") \
59 | .option("inferSchema", "true") \
60 | .load(NYC_TRAIN_CSV)
61 | # Set spark timezone for processing datetime
62 | spark.conf.set("spark.sql.session.timeZone", "UTC")
63 | # Transform the dataset
64 | data = nyc_taxi_preprocess(data)
65 | # Split data into train_dataset and test_dataset
66 | train_df, test_df = random_split(data, [0.9, 0.1], 0)
67 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"]
68 | # Define a neural network model
69 | class NYC_Model(nn.Module):
70 | def __init__(self, cols):
71 | super().__init__()
72 | self.fc1 = nn.Linear(cols, 256)
73 | self.fc2 = nn.Linear(256, 128)
74 | self.fc3 = nn.Linear(128, 64)
75 | self.fc4 = nn.Linear(64, 16)
76 | self.fc5 = nn.Linear(16, 1)
77 | self.bn1 = nn.BatchNorm1d(256)
78 | self.bn2 = nn.BatchNorm1d(128)
79 | self.bn3 = nn.BatchNorm1d(64)
80 | self.bn4 = nn.BatchNorm1d(16)
81 |
82 | def forward(self, x):
83 | x = F.relu(self.fc1(x))
84 | x = self.bn1(x)
85 | x = F.relu(self.fc2(x))
86 | x = self.bn2(x)
87 | x = F.relu(self.fc3(x))
88 | x = self.bn3(x)
89 | x = F.relu(self.fc4(x))
90 | x = self.bn4(x)
91 | x = self.fc5(x)
92 | return x
93 |
94 | nyc_model = NYC_Model(len(features))
95 | criterion = nn.SmoothL1Loss()
96 | optimizer = torch.optim.Adam(nyc_model.parameters(), lr=0.001)
97 | # Create a distributed estimator based on the raydp api
98 | estimator = TorchEstimator(num_workers=1, model=nyc_model, optimizer=optimizer, loss=criterion,
99 | feature_columns=features, feature_types=torch.float,
100 | label_column="fare_amount", label_type=torch.float,
101 | batch_size=64, num_epochs=30,
102 | metrics_name = ["MeanAbsoluteError", "MeanSquaredError"],
103 | use_ccl=False)
104 | # Train the model
105 | estimator.fit_on_spark(train_df, test_df)
106 | # Get the trained model
107 | model = estimator.get_model()
108 | # shutdown raydp and ray
109 | raydp.stop_spark()
110 | ray.shutdown()
111 |
--------------------------------------------------------------------------------