├── .artifactignore
├── .github
└── workflows
│ ├── sparkucx-ci.yml
│ └── sparkucx-release.yml
├── LICENSE
├── README.md
├── buildlib
├── azure-pipelines.yml
└── test.sh
├── pom.xml
└── src
└── main
├── java
└── org
│ └── apache
│ └── spark
│ └── shuffle
│ └── ucx
│ ├── UcxNode.java
│ ├── UnsafeUtils.java
│ ├── memory
│ ├── MemoryPool.java
│ └── RegisteredMemory.java
│ ├── reducer
│ ├── OnBlocksFetchCallback.java
│ ├── ReducerCallback.java
│ └── compat
│ │ ├── spark_2_1
│ │ ├── OnOffsetsFetchCallback.java
│ │ └── UcxShuffleClient.java
│ │ ├── spark_2_4
│ │ ├── OnOffsetsFetchCallback.java
│ │ └── UcxShuffleClient.java
│ │ └── spark_3_0
│ │ ├── OnOffsetsFetchCallback.java
│ │ └── UcxShuffleClient.java
│ └── rpc
│ ├── RpcConnectionCallback.java
│ ├── SerializableBlockManagerID.java
│ ├── UcxListenerThread.java
│ └── UcxRemoteMemory.java
└── scala
└── org
└── apache
└── spark
└── shuffle
├── CommonUcxShuffleBlockResolver.scala
├── CommonUcxShuffleManager.scala
├── UcxShuffleConf.scala
├── UcxWorkerWrapper.scala
└── compat
├── spark_2_1
├── UcxShuffleBlockResolver.scala
├── UcxShuffleManager.scala
└── UcxShuffleReader.scala
├── spark_2_4
├── UcxShuffleBlockResolver.scala
├── UcxShuffleManager.scala
└── UcxShuffleReader.scala
└── spark_3_0
├── UcxLocalDiskShuffleDataIO.scala
├── UcxLocalDiskShuffleExecutorComponents.scala
├── UcxShuffleBlockResolver.scala
├── UcxShuffleManager.scala
└── UcxShuffleReader.scala
/.artifactignore:
--------------------------------------------------------------------------------
1 | **/*
2 | !target/*.jar
3 |
--------------------------------------------------------------------------------
/.github/workflows/sparkucx-ci.yml:
--------------------------------------------------------------------------------
1 | name: SparkUCX CI
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - master
7 |
8 | jobs:
9 | build-sparkucx:
10 | strategy:
11 | matrix:
12 | spark_version: ["2.1", "2.4", "3.0"]
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v1
16 | - name: Set up JDK 1.11
17 | uses: actions/setup-java@v1
18 | with:
19 | java-version: 1.11
20 | - name: Build with Maven
21 | run: mvn -B package -Pspark-${{ matrix.spark_version }} -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn
22 | --file pom.xml
23 | - name: Run Sonar code analysis
24 | run: mvn -B sonar:sonar -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Dsonar.projectKey=openucx:spark-ucx -Dsonar.organization=openucx -Dsonar.host.url=https://sonarcloud.io -Dsonar.login=97f4df88ff4fa04e2d5b061acf07315717f1f08b -Pspark-${{ matrix.spark_version }}
25 | env:
26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
27 |
--------------------------------------------------------------------------------
/.github/workflows/sparkucx-release.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | # Sequence of patterns matched against refs/tags
4 | tags:
5 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10
6 |
7 | name: Upload Release Asset
8 |
9 | env:
10 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
11 |
12 | jobs:
13 | release:
14 | strategy:
15 | matrix:
16 | spark_version: ["2.1", "2.4", "3.0"]
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Checkout code
20 | uses: actions/checkout@v2
21 |
22 | - name: Set up JDK 1.11
23 | uses: actions/setup-java@v1
24 | with:
25 | java-version: 1.11
26 |
27 | - name: Build with Maven
28 | id: maven_package
29 | run: |
30 | mvn -B -Pspark-${{ matrix.spark_version }} clean package \
31 | -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn \
32 | --file pom.xml
33 | cd target
34 | echo "::set-output name=jar_name::$(echo spark-ucx-*-jar-with-dependencies.jar)"
35 |
36 | - name: Upload Release Jars
37 | uses: svenstaro/upload-release-action@v1-release
38 | with:
39 | repo_token: ${{ secrets.GITHUB_TOKEN }}
40 | file: ./target/${{ steps.maven_package.outputs.jar_name }}
41 | asset_name: ${{ steps.maven_package.outputs.jar_name }}
42 | tag: ${{ github.ref }}
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2019 Mellanox Technologies Ltd. All rights reserved.
2 | Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions
6 | are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 | 2. Redistributions in binary form must reproduce the above copyright
11 | notice, this list of conditions and the following disclaimer in the
12 | documentation and/or other materials provided with the distribution.
13 | 3. Neither the name of the copyright holder nor the names of its
14 | contributors may be used to endorse or promote products derived from
15 | this software without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
21 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
23 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
26 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SparkUCX ShuffleManager Plugin
2 | SparkUCX is a high performance ShuffleManager plugin for Apache Spark, that uses RDMA and other high performance transports
3 | that are supported by [UCX](https://github.com/openucx/ucx#supported-transports), to perform Shuffle data transfers in Spark jobs.
4 |
5 | This open-source project is developed, maintained and supported by the [UCF consortium](http://www.ucfconsortium.org/).
6 |
7 | ## Runtime requirements
8 | * Apache Spark 2.3/2.4/3.0
9 | * Java 8+
10 | * Installed UCX of version 1.10+, and [UCX supported transport hardware](https://github.com/openucx/ucx#supported-transports).
11 |
12 | ## Installation
13 |
14 | ### Obtain SparkUCX
15 | Please use the ["Releases"](https://github.com/openucx/sparkucx/releases) page to download SparkUCX jar file
16 | for your spark version (e.g. spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar).
17 | Put SparkUCX jar file in $SPARK_UCX_HOME on all the nodes in your cluster.
18 |
If you would like to build the project yourself, please refer to the ["Build"](https://github.com/openucx/sparkucx#build) section below.
19 |
20 | Ucx binaries **must** be in Spark classpath on every Spark Master and Worker.
21 | It can be obtained by installing the latest version from [Ucx release page](https://github.com/openucx/ucx/releases)
22 |
23 | ### Configuration
24 |
25 | Provide Spark the location of the SparkUCX plugin jars and ucx shared binaries by using the extraClassPath option.
26 |
27 | ```
28 | spark.driver.extraClassPath $SPARK_UCX_HOME/spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar:$UCX_PREFIX/lib
29 | spark.executor.extraClassPath $SPARK_UCX_HOME/spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar:$UCX_PREFIX/lib
30 | ```
31 | To enable the SparkUCX Shuffle Manager plugin, add the following configuration property
32 | to spark (e.g. in $SPARK_HOME/conf/spark-defaults.conf):
33 |
34 | ```
35 | spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager
36 | ```
37 | For spark-3.0 version add SparkUCX ShuffleIO plugin:
38 | ```
39 | spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO
40 | ```
41 |
42 | ### Build
43 |
44 | Building the SparkUCX plugin requires [Apache Maven](http://maven.apache.org/) and Java 8+ JDK
45 |
46 | Build instructions:
47 |
48 | ```
49 | % git clone https://github.com/openucx/sparkucx
50 | % cd sparkucx
51 | % mvn -DskipTests clean package -Pspark-2.4
52 | ```
53 |
54 | ### Performance
55 |
56 | SparkUCX plugin is built to provide the best performance out-of-the-box, and provides multiple configuration options to further tune SparkUCX per-job. For more information on how to setup [HiBench](https://github.com/Intel-bigdata/HiBench) benchmark and reproduce results, please refer to [Accelerated Apache SparkUCX 2.4/3.0 cluster deployment](https://docs.mellanox.com/pages/releaseview.action?pageId=19819236).
57 |
58 | 
59 |
60 |
--------------------------------------------------------------------------------
/buildlib/azure-pipelines.yml:
--------------------------------------------------------------------------------
1 | # See https://aka.ms/yaml
2 |
3 | trigger:
4 | - master
5 | - v*.*.x
6 | pr:
7 | - master
8 | - v*.*.x
9 |
10 | stages:
11 | - stage: Build
12 | jobs:
13 | - job: build
14 | strategy:
15 | maxParallel: 1
16 | matrix:
17 | spark-2.4:
18 | profile_version: "2.4"
19 | spark_version: "2.4.5"
20 | spark-3.0:
21 | profile_version: "3.0"
22 | spark_version: "3.0.1"
23 | pool:
24 | name: MLNX
25 | demands: Maven
26 | steps:
27 | - task: Maven@3
28 | displayName: build
29 | inputs:
30 | javaHomeOption: "path"
31 | jdkDirectory: "/hpc/local/oss/java/jdk/"
32 | jdkVersionOption: "1.8"
33 | mavenVersionSelection: "Path"
34 | mavenPath: "/hpc/local/oss/apache-maven-3.3.9"
35 | mavenSetM2Home: true
36 | publishJUnitResults: false
37 | goals: "package"
38 | options: "-B -Dmaven.repo.local=$(System.DefaultWorkingDirectory)/target/.deps -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Pspark-$(profile_version)"
39 | - bash: |
40 | set -xeE
41 | module load tools/spark-$(spark_version) hpcx-gcc
42 | source buildlib/test.sh
43 |
44 | if [[ $(get_rdma_device_iface) != "" ]]
45 | then
46 | export SPARK_UCX_JAR=$(System.DefaultWorkingDirectory)/target/spark-ucx-1.0-for-spark-$(profile_version)-jar-with-dependencies.jar
47 | export SPARK_LOCAL_DIRS=$(System.DefaultWorkingDirectory)/target/spark
48 | export SPARK_VERSION=$(spark_version)
49 | export UCX_LIB=$HPCX_UCX_DIR/lib
50 | cd $(System.DefaultWorkingDirectory)/target/
51 | run_tests
52 | else
53 | echo ##vso[task.complete result=Skipped;]No IB devices found
54 | fi
55 | displayName: Run spark tests
56 |
--------------------------------------------------------------------------------
/buildlib/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eExl
2 | #
3 | # Testing script for SparkUCX
4 | #
5 | # Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
6 | #
7 | # Environment variables:
8 | # - NODELIST : Space separated list of nodes (default: localhost)
9 | # - UCX_LIB : Path to UCX installation (default: LD_LIBRARY_PATH)
10 | # - SPARK_UCX_JAR : Path to SparkUCX jar (default in $PWD)
11 | # - PROCESSES_PER_INSTANCE : Number of spark processes per instance (default: 2)
12 | # - EXECUTOR_NUMBER : Unique number of jenkins executor, to run multiple tests on 1 machine
13 | # - JOB_ID : Unique job id for this test
14 | # - SCRATCH_DIRECTORY : Separate directory for this test where configs will be added.
15 | # For now it's required to be on NFS
16 | # - SPARK_LOCAL_DIRS : Directory to use for map output files and RDDs that get stored on disk.
17 | # This should be on a fast, local disk in your system.
18 | # It can also be a comma-separated list of multiple directories on different disks.
19 | # - SPARK_WORKER_CORES : Number of cores per spark process (default: 2)
20 | # - SPARK_WORKER_MEMORY : Memory per spark process (default: 1g)
21 | # - SPARK_HOME : Spark location on the file system (default: $PWD/spark). If not set, will be downloaded
22 | # - SPARK_VERSION : Version of spark to test (default: 2.4.4). Used only in SPARK_HOME is not set
23 | # - RDMA_NET_IFACE : Network interface to test
24 |
25 | NODELIST=${NODELIST:="localhost"}
26 |
27 | UCX_LIB=${UCX_LIB:=${LD_LIBRARY_PATH}}
28 |
29 | SPARK_UCX_JAR=${SPARK_UCX_JAR:=$PWD/spark-ucx-1.0-for-spark-2.4-jar-with-dependencies.jar}
30 |
31 | PROCESSES_PER_INSTANCE=${PROCESSES_PER_INSTANCE:=2}
32 |
33 | if [ -n "$EXECUTOR_NUMBER" ]
34 | then
35 | AFFINITY="taskset -c $(( 2 * EXECUTOR_NUMBER ))","$(( 2 * EXECUTOR_NUMBER + 1))"
36 | else
37 | AFFINITY=""
38 | fi
39 |
40 | EXECUTOR_NUMBER=${EXECUTOR_NUMBER:=$((RANDOM % `nproc`))}
41 |
42 | JOB_ID=${SLURM_JOB_ID:="sparkucx-test-$(cat /proc/sys/kernel/random/uuid)"}
43 |
44 | SCRATCH_DIRECTORY=${SCRATCH_DIRECTORY:=$PWD/${JOB_ID}}
45 |
46 | SPARK_LOCAL_DIRS=${SPARK_LOCAL_DIRS:="/scrap/sparkucxtest/sparkucx-${JOB_ID}"}
47 |
48 | SPARK_WORKER_CORES=${SPARK_WORKER_CORES:=2}
49 |
50 | SPARK_WORKER_MEMORY=${SPARK_WORKER_MEMORY:="1g"}
51 |
52 | SPARK_HOME=${SPARK_HOME:=$PWD/spark}
53 |
54 | SPARK_VERSION=${SPARK_VERSION:="2.4.4"}
55 |
56 | UCX_BRANCH=${UCX_BRANCH:="master"}
57 |
58 | export SPARK_MASTER_HOST=${SPARK_MASTER_HOST:=$(hostname -f)}
59 | export SPARK_MASTER_PORT=${SPARK_MASTER_PORT:=$(( 2000 + 10 * ${EXECUTOR_NUMBER}))}
60 | export SPARK_CONF_DIR=${SCRATCH_DIRECTORY}/conf
61 |
62 | get_rdma_device_iface() {
63 |
64 | if [ ! -r /dev/infiniband/rdma_cm ]
65 | then
66 | return
67 | fi
68 |
69 | if ! which ibdev2netdev >&/dev/null
70 | then
71 | return
72 | fi
73 |
74 | iface=`ibdev2netdev | grep Up | awk '{print $5}' | head -1`
75 | if [[ -n "$iface" ]]
76 | then
77 | ipaddr=$(ip addr show ${iface} | awk '/inet /{print $2}' | awk -F '/' '{print $1}')
78 | fi
79 |
80 | if [[ -z "$ipaddr" ]]
81 | then
82 | # if there is no inet (IPv4) address, escape
83 | return
84 | fi
85 |
86 | ibdev=`ibdev2netdev | grep ${iface} | awk '{print $1}'`
87 | node_guid=`cat /sys/class/infiniband/$ibdev/node_guid`
88 | if [[ ${node_guid} == "0000:0000:0000:0000" ]]
89 | then
90 | return
91 | fi
92 |
93 | SPARK_MASTER_HOST=$ipaddr
94 | echo ${iface}
95 | }
96 |
97 | RDMA_NET_IFACE=${RDMA_NET_IFACE:=`get_rdma_device_iface`}
98 |
99 | download_spark() {
100 | curl -s -L -O https://www-eu.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz
101 | mkdir -p ${SPARK_HOME}
102 | tar -xf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz -C ${SPARK_HOME} --strip-components=1
103 | }
104 |
105 | build_ucx() {
106 | git clone -b ${UCX_BRANCH} --depth=1 https://github.com/openucx/ucx.git && cd ucx
107 | ./autogen.sh
108 | mkdir build && cd build
109 | ../contrib/configure-release-mt --with-java --prefix=$PWD
110 | make -j `nproc`
111 | make install
112 | UCX_LIB=$PWD/lib/
113 | }
114 |
115 | setup_configuration() {
116 | mkdir -p ${SPARK_CONF_DIR}
117 |
118 | echo ${NODELIST} | tr -s ' ' '\n' >> "${SPARK_CONF_DIR}/slaves"
119 |
120 | cat <<-EOF > ${SPARK_CONF_DIR}/spark-defaults.conf
121 | spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager
122 | spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO
123 | spark.shuffle.readHostLocalDisk.enabled false
124 | spark.driver.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB}
125 | spark.executor.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB}
126 | spark.shuffle.ucx.driver.port $(( ${SPARK_MASTER_PORT} + 1 ))
127 | EOF
128 |
129 | cat <<-EOF > ${SPARK_CONF_DIR}/spark-env.sh
130 | export SPARK_LOCAL_IP=\`/sbin/ip addr show ${RDMA_NET_IFACE} | grep "inet\b" | awk '{print \$2}' | cut -d/ -f1\`
131 | export SPARK_WORKER_DIR=${SCRATCH_DIRECTORY}/work
132 | export SPARK_LOCAL_DIRS=${SPARK_LOCAL_DIRS}
133 | export SPARK_LOG_DIR=${SCRATCH_DIRECTORY}/logs
134 | export SPARK_CONF_DIR=${SPARK_CONF_DIR}
135 | export SPARK_MASTER_HOST=${SPARK_MASTER_HOST}
136 | export SPARK_MASTER_PORT=${SPARK_MASTER_PORT}
137 | export SPARK_WORKER_CORES=${SPARK_WORKER_CORES}
138 | export SPARK_WORKER_MEMORY=${SPARK_WORKER_MEMORY}
139 | export SPARK_IDENT_STRING=${JOB_ID}
140 | EOF
141 |
142 | cp ${SPARK_HOME}/conf/log4j.properties.template ${SPARK_CONF_DIR}/log4j.properties
143 | sed -i -e 's/INFO/WARN/g' ${SPARK_CONF_DIR}/log4j.properties
144 | echo "log4j.logger.org.apache.spark.shuffle=DEBUG" >> ${SPARK_CONF_DIR}/log4j.properties
145 | }
146 |
147 | start_cluster() {
148 | ${AFFINITY} ${SPARK_HOME}/sbin/start-master.sh
149 |
150 | # Make a script wrapper to propagate SPARK_CONF_DIR
151 | cat <<-EOF > ${SCRATCH_DIRECTORY}/sparkworker.sh
152 | #! /bin/bash
153 | export SPARK_CONF_DIR=${SPARK_CONF_DIR}
154 | export SPARK_WORKER_INSTANCES=${PROCESSES_PER_INSTANCE}
155 | export SPARK_IDENT_STRING=${JOB_ID}
156 | ${AFFINITY} ${SPARK_HOME}/sbin/start-slave.sh "spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT}"
157 | EOF
158 |
159 | SPARK_CONF_DIR=${SPARK_CONF_DIR} ${SPARK_HOME}/sbin/slaves.sh bash ${SCRATCH_DIRECTORY}/sparkworker.sh
160 | }
161 |
162 | run_groupby_test() {
163 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \
164 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \
165 | org.apache.spark.examples.GroupByTest 100 100
166 | }
167 |
168 | run_tc_test() {
169 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \
170 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \
171 | org.apache.spark.examples.SparkTC
172 | }
173 |
174 | run_tests() {
175 | if [[ ! -d ${SPARK_HOME} ]]
176 | then
177 | download_spark
178 | fi
179 |
180 | if [[ ! -d ${UCX_LIB} ]]
181 | then
182 | build_ucx
183 | fi
184 |
185 | trap stop_cluster EXIT;
186 |
187 | setup_configuration
188 | start_cluster
189 | run_groupby_test && run_tc_test
190 | }
191 |
192 | stop_cluster() {
193 | cat <<-EOF > ${SCRATCH_DIRECTORY}/stop-sparkworker.sh
194 | #! /bin/bash
195 | export SPARK_CONF_DIR=${SPARK_CONF_DIR}
196 | export SPARK_WORKER_INSTANCES=${PROCESSES_PER_INSTANCE}
197 | export SPARK_IDENT_STRING=${JOB_ID}
198 | ${SPARK_HOME}/sbin/stop-slave.sh
199 | EOF
200 |
201 | chmod +x ${SCRATCH_DIRECTORY}/stop-sparkworker.sh
202 | # Stop all slaves
203 | ${SPARK_HOME}/sbin/slaves.sh ${SCRATCH_DIRECTORY}/stop-sparkworker.sh
204 |
205 | ${SPARK_HOME}/sbin/stop-master.sh
206 | }
207 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
7 |
11 | 4.0.0
12 | org.openucx
13 | spark-ucx
14 | 1.0
15 | ${project.artifactId}
16 |
17 | A high-performance, scalable and efficient shuffle manager plugin for Apache Spark,
18 | utilizing UCX communication layer (https://github.com/openucx/ucx/).
19 |
20 | jar
21 |
22 |
23 |
24 | BSD 3 Clause License
25 | http://www.openucx.org/license/
26 | repo
27 |
28 |
29 |
30 |
31 | 1.8
32 | 1.8
33 | UTF-8
34 |
35 |
36 |
37 |
38 | spark-2.1
39 |
40 |
41 |
42 | org.apache.maven.plugins
43 | maven-compiler-plugin
44 |
45 |
46 | **/spark_3_0/**
47 | **/spark_2_4/**
48 |
49 |
50 |
51 |
52 | net.alchim31.maven
53 | scala-maven-plugin
54 |
55 |
56 | **/spark_3_0/**
57 | **/spark_2_4/**
58 |
59 |
60 |
61 |
62 |
63 |
64 | 2.1.0
65 | **/spark_3_0/**, **/spark_2_4/**
66 | 2.11.12
67 | 2.11
68 |
69 |
70 |
71 | spark-2.4
72 |
73 |
74 |
75 | org.apache.maven.plugins
76 | maven-compiler-plugin
77 |
78 |
79 | **/spark_3_0/**
80 | **/spark_2_1/**
81 |
82 |
83 |
84 |
85 | net.alchim31.maven
86 | scala-maven-plugin
87 |
88 |
89 | **/spark_2_1/**
90 | **/spark_3_0/**
91 |
92 |
93 |
94 |
95 |
96 |
97 | 2.4.0
98 | **/spark_3_0/**, **/spark_2_1/**
99 | 2.11.12
100 | 2.11
101 |
102 |
103 |
104 | spark-3.0
105 |
106 | true
107 |
108 |
109 |
110 |
111 | org.apache.maven.plugins
112 | maven-compiler-plugin
113 |
114 |
115 | **/spark_2_1/**
116 | **/spark_2_4/**
117 |
118 |
119 |
120 |
121 | net.alchim31.maven
122 | scala-maven-plugin
123 |
124 |
125 | **/spark_2_1/**
126 | **/spark_2_4/**
127 |
128 |
129 |
130 |
131 |
132 |
133 | 3.0.1
134 | 2.12.10
135 | 2.12
136 | **/spark_2_1/**, **/spark_2_4/**
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 | org.apache.spark
145 | spark-core_${scala.compat.version}
146 | ${spark.version}
147 | provided
148 |
149 |
150 | org.openucx
151 | jucx
152 | 1.11.0-rc3
153 |
154 |
155 |
156 |
157 | ${project.artifactId}-${project.version}-for-${project.activeProfiles[0].id}
158 |
159 |
160 | org.apache.maven.plugins
161 | maven-compiler-plugin
162 | 3.8.1
163 |
164 | 1.8
165 | 1.8
166 |
167 |
168 |
169 | net.alchim31.maven
170 | scala-maven-plugin
171 | 4.3.0
172 |
173 | all
174 |
175 | -nobootcp
176 | -Xexperimental
177 | -Xfatal-warnings
178 | -explaintypes
179 | -unchecked
180 | -deprecation
181 | -feature
182 |
183 |
184 |
185 |
186 | compile
187 |
188 | compile
189 |
190 | compile
191 |
192 |
193 | process-resources
194 |
195 | compile
196 |
197 |
198 |
199 |
200 |
201 | maven-assembly-plugin
202 | 3.1.1
203 |
204 |
205 | jar-with-dependencies
206 |
207 |
208 |
209 |
210 | make-assembly
211 | package
212 |
213 | single
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 | oss.sonatype.org-snapshot
224 | http://oss.sonatype.org/content/repositories/snapshots
225 |
226 | false
227 |
228 |
229 | true
230 |
231 |
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/UcxNode.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx;
6 |
7 | import org.apache.spark.SparkEnv;
8 | import org.apache.spark.shuffle.UcxShuffleConf;
9 | import org.apache.spark.shuffle.UcxWorkerWrapper;
10 | import org.apache.spark.shuffle.ucx.memory.MemoryPool;
11 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
12 | import org.apache.spark.shuffle.ucx.rpc.SerializableBlockManagerID;
13 | import org.apache.spark.shuffle.ucx.rpc.UcxListenerThread;
14 | import org.apache.spark.storage.BlockManagerId;
15 | import org.openucx.jucx.UcxCallback;
16 | import org.openucx.jucx.UcxException;
17 | import org.openucx.jucx.ucp.*;
18 |
19 | import java.io.Closeable;
20 | import java.io.IOException;
21 | import java.net.InetSocketAddress;
22 | import java.nio.ByteBuffer;
23 | import java.util.ArrayList;
24 | import java.util.List;
25 | import java.util.Set;
26 | import java.util.concurrent.ConcurrentHashMap;
27 | import java.util.concurrent.ConcurrentMap;
28 |
29 | import org.slf4j.Logger;
30 | import org.slf4j.LoggerFactory;
31 |
32 | /**
33 | * Single instance class per spark process, that keeps UcpContext, memory and worker pools.
34 | */
35 | public class UcxNode implements Closeable {
36 | // Global
37 | private static final Logger logger = LoggerFactory.getLogger(UcxNode.class);
38 | private final boolean isDriver;
39 | private final UcpContext context;
40 | private final MemoryPool memoryPool;
41 | private final UcpWorkerParams workerParams = new UcpWorkerParams();
42 | private final UcpWorker globalWorker;
43 | private final UcxShuffleConf conf;
44 | // Mapping from spark's entity of BlockManagerId to UcxEntity workerAddress.
45 | private static final ConcurrentHashMap workerAdresses =
46 | new ConcurrentHashMap<>();
47 | private final Thread listenerProgressThread;
48 | private boolean closed = false;
49 |
50 | // Driver
51 | private UcpListener listener;
52 | // Mapping from UcpEndpoint to ByteBuffer of RPC message, to introduce executor to cluster
53 | private static final ConcurrentHashMap rpcConnections =
54 | new ConcurrentHashMap<>();
55 | private List backwardEndpoints = new ArrayList<>();
56 |
57 | // Executor
58 | private UcpEndpoint globalDriverEndpoint;
59 | // Keep track of allocated workers to correctly close them.
60 | private static final Set allocatedWorkers = ConcurrentHashMap.newKeySet();
61 | private final ThreadLocal threadLocalWorker;
62 |
63 | public UcxNode(UcxShuffleConf conf, boolean isDriver) {
64 | this.conf = conf;
65 | this.isDriver = isDriver;
66 | UcpParams params = new UcpParams().requestTagFeature()
67 | .requestRmaFeature().requestWakeupFeature()
68 | .setMtWorkersShared(true);
69 | context = new UcpContext(params);
70 | memoryPool = new MemoryPool(context, conf);
71 | globalWorker = context.newWorker(workerParams);
72 | InetSocketAddress driverAddress = new InetSocketAddress(conf.driverHost(), conf.driverPort());
73 |
74 | if (isDriver) {
75 | startDriver(driverAddress);
76 | } else {
77 | startExecutor(driverAddress);
78 | }
79 |
80 | // Global listener thread, that keeps lazy progress for connection establishment
81 | listenerProgressThread = new UcxListenerThread(this, isDriver);
82 | listenerProgressThread.start();
83 |
84 | if (!isDriver) {
85 | memoryPool.preAlocate();
86 | }
87 |
88 | threadLocalWorker = ThreadLocal.withInitial(() -> {
89 | UcpWorker localWorker = context.newWorker(workerParams);
90 | UcxWorkerWrapper result = new UcxWorkerWrapper(localWorker,
91 | conf, allocatedWorkers.size());
92 | if (result.id() > conf.coresPerProcess()) {
93 | logger.warn("Thread: {} - creates new worker {} > numCores",
94 | Thread.currentThread().getId(), result.id());
95 | }
96 | allocatedWorkers.add(result);
97 | return result;
98 | });
99 | }
100 |
101 | private void startDriver(InetSocketAddress driverAddress) {
102 | // 1. Start listener on a driver and accept RPC messages from executors with their
103 | // worker addresses
104 | UcpListenerParams listenerParams = new UcpListenerParams().setSockAddr(driverAddress)
105 | .setConnectionHandler(ucpConnectionRequest ->
106 | backwardEndpoints.add(globalWorker.newEndpoint(new UcpEndpointParams()
107 | .setConnectionRequest(ucpConnectionRequest))));
108 | listener = globalWorker.newListener(listenerParams);
109 | logger.info("Started UcxNode on {}", driverAddress);
110 | }
111 |
112 | /**
113 | * Allocates ByteBuffer from memoryPool and serializes there workerAddress,
114 | * followed by BlockManagerID
115 | * @return RegisteredMemory that holds metadata buffer.
116 | */
117 | private RegisteredMemory buildMetadataBuffer() {
118 | BlockManagerId blockManagerId = SparkEnv.get().blockManager().blockManagerId();
119 | ByteBuffer workerAddresses = globalWorker.getAddress();
120 |
121 | RegisteredMemory metadataMemory = memoryPool.get(conf.metadataRPCBufferSize());
122 | ByteBuffer metadataBuffer = metadataMemory.getBuffer();
123 | metadataBuffer.putInt(workerAddresses.capacity());
124 | metadataBuffer.put(workerAddresses);
125 | try {
126 | SerializableBlockManagerID.serializeBlockManagerID(blockManagerId, metadataBuffer);
127 | } catch (IOException e) {
128 | String errorMsg = String.format("Failed to serialize %s: %s", blockManagerId,
129 | e.getMessage());
130 | throw new UcxException(errorMsg);
131 | }
132 | metadataBuffer.clear();
133 | return metadataMemory;
134 | }
135 |
136 | private void startExecutor(InetSocketAddress driverAddress) {
137 | // 1. Executor: connect to driver using sockaddr
138 | // and send it's worker address followed by BlockManagerID.
139 | globalDriverEndpoint = globalWorker.newEndpoint(
140 | new UcpEndpointParams().setSocketAddress(driverAddress).setPeerErrorHandlingMode()
141 | );
142 |
143 | RegisteredMemory metadataMemory = buildMetadataBuffer();
144 | // TODO: send using stream API when it would be available in jucx.
145 | globalDriverEndpoint.sendTaggedNonBlocking(metadataMemory.getBuffer(), new UcxCallback() {
146 | @Override
147 | public void onSuccess(UcpRequest request) {
148 | memoryPool.put(metadataMemory);
149 | }
150 | });
151 | }
152 |
153 | public UcxShuffleConf getConf() {
154 | return conf;
155 | }
156 |
157 | public UcpWorker getGlobalWorker() {
158 | return globalWorker;
159 | }
160 |
161 | public MemoryPool getMemoryPool() {
162 | return memoryPool;
163 | }
164 |
165 | public UcpContext getContext() {
166 | return context;
167 | }
168 |
169 | /**
170 | * Get or initialize worker for current thread
171 | */
172 | public UcxWorkerWrapper getThreadLocalWorker() {
173 | return threadLocalWorker.get();
174 | }
175 |
176 | public static ConcurrentMap getWorkerAddresses() {
177 | return workerAdresses;
178 | }
179 |
180 | public static ConcurrentMap getRpcConnections() {
181 | return rpcConnections;
182 | }
183 |
184 | private void stopDriver() {
185 | if (listener != null) {
186 | listener.close();
187 | listener = null;
188 | }
189 | rpcConnections.keySet().forEach(UcpEndpoint::close);
190 | rpcConnections.clear();
191 | backwardEndpoints.forEach(UcpEndpoint::close);
192 | backwardEndpoints.clear();
193 | }
194 |
195 | private void stopExecutor() {
196 | if (globalDriverEndpoint != null) {
197 | globalDriverEndpoint.close();
198 | globalDriverEndpoint = null;
199 | }
200 | allocatedWorkers.forEach(UcxWorkerWrapper::close);
201 | allocatedWorkers.clear();
202 | }
203 |
204 | @Override
205 | public void close() {
206 | threadLocalWorker.remove();
207 | synchronized (this) {
208 | if (!closed) {
209 | logger.info("Stopping UcxNode");
210 | listenerProgressThread.interrupt();
211 | globalWorker.signal();
212 | try {
213 | listenerProgressThread.join();
214 | if (isDriver) {
215 | stopDriver();
216 | } else {
217 | stopExecutor();
218 | }
219 | memoryPool.close();
220 | globalWorker.close();
221 | context.close();
222 | closed = true;
223 | } catch (InterruptedException e) {
224 | logger.error(e.getMessage());
225 | Thread.currentThread().interrupt();
226 | } catch (Exception ex) {
227 | logger.warn(ex.getLocalizedMessage());
228 | }
229 | }
230 | }
231 | }
232 | }
233 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx;
6 |
7 | import org.openucx.jucx.UcxException;
8 | import org.slf4j.Logger;
9 | import org.slf4j.LoggerFactory;
10 | import sun.nio.ch.FileChannelImpl;
11 |
12 | import java.io.IOException;
13 | import java.lang.reflect.Constructor;
14 | import java.lang.reflect.InvocationTargetException;
15 | import java.lang.reflect.Method;
16 | import java.nio.ByteBuffer;
17 | import java.nio.channels.FileChannel;
18 |
19 | /**
20 | * Java's native mmap functionality, that allows to mmap files > 2GB.
21 | */
22 | public class UnsafeUtils {
23 | private static final Method mmap;
24 | private static final Method unmmap;
25 | private static final Logger logger = LoggerFactory.getLogger(UnsafeUtils.class);
26 |
27 | private static final Constructor> directBufferConstructor;
28 |
29 | public static final int LONG_SIZE = 8;
30 | public static final int INT_SIZE = 4;
31 |
32 | static {
33 | try {
34 | mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class);
35 | mmap.setAccessible(true);
36 | unmmap = FileChannelImpl.class.getDeclaredMethod("unmap0", long.class, long.class);
37 | unmmap.setAccessible(true);
38 | Class> classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer");
39 | directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class);
40 | directBufferConstructor.setAccessible(true);
41 | } catch (Exception e) {
42 | throw new RuntimeException(e);
43 | }
44 | }
45 |
46 | private UnsafeUtils() {}
47 |
48 | public static long mmap(FileChannel fileChannel, long offset, long length) {
49 | long result;
50 | try {
51 | result = (long)mmap.invoke(fileChannel, 1, offset, length);
52 | } catch (Exception e) {
53 | logger.error("MMap({}, {}) failed: {}", offset, length, e.getMessage());
54 | throw new UcxException(e.getMessage());
55 | }
56 | return result;
57 | }
58 |
59 | public static void munmap(long address, long length) {
60 | try {
61 | unmmap.invoke(null, address, length);
62 | } catch (IllegalAccessException | InvocationTargetException e) {
63 | logger.error(e.getMessage());
64 | }
65 | }
66 |
67 | public static ByteBuffer getByteBuffer(long address, int length) throws IOException {
68 | try {
69 | return (ByteBuffer)directBufferConstructor.newInstance(address, length);
70 | } catch (InvocationTargetException ex) {
71 | throw new IOException("java.nio.DirectByteBuffer: " +
72 | "InvocationTargetException: " + ex.getTargetException());
73 | } catch (Exception e) {
74 | throw new IOException("java.nio.DirectByteBuffer exception: " + e.getMessage());
75 | }
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/memory/MemoryPool.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.memory;
6 |
7 | import org.apache.spark.shuffle.UcxShuffleConf;
8 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
9 | import org.openucx.jucx.UcxException;
10 | import org.openucx.jucx.UcxUtils;
11 | import org.openucx.jucx.ucp.UcpContext;
12 | import org.openucx.jucx.ucp.UcpMemMapParams;
13 | import org.openucx.jucx.ucp.UcpMemory;
14 | import org.slf4j.Logger;
15 | import org.slf4j.LoggerFactory;
16 |
17 | import java.io.Closeable;
18 | import java.nio.ByteBuffer;
19 | import java.util.concurrent.ConcurrentHashMap;
20 | import java.util.concurrent.ConcurrentLinkedDeque;
21 | import java.util.concurrent.atomic.AtomicInteger;
22 |
23 | /**
24 | * Utility class to reuse and preallocate registered memory to avoid memory allocation
25 | * and registration during shuffle phase.
26 | */
27 | public class MemoryPool implements Closeable {
28 | private static final Logger logger = LoggerFactory.getLogger(MemoryPool.class);
29 |
30 | @Override
31 | public void close() {
32 | for (AllocatorStack stack: allocStackMap.values()) {
33 | stack.close();
34 | logger.info("Stack of size {}. " +
35 | "Total requests: {}, total allocations: {}, preAllocations: {}",
36 | stack.length, stack.totalRequests.get(), stack.totalAlloc.get(), stack.preAllocs.get());
37 | }
38 | allocStackMap.clear();
39 | }
40 |
41 | private class AllocatorStack implements Closeable {
42 | private final AtomicInteger totalRequests = new AtomicInteger(0);
43 | private final AtomicInteger totalAlloc = new AtomicInteger(0);
44 | private final AtomicInteger preAllocs = new AtomicInteger(0);
45 | private final ConcurrentLinkedDeque stack = new ConcurrentLinkedDeque<>();
46 | private final int length;
47 |
48 | private AllocatorStack(int length) {
49 | this.length = length;
50 | }
51 |
52 | private RegisteredMemory get() {
53 | RegisteredMemory result = stack.pollFirst();
54 | if (result == null) {
55 | if (length < conf.minRegistrationSize()) {
56 | int numBuffers = conf.minRegistrationSize() / length;
57 | logger.debug("Allocating {} buffers of size {}", numBuffers, length);
58 | preallocate(numBuffers);
59 | result = stack.pollFirst();
60 | if (result == null) {
61 | return get();
62 | } else {
63 | result.getRefCount().incrementAndGet();
64 | }
65 | } else {
66 | UcpMemMapParams memMapParams = new UcpMemMapParams().setLength(length).allocate();
67 | UcpMemory memory = context.memoryMap(memMapParams);
68 | ByteBuffer buffer;
69 | try {
70 | buffer = UcxUtils.getByteBufferView(memory.getAddress(), (int)memory.getLength());
71 | } catch (Exception e) {
72 | throw new UcxException(e.getMessage());
73 | }
74 | result = new RegisteredMemory(new AtomicInteger(1), memory, buffer);
75 | totalAlloc.incrementAndGet();
76 | }
77 | } else {
78 | result.getRefCount().incrementAndGet();
79 | }
80 | totalRequests.incrementAndGet();
81 | return result;
82 | }
83 |
84 | private void put(RegisteredMemory registeredMemory) {
85 | registeredMemory.getRefCount().decrementAndGet();
86 | stack.addLast(registeredMemory);
87 | }
88 |
89 | private void preallocate(int numBuffers) {
90 | // Platform.allocateDirectBuffer supports only 2GB of buffer.
91 | // Decrease number of buffers if total size of preAllocation > 2GB.
92 | if ((long)length * (long)numBuffers > Integer.MAX_VALUE) {
93 | numBuffers = Integer.MAX_VALUE / length;
94 | }
95 |
96 | UcpMemMapParams memMapParams = new UcpMemMapParams().allocate().setLength(numBuffers * (long)length);
97 | UcpMemory memory = context.memoryMap(memMapParams);
98 | ByteBuffer buffer;
99 | try {
100 | buffer = UnsafeUtils.getByteBuffer(memory.getAddress(), numBuffers * length);
101 | } catch (Exception ex) {
102 | throw new UcxException(ex.getMessage());
103 | }
104 |
105 | AtomicInteger refCount = new AtomicInteger(numBuffers);
106 | for (int i = 0; i < numBuffers; i++) {
107 | buffer.position(i * length).limit(i * length + length);
108 | final ByteBuffer slice = buffer.slice();
109 | RegisteredMemory registeredMemory = new RegisteredMemory(refCount, memory, slice);
110 | put(registeredMemory);
111 | }
112 | preAllocs.incrementAndGet();
113 | totalAlloc.incrementAndGet();
114 | }
115 |
116 | @Override
117 | public void close() {
118 | while (!stack.isEmpty()) {
119 | RegisteredMemory memory = stack.pollFirst();
120 | if (memory != null) {
121 | memory.deregisterNativeMemory();
122 | }
123 | }
124 | }
125 | }
126 |
127 | private final ConcurrentHashMap allocStackMap =
128 | new ConcurrentHashMap<>();
129 | private final UcpContext context;
130 | private final UcxShuffleConf conf;
131 |
132 | public MemoryPool(UcpContext context, UcxShuffleConf conf) {
133 | this.context = context;
134 | this.conf = conf;
135 | }
136 |
137 | private long roundUpToTheNextPowerOf2(long length) {
138 | // Round up length to the nearest power of two, or the minimum block size
139 | if (length < conf.minBufferSize()) {
140 | length = conf.minBufferSize();
141 | } else {
142 | length--;
143 | length |= length >> 1;
144 | length |= length >> 2;
145 | length |= length >> 4;
146 | length |= length >> 8;
147 | length |= length >> 16;
148 | length++;
149 | }
150 | return length;
151 | }
152 |
153 | public RegisteredMemory get(int size) {
154 | long roundedSize = roundUpToTheNextPowerOf2(size);
155 | assert roundedSize < Integer.MAX_VALUE && roundedSize > 0;
156 | AllocatorStack stack =
157 | allocStackMap.computeIfAbsent((int)roundedSize, AllocatorStack::new);
158 | RegisteredMemory result = stack.get();
159 | result.getBuffer().position(0).limit(size);
160 | return result;
161 | }
162 |
163 | public void put(RegisteredMemory memory) {
164 | AllocatorStack allocatorStack = allocStackMap.get(memory.getBuffer().capacity());
165 | if (allocatorStack != null) {
166 | allocatorStack.put(memory);
167 | }
168 | }
169 |
170 | public void preAlocate() {
171 | conf.preallocateBuffersMap().forEach((size, numBuffers) -> {
172 | logger.debug("Pre allocating {} buffers of size {}", numBuffers, size);
173 | AllocatorStack stack = new AllocatorStack(size);
174 | allocStackMap.put(size, stack);
175 | stack.preallocate(numBuffers);
176 | });
177 | }
178 |
179 | }
180 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/memory/RegisteredMemory.java:
--------------------------------------------------------------------------------
1 | package org.apache.spark.shuffle.ucx.memory;
2 |
3 | import org.openucx.jucx.ucp.UcpMemory;
4 | import org.slf4j.Logger;
5 | import org.slf4j.LoggerFactory;
6 |
7 | import java.nio.ByteBuffer;
8 | import java.util.concurrent.atomic.AtomicInteger;
9 |
10 | /**
11 | * Structure to use 1 memory region for multiple ByteBuffers.
12 | * Keeps track on reference count to memory region.
13 | */
14 | public class RegisteredMemory {
15 | private static final Logger logger = LoggerFactory.getLogger(RegisteredMemory.class);
16 |
17 | private final AtomicInteger refcount;
18 | private final UcpMemory memory;
19 | private final ByteBuffer buffer;
20 |
21 | RegisteredMemory(AtomicInteger refcount, UcpMemory memory, ByteBuffer buffer) {
22 | this.refcount = refcount;
23 | this.memory = memory;
24 | this.buffer = buffer;
25 | }
26 |
27 | public ByteBuffer getBuffer() {
28 | return buffer;
29 | }
30 |
31 | AtomicInteger getRefCount() {
32 | return refcount;
33 | }
34 |
35 | void deregisterNativeMemory() {
36 | if (refcount.get() != 0) {
37 | logger.warn("De-registering memory of size {} that has active references.", buffer.capacity());
38 | }
39 | if (memory != null && memory.getNativeId() != null) {
40 | memory.deregister();
41 | }
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/OnBlocksFetchCallback.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer;
6 |
7 | import org.apache.spark.network.buffer.ManagedBuffer;
8 | import org.apache.spark.network.buffer.NioManagedBuffer;
9 |
10 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
11 | import org.apache.spark.storage.BlockId;
12 | import org.apache.spark.util.Utils;
13 | import org.openucx.jucx.ucp.UcpRequest;
14 |
15 | import java.nio.ByteBuffer;
16 | import java.util.concurrent.atomic.AtomicInteger;
17 |
18 | /**
19 | * Final callback when all blocks fetched.
20 | * Notifies Spark's shuffleFetchIterator on block fetch completion.
21 | */
22 | public class OnBlocksFetchCallback extends ReducerCallback {
23 | protected RegisteredMemory blocksMemory;
24 | protected int[] sizes;
25 |
26 | public OnBlocksFetchCallback(ReducerCallback callback, RegisteredMemory blocksMemory, int[] sizes) {
27 | super(callback);
28 | this.blocksMemory = blocksMemory;
29 | this.sizes = sizes;
30 | }
31 |
32 | @Override
33 | public void onSuccess(UcpRequest request) {
34 | int position = 0;
35 | AtomicInteger refCount = new AtomicInteger(blockIds.length);
36 | for (int i = 0; i < blockIds.length; i++) {
37 | BlockId block = blockIds[i];
38 | // Blocks are fetched to contiguous buffer.
39 | // |----block1---||---block2---||---block3---|
40 | // Slice each block to avoid buffer copy.
41 | blocksMemory.getBuffer().position(position).limit(position + sizes[i]);
42 | ByteBuffer blockBuffer = blocksMemory.getBuffer().slice();
43 | position += sizes[i];
44 | // Pass block to Spark's ShuffleFetchIterator.
45 | listener.onBlockFetchSuccess(block.name(), new NioManagedBuffer(blockBuffer) {
46 | @Override
47 | public ManagedBuffer release() {
48 | if (refCount.decrementAndGet() == 0) {
49 | mempool.put(blocksMemory);
50 | }
51 | return this;
52 | }
53 | });
54 | }
55 | logger.info("Endpoint {} fetched {} blocks of total size {} in {}ms", endpoint.getNativeId(), blockIds.length,
56 | Utils.bytesToString(position), System.currentTimeMillis() - startTime);
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/ReducerCallback.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer;
6 |
7 | import org.apache.spark.SparkEnv;
8 | import org.apache.spark.network.shuffle.BlockFetchingListener;
9 | import org.apache.spark.shuffle.CommonUcxShuffleManager;
10 | import org.apache.spark.shuffle.ucx.memory.MemoryPool;
11 | import org.apache.spark.storage.BlockId;
12 | import org.openucx.jucx.UcxCallback;
13 | import org.openucx.jucx.ucp.UcpEndpoint;
14 | import org.slf4j.Logger;
15 | import org.slf4j.LoggerFactory;
16 |
17 | /**
18 | * Common data needed for offset fetch callback and subsequent block fetch callback.
19 | */
20 | public abstract class ReducerCallback extends UcxCallback {
21 | protected MemoryPool mempool;
22 | protected BlockId[] blockIds;
23 | protected UcpEndpoint endpoint;
24 | protected BlockFetchingListener listener;
25 | protected static final Logger logger = LoggerFactory.getLogger(ReducerCallback.class);
26 | protected long startTime = System.currentTimeMillis();
27 |
28 | public ReducerCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener) {
29 | this.mempool = ((CommonUcxShuffleManager)SparkEnv.get().shuffleManager()).ucxNode().getMemoryPool();
30 | this.blockIds = blockIds;
31 | this.endpoint = endpoint;
32 | this.listener = listener;
33 | }
34 |
35 | public ReducerCallback(ReducerCallback callback) {
36 | this.blockIds = callback.blockIds;
37 | this.endpoint = callback.endpoint;
38 | this.listener = callback.listener;
39 | this.mempool = callback.mempool;
40 | this.startTime = callback.startTime;
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1;
6 |
7 | import org.apache.spark.network.shuffle.BlockFetchingListener;
8 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
9 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
10 | import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
11 | import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
12 | import org.apache.spark.storage.ShuffleBlockId;
13 | import org.openucx.jucx.UcxUtils;
14 | import org.openucx.jucx.ucp.UcpEndpoint;
15 | import org.openucx.jucx.ucp.UcpRemoteKey;
16 | import org.openucx.jucx.ucp.UcpRequest;
17 |
18 | import java.nio.ByteBuffer;
19 | import java.util.Map;
20 |
21 | /**
22 | * Callback, called when got all offsets for blocks
23 | */
24 | public class OnOffsetsFetchCallback extends ReducerCallback {
25 | private final RegisteredMemory offsetMemory;
26 | private final long[] dataAddresses;
27 | private Map dataRkeysCache;
28 |
29 | public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
30 | RegisteredMemory offsetMemory, long[] dataAddresses,
31 | Map dataRkeysCache) {
32 | super(blockIds, endpoint, listener);
33 | this.offsetMemory = offsetMemory;
34 | this.dataAddresses = dataAddresses;
35 | this.dataRkeysCache = dataRkeysCache;
36 | }
37 |
38 | @Override
39 | public void onSuccess(UcpRequest request) {
40 | ByteBuffer resultOffset = offsetMemory.getBuffer();
41 | long totalSize = 0;
42 | int[] sizes = new int[blockIds.length];
43 | int offsetSize = UnsafeUtils.LONG_SIZE;
44 | for (int i = 0; i < blockIds.length; i++) {
45 | // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
46 | long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
47 | long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
48 | assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
49 | sizes[i] = (int) blockLength;
50 | totalSize += blockLength;
51 | dataAddresses[i] += blockOffset;
52 | }
53 |
54 | assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
55 | mempool.put(offsetMemory);
56 | RegisteredMemory blocksMemory = mempool.get((int) totalSize);
57 |
58 | long offset = 0;
59 | // Submits N fetch blocks requests
60 | for (int i = 0; i < blockIds.length; i++) {
61 | endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()),
62 | UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
63 | offset += sizes[i];
64 | }
65 |
66 | // Process blocks when all fetched.
67 | // Flush guarantees that callback would invoke when all fetch requests will completed.
68 | endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1;
6 |
7 | import org.apache.spark.SparkEnv;
8 | import org.apache.spark.executor.TempShuffleReadMetrics;
9 | import org.apache.spark.network.shuffle.BlockFetchingListener;
10 | import org.apache.spark.network.shuffle.ShuffleClient;
11 | import org.apache.spark.shuffle.DriverMetadata;
12 | import org.apache.spark.shuffle.UcxShuffleManager;
13 | import org.apache.spark.shuffle.UcxWorkerWrapper;
14 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
15 | import org.apache.spark.shuffle.ucx.memory.MemoryPool;
16 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
17 | import org.apache.spark.storage.BlockId;
18 | import org.apache.spark.storage.BlockManagerId;
19 | import org.apache.spark.storage.ShuffleBlockId;
20 | import org.openucx.jucx.UcxUtils;
21 | import org.openucx.jucx.ucp.UcpEndpoint;
22 | import org.openucx.jucx.ucp.UcpRemoteKey;
23 | import org.slf4j.Logger;
24 | import org.slf4j.LoggerFactory;
25 | import scala.Option;
26 |
27 | import java.util.Arrays;
28 | import java.util.HashMap;
29 |
30 | public class UcxShuffleClient extends ShuffleClient {
31 | private final MemoryPool mempool;
32 | private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
33 | private final UcxShuffleManager ucxShuffleManager;
34 | private final TempShuffleReadMetrics shuffleReadMetrics;
35 | private final UcxWorkerWrapper workerWrapper;
36 | final HashMap offsetRkeysCache = new HashMap<>();
37 | final HashMap dataRkeysCache = new HashMap<>();
38 |
39 | public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
40 | UcxWorkerWrapper workerWrapper) {
41 | this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
42 | this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
43 | this.shuffleReadMetrics = shuffleReadMetrics;
44 | this.workerWrapper = workerWrapper;
45 | }
46 |
47 | /**
48 | * Submits n non blocking fetch offsets to get needed offsets for n blocks.
49 | */
50 | private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
51 | long[] dataAddresses, RegisteredMemory offsetMemory) {
52 | DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
53 | for (int i = 0; i < blockIds.length; i++) {
54 | ShuffleBlockId blockId = blockIds[i];
55 |
56 | long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
57 | dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());
58 |
59 | offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
60 | endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));
61 |
62 | dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
63 | endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));
64 |
65 | endpoint.getNonBlockingImplicit(
66 | offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
67 | offsetRkeysCache.get(blockId.mapId()),
68 | UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
69 | 2L * UnsafeUtils.LONG_SIZE);
70 | }
71 | }
72 |
73 | /**
74 | * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
75 | * This method is inside ShuffleFetchIterator's for loop over hosts.
76 | * First fetches block offset from index file, and then fetches block itself.
77 | */
78 | @Override
79 | public void fetchBlocks(String host, int port, String execId,
80 | String[] blockIds, BlockFetchingListener listener) {
81 | long startTime = System.currentTimeMillis();
82 |
83 | BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
84 | UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
85 |
86 | long[] dataAddresses = new long[blockIds.length];
87 |
88 | // Need to fetch 2 long offsets current block + next block to calculate exact block size.
89 | RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);
90 |
91 | ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
92 | .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
93 |
94 | // Submits N implicit get requests without callback
95 | submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);
96 |
97 | // flush guarantees that all that requests completes when callback is called.
98 | // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
99 | workerWrapper.worker().flushNonBlocking(
100 | new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
101 | dataAddresses, dataRkeysCache));
102 | shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
103 | }
104 |
105 | @Override
106 | public void close() {
107 | offsetRkeysCache.values().forEach(UcpRemoteKey::close);
108 | dataRkeysCache.values().forEach(UcpRemoteKey::close);
109 | logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/OnOffsetsFetchCallback.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4;
6 |
7 | import org.apache.spark.network.shuffle.BlockFetchingListener;
8 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
9 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
10 | import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
11 | import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
12 | import org.apache.spark.storage.ShuffleBlockId;
13 | import org.openucx.jucx.UcxUtils;
14 | import org.openucx.jucx.ucp.UcpEndpoint;
15 | import org.openucx.jucx.ucp.UcpRemoteKey;
16 | import org.openucx.jucx.ucp.UcpRequest;
17 |
18 | import java.nio.ByteBuffer;
19 | import java.util.Map;
20 |
21 | /**
22 | * Callback, called when got all offsets for blocks
23 | */
24 | public class OnOffsetsFetchCallback extends ReducerCallback {
25 | private final RegisteredMemory offsetMemory;
26 | private final long[] dataAddresses;
27 | private Map dataRkeysCache;
28 |
29 | public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
30 | RegisteredMemory offsetMemory, long[] dataAddresses,
31 | Map dataRkeysCache) {
32 | super(blockIds, endpoint, listener);
33 | this.offsetMemory = offsetMemory;
34 | this.dataAddresses = dataAddresses;
35 | this.dataRkeysCache = dataRkeysCache;
36 | }
37 |
38 | @Override
39 | public void onSuccess(UcpRequest request) {
40 | ByteBuffer resultOffset = offsetMemory.getBuffer();
41 | long totalSize = 0;
42 | int[] sizes = new int[blockIds.length];
43 | int offsetSize = UnsafeUtils.LONG_SIZE;
44 | for (int i = 0; i < blockIds.length; i++) {
45 | // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
46 | long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
47 | long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
48 | assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
49 | sizes[i] = (int) blockLength;
50 | totalSize += blockLength;
51 | dataAddresses[i] += blockOffset;
52 | }
53 |
54 | assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
55 | mempool.put(offsetMemory);
56 | RegisteredMemory blocksMemory = mempool.get((int) totalSize);
57 |
58 | long offset = 0;
59 | // Submits N fetch blocks requests
60 | for (int i = 0; i < blockIds.length; i++) {
61 | endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()),
62 | UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
63 | offset += sizes[i];
64 | }
65 |
66 | // Process blocks when all fetched.
67 | // Flush guarantees that callback would invoke when all fetch requests will completed.
68 | endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_4/UcxShuffleClient.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4;
6 |
7 | import org.apache.spark.SparkEnv;
8 | import org.apache.spark.executor.TempShuffleReadMetrics;
9 | import org.apache.spark.network.shuffle.BlockFetchingListener;
10 | import org.apache.spark.network.shuffle.DownloadFileManager;
11 | import org.apache.spark.network.shuffle.ShuffleClient;
12 | import org.apache.spark.shuffle.*;
13 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
14 | import org.apache.spark.shuffle.ucx.memory.MemoryPool;
15 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
16 | import org.apache.spark.storage.BlockId;
17 | import org.apache.spark.storage.BlockManagerId;
18 | import org.apache.spark.storage.ShuffleBlockId;
19 | import org.openucx.jucx.UcxUtils;
20 | import org.slf4j.Logger;
21 | import org.slf4j.LoggerFactory;
22 | import org.openucx.jucx.ucp.UcpEndpoint;
23 | import org.openucx.jucx.ucp.UcpRemoteKey;
24 | import scala.Option;
25 |
26 | import java.util.Arrays;
27 | import java.util.HashMap;
28 |
29 | public class UcxShuffleClient extends ShuffleClient {
30 | private final MemoryPool mempool;
31 | private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
32 | private final UcxShuffleManager ucxShuffleManager;
33 | private final TempShuffleReadMetrics shuffleReadMetrics;
34 | private final UcxWorkerWrapper workerWrapper;
35 | final HashMap offsetRkeysCache = new HashMap<>();
36 | final HashMap dataRkeysCache = new HashMap<>();
37 |
38 | public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics,
39 | UcxWorkerWrapper workerWrapper) {
40 | this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager();
41 | this.mempool = ucxShuffleManager.ucxNode().getMemoryPool();
42 | this.shuffleReadMetrics = shuffleReadMetrics;
43 | this.workerWrapper = workerWrapper;
44 | }
45 |
46 | /**
47 | * Submits n non blocking fetch offsets to get needed offsets for n blocks.
48 | */
49 | private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
50 | long[] dataAddresses, RegisteredMemory offsetMemory) {
51 | DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId());
52 | for (int i = 0; i < blockIds.length; i++) {
53 | ShuffleBlockId blockId = blockIds[i];
54 |
55 | long offsetAddress = driverMetadata.offsetAddress(blockId.mapId());
56 | dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId());
57 |
58 | offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
59 | endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId())));
60 |
61 | dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId ->
62 | endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));
63 |
64 | endpoint.getNonBlockingImplicit(
65 | offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
66 | offsetRkeysCache.get(blockId.mapId()),
67 | UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
68 | 2L * UnsafeUtils.LONG_SIZE);
69 | }
70 | }
71 |
72 | /**
73 | * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls.
74 | * This method is inside ShuffleFetchIterator's for loop over hosts.
75 | * First fetches block offset from index file, and then fetches block itself.
76 | */
77 | @Override
78 | public void fetchBlocks(String host, int port, String execId,
79 | String[] blockIds, BlockFetchingListener listener,
80 | DownloadFileManager downloadFileManager) {
81 | long startTime = System.currentTimeMillis();
82 |
83 | BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
84 | UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
85 |
86 | long[] dataAddresses = new long[blockIds.length];
87 |
88 | // Need to fetch 2 long offsets current block + next block to calculate exact block size.
89 | RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);
90 |
91 | ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
92 | .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
93 |
94 | // Submits N implicit get requests without callback
95 | submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory);
96 |
97 | // flush guarantees that all that requests completes when callback is called.
98 | // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
99 | workerWrapper.worker().flushNonBlocking(
100 | new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory,
101 | dataAddresses, dataRkeysCache));
102 | shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
103 | }
104 |
105 | @Override
106 | public void close() {
107 | offsetRkeysCache.values().forEach(UcpRemoteKey::close);
108 | dataRkeysCache.values().forEach(UcpRemoteKey::close);
109 | logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/OnOffsetsFetchCallback.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0;
6 |
7 | import org.apache.spark.network.shuffle.BlockFetchingListener;
8 | import org.apache.spark.shuffle.UcxWorkerWrapper;
9 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
10 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
11 | import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
12 | import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
13 | import org.apache.spark.storage.BlockId;
14 | import org.apache.spark.storage.ShuffleBlockBatchId;
15 | import org.apache.spark.storage.ShuffleBlockId;
16 | import org.openucx.jucx.UcxUtils;
17 | import org.openucx.jucx.ucp.UcpEndpoint;
18 | import org.openucx.jucx.ucp.UcpRemoteKey;
19 | import org.openucx.jucx.ucp.UcpRequest;
20 |
21 | import java.nio.ByteBuffer;
22 | import java.util.Map;
23 |
24 | /**
25 | * Callback, called when got all offsets for blocks
26 | */
27 | public class OnOffsetsFetchCallback extends ReducerCallback {
28 | private final RegisteredMemory offsetMemory;
29 | private final long[] dataAddresses;
30 | private Map dataRkeysCache;
31 | private final Map mapId2PartitionId;
32 |
33 | public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
34 | RegisteredMemory offsetMemory, long[] dataAddresses,
35 | Map dataRkeysCache,
36 | Map mapId2PartitionId) {
37 | super(blockIds, endpoint, listener);
38 | this.offsetMemory = offsetMemory;
39 | this.dataAddresses = dataAddresses;
40 | this.dataRkeysCache = dataRkeysCache;
41 | this.mapId2PartitionId = mapId2PartitionId;
42 | }
43 |
44 | @Override
45 | public void onSuccess(UcpRequest request) {
46 | ByteBuffer resultOffset = offsetMemory.getBuffer();
47 | long totalSize = 0;
48 | int[] sizes = new int[blockIds.length];
49 | int offset = 0;
50 | long blockOffset;
51 | long blockLength;
52 | int offsetSize = UnsafeUtils.LONG_SIZE;
53 | for (int i = 0; i < blockIds.length; i++) {
54 | // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
55 | if (blockIds[i] instanceof ShuffleBlockBatchId) {
56 | ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i];
57 | int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId();
58 | blockOffset = resultOffset.getLong(offset * 2 * offsetSize);
59 | blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch)
60 | - blockOffset;
61 | offset += blocksInBatch;
62 | } else {
63 | blockOffset = resultOffset.getLong(offset * 16);
64 | blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset;
65 | offset++;
66 | }
67 |
68 | assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
69 | sizes[i] = (int) blockLength;
70 | totalSize += blockLength;
71 | dataAddresses[i] += blockOffset;
72 | }
73 |
74 | assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
75 | mempool.put(offsetMemory);
76 | RegisteredMemory blocksMemory = mempool.get((int) totalSize);
77 |
78 | offset = 0;
79 | // Submits N fetch blocks requests
80 | for (int i = 0; i < blockIds.length; i++) {
81 | int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ?
82 | mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) :
83 | mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId());
84 | endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId),
85 | UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
86 | offset += sizes[i];
87 | }
88 |
89 | // Process blocks when all fetched.
90 | // Flush guarantees that callback would invoke when all fetch requests will completed.
91 | endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_3_0/UcxShuffleClient.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0;
6 |
7 | import org.apache.spark.SparkEnv;
8 | import org.apache.spark.executor.TempShuffleReadMetrics;
9 | import org.apache.spark.network.shuffle.BlockFetchingListener;
10 | import org.apache.spark.network.shuffle.BlockStoreClient;
11 | import org.apache.spark.network.shuffle.DownloadFileManager;
12 | import org.apache.spark.shuffle.DriverMetadata;
13 | import org.apache.spark.shuffle.UcxShuffleManager;
14 | import org.apache.spark.shuffle.UcxWorkerWrapper;
15 | import org.apache.spark.shuffle.ucx.UnsafeUtils;
16 | import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
17 | import org.apache.spark.storage.*;
18 | import org.openucx.jucx.UcxUtils;
19 | import org.openucx.jucx.ucp.UcpEndpoint;
20 | import org.openucx.jucx.ucp.UcpRemoteKey;
21 | import org.slf4j.Logger;
22 | import org.slf4j.LoggerFactory;
23 | import scala.Option;
24 |
25 |
26 | import java.util.HashMap;
27 | import java.util.Map;
28 |
29 | public class UcxShuffleClient extends BlockStoreClient {
30 | private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
31 | private final UcxWorkerWrapper workerWrapper;
32 | private final Map mapId2PartitionId;
33 | private final TempShuffleReadMetrics shuffleReadMetrics;
34 | private final int shuffleId;
35 | final HashMap offsetRkeysCache = new HashMap<>();
36 | final HashMap dataRkeysCache = new HashMap<>();
37 |
38 |
39 | public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper,
40 | Map mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) {
41 | this.workerWrapper = workerWrapper;
42 | this.shuffleId = shuffleId;
43 | this.mapId2PartitionId = mapId2PartitionId;
44 | this.shuffleReadMetrics = shuffleReadMetrics;
45 | }
46 |
47 | /**
48 | * Submits n non blocking fetch offsets to get needed offsets for n blocks.
49 | */
50 | private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds,
51 | RegisteredMemory offsetMemory,
52 | long[] dataAddresses) {
53 | DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId);
54 | long offset = 0;
55 | int startReduceId;
56 | long size;
57 |
58 | for (int i = 0; i < blockIds.length; i++) {
59 | BlockId blockId = blockIds[i];
60 | int mapIdpartition;
61 |
62 | if (blockId instanceof ShuffleBlockId) {
63 | ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId;
64 | mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId());
65 | size = 2L * UnsafeUtils.LONG_SIZE;
66 | startReduceId = shuffleBlockId.reduceId();
67 | } else {
68 | ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId;
69 | mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId());
70 | size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId())
71 | * 2L * UnsafeUtils.LONG_SIZE;
72 | startReduceId = shuffleBlockBatchId.startReduceId();
73 | }
74 |
75 | long offsetAddress = driverMetadata.offsetAddress(mapIdpartition);
76 | dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition);
77 |
78 | offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
79 | endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition)));
80 |
81 | dataRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
82 | endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition)));
83 |
84 | endpoint.getNonBlockingImplicit(
85 | offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE,
86 | offsetRkeysCache.get(mapIdpartition),
87 | UcxUtils.getAddress(offsetMemory.getBuffer()) + offset,
88 | size);
89 |
90 | offset += size;
91 | }
92 | }
93 |
94 | @Override
95 | public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener,
96 | DownloadFileManager downloadFileManager) {
97 | long startTime = System.currentTimeMillis();
98 | BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
99 | UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
100 | long[] dataAddresses = new long[blockIds.length];
101 | int totalBlocks = 0;
102 |
103 | BlockId[] blocks = new BlockId[blockIds.length];
104 |
105 | for (int i = 0; i < blockIds.length; i++) {
106 | blocks[i] = BlockId.apply(blockIds[i]);
107 | if (blocks[i] instanceof ShuffleBlockId) {
108 | totalBlocks += 1;
109 | } else {
110 | ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i];
111 | totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId());
112 | }
113 | }
114 |
115 | RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager())
116 | .ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE);
117 | // Submits N implicit get requests without callback
118 | submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses);
119 |
120 | // flush guarantees that all that requests completes when callback is called.
121 | // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
122 | workerWrapper.worker().flushNonBlocking(
123 | new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory,
124 | dataAddresses, dataRkeysCache, mapId2PartitionId));
125 |
126 | shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
127 | }
128 |
129 | @Override
130 | public void close() {
131 | offsetRkeysCache.values().forEach(UcpRemoteKey::close);
132 | dataRkeysCache.values().forEach(UcpRemoteKey::close);
133 | logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
134 | }
135 |
136 | }
137 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/rpc/RpcConnectionCallback.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.rpc;
6 |
7 | import org.apache.spark.shuffle.ucx.UcxNode;
8 | import org.apache.spark.storage.BlockManagerId;
9 | import org.apache.spark.unsafe.Platform;
10 | import org.openucx.jucx.UcxCallback;
11 | import org.openucx.jucx.UcxException;
12 | import org.openucx.jucx.ucp.UcpEndpoint;
13 | import org.openucx.jucx.ucp.UcpEndpointParams;
14 | import org.openucx.jucx.ucp.UcpRequest;
15 | import org.openucx.jucx.ucp.UcpWorker;
16 | import org.slf4j.Logger;
17 | import org.slf4j.LoggerFactory;
18 |
19 | import java.io.IOException;
20 | import java.nio.ByteBuffer;
21 | import java.util.concurrent.ConcurrentMap;
22 |
23 | /**
24 | * RPC processing logic. Both driver and excutor accepts the same RPC messgae:
25 | * executor worker address followed by it's serialized BlockManagerID.
26 | * Executor on accepting this message just adds workerAddress to the connection map.
27 | * Driver doing the logic of introducing connected executor to cluster nodes and
28 | * introduce cluster to connected executor.
29 | */
30 | public class RpcConnectionCallback extends UcxCallback {
31 | private static final Logger logger = LoggerFactory.getLogger(RpcConnectionCallback.class);
32 | private final ByteBuffer metadataBuffer;
33 | private final boolean isDriver;
34 | private final UcxNode ucxNode;
35 | private static final ConcurrentMap rpcConnections =
36 | UcxNode.getRpcConnections();
37 | private static final ConcurrentMap workerAdresses =
38 | UcxNode.getWorkerAddresses();
39 |
40 | RpcConnectionCallback(ByteBuffer metadataBuffer, boolean isDriver, UcxNode ucxNode) {
41 | this.metadataBuffer = metadataBuffer;
42 | this.isDriver = isDriver;
43 | this.ucxNode = ucxNode;
44 | }
45 |
46 | @Override
47 | public void onSuccess(UcpRequest request) {
48 | int workerAddressSize = metadataBuffer.getInt();
49 | ByteBuffer workerAddress = Platform.allocateDirectBuffer(workerAddressSize);
50 |
51 | // Copy worker address from metadata buffer to separate buffer.
52 | final ByteBuffer metadataView = metadataBuffer.duplicate();
53 | metadataView.limit(metadataView.position() + workerAddressSize);
54 | workerAddress.put(metadataView);
55 | metadataBuffer.position(metadataBuffer.position() + workerAddressSize);
56 |
57 | BlockManagerId blockManagerId;
58 | try {
59 | blockManagerId = SerializableBlockManagerID
60 | .deserializeBlockManagerID(metadataBuffer);
61 | } catch (IOException e) {
62 | String errorMsg = String.format("Failed to deserialize BlockManagerId: %s", e.getMessage());
63 | throw new UcxException(errorMsg);
64 | }
65 | logger.debug("Received RPC message from {}", blockManagerId);
66 | UcpWorker globalWorker = ucxNode.getGlobalWorker();
67 |
68 | workerAddress.clear();
69 |
70 | if (isDriver) {
71 | metadataBuffer.clear();
72 | UcpEndpoint newConnection = globalWorker.newEndpoint(
73 | new UcpEndpointParams().setPeerErrorHandlingMode()
74 | .setUcpAddress(workerAddress));
75 | // For each existing connection
76 | rpcConnections.forEach((connection, connectionMetadata) -> {
77 | // send address of joined worker to already connected workers
78 | connection.sendTaggedNonBlocking(metadataBuffer, null);
79 | // introduce other workers to joined worker
80 | newConnection.sendTaggedNonBlocking(connectionMetadata, null);
81 | });
82 |
83 | rpcConnections.put(newConnection, metadataBuffer);
84 | }
85 | workerAdresses.put(blockManagerId, workerAddress);
86 | synchronized (workerAdresses) {
87 | workerAdresses.notifyAll();
88 | }
89 | }
90 |
91 | @Override
92 | public void onError(int ucsStatus, String errorMsg) {
93 | // UCS_ERR_CANCELED = -16,
94 | if (ucsStatus != -16) {
95 | logger.error("Request error: {}", errorMsg);
96 | throw new UcxException(errorMsg);
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/rpc/SerializableBlockManagerID.java:
--------------------------------------------------------------------------------
1 | package org.apache.spark.shuffle.ucx.rpc;
2 |
3 | import com.esotericsoftware.kryo.io.ByteBufferInputStream;
4 | import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
5 | import org.apache.spark.storage.BlockManagerId;
6 |
7 | import java.io.IOException;
8 | import java.io.ObjectInputStream;
9 | import java.io.ObjectOutputStream;
10 | import java.nio.ByteBuffer;
11 |
12 | /**
13 | * Static mthods to serialize BlockManagerID to ByteBuffer.
14 | */
15 | public class SerializableBlockManagerID {
16 |
17 | public static void serializeBlockManagerID(BlockManagerId blockManagerId,
18 | ByteBuffer metadataBuffer) throws IOException {
19 | ObjectOutputStream oos = new ObjectOutputStream(
20 | new ByteBufferOutputStream(metadataBuffer));
21 | blockManagerId.writeExternal(oos);
22 | oos.close();
23 | }
24 |
25 | static BlockManagerId deserializeBlockManagerID(ByteBuffer metadataBuffer) throws IOException {
26 | ObjectInputStream ois =
27 | new ObjectInputStream(new ByteBufferInputStream(metadataBuffer));
28 | BlockManagerId blockManagerId = BlockManagerId.apply(ois);
29 | ois.close();
30 | return blockManagerId;
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxListenerThread.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.rpc;
6 |
7 | import org.apache.spark.shuffle.ucx.UcxNode;
8 | import org.apache.spark.unsafe.Platform;
9 | import org.openucx.jucx.ucp.UcpRequest;
10 | import org.openucx.jucx.ucp.UcpWorker;
11 | import org.slf4j.Logger;
12 | import org.slf4j.LoggerFactory;
13 |
14 | import java.nio.ByteBuffer;
15 |
16 | /**
17 | * Thread for progressing global worker for connection establishment and RPC exchange.
18 | */
19 | public class UcxListenerThread extends Thread implements Runnable {
20 | private static final Logger logger = LoggerFactory.getLogger(UcxListenerThread.class);
21 | private final UcxNode ucxNode;
22 | private final boolean isDriver;
23 | private final UcpWorker globalWorker;
24 |
25 | public UcxListenerThread(UcxNode ucxNode, boolean isDriver) {
26 | this.ucxNode = ucxNode;
27 | this.isDriver = isDriver;
28 | this.globalWorker = ucxNode.getGlobalWorker();
29 | setDaemon(true);
30 | setName("UcxListenerThread");
31 | }
32 |
33 | /**
34 | * 2. Both Driver and Executor. Accept Recv request.
35 | * If on driver broadcast it to other executors. On executor just save worker addresses.
36 | */
37 | private UcpRequest recvRequest() {
38 | ByteBuffer metadataBuffer = Platform.allocateDirectBuffer(
39 | ucxNode.getConf().metadataRPCBufferSize());
40 | RpcConnectionCallback callback = new RpcConnectionCallback(metadataBuffer, isDriver, ucxNode);
41 | return globalWorker.recvTaggedNonBlocking(metadataBuffer, callback);
42 | }
43 |
44 | @Override
45 | public void run() {
46 | UcpRequest recv = recvRequest();
47 | while (!isInterrupted()) {
48 | if (recv.isCompleted()) {
49 | // Process 1 recv request at a time.
50 | recv = recvRequest();
51 | }
52 | try {
53 | if (globalWorker.progress() == 0) {
54 | globalWorker.waitForEvents();
55 | }
56 | } catch (Exception e) {
57 | logger.error(e.getLocalizedMessage());
58 | interrupt();
59 | }
60 | }
61 | globalWorker.cancelRequest(recv);
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/shuffle/ucx/rpc/UcxRemoteMemory.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.ucx.rpc;
6 |
7 | import java.io.IOException;
8 | import java.io.ObjectInputStream;
9 | import java.io.ObjectOutputStream;
10 | import java.io.Serializable;
11 | import java.nio.ByteBuffer;
12 |
13 | /**
14 | * Utility class to serialize / deserialize metadata buffer on a driver.
15 | * Needed to propagate metadata buffer information to executors using
16 | * spark's mechanism to broadcast tasks.
17 | */
18 | public class UcxRemoteMemory implements Serializable {
19 | private long address;
20 | private ByteBuffer rkeyBuffer;
21 |
22 | public UcxRemoteMemory(long address, ByteBuffer rkeyBuffer) {
23 | this.address = address;
24 | this.rkeyBuffer = rkeyBuffer;
25 | }
26 |
27 | public UcxRemoteMemory() {}
28 |
29 | private void writeObject(ObjectOutputStream out) throws IOException {
30 | out.writeLong(address);
31 | out.writeInt(rkeyBuffer.limit());
32 | byte[] copy = new byte[rkeyBuffer.limit()];
33 | rkeyBuffer.clear();
34 | rkeyBuffer.get(copy);
35 | out.write(copy);
36 | }
37 |
38 | private void readObject(ObjectInputStream in) throws IOException {
39 | this.address = in.readLong();
40 | int bufferSize = in.readInt();
41 | byte[] buffer = new byte[bufferSize];
42 | in.read(buffer, 0, bufferSize);
43 | this.rkeyBuffer = ByteBuffer.allocateDirect(bufferSize).put(buffer);
44 | this.rkeyBuffer.clear();
45 | }
46 |
47 | public long getAddress() {
48 | return address;
49 | }
50 |
51 | public ByteBuffer getRkeyBuffer() {
52 | return rkeyBuffer;
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleBlockResolver.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import java.io.{File, RandomAccessFile}
8 | import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArrayList}
9 |
10 | import scala.collection.JavaConverters._
11 |
12 | import org.openucx.jucx.UcxUtils
13 | import org.openucx.jucx.ucp.{UcpMemMapParams, UcpMemory}
14 | import org.apache.spark.shuffle.ucx.UnsafeUtils
15 | import org.apache.spark.SparkException
16 |
17 | /**
18 | * Mapper entry point for UcxShuffle plugin. Performs memory registration
19 | * of data and index files and publish addresses to driver metadata buffer.
20 | */
21 | abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
22 | extends IndexShuffleBlockResolver(ucxShuffleManager.conf) {
23 | private lazy val memPool = ucxShuffleManager.ucxNode.getMemoryPool
24 |
25 | // Keep track of registered memory regions to release them when shuffle not needed
26 | private val fileMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala
27 | private val offsetMappings = new ConcurrentHashMap[ShuffleId, CopyOnWriteArrayList[UcpMemory]].asScala
28 |
29 | /**
30 | * Mapper commit protocol extension. Register index and data files and publish all needed
31 | * metadata to driver.
32 | */
33 | def writeIndexFileAndCommitCommon(shuffleId: ShuffleId, mapId: Int,
34 | lengths: Array[Long], dataTmp: File,
35 | indexBackFile: RandomAccessFile, dataBackFile: RandomAccessFile): Unit = {
36 | val startTime = System.currentTimeMillis()
37 |
38 | fileMappings.putIfAbsent(shuffleId, new CopyOnWriteArrayList[UcpMemory]())
39 | offsetMappings.putIfAbsent(shuffleId, new CopyOnWriteArrayList[UcpMemory]())
40 |
41 | val indexFileChannel = indexBackFile.getChannel
42 | val dataFileChannel = dataBackFile.getChannel
43 |
44 | // Memory map and register data and index file.
45 | val dataAddress = UnsafeUtils.mmap(dataFileChannel, 0, dataBackFile.length())
46 | val memMapParams = new UcpMemMapParams().setAddress(dataAddress)
47 | .setLength(dataBackFile.length())
48 | if (ucxShuffleManager.ucxShuffleConf.useOdp) {
49 | memMapParams.nonBlocking()
50 | }
51 | val dataMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
52 | fileMappings(shuffleId).add(dataMemory)
53 | assume(indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1))
54 |
55 | val offsetAddress = UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length())
56 | memMapParams.setAddress(offsetAddress).setLength(indexBackFile.length())
57 | val offsetMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
58 | offsetMappings(shuffleId).add(offsetMemory)
59 |
60 | dataFileChannel.close()
61 | dataBackFile.close()
62 | indexFileChannel.close()
63 | indexBackFile.close()
64 |
65 | val fileMemoryRkey = dataMemory.getRemoteKeyBuffer
66 | val offsetRkey = offsetMemory.getRemoteKeyBuffer
67 |
68 | val metadataRegisteredMemory = memPool.get(
69 | fileMemoryRkey.capacity() + offsetRkey.capacity() + 24)
70 | val metadataBuffer = metadataRegisteredMemory.getBuffer.slice()
71 |
72 | if (metadataBuffer.remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize) {
73 | throw new SparkException(s"Metadata block size ${metadataBuffer.remaining() / 2} " +
74 | s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" +
75 | s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize}).")
76 | }
77 |
78 | metadataBuffer.clear()
79 |
80 | metadataBuffer.putLong(offsetMemory.getAddress)
81 | metadataBuffer.putLong(dataMemory.getAddress)
82 |
83 | metadataBuffer.putInt(offsetRkey.capacity())
84 | metadataBuffer.put(offsetRkey)
85 |
86 | metadataBuffer.putInt(fileMemoryRkey.capacity())
87 | metadataBuffer.put(fileMemoryRkey)
88 |
89 | metadataBuffer.clear()
90 |
91 | val workerWrapper = ucxShuffleManager.ucxNode.getThreadLocalWorker
92 | val driverMetadata = workerWrapper.getDriverMetadata(shuffleId)
93 | val driverOffset = driverMetadata.address +
94 | mapId * ucxShuffleManager.ucxShuffleConf.metadataBlockSize
95 |
96 | val driverEndpoint = workerWrapper.driverEndpoint
97 | val request = driverEndpoint.putNonBlocking(UcxUtils.getAddress(metadataBuffer),
98 | metadataBuffer.remaining(), driverOffset, driverMetadata.driverRkey, null)
99 |
100 | workerWrapper.preconnect()
101 | // Blocking progress needed to make sure last mapper published data to driver before
102 | // reducer starts.
103 | workerWrapper.waitRequest(request)
104 | memPool.put(metadataRegisteredMemory)
105 | logInfo(s"MapID: $mapId register files + publishing overhead: " +
106 | s"${System.currentTimeMillis() - startTime} ms")
107 | }
108 |
109 | private def unregisterAndUnmap(mem: UcpMemory): Unit = {
110 | val address = mem.getAddress
111 | val length = mem.getLength
112 | mem.deregister()
113 | UnsafeUtils.munmap(address, length)
114 | }
115 |
116 | def removeShuffle(shuffleId: Int): Unit = {
117 | fileMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) =>
118 | mappings.asScala.par.foreach(unregisterAndUnmap))
119 | offsetMappings.remove(shuffleId).foreach((mappings: CopyOnWriteArrayList[UcpMemory]) =>
120 | mappings.asScala.par.foreach(unregisterAndUnmap))
121 | }
122 |
123 | override def stop(): Unit = {
124 | fileMappings.keys.foreach(removeShuffle)
125 | }
126 | }
127 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/CommonUcxShuffleManager.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import java.util.concurrent.ConcurrentHashMap
8 |
9 | import scala.collection.JavaConverters._
10 | import scala.collection.{concurrent, mutable}
11 |
12 | import org.openucx.jucx.ucp.UcpMemory
13 | import org.apache.spark.SparkConf
14 | import org.apache.spark.shuffle.sort.SortShuffleManager
15 | import org.apache.spark.shuffle.ucx.UcxNode
16 | import org.apache.spark.shuffle.ucx.rpc.UcxRemoteMemory
17 | import org.apache.spark.unsafe.Platform
18 |
19 | /**
20 | * Common part for all spark versions for UcxShuffleManager logic
21 | */
22 | abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) {
23 | type ShuffleId = Int
24 | type MapId = Int
25 | val ucxShuffleConf = new UcxShuffleConf(conf)
26 |
27 | var ucxNode: UcxNode = _
28 |
29 | // Shuffle handle is metadata information about the shuffle (num mappers, etc)
30 | // distributed by Spark task broadcast protocol.
31 | // UcxShuffleHandle is extension over Spark's shuffle handle to keep driver metadata info.
32 | val shuffleIdToHandle: concurrent.Map[ShuffleId, UcxShuffleHandle[_, _, _]] =
33 | new ConcurrentHashMap[ShuffleId, UcxShuffleHandle[_, _, _]]().asScala
34 |
35 | if (isDriver) {
36 | startUcxNodeIfMissing()
37 | }
38 |
39 | protected def registerShuffleCommon[K, V, C](baseHandle: BaseShuffleHandle[K,V,C],
40 | shuffleId: ShuffleId,
41 | numMaps: Int): ShuffleHandle = {
42 | // Register metadata buffer where each map will publish it's index and data file metadata
43 | val metadataBufferSize = numMaps * ucxShuffleConf.metadataBlockSize
44 | val metadataBuffer = Platform.allocateDirectBuffer(metadataBufferSize.toInt)
45 |
46 | val metadataMemory = ucxNode.getContext.registerMemory(metadataBuffer)
47 | shuffleIdToMetadataBuffer.put(shuffleId, metadataMemory)
48 |
49 | val driverMemory = new UcxRemoteMemory(metadataMemory.getAddress,
50 | metadataMemory.getRemoteKeyBuffer)
51 |
52 | val handle = new UcxShuffleHandle(shuffleId, driverMemory, numMaps, baseHandle)
53 |
54 | shuffleIdToHandle.putIfAbsent(shuffleId, handle)
55 | handle
56 | }
57 |
58 | /**
59 | * Mapping between shuffle and metadata buffer, to deregister it when shuffle not needed.
60 | */
61 | protected val shuffleIdToMetadataBuffer: mutable.Map[ShuffleId, UcpMemory] =
62 | new ConcurrentHashMap[ShuffleId, UcpMemory]().asScala
63 |
64 | /**
65 | * Atomically starts UcxNode singleton - one for all shuffle threads.
66 | */
67 | def startUcxNodeIfMissing(): Unit = if (ucxNode == null) {
68 | synchronized {
69 | if (ucxNode == null) {
70 | ucxNode = new UcxNode(ucxShuffleConf, isDriver)
71 | }
72 | }
73 | }
74 |
75 | override def unregisterShuffle(shuffleId: Int): Boolean = {
76 | shuffleIdToMetadataBuffer.remove(shuffleId).foreach(_.deregister())
77 | shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId)
78 | super.unregisterShuffle(shuffleId)
79 | }
80 |
81 | /**
82 | * Called on both driver and executors to finally cleanup resources.
83 | */
84 | override def stop(): Unit = synchronized {
85 | logInfo("Stopping shuffle manager")
86 | shuffleIdToHandle.keys.foreach(unregisterShuffle)
87 | shuffleIdToHandle.clear()
88 | if (ucxNode != null) {
89 | ucxNode.close()
90 | ucxNode = null
91 | }
92 | super.stop()
93 | }
94 |
95 | }
96 |
97 | /**
98 | * Spark shuffle handles extensions, broadcasted by TCP to executors.
99 | * Added metadataBufferOnDriver field, that contains address and rkey of driver metadata buffer.
100 | */
101 | class UcxShuffleHandle[K, V, C](override val shuffleId: Int,
102 | val metadataBufferOnDriver: UcxRemoteMemory,
103 | val numMaps: Int,
104 | val baseHandle: BaseShuffleHandle[K,V,C]) extends ShuffleHandle(shuffleId)
105 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/UcxShuffleConf.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import scala.collection.JavaConverters._
8 |
9 | import org.apache.spark.SparkConf
10 | import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry}
11 | import org.apache.spark.network.util.ByteUnit
12 | import org.apache.spark.util.Utils
13 |
14 | /**
15 | * Plugin configuration properties.
16 | */
17 | class UcxShuffleConf(conf: SparkConf) extends SparkConf {
18 | private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"
19 |
20 | lazy val getNumProcesses: Int = getInt("spark.executor.instances", 1)
21 |
22 | lazy val coresPerProcess: Int = getInt("spark.executor.cores",
23 | Runtime.getRuntime.availableProcessors())
24 |
25 | lazy val driverHost: String = conf.get(getUcxConf("driver.host"),
26 | conf.get("spark.driver.host", "0.0.0.0"))
27 |
28 | lazy val driverPort: Int = conf.getInt(getUcxConf("driver.port"), 55443)
29 |
30 | // Metadata
31 |
32 | lazy val RKEY_SIZE: ConfigEntry[Long] =
33 | ConfigBuilder(getUcxConf("rkeySize"))
34 | .doc("Maximum size of rKeyBuffer")
35 | .bytesConf(ByteUnit.BYTE)
36 | .createWithDefault(150)
37 |
38 | // For metadata we publish index file + data file rkeys
39 | lazy val metadataBlockSize: Long = 2 * conf.getSizeAsBytes(RKEY_SIZE.key,
40 | RKEY_SIZE.defaultValueString)
41 |
42 | private lazy val METADATA_RPC_BUFFER_SIZE =
43 | ConfigBuilder(getUcxConf("rpc.metadata.bufferSize"))
44 | .doc("Buffer size of worker -> driver metadata message")
45 | .bytesConf(ByteUnit.BYTE)
46 | .createWithDefault(4096)
47 |
48 | lazy val metadataRPCBufferSize: Int = conf.getSizeAsBytes(METADATA_RPC_BUFFER_SIZE.key,
49 | METADATA_RPC_BUFFER_SIZE.defaultValueString).toInt
50 |
51 | // Memory Pool
52 | private lazy val PREALLOCATE_BUFFERS =
53 | ConfigBuilder(getUcxConf("memory.preAllocateBuffers"))
54 | .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500")
55 | .stringConf.createWithDefault("")
56 |
57 | lazy val preallocateBuffersMap: java.util.Map[java.lang.Integer, java.lang.Integer] = {
58 | conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty)
59 | .map(entry => entry.split(":") match {
60 | case Array(bufferSize, bufferCount) =>
61 | (int2Integer(Utils.byteStringAsBytes(bufferSize.trim).toInt),
62 | int2Integer(bufferCount.toInt))
63 | }).toMap.asJava
64 | }
65 |
66 | private lazy val MIN_BUFFER_SIZE = ConfigBuilder(getUcxConf("memory.minBufferSize"))
67 | .doc("Minimal buffer size in memory pool.")
68 | .bytesConf(ByteUnit.BYTE)
69 | .createWithDefault(1024)
70 |
71 | lazy val minBufferSize: Long = conf.getSizeAsBytes(MIN_BUFFER_SIZE.key,
72 | MIN_BUFFER_SIZE.defaultValueString)
73 |
74 | private lazy val MIN_REGISTRATION_SIZE =
75 | ConfigBuilder(getUcxConf("memory.minAllocationSize"))
76 | .doc("Minimal memory registration size in memory pool.")
77 | .bytesConf(ByteUnit.MiB)
78 | .createWithDefault(4)
79 |
80 | lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
81 | MIN_REGISTRATION_SIZE.defaultValueString).toInt
82 |
83 | private lazy val PREREGISTER_MEMORY = ConfigBuilder(getUcxConf("memory.preregister"))
84 | .doc("Whether to do ucp mem map for allocated memory in memory pool")
85 | .booleanConf.createWithDefault(true)
86 |
87 | lazy val preregisterMemory: Boolean = conf.getBoolean(PREREGISTER_MEMORY.key, PREREGISTER_MEMORY.defaultValue.get)
88 |
89 | lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), false)
90 | }
91 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/UcxWorkerWrapper.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import java.io.Closeable
8 | import java.net.InetSocketAddress
9 | import java.nio.ByteBuffer
10 | import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
11 |
12 | import scala.collection.JavaConverters._
13 | import scala.collection.mutable
14 |
15 | import org.openucx.jucx.UcxException
16 | import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRemoteKey, UcpRequest, UcpWorker}
17 | import org.apache.spark.SparkEnv
18 | import org.apache.spark.internal.Logging
19 | import org.apache.spark.shuffle.ucx.{UcxNode, UnsafeUtils}
20 | import org.apache.spark.storage.BlockManagerId
21 | import org.apache.spark.unsafe.Platform
22 |
23 | /**
24 | * Driver metadata buffer information that holds unpacked RkeyBuffer for this WorkerWrapper
25 | * and fetched buffer itself.
26 | */
27 | case class DriverMetadata(address: Long, driverRkey: UcpRemoteKey, length: Int,
28 | var data: ByteBuffer) {
29 | // Driver metadata is an array of blocks:
30 | // | mapId0 | mapId1 | mapId2 | mapId3 | mapId4 | mapId5 |
31 | // Each block in driver metadata has next layout:
32 | // |offsetAddress|dataAddress|offsetRkeySize|offsetRkey|dataRkeySize|dataRkey|
33 |
34 | def offsetAddress(mapId: Int): Long = {
35 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize
36 | data.getLong(mapIdBlock)
37 | }
38 |
39 | def dataAddress(mapId: Int): Long = {
40 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize
41 | data.getLong(mapIdBlock + UnsafeUtils.LONG_SIZE)
42 | }
43 |
44 | def offsetRkey(mapId: Int): ByteBuffer = {
45 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize
46 | var offsetWithinBlock = mapIdBlock + 2 * UnsafeUtils.LONG_SIZE
47 | val rkeySize = data.getInt(offsetWithinBlock)
48 | offsetWithinBlock += UnsafeUtils.INT_SIZE
49 | val result = data.duplicate()
50 | result.position(offsetWithinBlock).limit(offsetWithinBlock + rkeySize)
51 | result.slice()
52 | }
53 |
54 | def dataRkey(mapId: Int): ByteBuffer = {
55 | val mapIdBlock = mapId * UcxWorkerWrapper.metadataBlockSize
56 | var offsetWithinBlock = mapIdBlock + 2 * UnsafeUtils.LONG_SIZE
57 | val offsetRkeySize = data.getInt(offsetWithinBlock)
58 | offsetWithinBlock += UnsafeUtils.INT_SIZE + offsetRkeySize
59 | val dataRkeySize = data.getInt(offsetWithinBlock)
60 | offsetWithinBlock += UnsafeUtils.INT_SIZE
61 | val result = data.duplicate()
62 | result.position(offsetWithinBlock).limit(offsetWithinBlock + dataRkeySize)
63 | result.slice()
64 | }
65 | }
66 |
67 | /**
68 | * Worker per thread wrapper, that maintains connection and progress logic.
69 | */
70 | class UcxWorkerWrapper(val worker: UcpWorker, val conf: UcxShuffleConf, val id: Int)
71 | extends Closeable with Logging {
72 | import UcxWorkerWrapper._
73 |
74 | private final val driverSocketAddress = new InetSocketAddress(conf.driverHost, conf.driverPort)
75 | private final val endpointParams = new UcpEndpointParams().setSocketAddress(driverSocketAddress)
76 | .setPeerErrorHandlingMode()
77 | val driverEndpoint: UcpEndpoint = worker.newEndpoint(endpointParams)
78 |
79 | private final val connections = mutable.Map.empty[BlockManagerId, UcpEndpoint]
80 |
81 | private final val driverMetadata = mutable.Map.empty[ShuffleId, DriverMetadata]
82 |
83 | override def close(): Unit = {
84 | driverMetadata.values.foreach{
85 | case DriverMetadata(address, rkey, length, data) => rkey.close()
86 | }
87 | driverMetadata.clear()
88 | driverEndpoint.close()
89 | connections.foreach{
90 | case (_, endpoint) => endpoint.close()
91 | }
92 | connections.clear()
93 | worker.close()
94 | driverMetadataBuffer.clear()
95 | }
96 |
97 | /**
98 | * Blocking progress single request until it's not completed.
99 | */
100 | def waitRequest(request: UcpRequest): Unit = {
101 | val startTime = System.currentTimeMillis()
102 | worker.progressRequest(request)
103 | logDebug(s"Request completed in ${System.currentTimeMillis() - startTime} ms")
104 | }
105 |
106 | /**
107 | * Blocking progress while result queue is empty.
108 | */
109 | def fillQueueWithBlocks(queue: LinkedBlockingQueue[_]): Unit = {
110 | while (queue.isEmpty) {
111 | progress()
112 | }
113 | }
114 |
115 | /**
116 | * The only place for worker progress
117 | */
118 | private def progress(): Int = {
119 | worker.progress()
120 | }
121 |
122 | /**
123 | * Establish connections to known instances.
124 | */
125 | def preconnect(): Unit = {
126 | UcxNode.getWorkerAddresses.keySet().asScala.foreach(getConnection)
127 | }
128 |
129 | def getConnection(blockManagerId: BlockManagerId): UcpEndpoint = {
130 | val workerAddresses = UcxNode.getWorkerAddresses
131 | // Block untill there's no worker address for this BlockManagerID
132 | val startTime = System.currentTimeMillis()
133 | val timeout = conf.getTimeAsMs("spark.network.timeout", "100")
134 | if (workerAddresses.get(blockManagerId) == null) {
135 | workerAddresses.synchronized {
136 | while (workerAddresses.get(blockManagerId) == null) {
137 | workerAddresses.wait(timeout)
138 | if (System.currentTimeMillis() - startTime > timeout) {
139 | throw new UcxException(s"Didn't get worker address for $blockManagerId during $timeout")
140 | }
141 | }
142 | }
143 | }
144 |
145 | connections.getOrElseUpdate(blockManagerId, {
146 | logInfo(s"Worker $id connecting to $blockManagerId")
147 | val endpointParams = new UcpEndpointParams()
148 | .setPeerErrorHandlingMode()
149 | .setUcpAddress(workerAddresses.get(blockManagerId))
150 | worker.newEndpoint(endpointParams)
151 | })
152 | }
153 |
154 | /**
155 | * Unpacks driver metadata RkeyBuffer for this worker.
156 | * Needed to perform PUT operation to publish map output info.
157 | */
158 | def getDriverMetadata(shuffleId: ShuffleId): DriverMetadata = {
159 | driverMetadata.getOrElseUpdate(shuffleId, {
160 | val ucxShuffleHandle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager]
161 | .shuffleIdToHandle(shuffleId)
162 | val (address, length, rkey): (Long, Int, ByteBuffer) = (ucxShuffleHandle.metadataBufferOnDriver.getAddress,
163 | ucxShuffleHandle.numMaps * conf.metadataBlockSize.toInt,
164 | ucxShuffleHandle.metadataBufferOnDriver.getRkeyBuffer)
165 |
166 | rkey.clear()
167 | val unpackedRkey = driverEndpoint.unpackRemoteKey(rkey)
168 | DriverMetadata(address, unpackedRkey, length, null)
169 | })
170 | }
171 |
172 | /**
173 | * Fetches using ucp_get metadata buffer from driver, with all needed information
174 | * for offset and data addresses and keys.
175 | */
176 | def fetchDriverMetadataBuffer(shuffleId: ShuffleId): DriverMetadata = {
177 | val handle = SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager]
178 | .shuffleIdToHandle(shuffleId)
179 |
180 | val metadata = getDriverMetadata(handle.shuffleId)
181 |
182 | UcxWorkerWrapper.driverMetadataBuffer.computeIfAbsent(shuffleId,
183 | (t: ShuffleId) => {
184 | val buffer = Platform.allocateDirectBuffer(metadata.length)
185 | val request = driverEndpoint.getNonBlocking(
186 | metadata.address, metadata.driverRkey, buffer, null)
187 | waitRequest(request)
188 | buffer
189 | }
190 | )
191 |
192 | if (metadata.data == null) {
193 | metadata.data = UcxWorkerWrapper.driverMetadataBuffer.get(shuffleId)
194 | }
195 | metadata
196 | }
197 | }
198 |
199 | object UcxWorkerWrapper {
200 | type ShuffleId = Int
201 | type MapId = Int
202 | // Driver metadata buffer, to fetch by first worker wrapper.
203 | val driverMetadataBuffer = new ConcurrentHashMap[ShuffleId, ByteBuffer]()
204 |
205 | val metadataBlockSize: MapId =
206 | SparkEnv.get.shuffleManager.asInstanceOf[CommonUcxShuffleManager].ucxShuffleConf.metadataBlockSize.toInt
207 | }
208 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_2_1
6 |
7 | import java.io.{File, RandomAccessFile}
8 |
9 | import org.apache.spark.SparkEnv
10 | import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver}
11 | import org.apache.spark.storage.ShuffleIndexBlockId
12 |
13 | /**
14 | * Mapper entry point for UcxShuffle plugin. Performs memory registration
15 | * of data and index files and publish addresses to driver metadata buffer.
16 | */
17 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
18 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
19 |
20 | private def getIndexFile(shuffleId: Int, mapId: Int): File = {
21 | SparkEnv.get.blockManager
22 | .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID))
23 | }
24 |
25 | /**
26 | * Mapper commit protocol extension. Register index and data files and publish all needed
27 | * metadata to driver.
28 | */
29 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int,
30 | lengths: Array[Long], dataTmp: File): Unit = {
31 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
32 | val dataFile = getDataFile(shuffleId, mapId)
33 | val dataBackFile = new RandomAccessFile(dataFile, "rw")
34 |
35 | if (dataBackFile.length() == 0) {
36 | dataBackFile.close()
37 | return
38 | }
39 |
40 | val indexFile = getIndexFile(shuffleId, mapId)
41 | val indexBackFile = new RandomAccessFile(indexFile, "rw")
42 | writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile)
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import org.apache.spark.shuffle.compat.spark_2_1.{UcxShuffleBlockResolver, UcxShuffleReader}
8 | import org.apache.spark.util.ShutdownHookManager
9 | import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
10 |
11 | /**
12 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
13 | * and injects needed logic in override methods.
14 | */
15 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
16 | ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
17 |
18 | /**
19 | * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
20 | * Called on driver and guaranteed by spark that shuffle on executor will start after it.
21 | */
22 | override def registerShuffle[K, V, C](shuffleId: ShuffleId,
23 | numMaps: Int,
24 | dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
25 | assume(isDriver)
26 | val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
27 | registerShuffleCommon(baseHandle, shuffleId, numMaps)
28 | }
29 |
30 | /**
31 | * Mapper callback on executor. Just start UcxNode and use Spark mapper logic.
32 | */
33 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Int,
34 | context: TaskContext): ShuffleWriter[K, V] = {
35 | startUcxNodeIfMissing()
36 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]])
37 | super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context)
38 | }
39 |
40 | override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this)
41 |
42 | /**
43 | * Reducer callback on executor.
44 | */
45 | override def getReader[K, C](handle: ShuffleHandle, startPartition: Int,
46 | endPartition: Int, context: TaskContext): ShuffleReader[K, C] = {
47 | startUcxNodeIfMissing()
48 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]])
49 | new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition,
50 | endPartition, context)
51 | }
52 | }
53 |
54 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_2_1
6 |
7 | import java.io.InputStream
8 | import java.util.concurrent.LinkedBlockingQueue
9 |
10 | import org.apache.spark.internal.{Logging, config}
11 | import org.apache.spark.serializer.SerializerManager
12 | import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1.UcxShuffleClient
13 | import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
14 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
15 | import org.apache.spark.util.CompletionIterator
16 | import org.apache.spark.util.collection.ExternalSorter
17 | import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
18 |
19 | /**
20 | * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
21 | * and lazy progress only when result queue is empty.
22 | */
23 | class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
24 | startPartition: Int,
25 | endPartition: Int,
26 | context: TaskContext,
27 | serializerManager: SerializerManager = SparkEnv.get.serializerManager,
28 | blockManager: BlockManager = SparkEnv.get.blockManager,
29 | mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
30 | extends ShuffleReader[K, C] with Logging {
31 |
32 | private val dep = handle.baseHandle.dependency
33 |
34 | /** Read the combined key-values for this reduce task */
35 | override def read(): Iterator[Product2[K, C]] = {
36 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
37 | val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
38 | .ucxNode.getThreadLocalWorker
39 | val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper)
40 | val wrappedStreams = new ShuffleBlockFetcherIterator(
41 | context,
42 | shuffleClient,
43 | blockManager,
44 | mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId,
45 | startPartition, endPartition),
46 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
47 | SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
48 | SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
49 |
50 | // Ucx shuffle logic
51 | // Java reflection to get access to private results queue
52 | val queueField = wrappedStreams.getClass.getDeclaredField(
53 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
54 | queueField.setAccessible(true)
55 | val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]
56 |
57 | // Do progress if queue is empty before calling next on ShuffleIterator
58 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
59 | override def next(): (BlockId, InputStream) = {
60 | val startTime = System.currentTimeMillis()
61 | workerWrapper.fillQueueWithBlocks(resultQueue)
62 | shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
63 | wrappedStreams.next()
64 | }
65 |
66 | override def hasNext: Boolean = {
67 | val result = wrappedStreams.hasNext
68 | if (!result) {
69 | shuffleClient.close()
70 | }
71 | result
72 | }
73 | }
74 | // End of ucx shuffle logic
75 |
76 | val serializerInstance = dep.serializer.newInstance()
77 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
78 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
79 | // NextIterator. The NextIterator makes sure that close() is called on the
80 | // underlying InputStream when all records have been read.
81 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
82 | }
83 |
84 | // Update the context task metrics for each record read.
85 | val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
86 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
87 | recordIter.map { record =>
88 | readMetrics.incRecordsRead(1)
89 | record
90 | },
91 | context.taskMetrics().mergeShuffleReadMetrics())
92 |
93 | // An interruptible iterator must be used here in order to support task cancellation
94 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
95 |
96 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
97 | if (dep.mapSideCombine) {
98 | // We are reading values that are already combined
99 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
100 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
101 | } else {
102 | // We don't know the value type, but also don't care -- the dependency *should*
103 | // have made sure its compatible w/ this aggregator, which will convert the value
104 | // type to the combined type C
105 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
106 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
107 | }
108 | } else {
109 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
110 | }
111 |
112 | // Sort the output if there is a sort ordering defined.
113 | val resultIter = dep.keyOrdering match {
114 | case Some(keyOrd: Ordering[K]) =>
115 | // Create an ExternalSorter to sort the data.
116 | val sorter =
117 | new ExternalSorter[K, C, C](context,
118 | ordering = Some(keyOrd), serializer = dep.serializer)
119 | sorter.insertAll(aggregatedIter)
120 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
121 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
122 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
123 | // Use completion callback to stop sorter if task was finished/cancelled.
124 | CompletionIterator[Product2[K, C],
125 | Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
126 | case None =>
127 | aggregatedIter
128 | }
129 |
130 | resultIter match {
131 | case _: InterruptibleIterator[Product2[K, C]] => resultIter
132 | case _ =>
133 | // Use another interruptible iterator here to support task cancellation as aggregator
134 | // or(and) sorter may have consumed previous interruptible iterator.
135 | new InterruptibleIterator[Product2[K, C]](context, resultIter)
136 | }
137 | }
138 |
139 | }
140 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_2_4
6 |
7 | import java.io.{File, RandomAccessFile}
8 |
9 | import org.apache.spark.SparkEnv
10 | import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver}
11 | import org.apache.spark.storage.ShuffleIndexBlockId
12 |
13 | /**
14 | * Mapper entry point for UcxShuffle plugin. Performs memory registration
15 | * of data and index files and publish addresses to driver metadata buffer.
16 | */
17 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
18 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
19 |
20 | private def getIndexFile(shuffleId: Int, mapId: Int): File = {
21 | SparkEnv.get.blockManager
22 | .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID))
23 | }
24 |
25 | /**
26 | * Mapper commit protocol extension. Register index and data files and publish all needed
27 | * metadata to driver.
28 | */
29 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int,
30 | lengths: Array[Long], dataTmp: File): Unit = {
31 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
32 | val dataFile = getDataFile(shuffleId, mapId)
33 | val dataBackFile = new RandomAccessFile(dataFile, "rw")
34 |
35 | if (dataBackFile.length() == 0) {
36 | dataBackFile.close()
37 | return
38 | }
39 |
40 | val indexFile = getIndexFile(shuffleId, mapId)
41 | val indexBackFile = new RandomAccessFile(indexFile, "rw")
42 | writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile)
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import org.apache.spark.shuffle.compat.spark_2_4.{UcxShuffleBlockResolver, UcxShuffleReader}
8 | import org.apache.spark.util.ShutdownHookManager
9 | import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
10 |
11 | /**
12 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
13 | * and injects needed logic in override methods.
14 | */
15 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
16 | ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
17 |
18 | /**
19 | * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
20 | * Called on driver and guaranteed by spark that shuffle on executor will start after it.
21 | */
22 | override def registerShuffle[K, V, C](shuffleId: ShuffleId,
23 | numMaps: Int,
24 | dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
25 | assume(isDriver)
26 | val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
27 | registerShuffleCommon(baseHandle, shuffleId, numMaps)
28 | }
29 |
30 | /**
31 | * Mapper callback on executor. Just start UcxNode and use Spark mapper logic.
32 | */
33 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Int,
34 | context: TaskContext): ShuffleWriter[K, V] = {
35 | startUcxNodeIfMissing()
36 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]])
37 | super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context)
38 | }
39 |
40 | override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this)
41 |
42 | /**
43 | * Reducer callback on executor.
44 | */
45 | override def getReader[K, C](handle: ShuffleHandle, startPartition: Int,
46 | endPartition: Int, context: TaskContext): ShuffleReader[K, C] = {
47 | startUcxNodeIfMissing()
48 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]])
49 | new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition,
50 | endPartition, context)
51 | }
52 | }
53 |
54 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_2_4
6 |
7 | import java.io.InputStream
8 | import java.util.concurrent.LinkedBlockingQueue
9 |
10 | import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
11 | import org.apache.spark.internal.{Logging, config}
12 | import org.apache.spark.serializer.SerializerManager
13 | import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
14 | import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4.UcxShuffleClient
15 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
16 | import org.apache.spark.util.CompletionIterator
17 | import org.apache.spark.util.collection.ExternalSorter
18 |
19 | /**
20 | * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
21 | * and lazy progress only when result queue is empty.
22 | */
23 | class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
24 | startPartition: Int,
25 | endPartition: Int,
26 | context: TaskContext,
27 | serializerManager: SerializerManager = SparkEnv.get.serializerManager,
28 | blockManager: BlockManager = SparkEnv.get.blockManager,
29 | mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
30 | extends ShuffleReader[K, C] with Logging {
31 |
32 | private val dep = handle.baseHandle.dependency
33 |
34 | /** Read the combined key-values for this reduce task */
35 | override def read(): Iterator[Product2[K, C]] = {
36 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
37 | val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
38 | .ucxNode.getThreadLocalWorker
39 | val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper)
40 | val wrappedStreams = new ShuffleBlockFetcherIterator(
41 | context,
42 | shuffleClient,
43 | blockManager,
44 | mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId,
45 | startPartition, endPartition),
46 | serializerManager.wrapStream,
47 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
48 | SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
49 | SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
50 | SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
51 | SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
52 | SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
53 |
54 | // Ucx shuffle logic
55 | // Java reflection to get access to private results queue
56 | val queueField = wrappedStreams.getClass.getDeclaredField(
57 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
58 | queueField.setAccessible(true)
59 | val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]]
60 |
61 | // Do progress if queue is empty before calling next on ShuffleIterator
62 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
63 | override def next(): (BlockId, InputStream) = {
64 | val startTime = System.currentTimeMillis()
65 | workerWrapper.fillQueueWithBlocks(resultQueue)
66 | shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
67 | wrappedStreams.next()
68 | }
69 |
70 | override def hasNext: Boolean = {
71 | val result = wrappedStreams.hasNext
72 | if (!result) {
73 | shuffleClient.close()
74 | }
75 | result
76 | }
77 | }
78 | // End of ucx shuffle logic
79 |
80 | val serializerInstance = dep.serializer.newInstance()
81 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
82 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
83 | // NextIterator. The NextIterator makes sure that close() is called on the
84 | // underlying InputStream when all records have been read.
85 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
86 | }
87 |
88 | // Update the context task metrics for each record read.
89 | val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
90 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
91 | recordIter.map { record =>
92 | readMetrics.incRecordsRead(1)
93 | record
94 | },
95 | context.taskMetrics().mergeShuffleReadMetrics())
96 |
97 | // An interruptible iterator must be used here in order to support task cancellation
98 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
99 |
100 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
101 | if (dep.mapSideCombine) {
102 | // We are reading values that are already combined
103 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
104 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
105 | } else {
106 | // We don't know the value type, but also don't care -- the dependency *should*
107 | // have made sure its compatible w/ this aggregator, which will convert the value
108 | // type to the combined type C
109 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
110 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
111 | }
112 | } else {
113 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
114 | }
115 |
116 | // Sort the output if there is a sort ordering defined.
117 | val resultIter = dep.keyOrdering match {
118 | case Some(keyOrd: Ordering[K]) =>
119 | // Create an ExternalSorter to sort the data.
120 | val sorter =
121 | new ExternalSorter[K, C, C](context,
122 | ordering = Some(keyOrd), serializer = dep.serializer)
123 | sorter.insertAll(aggregatedIter)
124 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
125 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
126 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
127 | // Use completion callback to stop sorter if task was finished/cancelled.
128 | context.addTaskCompletionListener[Unit](_ => {
129 | sorter.stop()
130 | })
131 | CompletionIterator[Product2[K, C],
132 | Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
133 | case None =>
134 | aggregatedIter
135 | }
136 |
137 | resultIter match {
138 | case _: InterruptibleIterator[Product2[K, C]] => resultIter
139 | case _ =>
140 | // Use another interruptible iterator here to support task cancellation as aggregator
141 | // or(and) sorter may have consumed previous interruptible iterator.
142 | new InterruptibleIterator[Product2[K, C]](context, resultIter)
143 | }
144 | }
145 |
146 | }
147 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleDataIO.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_3_0
6 |
7 | import org.apache.spark.SparkConf
8 | import org.apache.spark.internal.Logging
9 | import org.apache.spark.shuffle.api.ShuffleExecutorComponents
10 | import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO
11 |
12 | /**
13 | * Ucx local disk IO plugin to handle logic of writing to local disk and shuffle memory registration.
14 | */
15 | case class UcxLocalDiskShuffleDataIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging {
16 |
17 | override def executor(): ShuffleExecutorComponents = {
18 | new UcxLocalDiskShuffleExecutorComponents(sparkConf)
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_3_0
6 |
7 | import java.util
8 | import java.util.Optional
9 |
10 | import org.apache.spark.internal.Logging
11 | import org.apache.spark.{SparkConf, SparkEnv}
12 | import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter}
13 | import org.apache.spark.shuffle.UcxShuffleManager
14 | import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter}
15 |
16 | /**
17 | * Entry point to UCX executor.
18 | */
19 | class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf)
20 | extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{
21 |
22 | private var blockResolver: UcxShuffleBlockResolver = _
23 |
24 | override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = {
25 | val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
26 | ucxShuffleManager.startUcxNodeIfMissing()
27 | blockResolver = ucxShuffleManager.shuffleBlockResolver
28 | }
29 |
30 | override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = {
31 | if (blockResolver == null) {
32 | throw new IllegalStateException(
33 | "Executor components must be initialized before getting writers.")
34 | }
35 | new LocalDiskShuffleMapOutputWriter(
36 | shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf)
37 | }
38 |
39 | override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = {
40 | if (blockResolver == null) {
41 | throw new IllegalStateException(
42 | "Executor components must be initialized before getting writers.")
43 | }
44 | Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver))
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_3_0
6 |
7 | import java.io.{File, RandomAccessFile}
8 |
9 | import org.apache.spark.{SparkEnv, TaskContext}
10 | import org.apache.spark.network.shuffle.ExecutorDiskUtils
11 | import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
12 | import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager}
13 | import org.apache.spark.storage.ShuffleIndexBlockId
14 |
15 | /**
16 | * Mapper entry point for UcxShuffle plugin. Performs memory registration
17 | * of data and index files and publish addresses to driver metadata buffer.
18 | */
19 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager)
20 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) {
21 |
22 | private def getIndexFile(
23 | shuffleId: Int,
24 | mapId: Long,
25 | dirs: Option[Array[String]] = None): File = {
26 | val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
27 | val blockManager = SparkEnv.get.blockManager
28 | dirs
29 | .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name))
30 | .getOrElse(blockManager.diskBlockManager.getFile(blockId))
31 | }
32 |
33 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
34 | lengths: Array[Long], dataTmp: File): Unit = {
35 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
36 | // In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset
37 | // in metadata buffer
38 | val partitionId = TaskContext.getPartitionId()
39 | val dataFile = getDataFile(shuffleId, mapId)
40 | val dataBackFile = new RandomAccessFile(dataFile, "rw")
41 |
42 | if (dataBackFile.length() == 0) {
43 | dataBackFile.close()
44 | return
45 | }
46 |
47 | val indexFile = getIndexFile(shuffleId, mapId)
48 | val indexBackFile = new RandomAccessFile(indexFile, "rw")
49 |
50 | writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, dataTmp, indexBackFile, dataBackFile)
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle
6 |
7 | import scala.collection.JavaConverters._
8 |
9 | import org.apache.spark.shuffle.api.ShuffleExecutorComponents
10 | import org.apache.spark.shuffle.compat.spark_3_0.{UcxShuffleBlockResolver, UcxShuffleReader}
11 | import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter}
12 | import org.apache.spark.util.ShutdownHookManager
13 | import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
14 |
15 | /**
16 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin
17 | * and injects needed logic in override methods.
18 | */
19 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) {
20 | ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop)
21 | private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
22 |
23 | override val shuffleBlockResolver = new UcxShuffleBlockResolver(this)
24 |
25 | override def registerShuffle[K, V, C](shuffleId: ShuffleId, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
26 | assume(isDriver)
27 | val numMaps = dependency.partitioner.numPartitions
28 | val baseHandle = super.registerShuffle(shuffleId, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]]
29 | registerShuffleCommon(baseHandle, shuffleId, numMaps)
30 | }
31 |
32 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Long, context: TaskContext,
33 | metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
34 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, V, _]])
35 | val env = SparkEnv.get
36 | handle.asInstanceOf[UcxShuffleHandle[K, V, _]].baseHandle match {
37 | case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] =>
38 | new UnsafeShuffleWriter(
39 | env.blockManager,
40 | context.taskMemoryManager(),
41 | unsafeShuffleHandle,
42 | mapId,
43 | context,
44 | env.conf,
45 | metrics,
46 | shuffleExecutorComponents)
47 | case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] =>
48 | new SortShuffleWriter(
49 | shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
50 | }
51 | }
52 |
53 | override def getReader[K, C](handle: ShuffleHandle, startPartition: MapId, endPartition: MapId,
54 | context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
55 |
56 | startUcxNodeIfMissing()
57 | shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K, _, C]])
58 | new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, endPartition,
59 | context, readMetrics = metrics, shouldBatchFetch = true)
60 | }
61 |
62 |
63 | private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
64 | val executorComponents = ShuffleDataIOUtils.loadShuffleDataIO(conf).executor()
65 | val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX)
66 | .toMap
67 | executorComponents.initializeExecutor(
68 | conf.getAppId,
69 | SparkEnv.get.executorId,
70 | extraConfigs.asJava)
71 | executorComponents
72 | }
73 |
74 | }
75 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
3 | * See file LICENSE for terms.
4 | */
5 | package org.apache.spark.shuffle.compat.spark_3_0
6 |
7 | import java.io.InputStream
8 | import java.util.concurrent.LinkedBlockingQueue
9 |
10 | import scala.collection.JavaConverters._
11 |
12 | import org.apache.spark.internal.{Logging, config}
13 | import org.apache.spark.io.CompressionCodec
14 | import org.apache.spark.serializer.SerializerManager
15 | import org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0.UcxShuffleClient
16 | import org.apache.spark.shuffle.{ShuffleReadMetricsReporter, ShuffleReader, UcxShuffleHandle, UcxShuffleManager}
17 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId}
18 | import org.apache.spark.util.CompletionIterator
19 | import org.apache.spark.util.collection.ExternalSorter
20 | import org.apache.spark.{InterruptibleIterator, SparkEnv, SparkException, TaskContext}
21 |
22 |
23 | /**
24 | * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient,
25 | * and lazy progress only when result queue is empty.
26 | */
27 | class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C],
28 | startPartition: Int,
29 | endPartition: Int,
30 | context: TaskContext,
31 | serializerManager: SerializerManager = SparkEnv.get.serializerManager,
32 | blockManager: BlockManager = SparkEnv.get.blockManager,
33 | readMetrics: ShuffleReadMetricsReporter,
34 | shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging {
35 |
36 | private val dep = handle.baseHandle.dependency
37 |
38 | /** Read the combined key-values for this reduce task */
39 | override def read(): Iterator[Product2[K, C]] = {
40 | val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
41 | handle.shuffleId, startPartition, endPartition).duplicate
42 | val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{
43 | case (_, blocks) => blocks.map {
44 | case (blockId, _, mapIdx) => blockId match {
45 | case x: ShuffleBlockId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer])
46 | case x: ShuffleBlockBatchId => (x.mapId.asInstanceOf[java.lang.Long], mapIdx.asInstanceOf[java.lang.Integer])
47 | case _ => throw new SparkException("Unknown block")
48 | }
49 | }
50 | }.toMap
51 |
52 | val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
53 | .ucxNode.getThreadLocalWorker
54 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
55 | val shuffleClient = new UcxShuffleClient(handle.shuffleId, workerWrapper, mapIdToBlockIndex.asJava, shuffleMetrics)
56 | val shuffleIterator = new ShuffleBlockFetcherIterator(
57 | context,
58 | shuffleClient,
59 | blockManager,
60 | blocksByAddressIterator1,
61 | serializerManager.wrapStream,
62 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
63 | SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
64 | SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
65 | SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
66 | SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
67 | SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
68 | SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
69 | readMetrics,
70 | fetchContinuousBlocksInBatch)
71 |
72 | val wrappedStreams = shuffleIterator.toCompletionIterator
73 |
74 | // Ucx shuffle logic
75 | // Java reflection to get access to private results queue
76 | val queueField = shuffleIterator.getClass.getDeclaredField(
77 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
78 | queueField.setAccessible(true)
79 | val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]]
80 |
81 | // Do progress if queue is empty before calling next on ShuffleIterator
82 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
83 | override def next(): (BlockId, InputStream) = {
84 | val startTime = System.currentTimeMillis()
85 | workerWrapper.fillQueueWithBlocks(resultQueue)
86 | readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
87 | wrappedStreams.next()
88 | }
89 |
90 | override def hasNext: Boolean = {
91 | val result = wrappedStreams.hasNext
92 | if (!result) {
93 | shuffleClient.close()
94 | }
95 | result
96 | }
97 | }
98 | // End of ucx shuffle logic
99 |
100 | val serializerInstance = dep.serializer.newInstance()
101 |
102 | // Create a key/value iterator for each stream
103 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
104 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
105 | // NextIterator. The NextIterator makes sure that close() is called on the
106 | // underlying InputStream when all records have been read.
107 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
108 | }
109 |
110 | // Update the context task metrics for each record read.
111 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
112 | recordIter.map { record =>
113 | readMetrics.incRecordsRead(1)
114 | record
115 | },
116 | context.taskMetrics().mergeShuffleReadMetrics())
117 |
118 | // An interruptible iterator must be used here in order to support task cancellation
119 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
120 |
121 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
122 | if (dep.mapSideCombine) {
123 | // We are reading values that are already combined
124 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
125 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
126 | } else {
127 | // We don't know the value type, but also don't care -- the dependency *should*
128 | // have made sure its compatible w/ this aggregator, which will convert the value
129 | // type to the combined type C
130 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
131 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
132 | }
133 | } else {
134 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
135 | }
136 |
137 | // Sort the output if there is a sort ordering defined.
138 | val resultIter = dep.keyOrdering match {
139 | case Some(keyOrd: Ordering[K]) =>
140 | // Create an ExternalSorter to sort the data.
141 | val sorter =
142 | new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
143 | sorter.insertAll(aggregatedIter)
144 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
145 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
146 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
147 | // Use completion callback to stop sorter if task was finished/cancelled.
148 | context.addTaskCompletionListener[Unit](_ => {
149 | sorter.stop()
150 | })
151 | CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
152 | case None =>
153 | aggregatedIter
154 | }
155 |
156 | resultIter match {
157 | case _: InterruptibleIterator[Product2[K, C]] => resultIter
158 | case _ =>
159 | // Use another interruptible iterator here to support task cancellation as aggregator
160 | // or(and) sorter may have consumed previous interruptible iterator.
161 | new InterruptibleIterator[Product2[K, C]](context, resultIter)
162 | }
163 | }
164 |
165 | private def fetchContinuousBlocksInBatch: Boolean = {
166 | val conf = SparkEnv.get.conf
167 | val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
168 | val compressed = conf.get(config.SHUFFLE_COMPRESS)
169 | val codecConcatenation = if (compressed) {
170 | CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
171 | } else {
172 | true
173 | }
174 | val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
175 |
176 | val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
177 | (!compressed || codecConcatenation) && !useOldFetchProtocol
178 | if (shouldBatchFetch && !doBatchFetch) {
179 | logWarning("The feature tag of continuous shuffle block fetching is set to true, but " +
180 | "we can not enable the feature because other conditions are not satisfied. " +
181 | s"Shuffle compress: $compressed, serializer ${dep.serializer.getClass.getName} " +
182 | s"relocatable: $serializerRelocatable, " +
183 | s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
184 | s"$useOldFetchProtocol.")
185 | }
186 | doBatchFetch
187 | }
188 |
189 | }
190 |
--------------------------------------------------------------------------------