├── .artifactignore ├── .github └── workflows │ ├── sparkucx-ci.yml │ └── sparkucx-release.yml ├── LICENSE ├── README.md ├── buildlib ├── azure-pipelines.yml └── test.sh ├── pom.xml └── src └── main ├── java └── org │ └── apache │ └── spark │ └── shuffle │ └── ucx │ ├── UcxNode.java │ ├── UnsafeUtils.java │ ├── memory │ ├── MemoryPool.java │ └── RegisteredMemory.java │ ├── reducer │ ├── OnBlocksFetchCallback.java │ ├── ReducerCallback.java │ └── compat │ │ ├── spark_2_1 │ │ ├── OnOffsetsFetchCallback.java │ │ └── UcxShuffleClient.java │ │ ├── spark_2_4 │ │ ├── OnOffsetsFetchCallback.java │ │ └── UcxShuffleClient.java │ │ └── spark_3_0 │ │ ├── OnOffsetsFetchCallback.java │ │ └── UcxShuffleClient.java │ └── rpc │ ├── RpcConnectionCallback.java │ ├── SerializableBlockManagerID.java │ ├── UcxListenerThread.java │ └── UcxRemoteMemory.java └── scala └── org └── apache └── spark └── shuffle ├── CommonUcxShuffleBlockResolver.scala ├── CommonUcxShuffleManager.scala ├── UcxShuffleConf.scala ├── UcxWorkerWrapper.scala └── compat ├── spark_2_1 ├── UcxShuffleBlockResolver.scala ├── UcxShuffleManager.scala └── UcxShuffleReader.scala ├── spark_2_4 ├── UcxShuffleBlockResolver.scala ├── UcxShuffleManager.scala └── UcxShuffleReader.scala └── spark_3_0 ├── UcxLocalDiskShuffleDataIO.scala ├── UcxLocalDiskShuffleExecutorComponents.scala ├── UcxShuffleBlockResolver.scala ├── UcxShuffleManager.scala └── UcxShuffleReader.scala /.artifactignore: -------------------------------------------------------------------------------- 1 | **/* 2 | !target/*.jar 3 | -------------------------------------------------------------------------------- /.github/workflows/sparkucx-ci.yml: -------------------------------------------------------------------------------- 1 | name: SparkUCX CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | build-sparkucx: 10 | strategy: 11 | matrix: 12 | spark_version: ["2.1", "2.4", "3.0"] 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v1 16 | - name: Set up JDK 1.11 17 | uses: actions/setup-java@v1 18 | with: 19 | java-version: 1.11 20 | - name: Build with Maven 21 | run: mvn -B package -Pspark-${{ matrix.spark_version }} -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn 22 | --file pom.xml 23 | - name: Run Sonar code analysis 24 | run: mvn -B sonar:sonar -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Dsonar.projectKey=openucx:spark-ucx -Dsonar.organization=openucx -Dsonar.host.url=https://sonarcloud.io -Dsonar.login=97f4df88ff4fa04e2d5b061acf07315717f1f08b -Pspark-${{ matrix.spark_version }} 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | -------------------------------------------------------------------------------- /.github/workflows/sparkucx-release.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | # Sequence of patterns matched against refs/tags 4 | tags: 5 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 6 | 7 | name: Upload Release Asset 8 | 9 | env: 10 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 11 | 12 | jobs: 13 | release: 14 | strategy: 15 | matrix: 16 | spark_version: ["2.1", "2.4", "3.0"] 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v2 21 | 22 | - name: Set up JDK 1.11 23 | uses: actions/setup-java@v1 24 | with: 25 | java-version: 1.11 26 | 27 | - name: Build with Maven 28 | id: maven_package 29 | run: | 30 | mvn -B -Pspark-${{ matrix.spark_version }} clean package \ 31 | -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn \ 32 | --file pom.xml 33 | cd target 34 | echo "::set-output name=jar_name::$(echo spark-ucx-*-jar-with-dependencies.jar)" 35 | 36 | - name: Upload Release Jars 37 | uses: svenstaro/upload-release-action@v1-release 38 | with: 39 | repo_token: ${{ secrets.GITHUB_TOKEN }} 40 | file: ./target/${{ steps.maven_package.outputs.jar_name }} 41 | asset_name: ${{ steps.maven_package.outputs.jar_name }} 42 | tag: ${{ github.ref }} 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Mellanox Technologies Ltd. All rights reserved. 2 | Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 3. Neither the name of the copyright holder nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 23 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 26 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparkUCX ShuffleManager Plugin 2 | SparkUCX is a high performance ShuffleManager plugin for Apache Spark, that uses RDMA and other high performance transports 3 | that are supported by [UCX](https://github.com/openucx/ucx#supported-transports), to perform Shuffle data transfers in Spark jobs. 4 | 5 | This open-source project is developed, maintained and supported by the [UCF consortium](http://www.ucfconsortium.org/). 6 | 7 | ## Runtime requirements 8 | * Apache Spark 2.3/2.4/3.0 9 | * Java 8+ 10 | * Installed UCX of version 1.10+, and [UCX supported transport hardware](https://github.com/openucx/ucx#supported-transports). 11 | 12 | ## Installation 13 | 14 | ### Obtain SparkUCX 15 | Please use the ["Releases"](https://github.com/openucx/sparkucx/releases) page to download SparkUCX jar file 16 | for your spark version (e.g. spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar). 17 | Put SparkUCX jar file in $SPARK_UCX_HOME on all the nodes in your cluster. 18 |
If you would like to build the project yourself, please refer to the ["Build"](https://github.com/openucx/sparkucx#build) section below. 19 | 20 | Ucx binaries **must** be in Spark classpath on every Spark Master and Worker. 21 | It can be obtained by installing the latest version from [Ucx release page](https://github.com/openucx/ucx/releases) 22 | 23 | ### Configuration 24 | 25 | Provide Spark the location of the SparkUCX plugin jars and ucx shared binaries by using the extraClassPath option. 26 | 27 | ``` 28 | spark.driver.extraClassPath $SPARK_UCX_HOME/spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar:$UCX_PREFIX/lib 29 | spark.executor.extraClassPath $SPARK_UCX_HOME/spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar:$UCX_PREFIX/lib 30 | ``` 31 | To enable the SparkUCX Shuffle Manager plugin, add the following configuration property 32 | to spark (e.g. in $SPARK_HOME/conf/spark-defaults.conf): 33 | 34 | ``` 35 | spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager 36 | ``` 37 | For spark-3.0 version add SparkUCX ShuffleIO plugin: 38 | ``` 39 | spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO 40 | ``` 41 | 42 | ### Build 43 | 44 | Building the SparkUCX plugin requires [Apache Maven](http://maven.apache.org/) and Java 8+ JDK 45 | 46 | Build instructions: 47 | 48 | ``` 49 | % git clone https://github.com/openucx/sparkucx 50 | % cd sparkucx 51 | % mvn -DskipTests clean package -Pspark-2.4 52 | ``` 53 | 54 | ### Performance 55 | 56 | SparkUCX plugin is built to provide the best performance out-of-the-box, and provides multiple configuration options to further tune SparkUCX per-job. For more information on how to setup [HiBench](https://github.com/Intel-bigdata/HiBench) benchmark and reproduce results, please refer to [Accelerated Apache SparkUCX 2.4/3.0 cluster deployment](https://docs.mellanox.com/pages/releaseview.action?pageId=19819236). 57 | 58 | ![Performance results](https://docs.mellanox.com/download/attachments/19819236/image2020-1-23_15-39-14.png) 59 | 60 | -------------------------------------------------------------------------------- /buildlib/azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # See https://aka.ms/yaml 2 | 3 | trigger: 4 | - master 5 | - v*.*.x 6 | pr: 7 | - master 8 | - v*.*.x 9 | 10 | stages: 11 | - stage: Build 12 | jobs: 13 | - job: build 14 | strategy: 15 | maxParallel: 1 16 | matrix: 17 | spark-2.4: 18 | profile_version: "2.4" 19 | spark_version: "2.4.5" 20 | spark-3.0: 21 | profile_version: "3.0" 22 | spark_version: "3.0.1" 23 | pool: 24 | name: MLNX 25 | demands: Maven 26 | steps: 27 | - task: Maven@3 28 | displayName: build 29 | inputs: 30 | javaHomeOption: "path" 31 | jdkDirectory: "/hpc/local/oss/java/jdk/" 32 | jdkVersionOption: "1.8" 33 | mavenVersionSelection: "Path" 34 | mavenPath: "/hpc/local/oss/apache-maven-3.3.9" 35 | mavenSetM2Home: true 36 | publishJUnitResults: false 37 | goals: "package" 38 | options: "-B -Dmaven.repo.local=$(System.DefaultWorkingDirectory)/target/.deps -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Pspark-$(profile_version)" 39 | - bash: | 40 | set -xeE 41 | module load tools/spark-$(spark_version) hpcx-gcc 42 | source buildlib/test.sh 43 | 44 | if [[ $(get_rdma_device_iface) != "" ]] 45 | then 46 | export SPARK_UCX_JAR=$(System.DefaultWorkingDirectory)/target/spark-ucx-1.0-for-spark-$(profile_version)-jar-with-dependencies.jar 47 | export SPARK_LOCAL_DIRS=$(System.DefaultWorkingDirectory)/target/spark 48 | export SPARK_VERSION=$(spark_version) 49 | export UCX_LIB=$HPCX_UCX_DIR/lib 50 | cd $(System.DefaultWorkingDirectory)/target/ 51 | run_tests 52 | else 53 | echo ##vso[task.complete result=Skipped;]No IB devices found 54 | fi 55 | displayName: Run spark tests 56 | -------------------------------------------------------------------------------- /buildlib/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eExl 2 | # 3 | # Testing script for SparkUCX 4 | # 5 | # Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 6 | # 7 | # Environment variables: 8 | # - NODELIST : Space separated list of nodes (default: localhost) 9 | # - UCX_LIB : Path to UCX installation (default: LD_LIBRARY_PATH) 10 | # - SPARK_UCX_JAR : Path to SparkUCX jar (default in $PWD) 11 | # - PROCESSES_PER_INSTANCE : Number of spark processes per instance (default: 2) 12 | # - EXECUTOR_NUMBER : Unique number of jenkins executor, to run multiple tests on 1 machine 13 | # - JOB_ID : Unique job id for this test 14 | # - SCRATCH_DIRECTORY : Separate directory for this test where configs will be added. 15 | # For now it's required to be on NFS 16 | # - SPARK_LOCAL_DIRS : Directory to use for map output files and RDDs that get stored on disk. 17 | # This should be on a fast, local disk in your system. 18 | # It can also be a comma-separated list of multiple directories on different disks. 19 | # - SPARK_WORKER_CORES : Number of cores per spark process (default: 2) 20 | # - SPARK_WORKER_MEMORY : Memory per spark process (default: 1g) 21 | # - SPARK_HOME : Spark location on the file system (default: $PWD/spark). If not set, will be downloaded 22 | # - SPARK_VERSION : Version of spark to test (default: 2.4.4). Used only in SPARK_HOME is not set 23 | # - RDMA_NET_IFACE : Network interface to test 24 | 25 | NODELIST=${NODELIST:="localhost"} 26 | 27 | UCX_LIB=${UCX_LIB:=${LD_LIBRARY_PATH}} 28 | 29 | SPARK_UCX_JAR=${SPARK_UCX_JAR:=$PWD/spark-ucx-1.0-for-spark-2.4-jar-with-dependencies.jar} 30 | 31 | PROCESSES_PER_INSTANCE=${PROCESSES_PER_INSTANCE:=2} 32 | 33 | if [ -n "$EXECUTOR_NUMBER" ] 34 | then 35 | AFFINITY="taskset -c $(( 2 * EXECUTOR_NUMBER ))","$(( 2 * EXECUTOR_NUMBER + 1))" 36 | else 37 | AFFINITY="" 38 | fi 39 | 40 | EXECUTOR_NUMBER=${EXECUTOR_NUMBER:=$((RANDOM % `nproc`))} 41 | 42 | JOB_ID=${SLURM_JOB_ID:="sparkucx-test-$(cat /proc/sys/kernel/random/uuid)"} 43 | 44 | SCRATCH_DIRECTORY=${SCRATCH_DIRECTORY:=$PWD/${JOB_ID}} 45 | 46 | SPARK_LOCAL_DIRS=${SPARK_LOCAL_DIRS:="/scrap/sparkucxtest/sparkucx-${JOB_ID}"} 47 | 48 | SPARK_WORKER_CORES=${SPARK_WORKER_CORES:=2} 49 | 50 | SPARK_WORKER_MEMORY=${SPARK_WORKER_MEMORY:="1g"} 51 | 52 | SPARK_HOME=${SPARK_HOME:=$PWD/spark} 53 | 54 | SPARK_VERSION=${SPARK_VERSION:="2.4.4"} 55 | 56 | UCX_BRANCH=${UCX_BRANCH:="master"} 57 | 58 | export SPARK_MASTER_HOST=${SPARK_MASTER_HOST:=$(hostname -f)} 59 | export SPARK_MASTER_PORT=${SPARK_MASTER_PORT:=$(( 2000 + 10 * ${EXECUTOR_NUMBER}))} 60 | export SPARK_CONF_DIR=${SCRATCH_DIRECTORY}/conf 61 | 62 | get_rdma_device_iface() { 63 | 64 | if [ ! -r /dev/infiniband/rdma_cm ] 65 | then 66 | return 67 | fi 68 | 69 | if ! which ibdev2netdev >&/dev/null 70 | then 71 | return 72 | fi 73 | 74 | iface=`ibdev2netdev | grep Up | awk '{print $5}' | head -1` 75 | if [[ -n "$iface" ]] 76 | then 77 | ipaddr=$(ip addr show ${iface} | awk '/inet /{print $2}' | awk -F '/' '{print $1}') 78 | fi 79 | 80 | if [[ -z "$ipaddr" ]] 81 | then 82 | # if there is no inet (IPv4) address, escape 83 | return 84 | fi 85 | 86 | ibdev=`ibdev2netdev | grep ${iface} | awk '{print $1}'` 87 | node_guid=`cat /sys/class/infiniband/$ibdev/node_guid` 88 | if [[ ${node_guid} == "0000:0000:0000:0000" ]] 89 | then 90 | return 91 | fi 92 | 93 | SPARK_MASTER_HOST=$ipaddr 94 | echo ${iface} 95 | } 96 | 97 | RDMA_NET_IFACE=${RDMA_NET_IFACE:=`get_rdma_device_iface`} 98 | 99 | download_spark() { 100 | curl -s -L -O https://www-eu.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz 101 | mkdir -p ${SPARK_HOME} 102 | tar -xf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz -C ${SPARK_HOME} --strip-components=1 103 | } 104 | 105 | build_ucx() { 106 | git clone -b ${UCX_BRANCH} --depth=1 https://github.com/openucx/ucx.git && cd ucx 107 | ./autogen.sh 108 | mkdir build && cd build 109 | ../contrib/configure-release-mt --with-java --prefix=$PWD 110 | make -j `nproc` 111 | make install 112 | UCX_LIB=$PWD/lib/ 113 | } 114 | 115 | setup_configuration() { 116 | mkdir -p ${SPARK_CONF_DIR} 117 | 118 | echo ${NODELIST} | tr -s ' ' '\n' >> "${SPARK_CONF_DIR}/slaves" 119 | 120 | cat <<-EOF > ${SPARK_CONF_DIR}/spark-defaults.conf 121 | spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager 122 | spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO 123 | spark.shuffle.readHostLocalDisk.enabled false 124 | spark.driver.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB} 125 | spark.executor.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB} 126 | spark.shuffle.ucx.driver.port $(( ${SPARK_MASTER_PORT} + 1 )) 127 | EOF 128 | 129 | cat <<-EOF > ${SPARK_CONF_DIR}/spark-env.sh 130 | export SPARK_LOCAL_IP=\`/sbin/ip addr show ${RDMA_NET_IFACE} | grep "inet\b" | awk '{print \$2}' | cut -d/ -f1\` 131 | export SPARK_WORKER_DIR=${SCRATCH_DIRECTORY}/work 132 | export SPARK_LOCAL_DIRS=${SPARK_LOCAL_DIRS} 133 | export SPARK_LOG_DIR=${SCRATCH_DIRECTORY}/logs 134 | export SPARK_CONF_DIR=${SPARK_CONF_DIR} 135 | export SPARK_MASTER_HOST=${SPARK_MASTER_HOST} 136 | export SPARK_MASTER_PORT=${SPARK_MASTER_PORT} 137 | export SPARK_WORKER_CORES=${SPARK_WORKER_CORES} 138 | export SPARK_WORKER_MEMORY=${SPARK_WORKER_MEMORY} 139 | export SPARK_IDENT_STRING=${JOB_ID} 140 | EOF 141 | 142 | cp ${SPARK_HOME}/conf/log4j.properties.template ${SPARK_CONF_DIR}/log4j.properties 143 | sed -i -e 's/INFO/WARN/g' ${SPARK_CONF_DIR}/log4j.properties 144 | echo "log4j.logger.org.apache.spark.shuffle=DEBUG" >> ${SPARK_CONF_DIR}/log4j.properties 145 | } 146 | 147 | start_cluster() { 148 | ${AFFINITY} ${SPARK_HOME}/sbin/start-master.sh 149 | 150 | # Make a script wrapper to propagate SPARK_CONF_DIR 151 | cat <<-EOF > ${SCRATCH_DIRECTORY}/sparkworker.sh 152 | #! /bin/bash 153 | export SPARK_CONF_DIR=${SPARK_CONF_DIR} 154 | export SPARK_WORKER_INSTANCES=${PROCESSES_PER_INSTANCE} 155 | export SPARK_IDENT_STRING=${JOB_ID} 156 | ${AFFINITY} ${SPARK_HOME}/sbin/start-slave.sh "spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT}" 157 | EOF 158 | 159 | SPARK_CONF_DIR=${SPARK_CONF_DIR} ${SPARK_HOME}/sbin/slaves.sh bash ${SCRATCH_DIRECTORY}/sparkworker.sh 160 | } 161 | 162 | run_groupby_test() { 163 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \ 164 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \ 165 | org.apache.spark.examples.GroupByTest 100 100 166 | } 167 | 168 | run_tc_test() { 169 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \ 170 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \ 171 | org.apache.spark.examples.SparkTC 172 | } 173 | 174 | run_tests() { 175 | if [[ ! -d ${SPARK_HOME} ]] 176 | then 177 | download_spark 178 | fi 179 | 180 | if [[ ! -d ${UCX_LIB} ]] 181 | then 182 | build_ucx 183 | fi 184 | 185 | trap stop_cluster EXIT; 186 | 187 | setup_configuration 188 | start_cluster 189 | run_groupby_test && run_tc_test 190 | } 191 | 192 | stop_cluster() { 193 | cat <<-EOF > ${SCRATCH_DIRECTORY}/stop-sparkworker.sh 194 | #! /bin/bash 195 | export SPARK_CONF_DIR=${SPARK_CONF_DIR} 196 | export SPARK_WORKER_INSTANCES=${PROCESSES_PER_INSTANCE} 197 | export SPARK_IDENT_STRING=${JOB_ID} 198 | ${SPARK_HOME}/sbin/stop-slave.sh 199 | EOF 200 | 201 | chmod +x ${SCRATCH_DIRECTORY}/stop-sparkworker.sh 202 | # Stop all slaves 203 | ${SPARK_HOME}/sbin/slaves.sh ${SCRATCH_DIRECTORY}/stop-sparkworker.sh 204 | 205 | ${SPARK_HOME}/sbin/stop-master.sh 206 | } 207 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 6 | 7 | 11 | 4.0.0 12 | org.openucx 13 | spark-ucx 14 | 1.0 15 | ${project.artifactId} 16 | 17 | A high-performance, scalable and efficient shuffle manager plugin for Apache Spark, 18 | utilizing UCX communication layer (https://github.com/openucx/ucx/). 19 | 20 | jar 21 | 22 | 23 | 24 | BSD 3 Clause License 25 | http://www.openucx.org/license/ 26 | repo 27 | 28 | 29 | 30 | 31 | 1.8 32 | 1.8 33 | UTF-8 34 | 35 | 36 | 37 | 38 | spark-2.1 39 | 40 | 41 | 42 | org.apache.maven.plugins 43 | maven-compiler-plugin 44 | 45 | 46 | **/spark_3_0/** 47 | **/spark_2_4/** 48 | 49 | 50 | 51 | 52 | net.alchim31.maven 53 | scala-maven-plugin 54 | 55 | 56 | **/spark_3_0/** 57 | **/spark_2_4/** 58 | 59 | 60 | 61 | 62 | 63 | 64 | 2.1.0 65 | **/spark_3_0/**, **/spark_2_4/** 66 | 2.11.12 67 | 2.11 68 | 69 | 70 | 71 | spark-2.4 72 | 73 | 74 | 75 | org.apache.maven.plugins 76 | maven-compiler-plugin 77 | 78 | 79 | **/spark_3_0/** 80 | **/spark_2_1/** 81 | 82 | 83 | 84 | 85 | net.alchim31.maven 86 | scala-maven-plugin 87 | 88 | 89 | **/spark_2_1/** 90 | **/spark_3_0/** 91 | 92 | 93 | 94 | 95 | 96 | 97 | 2.4.0 98 | **/spark_3_0/**, **/spark_2_1/** 99 | 2.11.12 100 | 2.11 101 | 102 | 103 | 104 | spark-3.0 105 | 106 | true 107 | 108 | 109 | 110 | 111 | org.apache.maven.plugins 112 | maven-compiler-plugin 113 | 114 | 115 | **/spark_2_1/** 116 | **/spark_2_4/** 117 | 118 | 119 | 120 | 121 | net.alchim31.maven 122 | scala-maven-plugin 123 | 124 | 125 | **/spark_2_1/** 126 | **/spark_2_4/** 127 | 128 | 129 | 130 | 131 | 132 | 133 | 3.0.1 134 | 2.12.10 135 | 2.12 136 | **/spark_2_1/**, **/spark_2_4/** 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | org.apache.spark 145 | spark-core_${scala.compat.version} 146 | ${spark.version} 147 | provided 148 | 149 | 150 | org.openucx 151 | jucx 152 | 1.11.0-rc3 153 | 154 | 155 | 156 | 157 | ${project.artifactId}-${project.version}-for-${project.activeProfiles[0].id} 158 | 159 | 160 | org.apache.maven.plugins 161 | maven-compiler-plugin 162 | 3.8.1 163 | 164 | 1.8 165 | 1.8 166 | 167 | 168 | 169 | net.alchim31.maven 170 | scala-maven-plugin 171 | 4.3.0 172 | 173 | all 174 | 175 | -nobootcp 176 | -Xexperimental 177 | -Xfatal-warnings 178 | -explaintypes 179 | -unchecked 180 | -deprecation 181 | -feature 182 | 183 | 184 | 185 | 186 | compile 187 | 188 | compile 189 | 190 | compile 191 | 192 | 193 | process-resources 194 | 195 | compile 196 | 197 | 198 | 199 | 200 | 201 | maven-assembly-plugin 202 | 3.1.1 203 | 204 | 205 | jar-with-dependencies 206 | 207 | 208 | 209 | 210 | make-assembly 211 | package 212 | 213 | single 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | oss.sonatype.org-snapshot 224 | http://oss.sonatype.org/content/repositories/snapshots 225 | 226 | false 227 | 228 | 229 | true 230 | 231 | 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx; 6 | 7 | import org.apache.spark.SparkEnv; 8 | import org.apache.spark.shuffle.UcxShuffleConf; 9 | import org.apache.spark.shuffle.UcxWorkerWrapper; 10 | import org.apache.spark.shuffle.ucx.memory.MemoryPool; 11 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 12 | import org.apache.spark.shuffle.ucx.rpc.SerializableBlockManagerID; 13 | import org.apache.spark.shuffle.ucx.rpc.UcxListenerThread; 14 | import org.apache.spark.storage.BlockManagerId; 15 | import org.openucx.jucx.UcxCallback; 16 | import org.openucx.jucx.UcxException; 17 | import org.openucx.jucx.ucp.*; 18 | 19 | import java.io.Closeable; 20 | import java.io.IOException; 21 | import java.net.InetSocketAddress; 22 | import java.nio.ByteBuffer; 23 | import java.util.ArrayList; 24 | import java.util.List; 25 | import java.util.Set; 26 | import java.util.concurrent.ConcurrentHashMap; 27 | import java.util.concurrent.ConcurrentMap; 28 | 29 | import org.slf4j.Logger; 30 | import org.slf4j.LoggerFactory; 31 | 32 | /** 33 | * Single instance class per spark process, that keeps UcpContext, memory and worker pools. 34 | */ 35 | public class UcxNode implements Closeable { 36 | // Global 37 | private static final Logger logger = LoggerFactory.getLogger(UcxNode.class); 38 | private final boolean isDriver; 39 | private final UcpContext context; 40 | private final MemoryPool memoryPool; 41 | private final UcpWorkerParams workerParams = new UcpWorkerParams(); 42 | private final UcpWorker globalWorker; 43 | private final UcxShuffleConf conf; 44 | // Mapping from spark's entity of BlockManagerId to UcxEntity workerAddress. 45 | private static final ConcurrentHashMap workerAdresses = 46 | new ConcurrentHashMap<>(); 47 | private final Thread listenerProgressThread; 48 | private boolean closed = false; 49 | 50 | // Driver 51 | private UcpListener listener; 52 | // Mapping from UcpEndpoint to ByteBuffer of RPC message, to introduce executor to cluster 53 | private static final ConcurrentHashMap rpcConnections = 54 | new ConcurrentHashMap<>(); 55 | private List backwardEndpoints = new ArrayList<>(); 56 | 57 | // Executor 58 | private UcpEndpoint globalDriverEndpoint; 59 | // Keep track of allocated workers to correctly close them. 60 | private static final Set allocatedWorkers = ConcurrentHashMap.newKeySet(); 61 | private final ThreadLocal threadLocalWorker; 62 | 63 | public UcxNode(UcxShuffleConf conf, boolean isDriver) { 64 | this.conf = conf; 65 | this.isDriver = isDriver; 66 | UcpParams params = new UcpParams().requestTagFeature() 67 | .requestRmaFeature().requestWakeupFeature() 68 | .setMtWorkersShared(true); 69 | context = new UcpContext(params); 70 | memoryPool = new MemoryPool(context, conf); 71 | globalWorker = context.newWorker(workerParams); 72 | InetSocketAddress driverAddress = new InetSocketAddress(conf.driverHost(), conf.driverPort()); 73 | 74 | if (isDriver) { 75 | startDriver(driverAddress); 76 | } else { 77 | startExecutor(driverAddress); 78 | } 79 | 80 | // Global listener thread, that keeps lazy progress for connection establishment 81 | listenerProgressThread = new UcxListenerThread(this, isDriver); 82 | listenerProgressThread.start(); 83 | 84 | if (!isDriver) { 85 | memoryPool.preAlocate(); 86 | } 87 | 88 | threadLocalWorker = ThreadLocal.withInitial(() -> { 89 | UcpWorker localWorker = context.newWorker(workerParams); 90 | UcxWorkerWrapper result = new UcxWorkerWrapper(localWorker, 91 | conf, allocatedWorkers.size()); 92 | if (result.id() > conf.coresPerProcess()) { 93 | logger.warn("Thread: {} - creates new worker {} > numCores", 94 | Thread.currentThread().getId(), result.id()); 95 | } 96 | allocatedWorkers.add(result); 97 | return result; 98 | }); 99 | } 100 | 101 | private void startDriver(InetSocketAddress driverAddress) { 102 | // 1. Start listener on a driver and accept RPC messages from executors with their 103 | // worker addresses 104 | UcpListenerParams listenerParams = new UcpListenerParams().setSockAddr(driverAddress) 105 | .setConnectionHandler(ucpConnectionRequest -> 106 | backwardEndpoints.add(globalWorker.newEndpoint(new UcpEndpointParams() 107 | .setConnectionRequest(ucpConnectionRequest)))); 108 | listener = globalWorker.newListener(listenerParams); 109 | logger.info("Started UcxNode on {}", driverAddress); 110 | } 111 | 112 | /** 113 | * Allocates ByteBuffer from memoryPool and serializes there workerAddress, 114 | * followed by BlockManagerID 115 | * @return RegisteredMemory that holds metadata buffer. 116 | */ 117 | private RegisteredMemory buildMetadataBuffer() { 118 | BlockManagerId blockManagerId = SparkEnv.get().blockManager().blockManagerId(); 119 | ByteBuffer workerAddresses = globalWorker.getAddress(); 120 | 121 | RegisteredMemory metadataMemory = memoryPool.get(conf.metadataRPCBufferSize()); 122 | ByteBuffer metadataBuffer = metadataMemory.getBuffer(); 123 | metadataBuffer.putInt(workerAddresses.capacity()); 124 | metadataBuffer.put(workerAddresses); 125 | try { 126 | SerializableBlockManagerID.serializeBlockManagerID(blockManagerId, metadataBuffer); 127 | } catch (IOException e) { 128 | String errorMsg = String.format("Failed to serialize %s: %s", blockManagerId, 129 | e.getMessage()); 130 | throw new UcxException(errorMsg); 131 | } 132 | metadataBuffer.clear(); 133 | return metadataMemory; 134 | } 135 | 136 | private void startExecutor(InetSocketAddress driverAddress) { 137 | // 1. Executor: connect to driver using sockaddr 138 | // and send it's worker address followed by BlockManagerID. 139 | globalDriverEndpoint = globalWorker.newEndpoint( 140 | new UcpEndpointParams().setSocketAddress(driverAddress).setPeerErrorHandlingMode() 141 | ); 142 | 143 | RegisteredMemory metadataMemory = buildMetadataBuffer(); 144 | // TODO: send using stream API when it would be available in jucx. 145 | globalDriverEndpoint.sendTaggedNonBlocking(metadataMemory.getBuffer(), new UcxCallback() { 146 | @Override 147 | public void onSuccess(UcpRequest request) { 148 | memoryPool.put(metadataMemory); 149 | } 150 | }); 151 | } 152 | 153 | public UcxShuffleConf getConf() { 154 | return conf; 155 | } 156 | 157 | public UcpWorker getGlobalWorker() { 158 | return globalWorker; 159 | } 160 | 161 | public MemoryPool getMemoryPool() { 162 | return memoryPool; 163 | } 164 | 165 | public UcpContext getContext() { 166 | return context; 167 | } 168 | 169 | /** 170 | * Get or initialize worker for current thread 171 | */ 172 | public UcxWorkerWrapper getThreadLocalWorker() { 173 | return threadLocalWorker.get(); 174 | } 175 | 176 | public static ConcurrentMap getWorkerAddresses() { 177 | return workerAdresses; 178 | } 179 | 180 | public static ConcurrentMap getRpcConnections() { 181 | return rpcConnections; 182 | } 183 | 184 | private void stopDriver() { 185 | if (listener != null) { 186 | listener.close(); 187 | listener = null; 188 | } 189 | rpcConnections.keySet().forEach(UcpEndpoint::close); 190 | rpcConnections.clear(); 191 | backwardEndpoints.forEach(UcpEndpoint::close); 192 | backwardEndpoints.clear(); 193 | } 194 | 195 | private void stopExecutor() { 196 | if (globalDriverEndpoint != null) { 197 | globalDriverEndpoint.close(); 198 | globalDriverEndpoint = null; 199 | } 200 | allocatedWorkers.forEach(UcxWorkerWrapper::close); 201 | allocatedWorkers.clear(); 202 | } 203 | 204 | @Override 205 | public void close() { 206 | threadLocalWorker.remove(); 207 | synchronized (this) { 208 | if (!closed) { 209 | logger.info("Stopping UcxNode"); 210 | listenerProgressThread.interrupt(); 211 | globalWorker.signal(); 212 | try { 213 | listenerProgressThread.join(); 214 | if (isDriver) { 215 | stopDriver(); 216 | } else { 217 | stopExecutor(); 218 | } 219 | memoryPool.close(); 220 | globalWorker.close(); 221 | context.close(); 222 | closed = true; 223 | } catch (InterruptedException e) { 224 | logger.error(e.getMessage()); 225 | Thread.currentThread().interrupt(); 226 | } catch (Exception ex) { 227 | logger.warn(ex.getLocalizedMessage()); 228 | } 229 | } 230 | } 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx; 6 | 7 | import org.openucx.jucx.UcxException; 8 | import org.slf4j.Logger; 9 | import org.slf4j.LoggerFactory; 10 | import sun.nio.ch.FileChannelImpl; 11 | 12 | import java.io.IOException; 13 | import java.lang.reflect.Constructor; 14 | import java.lang.reflect.InvocationTargetException; 15 | import java.lang.reflect.Method; 16 | import java.nio.ByteBuffer; 17 | import java.nio.channels.FileChannel; 18 | 19 | /** 20 | * Java's native mmap functionality, that allows to mmap files > 2GB. 21 | */ 22 | public class UnsafeUtils { 23 | private static final Method mmap; 24 | private static final Method unmmap; 25 | private static final Logger logger = LoggerFactory.getLogger(UnsafeUtils.class); 26 | 27 | private static final Constructor directBufferConstructor; 28 | 29 | public static final int LONG_SIZE = 8; 30 | public static final int INT_SIZE = 4; 31 | 32 | static { 33 | try { 34 | mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class); 35 | mmap.setAccessible(true); 36 | unmmap = FileChannelImpl.class.getDeclaredMethod("unmap0", long.class, long.class); 37 | unmmap.setAccessible(true); 38 | Class classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer"); 39 | directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class); 40 | directBufferConstructor.setAccessible(true); 41 | } catch (Exception e) { 42 | throw new RuntimeException(e); 43 | } 44 | } 45 | 46 | private UnsafeUtils() {} 47 | 48 | public static long mmap(FileChannel fileChannel, long offset, long length) { 49 | long result; 50 | try { 51 | result = (long)mmap.invoke(fileChannel, 1, offset, length); 52 | } catch (Exception e) { 53 | logger.error("MMap({}, {}) failed: {}", offset, length, e.getMessage()); 54 | throw new UcxException(e.getMessage()); 55 | } 56 | return result; 57 | } 58 | 59 | public static void munmap(long address, long length) { 60 | try { 61 | unmmap.invoke(null, address, length); 62 | } catch (IllegalAccessException | InvocationTargetException e) { 63 | logger.error(e.getMessage()); 64 | } 65 | } 66 | 67 | public static ByteBuffer getByteBuffer(long address, int length) throws IOException { 68 | try { 69 | return (ByteBuffer)directBufferConstructor.newInstance(address, length); 70 | } catch (InvocationTargetException ex) { 71 | throw new IOException("java.nio.DirectByteBuffer: " + 72 | "InvocationTargetException: " + ex.getTargetException()); 73 | } catch (Exception e) { 74 | throw new IOException("java.nio.DirectByteBuffer exception: " + e.getMessage()); 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.memory; 6 | 7 | import org.apache.spark.shuffle.UcxShuffleConf; 8 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 9 | import org.openucx.jucx.UcxException; 10 | import org.openucx.jucx.UcxUtils; 11 | import org.openucx.jucx.ucp.UcpContext; 12 | import org.openucx.jucx.ucp.UcpMemMapParams; 13 | import org.openucx.jucx.ucp.UcpMemory; 14 | import org.slf4j.Logger; 15 | import org.slf4j.LoggerFactory; 16 | 17 | import java.io.Closeable; 18 | import java.nio.ByteBuffer; 19 | import java.util.concurrent.ConcurrentHashMap; 20 | import java.util.concurrent.ConcurrentLinkedDeque; 21 | import java.util.concurrent.atomic.AtomicInteger; 22 | 23 | /** 24 | * Utility class to reuse and preallocate registered memory to avoid memory allocation 25 | * and registration during shuffle phase. 26 | */ 27 | public class MemoryPool implements Closeable { 28 | private static final Logger logger = LoggerFactory.getLogger(MemoryPool.class); 29 | 30 | @Override 31 | public void close() { 32 | for (AllocatorStack stack: allocStackMap.values()) { 33 | stack.close(); 34 | logger.info("Stack of size {}. " + 35 | "Total requests: {}, total allocations: {}, preAllocations: {}", 36 | stack.length, stack.totalRequests.get(), stack.totalAlloc.get(), stack.preAllocs.get()); 37 | } 38 | allocStackMap.clear(); 39 | } 40 | 41 | private class AllocatorStack implements Closeable { 42 | private final AtomicInteger totalRequests = new AtomicInteger(0); 43 | private final AtomicInteger totalAlloc = new AtomicInteger(0); 44 | private final AtomicInteger preAllocs = new AtomicInteger(0); 45 | private final ConcurrentLinkedDeque stack = new ConcurrentLinkedDeque<>(); 46 | private final int length; 47 | 48 | private AllocatorStack(int length) { 49 | this.length = length; 50 | } 51 | 52 | private RegisteredMemory get() { 53 | RegisteredMemory result = stack.pollFirst(); 54 | if (result == null) { 55 | if (length < conf.minRegistrationSize()) { 56 | int numBuffers = conf.minRegistrationSize() / length; 57 | logger.debug("Allocating {} buffers of size {}", numBuffers, length); 58 | preallocate(numBuffers); 59 | result = stack.pollFirst(); 60 | if (result == null) { 61 | return get(); 62 | } else { 63 | result.getRefCount().incrementAndGet(); 64 | } 65 | } else { 66 | UcpMemMapParams memMapParams = new UcpMemMapParams().setLength(length).allocate(); 67 | UcpMemory memory = context.memoryMap(memMapParams); 68 | ByteBuffer buffer; 69 | try { 70 | buffer = UcxUtils.getByteBufferView(memory.getAddress(), (int)memory.getLength()); 71 | } catch (Exception e) { 72 | throw new UcxException(e.getMessage()); 73 | } 74 | result = new RegisteredMemory(new AtomicInteger(1), memory, buffer); 75 | totalAlloc.incrementAndGet(); 76 | } 77 | } else { 78 | result.getRefCount().incrementAndGet(); 79 | } 80 | totalRequests.incrementAndGet(); 81 | return result; 82 | } 83 | 84 | private void put(RegisteredMemory registeredMemory) { 85 | registeredMemory.getRefCount().decrementAndGet(); 86 | stack.addLast(registeredMemory); 87 | } 88 | 89 | private void preallocate(int numBuffers) { 90 | // Platform.allocateDirectBuffer supports only 2GB of buffer. 91 | // Decrease number of buffers if total size of preAllocation > 2GB. 92 | if ((long)length * (long)numBuffers > Integer.MAX_VALUE) { 93 | numBuffers = Integer.MAX_VALUE / length; 94 | } 95 | 96 | UcpMemMapParams memMapParams = new UcpMemMapParams().allocate().setLength(numBuffers * (long)length); 97 | UcpMemory memory = context.memoryMap(memMapParams); 98 | ByteBuffer buffer; 99 | try { 100 | buffer = UnsafeUtils.getByteBuffer(memory.getAddress(), numBuffers * length); 101 | } catch (Exception ex) { 102 | throw new UcxException(ex.getMessage()); 103 | } 104 | 105 | AtomicInteger refCount = new AtomicInteger(numBuffers); 106 | for (int i = 0; i < numBuffers; i++) { 107 | buffer.position(i * length).limit(i * length + length); 108 | final ByteBuffer slice = buffer.slice(); 109 | RegisteredMemory registeredMemory = new RegisteredMemory(refCount, memory, slice); 110 | put(registeredMemory); 111 | } 112 | preAllocs.incrementAndGet(); 113 | totalAlloc.incrementAndGet(); 114 | } 115 | 116 | @Override 117 | public void close() { 118 | while (!stack.isEmpty()) { 119 | RegisteredMemory memory = stack.pollFirst(); 120 | if (memory != null) { 121 | memory.deregisterNativeMemory(); 122 | } 123 | } 124 | } 125 | } 126 | 127 | private final ConcurrentHashMap allocStackMap = 128 | new ConcurrentHashMap<>(); 129 | private final UcpContext context; 130 | private final UcxShuffleConf conf; 131 | 132 | public MemoryPool(UcpContext context, UcxShuffleConf conf) { 133 | this.context = context; 134 | this.conf = conf; 135 | } 136 | 137 | private long roundUpToTheNextPowerOf2(long length) { 138 | // Round up length to the nearest power of two, or the minimum block size 139 | if (length < conf.minBufferSize()) { 140 | length = conf.minBufferSize(); 141 | } else { 142 | length--; 143 | length |= length >> 1; 144 | length |= length >> 2; 145 | length |= length >> 4; 146 | length |= length >> 8; 147 | length |= length >> 16; 148 | length++; 149 | } 150 | return length; 151 | } 152 | 153 | public RegisteredMemory get(int size) { 154 | long roundedSize = roundUpToTheNextPowerOf2(size); 155 | assert roundedSize < Integer.MAX_VALUE && roundedSize > 0; 156 | AllocatorStack stack = 157 | allocStackMap.computeIfAbsent((int)roundedSize, AllocatorStack::new); 158 | RegisteredMemory result = stack.get(); 159 | result.getBuffer().position(0).limit(size); 160 | return result; 161 | } 162 | 163 | public void put(RegisteredMemory memory) { 164 | AllocatorStack allocatorStack = allocStackMap.get(memory.getBuffer().capacity()); 165 | if (allocatorStack != null) { 166 | allocatorStack.put(memory); 167 | } 168 | } 169 | 170 | public void preAlocate() { 171 | conf.preallocateBuffersMap().forEach((size, numBuffers) -> { 172 | logger.debug("Pre allocating {} buffers of size {}", numBuffers, size); 173 | AllocatorStack stack = new AllocatorStack(size); 174 | allocStackMap.put(size, stack); 175 | stack.preallocate(numBuffers); 176 | }); 177 | } 178 | 179 | } 180 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java: -------------------------------------------------------------------------------- 1 | package org.apache.spark.shuffle.ucx.memory; 2 | 3 | import org.openucx.jucx.ucp.UcpMemory; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | import java.nio.ByteBuffer; 8 | import java.util.concurrent.atomic.AtomicInteger; 9 | 10 | /** 11 | * Structure to use 1 memory region for multiple ByteBuffers. 12 | * Keeps track on reference count to memory region. 13 | */ 14 | public class RegisteredMemory { 15 | private static final Logger logger = LoggerFactory.getLogger(RegisteredMemory.class); 16 | 17 | private final AtomicInteger refcount; 18 | private final UcpMemory memory; 19 | private final ByteBuffer buffer; 20 | 21 | RegisteredMemory(AtomicInteger refcount, UcpMemory memory, ByteBuffer buffer) { 22 | this.refcount = refcount; 23 | this.memory = memory; 24 | this.buffer = buffer; 25 | } 26 | 27 | public ByteBuffer getBuffer() { 28 | return buffer; 29 | } 30 | 31 | AtomicInteger getRefCount() { 32 | return refcount; 33 | } 34 | 35 | void deregisterNativeMemory() { 36 | if (refcount.get() != 0) { 37 | logger.warn("De-registering memory of size {} that has active references.", buffer.capacity()); 38 | } 39 | if (memory != null && memory.getNativeId() != null) { 40 | memory.deregister(); 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer; 6 | 7 | import org.apache.spark.network.buffer.ManagedBuffer; 8 | import org.apache.spark.network.buffer.NioManagedBuffer; 9 | 10 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 11 | import org.apache.spark.storage.BlockId; 12 | import org.apache.spark.util.Utils; 13 | import org.openucx.jucx.ucp.UcpRequest; 14 | 15 | import java.nio.ByteBuffer; 16 | import java.util.concurrent.atomic.AtomicInteger; 17 | 18 | /** 19 | * Final callback when all blocks fetched. 20 | * Notifies Spark's shuffleFetchIterator on block fetch completion. 21 | */ 22 | public class OnBlocksFetchCallback extends ReducerCallback { 23 | protected RegisteredMemory blocksMemory; 24 | protected int[] sizes; 25 | 26 | public OnBlocksFetchCallback(ReducerCallback callback, RegisteredMemory blocksMemory, int[] sizes) { 27 | super(callback); 28 | this.blocksMemory = blocksMemory; 29 | this.sizes = sizes; 30 | } 31 | 32 | @Override 33 | public void onSuccess(UcpRequest request) { 34 | int position = 0; 35 | AtomicInteger refCount = new AtomicInteger(blockIds.length); 36 | for (int i = 0; i < blockIds.length; i++) { 37 | BlockId block = blockIds[i]; 38 | // Blocks are fetched to contiguous buffer. 39 | // |----block1---||---block2---||---block3---| 40 | // Slice each block to avoid buffer copy. 41 | blocksMemory.getBuffer().position(position).limit(position + sizes[i]); 42 | ByteBuffer blockBuffer = blocksMemory.getBuffer().slice(); 43 | position += sizes[i]; 44 | // Pass block to Spark's ShuffleFetchIterator. 45 | listener.onBlockFetchSuccess(block.name(), new NioManagedBuffer(blockBuffer) { 46 | @Override 47 | public ManagedBuffer release() { 48 | if (refCount.decrementAndGet() == 0) { 49 | mempool.put(blocksMemory); 50 | } 51 | return this; 52 | } 53 | }); 54 | } 55 | logger.info("Endpoint {} fetched {} blocks of total size {} in {}ms", endpoint.getNativeId(), blockIds.length, 56 | Utils.bytesToString(position), System.currentTimeMillis() - startTime); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer; 6 | 7 | import org.apache.spark.SparkEnv; 8 | import org.apache.spark.network.shuffle.BlockFetchingListener; 9 | import org.apache.spark.shuffle.CommonUcxShuffleManager; 10 | import org.apache.spark.shuffle.ucx.memory.MemoryPool; 11 | import org.apache.spark.storage.BlockId; 12 | import org.openucx.jucx.UcxCallback; 13 | import org.openucx.jucx.ucp.UcpEndpoint; 14 | import org.slf4j.Logger; 15 | import org.slf4j.LoggerFactory; 16 | 17 | /** 18 | * Common data needed for offset fetch callback and subsequent block fetch callback. 19 | */ 20 | public abstract class ReducerCallback extends UcxCallback { 21 | protected MemoryPool mempool; 22 | protected BlockId[] blockIds; 23 | protected UcpEndpoint endpoint; 24 | protected BlockFetchingListener listener; 25 | protected static final Logger logger = LoggerFactory.getLogger(ReducerCallback.class); 26 | protected long startTime = System.currentTimeMillis(); 27 | 28 | public ReducerCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener) { 29 | this.mempool = ((CommonUcxShuffleManager)SparkEnv.get().shuffleManager()).ucxNode().getMemoryPool(); 30 | this.blockIds = blockIds; 31 | this.endpoint = endpoint; 32 | this.listener = listener; 33 | } 34 | 35 | public ReducerCallback(ReducerCallback callback) { 36 | this.blockIds = callback.blockIds; 37 | this.endpoint = callback.endpoint; 38 | this.listener = callback.listener; 39 | this.mempool = callback.mempool; 40 | this.startTime = callback.startTime; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1; 6 | 7 | import org.apache.spark.network.shuffle.BlockFetchingListener; 8 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 9 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 10 | import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback; 11 | import org.apache.spark.shuffle.ucx.reducer.ReducerCallback; 12 | import org.apache.spark.storage.ShuffleBlockId; 13 | import org.openucx.jucx.UcxUtils; 14 | import org.openucx.jucx.ucp.UcpEndpoint; 15 | import org.openucx.jucx.ucp.UcpRemoteKey; 16 | import org.openucx.jucx.ucp.UcpRequest; 17 | 18 | import java.nio.ByteBuffer; 19 | import java.util.Map; 20 | 21 | /** 22 | * Callback, called when got all offsets for blocks 23 | */ 24 | public class OnOffsetsFetchCallback extends ReducerCallback { 25 | private final RegisteredMemory offsetMemory; 26 | private final long[] dataAddresses; 27 | private Map dataRkeysCache; 28 | 29 | public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, 30 | RegisteredMemory offsetMemory, long[] dataAddresses, 31 | Map dataRkeysCache) { 32 | super(blockIds, endpoint, listener); 33 | this.offsetMemory = offsetMemory; 34 | this.dataAddresses = dataAddresses; 35 | this.dataRkeysCache = dataRkeysCache; 36 | } 37 | 38 | @Override 39 | public void onSuccess(UcpRequest request) { 40 | ByteBuffer resultOffset = offsetMemory.getBuffer(); 41 | long totalSize = 0; 42 | int[] sizes = new int[blockIds.length]; 43 | int offsetSize = UnsafeUtils.LONG_SIZE; 44 | for (int i = 0; i < blockIds.length; i++) { 45 | // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | 46 | long blockOffset = resultOffset.getLong(i * 2 * offsetSize); 47 | long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; 48 | assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); 49 | sizes[i] = (int) blockLength; 50 | totalSize += blockLength; 51 | dataAddresses[i] += blockOffset; 52 | } 53 | 54 | assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); 55 | mempool.put(offsetMemory); 56 | RegisteredMemory blocksMemory = mempool.get((int) totalSize); 57 | 58 | long offset = 0; 59 | // Submits N fetch blocks requests 60 | for (int i = 0; i < blockIds.length; i++) { 61 | endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()), 62 | UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); 63 | offset += sizes[i]; 64 | } 65 | 66 | // Process blocks when all fetched. 67 | // Flush guarantees that callback would invoke when all fetch requests will completed. 68 | endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1; 6 | 7 | import org.apache.spark.SparkEnv; 8 | import org.apache.spark.executor.TempShuffleReadMetrics; 9 | import org.apache.spark.network.shuffle.BlockFetchingListener; 10 | import org.apache.spark.network.shuffle.ShuffleClient; 11 | import org.apache.spark.shuffle.DriverMetadata; 12 | import org.apache.spark.shuffle.UcxShuffleManager; 13 | import org.apache.spark.shuffle.UcxWorkerWrapper; 14 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 15 | import org.apache.spark.shuffle.ucx.memory.MemoryPool; 16 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 17 | import org.apache.spark.storage.BlockId; 18 | import org.apache.spark.storage.BlockManagerId; 19 | import org.apache.spark.storage.ShuffleBlockId; 20 | import org.openucx.jucx.UcxUtils; 21 | import org.openucx.jucx.ucp.UcpEndpoint; 22 | import org.openucx.jucx.ucp.UcpRemoteKey; 23 | import org.slf4j.Logger; 24 | import org.slf4j.LoggerFactory; 25 | import scala.Option; 26 | 27 | import java.util.Arrays; 28 | import java.util.HashMap; 29 | 30 | public class UcxShuffleClient extends ShuffleClient { 31 | private final MemoryPool mempool; 32 | private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); 33 | private final UcxShuffleManager ucxShuffleManager; 34 | private final TempShuffleReadMetrics shuffleReadMetrics; 35 | private final UcxWorkerWrapper workerWrapper; 36 | final HashMap offsetRkeysCache = new HashMap<>(); 37 | final HashMap dataRkeysCache = new HashMap<>(); 38 | 39 | public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, 40 | UcxWorkerWrapper workerWrapper) { 41 | this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); 42 | this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); 43 | this.shuffleReadMetrics = shuffleReadMetrics; 44 | this.workerWrapper = workerWrapper; 45 | } 46 | 47 | /** 48 | * Submits n non blocking fetch offsets to get needed offsets for n blocks. 49 | */ 50 | private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, 51 | long[] dataAddresses, RegisteredMemory offsetMemory) { 52 | DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); 53 | for (int i = 0; i < blockIds.length; i++) { 54 | ShuffleBlockId blockId = blockIds[i]; 55 | 56 | long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); 57 | dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); 58 | 59 | offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> 60 | endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); 61 | 62 | dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> 63 | endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); 64 | 65 | endpoint.getNonBlockingImplicit( 66 | offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, 67 | offsetRkeysCache.get(blockId.mapId()), 68 | UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), 69 | 2L * UnsafeUtils.LONG_SIZE); 70 | } 71 | } 72 | 73 | /** 74 | * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. 75 | * This method is inside ShuffleFetchIterator's for loop over hosts. 76 | * First fetches block offset from index file, and then fetches block itself. 77 | */ 78 | @Override 79 | public void fetchBlocks(String host, int port, String execId, 80 | String[] blockIds, BlockFetchingListener listener) { 81 | long startTime = System.currentTimeMillis(); 82 | 83 | BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); 84 | UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); 85 | 86 | long[] dataAddresses = new long[blockIds.length]; 87 | 88 | // Need to fetch 2 long offsets current block + next block to calculate exact block size. 89 | RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); 90 | 91 | ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) 92 | .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); 93 | 94 | // Submits N implicit get requests without callback 95 | submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); 96 | 97 | // flush guarantees that all that requests completes when callback is called. 98 | // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. 99 | workerWrapper.worker().flushNonBlocking( 100 | new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, 101 | dataAddresses, dataRkeysCache)); 102 | shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); 103 | } 104 | 105 | @Override 106 | public void close() { 107 | offsetRkeysCache.values().forEach(UcpRemoteKey::close); 108 | dataRkeysCache.values().forEach(UcpRemoteKey::close); 109 | logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4; 6 | 7 | import org.apache.spark.network.shuffle.BlockFetchingListener; 8 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 9 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 10 | import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback; 11 | import org.apache.spark.shuffle.ucx.reducer.ReducerCallback; 12 | import org.apache.spark.storage.ShuffleBlockId; 13 | import org.openucx.jucx.UcxUtils; 14 | import org.openucx.jucx.ucp.UcpEndpoint; 15 | import org.openucx.jucx.ucp.UcpRemoteKey; 16 | import org.openucx.jucx.ucp.UcpRequest; 17 | 18 | import java.nio.ByteBuffer; 19 | import java.util.Map; 20 | 21 | /** 22 | * Callback, called when got all offsets for blocks 23 | */ 24 | public class OnOffsetsFetchCallback extends ReducerCallback { 25 | private final RegisteredMemory offsetMemory; 26 | private final long[] dataAddresses; 27 | private Map dataRkeysCache; 28 | 29 | public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, 30 | RegisteredMemory offsetMemory, long[] dataAddresses, 31 | Map dataRkeysCache) { 32 | super(blockIds, endpoint, listener); 33 | this.offsetMemory = offsetMemory; 34 | this.dataAddresses = dataAddresses; 35 | this.dataRkeysCache = dataRkeysCache; 36 | } 37 | 38 | @Override 39 | public void onSuccess(UcpRequest request) { 40 | ByteBuffer resultOffset = offsetMemory.getBuffer(); 41 | long totalSize = 0; 42 | int[] sizes = new int[blockIds.length]; 43 | int offsetSize = UnsafeUtils.LONG_SIZE; 44 | for (int i = 0; i < blockIds.length; i++) { 45 | // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | 46 | long blockOffset = resultOffset.getLong(i * 2 * offsetSize); 47 | long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; 48 | assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); 49 | sizes[i] = (int) blockLength; 50 | totalSize += blockLength; 51 | dataAddresses[i] += blockOffset; 52 | } 53 | 54 | assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); 55 | mempool.put(offsetMemory); 56 | RegisteredMemory blocksMemory = mempool.get((int) totalSize); 57 | 58 | long offset = 0; 59 | // Submits N fetch blocks requests 60 | for (int i = 0; i < blockIds.length; i++) { 61 | endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()), 62 | UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); 63 | offset += sizes[i]; 64 | } 65 | 66 | // Process blocks when all fetched. 67 | // Flush guarantees that callback would invoke when all fetch requests will completed. 68 | endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4; 6 | 7 | import org.apache.spark.SparkEnv; 8 | import org.apache.spark.executor.TempShuffleReadMetrics; 9 | import org.apache.spark.network.shuffle.BlockFetchingListener; 10 | import org.apache.spark.network.shuffle.DownloadFileManager; 11 | import org.apache.spark.network.shuffle.ShuffleClient; 12 | import org.apache.spark.shuffle.*; 13 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 14 | import org.apache.spark.shuffle.ucx.memory.MemoryPool; 15 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 16 | import org.apache.spark.storage.BlockId; 17 | import org.apache.spark.storage.BlockManagerId; 18 | import org.apache.spark.storage.ShuffleBlockId; 19 | import org.openucx.jucx.UcxUtils; 20 | import org.slf4j.Logger; 21 | import org.slf4j.LoggerFactory; 22 | import org.openucx.jucx.ucp.UcpEndpoint; 23 | import org.openucx.jucx.ucp.UcpRemoteKey; 24 | import scala.Option; 25 | 26 | import java.util.Arrays; 27 | import java.util.HashMap; 28 | 29 | public class UcxShuffleClient extends ShuffleClient { 30 | private final MemoryPool mempool; 31 | private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); 32 | private final UcxShuffleManager ucxShuffleManager; 33 | private final TempShuffleReadMetrics shuffleReadMetrics; 34 | private final UcxWorkerWrapper workerWrapper; 35 | final HashMap offsetRkeysCache = new HashMap<>(); 36 | final HashMap dataRkeysCache = new HashMap<>(); 37 | 38 | public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, 39 | UcxWorkerWrapper workerWrapper) { 40 | this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); 41 | this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); 42 | this.shuffleReadMetrics = shuffleReadMetrics; 43 | this.workerWrapper = workerWrapper; 44 | } 45 | 46 | /** 47 | * Submits n non blocking fetch offsets to get needed offsets for n blocks. 48 | */ 49 | private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, 50 | long[] dataAddresses, RegisteredMemory offsetMemory) { 51 | DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); 52 | for (int i = 0; i < blockIds.length; i++) { 53 | ShuffleBlockId blockId = blockIds[i]; 54 | 55 | long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); 56 | dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); 57 | 58 | offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> 59 | endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); 60 | 61 | dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> 62 | endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); 63 | 64 | endpoint.getNonBlockingImplicit( 65 | offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, 66 | offsetRkeysCache.get(blockId.mapId()), 67 | UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), 68 | 2L * UnsafeUtils.LONG_SIZE); 69 | } 70 | } 71 | 72 | /** 73 | * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. 74 | * This method is inside ShuffleFetchIterator's for loop over hosts. 75 | * First fetches block offset from index file, and then fetches block itself. 76 | */ 77 | @Override 78 | public void fetchBlocks(String host, int port, String execId, 79 | String[] blockIds, BlockFetchingListener listener, 80 | DownloadFileManager downloadFileManager) { 81 | long startTime = System.currentTimeMillis(); 82 | 83 | BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); 84 | UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); 85 | 86 | long[] dataAddresses = new long[blockIds.length]; 87 | 88 | // Need to fetch 2 long offsets current block + next block to calculate exact block size. 89 | RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); 90 | 91 | ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) 92 | .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); 93 | 94 | // Submits N implicit get requests without callback 95 | submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); 96 | 97 | // flush guarantees that all that requests completes when callback is called. 98 | // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. 99 | workerWrapper.worker().flushNonBlocking( 100 | new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, 101 | dataAddresses, dataRkeysCache)); 102 | shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); 103 | } 104 | 105 | @Override 106 | public void close() { 107 | offsetRkeysCache.values().forEach(UcpRemoteKey::close); 108 | dataRkeysCache.values().forEach(UcpRemoteKey::close); 109 | logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0; 6 | 7 | import org.apache.spark.network.shuffle.BlockFetchingListener; 8 | import org.apache.spark.shuffle.UcxWorkerWrapper; 9 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 10 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 11 | import org.apache.spark.shuffle.ucx.reducer.ReducerCallback; 12 | import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback; 13 | import org.apache.spark.storage.BlockId; 14 | import org.apache.spark.storage.ShuffleBlockBatchId; 15 | import org.apache.spark.storage.ShuffleBlockId; 16 | import org.openucx.jucx.UcxUtils; 17 | import org.openucx.jucx.ucp.UcpEndpoint; 18 | import org.openucx.jucx.ucp.UcpRemoteKey; 19 | import org.openucx.jucx.ucp.UcpRequest; 20 | 21 | import java.nio.ByteBuffer; 22 | import java.util.Map; 23 | 24 | /** 25 | * Callback, called when got all offsets for blocks 26 | */ 27 | public class OnOffsetsFetchCallback extends ReducerCallback { 28 | private final RegisteredMemory offsetMemory; 29 | private final long[] dataAddresses; 30 | private Map dataRkeysCache; 31 | private final Map mapId2PartitionId; 32 | 33 | public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, 34 | RegisteredMemory offsetMemory, long[] dataAddresses, 35 | Map dataRkeysCache, 36 | Map mapId2PartitionId) { 37 | super(blockIds, endpoint, listener); 38 | this.offsetMemory = offsetMemory; 39 | this.dataAddresses = dataAddresses; 40 | this.dataRkeysCache = dataRkeysCache; 41 | this.mapId2PartitionId = mapId2PartitionId; 42 | } 43 | 44 | @Override 45 | public void onSuccess(UcpRequest request) { 46 | ByteBuffer resultOffset = offsetMemory.getBuffer(); 47 | long totalSize = 0; 48 | int[] sizes = new int[blockIds.length]; 49 | int offset = 0; 50 | long blockOffset; 51 | long blockLength; 52 | int offsetSize = UnsafeUtils.LONG_SIZE; 53 | for (int i = 0; i < blockIds.length; i++) { 54 | // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | 55 | if (blockIds[i] instanceof ShuffleBlockBatchId) { 56 | ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i]; 57 | int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId(); 58 | blockOffset = resultOffset.getLong(offset * 2 * offsetSize); 59 | blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch) 60 | - blockOffset; 61 | offset += blocksInBatch; 62 | } else { 63 | blockOffset = resultOffset.getLong(offset * 16); 64 | blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset; 65 | offset++; 66 | } 67 | 68 | assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); 69 | sizes[i] = (int) blockLength; 70 | totalSize += blockLength; 71 | dataAddresses[i] += blockOffset; 72 | } 73 | 74 | assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); 75 | mempool.put(offsetMemory); 76 | RegisteredMemory blocksMemory = mempool.get((int) totalSize); 77 | 78 | offset = 0; 79 | // Submits N fetch blocks requests 80 | for (int i = 0; i < blockIds.length; i++) { 81 | int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ? 82 | mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) : 83 | mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId()); 84 | endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId), 85 | UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); 86 | offset += sizes[i]; 87 | } 88 | 89 | // Process blocks when all fetched. 90 | // Flush guarantees that callback would invoke when all fetch requests will completed. 91 | endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0; 6 | 7 | import org.apache.spark.SparkEnv; 8 | import org.apache.spark.executor.TempShuffleReadMetrics; 9 | import org.apache.spark.network.shuffle.BlockFetchingListener; 10 | import org.apache.spark.network.shuffle.BlockStoreClient; 11 | import org.apache.spark.network.shuffle.DownloadFileManager; 12 | import org.apache.spark.shuffle.DriverMetadata; 13 | import org.apache.spark.shuffle.UcxShuffleManager; 14 | import org.apache.spark.shuffle.UcxWorkerWrapper; 15 | import org.apache.spark.shuffle.ucx.UnsafeUtils; 16 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; 17 | import org.apache.spark.storage.*; 18 | import org.openucx.jucx.UcxUtils; 19 | import org.openucx.jucx.ucp.UcpEndpoint; 20 | import org.openucx.jucx.ucp.UcpRemoteKey; 21 | import org.slf4j.Logger; 22 | import org.slf4j.LoggerFactory; 23 | import scala.Option; 24 | 25 | 26 | import java.util.HashMap; 27 | import java.util.Map; 28 | 29 | public class UcxShuffleClient extends BlockStoreClient { 30 | private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); 31 | private final UcxWorkerWrapper workerWrapper; 32 | private final Map mapId2PartitionId; 33 | private final TempShuffleReadMetrics shuffleReadMetrics; 34 | private final int shuffleId; 35 | final HashMap offsetRkeysCache = new HashMap<>(); 36 | final HashMap dataRkeysCache = new HashMap<>(); 37 | 38 | 39 | public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper, 40 | Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) { 41 | this.workerWrapper = workerWrapper; 42 | this.shuffleId = shuffleId; 43 | this.mapId2PartitionId = mapId2PartitionId; 44 | this.shuffleReadMetrics = shuffleReadMetrics; 45 | } 46 | 47 | /** 48 | * Submits n non blocking fetch offsets to get needed offsets for n blocks. 49 | */ 50 | private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds, 51 | RegisteredMemory offsetMemory, 52 | long[] dataAddresses) { 53 | DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId); 54 | long offset = 0; 55 | int startReduceId; 56 | long size; 57 | 58 | for (int i = 0; i < blockIds.length; i++) { 59 | BlockId blockId = blockIds[i]; 60 | int mapIdpartition; 61 | 62 | if (blockId instanceof ShuffleBlockId) { 63 | ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId; 64 | mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId()); 65 | size = 2L * UnsafeUtils.LONG_SIZE; 66 | startReduceId = shuffleBlockId.reduceId(); 67 | } else { 68 | ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId; 69 | mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId()); 70 | size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId()) 71 | * 2L * UnsafeUtils.LONG_SIZE; 72 | startReduceId = shuffleBlockBatchId.startReduceId(); 73 | } 74 | 75 | long offsetAddress = driverMetadata.offsetAddress(mapIdpartition); 76 | dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition); 77 | 78 | offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId -> 79 | endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition))); 80 | 81 | dataRkeysCache.computeIfAbsent(mapIdpartition, mapId -> 82 | endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition))); 83 | 84 | endpoint.getNonBlockingImplicit( 85 | offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE, 86 | offsetRkeysCache.get(mapIdpartition), 87 | UcxUtils.getAddress(offsetMemory.getBuffer()) + offset, 88 | size); 89 | 90 | offset += size; 91 | } 92 | } 93 | 94 | @Override 95 | public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener, 96 | DownloadFileManager downloadFileManager) { 97 | long startTime = System.currentTimeMillis(); 98 | BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); 99 | UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); 100 | long[] dataAddresses = new long[blockIds.length]; 101 | int totalBlocks = 0; 102 | 103 | BlockId[] blocks = new BlockId[blockIds.length]; 104 | 105 | for (int i = 0; i < blockIds.length; i++) { 106 | blocks[i] = BlockId.apply(blockIds[i]); 107 | if (blocks[i] instanceof ShuffleBlockId) { 108 | totalBlocks += 1; 109 | } else { 110 | ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i]; 111 | totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId()); 112 | } 113 | } 114 | 115 | RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager()) 116 | .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE); 117 | // Submits N implicit get requests without callback 118 | submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses); 119 | 120 | // flush guarantees that all that requests completes when callback is called. 121 | // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. 122 | workerWrapper.worker().flushNonBlocking( 123 | new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory, 124 | dataAddresses, dataRkeysCache, mapId2PartitionId)); 125 | 126 | shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); 127 | } 128 | 129 | @Override 130 | public void close() { 131 | offsetRkeysCache.values().forEach(UcpRemoteKey::close); 132 | dataRkeysCache.values().forEach(UcpRemoteKey::close); 133 | logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); 134 | } 135 | 136 | } 137 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc; 6 | 7 | import org.apache.spark.shuffle.ucx.UcxNode; 8 | import org.apache.spark.storage.BlockManagerId; 9 | import org.apache.spark.unsafe.Platform; 10 | import org.openucx.jucx.UcxCallback; 11 | import org.openucx.jucx.UcxException; 12 | import org.openucx.jucx.ucp.UcpEndpoint; 13 | import org.openucx.jucx.ucp.UcpEndpointParams; 14 | import org.openucx.jucx.ucp.UcpRequest; 15 | import org.openucx.jucx.ucp.UcpWorker; 16 | import org.slf4j.Logger; 17 | import org.slf4j.LoggerFactory; 18 | 19 | import java.io.IOException; 20 | import java.nio.ByteBuffer; 21 | import java.util.concurrent.ConcurrentMap; 22 | 23 | /** 24 | * RPC processing logic. Both driver and excutor accepts the same RPC messgae: 25 | * executor worker address followed by it's serialized BlockManagerID. 26 | * Executor on accepting this message just adds workerAddress to the connection map. 27 | * Driver doing the logic of introducing connected executor to cluster nodes and 28 | * introduce cluster to connected executor. 29 | */ 30 | public class RpcConnectionCallback extends UcxCallback { 31 | private static final Logger logger = LoggerFactory.getLogger(RpcConnectionCallback.class); 32 | private final ByteBuffer metadataBuffer; 33 | private final boolean isDriver; 34 | private final UcxNode ucxNode; 35 | private static final ConcurrentMap rpcConnections = 36 | UcxNode.getRpcConnections(); 37 | private static final ConcurrentMap workerAdresses = 38 | UcxNode.getWorkerAddresses(); 39 | 40 | RpcConnectionCallback(ByteBuffer metadataBuffer, boolean isDriver, UcxNode ucxNode) { 41 | this.metadataBuffer = metadataBuffer; 42 | this.isDriver = isDriver; 43 | this.ucxNode = ucxNode; 44 | } 45 | 46 | @Override 47 | public void onSuccess(UcpRequest request) { 48 | int workerAddressSize = metadataBuffer.getInt(); 49 | ByteBuffer workerAddress = Platform.allocateDirectBuffer(workerAddressSize); 50 | 51 | // Copy worker address from metadata buffer to separate buffer. 52 | final ByteBuffer metadataView = metadataBuffer.duplicate(); 53 | metadataView.limit(metadataView.position() + workerAddressSize); 54 | workerAddress.put(metadataView); 55 | metadataBuffer.position(metadataBuffer.position() + workerAddressSize); 56 | 57 | BlockManagerId blockManagerId; 58 | try { 59 | blockManagerId = SerializableBlockManagerID 60 | .deserializeBlockManagerID(metadataBuffer); 61 | } catch (IOException e) { 62 | String errorMsg = String.format("Failed to deserialize BlockManagerId: %s", e.getMessage()); 63 | throw new UcxException(errorMsg); 64 | } 65 | logger.debug("Received RPC message from {}", blockManagerId); 66 | UcpWorker globalWorker = ucxNode.getGlobalWorker(); 67 | 68 | workerAddress.clear(); 69 | 70 | if (isDriver) { 71 | metadataBuffer.clear(); 72 | UcpEndpoint newConnection = globalWorker.newEndpoint( 73 | new UcpEndpointParams().setPeerErrorHandlingMode() 74 | .setUcpAddress(workerAddress)); 75 | // For each existing connection 76 | rpcConnections.forEach((connection, connectionMetadata) -> { 77 | // send address of joined worker to already connected workers 78 | connection.sendTaggedNonBlocking(metadataBuffer, null); 79 | // introduce other workers to joined worker 80 | newConnection.sendTaggedNonBlocking(connectionMetadata, null); 81 | }); 82 | 83 | rpcConnections.put(newConnection, metadataBuffer); 84 | } 85 | workerAdresses.put(blockManagerId, workerAddress); 86 | synchronized (workerAdresses) { 87 | workerAdresses.notifyAll(); 88 | } 89 | } 90 | 91 | @Override 92 | public void onError(int ucsStatus, String errorMsg) { 93 | // UCS_ERR_CANCELED = -16, 94 | if (ucsStatus != -16) { 95 | logger.error("Request error: {}", errorMsg); 96 | throw new UcxException(errorMsg); 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java: -------------------------------------------------------------------------------- 1 | package org.apache.spark.shuffle.ucx.rpc; 2 | 3 | import com.esotericsoftware.kryo.io.ByteBufferInputStream; 4 | import com.esotericsoftware.kryo.io.ByteBufferOutputStream; 5 | import org.apache.spark.storage.BlockManagerId; 6 | 7 | import java.io.IOException; 8 | import java.io.ObjectInputStream; 9 | import java.io.ObjectOutputStream; 10 | import java.nio.ByteBuffer; 11 | 12 | /** 13 | * Static mthods to serialize BlockManagerID to ByteBuffer. 14 | */ 15 | public class SerializableBlockManagerID { 16 | 17 | public static void serializeBlockManagerID(BlockManagerId blockManagerId, 18 | ByteBuffer metadataBuffer) throws IOException { 19 | ObjectOutputStream oos = new ObjectOutputStream( 20 | new ByteBufferOutputStream(metadataBuffer)); 21 | blockManagerId.writeExternal(oos); 22 | oos.close(); 23 | } 24 | 25 | static BlockManagerId deserializeBlockManagerID(ByteBuffer metadataBuffer) throws IOException { 26 | ObjectInputStream ois = 27 | new ObjectInputStream(new ByteBufferInputStream(metadataBuffer)); 28 | BlockManagerId blockManagerId = BlockManagerId.apply(ois); 29 | ois.close(); 30 | return blockManagerId; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc; 6 | 7 | import org.apache.spark.shuffle.ucx.UcxNode; 8 | import org.apache.spark.unsafe.Platform; 9 | import org.openucx.jucx.ucp.UcpRequest; 10 | import org.openucx.jucx.ucp.UcpWorker; 11 | import org.slf4j.Logger; 12 | import org.slf4j.LoggerFactory; 13 | 14 | import java.nio.ByteBuffer; 15 | 16 | /** 17 | * Thread for progressing global worker for connection establishment and RPC exchange. 18 | */ 19 | public class UcxListenerThread extends Thread implements Runnable { 20 | private static final Logger logger = LoggerFactory.getLogger(UcxListenerThread.class); 21 | private final UcxNode ucxNode; 22 | private final boolean isDriver; 23 | private final UcpWorker globalWorker; 24 | 25 | public UcxListenerThread(UcxNode ucxNode, boolean isDriver) { 26 | this.ucxNode = ucxNode; 27 | this.isDriver = isDriver; 28 | this.globalWorker = ucxNode.getGlobalWorker(); 29 | setDaemon(true); 30 | setName("UcxListenerThread"); 31 | } 32 | 33 | /** 34 | * 2. Both Driver and Executor. Accept Recv request. 35 | * If on driver broadcast it to other executors. On executor just save worker addresses. 36 | */ 37 | private UcpRequest recvRequest() { 38 | ByteBuffer metadataBuffer = Platform.allocateDirectBuffer( 39 | ucxNode.getConf().metadataRPCBufferSize()); 40 | RpcConnectionCallback callback = new RpcConnectionCallback(metadataBuffer, isDriver, ucxNode); 41 | return globalWorker.recvTaggedNonBlocking(metadataBuffer, callback); 42 | } 43 | 44 | @Override 45 | public void run() { 46 | UcpRequest recv = recvRequest(); 47 | while (!isInterrupted()) { 48 | if (recv.isCompleted()) { 49 | // Process 1 recv request at a time. 50 | recv = recvRequest(); 51 | } 52 | try { 53 | if (globalWorker.progress() == 0) { 54 | globalWorker.waitForEvents(); 55 | } 56 | } catch (Exception e) { 57 | logger.error(e.getLocalizedMessage()); 58 | interrupt(); 59 | } 60 | } 61 | globalWorker.cancelRequest(recv); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc; 6 | 7 | import java.io.IOException; 8 | import java.io.ObjectInputStream; 9 | import java.io.ObjectOutputStream; 10 | import java.io.Serializable; 11 | import java.nio.ByteBuffer; 12 | 13 | /** 14 | * Utility class to serialize / deserialize metadata buffer on a driver. 15 | * Needed to propagate metadata buffer information to executors using 16 | * spark's mechanism to broadcast tasks. 17 | */ 18 | public class UcxRemoteMemory implements Serializable { 19 | private long address; 20 | private ByteBuffer rkeyBuffer; 21 | 22 | public UcxRemoteMemory(long address, ByteBuffer rkeyBuffer) { 23 | this.address = address; 24 | this.rkeyBuffer = rkeyBuffer; 25 | } 26 | 27 | public UcxRemoteMemory() {} 28 | 29 | private void writeObject(ObjectOutputStream out) throws IOException { 30 | out.writeLong(address); 31 | out.writeInt(rkeyBuffer.limit()); 32 | byte[] copy = new byte[rkeyBuffer.limit()]; 33 | rkeyBuffer.clear(); 34 | rkeyBuffer.get(copy); 35 | out.write(copy); 36 | } 37 | 38 | private void readObject(ObjectInputStream in) throws IOException { 39 | this.address = in.readLong(); 40 | int bufferSize = in.readInt(); 41 | byte[] buffer = new byte[bufferSize]; 42 | in.read(buffer, 0, bufferSize); 43 | this.rkeyBuffer = ByteBuffer.allocateDirect(bufferSize).put(buffer); 44 | this.rkeyBuffer.clear(); 45 | } 46 | 47 | public long getAddress() { 48 | return address; 49 | } 50 | 51 | public ByteBuffer getRkeyBuffer() { 52 | return rkeyBuffer; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import java.io.{File, RandomAccessFile} 8 | import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArrayList} 9 | 10 | import scala.collection.JavaConverters._ 11 | 12 | import org.openucx.jucx.UcxUtils 13 | import org.openucx.jucx.ucp.{UcpMemMapParams, UcpMemory} 14 | import org.apache.spark.shuffle.ucx.UnsafeUtils 15 | import org.apache.spark.SparkException 16 | 17 | /** 18 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 19 | * of data and index files and publish addresses to driver metadata buffer. 20 | */ 21 | abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 22 | extends IndexShuffleBlockResolver(ucxShuffleManager.conf) { 23 | private lazy val memPool = ucxShuffleManager.ucxNode.getMemoryPool 24 | 25 | // Keep track of registered memory regions to release them when shuffle not needed 26 | private val fileMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala 27 | private val offsetMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala 28 | 29 | /** 30 | * Mapper commit protocol extension. Register index and data files and publish all needed 31 | * metadata to driver. 32 | */ 33 | def writeIndexFileAndCommitCommon(shuffleId: ShuffleId, mapId: Int, 34 | lengths: Array[Long], dataTmp: File, 35 | indexBackFile: RandomAccessFile, dataBackFile: RandomAccessFile): Unit = { 36 | val startTime = System.currentTimeMillis() 37 | 38 | fileMappings.putIfAbsent(shuffleId, new CopyOnWriteArrayList[UcpMemory]()) 39 | offsetMappings.putIfAbsent(shuffleId, new CopyOnWriteArrayList[UcpMemory]()) 40 | 41 | val indexFileChannel = indexBackFile.getChannel 42 | val dataFileChannel = dataBackFile.getChannel 43 | 44 | // Memory map and register data and index file. 45 | val dataAddress = UnsafeUtils.mmap(dataFileChannel, 0, dataBackFile.length()) 46 | val memMapParams = new UcpMemMapParams().setAddress(dataAddress) 47 | .setLength(dataBackFile.length()) 48 | if (ucxShuffleManager.ucxShuffleConf.useOdp) { 49 | memMapParams.nonBlocking() 50 | } 51 | val dataMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams) 52 | fileMappings(shuffleId).add(dataMemory) 53 | assume(indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1)) 54 | 55 | val offsetAddress = UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length()) 56 | memMapParams.setAddress(offsetAddress).setLength(indexBackFile.length()) 57 | val offsetMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams) 58 | offsetMappings(shuffleId).add(offsetMemory) 59 | 60 | dataFileChannel.close() 61 | dataBackFile.close() 62 | indexFileChannel.close() 63 | indexBackFile.close() 64 | 65 | val fileMemoryRkey = dataMemory.getRemoteKeyBuffer 66 | val offsetRkey = offsetMemory.getRemoteKeyBuffer 67 | 68 | val metadataRegisteredMemory = memPool.get( 69 | fileMemoryRkey.capacity() + offsetRkey.capacity() + 24) 70 | val metadataBuffer = metadataRegisteredMemory.getBuffer.slice() 71 | 72 | if (metadataBuffer.remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize) { 73 | throw new SparkException(s"Metadata block size ${metadataBuffer.remaining() / 2} " + 74 | s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" + 75 | s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize}).") 76 | } 77 | 78 | metadataBuffer.clear() 79 | 80 | metadataBuffer.putLong(offsetMemory.getAddress) 81 | metadataBuffer.putLong(dataMemory.getAddress) 82 | 83 | metadataBuffer.putInt(offsetRkey.capacity()) 84 | metadataBuffer.put(offsetRkey) 85 | 86 | metadataBuffer.putInt(fileMemoryRkey.capacity()) 87 | metadataBuffer.put(fileMemoryRkey) 88 | 89 | metadataBuffer.clear() 90 | 91 | val workerWrapper = ucxShuffleManager.ucxNode.getThreadLocalWorker 92 | val driverMetadata = workerWrapper.getDriverMetadata(shuffleId) 93 | val driverOffset = driverMetadata.address + 94 | mapId * ucxShuffleManager.ucxShuffleConf.metadataBlockSize 95 | 96 | val driverEndpoint = workerWrapper.driverEndpoint 97 | val request = driverEndpoint.putNonBlocking(UcxUtils.getAddress(metadataBuffer), 98 | metadataBuffer.remaining(), driverOffset, driverMetadata.driverRkey, null) 99 | 100 | workerWrapper.preconnect() 101 | // Blocking progress needed to make sure last mapper published data to driver before 102 | // reducer starts. 103 | workerWrapper.waitRequest(request) 104 | memPool.put(metadataRegisteredMemory) 105 | logInfo(s"MapID: $mapId register files + publishing overhead: " + 106 | s"${System.currentTimeMillis() - startTime} ms") 107 | } 108 | 109 | private def unregisterAndUnmap(mem: UcpMemory): Unit = { 110 | val address = mem.getAddress 111 | val length = mem.getLength 112 | mem.deregister() 113 | UnsafeUtils.munmap(address, length) 114 | } 115 | 116 | def removeShuffle(shuffleId: Int): Unit = { 117 | fileMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) => 118 | mappings.asScala.par.foreach(unregisterAndUnmap)) 119 | offsetMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) => 120 | mappings.asScala.par.foreach(unregisterAndUnmap)) 121 | } 122 | 123 | override def stop(): Unit = { 124 | fileMappings.keys.foreach(removeShuffle) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import java.util.concurrent.ConcurrentHashMap 8 | 9 | import scala.collection.JavaConverters._ 10 | import scala.collection.{concurrent, mutable} 11 | 12 | import org.openucx.jucx.ucp.UcpMemory 13 | import org.apache.spark.SparkConf 14 | import org.apache.spark.shuffle.sort.SortShuffleManager 15 | import org.apache.spark.shuffle.ucx.UcxNode 16 | import org.apache.spark.shuffle.ucx.rpc.UcxRemoteMemory 17 | import org.apache.spark.unsafe.Platform 18 | 19 | /** 20 | * Common part for all spark versions for UcxShuffleManager logic 21 | */ 22 | abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { 23 | type ShuffleId = Int 24 | type MapId = Int 25 | val ucxShuffleConf = new UcxShuffleConf(conf) 26 | 27 | var ucxNode: UcxNode = _ 28 | 29 | // Shuffle handle is metadata information about the shuffle (num mappers, etc) 30 | // distributed by Spark task broadcast protocol. 31 | // UcxShuffleHandle is extension over Spark's shuffle handle to keep driver metadata info. 32 | val shuffleIdToHandle: concurrent.Map[ShuffleId, UcxShuffleHandle[_, _, _]] = 33 | new ConcurrentHashMap[ShuffleId, UcxShuffleHandle[_, _, _]]().asScala 34 | 35 | if (isDriver) { 36 | startUcxNodeIfMissing() 37 | } 38 | 39 | protected def registerShuffleCommon[K, V, C](baseHandle: BaseShuffleHandle[K,V,C], 40 | shuffleId: ShuffleId, 41 | numMaps: Int): ShuffleHandle = { 42 | // Register metadata buffer where each map will publish it's index and data file metadata 43 | val metadataBufferSize = numMaps * ucxShuffleConf.metadataBlockSize 44 | val metadataBuffer = Platform.allocateDirectBuffer(metadataBufferSize.toInt) 45 | 46 | val metadataMemory = ucxNode.getContext.registerMemory(metadataBuffer) 47 | shuffleIdToMetadataBuffer.put(shuffleId, metadataMemory) 48 | 49 | val driverMemory = new UcxRemoteMemory(metadataMemory.getAddress, 50 | metadataMemory.getRemoteKeyBuffer) 51 | 52 | val handle = new UcxShuffleHandle(shuffleId, driverMemory, numMaps, baseHandle) 53 | 54 | shuffleIdToHandle.putIfAbsent(shuffleId, handle) 55 | handle 56 | } 57 | 58 | /** 59 | * Mapping between shuffle and metadata buffer, to deregister it when shuffle not needed. 60 | */ 61 | protected val shuffleIdToMetadataBuffer: mutable.Map[ShuffleId, UcpMemory] = 62 | new ConcurrentHashMap[ShuffleId, UcpMemory]().asScala 63 | 64 | /** 65 | * Atomically starts UcxNode singleton - one for all shuffle threads. 66 | */ 67 | def startUcxNodeIfMissing(): Unit = if (ucxNode == null) { 68 | synchronized { 69 | if (ucxNode == null) { 70 | ucxNode = new UcxNode(ucxShuffleConf, isDriver) 71 | } 72 | } 73 | } 74 | 75 | override def unregisterShuffle(shuffleId: Int): Boolean = { 76 | shuffleIdToMetadataBuffer.remove(shuffleId).foreach(_.deregister()) 77 | shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId) 78 | super.unregisterShuffle(shuffleId) 79 | } 80 | 81 | /** 82 | * Called on both driver and executors to finally cleanup resources. 83 | */ 84 | override def stop(): Unit = synchronized { 85 | logInfo("Stopping shuffle manager") 86 | shuffleIdToHandle.keys.foreach(unregisterShuffle) 87 | shuffleIdToHandle.clear() 88 | if (ucxNode != null) { 89 | ucxNode.close() 90 | ucxNode = null 91 | } 92 | super.stop() 93 | } 94 | 95 | } 96 | 97 | /** 98 | * Spark shuffle handles extensions, broadcasted by TCP to executors. 99 | * Added metadataBufferOnDriver field, that contains address and rkey of driver metadata buffer. 100 | */ 101 | class UcxShuffleHandle[K, V, C](override val shuffleId: Int, 102 | val metadataBufferOnDriver: UcxRemoteMemory, 103 | val numMaps: Int, 104 | val baseHandle: BaseShuffleHandle[K,V,C]) extends ShuffleHandle(shuffleId) 105 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | import org.apache.spark.SparkConf 10 | import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry} 11 | import org.apache.spark.network.util.ByteUnit 12 | import org.apache.spark.util.Utils 13 | 14 | /** 15 | * Plugin configuration properties. 16 | */ 17 | class UcxShuffleConf(conf: SparkConf) extends SparkConf { 18 | private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" 19 | 20 | lazy val getNumProcesses: Int = getInt("spark.executor.instances", 1) 21 | 22 | lazy val coresPerProcess: Int = getInt("spark.executor.cores", 23 | Runtime.getRuntime.availableProcessors()) 24 | 25 | lazy val driverHost: String = conf.get(getUcxConf("driver.host"), 26 | conf.get("spark.driver.host", "0.0.0.0")) 27 | 28 | lazy val driverPort: Int = conf.getInt(getUcxConf("driver.port"), 55443) 29 | 30 | // Metadata 31 | 32 | lazy val RKEY_SIZE: ConfigEntry[Long] = 33 | ConfigBuilder(getUcxConf("rkeySize")) 34 | .doc("Maximum size of rKeyBuffer") 35 | .bytesConf(ByteUnit.BYTE) 36 | .createWithDefault(150) 37 | 38 | // For metadata we publish index file + data file rkeys 39 | lazy val metadataBlockSize: Long = 2 * conf.getSizeAsBytes(RKEY_SIZE.key, 40 | RKEY_SIZE.defaultValueString) 41 | 42 | private lazy val METADATA_RPC_BUFFER_SIZE = 43 | ConfigBuilder(getUcxConf("rpc.metadata.bufferSize")) 44 | .doc("Buffer size of worker -> driver metadata message") 45 | .bytesConf(ByteUnit.BYTE) 46 | .createWithDefault(4096) 47 | 48 | lazy val metadataRPCBufferSize: Int = conf.getSizeAsBytes(METADATA_RPC_BUFFER_SIZE.key, 49 | METADATA_RPC_BUFFER_SIZE.defaultValueString).toInt 50 | 51 | // Memory Pool 52 | private lazy val PREALLOCATE_BUFFERS = 53 | ConfigBuilder(getUcxConf("memory.preAllocateBuffers")) 54 | .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500") 55 | .stringConf.createWithDefault("") 56 | 57 | lazy val preallocateBuffersMap: java.util.Map[java.lang.Integer, java.lang.Integer] = { 58 | conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty) 59 | .map(entry => entry.split(":") match { 60 | case Array(bufferSize, bufferCount) => 61 | (int2Integer(Utils.byteStringAsBytes(bufferSize.trim).toInt), 62 | int2Integer(bufferCount.toInt)) 63 | }).toMap.asJava 64 | } 65 | 66 | private lazy val MIN_BUFFER_SIZE = ConfigBuilder(getUcxConf("memory.minBufferSize")) 67 | .doc("Minimal buffer size in memory pool.") 68 | .bytesConf(ByteUnit.BYTE) 69 | .createWithDefault(1024) 70 | 71 | lazy val minBufferSize: Long = conf.getSizeAsBytes(MIN_BUFFER_SIZE.key, 72 | MIN_BUFFER_SIZE.defaultValueString) 73 | 74 | private lazy val MIN_REGISTRATION_SIZE = 75 | ConfigBuilder(getUcxConf("memory.minAllocationSize")) 76 | .doc("Minimal memory registration size in memory pool.") 77 | .bytesConf(ByteUnit.MiB) 78 | .createWithDefault(4) 79 | 80 | lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, 81 | MIN_REGISTRATION_SIZE.defaultValueString).toInt 82 | 83 | private lazy val PREREGISTER_MEMORY = ConfigBuilder(getUcxConf("memory.preregister")) 84 | .doc("Whether to do ucp mem map for allocated memory in memory pool") 85 | .booleanConf.createWithDefault(true) 86 | 87 | lazy val preregisterMemory: Boolean = conf.getBoolean(PREREGISTER_MEMORY.key, PREREGISTER_MEMORY.defaultValue.get) 88 | 89 | lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), false) 90 | } 91 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import java.io.Closeable 8 | import java.net.InetSocketAddress 9 | import java.nio.ByteBuffer 10 | import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} 11 | 12 | import scala.collection.JavaConverters._ 13 | import scala.collection.mutable 14 | 15 | import org.openucx.jucx.UcxException 16 | import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRemoteKey, UcpRequest, UcpWorker} 17 | import org.apache.spark.SparkEnv 18 | import org.apache.spark.internal.Logging 19 | import org.apache.spark.shuffle.ucx.{UcxNode, UnsafeUtils} 20 | import org.apache.spark.storage.BlockManagerId 21 | import org.apache.spark.unsafe.Platform 22 | 23 | /** 24 | * Driver metadata buffer information that holds unpacked RkeyBuffer for this WorkerWrapper 25 | * and fetched buffer itself. 26 | */ 27 | case class DriverMetadata(address: Long, driverRkey: UcpRemoteKey, length: Int, 28 | var data: ByteBuffer) { 29 | // Driver metadata is an array of blocks: 30 | // | mapId0 | mapId1 | mapId2 | mapId3 | mapId4 | mapId5 | 31 | // Each block in driver metadata has next layout: 32 | // |offsetAddress|dataAddress|offsetRkeySize|offsetRkey|dataRkeySize|dataRkey| 33 | 34 | def offsetAddress(mapId: Int): Long = { 35 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize 36 | data.getLong(mapIdBlock) 37 | } 38 | 39 | def dataAddress(mapId: Int): Long = { 40 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize 41 | data.getLong(mapIdBlock + UnsafeUtils.LONG_SIZE) 42 | } 43 | 44 | def offsetRkey(mapId: Int): ByteBuffer = { 45 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize 46 | var offsetWithinBlock = mapIdBlock + 2 * UnsafeUtils.LONG_SIZE 47 | val rkeySize = data.getInt(offsetWithinBlock) 48 | offsetWithinBlock += UnsafeUtils.INT_SIZE 49 | val result = data.duplicate() 50 | result.position(offsetWithinBlock).limit(offsetWithinBlock + rkeySize) 51 | result.slice() 52 | } 53 | 54 | def dataRkey(mapId: Int): ByteBuffer = { 55 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize 56 | var offsetWithinBlock = mapIdBlock + 2 * UnsafeUtils.LONG_SIZE 57 | val offsetRkeySize = data.getInt(offsetWithinBlock) 58 | offsetWithinBlock += UnsafeUtils.INT_SIZE + offsetRkeySize 59 | val dataRkeySize = data.getInt(offsetWithinBlock) 60 | offsetWithinBlock += UnsafeUtils.INT_SIZE 61 | val result = data.duplicate() 62 | result.position(offsetWithinBlock).limit(offsetWithinBlock + dataRkeySize) 63 | result.slice() 64 | } 65 | } 66 | 67 | /** 68 | * Worker per thread wrapper, that maintains connection and progress logic. 69 | */ 70 | class UcxWorkerWrapper(val worker: UcpWorker, val conf: UcxShuffleConf, val id: Int) 71 | extends Closeable with Logging { 72 | import UcxWorkerWrapper._ 73 | 74 | private final val driverSocketAddress = new InetSocketAddress(conf.driverHost, conf.driverPort) 75 | private final val endpointParams = new UcpEndpointParams().setSocketAddress(driverSocketAddress) 76 | .setPeerErrorHandlingMode() 77 | val driverEndpoint: UcpEndpoint = worker.newEndpoint(endpointParams) 78 | 79 | private final val connections = mutable.Map.empty[BlockManagerId, UcpEndpoint] 80 | 81 | private final val driverMetadata = mutable.Map.empty[ShuffleId, DriverMetadata] 82 | 83 | override def close(): Unit = { 84 | driverMetadata.values.foreach{ 85 | case DriverMetadata(address, rkey, length, data) => rkey.close() 86 | } 87 | driverMetadata.clear() 88 | driverEndpoint.close() 89 | connections.foreach{ 90 | case (_, endpoint) => endpoint.close() 91 | } 92 | connections.clear() 93 | worker.close() 94 | driverMetadataBuffer.clear() 95 | } 96 | 97 | /** 98 | * Blocking progress single request until it's not completed. 99 | */ 100 | def waitRequest(request: UcpRequest): Unit = { 101 | val startTime = System.currentTimeMillis() 102 | worker.progressRequest(request) 103 | logDebug(s"Request completed in ${System.currentTimeMillis() - startTime} ms") 104 | } 105 | 106 | /** 107 | * Blocking progress while result queue is empty. 108 | */ 109 | def fillQueueWithBlocks(queue: LinkedBlockingQueue[_]): Unit = { 110 | while (queue.isEmpty) { 111 | progress() 112 | } 113 | } 114 | 115 | /** 116 | * The only place for worker progress 117 | */ 118 | private def progress(): Int = { 119 | worker.progress() 120 | } 121 | 122 | /** 123 | * Establish connections to known instances. 124 | */ 125 | def preconnect(): Unit = { 126 | UcxNode.getWorkerAddresses.keySet().asScala.foreach(getConnection) 127 | } 128 | 129 | def getConnection(blockManagerId: BlockManagerId): UcpEndpoint = { 130 | val workerAddresses = UcxNode.getWorkerAddresses 131 | // Block untill there's no worker address for this BlockManagerID 132 | val startTime = System.currentTimeMillis() 133 | val timeout = conf.getTimeAsMs("spark.network.timeout", "100") 134 | if (workerAddresses.get(blockManagerId) == null) { 135 | workerAddresses.synchronized { 136 | while (workerAddresses.get(blockManagerId) == null) { 137 | workerAddresses.wait(timeout) 138 | if (System.currentTimeMillis() - startTime > timeout) { 139 | throw new UcxException(s"Didn't get worker address for $blockManagerId during $timeout") 140 | } 141 | } 142 | } 143 | } 144 | 145 | connections.getOrElseUpdate(blockManagerId, { 146 | logInfo(s"Worker $id connecting to $blockManagerId") 147 | val endpointParams = new UcpEndpointParams() 148 | .setPeerErrorHandlingMode() 149 | .setUcpAddress(workerAddresses.get(blockManagerId)) 150 | worker.newEndpoint(endpointParams) 151 | }) 152 | } 153 | 154 | /** 155 | * Unpacks driver metadata RkeyBuffer for this worker. 156 | * Needed to perform PUT operation to publish map output info. 157 | */ 158 | def getDriverMetadata(shuffleId: ShuffleId): DriverMetadata = { 159 | driverMetadata.getOrElseUpdate(shuffleId, { 160 | val ucxShuffleHandle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager] 161 | .shuffleIdToHandle(shuffleId) 162 | val (address, length, rkey): (Long, Int, ByteBuffer) = (ucxShuffleHandle.metadataBufferOnDriver.getAddress, 163 | ucxShuffleHandle.numMaps * conf.metadataBlockSize.toInt, 164 | ucxShuffleHandle.metadataBufferOnDriver.getRkeyBuffer) 165 | 166 | rkey.clear() 167 | val unpackedRkey = driverEndpoint.unpackRemoteKey(rkey) 168 | DriverMetadata(address, unpackedRkey, length, null) 169 | }) 170 | } 171 | 172 | /** 173 | * Fetches using ucp_get metadata buffer from driver, with all needed information 174 | * for offset and data addresses and keys. 175 | */ 176 | def fetchDriverMetadataBuffer(shuffleId: ShuffleId): DriverMetadata = { 177 | val handle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager] 178 | .shuffleIdToHandle(shuffleId) 179 | 180 | val metadata = getDriverMetadata(handle.shuffleId) 181 | 182 | UcxWorkerWrapper.driverMetadataBuffer.computeIfAbsent(shuffleId, 183 | (t: ShuffleId) => { 184 | val buffer = Platform.allocateDirectBuffer(metadata.length) 185 | val request = driverEndpoint.getNonBlocking( 186 | metadata.address, metadata.driverRkey, buffer, null) 187 | waitRequest(request) 188 | buffer 189 | } 190 | ) 191 | 192 | if (metadata.data == null) { 193 | metadata.data = UcxWorkerWrapper.driverMetadataBuffer.get(shuffleId) 194 | } 195 | metadata 196 | } 197 | } 198 | 199 | object UcxWorkerWrapper { 200 | type ShuffleId = Int 201 | type MapId = Int 202 | // Driver metadata buffer, to fetch by first worker wrapper. 203 | val driverMetadataBuffer = new ConcurrentHashMap[ShuffleId, ByteBuffer]() 204 | 205 | val metadataBlockSize: MapId = 206 | SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager].ucxShuffleConf.metadataBlockSize.toInt 207 | } 208 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_2_1 6 | 7 | import java.io.{File, RandomAccessFile} 8 | 9 | import org.apache.spark.SparkEnv 10 | import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver} 11 | import org.apache.spark.storage.ShuffleIndexBlockId 12 | 13 | /** 14 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 15 | * of data and index files and publish addresses to driver metadata buffer. 16 | */ 17 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 18 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { 19 | 20 | private def getIndexFile(shuffleId: Int, mapId: Int): File = { 21 | SparkEnv.get.blockManager 22 | .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)) 23 | } 24 | 25 | /** 26 | * Mapper commit protocol extension. Register index and data files and publish all needed 27 | * metadata to driver. 28 | */ 29 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int, 30 | lengths: Array[Long], dataTmp: File): Unit = { 31 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) 32 | val dataFile = getDataFile(shuffleId, mapId) 33 | val dataBackFile = new RandomAccessFile(dataFile, "rw") 34 | 35 | if (dataBackFile.length() == 0) { 36 | dataBackFile.close() 37 | return 38 | } 39 | 40 | val indexFile = getIndexFile(shuffleId, mapId) 41 | val indexBackFile = new RandomAccessFile(indexFile, "rw") 42 | writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import org.apache.spark.shuffle.compat.spark_2_1.{UcxShuffleBlockResolver, UcxShuffleReader} 8 | import org.apache.spark.util.ShutdownHookManager 9 | import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} 10 | 11 | /** 12 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin 13 | * and injects needed logic in override methods. 14 | */ 15 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { 16 | ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) 17 | 18 | /** 19 | * Register a shuffle with the manager and obtain a handle for it to pass to tasks. 20 | * Called on driver and guaranteed by spark that shuffle on executor will start after it. 21 | */ 22 | override def registerShuffle[K, V, C](shuffleId: ShuffleId, 23 | numMaps: Int, 24 | dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { 25 | assume(isDriver) 26 | val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] 27 | registerShuffleCommon(baseHandle, shuffleId, numMaps) 28 | } 29 | 30 | /** 31 | * Mapper callback on executor. Just start UcxNode and use Spark mapper logic. 32 | */ 33 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, 34 | context: TaskContext): ShuffleWriter[K, V] = { 35 | startUcxNodeIfMissing() 36 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]]) 37 | super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context) 38 | } 39 | 40 | override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this) 41 | 42 | /** 43 | * Reducer callback on executor. 44 | */ 45 | override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, 46 | endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { 47 | startUcxNodeIfMissing() 48 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]]) 49 | new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, 50 | endPartition, context) 51 | } 52 | } 53 | 54 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_2_1 6 | 7 | import java.io.InputStream 8 | import java.util.concurrent.LinkedBlockingQueue 9 | 10 | import org.apache.spark.internal.{Logging, config} 11 | import org.apache.spark.serializer.SerializerManager 12 | import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1.UcxShuffleClient 13 | import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager} 14 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} 15 | import org.apache.spark.util.CompletionIterator 16 | import org.apache.spark.util.collection.ExternalSorter 17 | import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} 18 | 19 | /** 20 | * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, 21 | * and lazy progress only when result queue is empty. 22 | */ 23 | class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], 24 | startPartition: Int, 25 | endPartition: Int, 26 | context: TaskContext, 27 | serializerManager: SerializerManager = SparkEnv.get.serializerManager, 28 | blockManager: BlockManager = SparkEnv.get.blockManager, 29 | mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) 30 | extends ShuffleReader[K, C] with Logging { 31 | 32 | private val dep = handle.baseHandle.dependency 33 | 34 | /** Read the combined key-values for this reduce task */ 35 | override def read(): Iterator[Product2[K, C]] = { 36 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() 37 | val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] 38 | .ucxNode.getThreadLocalWorker 39 | val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) 40 | val wrappedStreams = new ShuffleBlockFetcherIterator( 41 | context, 42 | shuffleClient, 43 | blockManager, 44 | mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, 45 | startPartition, endPartition), 46 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility 47 | SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, 48 | SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) 49 | 50 | // Ucx shuffle logic 51 | // Java reflection to get access to private results queue 52 | val queueField = wrappedStreams.getClass.getDeclaredField( 53 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") 54 | queueField.setAccessible(true) 55 | val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] 56 | 57 | // Do progress if queue is empty before calling next on ShuffleIterator 58 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { 59 | override def next(): (BlockId, InputStream) = { 60 | val startTime = System.currentTimeMillis() 61 | workerWrapper.fillQueueWithBlocks(resultQueue) 62 | shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) 63 | wrappedStreams.next() 64 | } 65 | 66 | override def hasNext: Boolean = { 67 | val result = wrappedStreams.hasNext 68 | if (!result) { 69 | shuffleClient.close() 70 | } 71 | result 72 | } 73 | } 74 | // End of ucx shuffle logic 75 | 76 | val serializerInstance = dep.serializer.newInstance() 77 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => 78 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a 79 | // NextIterator. The NextIterator makes sure that close() is called on the 80 | // underlying InputStream when all records have been read. 81 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator 82 | } 83 | 84 | // Update the context task metrics for each record read. 85 | val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() 86 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( 87 | recordIter.map { record => 88 | readMetrics.incRecordsRead(1) 89 | record 90 | }, 91 | context.taskMetrics().mergeShuffleReadMetrics()) 92 | 93 | // An interruptible iterator must be used here in order to support task cancellation 94 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) 95 | 96 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { 97 | if (dep.mapSideCombine) { 98 | // We are reading values that are already combined 99 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] 100 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) 101 | } else { 102 | // We don't know the value type, but also don't care -- the dependency *should* 103 | // have made sure its compatible w/ this aggregator, which will convert the value 104 | // type to the combined type C 105 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] 106 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) 107 | } 108 | } else { 109 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] 110 | } 111 | 112 | // Sort the output if there is a sort ordering defined. 113 | val resultIter = dep.keyOrdering match { 114 | case Some(keyOrd: Ordering[K]) => 115 | // Create an ExternalSorter to sort the data. 116 | val sorter = 117 | new ExternalSorter[K, C, C](context, 118 | ordering = Some(keyOrd), serializer = dep.serializer) 119 | sorter.insertAll(aggregatedIter) 120 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) 121 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) 122 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) 123 | // Use completion callback to stop sorter if task was finished/cancelled. 124 | CompletionIterator[Product2[K, C], 125 | Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) 126 | case None => 127 | aggregatedIter 128 | } 129 | 130 | resultIter match { 131 | case _: InterruptibleIterator[Product2[K, C]] => resultIter 132 | case _ => 133 | // Use another interruptible iterator here to support task cancellation as aggregator 134 | // or(and) sorter may have consumed previous interruptible iterator. 135 | new InterruptibleIterator[Product2[K, C]](context, resultIter) 136 | } 137 | } 138 | 139 | } 140 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_2_4 6 | 7 | import java.io.{File, RandomAccessFile} 8 | 9 | import org.apache.spark.SparkEnv 10 | import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver} 11 | import org.apache.spark.storage.ShuffleIndexBlockId 12 | 13 | /** 14 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 15 | * of data and index files and publish addresses to driver metadata buffer. 16 | */ 17 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 18 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { 19 | 20 | private def getIndexFile(shuffleId: Int, mapId: Int): File = { 21 | SparkEnv.get.blockManager 22 | .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)) 23 | } 24 | 25 | /** 26 | * Mapper commit protocol extension. Register index and data files and publish all needed 27 | * metadata to driver. 28 | */ 29 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int, 30 | lengths: Array[Long], dataTmp: File): Unit = { 31 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) 32 | val dataFile = getDataFile(shuffleId, mapId) 33 | val dataBackFile = new RandomAccessFile(dataFile, "rw") 34 | 35 | if (dataBackFile.length() == 0) { 36 | dataBackFile.close() 37 | return 38 | } 39 | 40 | val indexFile = getIndexFile(shuffleId, mapId) 41 | val indexBackFile = new RandomAccessFile(indexFile, "rw") 42 | writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import org.apache.spark.shuffle.compat.spark_2_4.{UcxShuffleBlockResolver, UcxShuffleReader} 8 | import org.apache.spark.util.ShutdownHookManager 9 | import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} 10 | 11 | /** 12 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin 13 | * and injects needed logic in override methods. 14 | */ 15 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { 16 | ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) 17 | 18 | /** 19 | * Register a shuffle with the manager and obtain a handle for it to pass to tasks. 20 | * Called on driver and guaranteed by spark that shuffle on executor will start after it. 21 | */ 22 | override def registerShuffle[K, V, C](shuffleId: ShuffleId, 23 | numMaps: Int, 24 | dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { 25 | assume(isDriver) 26 | val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] 27 | registerShuffleCommon(baseHandle, shuffleId, numMaps) 28 | } 29 | 30 | /** 31 | * Mapper callback on executor. Just start UcxNode and use Spark mapper logic. 32 | */ 33 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, 34 | context: TaskContext): ShuffleWriter[K, V] = { 35 | startUcxNodeIfMissing() 36 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]]) 37 | super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context) 38 | } 39 | 40 | override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this) 41 | 42 | /** 43 | * Reducer callback on executor. 44 | */ 45 | override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, 46 | endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { 47 | startUcxNodeIfMissing() 48 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]]) 49 | new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, 50 | endPartition, context) 51 | } 52 | } 53 | 54 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_2_4 6 | 7 | import java.io.InputStream 8 | import java.util.concurrent.LinkedBlockingQueue 9 | 10 | import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} 11 | import org.apache.spark.internal.{Logging, config} 12 | import org.apache.spark.serializer.SerializerManager 13 | import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager} 14 | import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4.UcxShuffleClient 15 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} 16 | import org.apache.spark.util.CompletionIterator 17 | import org.apache.spark.util.collection.ExternalSorter 18 | 19 | /** 20 | * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, 21 | * and lazy progress only when result queue is empty. 22 | */ 23 | class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], 24 | startPartition: Int, 25 | endPartition: Int, 26 | context: TaskContext, 27 | serializerManager: SerializerManager = SparkEnv.get.serializerManager, 28 | blockManager: BlockManager = SparkEnv.get.blockManager, 29 | mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) 30 | extends ShuffleReader[K, C] with Logging { 31 | 32 | private val dep = handle.baseHandle.dependency 33 | 34 | /** Read the combined key-values for this reduce task */ 35 | override def read(): Iterator[Product2[K, C]] = { 36 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() 37 | val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] 38 | .ucxNode.getThreadLocalWorker 39 | val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) 40 | val wrappedStreams = new ShuffleBlockFetcherIterator( 41 | context, 42 | shuffleClient, 43 | blockManager, 44 | mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, 45 | startPartition, endPartition), 46 | serializerManager.wrapStream, 47 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility 48 | SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, 49 | SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), 50 | SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), 51 | SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), 52 | SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) 53 | 54 | // Ucx shuffle logic 55 | // Java reflection to get access to private results queue 56 | val queueField = wrappedStreams.getClass.getDeclaredField( 57 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") 58 | queueField.setAccessible(true) 59 | val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] 60 | 61 | // Do progress if queue is empty before calling next on ShuffleIterator 62 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { 63 | override def next(): (BlockId, InputStream) = { 64 | val startTime = System.currentTimeMillis() 65 | workerWrapper.fillQueueWithBlocks(resultQueue) 66 | shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) 67 | wrappedStreams.next() 68 | } 69 | 70 | override def hasNext: Boolean = { 71 | val result = wrappedStreams.hasNext 72 | if (!result) { 73 | shuffleClient.close() 74 | } 75 | result 76 | } 77 | } 78 | // End of ucx shuffle logic 79 | 80 | val serializerInstance = dep.serializer.newInstance() 81 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => 82 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a 83 | // NextIterator. The NextIterator makes sure that close() is called on the 84 | // underlying InputStream when all records have been read. 85 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator 86 | } 87 | 88 | // Update the context task metrics for each record read. 89 | val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() 90 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( 91 | recordIter.map { record => 92 | readMetrics.incRecordsRead(1) 93 | record 94 | }, 95 | context.taskMetrics().mergeShuffleReadMetrics()) 96 | 97 | // An interruptible iterator must be used here in order to support task cancellation 98 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) 99 | 100 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { 101 | if (dep.mapSideCombine) { 102 | // We are reading values that are already combined 103 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] 104 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) 105 | } else { 106 | // We don't know the value type, but also don't care -- the dependency *should* 107 | // have made sure its compatible w/ this aggregator, which will convert the value 108 | // type to the combined type C 109 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] 110 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) 111 | } 112 | } else { 113 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] 114 | } 115 | 116 | // Sort the output if there is a sort ordering defined. 117 | val resultIter = dep.keyOrdering match { 118 | case Some(keyOrd: Ordering[K]) => 119 | // Create an ExternalSorter to sort the data. 120 | val sorter = 121 | new ExternalSorter[K, C, C](context, 122 | ordering = Some(keyOrd), serializer = dep.serializer) 123 | sorter.insertAll(aggregatedIter) 124 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) 125 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) 126 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) 127 | // Use completion callback to stop sorter if task was finished/cancelled. 128 | context.addTaskCompletionListener[Unit](_ => { 129 | sorter.stop() 130 | }) 131 | CompletionIterator[Product2[K, C], 132 | Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) 133 | case None => 134 | aggregatedIter 135 | } 136 | 137 | resultIter match { 138 | case _: InterruptibleIterator[Product2[K, C]] => resultIter 139 | case _ => 140 | // Use another interruptible iterator here to support task cancellation as aggregator 141 | // or(and) sorter may have consumed previous interruptible iterator. 142 | new InterruptibleIterator[Product2[K, C]](context, resultIter) 143 | } 144 | } 145 | 146 | } 147 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import org.apache.spark.SparkConf 8 | import org.apache.spark.internal.Logging 9 | import org.apache.spark.shuffle.api.ShuffleExecutorComponents 10 | import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO 11 | 12 | /** 13 | * Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration. 14 | */ 15 | case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging { 16 | 17 | override def executor(): ShuffleExecutorComponents = { 18 | new UcxLocalDiskShuffleExecutorComponents(sparkConf) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import java.util 8 | import java.util.Optional 9 | 10 | import org.apache.spark.internal.Logging 11 | import org.apache.spark.{SparkConf, SparkEnv} 12 | import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} 13 | import org.apache.spark.shuffle.UcxShuffleManager 14 | import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} 15 | 16 | /** 17 | * Entry point to UCX executor. 18 | */ 19 | class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) 20 | extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{ 21 | 22 | private var blockResolver: UcxShuffleBlockResolver = _ 23 | 24 | override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { 25 | val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] 26 | ucxShuffleManager.startUcxNodeIfMissing() 27 | blockResolver = ucxShuffleManager.shuffleBlockResolver 28 | } 29 | 30 | override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { 31 | if (blockResolver == null) { 32 | throw new IllegalStateException( 33 | "Executor components must be initialized before getting writers.") 34 | } 35 | new LocalDiskShuffleMapOutputWriter( 36 | shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) 37 | } 38 | 39 | override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = { 40 | if (blockResolver == null) { 41 | throw new IllegalStateException( 42 | "Executor components must be initialized before getting writers.") 43 | } 44 | Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import java.io.{File, RandomAccessFile} 8 | 9 | import org.apache.spark.{SparkEnv, TaskContext} 10 | import org.apache.spark.network.shuffle.ExecutorDiskUtils 11 | import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID 12 | import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager} 13 | import org.apache.spark.storage.ShuffleIndexBlockId 14 | 15 | /** 16 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 17 | * of data and index files and publish addresses to driver metadata buffer. 18 | */ 19 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 20 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { 21 | 22 | private def getIndexFile( 23 | shuffleId: Int, 24 | mapId: Long, 25 | dirs: Option[Array[String]] = None): File = { 26 | val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID) 27 | val blockManager = SparkEnv.get.blockManager 28 | dirs 29 | .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) 30 | .getOrElse(blockManager.diskBlockManager.getFile(blockId)) 31 | } 32 | 33 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, 34 | lengths: Array[Long], dataTmp: File): Unit = { 35 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) 36 | // In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset 37 | // in metadata buffer 38 | val partitionId = TaskContext.getPartitionId() 39 | val dataFile = getDataFile(shuffleId, mapId) 40 | val dataBackFile = new RandomAccessFile(dataFile, "rw") 41 | 42 | if (dataBackFile.length() == 0) { 43 | dataBackFile.close() 44 | return 45 | } 46 | 47 | val indexFile = getIndexFile(shuffleId, mapId) 48 | val indexBackFile = new RandomAccessFile(indexFile, "rw") 49 | 50 | writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, dataTmp, indexBackFile, dataBackFile) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | import org.apache.spark.shuffle.api.ShuffleExecutorComponents 10 | import org.apache.spark.shuffle.compat.spark_3_0.{UcxShuffleBlockResolver, UcxShuffleReader} 11 | import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter} 12 | import org.apache.spark.util.ShutdownHookManager 13 | import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} 14 | 15 | /** 16 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin 17 | * and injects needed logic in override methods. 18 | */ 19 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { 20 | ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) 21 | private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) 22 | 23 | override val shuffleBlockResolver = new UcxShuffleBlockResolver(this) 24 | 25 | override def registerShuffle[K, V, C](shuffleId: ShuffleId, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { 26 | assume(isDriver) 27 | val numMaps = dependency.partitioner.numPartitions 28 | val baseHandle = super.registerShuffle(shuffleId, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] 29 | registerShuffleCommon(baseHandle, shuffleId, numMaps) 30 | } 31 | 32 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Long, context: TaskContext, 33 | metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { 34 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, V, _]]) 35 | val env = SparkEnv.get 36 | handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle match { 37 | case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] => 38 | new UnsafeShuffleWriter( 39 | env.blockManager, 40 | context.taskMemoryManager(), 41 | unsafeShuffleHandle, 42 | mapId, 43 | context, 44 | env.conf, 45 | metrics, 46 | shuffleExecutorComponents) 47 | case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => 48 | new SortShuffleWriter( 49 | shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) 50 | } 51 | } 52 | 53 | override def getReader[K, C](handle: ShuffleHandle, startPartition: MapId, endPartition: MapId, 54 | context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { 55 | 56 | startUcxNodeIfMissing() 57 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, _, C]]) 58 | new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, endPartition, 59 | context, readMetrics = metrics, shouldBatchFetch = true) 60 | } 61 | 62 | 63 | private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { 64 | val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor() 65 | val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX) 66 | .toMap 67 | executorComponents.initializeExecutor( 68 | conf.getAppId, 69 | SparkEnv.get.executorId, 70 | extraConfigs.asJava) 71 | executorComponents 72 | } 73 | 74 | } 75 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import java.io.InputStream 8 | import java.util.concurrent.LinkedBlockingQueue 9 | 10 | import scala.collection.JavaConverters._ 11 | 12 | import org.apache.spark.internal.{Logging, config} 13 | import org.apache.spark.io.CompressionCodec 14 | import org.apache.spark.serializer.SerializerManager 15 | import org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0.UcxShuffleClient 16 | import org.apache.spark.shuffle.{ShuffleReadMetricsReporter, ShuffleReader, UcxShuffleHandle, UcxShuffleManager} 17 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId} 18 | import org.apache.spark.util.CompletionIterator 19 | import org.apache.spark.util.collection.ExternalSorter 20 | import org.apache.spark.{InterruptibleIterator, SparkEnv, SparkException, TaskContext} 21 | 22 | 23 | /** 24 | * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, 25 | * and lazy progress only when result queue is empty. 26 | */ 27 | class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], 28 | startPartition: Int, 29 | endPartition: Int, 30 | context: TaskContext, 31 | serializerManager: SerializerManager = SparkEnv.get.serializerManager, 32 | blockManager: BlockManager = SparkEnv.get.blockManager, 33 | readMetrics: ShuffleReadMetricsReporter, 34 | shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging { 35 | 36 | private val dep = handle.baseHandle.dependency 37 | 38 | /** Read the combined key-values for this reduce task */ 39 | override def read(): Iterator[Product2[K, C]] = { 40 | val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( 41 | handle.shuffleId, startPartition, endPartition).duplicate 42 | val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{ 43 | case (_, blocks) => blocks.map { 44 | case (blockId, _, mapIdx) => blockId match { 45 | case x: ShuffleBlockId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer]) 46 | case x: ShuffleBlockBatchId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer]) 47 | case _ => throw new SparkException("Unknown block") 48 | } 49 | } 50 | }.toMap 51 | 52 | val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] 53 | .ucxNode.getThreadLocalWorker 54 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() 55 | val shuffleClient = new UcxShuffleClient(handle.shuffleId, workerWrapper, mapIdToBlockIndex.asJava, shuffleMetrics) 56 | val shuffleIterator = new ShuffleBlockFetcherIterator( 57 | context, 58 | shuffleClient, 59 | blockManager, 60 | blocksByAddressIterator1, 61 | serializerManager.wrapStream, 62 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility 63 | SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, 64 | SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), 65 | SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), 66 | SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), 67 | SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), 68 | SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), 69 | readMetrics, 70 | fetchContinuousBlocksInBatch) 71 | 72 | val wrappedStreams = shuffleIterator.toCompletionIterator 73 | 74 | // Ucx shuffle logic 75 | // Java reflection to get access to private results queue 76 | val queueField = shuffleIterator.getClass.getDeclaredField( 77 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") 78 | queueField.setAccessible(true) 79 | val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] 80 | 81 | // Do progress if queue is empty before calling next on ShuffleIterator 82 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { 83 | override def next(): (BlockId, InputStream) = { 84 | val startTime = System.currentTimeMillis() 85 | workerWrapper.fillQueueWithBlocks(resultQueue) 86 | readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) 87 | wrappedStreams.next() 88 | } 89 | 90 | override def hasNext: Boolean = { 91 | val result = wrappedStreams.hasNext 92 | if (!result) { 93 | shuffleClient.close() 94 | } 95 | result 96 | } 97 | } 98 | // End of ucx shuffle logic 99 | 100 | val serializerInstance = dep.serializer.newInstance() 101 | 102 | // Create a key/value iterator for each stream 103 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => 104 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a 105 | // NextIterator. The NextIterator makes sure that close() is called on the 106 | // underlying InputStream when all records have been read. 107 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator 108 | } 109 | 110 | // Update the context task metrics for each record read. 111 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( 112 | recordIter.map { record => 113 | readMetrics.incRecordsRead(1) 114 | record 115 | }, 116 | context.taskMetrics().mergeShuffleReadMetrics()) 117 | 118 | // An interruptible iterator must be used here in order to support task cancellation 119 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) 120 | 121 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { 122 | if (dep.mapSideCombine) { 123 | // We are reading values that are already combined 124 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] 125 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) 126 | } else { 127 | // We don't know the value type, but also don't care -- the dependency *should* 128 | // have made sure its compatible w/ this aggregator, which will convert the value 129 | // type to the combined type C 130 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] 131 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) 132 | } 133 | } else { 134 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] 135 | } 136 | 137 | // Sort the output if there is a sort ordering defined. 138 | val resultIter = dep.keyOrdering match { 139 | case Some(keyOrd: Ordering[K]) => 140 | // Create an ExternalSorter to sort the data. 141 | val sorter = 142 | new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) 143 | sorter.insertAll(aggregatedIter) 144 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) 145 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) 146 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) 147 | // Use completion callback to stop sorter if task was finished/cancelled. 148 | context.addTaskCompletionListener[Unit](_ => { 149 | sorter.stop() 150 | }) 151 | CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) 152 | case None => 153 | aggregatedIter 154 | } 155 | 156 | resultIter match { 157 | case _: InterruptibleIterator[Product2[K, C]] => resultIter 158 | case _ => 159 | // Use another interruptible iterator here to support task cancellation as aggregator 160 | // or(and) sorter may have consumed previous interruptible iterator. 161 | new InterruptibleIterator[Product2[K, C]](context, resultIter) 162 | } 163 | } 164 | 165 | private def fetchContinuousBlocksInBatch: Boolean = { 166 | val conf = SparkEnv.get.conf 167 | val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects 168 | val compressed = conf.get(config.SHUFFLE_COMPRESS) 169 | val codecConcatenation = if (compressed) { 170 | CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) 171 | } else { 172 | true 173 | } 174 | val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) 175 | 176 | val doBatchFetch = shouldBatchFetch && serializerRelocatable && 177 | (!compressed || codecConcatenation) && !useOldFetchProtocol 178 | if (shouldBatchFetch && !doBatchFetch) { 179 | logWarning("The feature tag of continuous shuffle block fetching is set to true, but " + 180 | "we can not enable the feature because other conditions are not satisfied. " + 181 | s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " + 182 | s"relocatable: $serializerRelocatable, " + 183 | s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + 184 | s"$useOldFetchProtocol.") 185 | } 186 | doBatchFetch 187 | } 188 | 189 | } 190 | --------------------------------------------------------------------------------