├── core ├── shims │ ├── spark322 │ │ ├── src │ │ │ └── main │ │ │ │ ├── resources │ │ │ │ └── META-INF │ │ │ │ │ └── services │ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ │ │ └── scala │ │ │ │ ├── org │ │ │ │ └── apache │ │ │ │ │ └── spark │ │ │ │ │ ├── TaskContextUtils.scala │ │ │ │ │ ├── sql │ │ │ │ │ └── SparkSqlUtils.scala │ │ │ │ │ └── executor │ │ │ │ │ └── RayDPSpark322ExecutorBackendFactory.scala │ │ │ │ └── com │ │ │ │ └── intel │ │ │ │ └── raydp │ │ │ │ └── shims │ │ │ │ ├── SparkShimProvider.scala │ │ │ │ └── SparkShims.scala │ │ └── pom.xml │ ├── spark330 │ │ ├── src │ │ │ └── main │ │ │ │ ├── resources │ │ │ │ └── META-INF │ │ │ │ │ └── services │ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ │ │ └── scala │ │ │ │ ├── org │ │ │ │ └── apache │ │ │ │ │ └── spark │ │ │ │ │ ├── TaskContextUtils.scala │ │ │ │ │ ├── sql │ │ │ │ │ └── SparkSqlUtils.scala │ │ │ │ │ └── executor │ │ │ │ │ ├── RayCoarseGrainedExecutorBackend.scala │ │ │ │ │ └── RayDPSpark330ExecutorBackendFactory.scala │ │ │ │ └── com │ │ │ │ └── intel │ │ │ │ └── raydp │ │ │ │ └── shims │ │ │ │ ├── SparkShimProvider.scala │ │ │ │ └── SparkShims.scala │ │ └── pom.xml │ ├── spark340 │ │ ├── src │ │ │ └── main │ │ │ │ ├── resources │ │ │ │ └── META-INF │ │ │ │ │ └── services │ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ │ │ └── scala │ │ │ │ ├── org │ │ │ │ └── apache │ │ │ │ │ └── spark │ │ │ │ │ ├── TaskContextUtils.scala │ │ │ │ │ ├── executor │ │ │ │ │ ├── RayCoarseGrainedExecutorBackend.scala │ │ │ │ │ └── RayDPSpark340ExecutorBackendFactory.scala │ │ │ │ │ └── sql │ │ │ │ │ └── SparkSqlUtils.scala │ │ │ │ └── com │ │ │ │ └── intel │ │ │ │ └── raydp │ │ │ │ └── shims │ │ │ │ ├── SparkShimProvider.scala │ │ │ │ └── SparkShims.scala │ │ └── pom.xml │ ├── spark350 │ │ ├── src │ │ │ └── main │ │ │ │ ├── resources │ │ │ │ └── META-INF │ │ │ │ │ └── services │ │ │ │ │ └── com.intel.raydp.shims.SparkShimProvider │ │ │ │ └── scala │ │ │ │ ├── org │ │ │ │ └── apache │ │ │ │ │ └── spark │ │ │ │ │ ├── TaskContextUtils.scala │ │ │ │ │ ├── executor │ │ │ │ │ ├── RayCoarseGrainedExecutorBackend.scala │ │ │ │ │ └── RayDPSpark350ExecutorBackendFactory.scala │ │ │ │ │ └── sql │ │ │ │ │ └── SparkSqlUtils.scala │ │ │ │ └── com │ │ │ │ └── intel │ │ │ │ └── raydp │ │ │ │ └── shims │ │ │ │ ├── SparkShimProvider.scala │ │ │ │ └── SparkShims.scala │ │ └── pom.xml │ ├── common │ │ └── src │ │ │ └── main │ │ │ └── scala │ │ │ ├── com │ │ │ └── intel │ │ │ │ └── raydp │ │ │ │ └── shims │ │ │ │ ├── SparkShimProvider.scala │ │ │ │ ├── SparkShims.scala │ │ │ │ └── SparkShimLoader.scala │ │ │ └── org │ │ │ └── apache │ │ │ └── spark │ │ │ └── executor │ │ │ └── RayDPExecutorBackendFactory.scala │ └── pom.xml ├── raydp-main │ └── src │ │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── org.apache.spark.scheduler.ExternalClusterManager │ │ ├── scala │ │ └── org │ │ │ └── apache │ │ │ └── spark │ │ │ ├── RayDPException.scala │ │ │ ├── deploy │ │ │ └── raydp │ │ │ │ ├── ApplicationState.scala │ │ │ │ ├── Messages.scala │ │ │ │ ├── RayExternalShuffleService.scala │ │ │ │ ├── ApplicationDescription.scala │ │ │ │ ├── RayDPDriverAgent.scala │ │ │ │ ├── AppMasterEntryPoint.scala │ │ │ │ └── AppMasterJavaBridge.scala │ │ │ ├── scheduler │ │ │ └── cluster │ │ │ │ └── raydp │ │ │ │ └── RayClusterManager.scala │ │ │ ├── rdd │ │ │ ├── RayObjectRefRDD.scala │ │ │ └── RayDatasetRDD.scala │ │ │ └── sql │ │ │ └── raydp │ │ │ └── ObjectStoreReader.scala │ │ ├── java │ │ └── org │ │ │ └── apache │ │ │ └── spark │ │ │ ├── deploy │ │ │ └── raydp │ │ │ │ ├── ExternalShuffleServiceUtils.java │ │ │ │ └── RayAppMasterUtils.java │ │ │ └── raydp │ │ │ ├── RayDPUtils.java │ │ │ └── RayExecutorUtils.java │ │ └── test │ │ └── org │ │ └── apache │ │ └── spark │ │ └── scheduler │ │ └── cluster │ │ └── raydp │ │ └── TestRayCoarseGrainedSchedulerBackend.java ├── agent │ ├── pom.xml │ └── src │ │ └── main │ │ └── java │ │ └── org │ │ └── apache │ │ └── spark │ │ └── raydp │ │ └── Agent.java └── javastyle-suppressions.xml ├── examples ├── test_pyfile.py ├── random_nyctaxi.py ├── README.md ├── raydp-submit.py ├── test_pyfiles_main.py ├── test_raydp_submit_pyfiles.py ├── xgboost_ray_nyctaxi.py ├── tensorflow_nyctaxi.py └── pytorch_nyctaxi.py ├── .gitignore ├── SECURITY.md ├── docker ├── Dockerfile ├── build-docker.sh └── README.md ├── python ├── raydp │ ├── tf │ │ └── __init__.py │ ├── torch │ │ ├── __init__.py │ │ ├── config.py │ │ ├── torch_metrics.py │ │ └── torch_ml_dataset.py │ ├── xgboost │ │ └── __init__.py │ ├── __init__.py │ ├── mpi │ │ ├── network │ │ │ ├── __init__.py │ │ │ ├── network.proto │ │ │ └── network_pb2.py │ │ ├── constants.py │ │ ├── utils.py │ │ └── __init__.py │ ├── spark │ │ ├── parallel_iterator_worker.py │ │ ├── __init__.py │ │ └── interfaces.py │ ├── estimator.py │ ├── versions.py │ ├── tests │ │ ├── test_torch_sequential.py │ │ ├── test_spark_master_memory.py │ │ ├── test_xgboost.py │ │ ├── test_tf.py │ │ └── test_torch.py │ ├── ray_cluster_resources.py │ └── services.py └── MANIFEST.in ├── .github └── workflows │ ├── pypi.yml │ └── pypi_release.yml └── doc └── mpi.md /core/shims/spark322/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark322.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark330.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark340.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/resources/META-INF/services/com.intel.raydp.shims.SparkShimProvider: -------------------------------------------------------------------------------- 1 | com.intel.raydp.shims.spark350.SparkShimProvider 2 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager: -------------------------------------------------------------------------------- 1 | org.apache.spark.scheduler.cluster.raydp.RayClusterManager -------------------------------------------------------------------------------- /examples/test_pyfile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper module for testing raydp-submit --py-files functionality. 3 | This module defines functions that will be used by the main.py script. 4 | """ 5 | 6 | def compute_sum(numbers): 7 | """Compute sum to verify module functionality in Spark.""" 8 | return sum(numbers) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | 4 | __pycache__/ 5 | build/ 6 | dist/ 7 | *.egg-info/ 8 | *.eggs/ 9 | 10 | dev/.tmp_dir/ 11 | target/ 12 | *.jar 13 | 14 | .DS_Store 15 | 16 | .vscode 17 | examples/.ipynb_checkpoints/ 18 | .python-version 19 | 20 | # Vim temp files 21 | *.swp 22 | *.swo 23 | *.parquet 24 | *.crc 25 | _SUCCESS 26 | 27 | .metals/ 28 | .bloop/ 29 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Report a Vulnerability 4 | 5 | Please report security issues or vulnerabilities to the [Intel® Security Center]. 6 | 7 | For more information on how Intel® works to resolve security issues, see 8 | [Vulnerability Handling Guidelines]. 9 | 10 | [Intel® Security Center]:https://www.intel.com/security 11 | 12 | [Vulnerability Handling Guidelines]:https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html 13 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray:latest@sha256:c864e37f4ce516ff49425f69cac5503a51e84c333d30928416714a2c3da55b43 2 | 3 | ARG HTTP_PROXY 4 | ARG HTTPS_PROXY 5 | 6 | # set http_proxy & https_proxy 7 | ENV http_proxy=${HTTP_PROXY} 8 | ENV https_proxy=${HTTPS_PROXY} 9 | 10 | # install java, create workdir and install raydp 11 | # You could change the raydp to raydp-nightly if you want to try the master branch code 12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \ 13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \ 14 | && sudo mkdir /raydp \ 15 | && sudo chown -R ray /raydp \ 16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp 17 | 18 | WORKDIR /raydp 19 | 20 | # unset http_proxy & https_proxy 21 | ENV http_proxy= 22 | ENV https_proxy= 23 | -------------------------------------------------------------------------------- /python/raydp/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .estimator import TFEstimator 19 | 20 | __all__ = ["TFEstimator"] 21 | -------------------------------------------------------------------------------- /python/raydp/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .estimator import TorchEstimator 19 | 20 | __all__ = ["TorchEstimator"] 21 | -------------------------------------------------------------------------------- /python/raydp/xgboost/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .estimator import XGBoostEstimator 19 | 20 | __all__ = ["XGBoostEstimator"] 21 | -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | include README.md 19 | recursive-include raydp/jars *.jar 20 | global-exclude *.py[cod] __pycache__ .DS_Store -------------------------------------------------------------------------------- /python/raydp/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from raydp.context import init_spark, stop_spark 19 | 20 | __version__ = "1.7.0.dev0" 21 | 22 | __all__ = ["init_spark", "stop_spark"] 23 | -------------------------------------------------------------------------------- /python/raydp/mpi/network/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import sys 19 | import os 20 | 21 | dir_path = os.path.dirname(os.path.realpath(__file__)) 22 | sys.path.append(str(dir_path)) 23 | -------------------------------------------------------------------------------- /docker/build-docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | docker build --build-arg HTTP_PROXY=${http_proxy} \ 21 | --build-arg HTTPS_PROXY=${https_proxy} \ 22 | -t oap-project/raydp:latest . 23 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/RayDPException.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | class RayDPException(message: String, cause: Throwable) 21 | extends SparkException(message, cause) { 22 | def this(message: String) = this(message, null) 23 | } 24 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationState.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | object ApplicationState extends Enumeration { 21 | 22 | type ApplicationState = Value 23 | 24 | val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value 25 | } 26 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims 19 | 20 | /** 21 | * Provider interface for matching and retrieving the Shims of a specific Spark version 22 | */ 23 | trait SparkShimProvider { 24 | def matches(version:String): Boolean 25 | def createShim: SparkShims 26 | } 27 | -------------------------------------------------------------------------------- /python/raydp/mpi/constants.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from os import path 19 | 20 | MPI_TYPE = "raydp_mpi_type" 21 | MPI_JOB_ID = "raydp_mpi_job_id" 22 | MPI_DRIVER_HOST = "raydp_mpi_driver_host" 23 | MPI_DRIVER_PORT = "raydp_mpi_driver_port" 24 | 25 | MAXIMUM_WAIT_TIME_OUT = "raydp_maximum_wait_time_out" 26 | 27 | _current_dir = path.dirname(path.realpath(__file__)) 28 | MPI_MAIN_CLASS_PATH = path.join(_current_dir, "mpi_worker.py") 29 | -------------------------------------------------------------------------------- /python/raydp/spark/parallel_iterator_worker.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from typing import Any 19 | 20 | from ray.util.iter import ParallelIteratorWorker 21 | 22 | 23 | class ParallelIteratorWorkerWithLen(ParallelIteratorWorker): 24 | def __init__(self, item_generator: Any, repeat: bool, num_records: int): 25 | super().__init__(item_generator, repeat) 26 | self.num_records = num_records 27 | 28 | def __len__(self): 29 | return self.num_records 30 | -------------------------------------------------------------------------------- /core/shims/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-parent 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims 15 | RayDP Shims 16 | pom 17 | 18 | 19 | common 20 | spark322 21 | spark330 22 | spark340 23 | spark350 24 | 25 | 26 | 27 | 2.12 28 | 4.3.0 29 | 3.2.2 30 | 31 | 32 | 33 | 34 | 35 | net.alchim31.maven 36 | scala-maven-plugin 37 | ${scala.plugin.version} 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark322 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark330 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark340 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/TaskContextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.spark350 19 | 20 | import java.util.Properties 21 | 22 | import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} 23 | import org.apache.spark.memory.TaskMemoryManager 24 | 25 | object TaskContextUtils { 26 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 27 | new TaskContextImpl(0, 0, partitionId, -1024, 0, 0, 28 | new TaskMemoryManager(env.memoryManager, 0), new Properties(), env.metricsSystem) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /examples/random_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | base_date = np.datetime64("2010-01-01 00:00:00") 8 | 9 | parser = argparse.ArgumentParser(description="Rabdin NYC taxi Generator") 10 | parser.add_argument( 11 | "--num-records", 12 | type=int, 13 | default=2000, 14 | metavar="N", 15 | help="number of records to generate (default: 2000)") 16 | 17 | args = parser.parse_args() 18 | 19 | N = args.num_records 20 | 21 | fare_amount = np.random.uniform(3.0, 50.0, size=N) 22 | pick_long = np.random.uniform(-74.2, -73.8, size=N) 23 | pick_lat = np.random.uniform(40.7, 40.8, size=N) 24 | drop_long = np.random.uniform(-74.2, -73.8, size=N) 25 | drop_lat = np.random.uniform(40.7, 40.8, size=N) 26 | passenger_count = np.random.randint(1, 5, size=N) 27 | date = np.random.randint(0, 157680000, size=N) + base_date 28 | date = np.array([t.item().strftime("%Y-%m-%d %H:%m:%S UTC") for t in date]) 29 | key = ["fake_key"] * N 30 | df = pd.DataFrame({ 31 | "key": key, 32 | "fare_amount":fare_amount, 33 | "pickup_datetime": date, 34 | "pickup_longitude": pick_long, 35 | "pickup_latitude": pick_lat, 36 | "dropoff_longitude": drop_long, 37 | "dropoff_latitude": drop_lat, 38 | "passenger_count": passenger_count 39 | }) 40 | csv_path = os.path.dirname(os.path.realpath(__file__)) + "/fake_nyctaxi.csv" 41 | df.to_csv(csv_path, index=False) 42 | -------------------------------------------------------------------------------- /python/raydp/spark/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from .dataset import PartitionObjectsOwner, \ 19 | get_raydp_master_owner, \ 20 | spark_dataframe_to_ray_dataset, \ 21 | ray_dataset_to_spark_dataframe, \ 22 | from_spark_recoverable 23 | from .interfaces import SparkEstimatorInterface 24 | from .ray_cluster import SparkCluster 25 | 26 | __all__ = [ 27 | "SparkCluster", 28 | "SparkEstimatorInterface", 29 | "PartitionObjectsOwner", 30 | "get_raydp_master_owner", 31 | "spark_dataframe_to_ray_dataset", 32 | "ray_dataset_to_spark_dataframe", 33 | "from_spark_recoverable" 34 | ] 35 | -------------------------------------------------------------------------------- /python/raydp/estimator.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, NoReturn, Optional 20 | 21 | 22 | 23 | class EstimatorInterface(ABC): 24 | """ 25 | A scikit-learn like API. 26 | """ 27 | 28 | @abstractmethod 29 | def fit(self, 30 | train_ds, 31 | evaluate_ds = None) -> NoReturn: 32 | """Train or evaluate the model. 33 | 34 | :param train_ds: the model will train on the MLDataset 35 | :param evaluate_ds: if this is provided, the model will evaluate on the MLDataset 36 | """ 37 | 38 | @abstractmethod 39 | def get_model(self) -> Any: 40 | """Get the trained model 41 | 42 | :return the model 43 | """ 44 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/org/apache/spark/executor/RayDPExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.rpc.RpcEnv 24 | import org.apache.spark.resource.ResourceProfile 25 | 26 | trait RayDPExecutorBackendFactory { 27 | def createExecutorBackend( 28 | rpcEnv: RpcEnv, 29 | driverUrl: String, 30 | executorId: String, 31 | bindAddress: String, 32 | hostname: String, 33 | cores: Int, 34 | userClassPath: Seq[URL], 35 | env: SparkEnv, 36 | resourcesFileOpt: Option[String], 37 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend 38 | } 39 | -------------------------------------------------------------------------------- /python/raydp/versions.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import re 19 | import pyspark 20 | 21 | 22 | # log4j1 if spark version <= 3.2, otherwise, log4j2 23 | SPARK_LOG4J_VERSION = "log4j" 24 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j.configurationFile" 25 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j-default.properties" 26 | _spark_ver = re.search("\\d+\\.\\d+", pyspark.version.__version__) 27 | if _spark_ver.group(0) > "3.2": 28 | SPARK_LOG4J_VERSION = "log4j2" 29 | SPARK_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile" 30 | SPARK_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2-default.properties" 31 | 32 | # support ray >= 2.1, they all use log4j2 33 | RAY_LOG4J_VERSION = "log4j2" 34 | RAY_LOG4J_CONFIG_FILE_NAME_KEY = "log4j2.configurationFile" 35 | RAY_LOG4J_CONFIG_FILE_NAME_DEFAULT = "log4j2.xml" 36 | -------------------------------------------------------------------------------- /python/raydp/torch/config.py: -------------------------------------------------------------------------------- 1 | from ray.train.torch.config import _TorchBackend 2 | from ray.train.torch.config import TorchConfig as RayTorchConfig 3 | from ray.train._internal.worker_group import WorkerGroup 4 | from dataclasses import dataclass 5 | import sys 6 | # The package importlib_metadata is in a different place, depending on the Python version. 7 | if sys.version_info < (3, 8): 8 | import importlib_metadata 9 | else: 10 | import importlib.metadata as importlib_metadata 11 | 12 | @dataclass 13 | class TorchConfig(RayTorchConfig): 14 | 15 | @property 16 | def backend_cls(self): 17 | return EnableCCLBackend 18 | 19 | def libs_import(): 20 | """try to import IPEX and oneCCL. 21 | """ 22 | try: 23 | import intel_extension_for_pytorch 24 | except ImportError: 25 | raise ImportError( 26 | "Please install intel_extension_for_pytorch" 27 | ) 28 | try: 29 | ccl_version = importlib_metadata.version("oneccl_bind_pt") 30 | if ccl_version >= "1.12": 31 | # pylint: disable-all 32 | import oneccl_bindings_for_pytorch 33 | else: 34 | import torch_ccl 35 | except ImportError as ccl_not_exist: 36 | raise ImportError( 37 | "Please install torch-ccl" 38 | ) from ccl_not_exist 39 | 40 | class EnableCCLBackend(_TorchBackend): 41 | 42 | def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): 43 | for i in range(len(worker_group)): 44 | worker_group.execute_single_async(i, libs_import) 45 | super().on_start(worker_group, backend_config) 46 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark330 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 23 | import org.apache.spark.sql.execution.arrow.ArrowConverters 24 | import org.apache.spark.sql.types.StructType 25 | import org.apache.spark.sql.util.ArrowUtils 26 | 27 | object SparkSqlUtils { 28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { 29 | ArrowConverters.toDataFrame(rdd, schema, session) 30 | } 31 | 32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark322 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 23 | import org.apache.spark.sql.execution.arrow.ArrowConverters 24 | import org.apache.spark.sql.types.StructType 25 | import org.apache.spark.sql.util.ArrowUtils 26 | 27 | object SparkSqlUtils { 28 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame = { 29 | ArrowConverters.toDataFrame(rdd, schema, new SQLContext(session)) 30 | } 31 | 32 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 33 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /python/raydp/spark/interfaces.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import NoReturn 19 | from typing import Optional, Union 20 | 21 | from raydp.utils import convert_to_spark 22 | 23 | DF = Union["pyspark.sql.DataFrame", "pyspark.pandas.DataFrame"] 24 | OPTIONAL_DF = Union[Optional["pyspark.sql.DataFrame"], Optional["pyspark.pandas.DataFrame"]] 25 | 26 | 27 | class SparkEstimatorInterface: 28 | def _check_and_convert(self, df): 29 | train_df, _ = convert_to_spark(df) 30 | return train_df 31 | 32 | def fit_on_spark(self, 33 | train_df: DF, 34 | evaluate_df: OPTIONAL_DF = None) -> NoReturn: 35 | """Fit and evaluate the model on the Spark or koalas DataFrame. 36 | 37 | :param train_df the DataFrame which the model will train on. 38 | :param evaluate_df the optional DataFrame which the model evaluate on it 39 | """ 40 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.resource.ResourceProfile 24 | import org.apache.spark.rpc.RpcEnv 25 | 26 | class RayCoarseGrainedExecutorBackend( 27 | rpcEnv: RpcEnv, 28 | driverUrl: String, 29 | executorId: String, 30 | bindAddress: String, 31 | hostname: String, 32 | cores: Int, 33 | userClassPath: Seq[URL], 34 | env: SparkEnv, 35 | resourcesFileOpt: Option[String], 36 | resourceProfile: ResourceProfile) 37 | extends CoarseGrainedExecutorBackend( 38 | rpcEnv, 39 | driverUrl, 40 | executorId, 41 | bindAddress, 42 | hostname, 43 | cores, 44 | env, 45 | resourcesFileOpt, 46 | resourceProfile) { 47 | 48 | override def getUserClassPath: Seq[URL] = userClassPath 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.resource.ResourceProfile 24 | import org.apache.spark.rpc.RpcEnv 25 | 26 | class RayCoarseGrainedExecutorBackend( 27 | rpcEnv: RpcEnv, 28 | driverUrl: String, 29 | executorId: String, 30 | bindAddress: String, 31 | hostname: String, 32 | cores: Int, 33 | userClassPath: Seq[URL], 34 | env: SparkEnv, 35 | resourcesFileOpt: Option[String], 36 | resourceProfile: ResourceProfile) 37 | extends CoarseGrainedExecutorBackend( 38 | rpcEnv, 39 | driverUrl, 40 | executorId, 41 | bindAddress, 42 | hostname, 43 | cores, 44 | env, 45 | resourcesFileOpt, 46 | resourceProfile) { 47 | 48 | override def getUserClassPath: Seq[URL] = userClassPath 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/executor/RayCoarseGrainedExecutorBackend.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.resource.ResourceProfile 24 | import org.apache.spark.rpc.RpcEnv 25 | 26 | class RayCoarseGrainedExecutorBackend( 27 | rpcEnv: RpcEnv, 28 | driverUrl: String, 29 | executorId: String, 30 | bindAddress: String, 31 | hostname: String, 32 | cores: Int, 33 | userClassPath: Seq[URL], 34 | env: SparkEnv, 35 | resourcesFileOpt: Option[String], 36 | resourceProfile: ResourceProfile) 37 | extends CoarseGrainedExecutorBackend( 38 | rpcEnv, 39 | driverUrl, 40 | executorId, 41 | bindAddress, 42 | hostname, 43 | cores, 44 | env, 45 | resourcesFileOpt, 46 | resourceProfile) { 47 | 48 | override def getUserClassPath: Seq[URL] = userClassPath 49 | 50 | } 51 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/ExternalShuffleServiceUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp; 19 | 20 | import java.util.List; 21 | 22 | import io.ray.api.ActorHandle; 23 | import io.ray.api.Ray; 24 | 25 | public class ExternalShuffleServiceUtils { 26 | public static ActorHandle createShuffleService( 27 | String node, List options) { 28 | return Ray.actor(RayExternalShuffleService::new) 29 | .setResource("node:" + node, 0.01) 30 | .setJvmOptions(options).remote(); 31 | } 32 | 33 | public static void startShuffleService( 34 | ActorHandle handle) { 35 | handle.task(RayExternalShuffleService::start).remote(); 36 | } 37 | 38 | public static void stopShuffleService( 39 | ActorHandle handle) { 40 | handle.task(RayExternalShuffleService::stop).remote(); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark330 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK330_DESCRIPTOR = SparkShimDescriptor(3, 3, 0) 24 | val SPARK331_DESCRIPTOR = SparkShimDescriptor(3, 3, 1) 25 | val SPARK332_DESCRIPTOR = SparkShimDescriptor(3, 3, 2) 26 | val SPARK333_DESCRIPTOR = SparkShimDescriptor(3, 3, 3) 27 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK330_DESCRIPTOR", s"$SPARK331_DESCRIPTOR", 28 | s"$SPARK332_DESCRIPTOR", s"$SPARK333_DESCRIPTOR") 29 | val DESCRIPTOR = SPARK332_DESCRIPTOR 30 | } 31 | 32 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 33 | def createShim: SparkShims = { 34 | new Spark330Shims() 35 | } 36 | 37 | def matches(version: String): Boolean = { 38 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark340 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK340_DESCRIPTOR = SparkShimDescriptor(3, 4, 0) 24 | val SPARK341_DESCRIPTOR = SparkShimDescriptor(3, 4, 1) 25 | val SPARK342_DESCRIPTOR = SparkShimDescriptor(3, 4, 2) 26 | val SPARK343_DESCRIPTOR = SparkShimDescriptor(3, 4, 3) 27 | val SPARK344_DESCRIPTOR = SparkShimDescriptor(3, 4, 4) 28 | val DESCRIPTOR_STRINGS = Seq(s"$SPARK340_DESCRIPTOR", s"$SPARK341_DESCRIPTOR", s"$SPARK342_DESCRIPTOR", 29 | s"$SPARK343_DESCRIPTOR", s"$SPARK344_DESCRIPTOR") 30 | val DESCRIPTOR = SPARK341_DESCRIPTOR 31 | } 32 | 33 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 34 | def createShim: SparkShims = { 35 | new Spark340Shims() 36 | } 37 | 38 | def matches(version: String): Boolean = { 39 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.{SparkEnv, TaskContext} 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.executor.RayDPExecutorBackendFactory 24 | import org.apache.spark.sql.types.StructType 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | 27 | sealed abstract class ShimDescriptor 28 | 29 | case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimDescriptor { 30 | override def toString(): String = s"$major.$minor.$patch" 31 | } 32 | 33 | trait SparkShims { 34 | def getShimDescriptor: ShimDescriptor 35 | 36 | def toDataFrame(rdd: JavaRDD[Array[Byte]], schema: String, session: SparkSession): DataFrame 37 | 38 | def getExecutorBackendFactory(): RayDPExecutorBackendFactory 39 | 40 | def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext 41 | 42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema 43 | } 44 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/org/apache/spark/executor/RayDPSpark330ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark330 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor._ 24 | import org.apache.spark.resource.ResourceProfile 25 | import org.apache.spark.rpc.RpcEnv 26 | 27 | class RayDPSpark330ExecutorBackendFactory 28 | extends RayDPExecutorBackendFactory { 29 | override def createExecutorBackend( 30 | rpcEnv: RpcEnv, 31 | driverUrl: String, 32 | executorId: String, 33 | bindAddress: String, 34 | hostname: String, 35 | cores: Int, 36 | userClassPath: Seq[URL], 37 | env: SparkEnv, 38 | resourcesFileOpt: Option[String], 39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 40 | new RayCoarseGrainedExecutorBackend( 41 | rpcEnv, 42 | driverUrl, 43 | executorId, 44 | bindAddress, 45 | hostname, 46 | cores, 47 | userClassPath, 48 | env, 49 | resourcesFileOpt, 50 | resourceProfile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/executor/RayDPSpark340ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark340 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor._ 24 | import org.apache.spark.resource.ResourceProfile 25 | import org.apache.spark.rpc.RpcEnv 26 | 27 | class RayDPSpark340ExecutorBackendFactory 28 | extends RayDPExecutorBackendFactory { 29 | override def createExecutorBackend( 30 | rpcEnv: RpcEnv, 31 | driverUrl: String, 32 | executorId: String, 33 | bindAddress: String, 34 | hostname: String, 35 | cores: Int, 36 | userClassPath: Seq[URL], 37 | env: SparkEnv, 38 | resourcesFileOpt: Option[String], 39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 40 | new RayCoarseGrainedExecutorBackend( 41 | rpcEnv, 42 | driverUrl, 43 | executorId, 44 | bindAddress, 45 | hostname, 46 | cores, 47 | userClassPath, 48 | env, 49 | resourcesFileOpt, 50 | resourceProfile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/executor/RayDPSpark350ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark350 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor._ 24 | import org.apache.spark.resource.ResourceProfile 25 | import org.apache.spark.rpc.RpcEnv 26 | 27 | class RayDPSpark350ExecutorBackendFactory 28 | extends RayDPExecutorBackendFactory { 29 | override def createExecutorBackend( 30 | rpcEnv: RpcEnv, 31 | driverUrl: String, 32 | executorId: String, 33 | bindAddress: String, 34 | hostname: String, 35 | cores: Int, 36 | userClassPath: Seq[URL], 37 | env: SparkEnv, 38 | resourcesFileOpt: Option[String], 39 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 40 | new RayCoarseGrainedExecutorBackend( 41 | rpcEnv, 42 | driverUrl, 43 | executorId, 44 | bindAddress, 45 | hostname, 46 | cores, 47 | userClassPath, 48 | env, 49 | resourcesFileOpt, 50 | resourceProfile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/scheduler/cluster/raydp/RayClusterManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.scheduler.cluster.raydp 19 | 20 | import org.apache.spark.SparkContext 21 | import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} 22 | 23 | private[spark] class RayClusterManager extends ExternalClusterManager { 24 | 25 | override def canCreate(masterURL: String): Boolean = { 26 | masterURL.startsWith("ray") 27 | } 28 | 29 | override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { 30 | new TaskSchedulerImpl(sc) 31 | } 32 | 33 | override def createSchedulerBackend( 34 | sc: SparkContext, 35 | masterURL: String, 36 | scheduler: TaskScheduler): SchedulerBackend = { 37 | new RayCoarseGrainedSchedulerBackend( 38 | sc, 39 | scheduler.asInstanceOf[TaskSchedulerImpl], 40 | masterURL) 41 | } 42 | 43 | override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { 44 | scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/rdd/RayObjectRefRDD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.rdd 19 | 20 | import java.util.List; 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import io.ray.runtime.generated.Common.Address 25 | 26 | import org.apache.spark.{Partition, SparkContext, TaskContext} 27 | import org.apache.spark.raydp.RayDPUtils 28 | import org.apache.spark.sql.Row 29 | 30 | private[spark] class RayObjectRefRDDPartition(idx: Int) extends Partition { 31 | val index = idx 32 | } 33 | 34 | private[spark] 35 | class RayObjectRefRDD( 36 | sc: SparkContext, 37 | locations: List[Array[Byte]]) 38 | extends RDD[Row](sc, Nil) { 39 | 40 | override def getPartitions: Array[Partition] = { 41 | (0 until locations.size()).map { i => 42 | new RayObjectRefRDDPartition(i).asInstanceOf[Partition] 43 | }.toArray 44 | } 45 | 46 | override def compute(split: Partition, context: TaskContext): Iterator[Row] = { 47 | (Row(split.index) :: Nil).iterator 48 | } 49 | 50 | override def getPreferredLocations(split: Partition): Seq[String] = { 51 | Seq(Address.parseFrom(locations.get(split.index)).getIpAddress()) 52 | } 53 | } 54 | 55 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/org/apache/spark/executor/RayDPSpark322ExecutorBackendFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.executor.spark322 19 | 20 | import java.net.URL 21 | 22 | import org.apache.spark.SparkEnv 23 | import org.apache.spark.executor.CoarseGrainedExecutorBackend 24 | import org.apache.spark.executor.RayDPExecutorBackendFactory 25 | import org.apache.spark.resource.ResourceProfile 26 | import org.apache.spark.rpc.RpcEnv 27 | 28 | class RayDPSpark322ExecutorBackendFactory 29 | extends RayDPExecutorBackendFactory { 30 | override def createExecutorBackend( 31 | rpcEnv: RpcEnv, 32 | driverUrl: String, 33 | executorId: String, 34 | bindAddress: String, 35 | hostname: String, 36 | cores: Int, 37 | userClassPath: Seq[URL], 38 | env: SparkEnv, 39 | resourcesFileOpt: Option[String], 40 | resourceProfile: ResourceProfile): CoarseGrainedExecutorBackend = { 41 | new CoarseGrainedExecutorBackend( 42 | rpcEnv, 43 | driverUrl, 44 | executorId, 45 | bindAddress, 46 | hostname, 47 | cores, 48 | userClassPath, 49 | env, 50 | resourcesFileOpt, 51 | resourceProfile) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/Messages.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import org.apache.spark.rpc.RpcEndpointRef 21 | 22 | private[deploy] sealed trait RayDPDeployMessage extends Serializable 23 | 24 | case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) 25 | extends RayDPDeployMessage 26 | 27 | case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends RayDPDeployMessage 28 | 29 | case class UnregisterApplication(appId: String) extends RayDPDeployMessage 30 | 31 | case class RegisterExecutor(executorId: String, nodeIp: String) extends RayDPDeployMessage 32 | 33 | case class ExecutorStarted(executorId: String) extends RayDPDeployMessage 34 | 35 | case class RequestExecutors(appId: String, requestedTotal: Int) extends RayDPDeployMessage 36 | 37 | case class KillExecutors(appId: String, executorIds: Seq[String]) extends RayDPDeployMessage 38 | 39 | case class RequestAddPendingRestartedExecutor(executorId: String) 40 | extends RayDPDeployMessage 41 | 42 | case class AddPendingRestartedExecutorReply(newExecutorId: Option[String]) 43 | extends RayDPDeployMessage 44 | 45 | case class RecacheRDD(rddId: Int) extends RayDPDeployMessage 46 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark340 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.TaskContext 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 24 | import org.apache.spark.sql.execution.arrow.ArrowConverters 25 | import org.apache.spark.sql.types._ 26 | import org.apache.spark.sql.util.ArrowUtils 27 | 28 | object SparkSqlUtils { 29 | def toDataFrame( 30 | arrowBatchRDD: JavaRDD[Array[Byte]], 31 | schemaString: String, 32 | session: SparkSession): DataFrame = { 33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] 34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone 35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter => 36 | val context = TaskContext.get() 37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) 38 | } 39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema) 40 | } 41 | 42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /python/raydp/tests/test_torch_sequential.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import pytest 19 | import sys 20 | import torch 21 | import raydp 22 | from raydp.torch import TorchEstimator 23 | 24 | def test_torch_estimator(spark_on_ray_small): 25 | ##prepare the data 26 | customers = [ 27 | (1,'James', 21, 6), 28 | (2, "Liz", 25, 8), 29 | (3, "John", 31, 6), 30 | (4, "Jennifer", 45, 7), 31 | (5, "Robert", 41, 5), 32 | (6, "Sandra", 45, 8) 33 | ] 34 | df = spark_on_ray_small.createDataFrame(customers, ["cID", "name", "age", "grade"]) 35 | 36 | ##create model 37 | model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2,1)) 38 | optimizer = torch.optim.Adam(model.parameters()) 39 | loss = torch.nn.MSELoss() 40 | 41 | #config 42 | estimator = TorchEstimator( 43 | model = model, 44 | optimizer = optimizer, 45 | loss = loss, 46 | num_workers = 3, 47 | num_epochs = 5, 48 | feature_columns = ["age"], 49 | feature_types = torch.float, 50 | label_column = "grade", 51 | label_type = torch.float, 52 | batch_size = 1 53 | ) 54 | estimator.fit_on_spark(df) 55 | 56 | if __name__ == "__main__": 57 | sys.exit(pytest.main(["-v", __file__])) -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayExternalShuffleService.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import io.ray.api.Ray; 21 | 22 | import org.apache.spark.{SecurityManager, SparkConf} 23 | import org.apache.spark.deploy.ExternalShuffleService 24 | import org.apache.spark.internal.Logging 25 | 26 | class RayExternalShuffleService() extends Logging { 27 | val conf = new SparkConf() 28 | val mgr = new SecurityManager(conf) 29 | val instance = new ExternalShuffleService(conf, mgr) 30 | 31 | def start(): Unit = { 32 | instance.start() 33 | } 34 | 35 | def stop(): Unit = { 36 | instance.stop() 37 | Ray.exitActor() 38 | } 39 | } 40 | 41 | object RayExternalShuffleService { 42 | def getShuffleConf(conf: SparkConf): Array[String] = { 43 | // all conf needed by external shuffle service 44 | var shuffleConf = conf.getAll.filter { 45 | case (k, v) => k.startsWith("spark.shuffle") 46 | }.map { 47 | case (k, v) => 48 | "-D" + k + "=" + v 49 | } 50 | val localDirKey = "spark.local.dir" 51 | if (conf.contains(localDirKey)) { 52 | shuffleConf = shuffleConf :+ 53 | "-D" + localDirKey + "=" + conf.get(localDirKey) 54 | } 55 | shuffleConf 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/org/apache/spark/sql/SparkSqlUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.spark350 19 | 20 | import org.apache.arrow.vector.types.pojo.Schema 21 | import org.apache.spark.TaskContext 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkSession} 24 | import org.apache.spark.sql.execution.arrow.ArrowConverters 25 | import org.apache.spark.sql.types._ 26 | import org.apache.spark.sql.util.ArrowUtils 27 | 28 | object SparkSqlUtils { 29 | def toDataFrame( 30 | arrowBatchRDD: JavaRDD[Array[Byte]], 31 | schemaString: String, 32 | session: SparkSession): DataFrame = { 33 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] 34 | val timeZoneId = session.sessionState.conf.sessionLocalTimeZone 35 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter => 36 | val context = TaskContext.get() 37 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId,false, context) 38 | } 39 | session.internalCreateDataFrame(rdd.setName("arrow"), schema) 40 | } 41 | 42 | def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 43 | ArrowUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId, errorOnDuplicatedFieldNames = false) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark322 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK311_DESCRIPTOR = SparkShimDescriptor(3, 1, 1) 24 | val SPARK312_DESCRIPTOR = SparkShimDescriptor(3, 1, 2) 25 | val SPARK313_DESCRIPTOR = SparkShimDescriptor(3, 1, 3) 26 | val SPARK320_DESCRIPTOR = SparkShimDescriptor(3, 2, 0) 27 | val SPARK321_DESCRIPTOR = SparkShimDescriptor(3, 2, 1) 28 | val SPARK322_DESCRIPTOR = SparkShimDescriptor(3, 2, 2) 29 | val SPARK323_DESCRIPTOR = SparkShimDescriptor(3, 2, 3) 30 | val SPARK324_DESCRIPTOR = SparkShimDescriptor(3, 2, 4) 31 | val DESCRIPTOR_STRINGS = 32 | Seq(s"$SPARK311_DESCRIPTOR", s"$SPARK312_DESCRIPTOR" ,s"$SPARK313_DESCRIPTOR", 33 | s"$SPARK320_DESCRIPTOR", s"$SPARK321_DESCRIPTOR", s"$SPARK322_DESCRIPTOR", 34 | s"$SPARK323_DESCRIPTOR", s"$SPARK324_DESCRIPTOR") 35 | val DESCRIPTOR = SPARK323_DESCRIPTOR 36 | } 37 | 38 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 39 | def createShim: SparkShims = { 40 | new Spark322Shims() 41 | } 42 | 43 | def matches(version: String): Boolean = { 44 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShimProvider.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark350 19 | 20 | import com.intel.raydp.shims.{SparkShims, SparkShimDescriptor} 21 | 22 | object SparkShimProvider { 23 | val SPARK350_DESCRIPTOR = SparkShimDescriptor(3, 5, 0) 24 | val SPARK351_DESCRIPTOR = SparkShimDescriptor(3, 5, 1) 25 | val SPARK352_DESCRIPTOR = SparkShimDescriptor(3, 5, 2) 26 | val SPARK353_DESCRIPTOR = SparkShimDescriptor(3, 5, 3) 27 | val SPARK354_DESCRIPTOR = SparkShimDescriptor(3, 5, 4) 28 | val SPARK355_DESCRIPTOR = SparkShimDescriptor(3, 5, 5) 29 | val SPARK356_DESCRIPTOR = SparkShimDescriptor(3, 5, 6) 30 | val SPARK357_DESCRIPTOR = SparkShimDescriptor(3, 5, 7) 31 | val DESCRIPTOR_STRINGS = Seq( 32 | s"$SPARK350_DESCRIPTOR", s"$SPARK351_DESCRIPTOR", s"$SPARK352_DESCRIPTOR", 33 | s"$SPARK353_DESCRIPTOR", s"$SPARK354_DESCRIPTOR", s"$SPARK355_DESCRIPTOR", 34 | s"$SPARK356_DESCRIPTOR", s"$SPARK357_DESCRIPTOR" 35 | ) 36 | val DESCRIPTOR = SPARK350_DESCRIPTOR 37 | } 38 | 39 | class SparkShimProvider extends com.intel.raydp.shims.SparkShimProvider { 40 | def createShim: SparkShims = { 41 | new Spark350Shims() 42 | } 43 | 44 | def matches(version: String): Boolean = { 45 | SparkShimProvider.DESCRIPTOR_STRINGS.contains(version) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/ApplicationDescription.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import scala.collection.Map 21 | 22 | private[spark] case class Command( 23 | driverUrl: String, 24 | environment: Map[String, String], 25 | classPathEntries: Seq[String], 26 | libraryPathEntries: Seq[String], 27 | javaOpts: Seq[String]) { 28 | 29 | def withNewJavaOpts(newJavaOptions: Seq[String]): Command = { 30 | Command(driverUrl, environment, classPathEntries, libraryPathEntries, newJavaOptions) 31 | } 32 | } 33 | 34 | private[spark] case class ApplicationDescription( 35 | name: String, 36 | numExecutors: Int, 37 | coresPerExecutor: Option[Int], 38 | memoryPerExecutorMB: Int, 39 | rayActorCPU: Double, 40 | command: Command, 41 | user: String = System.getProperty("user.name", ""), 42 | resourceReqsPerExecutor: Map[String, Double] = Map.empty) { 43 | 44 | def withNewCommand(newCommand: Command): ApplicationDescription = { 45 | ApplicationDescription(name = name, 46 | numExecutors = numExecutors, coresPerExecutor = coresPerExecutor, 47 | memoryPerExecutorMB = memoryPerExecutorMB, command = newCommand, user = user, 48 | resourceReqsPerExecutor = resourceReqsPerExecutor, 49 | rayActorCPU = rayActorCPU) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /python/raydp/tests/test_spark_master_memory.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytest 3 | import ray 4 | import raydp 5 | from ray.cluster_utils import Cluster 6 | from ray.util.state import list_actors 7 | 8 | 9 | def test_spark_master_memory_custom(jdk17_extra_spark_configs): 10 | cluster = Cluster( 11 | initialize_head=True, 12 | head_node_args={ 13 | "num_cpus": 2, 14 | "resources": {"master": 10}, 15 | "include_dashboard": True, 16 | "dashboard_port": 8270, 17 | }, 18 | ) 19 | ray.init(address=cluster.address, 20 | dashboard_port=cluster.head_node.dashboard_grpc_port, 21 | include_dashboard=True) 22 | 23 | custom_memory = 100 * 1024 * 1024 # 100MB in bytes 24 | configs = jdk17_extra_spark_configs.copy() 25 | # Config under test: set Spark Master actor memory via RayDP config 26 | configs["spark.ray.raydp_spark_master.actor.resource.memory"] = str(custom_memory) 27 | # Also require the master custom resource so the actor is scheduled on the head 28 | configs["spark.ray.raydp_spark_master.actor.resource.master"] = "1" 29 | 30 | app_name = "test_spark_master_memory_custom" 31 | 32 | spark = raydp.init_spark( 33 | app_name=app_name, 34 | num_executors=1, 35 | executor_cores=1, 36 | executor_memory="500M", 37 | configs=configs, 38 | ) 39 | 40 | # Trigger the Spark master / RayDPSparkMaster startup 41 | spark.createDataFrame([(1, 2)], ["a", "b"]).count() 42 | 43 | # RayDPSparkMaster name is app_name + RAYDP_SPARK_MASTER_SUFFIX 44 | master_actor_name = f"{app_name}_SPARK_MASTER" 45 | 46 | actor = ray.get_actor(master_actor_name) 47 | assert actor is not None 48 | 49 | # Query Ray state for this actor 50 | actor_state = list_actors(filters=[("actor_id", "=", actor._actor_id.hex())], detail=True)[0] 51 | resources = actor_state.required_resources 52 | 53 | assert resources["memory"] == custom_memory 54 | assert resources["master"] == 1 55 | 56 | spark.stop() 57 | raydp.stop_spark() 58 | ray.shutdown() 59 | cluster.shutdown() 60 | 61 | 62 | if __name__ == "__main__": 63 | sys.exit(pytest.main(["-v", __file__])) 64 | 65 | 66 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/raydp/RayDPUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.raydp; 19 | 20 | import io.ray.api.ObjectRef; 21 | import io.ray.api.Ray; 22 | import io.ray.api.id.ObjectId; 23 | import io.ray.runtime.AbstractRayRuntime; 24 | import io.ray.runtime.object.ObjectRefImpl; 25 | 26 | public class RayDPUtils { 27 | 28 | /** 29 | * Convert ObjectRef to subclass ObjectRefImpl. Throw RuntimeException if it is not instance 30 | * of ObjectRefImpl. We can't import the ObjectRefImpl in scala code, so we do the 31 | * conversion at here. 32 | */ 33 | public static ObjectRefImpl convert(ObjectRef obj) { 34 | if (obj instanceof ObjectRefImpl) { 35 | return (ObjectRefImpl)obj; 36 | } else { 37 | throw new RuntimeException(obj.getClass() + " is not ObjectRefImpl"); 38 | } 39 | } 40 | 41 | /** 42 | * Create ObjectRef from Array[Byte] and register ownership. 43 | * We can't import the ObjectRefImpl in scala code, so we do the conversion at here. 44 | */ 45 | public static ObjectRef readBinary(byte[] obj, Class clazz, byte[] ownerAddress) { 46 | ObjectId id = new ObjectId(obj); 47 | ObjectRefImpl ref = new ObjectRefImpl<>(id, clazz, false); 48 | AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); 49 | runtime.getObjectStore().registerOwnershipInfoAndResolveFuture( 50 | id, null, ownerAddress 51 | ); 52 | return ref; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Running RayDP on k8s cluster 2 | 3 | ## Build docker image 4 | Build the docker image to use in K8S with the following command, and this will create an image tag with `oap-project/raydp:latest` 5 | ```shell 6 | # under ${RAYDP_HOME}/docker 7 | ./build-docker.sh 8 | ``` 9 | 10 | You can install our nightly build with `pip install raydp --pre` or `pip install raydp-nightly`.To install raydp-nightly in the image, modify the following code in `Dockerfile`: 11 | ```Dockerfile 12 | RUN sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get update -y \ 13 | && sudo http_proxy=${HTTP_PROXY} https_proxy=${HTTPS_PROXY} apt-get install -y openjdk-8-jdk \ 14 | && sudo mkdir /raydp \ 15 | && sudo chown -R ray /raydp \ 16 | && $HOME/anaconda3/bin/pip --no-cache-dir install raydp-nightly 17 | ``` 18 | 19 | Meanwhile, you should install all dependencies of your application in the `Dockerfile`. If suitable, you can change the base image to `ray-ml`: 20 | ```Dockerfile 21 | FROM rayproject/ray-ml:latest 22 | ``` 23 | 24 | Then, you can push the built image to repository or spread to the k8s worker nodes. 25 | 26 | ## Deploy ray cluster with Helm 27 | You need to create a Helm chart first. To start with, check out this [example ray cluster Helm chart](https://github.com/ray-project/kuberay/tree/master/helm-chart/ray-cluster). You can clone this repo and copy this directory, then modify `values.yaml` to use the previously built image. 28 | 29 | ```yaml 30 | image: 31 | repository: oap-project/raydp 32 | tag: latest 33 | pullPolicy: IfNotPresent 34 | ``` 35 | 36 | You can also change other fields in this file to specify number of workers, etc. 37 | 38 | Then, you need to deploy the KubeRay operator first, please refer to [here](https://docs.ray.io/en/latest/cluster/kubernetes/getting-started.html#kuberay-quickstart) for instructions. You can now deploy a Ray cluster with RayDP installed via `helm install ray-cluster PATH_to_CHART`. 39 | 40 | ## Access the cluster 41 | Check here [here](https://docs.ray.io/en/master/cluster/kubernetes/getting-started.html#running-applications-on-a-ray-cluster) to see how to run applications on the cluster you just deployed. 42 | 43 | ## Legacy 44 | If you are using Ray versions before 2.0, you can try this command. 45 | ```shell 46 | ray up ${RAYDP_HOME}/docker/legacy.yaml 47 | ``` -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/rdd/RayDatasetRDD.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.rdd 19 | 20 | import java.util.List; 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import io.ray.runtime.generated.Common.Address 25 | 26 | import org.apache.spark.{Partition, SparkContext, TaskContext} 27 | import org.apache.spark.api.java.JavaSparkContext 28 | import org.apache.spark.raydp.RayDPUtils 29 | import org.apache.spark.sql.raydp.ObjectStoreReader 30 | 31 | private[spark] class RayDatasetRDDPartition(val ref: Array[Byte], idx: Int) extends Partition { 32 | val index = idx 33 | } 34 | 35 | private[spark] 36 | class RayDatasetRDD( 37 | jsc: JavaSparkContext, 38 | @transient val objectIds: List[Array[Byte]], 39 | locations: List[Array[Byte]]) 40 | extends RDD[Array[Byte]](jsc.sc, Nil) { 41 | 42 | override def getPartitions: Array[Partition] = { 43 | objectIds.asScala.zipWithIndex.map { case (k, i) => 44 | new RayDatasetRDDPartition(k, i).asInstanceOf[Partition] 45 | }.toArray 46 | } 47 | 48 | override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { 49 | val ref = split.asInstanceOf[RayDatasetRDDPartition].ref 50 | ObjectStoreReader.getBatchesFromStream(ref, locations.get(split.index)) 51 | } 52 | 53 | override def getPreferredLocations(split: Partition): Seq[String] = { 54 | val address = Address.parseFrom(locations.get(split.index)) 55 | Seq(address.getIpAddress()) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /core/shims/spark322/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark322 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark322._ 24 | import org.apache.spark.spark322.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark322.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark322Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark322ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark330/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark330 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark330._ 24 | import org.apache.spark.spark330.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark330.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark330Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark330ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark340/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark340 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark340._ 24 | import org.apache.spark.spark340.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark340.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark340Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark340ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /core/shims/spark350/src/main/scala/com/intel/raydp/shims/SparkShims.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims.spark350 19 | 20 | import org.apache.spark.{SparkEnv, TaskContext} 21 | import org.apache.spark.api.java.JavaRDD 22 | import org.apache.spark.executor.RayDPExecutorBackendFactory 23 | import org.apache.spark.executor.spark350._ 24 | import org.apache.spark.spark350.TaskContextUtils 25 | import org.apache.spark.sql.{DataFrame, SparkSession} 26 | import org.apache.spark.sql.spark350.SparkSqlUtils 27 | import com.intel.raydp.shims.{ShimDescriptor, SparkShims} 28 | import org.apache.arrow.vector.types.pojo.Schema 29 | import org.apache.spark.sql.types.StructType 30 | 31 | class Spark350Shims extends SparkShims { 32 | override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR 33 | 34 | override def toDataFrame( 35 | rdd: JavaRDD[Array[Byte]], 36 | schema: String, 37 | session: SparkSession): DataFrame = { 38 | SparkSqlUtils.toDataFrame(rdd, schema, session) 39 | } 40 | 41 | override def getExecutorBackendFactory(): RayDPExecutorBackendFactory = { 42 | new RayDPSpark350ExecutorBackendFactory() 43 | } 44 | 45 | override def getDummyTaskContext(partitionId: Int, env: SparkEnv): TaskContext = { 46 | TaskContextUtils.getDummyTaskContext(partitionId, env) 47 | } 48 | 49 | override def toArrowSchema(schema : StructType, timeZoneId : String) : Schema = { 50 | SparkSqlUtils.toArrowSchema(schema = schema, timeZoneId = timeZoneId) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # RayDP Examples 2 | Here are a few examples showing how RayDP works together with other libraries, such as PyTorch, Tensorflow, XGBoost and Horovod. 3 | 4 | In order to run these examples, you may need to install corresponding dependencies. For installation guides, please refer to their homepages. Notice that we need to install [xgboost_ray](https://github.com/ray-project/xgboost_ray) to run the xgboost example. In addition, if you are running the examples in a ray cluster, all nodes should have the dependencies installed. 5 | 6 | ## NYC Taxi Fare Prediction Dataset 7 | We have a few examples which use this dataset. 8 | You can run our examples right away after you clone our repo, because we include a small example dataset generated randomly using `examples/random_nyctaxi.py`. Generated datasets just demonstrates that our examples can work, but the trained models might not be meaningful. 9 | 10 | The original dataset can be downloaded [here](https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data). After you download it, please modify the variable `NYC_TRAIN_CSV` in `data_process.py` and point it to where `train.csv` is saved. 11 | 12 | ## Horovod 13 | To run the example, please install horovod via `pip install horovod[pytorch, ray]`. In addition, `HOROVOD_WITH_PYTORCH` and `HOROVOD_WITH_GLOO` should be set to `1` before pip. Notice that macOS users need to first install `libuv` via `brew install libuv`. Please refer to [here](https://horovod.readthedocs.io/en/stable/install_include.html) for details. 14 | 15 | When running `horovod_nyctaxi.py`, do not use `horovodrun`. Check [here](https://horovod.readthedocs.io/en/stable/ray_include.html) for more information. 16 | 17 | ## RaySGD Example 18 | In the RaySGD example, we demonstrate how to use our `MLDataset` API. After we use Spark to transform the dataset, we call `RayMLDataset.from_spark` to write the Spark DataFrames into Ray object store, using Apache Arrow format. We then convert the data to `pandas` DataFrame, hopefully zero-copy. Finally, they can be consumed by any framework supports `numpy` format, such as PyTorch or Tensorflow. `MLDataset` is partitioned, or sharded, just like Spark DataFrames. Their numbers of partitions are not required to be the same. However, the number of shards of `MLDataset` should be the same as the number of workers of `TorchTrainer` or `TFTrainer`, so that each worker is mapped to a shard. 19 | -------------------------------------------------------------------------------- /python/raydp/tests/test_xgboost.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import sys 20 | import shutil 21 | import platform 22 | import pytest 23 | import pyspark 24 | import numpy as np 25 | from pyspark.sql.functions import rand 26 | 27 | from raydp.xgboost import XGBoostEstimator 28 | from raydp.utils import random_split 29 | 30 | @pytest.mark.parametrize("use_fs_directory", [True, False]) 31 | def test_xgb_estimator(spark_on_ray_small, use_fs_directory): 32 | if platform.system() == "Darwin": 33 | pytest.skip("Skip xgboost test on MacOS") 34 | spark = spark_on_ray_small 35 | 36 | # calculate z = 3 * x + 4 * y + 5 37 | df: pyspark.sql.DataFrame = spark.range(0, 100000) 38 | df = df.withColumn("x", rand() * 100) # add x column 39 | df = df.withColumn("y", rand() * 1000) # ad y column 40 | df = df.withColumn("z", df.x * 3 + df.y * 4 + rand() + 5) # ad z column 41 | df = df.select(df.x, df.y, df.z) 42 | 43 | train_df, test_df = random_split(df, [0.7, 0.3]) 44 | params = {} 45 | estimator = XGBoostEstimator(params, "z", resources_per_worker={"CPU": 1}) 46 | if use_fs_directory: 47 | dir = os.path.dirname(os.path.realpath(__file__)) + "/test_xgboost" 48 | uri = "file://" + dir 49 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri) 50 | else: 51 | estimator.fit_on_spark(train_df, test_df) 52 | print(estimator.get_model().inplace_predict(np.asarray([[1,2]]))) 53 | if use_fs_directory: 54 | shutil.rmtree(dir) 55 | 56 | if __name__ == '__main__': 57 | sys.exit(pytest.main(["-v", __file__])) -------------------------------------------------------------------------------- /examples/raydp-submit.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname 2 | import sys 3 | import json 4 | import subprocess 5 | import shlex 6 | import ray 7 | import pyspark 8 | 9 | ray.init(address="auto") 10 | node = ray.worker.global_worker.node 11 | options = {} 12 | options["ray"] = {} 13 | options["ray"]["run-mode"] = "CLUSTER" 14 | options["ray"]["node-ip"] = node.node_ip_address 15 | options["ray"]["address"] = node.address 16 | options["ray"]["session-dir"] = node.get_session_dir_path() 17 | 18 | ray.shutdown() 19 | conf_path = dirname(__file__) + "/ray.conf" 20 | with open(conf_path, "w") as f: 21 | json.dump(options, f) 22 | 23 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access 24 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP. 25 | java_opts = " ".join([ 26 | "-XX:+IgnoreUnrecognizedVMOptions", 27 | "--add-opens=java.base/java.lang=ALL-UNNAMED", 28 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", 29 | "--add-opens=java.base/java.io=ALL-UNNAMED", 30 | "--add-opens=java.base/java.net=ALL-UNNAMED", 31 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 32 | "--add-opens=java.base/java.math=ALL-UNNAMED", 33 | "--add-opens=java.base/java.text=ALL-UNNAMED", 34 | "--add-opens=java.base/java.util=ALL-UNNAMED", 35 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", 36 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", 37 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", 38 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", 39 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED", 40 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", 41 | ]) 42 | 43 | command = ["bin/raydp-submit", "--ray-conf", conf_path] 44 | command += ["--conf", "spark.executor.cores=1"] 45 | command += ["--conf", "spark.executor.instances=1"] 46 | command += ["--conf", "spark.executor.memory=500m"] 47 | command += ["--conf", f"spark.executor.extraJavaOptions={java_opts}"] 48 | command += ["--conf", f"spark.driver.extraJavaOptions={java_opts}"] 49 | command += ["--conf", f"spark.ray.raydp_app_master.extraJavaOptions={java_opts}"] 50 | example_path = dirname(pyspark.__file__) 51 | # run SparkPi as example 52 | command.append(example_path + "/examples/src/main/python/pi.py") 53 | cmd_str = " ".join(shlex.quote(arg) for arg in command) 54 | sys.exit(subprocess.run(cmd_str, check=True, shell=True).returncode) 55 | -------------------------------------------------------------------------------- /python/raydp/torch/torch_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | module = sys.modules[__name__] 3 | 4 | def try_import_torchmetrics(): 5 | """Tries importing torchmetrics and returns the module (or None). 6 | Returns: 7 | torchmetrics modules. 8 | """ 9 | try: 10 | # pylint: disable=import-outside-toplevel 11 | import torchmetrics 12 | 13 | return torchmetrics 14 | except ImportError as torchmetrics_not_exist: 15 | raise ImportError( 16 | "Could not import torchmetrics! Raydp TorchEstimator requires " 17 | "you to install torchmetrics: " 18 | "`pip install torchmetrics`." 19 | ) from torchmetrics_not_exist 20 | 21 | class TorchMetric(): 22 | def __init__(self, metrics_name, metrics_config): 23 | torchmetrics = try_import_torchmetrics() 24 | self._metrics_name = metrics_name 25 | self._metrics_func = {} 26 | if self._metrics_name is not None: 27 | assert isinstance(metrics_name, list), "metrics_name must be a list" 28 | for metric in self._metrics_name: 29 | if isinstance(metric, torchmetrics.Metric): 30 | self._metrics_func[metric.__class__.__name__] = metric 31 | elif isinstance(metric, str) and hasattr(torchmetrics, metric): 32 | if metrics_config is not None and metrics_config[metric] is not None: 33 | self._metrics_func[metric] = getattr(torchmetrics, metric)( 34 | **metrics_config[metric]) 35 | else: 36 | self._metrics_func[metric] = getattr(torchmetrics, metric)() 37 | else: 38 | raise Exception( 39 | "Unsupported parameter, we only support list of " 40 | "torchmetrics.Metric instances or arr of torchmetrics.") 41 | 42 | def update(self, preds, targets): 43 | for metric in self._metrics_func: 44 | self._metrics_func[metric].update(preds, targets) 45 | 46 | def compute(self): 47 | epoch_res = {} 48 | for metric in self._metrics_func: 49 | epoch_res[metric] = self._metrics_func[metric].compute().item() 50 | 51 | return epoch_res 52 | 53 | def reset(self): 54 | for metric in self._metrics_func: 55 | self._metrics_func[metric].reset() 56 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | name: RayDP PyPi 19 | 20 | on: 21 | schedule: 22 | - cron: '0 0 * * *' 23 | # can manually trigger the workflow 24 | workflow_dispatch: 25 | 26 | permissions: # added using https://github.com/step-security/secure-repo 27 | contents: read 28 | 29 | jobs: 30 | build-and-publish: 31 | # do not run in forks 32 | if: ${{ github.repository_owner == 'ray-project' }} 33 | name: build wheel and upload 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master 37 | - name: Set up Python 3.10 38 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 39 | with: 40 | python-version: 3.10.14 41 | - name: Set up JDK 1.8 42 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 43 | with: 44 | java-version: 1.8 45 | - name: days since the commit date 46 | run: | 47 | : 48 | timestamp=$(git log --no-walk --date=unix --format=%cd $GITHUB_SHA) 49 | days=$(( ( $(date --utc +%s) - $timestamp ) / 86400 )) 50 | if [ $days -eq 0 ]; then 51 | echo COMMIT_TODAY=true >> $GITHUB_ENV 52 | fi 53 | - name: Build wheel 54 | if: env.COMMIT_TODAY == 'true' 55 | env: 56 | RAYDP_BUILD_MODE: nightly 57 | run: pip install wheel grpcio-tools && ./build.sh 58 | - name: Upload 59 | if: env.COMMIT_TODAY == 'true' 60 | uses: pypa/gh-action-pypi-publish@v1.13.0 61 | with: 62 | password: ${{ secrets.PYPI_API_TOKEN }} 63 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.raydp 19 | 20 | import java.io.ByteArrayInputStream 21 | import java.nio.channels.{Channels, ReadableByteChannel} 22 | import java.util.List 23 | 24 | import com.intel.raydp.shims.SparkShimLoader 25 | 26 | import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} 27 | import org.apache.spark.raydp.RayDPUtils 28 | import org.apache.spark.rdd.{RayDatasetRDD, RayObjectRefRDD} 29 | import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} 30 | import org.apache.spark.sql.catalyst.expressions.GenericRow 31 | import org.apache.spark.sql.execution.arrow.ArrowConverters 32 | import org.apache.spark.sql.types.{IntegerType, StructType} 33 | 34 | object ObjectStoreReader { 35 | def createRayObjectRefDF( 36 | spark: SparkSession, 37 | locations: List[Array[Byte]]): DataFrame = { 38 | val rdd = new RayObjectRefRDD(spark.sparkContext, locations) 39 | val schema = new StructType().add("idx", IntegerType) 40 | spark.createDataFrame(rdd, schema) 41 | } 42 | 43 | def RayDatasetToDataFrame( 44 | sparkSession: SparkSession, 45 | rdd: RayDatasetRDD, 46 | schema: String): DataFrame = { 47 | SparkShimLoader.getSparkShims.toDataFrame(JavaRDD.fromRDD(rdd), schema, sparkSession) 48 | } 49 | 50 | def getBatchesFromStream( 51 | ref: Array[Byte], 52 | ownerAddress: Array[Byte]): Iterator[Array[Byte]] = { 53 | val objectRef = RayDPUtils.readBinary(ref, classOf[Array[Byte]], ownerAddress) 54 | ArrowConverters.getBatchesFromStream( 55 | Channels.newChannel(new ByteArrayInputStream(objectRef.get))) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /core/agent/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-parent 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-agent 15 | RayDP Java Agent 16 | jar 17 | 18 | 19 | 20 | 21 | org.apache.maven.plugins 22 | maven-compiler-plugin 23 | 3.8.0 24 | 25 | 1.8 26 | 1.8 27 | 28 | 29 | 30 | org.apache.maven.plugins 31 | maven-jar-plugin 32 | 3.3.0 33 | 34 | 35 | 36 | org.apache.spark.raydp.Agent 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | org.apache.logging.log4j 47 | log4j-core 48 | 2.25.3 49 | 50 | 51 | org.apache.logging.log4j 52 | log4j-slf4j-impl 53 | 2.17.1 54 | 55 | 56 | org.slf4j 57 | slf4j-api 58 | 1.7.32 59 | 60 | 61 | com.intel 62 | raydp 63 | ${project.version} 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/deploy/raydp/RayAppMasterUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp; 19 | 20 | import java.util.List; 21 | import java.util.Map; 22 | 23 | import io.ray.api.ActorHandle; 24 | import io.ray.api.Ray; 25 | import io.ray.api.call.ActorCreator; 26 | import org.apache.spark.raydp.SparkOnRayConfigs; 27 | 28 | public class RayAppMasterUtils { 29 | public static ActorHandle createAppMaster( 30 | String cp, 31 | String name, 32 | List jvmOptions, 33 | Map appMasterResource) { 34 | ActorCreator creator = Ray.actor(RayAppMaster::new, cp); 35 | if (name != null) { 36 | creator.setName(name); 37 | } 38 | jvmOptions.add("-cp"); 39 | jvmOptions.add(cp); 40 | creator.setJvmOptions(jvmOptions); 41 | for(Map.Entry resource : appMasterResource.entrySet()) { 42 | String resourceName = resource.getKey() 43 | .substring(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX.length() + 1); 44 | creator.setResource(resourceName, resource.getValue()); 45 | } 46 | 47 | return creator.remote(); 48 | } 49 | 50 | public static String getMasterUrl( 51 | ActorHandle handle) { 52 | return handle.task(RayAppMaster::getMasterUrl).remote().get(); 53 | } 54 | 55 | public static Map getRestartedExecutors( 56 | ActorHandle handle) { 57 | return handle.task(RayAppMaster::getRestartedExecutors).remote().get(); 58 | } 59 | 60 | public static void stopAppMaster( 61 | ActorHandle handle) { 62 | handle.task(RayAppMaster::stop).remote().get(); 63 | handle.kill(); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /core/javastyle-suppressions.xml: -------------------------------------------------------------------------------- 1 | 17 | 18 | 21 | 22 | 29 | 30 | 31 | 33 | 35 | 37 | 39 | 41 | 43 | 45 | 47 | 49 | 51 | 52 | -------------------------------------------------------------------------------- /python/raydp/mpi/network/network.proto: -------------------------------------------------------------------------------- 1 | // 2 | // Licensed to the Apache Software Foundation (ASF) under one or more 3 | // contributor license agreements. See the NOTICE file distributed with 4 | // this work for additional information regarding copyright ownership. 5 | // The ASF licenses this file to You under the Apache License, Version 2.0 6 | // (the "License"); you may not use this file except in compliance with 7 | // the License. You may obtain a copy of the License at 8 | // 9 | // http://www.apache.org/licenses/LICENSE-2.0 10 | // 11 | // Unless required by applicable law or agreed to in writing, software 12 | // distributed under the License is distributed on an "AS IS" BASIS, 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | // See the License for the specific language governing permissions and 15 | // limitations under the License. 16 | // 17 | 18 | // Syntax version 19 | syntax = "proto3"; 20 | 21 | // Driver Service definition 22 | service DriverService { 23 | // register the worker process to driver which used to tell the worker has started up 24 | rpc RegisterWorker (RegisterWorkerRequest) returns (RegisterWorkerReply); 25 | // register the worker service host and port 26 | rpc RegisterWorkerService (RegisterWorkerServiceRequest) returns (RegisterWorkerServiceReply); 27 | // register the function result 28 | rpc RegisterFuncResult (FunctionResult) returns (Empty); 29 | } 30 | 31 | // Worker Service 32 | service WorkerService { 33 | // run the given function 34 | rpc RunFunction (Function) returns (Empty); 35 | // stop the worker service 36 | rpc Stop (Empty) returns (Empty); 37 | } 38 | 39 | message RegisterWorkerRequest { 40 | // the job id 41 | string job_id = 1; 42 | // the world rank id 43 | int32 world_rank = 2; 44 | } 45 | 46 | message RegisterWorkerReply { 47 | // the all node addresses and used to determine the current node ip adddress 48 | repeated string node_addresses = 3; 49 | } 50 | 51 | message RegisterWorkerServiceRequest { 52 | // the world rank 53 | int32 world_rank = 1; 54 | // the worker service listening ip 55 | string worker_ip = 2; 56 | // the worker service listening port 57 | int32 worker_port = 3; 58 | } 59 | 60 | message RegisterWorkerServiceReply { 61 | // the ray redis address 62 | string ray_address = 1; 63 | // the ray redis password 64 | string redis_password = 2; 65 | } 66 | 67 | message Function { 68 | // the function id 69 | int32 func_id = 1; 70 | // the serialized python function 71 | bytes func = 2; 72 | } 73 | 74 | message FunctionResult { 75 | int32 world_rank = 1; 76 | // the function id 77 | int32 func_id = 2; 78 | // the function results 79 | bytes result = 3; 80 | } 81 | 82 | message Empty { 83 | } 84 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/RayDPDriverAgent.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import io.ray.runtime.config.RayConfig 21 | 22 | import org.apache.spark.{SecurityManager, SparkConf, SparkContext} 23 | import org.apache.spark.internal.Logging 24 | import org.apache.spark.rpc._ 25 | 26 | 27 | class RayDPDriverAgent() { 28 | private val spark = SparkContext.getOrCreate() 29 | private var endpoint: RpcEndpointRef = _ 30 | private var rpcEnv: RpcEnv = _ 31 | private val conf: SparkConf = new SparkConf() 32 | 33 | init 34 | 35 | def init(): Unit = { 36 | val securityMgr = new SecurityManager(conf) 37 | val host = RayConfig.create().nodeIp 38 | rpcEnv = RpcEnv.create( 39 | RayAppMaster.ENV_NAME, 40 | host, 41 | host, 42 | 0, 43 | conf, 44 | securityMgr, 45 | // limit to single-thread 46 | numUsableCores = 1, 47 | clientMode = false) 48 | // register endpoint 49 | endpoint = rpcEnv.setupEndpoint(RayDPDriverAgent.ENDPOINT_NAME, 50 | new RayDPDriverAgentEndpoint(rpcEnv)) 51 | } 52 | 53 | def getDriverAgentEndpointUrl(): String = { 54 | RpcEndpointAddress(rpcEnv.address, RayDPDriverAgent.ENDPOINT_NAME).toString 55 | } 56 | 57 | class RayDPDriverAgentEndpoint(override val rpcEnv: RpcEnv) 58 | extends ThreadSafeRpcEndpoint with Logging { 59 | override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { 60 | case RecacheRDD(rddId) => 61 | // TODO if multiple blocks get lost, should call this only once 62 | // SparkEnv.get.blockManagerMaster.getLocationsAndStatus() 63 | spark.getPersistentRDDs.map { 64 | case (id, rdd) => 65 | if (id == rddId) { 66 | rdd.count 67 | } 68 | } 69 | context.reply(true) 70 | } 71 | } 72 | 73 | } 74 | 75 | object RayDPDriverAgent { 76 | val ENDPOINT_NAME = "RAYDP_DRIVER_AGENT" 77 | } 78 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/test/org/apache/spark/scheduler/cluster/raydp/TestRayCoarseGrainedSchedulerBackend.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | package org.apache.spark.scheduler.cluster.raydp; 18 | 19 | import org.apache.spark.SparkConf; 20 | import org.junit.jupiter.api.Test; 21 | 22 | import org.apache.spark.scheduler.cluster.SchedulerBackendUtils; 23 | 24 | import static org.junit.jupiter.api.Assertions.assertEquals; 25 | 26 | /** 27 | * This class performs unit testing on some methods in `RayCoarseGrainedSchedulerBackend`. 28 | */ 29 | public class TestRayCoarseGrainedSchedulerBackend { 30 | 31 | // Test using the default value. 32 | @Test 33 | public void testExecutorNumberWithDefaultConfig() { 34 | SparkConf conf = new SparkConf(); 35 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 36 | assertEquals(2, executorNumber); 37 | } 38 | 39 | // Test using a negative value. 40 | @Test 41 | public void testExecutorNumberWithNegativeConfig() { 42 | SparkConf conf = new SparkConf(); 43 | conf.set("spark.dynamicAllocation.initialExecutors", "-1"); 44 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 45 | assertEquals(2, executorNumber); 46 | } 47 | 48 | // Test using reasonable values. 49 | @Test 50 | public void testExecutorNumberWithValidConfig() { 51 | SparkConf conf = new SparkConf(); 52 | conf.set("spark.executor.instances", "5"); 53 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 54 | assertEquals(5, executorNumber); 55 | } 56 | 57 | // Test using dynamic values. 58 | @Test 59 | public void testExecutorNumberWithDynamicConfig() { 60 | SparkConf conf = new SparkConf(); 61 | conf.set("spark.dynamicAllocation.enabled", "true"); 62 | conf.set("spark.dynamicAllocation.minExecutors", "3"); 63 | int executorNumber = SchedulerBackendUtils.getInitialTargetExecutorNumber(conf, 2); 64 | assertEquals(3, executorNumber); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /examples/test_pyfiles_main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main script for testing raydp-submit --py-files functionality. 3 | This script imports and uses functions from test_pyfile.py. 4 | """ 5 | import sys 6 | import raydp 7 | 8 | 9 | def main(): 10 | """Test that py-files are properly distributed and accessible.""" 11 | print("=" * 60) 12 | print("Testing raydp-submit --py-files functionality") 13 | print("=" * 60) 14 | 15 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access 16 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP. 17 | java_opts = " ".join([ 18 | "-XX:+IgnoreUnrecognizedVMOptions", 19 | "--add-opens=java.base/java.lang=ALL-UNNAMED", 20 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", 21 | "--add-opens=java.base/java.io=ALL-UNNAMED", 22 | "--add-opens=java.base/java.net=ALL-UNNAMED", 23 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 24 | "--add-opens=java.base/java.math=ALL-UNNAMED", 25 | "--add-opens=java.base/java.text=ALL-UNNAMED", 26 | "--add-opens=java.base/java.util=ALL-UNNAMED", 27 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", 28 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", 29 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", 30 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", 31 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED", 32 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", 33 | ]) 34 | extra_configs = { 35 | "spark.executor.extraJavaOptions": java_opts, 36 | "spark.driver.extraJavaOptions": java_opts, 37 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts, 38 | } 39 | spark = raydp.init_spark("Test PyFiles", 1, 1, "500M", 40 | configs=extra_configs) 41 | 42 | # Test: Use compute_sum in Spark executor context 43 | print("\nTesting py-files in Spark executor context...") 44 | # Create RDD and use function from test_pyfile in executors 45 | numbers_rdd = spark.sparkContext.parallelize([1, 2, 3, 4, 5], 2) 46 | 47 | # Map function that uses imported function 48 | def process_partition(partition): 49 | from test_pyfile import compute_sum # pylint: disable=import-outside-toplevel 50 | nums = list(partition) 51 | return [compute_sum(nums)] 52 | 53 | results = numbers_rdd.mapPartitions(process_partition).collect() 54 | total = sum(results) 55 | print(f"Sum from executors: {total}") 56 | assert total == 15, f"Expected 15, got {total}" 57 | print("✓ Functions from py-files work in Spark executors!") 58 | 59 | raydp.stop_spark() 60 | 61 | print("\n" + "=" * 60) 62 | print("Test passed! --py-files is working correctly!") 63 | print("=" * 60) 64 | return 0 65 | 66 | 67 | if __name__ == "__main__": 68 | sys.exit(main()) 69 | -------------------------------------------------------------------------------- /python/raydp/mpi/network/network_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: network.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import symbol_database as _symbol_database 8 | from google.protobuf.internal import builder as _builder 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rnetwork.proto\";\n\x15RegisterWorkerRequest\x12\x0e\n\x06job_id\x18\x01 \x01(\t\x12\x12\n\nworld_rank\x18\x02 \x01(\x05\"-\n\x13RegisterWorkerReply\x12\x16\n\x0enode_addresses\x18\x03 \x03(\t\"Z\n\x1cRegisterWorkerServiceRequest\x12\x12\n\nworld_rank\x18\x01 \x01(\x05\x12\x11\n\tworker_ip\x18\x02 \x01(\t\x12\x13\n\x0bworker_port\x18\x03 \x01(\x05\"I\n\x1aRegisterWorkerServiceReply\x12\x13\n\x0bray_address\x18\x01 \x01(\t\x12\x16\n\x0eredis_password\x18\x02 \x01(\t\")\n\x08\x46unction\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x0c\n\x04\x66unc\x18\x02 \x01(\x0c\"E\n\x0e\x46unctionResult\x12\x12\n\nworld_rank\x18\x01 \x01(\x05\x12\x0f\n\x07\x66unc_id\x18\x02 \x01(\x05\x12\x0e\n\x06result\x18\x03 \x01(\x0c\"\x07\n\x05\x45mpty2\xd3\x01\n\rDriverService\x12>\n\x0eRegisterWorker\x12\x16.RegisterWorkerRequest\x1a\x14.RegisterWorkerReply\x12S\n\x15RegisterWorkerService\x12\x1d.RegisterWorkerServiceRequest\x1a\x1b.RegisterWorkerServiceReply\x12-\n\x12RegisterFuncResult\x12\x0f.FunctionResult\x1a\x06.Empty2I\n\rWorkerService\x12 \n\x0bRunFunction\x12\t.Function\x1a\x06.Empty\x12\x16\n\x04Stop\x12\x06.Empty\x1a\x06.Emptyb\x06proto3') 17 | 18 | _globals = globals() 19 | _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) 20 | _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'network_pb2', _globals) 21 | if _descriptor._USE_C_DESCRIPTORS == False: 22 | 23 | DESCRIPTOR._options = None 24 | _globals['_REGISTERWORKERREQUEST']._serialized_start=17 25 | _globals['_REGISTERWORKERREQUEST']._serialized_end=76 26 | _globals['_REGISTERWORKERREPLY']._serialized_start=78 27 | _globals['_REGISTERWORKERREPLY']._serialized_end=123 28 | _globals['_REGISTERWORKERSERVICEREQUEST']._serialized_start=125 29 | _globals['_REGISTERWORKERSERVICEREQUEST']._serialized_end=215 30 | _globals['_REGISTERWORKERSERVICEREPLY']._serialized_start=217 31 | _globals['_REGISTERWORKERSERVICEREPLY']._serialized_end=290 32 | _globals['_FUNCTION']._serialized_start=292 33 | _globals['_FUNCTION']._serialized_end=333 34 | _globals['_FUNCTIONRESULT']._serialized_start=335 35 | _globals['_FUNCTIONRESULT']._serialized_end=404 36 | _globals['_EMPTY']._serialized_start=406 37 | _globals['_EMPTY']._serialized_end=413 38 | _globals['_DRIVERSERVICE']._serialized_start=416 39 | _globals['_DRIVERSERVICE']._serialized_end=627 40 | _globals['_WORKERSERVICE']._serialized_start=629 41 | _globals['_WORKERSERVICE']._serialized_end=702 42 | # @@protoc_insertion_point(module_scope) 43 | -------------------------------------------------------------------------------- /python/raydp/ray_cluster_resources.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from typing import Dict, List 19 | 20 | import ray 21 | import time 22 | from ray.ray_constants import MEMORY_RESOURCE_UNIT_BYTES 23 | 24 | 25 | class ClusterResources: 26 | # TODO: make this configurable 27 | refresh_interval = 0.1 28 | latest_refresh_time = time.time() - refresh_interval 29 | node_to_resources = {} 30 | item_keys_mapping = {"num_cpus": "CPU"} 31 | label_name = "__ray_spark_node_label" 32 | 33 | @classmethod 34 | def total_alive_nodes(cls): 35 | cls._refresh() 36 | return len(cls.node_to_resources) 37 | 38 | @classmethod 39 | def satisfy(cls, request: Dict[str, float]) -> List[str]: 40 | cls._refresh() 41 | satisfied = [] 42 | for host_name, resources in cls.node_to_resources.items(): 43 | if cls._compare_two_dict(resources, request): 44 | satisfied.append(resources[cls.label_name]) 45 | 46 | return satisfied 47 | 48 | @classmethod 49 | def _refresh(cls): 50 | if (time.time() - cls.latest_refresh_time) < cls.refresh_interval: 51 | return 52 | 53 | for node in ray.nodes(): 54 | if node["Alive"]: 55 | host_name = node["NodeManagerHostname"] 56 | resources = node["Resources"] 57 | for key in resources: 58 | if key.startswith("node:"): 59 | resources[cls.label_name] = key 60 | break 61 | assert cls.label_name in resources,\ 62 | f"{resources} should contain a resource likes: 'node:10.0.0.131': 1.0" 63 | cls.node_to_resources[host_name] = resources 64 | cls.latest_refresh_time = time.time() 65 | 66 | @classmethod 67 | def _compare_two_dict(cls, available: Dict[str, float], request: Dict[str, float]) -> bool: 68 | for k, v in request.items(): 69 | k = cls.item_keys_mapping.get(k, k) 70 | if k not in available: 71 | return False 72 | 73 | if k == "memory": 74 | v = int(v / MEMORY_RESOURCE_UNIT_BYTES) 75 | 76 | if available[k] < v: 77 | return False 78 | 79 | return True 80 | -------------------------------------------------------------------------------- /core/shims/common/src/main/scala/com/intel/raydp/shims/SparkShimLoader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.intel.raydp.shims 19 | 20 | import java.util.ServiceLoader 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import org.apache.spark.SPARK_VERSION_SHORT 25 | import org.apache.spark.internal.Logging 26 | 27 | object SparkShimLoader extends Logging { 28 | private var sparkShims: SparkShims = null 29 | private var sparkShimProviderClass: String = null 30 | 31 | def getSparkShims: SparkShims = { 32 | if (sparkShims == null) { 33 | val provider = getSparkShimProvider() 34 | sparkShims = provider.createShim 35 | } 36 | sparkShims 37 | } 38 | 39 | def getSparkVersion: String = { 40 | SPARK_VERSION_SHORT 41 | } 42 | 43 | def setSparkShimProviderClass(providerClass: String): Unit = { 44 | sparkShimProviderClass = providerClass 45 | } 46 | 47 | private def loadSparkShimProvider(): SparkShimProvider = { 48 | // Match and load Shim provider for current Spark version. 49 | val sparkVersion = getSparkVersion 50 | logInfo(s"Loading Spark Shims for version: $sparkVersion") 51 | 52 | // Load and filter the providers based on version 53 | val shimProviders = 54 | ServiceLoader.load(classOf[SparkShimProvider]).asScala.filter(_.matches(sparkVersion)) 55 | if (shimProviders.size > 1) { 56 | throw new IllegalStateException(s"More than one SparkShimProvider found: $shimProviders") 57 | } 58 | 59 | val shimProvider = shimProviders.headOption match { 60 | case Some(shimProvider) => shimProvider 61 | case None => 62 | throw new IllegalStateException(s"No Spark Shim Provider found for $sparkVersion") 63 | } 64 | logInfo(s"Using Shim provider: $shimProviders") 65 | shimProvider 66 | } 67 | 68 | private def getSparkShimProvider(): SparkShimProvider = { 69 | if (sparkShimProviderClass != null) { 70 | logInfo(s"Using Spark Shim Provider specified by $sparkShimProviderClass. ") 71 | val providerClass = Class.forName(sparkShimProviderClass) 72 | val providerConstructor = providerClass.getConstructor() 73 | providerConstructor.newInstance().asInstanceOf[SparkShimProvider] 74 | } else { 75 | loadSparkShimProvider() 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /python/raydp/tests/test_tf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import pyspark 19 | import pytest 20 | import os 21 | import sys 22 | import shutil 23 | 24 | import tensorflow as tf 25 | import tensorflow.keras as keras 26 | 27 | from pyspark.sql.functions import rand 28 | 29 | from raydp.tf import TFEstimator 30 | from raydp.utils import random_split 31 | 32 | @pytest.mark.parametrize("use_fs_directory", [True, False]) 33 | def test_tf_estimator(spark_on_ray_small, use_fs_directory): 34 | spark = spark_on_ray_small 35 | 36 | # ---------------- data process with Spark ------------ 37 | # calculate y = 3 * x + 4 38 | df: pyspark.sql.DataFrame = spark.range(0, 100000) 39 | df = df.withColumn("x", rand() * 100) # add x column 40 | df = df.withColumn("y", df.x * 3 + rand() + 4) # add y column 41 | df = df.select(df.x, df.y) 42 | 43 | train_df, test_df = random_split(df, [0.7, 0.3]) 44 | 45 | # create model 46 | model = keras.Sequential( 47 | [ 48 | keras.layers.InputLayer(input_shape=()), 49 | # Add feature dimension, expanding (batch_size,) to (batch_size, 1). 50 | keras.layers.Flatten(), 51 | keras.layers.Dense(1), 52 | ] 53 | ) 54 | 55 | optimizer = keras.optimizers.Adam(0.01) 56 | loss = keras.losses.MeanSquaredError() 57 | 58 | estimator = TFEstimator(num_workers=2, 59 | model=model, 60 | optimizer=optimizer, 61 | loss=loss, 62 | metrics=["accuracy", "mse"], 63 | feature_columns="x", 64 | label_columns="y", 65 | batch_size=1000, 66 | num_epochs=2, 67 | use_gpu=False) 68 | 69 | if use_fs_directory: 70 | dir = os.path.dirname(__file__) + "/test_tf" 71 | uri = "file://" + dir 72 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri) 73 | else: 74 | estimator.fit_on_spark(train_df, test_df) 75 | model = estimator.get_model() 76 | result = model(tf.constant([0, 0])) 77 | assert result.shape == (2, 1) 78 | if use_fs_directory: 79 | shutil.rmtree(dir) 80 | 81 | if __name__ == "__main__": 82 | sys.exit(pytest.main(["-v", __file__])) 83 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterEntryPoint.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import java.io.{DataOutputStream, File, FileOutputStream} 21 | import java.net.InetAddress 22 | import java.nio.file.Files 23 | 24 | import scala.util.Try 25 | 26 | import py4j.GatewayServer 27 | 28 | import org.apache.spark.internal.Logging 29 | 30 | 31 | class AppMasterEntryPoint { 32 | private val appMaster: AppMasterJavaBridge = new AppMasterJavaBridge() 33 | 34 | def getAppMasterBridge(): AppMasterJavaBridge = { 35 | appMaster 36 | } 37 | } 38 | 39 | object AppMasterEntryPoint extends Logging { 40 | private val localhost = InetAddress.getLoopbackAddress() 41 | 42 | def getGatewayServer(): GatewayServer = { 43 | new GatewayServer.GatewayServerBuilder() 44 | .javaPort(0) 45 | .javaAddress(localhost) 46 | .entryPoint(new AppMasterEntryPoint()) 47 | .build() 48 | } 49 | 50 | def main(args: Array[String]): Unit = { 51 | 52 | var server = getGatewayServer() 53 | 54 | while(true) { 55 | if (!Try(server.start()).isFailure) { 56 | val boundPort: Int = server.getListeningPort() 57 | if (boundPort == -1) { 58 | logError(s"${server.getClass} failed to bind; exiting") 59 | System.exit(1) 60 | } else { 61 | logDebug(s"Started PythonGatewayServer on port $boundPort") 62 | } 63 | 64 | 65 | val connectionInfoPath = new File(sys.env("_RAYDP_APPMASTER_CONN_INFO_PATH")) 66 | val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), 67 | "connection", ".info").toFile() 68 | 69 | val dos = new DataOutputStream(new FileOutputStream(tmpPath)) 70 | dos.writeInt(boundPort) 71 | dos.close() 72 | 73 | if (!tmpPath.renameTo(connectionInfoPath)) { 74 | logError(s"Unable to write connection information to $connectionInfoPath.") 75 | System.exit(1) 76 | } 77 | 78 | // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: 79 | while (System.in.read() != -1) { 80 | // Do nothing 81 | } 82 | logDebug("Exiting due to broken pipe from Python driver") 83 | System.exit(0) 84 | } else { 85 | server.shutdown() 86 | logError(s"${server.getClass} failed to bind; retrying...") 87 | Thread.sleep(1000) 88 | server = getGatewayServer() 89 | } 90 | } 91 | 92 | 93 | 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /examples/test_raydp_submit_pyfiles.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script for raydp-submit --py-files functionality. 3 | This script mimics raydp-submit.py but tests that --py-files works correctly. 4 | """ 5 | from os.path import dirname, abspath, join 6 | import sys 7 | import json 8 | import subprocess 9 | import shlex 10 | import ray 11 | 12 | def main(): 13 | print("Starting raydp-submit --py-files test...") 14 | 15 | # Initialize Ray and get cluster info 16 | ray.init(address="auto") 17 | node = ray.worker.global_worker.node 18 | options = {} 19 | options["ray"] = {} 20 | options["ray"]["run-mode"] = "CLUSTER" 21 | options["ray"]["node-ip"] = node.node_ip_address 22 | options["ray"]["address"] = node.address 23 | options["ray"]["session-dir"] = node.get_session_dir_path() 24 | 25 | ray.shutdown() 26 | 27 | # Write Ray configuration 28 | examples_dir = dirname(abspath(__file__)) 29 | conf_path = join(examples_dir, "ray.conf") 30 | with open(conf_path, "w") as f: 31 | json.dump(options, f) 32 | 33 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access 34 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP. 35 | java_opts = " ".join([ 36 | "-XX:+IgnoreUnrecognizedVMOptions", 37 | "--add-opens=java.base/java.lang=ALL-UNNAMED", 38 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", 39 | "--add-opens=java.base/java.io=ALL-UNNAMED", 40 | "--add-opens=java.base/java.net=ALL-UNNAMED", 41 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 42 | "--add-opens=java.base/java.math=ALL-UNNAMED", 43 | "--add-opens=java.base/java.text=ALL-UNNAMED", 44 | "--add-opens=java.base/java.util=ALL-UNNAMED", 45 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", 46 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", 47 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", 48 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", 49 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED", 50 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", 51 | ]) 52 | 53 | # Build raydp-submit command 54 | command = ["bin/raydp-submit", "--ray-conf", conf_path] 55 | command += ["--conf", "spark.executor.cores=1"] 56 | command += ["--conf", "spark.executor.instances=1"] 57 | command += ["--conf", "spark.executor.memory=500m"] 58 | command += ["--conf", f"spark.executor.extraJavaOptions={java_opts}"] 59 | command += ["--conf", f"spark.driver.extraJavaOptions={java_opts}"] 60 | command += ["--conf", f"spark.ray.raydp_app_master.extraJavaOptions={java_opts}"] 61 | 62 | # Add --py-files with test_pyfile.py 63 | test_pyfile_path = join(examples_dir, "test_pyfile.py") 64 | command += ["--py-files", test_pyfile_path] 65 | 66 | # Add the main script 67 | main_script_path = join(examples_dir, "test_pyfiles_main.py") 68 | command.append(main_script_path) 69 | 70 | # Execute the command 71 | print("\nExecuting command:") 72 | cmd_str = " ".join(shlex.quote(arg) for arg in command) 73 | print(cmd_str) 74 | print("\n" + "=" * 60) 75 | 76 | result = subprocess.run(cmd_str, check=True, shell=True) 77 | 78 | return result.returncode 79 | 80 | 81 | if __name__ == "__main__": 82 | sys.exit(main()) 83 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/scala/org/apache/spark/deploy/raydp/AppMasterJavaBridge.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.deploy.raydp 19 | 20 | import java.util.Map 21 | 22 | import scala.collection.JavaConverters._ 23 | 24 | import io.ray.api.{ActorHandle, Ray} 25 | 26 | import org.apache.spark.raydp.SparkOnRayConfigs 27 | 28 | class AppMasterJavaBridge { 29 | private var handle: ActorHandle[RayAppMaster] = null 30 | 31 | def startUpAppMaster(extra_cp: String, sparkProps: Map[String, Any]): Unit = { 32 | if (handle == null) { 33 | // init ray, we should set the config by java properties 34 | Ray.init() 35 | val name = RayAppMaster.ACTOR_NAME 36 | val sparkJvmOptions = sparkProps.asScala.toMap.filter( 37 | e => { 38 | !SparkOnRayConfigs.SPARK_DRIVER_EXTRA_JAVA_OPTIONS.equals(e._1) && 39 | !SparkOnRayConfigs.SPARK_APP_MASTER_EXTRA_JAVA_OPTIONS.equals(e._1) 40 | }) 41 | .map { 42 | case (k, v) => 43 | if (!SparkOnRayConfigs.SPARK_JAVAAGENT.equals(k)) { 44 | "-D" + k + "=" + v 45 | } else { 46 | "-javaagent:" + v 47 | } 48 | }.toBuffer 49 | 50 | // Add raw JVM options from spark.ray.raydp_app_master.extraJavaOptions 51 | // (e.g., --add-opens for JDK 17) 52 | val appMasterExtraJavaOptions = 53 | sparkProps.get(SparkOnRayConfigs.SPARK_APP_MASTER_EXTRA_JAVA_OPTIONS) 54 | if (appMasterExtraJavaOptions != null) { 55 | val opts = appMasterExtraJavaOptions.toString.trim 56 | if (opts.nonEmpty) { 57 | sparkJvmOptions ++= opts.split("\\s+").toSeq 58 | } 59 | } 60 | 61 | val appMasterResources = sparkProps.asScala.filter { 62 | case (k, v) => k.startsWith(SparkOnRayConfigs.SPARK_MASTER_ACTOR_RESOURCE_PREFIX) 63 | }.map{ case (k, v) => k->double2Double(v.toString.toDouble) }.asJava 64 | 65 | handle = RayAppMasterUtils.createAppMaster( 66 | extra_cp, name, 67 | (sparkJvmOptions ++ Seq(SparkOnRayConfigs.RAYDP_LOGFILE_PREFIX_CFG)).asJava, 68 | appMasterResources) 69 | } 70 | } 71 | 72 | def getMasterUrl(): String = { 73 | if (handle == null) { 74 | throw new RuntimeException("You should create the RayAppMaster handle first") 75 | } 76 | RayAppMasterUtils.getMasterUrl(handle) 77 | } 78 | 79 | def stop(): Unit = { 80 | if (handle != null) { 81 | RayAppMasterUtils.stopAppMaster(handle) 82 | Ray.shutdown() 83 | handle = null 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /core/shims/spark340/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark340 15 | RayDP Shims for Spark 3.4.0 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark340.version} 70 | provided 71 | 72 | 73 | org.apache.spark 74 | spark-core_${scala.binary.version} 75 | ${spark340.version} 76 | provided 77 | 78 | 79 | org.xerial.snappy 80 | snappy-java 81 | 82 | 83 | io.netty 84 | netty-handler 85 | 86 | 87 | 88 | 89 | org.xerial.snappy 90 | snappy-java 91 | ${snappy.version} 92 | 93 | 94 | io.netty 95 | netty-handler 96 | ${netty.version} 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /core/shims/spark350/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark350 15 | RayDP Shims for Spark 3.5.0 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark350.version} 70 | provided 71 | 72 | 73 | org.apache.spark 74 | spark-core_${scala.binary.version} 75 | ${spark350.version} 76 | provided 77 | 78 | 79 | org.xerial.snappy 80 | snappy-java 81 | 82 | 83 | io.netty 84 | netty-handler 85 | 86 | 87 | 88 | 89 | org.xerial.snappy 90 | snappy-java 91 | ${snappy.version} 92 | 93 | 94 | io.netty 95 | netty-handler 96 | ${netty.version} 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /examples/xgboost_ray_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import numpy as np 3 | # XGBoost on ray is needed to run this example. 4 | # Please refer to https://docs.ray.io/en/latest/xgboost-ray.html to install it. 5 | from xgboost_ray import RayDMatrix, train, RayParams 6 | import raydp 7 | from raydp.utils import random_split 8 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV 9 | 10 | # connect to ray cluster 11 | # ray.init(address="auto") 12 | ray.init(address="local", num_cpus=4) 13 | # After ray.init, you can use the raydp api to get a spark session 14 | app_name = "NYC Taxi Fare Prediction with RayDP" 15 | num_executors = 1 16 | cores_per_executor = 1 17 | memory_per_executor = "500M" 18 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access 19 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP. 20 | java_opts = " ".join([ 21 | "-XX:+IgnoreUnrecognizedVMOptions", 22 | "--add-opens=java.base/java.lang=ALL-UNNAMED", 23 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", 24 | "--add-opens=java.base/java.io=ALL-UNNAMED", 25 | "--add-opens=java.base/java.net=ALL-UNNAMED", 26 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 27 | "--add-opens=java.base/java.math=ALL-UNNAMED", 28 | "--add-opens=java.base/java.text=ALL-UNNAMED", 29 | "--add-opens=java.base/java.util=ALL-UNNAMED", 30 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", 31 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", 32 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", 33 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", 34 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED", 35 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", 36 | ]) 37 | extra_configs = { 38 | "spark.executor.extraJavaOptions": java_opts, 39 | "spark.driver.extraJavaOptions": java_opts, 40 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts, 41 | } 42 | spark = raydp.init_spark(app_name, num_executors, 43 | cores_per_executor, memory_per_executor, 44 | configs=extra_configs) 45 | data = spark.read.format("csv").option("header", "true") \ 46 | .option("inferSchema", "true") \ 47 | .load(NYC_TRAIN_CSV) 48 | # Set spark timezone for processing datetime 49 | spark.conf.set("spark.sql.session.timeZone", "UTC") 50 | # Transform the dataset 51 | data = nyc_taxi_preprocess(data) 52 | # Split data into train_dataset and test_dataset 53 | train_df, test_df = random_split(data, [0.9, 0.1], 0) 54 | # Convert spark dataframe into ray dataset 55 | train_dataset = ray.data.from_spark(train_df) 56 | test_dataset = ray.data.from_spark(test_df) 57 | # Then convert them into DMatrix used by xgboost 58 | dtrain = RayDMatrix(train_dataset, label="fare_amount") 59 | dtest = RayDMatrix(test_dataset, label="fare_amount") 60 | # Configure the XGBoost model 61 | config = { 62 | "tree_method": "hist", 63 | "eval_metric": ["logloss", "error"], 64 | } 65 | evals_result = {} 66 | # Train the model 67 | bst = train( 68 | config, 69 | dtrain, 70 | evals=[(dtest, "eval")], 71 | evals_result=evals_result, 72 | ray_params=RayParams(max_actor_restarts=1, num_actors=1, cpus_per_actor=1), 73 | num_boost_round=10) 74 | # print evaluation stats 75 | print("Final validation error: {:.4f}".format( 76 | evals_result["eval"]["error"][-1])) 77 | raydp.stop_spark() 78 | ray.shutdown() 79 | -------------------------------------------------------------------------------- /python/raydp/services.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from abc import ABC, abstractmethod 19 | from typing import Any, Dict, NoReturn 20 | 21 | 22 | class Cluster(ABC): 23 | """ 24 | This is the base class for all specified cluster, such as SparkCluster, FlinkCluster. 25 | :param master_resources_requirement: The resources requirement for the master service. 26 | """ 27 | def __init__(self, master_resources_requirement): 28 | # the master node is live as same as ray driver node. And we can specify the resources 29 | # limitation for master node. So we don't count it. 30 | self._num_nodes = 0 31 | 32 | @abstractmethod 33 | def _set_up_master(self, 34 | resources: Dict[str, float], 35 | kwargs: Dict[Any, Any]): 36 | """ 37 | Subcluster should implement this to set up master node. 38 | """ 39 | 40 | def add_worker(self, 41 | resources_requirement: Dict[str, float], 42 | **kwargs: Dict[Any, Any]): 43 | """ 44 | Add one worker to the cluster. 45 | :param resources_requirement: The resource requirements for the worker service. 46 | """ 47 | try: 48 | self._set_up_worker(resources_requirement, kwargs) 49 | except: 50 | self.stop() 51 | raise 52 | 53 | @abstractmethod 54 | def _set_up_worker(self, 55 | resources: Dict[str, float], 56 | kwargs: Dict[str, str]): 57 | """ 58 | Subcluster should implement this to set up worker node. 59 | """ 60 | 61 | @abstractmethod 62 | def get_cluster_url(self) -> str: 63 | """ 64 | Return the cluster url, eg: spark://master-node:7077 65 | """ 66 | 67 | @abstractmethod 68 | def stop(self): 69 | """ 70 | Stop cluster 71 | """ 72 | 73 | 74 | class ClusterMaster(ABC): 75 | 76 | @abstractmethod 77 | def start_up(self) -> NoReturn: 78 | pass 79 | 80 | @abstractmethod 81 | def get_master_url(self) -> str: 82 | pass 83 | 84 | @abstractmethod 85 | def get_host(self) -> str: 86 | pass 87 | 88 | @abstractmethod 89 | def stop(self): 90 | pass 91 | 92 | 93 | class ClusterWorker(ABC): 94 | 95 | @abstractmethod 96 | def start_up(self) -> str: 97 | """ 98 | :return: error message, return None if succeeded 99 | """ 100 | 101 | @abstractmethod 102 | def get_host(self) -> str: 103 | pass 104 | 105 | @abstractmethod 106 | def stop(self): 107 | pass 108 | -------------------------------------------------------------------------------- /python/raydp/tests/test_torch.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from pyspark.sql import SparkSession 19 | import pytest 20 | import os 21 | import sys 22 | import shutil 23 | import torch 24 | 25 | # https://spark.apache.org/docs/latest/api/python/migration_guide/koalas_to_pyspark.html 26 | # import databricks.koalas as ks 27 | import pyspark.pandas as ps 28 | 29 | from raydp.torch import TorchEstimator 30 | from raydp.utils import random_split 31 | 32 | @pytest.mark.parametrize("use_fs_directory", [True, False]) 33 | def test_torch_estimator(spark_on_ray_small, use_fs_directory): 34 | # ---------------- data process with koalas ------------ 35 | spark: SparkSession = spark_on_ray_small 36 | 37 | # calculate z = 3 * x + 4 * y + 5 38 | df = ps.range(0, 100000) 39 | df["x"] = df["id"] + 100 40 | df["y"] = df["id"] + 1000 41 | df["z"] = df["x"] * 3 + df["y"] * 4 + 5 42 | df = df.astype("float") 43 | 44 | train_df, test_df = random_split(df, [0.7, 0.3]) 45 | 46 | # ---------------- ray sgd ------------------------- 47 | # create the model 48 | class LinearModel(torch.nn.Module): 49 | def __init__(self): 50 | super(LinearModel, self).__init__() 51 | self.linear = torch.nn.Linear(2, 1) 52 | 53 | def forward(self, x): 54 | return self.linear(x) 55 | 56 | model = LinearModel() 57 | # create the optimizer 58 | optimizer = torch.optim.Adam(model.parameters()) 59 | # create the loss 60 | loss = torch.nn.MSELoss() 61 | # create lr_scheduler 62 | 63 | def lr_scheduler_creator(optimizer, config): 64 | return torch.optim.lr_scheduler.MultiStepLR( 65 | optimizer, milestones=[150, 250, 350], gamma=0.1) 66 | 67 | # create the estimator 68 | estimator = TorchEstimator(num_workers=2, 69 | model=model, 70 | optimizer=optimizer, 71 | loss=loss, 72 | lr_scheduler_creator=lr_scheduler_creator, 73 | feature_columns=["x", "y"], 74 | feature_types=torch.float, 75 | label_column="z", 76 | label_type=torch.float, 77 | batch_size=1000, 78 | num_epochs=2, 79 | use_gpu=False) 80 | 81 | # train the model 82 | if use_fs_directory: 83 | dir = os.path.dirname(__file__) + "/test_torch" 84 | uri = "file://" + dir 85 | estimator.fit_on_spark(train_df, test_df, fs_directory=uri) 86 | else: 87 | estimator.fit_on_spark(train_df, test_df) 88 | model = estimator.get_model() 89 | result = model(torch.Tensor([[0, 0], [1, 1]])) 90 | assert result.shape == (2, 1) 91 | if use_fs_directory: 92 | shutil.rmtree(dir) 93 | 94 | 95 | if __name__ == "__main__": 96 | sys.exit(pytest.main(["-v", __file__])) 97 | -------------------------------------------------------------------------------- /core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.raydp; 19 | 20 | import io.ray.api.ActorHandle; 21 | import io.ray.api.ObjectRef; 22 | import io.ray.api.Ray; 23 | import io.ray.api.call.ActorCreator; 24 | import java.util.Map; 25 | import java.util.List; 26 | 27 | import io.ray.api.placementgroup.PlacementGroup; 28 | import io.ray.runtime.object.ObjectRefImpl; 29 | import org.apache.spark.executor.RayDPExecutor; 30 | 31 | public class RayExecutorUtils { 32 | /** 33 | * Convert from mbs -> memory units. The memory units in ray is byte 34 | */ 35 | 36 | private static double toMemoryUnits(int memoryInMB) { 37 | double result = 1.0 * memoryInMB * 1024 * 1024; 38 | return Math.round(result); 39 | } 40 | 41 | public static ActorHandle createExecutorActor( 42 | String executorId, 43 | String appMasterURL, 44 | double cores, 45 | int memoryInMB, 46 | Map resources, 47 | PlacementGroup placementGroup, 48 | int bundleIndex, 49 | List javaOpts) { 50 | ActorCreator creator = Ray.actor( 51 | RayDPExecutor::new, executorId, appMasterURL); 52 | creator.setName("raydp-executor-" + executorId); 53 | creator.setJvmOptions(javaOpts); 54 | creator.setResource("CPU", cores); 55 | creator.setResource("memory", toMemoryUnits(memoryInMB)); 56 | 57 | for (Map.Entry entry: resources.entrySet()) { 58 | creator.setResource(entry.getKey(), entry.getValue()); 59 | } 60 | if (placementGroup != null) { 61 | creator.setPlacementGroup(placementGroup, bundleIndex); 62 | } 63 | creator.setMaxRestarts(-1); 64 | creator.setMaxTaskRetries(-1); 65 | creator.setMaxConcurrency(2); 66 | return creator.remote(); 67 | } 68 | 69 | public static void setUpExecutor( 70 | ActorHandle handler, 71 | String appId, 72 | String driverUrl, 73 | int cores, 74 | String classPathEntries) { 75 | handler.task(RayDPExecutor::startUp, 76 | appId, driverUrl, cores, classPathEntries).remote(); 77 | } 78 | 79 | public static String[] getBlockLocations( 80 | ActorHandle handler, 81 | int rddId, 82 | int numPartitions) { 83 | return handler.task(RayDPExecutor::getBlockLocations, 84 | rddId, numPartitions).remote().get(); 85 | } 86 | 87 | public static ObjectRef getRDDPartition( 88 | ActorHandle handle, 89 | int rddId, 90 | int partitionId, 91 | String schema, 92 | String driverAgentUrl) { 93 | return (ObjectRefImpl) handle.task( 94 | RayDPExecutor::getRDDPartition, 95 | rddId, partitionId, schema, driverAgentUrl).remote(); 96 | } 97 | 98 | public static void exitExecutor( 99 | ActorHandle handle 100 | ) { 101 | handle.task(RayDPExecutor::stop).remote(); 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /doc/mpi.md: -------------------------------------------------------------------------------- 1 | # MPI on Ray 2 | 3 | RayDP also provides a simple API to running MPI job on top of Ray. Currently, we support three types of MPI: `intel_mpi`, `openmpi` and `MPICH`. To use the following API, make sure you have installed the given type of MPI on each of Ray worker node. 4 | 5 | ### API 6 | 7 | ```python 8 | def create_mpi_job(job_name: str, 9 | world_size: int, 10 | num_cpus_per_process: int, 11 | num_processes_per_node: int, 12 | mpi_script_prepare_fn: Callable = None, 13 | timeout: int = 1, 14 | mpi_type: str = "intel_mpi", 15 | placement_group=None, 16 | placement_group_bundle_indexes: List[int] = None) -> MPIJob: 17 | """ Create a MPI Job 18 | 19 | :param job_name: the job name 20 | :param world_size: the world size 21 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray 22 | :param num_processes_per_node: num processes per node 23 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a 24 | MPIJobcontext instance. It will use the default script if not provides. 25 | :param timeout: the timeout used to wait for job creation 26 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and mpich 27 | :param placement_group: the placement_group for request mpi resources 28 | :param placement_group_bundle_indexes: this should be equal with 29 | world_size / num_processes_per_node if provides. 30 | """ 31 | ``` 32 | 33 | ### Create a simple MPI Job 34 | 35 | ```python 36 | from raydp.mpi import create_mpi_job, MPIJobContext, WorkerContext 37 | 38 | # Define the MPI JOb. We want to create a 4 world_size MPIJob, and each process requires 2 cpus. 39 | # We have set the num_processes_per_node to 2, so the processes will be strictly spread into two nodes. 40 | 41 | # You could also to specify the placement group to reserve the resources for MPI job. The num_cpus_per_process 42 | # will be ignored if the placement group is provided. And the size of 43 | # placement_group_bundle_indexes should be equal with world_size // num_processes_per_node. 44 | job = create_mpi_job(job_name="example", 45 | world_size=4, 46 | num_cpus_per_process=2, 47 | num_processes_per_node=2, 48 | timeout=5, 49 | mpi_type="intel_mpi", 50 | placement_group=None, 51 | placement_group_bundle_indexes: List[int] = None) 52 | 53 | # Start the MPI Job, this will start up the MPI processes and connect to the ray cluster 54 | job.start() 55 | 56 | # define the MPI task function 57 | def func(context: WorkerContext): 58 | return context.job_id 59 | 60 | # run the MPI task, this is a blocking operation. And the results is a world_size array. 61 | results = job.run(func) 62 | 63 | # stop the MPI job 64 | job.stop() 65 | ``` 66 | 67 | ### Use `with` auto start/stop MPIJob 68 | ```python 69 | with create_mpi_job(job_name="example", 70 | world_size=4, 71 | num_cpus_per_process=2, 72 | num_processes_per_node=2, 73 | timeout=5, 74 | mpi_type="intel_mpi") as job: 75 | def f(context: WorkerContext): 76 | return context.job_id 77 | results = job.run(f) 78 | ``` 79 | 80 | ### Specify the MPI script and environments 81 | 82 | You could customize the MPI job environments and MPI scripts with `mpi_script_prepare_fn` argument. 83 | 84 | ```python 85 | def script_prepare_fn(context: MPIJobContext): 86 | context.add_env("OMP_NUM_THREADS", "2") 87 | default_script = ["mpirun", "--allow-run-as-root", "--tag-output", "-H", 88 | ",".join(context.hosts), "-N", f"{context.num_procs_per_node}"] 89 | return default_script 90 | 91 | job = create_mpi_job(job_name="example", 92 | world_size=4, 93 | num_cpus_per_process=2, 94 | num_processes_per_node=2, 95 | timeout=5, 96 | mpi_type="intel_mpi", 97 | mpi_script_prepare_fn=script_prepare_fn) 98 | ``` 99 | -------------------------------------------------------------------------------- /core/agent/src/main/java/org/apache/spark/raydp/Agent.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.raydp; 19 | 20 | import org.slf4j.LoggerFactory; 21 | 22 | import java.io.File; 23 | import java.io.FileOutputStream; 24 | import java.io.IOException; 25 | import java.io.OutputStream; 26 | import java.io.OutputStreamWriter; 27 | import java.io.PrintStream; 28 | import java.io.Writer; 29 | import java.lang.instrument.Instrumentation; 30 | import java.lang.management.ManagementFactory; 31 | import java.nio.charset.Charset; 32 | import java.nio.charset.StandardCharsets; 33 | 34 | 35 | public class Agent { 36 | 37 | public static final PrintStream DEFAULT_ERR_PS = System.err; 38 | 39 | public static final PrintStream DEFAULT_OUT_PS = System.out; 40 | 41 | public static void premain(String agentArgs, Instrumentation inst) 42 | throws IOException { 43 | // redirect system output/error stream so that annoying SLF4J warnings 44 | // and other logs during binding 45 | // SLF4J factory don't show in spark-shell 46 | // Instead, the warnings and logs are kept in 47 | // /logs/slf4j-.log 48 | 49 | String pid = ManagementFactory.getRuntimeMXBean().getName() 50 | .split("@")[0]; 51 | String logDir = System.getProperty("ray.logging.dir"); 52 | if (logDir == null) { 53 | logDir = "/tmp/ray/session_latest/logs"; 54 | System.getProperties().put("ray.logging.dir", logDir); 55 | } 56 | 57 | File parentDir = new File(logDir); 58 | if (!parentDir.exists()) { 59 | boolean flag = parentDir.mkdirs(); 60 | if (!flag) { 61 | throw new RuntimeException("Error create log dir."); 62 | } 63 | } 64 | 65 | File logFile = new File(parentDir, "/slf4j-" + pid + ".log"); 66 | try (PrintStream ps = new PrintStream(logFile, "UTF-8")) { 67 | System.setOut(ps); 68 | System.setErr(ps); 69 | // slf4j binding 70 | LoggerFactory.getLogger(Agent.class); 71 | } catch (Exception e) { 72 | e.printStackTrace(); 73 | } finally { 74 | System.out.flush(); 75 | System.err.flush(); 76 | // restore system output/error stream 77 | System.setErr(DEFAULT_ERR_PS); 78 | System.setOut(DEFAULT_OUT_PS); 79 | } 80 | // below is to write ':job_id:' to first line of log file prefixed with 'java-worker' as required by 81 | // PR, https://github.com/ray-project/ray/pull/31772. 82 | // It's a workaround of the ray 2.3.[0-1] issue going to be fixed by https://github.com/ray-project/ray/pull/33665. 83 | String jobId = System.getenv("RAY_JOB_ID"); 84 | String rayAddress = System.getProperty("ray.address"); 85 | if (jobId != null && rayAddress != null) { 86 | String prefix = "java-worker"; 87 | // TODO: uncomment after the ray PR #33665 released 88 | // String prefix = System.getProperty("ray.logging.file-prefix", "java-worker"); 89 | // if ("java-worker".equals(prefix)) { 90 | File file = new File(new String((logDir + "/" + prefix + "-" + jobId + "-" + pid + ".log") 91 | .getBytes(Charset.forName("UTF-8")), "UTF-8")); 92 | try (OutputStream out = new FileOutputStream(file); 93 | Writer writer = new OutputStreamWriter(out, StandardCharsets.UTF_8)) { 94 | writer.write(":job_id:" + jobId + "\n"); 95 | } 96 | // } 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /.github/workflows/pypi_release.yml: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | name: RayDP PyPI Release 19 | 20 | on: 21 | workflow_dispatch: 22 | inputs: 23 | tag: 24 | description: 'Git tag to publish (e.g., v1.7.0 or 1.7.0)' 25 | required: true 26 | type: string 27 | 28 | permissions: # added using https://github.com/step-security/secure-repo 29 | contents: read 30 | 31 | jobs: 32 | build-and-publish: 33 | # do not run in forks 34 | if: ${{ github.repository_owner == 'ray-project' }} 35 | name: build wheel and upload release 36 | runs-on: ubuntu-latest 37 | env: 38 | PYSPARK_VERSION: "3.5.7" 39 | RAY_VERSION: "2.40.0" 40 | steps: 41 | - uses: actions/checkout@61b9e3751b92087fd0b06925ba6dd6314e06f089 # master 42 | with: 43 | ref: ${{ inputs.tag }} 44 | fetch-depth: 0 45 | - name: Set up Python 3.10 46 | uses: actions/setup-python@e9aba2c848f5ebd159c070c61ea2c4e2b122355e # v2.3.4 47 | with: 48 | python-version: 3.10.14 49 | - name: Set up JDK 1.8 50 | uses: actions/setup-java@b6e674f4b717d7b0ae3baee0fbe79f498905dfde # v1.4.4 51 | with: 52 | java-version: 1.8 53 | - name: Install extra dependencies for Ubuntu 54 | run: | 55 | sudo apt-get install -y mpich 56 | - name: Cache pip 57 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 58 | with: 59 | path: ~/.cache/pip 60 | key: ubuntu-latest-3.10.14-pip 61 | - name: Install dependencies 62 | run: | 63 | python -m pip install --upgrade pip 64 | pip install wheel 65 | pip install "numpy<1.24" "click<8.3.0" 66 | pip install "pydantic<2.0" 67 | pip install torch --index-url https://download.pytorch.org/whl/cpu 68 | pip install pyarrow "ray[train,default]==${{ env.RAY_VERSION }}" tqdm pytest tensorflow==2.13.1 tabulate grpcio-tools wget 69 | pip install "xgboost_ray[default]<=0.1.13" 70 | pip install "xgboost<=2.0.3" 71 | pip install torchmetrics 72 | - name: Cache Maven 73 | uses: actions/cache@8492260343ad570701412c2f464a5877dc76bace # v2 74 | with: 75 | path: ~/.m2 76 | key: ubuntu-latest-m2-${{ hashFiles('core/pom.xml') }} 77 | - name: Build and install 78 | env: 79 | GITHUB_CI: 1 80 | run: | 81 | pip install pyspark==${{ env.PYSPARK_VERSION }} 82 | ./build.sh 83 | pip install dist/raydp-*.whl 84 | - name: Lint 85 | run: | 86 | pip install pylint==2.8.3 87 | pylint --rcfile=python/pylintrc python/raydp 88 | pylint --rcfile=python/pylintrc examples/*.py 89 | - name: Test with pytest 90 | run: | 91 | ray start --head --num-cpus 6 92 | pytest python/raydp/tests/ -v 93 | ray stop --force 94 | - name: Test Examples 95 | run: | 96 | ray start --head 97 | python examples/raydp-submit.py 98 | python examples/test_raydp_submit_pyfiles.py 99 | ray stop 100 | python examples/pytorch_nyctaxi.py 101 | python examples/tensorflow_nyctaxi.py 102 | python examples/xgboost_ray_nyctaxi.py 103 | # python examples/raytrain_nyctaxi.py 104 | python examples/data_process.py 105 | - name: Upload to PyPI 106 | uses: pypa/gh-action-pypi-publish@v1.13.0 107 | with: 108 | password: ${{ secrets.PYPI_API_TOKEN }} 109 | -------------------------------------------------------------------------------- /python/raydp/torch/torch_ml_dataset.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import logging 19 | import queue 20 | import threading 21 | from typing import Callable 22 | 23 | import ray 24 | from ray.util.data import MLDataset 25 | from torch.utils.data import IterableDataset 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class TorchMLDataset(IterableDataset): 31 | def __init__(self, 32 | ds: MLDataset, 33 | collate_fn: Callable, 34 | shuffle: bool = False, 35 | shuffle_seed: int = None): 36 | super().__init__() 37 | self.ds = ds 38 | self.collate_fn = collate_fn 39 | self.shuffle = shuffle 40 | self.shuffle_seed = shuffle_seed or 1 41 | 42 | def __iter__(self): 43 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards()) 44 | it = iter(it) 45 | for pdf in it: 46 | if self.shuffle: 47 | pdf = pdf.sample(frac=1.0, random_state=self.shuffle_seed) 48 | yield self.collate_fn(pdf) 49 | 50 | def __len__(self): 51 | all_actors = [] 52 | for actor_set in self.ds.actor_sets: 53 | all_actors.extend(actor_set.actors) 54 | assert len(all_actors) > 0 55 | if "__len__" in dir(all_actors[0]): 56 | # This is a very hack method to get the length of the iterator 57 | num_records = sum([ray.get(actor.__len__.remote()) for actor in all_actors]) 58 | else: 59 | logger.warning("The MLDataset has not provide the __len__ method, we will iter all " 60 | "data to count the number of rows. This should be pretty slowly.") 61 | it = self.ds.gather_async(batch_ms=0, num_async=self.ds.num_shards()) 62 | it = iter(it) 63 | num_records = 0 64 | for pdf in it: 65 | num_records += pdf.shape[0] 66 | return num_records 67 | 68 | 69 | class PrefetchedDataLoader: 70 | def __init__(self, base_loader, max_size: int = 5): 71 | self.base_loader = base_loader 72 | self.max_size = max_size 73 | self.queue = queue.Queue(maxsize=max_size) 74 | self.fetcher = None 75 | self.fetcher_stop = threading.Event() 76 | 77 | def _setup(self): 78 | if self.fetcher is not None: 79 | self.fetcher_stop.set() 80 | if self.queue is not None and not self.queue.empty(): 81 | self.queue.get() 82 | self.queue = queue.Queue(maxsize=self.max_size) 83 | self.fetcher = None 84 | self.fetcher_stop.clear() 85 | 86 | it = iter(self.base_loader) 87 | 88 | def fetch_task(): 89 | while not self.fetcher_stop.is_set(): 90 | try: 91 | got_data = next(it) 92 | self.queue.put(got_data) 93 | except StopIteration: 94 | self.queue.put(None) 95 | break 96 | except: # pylint: disable=W0707, W0706 97 | raise 98 | self.fetcher = threading.Thread(target=fetch_task) 99 | self.fetcher.start() 100 | 101 | def __iter__(self): 102 | self._setup() 103 | while True: 104 | fetched_data = self.queue.get() 105 | if fetched_data is not None: 106 | yield fetched_data 107 | else: 108 | break 109 | 110 | def __len__(self): 111 | return len(self.base_loader) 112 | -------------------------------------------------------------------------------- /core/shims/spark330/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark330 15 | RayDP Shims for Spark 3.3.0 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark330.version} 70 | provided 71 | 72 | 73 | com.google.protobuf 74 | protobuf-java 75 | 76 | 77 | 78 | 79 | org.apache.spark 80 | spark-core_${scala.binary.version} 81 | ${spark330.version} 82 | provided 83 | 84 | 85 | org.xerial.snappy 86 | snappy-java 87 | 88 | 89 | io.netty 90 | netty-handler 91 | 92 | 93 | org.apache.commons 94 | commons-text 95 | 96 | 97 | org.apache.ivy 98 | ivy 99 | 100 | 101 | 102 | 103 | org.xerial.snappy 104 | snappy-java 105 | ${snappy.version} 106 | 107 | 108 | io.netty 109 | netty-handler 110 | ${netty.version} 111 | 112 | 113 | org.apache.commons 114 | commons-text 115 | ${commons.text.version} 116 | 117 | 118 | org.apache.ivy 119 | ivy 120 | ${ivy.version} 121 | 122 | 123 | com.google.protobuf 124 | protobuf-java 125 | ${protobuf.version} 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /python/raydp/mpi/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import select 20 | import subprocess 21 | import threading 22 | import time 23 | from typing import List 24 | 25 | import grpc 26 | import netifaces 27 | 28 | 29 | class StoppableThread(threading.Thread): 30 | 31 | def __init__(self, group=None, target=None, name=None, 32 | args=(), kwargs=None, *, daemon=None): 33 | super().__init__(group, target, name, args, kwargs, daemon=daemon) 34 | self._stop_event = threading.Event() 35 | 36 | def stop(self): 37 | self._stop_event.set() 38 | 39 | def stopped(self): 40 | return self._stop_event.is_set() 41 | 42 | 43 | def run_cmd(cmd: str, env, failed_callback): 44 | # pylint: disable=R1732 45 | proc = subprocess.Popen(cmd, 46 | shell=True, 47 | stdin=subprocess.DEVNULL, 48 | stdout=subprocess.PIPE, 49 | stderr=subprocess.PIPE, 50 | env=env, 51 | start_new_session=True) 52 | 53 | def check_failed(): 54 | # check whether the process has finished 55 | while not threading.current_thread().stopped(): 56 | ret_code = proc.poll() 57 | if ret_code: 58 | failed_callback() 59 | raise Exception(f"mpirun failed: {ret_code}") 60 | 61 | if ret_code == 0: 62 | break 63 | 64 | time.sleep(1) 65 | 66 | check_thread = StoppableThread(target=check_failed) 67 | 68 | def redirect_stream(streams): 69 | while not threading.current_thread().stopped() and streams: 70 | readable, _, _ = select.select(streams, [], [], 0.5) 71 | for stream in readable: 72 | if not stream: 73 | continue 74 | line = stream.readline() 75 | if not line: 76 | streams.remove(stream) 77 | else: 78 | print(line.decode().strip("\n")) 79 | 80 | redirect_thread = StoppableThread(target=redirect_stream, args=([proc.stdout, proc.stderr],)) 81 | check_thread.start() 82 | redirect_thread.start() 83 | return proc, check_thread, redirect_thread 84 | 85 | 86 | def create_insecure_channel(address, 87 | options=None, 88 | compression=None): 89 | """Disable the http proxy when create channel""" 90 | # disable http proxy 91 | if options is not None: 92 | need_add = True 93 | for k, v in options: 94 | if k == "grpc.enable_http_proxy": 95 | need_add = False 96 | break 97 | if need_add: 98 | options = (*options, ("grpc.enable_http_proxy", 0)) 99 | else: 100 | options = (("grpc.enable_http_proxy", 0),) 101 | 102 | return grpc.insecure_channel( 103 | address, options, compression) 104 | 105 | 106 | def get_environ_value(key: str) -> str: 107 | """Get value from environ, raise exception if the key not existed""" 108 | assert key in os.environ, f"{key} should be set in the environ" 109 | return os.environ[key] 110 | 111 | 112 | def get_node_ip_address(node_addresses: List[str]) -> str: 113 | found = None 114 | for interface in netifaces.interfaces(): 115 | addrs = netifaces.ifaddresses(interface) 116 | addresses = addrs.get(netifaces.AF_INET, None) 117 | if not addresses: 118 | continue 119 | for inet_addr in addresses: 120 | address = inet_addr.get("addr", None) 121 | if address in node_addresses: 122 | found = address 123 | return found 124 | -------------------------------------------------------------------------------- /core/shims/spark322/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | 8 | com.intel 9 | raydp-shims 10 | 1.7.0-SNAPSHOT 11 | ../pom.xml 12 | 13 | 14 | raydp-shims-spark322 15 | RayDP Shims for Spark 3.2.2 16 | jar 17 | 18 | 19 | 2.12.15 20 | 2.13.5 21 | 22 | 23 | 24 | 25 | 26 | org.scalastyle 27 | scalastyle-maven-plugin 28 | 29 | 30 | net.alchim31.maven 31 | scala-maven-plugin 32 | 3.2.2 33 | 34 | 35 | scala-compile-first 36 | process-resources 37 | 38 | compile 39 | 40 | 41 | 42 | scala-test-compile-first 43 | process-test-resources 44 | 45 | testCompile 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | src/main/resources 55 | 56 | 57 | 58 | 59 | 60 | 61 | com.intel 62 | raydp-shims-common 63 | ${project.version} 64 | compile 65 | 66 | 67 | org.apache.spark 68 | spark-sql_${scala.binary.version} 69 | ${spark322.version} 70 | provided 71 | 72 | 73 | com.google.protobuf 74 | protobuf-java 75 | 76 | 77 | 78 | 79 | org.apache.spark 80 | spark-core_${scala.binary.version} 81 | ${spark322.version} 82 | provided 83 | 84 | 85 | org.xerial.snappy 86 | snappy-java 87 | 88 | 89 | org.apache.commons 90 | commons-compress 91 | 92 | 93 | org.apache.commons 94 | commons-text 95 | 96 | 97 | org.apache.ivy 98 | ivy 99 | 100 | 101 | log4j 102 | log4j 103 | 104 | 105 | 106 | 107 | org.xerial.snappy 108 | snappy-java 109 | ${snappy.version} 110 | 111 | 112 | org.apache.commons 113 | commons-compress 114 | ${commons.compress.version} 115 | 116 | 117 | org.apache.commons 118 | commons-text 119 | ${commons.text.version} 120 | 121 | 122 | org.apache.ivy 123 | ivy 124 | ${ivy.version} 125 | 126 | 127 | com.google.protobuf 128 | protobuf-java 129 | ${protobuf.version} 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /examples/tensorflow_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import ray 2 | from tensorflow import keras 3 | from tensorflow.keras.callbacks import Callback 4 | 5 | import raydp 6 | from raydp.tf import TFEstimator 7 | from raydp.utils import random_split 8 | 9 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV 10 | from typing import List, Dict 11 | # Firstly, You need to init or connect to a ray cluster. 12 | # Note that you should set include_java to True. 13 | # For more config info in ray, please refer the ray doc: 14 | # https://docs.ray.io/en/latest/package-ref.html 15 | # ray.init(address="auto") 16 | ray.init(address="local", num_cpus=6) 17 | 18 | # After initialize ray cluster, you can use the raydp api to get a spark session 19 | app_name = "NYC Taxi Fare Prediction with RayDP" 20 | num_executors = 1 21 | cores_per_executor = 1 22 | memory_per_executor = "500M" 23 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access 24 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP. 25 | java_opts = " ".join([ 26 | "-XX:+IgnoreUnrecognizedVMOptions", 27 | "--add-opens=java.base/java.lang=ALL-UNNAMED", 28 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", 29 | "--add-opens=java.base/java.io=ALL-UNNAMED", 30 | "--add-opens=java.base/java.net=ALL-UNNAMED", 31 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 32 | "--add-opens=java.base/java.math=ALL-UNNAMED", 33 | "--add-opens=java.base/java.text=ALL-UNNAMED", 34 | "--add-opens=java.base/java.util=ALL-UNNAMED", 35 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", 36 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", 37 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", 38 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", 39 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED", 40 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", 41 | ]) 42 | extra_configs = { 43 | "spark.executor.extraJavaOptions": java_opts, 44 | "spark.driver.extraJavaOptions": java_opts, 45 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts, 46 | } 47 | spark = raydp.init_spark(app_name, num_executors, 48 | cores_per_executor, memory_per_executor, 49 | configs=extra_configs) 50 | 51 | # Then you can code as you are using spark 52 | # The dataset can be downloaded from: 53 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data 54 | # Here we just use a subset of the training data 55 | data = spark.read.format("csv").option("header", "true") \ 56 | .option("inferSchema", "true") \ 57 | .load(NYC_TRAIN_CSV) 58 | # Set spark timezone for processing datetime 59 | spark.conf.set("spark.sql.session.timeZone", "UTC") 60 | # Transform the dataset 61 | data = nyc_taxi_preprocess(data) 62 | data = data.cache() 63 | # Split data into train_dataset and test_dataset 64 | train_df, test_df = random_split(data, [0.9, 0.1], 0) 65 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"] 66 | 67 | # Define the keras model 68 | model = keras.Sequential( 69 | [ 70 | keras.layers.InputLayer(input_shape=(len(features),)), 71 | keras.layers.Flatten(), 72 | keras.layers.Dense(256, activation="relu"), 73 | keras.layers.BatchNormalization(), 74 | keras.layers.Dense(128, activation="relu"), 75 | keras.layers.BatchNormalization(), 76 | keras.layers.Dense(64, activation="relu"), 77 | keras.layers.BatchNormalization(), 78 | keras.layers.Dense(32, activation="relu"), 79 | keras.layers.BatchNormalization(), 80 | keras.layers.Dense(16, activation="relu"), 81 | keras.layers.BatchNormalization(), 82 | keras.layers.Dense(1), 83 | ] 84 | ) 85 | 86 | class PrintingCallback(Callback): 87 | def handle_result(self, results: List[Dict], **info): 88 | print(results) 89 | 90 | # Define the optimizer and loss function 91 | # Then create the tensorflow estimator provided by Raydp 92 | adam = keras.optimizers.Adam(learning_rate=0.001) 93 | loss = keras.losses.MeanSquaredError() 94 | estimator = TFEstimator(num_workers=2, model=model, optimizer=adam, loss=loss, 95 | merge_feature_columns=True, metrics=["mae"], 96 | feature_columns=features, label_columns="fare_amount", 97 | batch_size=256, num_epochs=10, callbacks=[PrintingCallback()]) 98 | 99 | # Train the model 100 | estimator.fit_on_spark(train_df, test_df) 101 | # Get the model 102 | model = estimator.get_model() 103 | # shudown raydp and ray 104 | raydp.stop_spark() 105 | ray.shutdown() 106 | -------------------------------------------------------------------------------- /python/raydp/mpi/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | 19 | from typing import Callable, List 20 | 21 | from .mpi_job import MPIJob, MPIType, IntelMPIJob, OpenMPIJob, MPICHJob, MPIJobContext 22 | from .mpi_worker import WorkerContext 23 | 24 | 25 | def _get_mpi_type(mpi_type: str) -> MPIType: 26 | if mpi_type.strip().lower() == "openmpi": 27 | return MPIType.OPEN_MPI 28 | elif mpi_type.strip().lower() == "intel_mpi": 29 | return MPIType.INTEL_MPI 30 | elif mpi_type.strip().lower() == "mpich": 31 | return MPIType.MPICH 32 | else: 33 | return None 34 | 35 | 36 | def create_mpi_job(job_name: str, 37 | world_size: int, 38 | num_cpus_per_process: int, 39 | num_processes_per_node: int, 40 | mpi_script_prepare_fn: Callable = None, 41 | timeout: int = 1, 42 | mpi_type: str = "intel_mpi", 43 | placement_group=None, 44 | placement_group_bundle_indexes: List[int] = None) -> MPIJob: 45 | """Create a MPI Job 46 | 47 | :param job_name: the job name 48 | :param world_size: the world size 49 | :param num_cpus_per_process: num cpus per process, this used to request resource from Ray 50 | :param num_processes_per_node: num processes per node 51 | :param mpi_script_prepare_fn: a function used to create mpi script, it will pass in a 52 | MPIJobContext instance. It will use the default script if not provides. 53 | :param timeout: the timeout used to wait for job creation 54 | :param mpi_type: the mpi type, now only support openmpi, intel_mpi and MPICH 55 | :param placement_group: the placement_group for request mpi resources 56 | :param placement_group_bundle_indexes: this should be equal with 57 | world_size / num_processes_per_node if provides. 58 | """ 59 | mpi_type = _get_mpi_type(mpi_type) 60 | if mpi_type == MPIType.OPEN_MPI: 61 | return OpenMPIJob(mpi_type=MPIType.OPEN_MPI, 62 | job_name=job_name, 63 | world_size=world_size, 64 | num_cpus_per_process=num_cpus_per_process, 65 | num_processes_per_node=num_processes_per_node, 66 | mpi_script_prepare_fn=mpi_script_prepare_fn, 67 | timeout=timeout, 68 | placement_group=placement_group, 69 | placement_group_bundle_indexes=placement_group_bundle_indexes) 70 | elif mpi_type == MPIType.INTEL_MPI: 71 | return IntelMPIJob(mpi_type=MPIType.INTEL_MPI, 72 | job_name=job_name, 73 | world_size=world_size, 74 | num_cpus_per_process=num_cpus_per_process, 75 | num_processes_per_node=num_processes_per_node, 76 | mpi_script_prepare_fn=mpi_script_prepare_fn, 77 | timeout=timeout, 78 | placement_group=placement_group, 79 | placement_group_bundle_indexes=placement_group_bundle_indexes) 80 | elif mpi_type == MPIType.MPICH: 81 | return MPICHJob(mpi_type=MPIType.MPICH, 82 | job_name=job_name, 83 | world_size=world_size, 84 | num_cpus_per_process=num_cpus_per_process, 85 | num_processes_per_node=num_processes_per_node, 86 | mpi_script_prepare_fn=mpi_script_prepare_fn, 87 | timeout=timeout, 88 | placement_group=placement_group, 89 | placement_group_bundle_indexes=placement_group_bundle_indexes) 90 | else: 91 | raise Exception(f"MPI type: {mpi_type} not supported now") 92 | 93 | 94 | __all__ = ["create_mpi_job", "MPIJobContext", "WorkerContext"] 95 | -------------------------------------------------------------------------------- /examples/pytorch_nyctaxi.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import raydp 7 | from raydp.torch import TorchEstimator 8 | from raydp.utils import random_split 9 | 10 | from data_process import nyc_taxi_preprocess, NYC_TRAIN_CSV 11 | from typing import List, Dict 12 | 13 | # Firstly, You need to init or connect to a ray cluster. 14 | # Note that you should set include_java to True. 15 | # For more config info in ray, please refer the ray doc: 16 | # https://docs.ray.io/en/latest/package-ref.html 17 | # ray.init(address="auto") 18 | ray.init(address="local", num_cpus=4) 19 | 20 | # After initialize ray cluster, you can use the raydp api to get a spark session 21 | app_name = "NYC Taxi Fare Prediction with RayDP" 22 | num_executors = 1 23 | cores_per_executor = 1 24 | memory_per_executor = "500M" 25 | 26 | # JDK 17+ requires --add-opens for reflective access and --add-exports for direct access 27 | # to internal JDK modules. These are needed for Spark, Ray serialization, and RayDP. 28 | java_opts = " ".join([ 29 | "-XX:+IgnoreUnrecognizedVMOptions", 30 | "--add-opens=java.base/java.lang=ALL-UNNAMED", 31 | "--add-opens=java.base/java.lang.invoke=ALL-UNNAMED", 32 | "--add-opens=java.base/java.io=ALL-UNNAMED", 33 | "--add-opens=java.base/java.net=ALL-UNNAMED", 34 | "--add-opens=java.base/java.nio=ALL-UNNAMED", 35 | "--add-opens=java.base/java.math=ALL-UNNAMED", 36 | "--add-opens=java.base/java.text=ALL-UNNAMED", 37 | "--add-opens=java.base/java.util=ALL-UNNAMED", 38 | "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", 39 | "--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED", 40 | "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED", 41 | "--add-opens=java.base/sun.nio.cs=ALL-UNNAMED", 42 | "--add-opens=java.base/sun.security.action=ALL-UNNAMED", 43 | "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED", 44 | ]) 45 | extra_configs = { 46 | "spark.executor.extraJavaOptions": java_opts, 47 | "spark.driver.extraJavaOptions": java_opts, 48 | "spark.ray.raydp_app_master.extraJavaOptions": java_opts, 49 | } 50 | spark = raydp.init_spark(app_name, num_executors, 51 | cores_per_executor, memory_per_executor, 52 | configs=extra_configs) 53 | 54 | # Then you can code as you are using spark 55 | # The dataset can be downloaded from: 56 | # https://www.kaggle.com/c/new-york-city-taxi-fare-prediction/data 57 | # Here we just use a subset of the training data 58 | data = spark.read.format("csv").option("header", "true") \ 59 | .option("inferSchema", "true") \ 60 | .load(NYC_TRAIN_CSV) 61 | # Set spark timezone for processing datetime 62 | spark.conf.set("spark.sql.session.timeZone", "UTC") 63 | # Transform the dataset 64 | data = nyc_taxi_preprocess(data) 65 | # Split data into train_dataset and test_dataset 66 | train_df, test_df = random_split(data, [0.9, 0.1], 0) 67 | features = [field.name for field in list(train_df.schema) if field.name != "fare_amount"] 68 | # Define a neural network model 69 | class NYC_Model(nn.Module): 70 | def __init__(self, cols): 71 | super().__init__() 72 | self.fc1 = nn.Linear(cols, 256) 73 | self.fc2 = nn.Linear(256, 128) 74 | self.fc3 = nn.Linear(128, 64) 75 | self.fc4 = nn.Linear(64, 16) 76 | self.fc5 = nn.Linear(16, 1) 77 | self.bn1 = nn.BatchNorm1d(256) 78 | self.bn2 = nn.BatchNorm1d(128) 79 | self.bn3 = nn.BatchNorm1d(64) 80 | self.bn4 = nn.BatchNorm1d(16) 81 | 82 | def forward(self, x): 83 | x = F.relu(self.fc1(x)) 84 | x = self.bn1(x) 85 | x = F.relu(self.fc2(x)) 86 | x = self.bn2(x) 87 | x = F.relu(self.fc3(x)) 88 | x = self.bn3(x) 89 | x = F.relu(self.fc4(x)) 90 | x = self.bn4(x) 91 | x = self.fc5(x) 92 | return x 93 | 94 | nyc_model = NYC_Model(len(features)) 95 | criterion = nn.SmoothL1Loss() 96 | optimizer = torch.optim.Adam(nyc_model.parameters(), lr=0.001) 97 | # Create a distributed estimator based on the raydp api 98 | estimator = TorchEstimator(num_workers=1, model=nyc_model, optimizer=optimizer, loss=criterion, 99 | feature_columns=features, feature_types=torch.float, 100 | label_column="fare_amount", label_type=torch.float, 101 | batch_size=64, num_epochs=30, 102 | metrics_name = ["MeanAbsoluteError", "MeanSquaredError"], 103 | use_ccl=False) 104 | # Train the model 105 | estimator.fit_on_spark(train_df, test_df) 106 | # Get the trained model 107 | model = estimator.get_model() 108 | # shutdown raydp and ray 109 | raydp.stop_spark() 110 | ray.shutdown() 111 | --------------------------------------------------------------------------------