├── .artifactignore ├── .github └── workflows │ ├── signoff-check.yml │ ├── signoff-check │ ├── Dockerfile │ ├── action.yml │ └── signoff-check │ ├── sparkucx-ci.yml │ └── sparkucx-release.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── buildlib ├── azure-pipelines.yml └── test.sh ├── pom.xml └── src └── main └── scala └── org └── apache └── spark └── shuffle ├── compat ├── spark_2_4 │ ├── UcxShuffleBlockResolver.scala │ ├── UcxShuffleClient.scala │ ├── UcxShuffleManager.scala │ └── UcxShuffleReader.scala └── spark_3_0 │ ├── UcxLocalDiskShuffleExecutorComponents.scala │ ├── UcxShuffleBlockResolver.scala │ ├── UcxShuffleClient.scala │ ├── UcxShuffleManager.scala │ └── UcxShuffleReader.scala ├── ucx ├── CommonUcxShuffleBlockResolver.scala ├── CommonUcxShuffleManager.scala ├── ShuffleTransport.scala ├── UcxShuffleConf.scala ├── UcxShuffleTransport.scala ├── UcxWorkerWrapper.scala ├── memory │ └── MemoryPool.scala ├── perf │ └── UcxPerfBenchmark.scala └── rpc │ ├── GlobalWorkerRpcThread.scala │ ├── UcxDriverRpcEndpoint.scala │ ├── UcxExecutorRpcEndpoint.scala │ └── UcxRpcMessages.scala └── utils ├── SerializableDirectBuffer.scala └── UnsafeUtils.scala /.artifactignore: -------------------------------------------------------------------------------- 1 | **/* 2 | !target/*.jar 3 | -------------------------------------------------------------------------------- /.github/workflows/signoff-check.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # A workflow to check if PR got sign-off 16 | name: signoff check 17 | 18 | on: 19 | pull_request_target: 20 | types: [opened, synchronize, reopened] 21 | 22 | jobs: 23 | signoff-check: 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v2 27 | 28 | - name: sigoff-check job 29 | uses: ./.github/workflows/signoff-check 30 | env: 31 | OWNER: NVIDIA 32 | REPO_NAME: sparkucx 33 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | PULL_NUMBER: ${{ github.event.number }} 35 | -------------------------------------------------------------------------------- /.github/workflows/signoff-check/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | FROM python:3.8-slim-buster 16 | 17 | WORKDIR / 18 | COPY signoff-check . 19 | RUN pip install PyGithub && chmod +x /signoff-check 20 | 21 | # require envs: OWNER,REPO_NAME,GITHUB_TOKEN,PULL_NUMBER 22 | ENTRYPOINT ["/signoff-check"] 23 | -------------------------------------------------------------------------------- /.github/workflows/signoff-check/action.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: 'signoff check action' 16 | description: 'check if PR got signed off' 17 | runs: 18 | using: 'docker' 19 | image: 'Dockerfile' 20 | -------------------------------------------------------------------------------- /.github/workflows/signoff-check/signoff-check: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """A signoff check 18 | 19 | The tool checks if any commit got signoff in a pull request. 20 | 21 | NOTE: this script is for github actions only, you should not use it anywhere else. 22 | """ 23 | import os 24 | import re 25 | import sys 26 | from argparse import ArgumentParser 27 | 28 | from github import Github 29 | 30 | SIGNOFF_REGEX = re.compile('Signed-off-by:') 31 | 32 | 33 | def signoff(token: str, owner: str, repo_name: str, pull_number: int): 34 | gh = Github(token, per_page=100, user_agent='signoff-check', verify=True) 35 | pr = gh.get_repo(f"{owner}/{repo_name}").get_pull(pull_number) 36 | for c in pr.get_commits(): 37 | if SIGNOFF_REGEX.search(c.commit.message): 38 | print('Found signoff.\n') 39 | print(f"Commit sha:\n{c.commit.sha}") 40 | print(f"Commit message:\n{c.commit.message}") 41 | return True 42 | return False 43 | 44 | 45 | def main(token: str, owner: str, repo_name: str, pull_number: int): 46 | try: 47 | if not signoff(token, owner, repo_name, pull_number): 48 | raise Exception('No commits w/ signoff') 49 | except Exception as e: # pylint: disable=broad-except 50 | print(e) 51 | sys.exit(1) 52 | 53 | 54 | if __name__ == '__main__': 55 | parser = ArgumentParser(description="signoff check") 56 | parser.add_argument("--owner", help="repo owner", default='') 57 | parser.add_argument("--repo_name", help="repo name", default='') 58 | parser.add_argument("--token", help="github token, will use GITHUB_TOKEN if empty", default='') 59 | parser.add_argument("--pull_number", help="pull request number", type=int) 60 | args = parser.parse_args() 61 | 62 | GITHUB_TOKEN = args.token if args.token else os.environ.get('GITHUB_TOKEN') 63 | assert GITHUB_TOKEN, 'env GITHUB_TOKEN should not be empty' 64 | OWNER = args.owner if args.owner else os.environ.get('OWNER') 65 | assert OWNER, 'env OWNER should not be empty' 66 | REPO_NAME = args.repo_name if args.repo_name else os.environ.get('REPO_NAME') 67 | assert REPO_NAME, 'env REPO_NAME should not be empty' 68 | PULL_NUMBER = args.pull_number if args.pull_number else int(os.environ.get('PULL_NUMBER')) 69 | assert PULL_NUMBER, 'env PULL_NUMBER should not be empty' 70 | 71 | main(token=GITHUB_TOKEN, owner=OWNER, repo_name=REPO_NAME, pull_number=PULL_NUMBER) 72 | -------------------------------------------------------------------------------- /.github/workflows/sparkucx-ci.yml: -------------------------------------------------------------------------------- 1 | name: SparkUCX CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | build-sparkucx: 10 | strategy: 11 | matrix: 12 | spark_version: ["2.1", "2.4", "3.0"] 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v1 16 | - name: Set up JDK 1.11 17 | uses: actions/setup-java@v1 18 | with: 19 | java-version: 1.11 20 | - name: Build with Maven 21 | run: mvn -B package -Pspark-${{ matrix.spark_version }} -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn 22 | --file pom.xml 23 | - name: Run Sonar code analysis 24 | run: mvn -B sonar:sonar -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Dsonar.projectKey=openucx:spark-ucx -Dsonar.organization=openucx -Dsonar.scanner.force-deprecated-java-version=true -Dsonar.host.url=https://sonarcloud.io -Dsonar.login=97f4df88ff4fa04e2d5b061acf07315717f1f08b -Pspark-${{ matrix.spark_version }} 25 | env: 26 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 27 | -------------------------------------------------------------------------------- /.github/workflows/sparkucx-release.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | # Sequence of patterns matched against refs/tags 4 | tags: 5 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 6 | 7 | name: Upload Release Asset 8 | 9 | env: 10 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 11 | 12 | jobs: 13 | release: 14 | strategy: 15 | matrix: 16 | spark_version: ["2.4", "3.0"] 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v2 21 | 22 | - name: Set up JDK 1.11 23 | uses: actions/setup-java@v1 24 | with: 25 | java-version: 1.11 26 | 27 | - name: Build with Maven 28 | id: maven_package 29 | run: | 30 | mvn -B -Pspark-${{ matrix.spark_version }} clean package \ 31 | -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn \ 32 | --file pom.xml 33 | cd target 34 | echo "::set-output name=jar_name::$(echo ucx-spark-*-jar-with-dependencies.jar)" 35 | 36 | - name: Upload Release Jars 37 | uses: svenstaro/upload-release-action@v1-release 38 | with: 39 | repo_token: ${{ secrets.GITHUB_TOKEN }} 40 | file: ./target/${{ steps.maven_package.outputs.jar_name }} 41 | asset_name: ${{ steps.maven_package.outputs.jar_name }} 42 | tag: ${{ github.ref }} 43 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to UCX Accelerator for Apache Spark 2 | 3 | ### Sign your work 4 | 5 | We require that all contributors sign-off on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. 6 | 7 | Any contribution which contains commits that are not signed off will not be accepted. 8 | 9 | To sign off on a commit use the `--signoff` (or `-s`) option when committing your changes: 10 | 11 | ```shell 12 | git commit -s -m "Add cool feature." 13 | ``` 14 | 15 | This will append the following to your commit message: 16 | 17 | ``` 18 | Signed-off-by: Your Name 19 | ``` 20 | 21 | The sign-off is a simple line at the end of the explanation for the patch. Your signature certifies that you wrote the patch or otherwise have the right to pass it on as an open-source patch. Use your real name, no pseudonyms or anonymous contributions. If you set your `user.name` and `user.email` git configs, you can sign your commit automatically with `git commit -s`. 22 | 23 | 24 | The signoff means you certify the below (from [developercertificate.org](https://developercertificate.org)): 25 | 26 | ``` 27 | Developer Certificate of Origin 28 | Version 1.1 29 | 30 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 31 | 1 Letterman Drive 32 | Suite D4700 33 | San Francisco, CA, 94129 34 | 35 | Everyone is permitted to copy and distribute verbatim copies of this 36 | license document, but changing it is not allowed. 37 | 38 | 39 | Developer's Certificate of Origin 1.1 40 | 41 | By making a contribution to this project, I certify that: 42 | 43 | (a) The contribution was created in whole or in part by me and I 44 | have the right to submit it under the open source license 45 | indicated in the file; or 46 | 47 | (b) The contribution is based upon previous work that, to the best 48 | of my knowledge, is covered under an appropriate open source 49 | license and I have the right under that license to submit that 50 | work with modifications, whether created in whole or in part 51 | by me, under the same open source license (unless I am 52 | permitted to submit under a different license), as indicated 53 | in the file; or 54 | 55 | (c) The contribution was provided directly to me by some other 56 | person who certified (a), (b) or (c) and I have not modified 57 | it. 58 | 59 | (d) I understand and agree that this project and the contribution 60 | are public and that a record of the contribution (including all 61 | personal information I submit with it, including my sign-off) is 62 | maintained indefinitely and may be redistributed consistent with 63 | this project or the open source license(s) involved. 64 | ``` 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Mellanox Technologies Ltd. All rights reserved. 2 | Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 3. Neither the name of the copyright holder nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 23 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 26 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UCX for Apache Spark Plugin 2 | UCX for Apache Spark is a high performance ShuffleManager plugin for Apache Spark, that uses RDMA and other high performance transports 3 | that are supported by [UCX](https://github.com/openucx/ucx#supported-transports), to perform Shuffle data transfers in Spark jobs. 4 | 5 | ## Runtime requirements 6 | * Apache Spark 2.4/3.0 7 | * Java 8+ 8 | * Installed UCX of version 1.13+, and [UCX supported transport hardware](https://github.com/openucx/ucx#supported-transports). 9 | 10 | ## Installation 11 | 12 | ### Obtain UCX for Apache Spark 13 | Please use the ["Releases"](https://github.com/NVIDIA/sparkucx/releases) page to download SparkUCX jar file 14 | for your spark version (e.g. ucx-spark-1.1-for-spark-2.4.0-jar-with-dependencies.jar). 15 | Put ucx-spark jar file in $SPARK_UCX_HOME on all the nodes in your cluster. 16 |
If you would like to build the project yourself, please refer to the ["Build"](https://github.com/NVIDIA/sparkucx#build) section below. 17 | 18 | Ucx binaries **must** be in Spark classpath on every Spark Worker. 19 | It can be obtained by installing the latest version from [Ucx release page](https://github.com/openucx/ucx/releases) 20 | 21 | ### Configuration 22 | 23 | Provide Spark the location of the SparkUCX plugin jars and ucx shared binaries by using the extraClassPath option. 24 | 25 | ``` 26 | spark.driver.extraClassPath $SPARK_UCX_HOME/spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar 27 | spark.executor.extraClassPath $SPARK_UCX_HOME/spark-ucx-1.0-for-spark-2.4.0-jar-with-dependencies.jar:$UCX_PREFIX/lib 28 | ``` 29 | To enable the UCX for Apache Spark Shuffle Manager plugin, add the following configuration property 30 | to spark (e.g. in $SPARK_HOME/conf/spark-defaults.conf): 31 | 32 | ``` 33 | spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager 34 | spark.executorEnv.UCX_ERROR_SIGNALS "" 35 | ``` 36 | 37 | 38 | ### Build 39 | 40 | Building the SparkUCX plugin requires [Apache Maven](http://maven.apache.org/) and Java 8+ JDK 41 | 42 | Build instructions: 43 | 44 | ``` 45 | % git clone https://github.com/nvidia/sparkucx 46 | % cd sparkucx 47 | % mvn -DskipTests clean package -Pspark-3.0 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- /buildlib/azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # See https://aka.ms/yaml 2 | 3 | trigger: 4 | - master 5 | - v*.*.x 6 | pr: 7 | - master 8 | - v*.*.x 9 | 10 | stages: 11 | - stage: Build 12 | jobs: 13 | - job: build 14 | strategy: 15 | maxParallel: 1 16 | matrix: 17 | spark-2.4: 18 | profile_version: "2.4" 19 | spark_version: "2.4.5" 20 | spark-3.0: 21 | profile_version: "3.0" 22 | spark_version: "3.0.1" 23 | pool: 24 | name: MLNX 25 | demands: 26 | - Maven 27 | - ucx_docker -equals yes 28 | steps: 29 | - task: Maven@3 30 | displayName: build 31 | inputs: 32 | javaHomeOption: "path" 33 | jdkDirectory: "/hpc/local/oss/java/jdk/" 34 | jdkVersionOption: "1.8" 35 | mavenVersionSelection: "Path" 36 | mavenPath: "/hpc/local/oss/apache-maven-3.3.9" 37 | mavenSetM2Home: true 38 | publishJUnitResults: false 39 | goals: "clean package" 40 | options: "-B -Dmaven.repo.local=$(System.DefaultWorkingDirectory)/target/.deps -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Pspark-$(profile_version)" 41 | - bash: | 42 | set -xeE 43 | module load dev/jdk-1.8 tools/spark-$(spark_version) 44 | source buildlib/test.sh 45 | 46 | if [[ $(get_rdma_device_iface) != "" ]] 47 | then 48 | export SPARK_UCX_JAR=$(System.DefaultWorkingDirectory)/target/ucx-spark-1.1-for-spark-$(profile_version)-jar-with-dependencies.jar 49 | export SPARK_LOCAL_DIRS=$(System.DefaultWorkingDirectory)/target/spark 50 | export SPARK_VERSION=$(spark_version) 51 | cd $(System.DefaultWorkingDirectory)/target/ 52 | run_tests 53 | else 54 | echo ##vso[task.complete result=Skipped;]No IB devices found 55 | fi 56 | displayName: Run spark tests 57 | -------------------------------------------------------------------------------- /buildlib/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eExl 2 | # 3 | # Testing script for SparkUCX 4 | # 5 | # Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 6 | # 7 | # Environment variables: 8 | # - NODELIST : Space separated list of nodes (default: localhost) 9 | # - UCX_LIB : Path to UCX installation (default: LD_LIBRARY_PATH) 10 | # - SPARK_UCX_JAR : Path to SparkUCX jar (default in $PWD) 11 | # - PROCESSES_PER_INSTANCE : Number of spark processes per instance (default: 2) 12 | # - EXECUTOR_NUMBER : Unique number of jenkins executor, to run multiple tests on 1 machine 13 | # - JOB_ID : Unique job id for this test 14 | # - SCRATCH_DIRECTORY : Separate directory for this test where configs will be added. 15 | # For now it's required to be on NFS 16 | # - SPARK_LOCAL_DIRS : Directory to use for map output files and RDDs that get stored on disk. 17 | # This should be on a fast, local disk in your system. 18 | # It can also be a comma-separated list of multiple directories on different disks. 19 | # - SPARK_WORKER_CORES : Number of cores per spark process (default: 2) 20 | # - SPARK_WORKER_MEMORY : Memory per spark process (default: 1g) 21 | # - SPARK_HOME : Spark location on the file system (default: $PWD/spark). If not set, will be downloaded 22 | # - SPARK_VERSION : Version of spark to test (default: 2.4.4). Used only in SPARK_HOME is not set 23 | # - RDMA_NET_IFACE : Network interface to test 24 | 25 | NODELIST=${NODELIST:="localhost"} 26 | 27 | UCX_LIB=${UCX_LIB:=${LD_LIBRARY_PATH}} 28 | 29 | SPARK_UCX_JAR=${SPARK_UCX_JAR:=$PWD/spark-ucx-1.1-for-spark-2.4-jar-with-dependencies.jar} 30 | 31 | PROCESSES_PER_INSTANCE=${PROCESSES_PER_INSTANCE:=2} 32 | 33 | if [ -n "$EXECUTOR_NUMBER" ] 34 | then 35 | AFFINITY="taskset -c $(( 2 * EXECUTOR_NUMBER ))","$(( 2 * EXECUTOR_NUMBER + 1))" 36 | else 37 | AFFINITY="" 38 | fi 39 | 40 | EXECUTOR_NUMBER=${EXECUTOR_NUMBER:=$((RANDOM % `nproc`))} 41 | 42 | JOB_ID=${SLURM_JOB_ID:="sparkucx-test-$(cat /proc/sys/kernel/random/uuid)"} 43 | 44 | SCRATCH_DIRECTORY=${SCRATCH_DIRECTORY:=$PWD/${JOB_ID}} 45 | 46 | SPARK_LOCAL_DIRS=${SPARK_LOCAL_DIRS:="/scrap/sparkucxtest/sparkucx-${JOB_ID}"} 47 | 48 | SPARK_WORKER_CORES=${SPARK_WORKER_CORES:=2} 49 | 50 | SPARK_WORKER_MEMORY=${SPARK_WORKER_MEMORY:="1g"} 51 | 52 | SPARK_HOME=${SPARK_HOME:=$PWD/spark} 53 | 54 | SPARK_VERSION=${SPARK_VERSION:="2.4.4"} 55 | 56 | UCX_BRANCH=${UCX_BRANCH:="master"} 57 | 58 | export SPARK_MASTER_HOST=${SPARK_MASTER_HOST:=$(hostname -f)} 59 | export SPARK_MASTER_PORT=${SPARK_MASTER_PORT:=$(( 2000 + 10 * ${EXECUTOR_NUMBER}))} 60 | export SPARK_CONF_DIR=${SCRATCH_DIRECTORY}/conf 61 | 62 | get_rdma_device_iface() { 63 | 64 | if [ ! -r /dev/infiniband/rdma_cm ] 65 | then 66 | return 67 | fi 68 | 69 | if ! which ibdev2netdev >&/dev/null 70 | then 71 | return 72 | fi 73 | 74 | iface=`ibdev2netdev | grep Up | awk '{print $5}' | head -1` 75 | if [[ -n "$iface" ]] 76 | then 77 | ipaddr=$(ip addr show ${iface} | awk '/inet /{print $2}' | awk -F '/' '{print $1}') 78 | fi 79 | 80 | if [[ -z "$ipaddr" ]] 81 | then 82 | # if there is no inet (IPv4) address, escape 83 | return 84 | fi 85 | 86 | ibdev=`ibdev2netdev | grep ${iface} | awk '{print $1}'` 87 | node_guid=`cat /sys/class/infiniband/$ibdev/node_guid` 88 | if [[ ${node_guid} == "0000:0000:0000:0000" ]] 89 | then 90 | return 91 | fi 92 | 93 | SPARK_MASTER_HOST=$ipaddr 94 | echo ${iface} 95 | } 96 | 97 | RDMA_NET_IFACE=${RDMA_NET_IFACE:=`get_rdma_device_iface`} 98 | 99 | download_spark() { 100 | curl -s -L -O https://www-eu.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz 101 | mkdir -p ${SPARK_HOME} 102 | tar -xf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz -C ${SPARK_HOME} --strip-components=1 103 | } 104 | 105 | build_ucx() { 106 | git clone -b ${UCX_BRANCH} --depth=1 https://github.com/openucx/ucx.git && cd ucx 107 | ./autogen.sh 108 | mkdir build && cd build 109 | ../contrib/configure-release-mt --prefix=$PWD 110 | make -j `nproc` 111 | make install 112 | UCX_LIB=$PWD/lib/ 113 | } 114 | 115 | setup_configuration() { 116 | mkdir -p ${SPARK_CONF_DIR} 117 | 118 | echo ${NODELIST} | tr -s ' ' '\n' > "${SPARK_CONF_DIR}/slaves" 119 | 120 | cat <<-EOF > ${SPARK_CONF_DIR}/spark-defaults.conf 121 | spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager 122 | spark.shuffle.readHostLocalDisk.enabled false 123 | spark.driver.extraClassPath ${SPARK_UCX_JAR} 124 | spark.executor.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB}:${UCX_LIB}/ucx 125 | spark.shuffle.ucx.driver.port $(( ${SPARK_MASTER_PORT} + 1 )) 126 | spark.executorEnv.UCX_ERROR_SIGNALS "" 127 | spark.executorEnv.UCX_LOG_LEVEL trace 128 | EOF 129 | 130 | cat <<-EOF > ${SPARK_CONF_DIR}/spark-env.sh 131 | export SPARK_LOCAL_IP=\`/sbin/ip addr show ${RDMA_NET_IFACE} | grep "inet\b" | awk '{print \$2}' | cut -d/ -f1\` 132 | export SPARK_WORKER_DIR=${SCRATCH_DIRECTORY}/work 133 | export SPARK_LOCAL_DIRS=${SPARK_LOCAL_DIRS} 134 | export SPARK_LOG_DIR=${SCRATCH_DIRECTORY}/logs 135 | export SPARK_CONF_DIR=${SPARK_CONF_DIR} 136 | export SPARK_MASTER_HOST=${SPARK_MASTER_HOST} 137 | export SPARK_MASTER_PORT=${SPARK_MASTER_PORT} 138 | export SPARK_WORKER_CORES=${SPARK_WORKER_CORES} 139 | export SPARK_WORKER_MEMORY=${SPARK_WORKER_MEMORY} 140 | export SPARK_IDENT_STRING=${JOB_ID} 141 | EOF 142 | 143 | cp ${SPARK_HOME}/conf/log4j.properties.template ${SPARK_CONF_DIR}/log4j.properties 144 | sed -i -e 's/INFO/WARN/g' ${SPARK_CONF_DIR}/log4j.properties 145 | echo "log4j.logger.org.apache.spark.shuffle=DEBUG" >> ${SPARK_CONF_DIR}/log4j.properties 146 | } 147 | 148 | start_cluster() { 149 | ${AFFINITY} ${SPARK_HOME}/sbin/start-master.sh 150 | 151 | # Make a script wrapper to propagate SPARK_CONF_DIR 152 | cat <<-EOF > ${SCRATCH_DIRECTORY}/sparkworker.sh 153 | #! /bin/bash 154 | export SPARK_CONF_DIR=${SPARK_CONF_DIR} 155 | export SPARK_WORKER_INSTANCES=${PROCESSES_PER_INSTANCE} 156 | export SPARK_IDENT_STRING=${JOB_ID} 157 | ${AFFINITY} ${SPARK_HOME}/sbin/start-slave.sh "spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT}" 158 | EOF 159 | 160 | SPARK_CONF_DIR=${SPARK_CONF_DIR} ${SPARK_HOME}/sbin/slaves.sh bash ${SCRATCH_DIRECTORY}/sparkworker.sh 161 | } 162 | 163 | run_groupby_test() { 164 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \ 165 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \ 166 | org.apache.spark.examples.GroupByTest 100 100 167 | } 168 | 169 | run_big_test() { 170 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \ 171 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \ 172 | org.apache.spark.examples.GroupByTest 200 5000 25000 200 173 | } 174 | 175 | run_tc_test() { 176 | ${SPARK_HOME}/bin/run-example --verbose --master spark://${SPARK_MASTER_HOST}:${SPARK_MASTER_PORT} \ 177 | --jars "${SPARK_HOME}/examples/jars/*.jar" --executor-memory ${SPARK_WORKER_MEMORY} \ 178 | org.apache.spark.examples.SparkTC 179 | } 180 | 181 | run_tests() { 182 | if [[ ! -d ${SPARK_HOME} ]] 183 | then 184 | download_spark 185 | fi 186 | 187 | if [[ ! -d ${UCX_LIB} ]] 188 | then 189 | build_ucx 190 | fi 191 | 192 | trap stop_cluster EXIT; 193 | 194 | setup_configuration 195 | start_cluster 196 | run_groupby_test && run_tc_test 197 | } 198 | 199 | stop_cluster() { 200 | cat <<-EOF > ${SCRATCH_DIRECTORY}/stop-sparkworker.sh 201 | #! /bin/bash 202 | export SPARK_CONF_DIR=${SPARK_CONF_DIR} 203 | export SPARK_WORKER_INSTANCES=${PROCESSES_PER_INSTANCE} 204 | export SPARK_IDENT_STRING=${JOB_ID} 205 | ${SPARK_HOME}/sbin/stop-slave.sh 206 | EOF 207 | 208 | chmod +x ${SCRATCH_DIRECTORY}/stop-sparkworker.sh 209 | # Stop all slaves 210 | ${SPARK_HOME}/sbin/slaves.sh ${SCRATCH_DIRECTORY}/stop-sparkworker.sh 211 | 212 | ${SPARK_HOME}/sbin/stop-master.sh 213 | } 214 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 6 | 7 | 11 | 4.0.0 12 | org.openucx 13 | ucx-spark 14 | 1.1 15 | ${project.artifactId} 16 | 17 | A high-performance, scalable and efficient shuffle manager plugin for Apache Spark, 18 | utilizing UCX communication layer (https://github.com/openucx/ucx/). 19 | 20 | jar 21 | 22 | 23 | 24 | BSD 3 Clause License 25 | http://www.openucx.org/license/ 26 | repo 27 | 28 | 29 | 30 | 31 | 1.8 32 | 1.8 33 | UTF-8 34 | 35 | 36 | 37 | 38 | spark-2.4 39 | 40 | 41 | 42 | net.alchim31.maven 43 | scala-maven-plugin 44 | 45 | 46 | **/spark_2_1/** 47 | **/spark_3_0/** 48 | 49 | 50 | 51 | 52 | 53 | 54 | 2.4.0 55 | **/spark_3_0/**, **/spark_2_1/** 56 | 2.11.12 57 | 2.11 58 | 59 | 60 | 61 | spark-3.0 62 | 63 | true 64 | 65 | 66 | 67 | 68 | net.alchim31.maven 69 | scala-maven-plugin 70 | 71 | 72 | **/spark_2_1/** 73 | **/spark_2_4/** 74 | 75 | 76 | 77 | 78 | 79 | 80 | 3.0.1 81 | 2.12.10 82 | 2.12 83 | **/spark_2_1/**, **/spark_2_4/** 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | org.apache.spark 92 | spark-core_${scala.compat.version} 93 | ${spark.version} 94 | provided 95 | 96 | 97 | org.openucx 98 | jucx 99 | 1.13.1 100 | 101 | 102 | 103 | 104 | ${project.artifactId}-${project.version}-for-${project.activeProfiles[0].id} 105 | 106 | 107 | net.alchim31.maven 108 | scala-maven-plugin 109 | 4.3.0 110 | 111 | all 112 | 113 | -nobootcp 114 | -Xexperimental 115 | -Xfatal-warnings 116 | -explaintypes 117 | -unchecked 118 | -deprecation 119 | -feature 120 | 121 | 122 | 123 | 124 | compile 125 | 126 | compile 127 | 128 | compile 129 | 130 | 131 | process-resources 132 | 133 | compile 134 | 135 | 136 | 137 | 138 | 139 | maven-assembly-plugin 140 | 3.1.1 141 | 142 | 143 | jar-with-dependencies 144 | 145 | 146 | 147 | 148 | make-assembly 149 | package 150 | 151 | single 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | oss.sonatype.org-snapshot 162 | http://oss.sonatype.org/content/repositories/snapshots 163 | 164 | false 165 | 166 | 167 | true 168 | 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_2_4 6 | 7 | import java.io.{File, RandomAccessFile} 8 | 9 | import org.apache.spark.shuffle.ucx.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager} 10 | 11 | /** 12 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 13 | * of data and index files and publish addresses to driver metadata buffer. 14 | */ 15 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 16 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { 17 | 18 | /** 19 | * Mapper commit protocol extension. Register index and data files and publish all needed 20 | * metadata to driver. 21 | */ 22 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int, 23 | lengths: Array[Long], dataTmp: File): Unit = { 24 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) 25 | val dataFile = getDataFile(shuffleId, mapId) 26 | if (!dataFile.exists() || dataFile.length() == 0) { 27 | return 28 | } 29 | 30 | writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, new RandomAccessFile(dataFile, "r")) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.shuffle.compat.spark_2_4 2 | 3 | import org.openucx.jucx.UcxUtils 4 | import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} 5 | import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient} 6 | import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport} 7 | import org.apache.spark.shuffle.utils.UnsafeUtils 8 | import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} 9 | 10 | class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ 11 | override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], 12 | listener: BlockFetchingListener, 13 | downloadFileManager: DownloadFileManager): Unit = { 14 | val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) 15 | val callbacks = Array.ofDim[OperationCallback](blockIds.length) 16 | for (i <- blockIds.indices) { 17 | val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId] 18 | ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) 19 | callbacks(i) = (result: OperationResult) => { 20 | val memBlock = result.getData 21 | val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt) 22 | listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { 23 | override def release: ManagedBuffer = { 24 | memBlock.close() 25 | this 26 | } 27 | }) 28 | } 29 | } 30 | val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) 31 | transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) 32 | } 33 | 34 | override def close(): Unit = { 35 | 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import org.apache.spark.shuffle.compat.spark_2_4.{UcxShuffleBlockResolver, UcxShuffleReader} 8 | import org.apache.spark.shuffle.ucx.CommonUcxShuffleManager 9 | import org.apache.spark.{SparkConf, TaskContext} 10 | 11 | /** 12 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin 13 | * and injects needed logic in override methods. 14 | */ 15 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) 16 | extends CommonUcxShuffleManager(conf, isDriver) { 17 | private[this] lazy val transport = awaitUcxTransport 18 | /** 19 | * Mapper callback on executor. Just start UcxNode and use Spark mapper logic. 20 | */ 21 | override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, 22 | context: TaskContext): ShuffleWriter[K, V] = { 23 | super.getWriter(handle, mapId, context) 24 | } 25 | 26 | override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this) 27 | 28 | /** 29 | * Reducer callback on executor. 30 | */ 31 | override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, 32 | endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { 33 | new UcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]], startPartition, 34 | endPartition, context, transport) 35 | } 36 | } 37 | 38 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_2_4 6 | 7 | import java.io.InputStream 8 | import java.util.concurrent.LinkedBlockingQueue 9 | 10 | import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} 11 | import org.apache.spark.internal.{Logging, config} 12 | import org.apache.spark.serializer.SerializerManager 13 | import org.apache.spark.shuffle.ucx.UcxShuffleTransport 14 | import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} 15 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} 16 | import org.apache.spark.util.CompletionIterator 17 | import org.apache.spark.util.collection.ExternalSorter 18 | 19 | /** 20 | * Extension of Spark's shuffle reader with a logic of injection UcxShuffleClient, 21 | * and lazy progress only when result queue is empty. 22 | */ 23 | class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], 24 | startPartition: Int, 25 | endPartition: Int, 26 | context: TaskContext, 27 | transport: UcxShuffleTransport, 28 | serializerManager: SerializerManager = SparkEnv.get.serializerManager, 29 | blockManager: BlockManager = SparkEnv.get.blockManager, 30 | mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) 31 | extends ShuffleReader[K, C] with Logging { 32 | 33 | private val dep = handle.dependency 34 | 35 | /** Read the combined key-values for this reduce task */ 36 | override def read(): Iterator[Product2[K, C]] = { 37 | val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() 38 | val shuffleClient = new UcxShuffleClient(transport) 39 | val wrappedStreams = new ShuffleBlockFetcherIterator( 40 | context, 41 | shuffleClient, 42 | blockManager, 43 | mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, 44 | startPartition, endPartition), 45 | serializerManager.wrapStream, 46 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility 47 | SparkEnv.get.conf.getSizeAsBytes("spark.reducer.maxSizeInFlight", "48m"), 48 | SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), 49 | SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), 50 | SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), 51 | SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) 52 | 53 | // Ucx shuffle logic 54 | // Java reflection to get access to private results queue 55 | val queueField = wrappedStreams.getClass.getDeclaredField( 56 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") 57 | queueField.setAccessible(true) 58 | val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] 59 | 60 | // Do progress if queue is empty before calling next on ShuffleIterator 61 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { 62 | override def next(): (BlockId, InputStream) = { 63 | val startTime = System.currentTimeMillis() 64 | while (resultQueue.isEmpty) { 65 | transport.progress() 66 | } 67 | shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) 68 | wrappedStreams.next() 69 | } 70 | 71 | override def hasNext: Boolean = { 72 | val result = wrappedStreams.hasNext 73 | if (!result) { 74 | shuffleClient.close() 75 | } 76 | result 77 | } 78 | } 79 | // End of ucx shuffle logic 80 | 81 | val serializerInstance = dep.serializer.newInstance() 82 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => 83 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a 84 | // NextIterator. The NextIterator makes sure that close() is called on the 85 | // underlying InputStream when all records have been read. 86 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator 87 | } 88 | 89 | // Update the context task metrics for each record read. 90 | val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() 91 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( 92 | recordIter.map { record => 93 | readMetrics.incRecordsRead(1) 94 | record 95 | }, 96 | context.taskMetrics().mergeShuffleReadMetrics()) 97 | 98 | // An interruptible iterator must be used here in order to support task cancellation 99 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) 100 | 101 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { 102 | if (dep.mapSideCombine) { 103 | // We are reading values that are already combined 104 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] 105 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) 106 | } else { 107 | // We don't know the value type, but also don't care -- the dependency *should* 108 | // have made sure its compatible w/ this aggregator, which will convert the value 109 | // type to the combined type C 110 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] 111 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) 112 | } 113 | } else { 114 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] 115 | } 116 | 117 | // Sort the output if there is a sort ordering defined. 118 | val resultIter = dep.keyOrdering match { 119 | case Some(keyOrd: Ordering[K]) => 120 | // Create an ExternalSorter to sort the data. 121 | val sorter = 122 | new ExternalSorter[K, C, C](context, 123 | ordering = Some(keyOrd), serializer = dep.serializer) 124 | sorter.insertAll(aggregatedIter) 125 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) 126 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) 127 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) 128 | // Use completion callback to stop sorter if task was finished/cancelled. 129 | context.addTaskCompletionListener[Unit](_ => { 130 | sorter.stop() 131 | }) 132 | CompletionIterator[Product2[K, C], 133 | Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) 134 | case None => 135 | aggregatedIter 136 | } 137 | 138 | resultIter match { 139 | case _: InterruptibleIterator[Product2[K, C]] => resultIter 140 | case _ => 141 | // Use another interruptible iterator here to support task cancellation as aggregator 142 | // or(and) sorter may have consumed previous interruptible iterator. 143 | new InterruptibleIterator[Product2[K, C]](context, resultIter) 144 | } 145 | } 146 | 147 | } 148 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxLocalDiskShuffleExecutorComponents.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import java.util 8 | import java.util.Optional 9 | 10 | import org.apache.spark.internal.Logging 11 | import org.apache.spark.{SparkConf, SparkEnv} 12 | import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} 13 | import org.apache.spark.shuffle.UcxShuffleManager 14 | import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} 15 | 16 | /** 17 | * Entry point to UCX executor. 18 | */ 19 | class UcxLocalDiskShuffleExecutorComponents(sparkConf: SparkConf) 20 | extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging{ 21 | 22 | private var blockResolver: UcxShuffleBlockResolver = _ 23 | 24 | override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { 25 | val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] 26 | while (ucxShuffleManager.ucxTransport == null) { 27 | Thread.sleep(5) 28 | } 29 | blockResolver = ucxShuffleManager.shuffleBlockResolver 30 | } 31 | 32 | override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { 33 | if (blockResolver == null) { 34 | throw new IllegalStateException( 35 | "Executor components must be initialized before getting writers.") 36 | } 37 | new LocalDiskShuffleMapOutputWriter( 38 | shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) 39 | } 40 | 41 | override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = { 42 | if (blockResolver == null) { 43 | throw new IllegalStateException( 44 | "Executor components must be initialized before getting writers.") 45 | } 46 | Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import java.io.{File, RandomAccessFile} 8 | 9 | import org.apache.spark.TaskContext 10 | import org.apache.spark.shuffle.ucx.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager} 11 | 12 | /** 13 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 14 | * of data and index files and publish addresses to driver metadata buffer. 15 | */ 16 | class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 17 | extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { 18 | 19 | 20 | override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, 21 | lengths: Array[Long], dataTmp: File): Unit = { 22 | super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) 23 | // In Spark-3.0 MapId is long and unique among all jobs in spark. We need to use partitionId as offset 24 | // in metadata buffer 25 | val partitionId = TaskContext.getPartitionId() 26 | val dataFile = getDataFile(shuffleId, mapId) 27 | if (!dataFile.exists() || dataFile.length() == 0) { 28 | return 29 | } 30 | writeIndexFileAndCommitCommon(shuffleId, partitionId, lengths, new RandomAccessFile(dataFile, "r")) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.compat.spark_3_0 6 | 7 | import org.apache.spark.internal.Logging 8 | import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} 9 | import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} 10 | import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport} 11 | import org.apache.spark.shuffle.utils.UnsafeUtils 12 | import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} 13 | 14 | class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { 15 | 16 | override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], 17 | listener: BlockFetchingListener, 18 | downloadFileManager: DownloadFileManager): Unit = { 19 | if (blockIds.length > transport.ucxShuffleConf.maxBlocksPerRequest) { 20 | val (b1, b2) = blockIds.splitAt(blockIds.length / 2) 21 | fetchBlocks(host, port, execId, b1, listener, downloadFileManager) 22 | fetchBlocks(host, port, execId, b2, listener, downloadFileManager) 23 | return 24 | } 25 | 26 | val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) 27 | val callbacks = Array.ofDim[OperationCallback](blockIds.length) 28 | for (i <- blockIds.indices) { 29 | val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId] 30 | ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId) 31 | callbacks(i) = (result: OperationResult) => { 32 | val memBlock = result.getData 33 | val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt) 34 | listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { 35 | override def release: ManagedBuffer = { 36 | memBlock.close() 37 | this 38 | } 39 | }) 40 | } 41 | } 42 | val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) 43 | transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks) 44 | transport.progress() 45 | } 46 | 47 | override def close(): Unit = { 48 | 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | import org.apache.spark.shuffle.api.ShuffleExecutorComponents 10 | import org.apache.spark.shuffle.compat.spark_3_0.{UcxLocalDiskShuffleExecutorComponents, UcxShuffleBlockResolver, UcxShuffleReader} 11 | import org.apache.spark.shuffle.sort.{SerializedShuffleHandle, SortShuffleWriter, UnsafeShuffleWriter} 12 | import org.apache.spark.shuffle.ucx.CommonUcxShuffleManager 13 | import org.apache.spark.{SparkConf, SparkEnv, TaskContext} 14 | 15 | /** 16 | * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin 17 | * and injects needed logic in override methods. 18 | */ 19 | class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) 20 | extends CommonUcxShuffleManager(conf, isDriver) { 21 | 22 | private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) 23 | private[this] lazy val transport = awaitUcxTransport 24 | 25 | override val shuffleBlockResolver = new UcxShuffleBlockResolver(this) 26 | 27 | override def getWriter[K, V](handle: ShuffleHandle, mapId: ReduceId, context: TaskContext, 28 | metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { 29 | val env = SparkEnv.get 30 | handle match { 31 | case unsafeShuffleHandle: SerializedShuffleHandle[K@unchecked, V@unchecked] => 32 | new UnsafeShuffleWriter( 33 | env.blockManager, 34 | context.taskMemoryManager(), 35 | unsafeShuffleHandle, 36 | mapId, 37 | context, 38 | env.conf, 39 | metrics, 40 | shuffleExecutorComponents) 41 | case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => 42 | new SortShuffleWriter( 43 | shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) 44 | } 45 | } 46 | 47 | override def getReader[K, C](handle: ShuffleHandle, startPartition: MapId, endPartition: MapId, 48 | context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { 49 | new UcxShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K,_,C]], startPartition, endPartition, 50 | context, transport, readMetrics = metrics, shouldBatchFetch = false) 51 | } 52 | 53 | private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { 54 | val executorComponents = new UcxLocalDiskShuffleExecutorComponents(conf) 55 | val extraConfigs = conf.getAllWithPrefix(ShuffleDataIOUtils.SHUFFLE_SPARK_CONF_PREFIX) 56 | .toMap 57 | executorComponents.initializeExecutor( 58 | conf.getAppId, 59 | SparkEnv.get.executorId, 60 | extraConfigs.asJava) 61 | executorComponents 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleReader.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.shuffle.compat.spark_3_0 19 | 20 | import java.io.InputStream 21 | import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} 22 | 23 | import org.apache.spark._ 24 | import org.apache.spark.internal.{Logging, config} 25 | import org.apache.spark.io.CompressionCodec 26 | import org.apache.spark.serializer.SerializerManager 27 | import org.apache.spark.shuffle.ucx.UcxShuffleTransport 28 | import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReadMetricsReporter, ShuffleReader} 29 | import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockBatchId, ShuffleBlockFetcherIterator, ShuffleBlockId} 30 | import org.apache.spark.util.CompletionIterator 31 | import org.apache.spark.util.collection.ExternalSorter 32 | 33 | /** 34 | * Fetches and reads the blocks from a shuffle by requesting them from other nodes' block stores. 35 | */ 36 | private[spark] class UcxShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], 37 | startPartition: Int, 38 | endPartition: Int, 39 | context: TaskContext, 40 | transport: UcxShuffleTransport, 41 | readMetrics: ShuffleReadMetricsReporter, 42 | serializerManager: SerializerManager = SparkEnv.get.serializerManager, 43 | blockManager: BlockManager = SparkEnv.get.blockManager, 44 | mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, 45 | shouldBatchFetch: Boolean = false) 46 | extends ShuffleReader[K, C] with Logging { 47 | 48 | private val dep = handle.dependency 49 | 50 | private def fetchContinuousBlocksInBatch: Boolean = { 51 | val conf = SparkEnv.get.conf 52 | val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects 53 | val compressed = conf.get(config.SHUFFLE_COMPRESS) 54 | val codecConcatenation = if (compressed) { 55 | CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) 56 | } else { 57 | true 58 | } 59 | val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) 60 | 61 | val doBatchFetch = shouldBatchFetch && serializerRelocatable && 62 | (!compressed || codecConcatenation) && !useOldFetchProtocol 63 | if (shouldBatchFetch && !doBatchFetch) { 64 | logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + 65 | "we can not enable the feature because other conditions are not satisfied. " + 66 | s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + 67 | s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + 68 | s"$useOldFetchProtocol.") 69 | } 70 | doBatchFetch 71 | } 72 | 73 | /** Read the combined key-values for this reduce task */ 74 | override def read(): Iterator[Product2[K, C]] = { 75 | val (blocksByAddressIterator1, blocksByAddressIterator2) = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( 76 | handle.shuffleId, startPartition, endPartition).duplicate 77 | val mapIdToBlockIndex = blocksByAddressIterator2.flatMap{ 78 | case (_, blocks) => blocks.map { 79 | case (blockId, _, mapIdx) => blockId match { 80 | case x: ShuffleBlockId => (x.mapId, mapIdx) 81 | case x: ShuffleBlockBatchId => (x.mapId, mapIdx) 82 | case _ => throw new SparkException("Unknown block") 83 | } 84 | } 85 | }.toMap 86 | 87 | val shuffleClient = new UcxShuffleClient(transport, mapIdToBlockIndex) 88 | val shuffleIterator = new ShuffleBlockFetcherIterator( 89 | context, 90 | shuffleClient, 91 | blockManager, 92 | blocksByAddressIterator1, 93 | serializerManager.wrapStream, 94 | // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility 95 | SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, 96 | SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), 97 | SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), 98 | SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), 99 | SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), 100 | SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), 101 | readMetrics, 102 | // TODO: Support batch fetch 103 | doBatchFetch = false) 104 | 105 | val wrappedStreams = shuffleIterator.toCompletionIterator 106 | 107 | 108 | // Ucx shuffle logic 109 | // Java reflection to get access to private results queue 110 | val queueField = shuffleIterator.getClass.getDeclaredField( 111 | "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") 112 | queueField.setAccessible(true) 113 | val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] 114 | 115 | // Do progress if queue is empty before calling next on ShuffleIterator 116 | val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { 117 | override def next(): (BlockId, InputStream) = { 118 | val startTime = System.nanoTime() 119 | while (resultQueue.isEmpty) { 120 | transport.progress() 121 | } 122 | val fetchWaitTime = System.nanoTime() - startTime 123 | readMetrics.incFetchWaitTime(TimeUnit.NANOSECONDS.toMillis(fetchWaitTime)) 124 | wrappedStreams.next() 125 | } 126 | 127 | override def hasNext: Boolean = { 128 | val result = wrappedStreams.hasNext 129 | if (!result) { 130 | shuffleClient.close() 131 | } 132 | result 133 | } 134 | } 135 | // End of ucx shuffle logic 136 | 137 | val serializerInstance = dep.serializer.newInstance() 138 | 139 | // Create a key/value iterator for each stream 140 | val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => 141 | // Note: the asKeyValueIterator below wraps a key/value iterator inside of a 142 | // NextIterator. The NextIterator makes sure that close() is called on the 143 | // underlying InputStream when all records have been read. 144 | serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator 145 | } 146 | 147 | // Update the context task metrics for each record read. 148 | val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( 149 | recordIter.map { record => 150 | readMetrics.incRecordsRead(1) 151 | record 152 | }, 153 | context.taskMetrics().mergeShuffleReadMetrics()) 154 | 155 | // An interruptible iterator must be used here in order to support task cancellation 156 | val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) 157 | 158 | val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { 159 | if (dep.mapSideCombine) { 160 | // We are reading values that are already combined 161 | val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] 162 | dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) 163 | } else { 164 | // We don't know the value type, but also don't care -- the dependency *should* 165 | // have made sure its compatible w/ this aggregator, which will convert the value 166 | // type to the combined type C 167 | val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] 168 | dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) 169 | } 170 | } else { 171 | interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] 172 | } 173 | 174 | // Sort the output if there is a sort ordering defined. 175 | val resultIter = dep.keyOrdering match { 176 | case Some(keyOrd: Ordering[K]) => 177 | // Create an ExternalSorter to sort the data. 178 | val sorter = 179 | new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) 180 | sorter.insertAll(aggregatedIter) 181 | context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) 182 | context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) 183 | context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) 184 | // Use completion callback to stop sorter if task was finished/cancelled. 185 | context.addTaskCompletionListener[Unit](_ => { 186 | sorter.stop() 187 | }) 188 | CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) 189 | case None => 190 | aggregatedIter 191 | } 192 | 193 | resultIter match { 194 | case _: InterruptibleIterator[Product2[K, C]] => resultIter 195 | case _ => 196 | // Use another interruptible iterator here to support task cancellation as aggregator 197 | // or(and) sorter may have consumed previous interruptible iterator. 198 | new InterruptibleIterator[Product2[K, C]](context, resultIter) 199 | } 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx 6 | 7 | import java.io.RandomAccessFile 8 | import java.nio.ByteBuffer 9 | import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} 10 | 11 | import org.apache.spark.shuffle.IndexShuffleBlockResolver 12 | import org.apache.spark.shuffle.utils.UnsafeUtils 13 | 14 | class FileBackedMemoryBlock(baseAddress: Long, baseSize: Long, address: Long, size: Long) 15 | extends MemoryBlock(address, size) { 16 | override def close(): Unit = { 17 | UnsafeUtils.munmap(baseAddress, baseSize) 18 | } 19 | } 20 | 21 | /** 22 | * Mapper entry point for UcxShuffle plugin. Performs memory registration 23 | * of data and index files and publish addresses to driver metadata buffer. 24 | */ 25 | abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) 26 | extends IndexShuffleBlockResolver(ucxShuffleManager.conf) { 27 | 28 | private val openFds = new ConcurrentHashMap[ShuffleId, ConcurrentLinkedQueue[RandomAccessFile]]() 29 | private[ucx] lazy val transport = ucxShuffleManager.awaitUcxTransport 30 | 31 | /** 32 | * Mapper commit protocol extension. Register index and data files and publish all needed 33 | * metadata to driver. 34 | */ 35 | def writeIndexFileAndCommitCommon(shuffleId: ShuffleId, mapId: Int, 36 | lengths: Array[Long], dataBackFile: RandomAccessFile): Unit = { 37 | openFds.computeIfAbsent(shuffleId, (_: ShuffleId) => new ConcurrentLinkedQueue[RandomAccessFile]()) 38 | openFds.get(shuffleId).add(dataBackFile) 39 | var offset = 0L 40 | val channel = dataBackFile.getChannel 41 | for ((blockLength, reduceId) <- lengths.zipWithIndex) { 42 | if (blockLength > 0) { 43 | val blockId = UcxShuffleBockId(shuffleId, mapId ,reduceId) 44 | val block = new Block { 45 | private val fileOffset = offset 46 | 47 | override def getBlock(byteBuffer: ByteBuffer): Unit = { 48 | channel.read(byteBuffer, fileOffset) 49 | } 50 | 51 | override def getSize: Long = blockLength 52 | } 53 | transport.register(blockId, block) 54 | offset += blockLength 55 | } 56 | } 57 | } 58 | 59 | def removeShuffle(shuffleId: Int): Unit = { 60 | val fds = openFds.remove(shuffleId) 61 | if (fds != null) { 62 | fds.forEach(f => f.close()) 63 | } 64 | if (ucxShuffleManager.ucxTransport != null) { 65 | ucxShuffleManager.ucxTransport.unregisterShuffle(shuffleId) 66 | } 67 | } 68 | 69 | override def stop(): Unit = { 70 | if (ucxShuffleManager.ucxTransport != null) { 71 | ucxShuffleManager.ucxTransport.unregisterAllBlocks() 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleManager.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx 6 | 7 | import java.util.concurrent.{CountDownLatch, TimeUnit} 8 | 9 | import scala.concurrent.ExecutionContext.Implicits.global 10 | import scala.util.Success 11 | 12 | import org.apache.spark.rpc.RpcEnv 13 | import org.apache.spark.shuffle.sort.SortShuffleManager 14 | import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint} 15 | import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} 16 | import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer 17 | import org.apache.spark.util.{RpcUtils, ThreadUtils} 18 | import org.apache.spark.{SecurityManager, SparkConf, SparkEnv} 19 | import org.openucx.jucx.{NativeLibs, UcxException} 20 | 21 | /** 22 | * Common part for all spark versions for UcxShuffleManager logic 23 | */ 24 | abstract class CommonUcxShuffleManager(val conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { 25 | type ShuffleId = Int 26 | type MapId = Int 27 | type ReduceId = Long 28 | 29 | /* Load UCX/JUCX libraries as soon as possible to avoid collision with JVM when register malloc/mmap hook. */ 30 | if (!isDriver) { 31 | NativeLibs.load(); 32 | } 33 | 34 | val ucxShuffleConf = new UcxShuffleConf(conf) 35 | 36 | private[this] val latch = new CountDownLatch(1) 37 | @volatile var ucxTransport: UcxShuffleTransport = _ 38 | 39 | private var executorEndpoint: UcxExecutorRpcEndpoint = _ 40 | private var driverEndpoint: UcxDriverRpcEndpoint = _ 41 | 42 | protected val driverRpcName = "SparkUCX_driver" 43 | 44 | private val setupThread = ThreadUtils.newDaemonSingleThreadExecutor("UcxTransportSetupThread") 45 | 46 | setupThread.submit(new Runnable { 47 | override def run(): Unit = { 48 | while (SparkEnv.get == null) { 49 | Thread.sleep(10) 50 | } 51 | if (isDriver) { 52 | val rpcEnv = SparkEnv.get.rpcEnv 53 | logInfo(s"Setting up driver RPC") 54 | driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv) 55 | rpcEnv.setupEndpoint(driverRpcName, driverEndpoint) 56 | } else { 57 | while (SparkEnv.get.blockManager.blockManagerId == null) { 58 | Thread.sleep(5) 59 | } 60 | startUcxTransport() 61 | } 62 | } 63 | }) 64 | 65 | def awaitUcxTransport(): UcxShuffleTransport = { 66 | if (ucxTransport == null) { 67 | latch.await(10, TimeUnit.SECONDS) 68 | if (ucxTransport == null) { 69 | throw new UcxException("UcxShuffleTransport init timeout") 70 | } 71 | } 72 | ucxTransport 73 | } 74 | 75 | /** 76 | * Atomically starts UcxNode singleton - one for all shuffle threads. 77 | */ 78 | def startUcxTransport(): Unit = if (ucxTransport == null) { 79 | val blockManager = SparkEnv.get.blockManager.blockManagerId 80 | val transport = new UcxShuffleTransport(ucxShuffleConf, blockManager.executorId.toLong) 81 | val address = transport.init() 82 | ucxTransport = transport 83 | latch.countDown() 84 | val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.port, 85 | conf, new SecurityManager(conf), clientMode = false) 86 | executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxTransport, setupThread) 87 | val endpoint = rpcEnv.setupEndpoint( 88 | s"ucx-shuffle-executor-${blockManager.executorId}", 89 | executorEndpoint) 90 | val driverEndpoint = RpcUtils.makeDriverRef(driverRpcName, conf, rpcEnv) 91 | driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId.toLong, endpoint, 92 | new SerializableDirectBuffer(address))) 93 | .andThen { 94 | case Success(msg) => 95 | logInfo(s"Receive reply $msg") 96 | executorEndpoint.receive(msg) 97 | } 98 | } 99 | 100 | 101 | override def unregisterShuffle(shuffleId: Int): Boolean = { 102 | shuffleBlockResolver.asInstanceOf[CommonUcxShuffleBlockResolver].removeShuffle(shuffleId) 103 | super.unregisterShuffle(shuffleId) 104 | } 105 | 106 | /** 107 | * Called on both driver and executors to finally cleanup resources. 108 | */ 109 | override def stop(): Unit = synchronized { 110 | super.stop() 111 | if (ucxTransport != null) { 112 | ucxTransport.close() 113 | ucxTransport = null 114 | } 115 | if (executorEndpoint != null) { 116 | executorEndpoint.stop() 117 | } 118 | if (driverEndpoint != null) { 119 | driverEndpoint.stop() 120 | } 121 | setupThread.shutdown() 122 | } 123 | 124 | } 125 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx 6 | 7 | import java.nio.ByteBuffer 8 | import java.util.concurrent.locks.StampedLock 9 | 10 | /** 11 | * Class that represents some block in memory with it's address, size. 12 | * 13 | * @param isHostMemory host or GPU memory 14 | */ 15 | case class MemoryBlock(address: Long, size: Long, isHostMemory: Boolean = true) extends AutoCloseable { 16 | /** 17 | * Important to call this method, to return memory to pool, or close resources 18 | */ 19 | override def close(): Unit = {} 20 | } 21 | 22 | /** 23 | * Base class to indicate some blockId. It should be hashable and could be constructed on both ends. 24 | * E.g. ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) 25 | */ 26 | trait BlockId { 27 | def serializedSize: Int 28 | def serialize(byteBuffer: ByteBuffer): Unit 29 | } 30 | 31 | private[ucx] sealed trait BlockLock { 32 | // Private transport lock to know when there are outstanding operations to block memory. 33 | private[ucx] lazy val lock = new StampedLock().asReadWriteLock() 34 | } 35 | 36 | /** 37 | * Some block in memory, that transport registers and that would requested on a remote side. 38 | */ 39 | trait Block extends BlockLock { 40 | def getSize: Long 41 | 42 | // This method for future use with a device buffers. 43 | def getMemoryBlock: MemoryBlock = ??? 44 | 45 | // Get block from a file into byte buffer backed bunce buffer 46 | def getBlock(byteBuffer: ByteBuffer): Unit 47 | } 48 | 49 | object OperationStatus extends Enumeration { 50 | val SUCCESS, CANCELED, FAILURE = Value 51 | } 52 | 53 | /** 54 | * Operation statistic, like completionTime, transport used, protocol used, etc. 55 | */ 56 | trait OperationStats { 57 | /** 58 | * Time it took from operation submit to callback call. 59 | * This depends on [[ ShuffleTransport.progress() ]] calls, 60 | * and does not indicate actual data transfer time. 61 | */ 62 | def getElapsedTimeNs: Long 63 | 64 | /** 65 | * Indicates number of valid bytes in receive memory when using 66 | * [[ ShuffleTransport.fetchBlocksByBlockIds()]] 67 | */ 68 | def recvSize: Long 69 | } 70 | 71 | class TransportError(errorMsg: String) extends Exception(errorMsg) 72 | 73 | trait OperationResult { 74 | def getStatus: OperationStatus.Value 75 | def getError: TransportError 76 | def getStats: Option[OperationStats] 77 | def getData: MemoryBlock 78 | } 79 | 80 | /** 81 | * Request object that returns by [[ ShuffleTransport.fetchBlocksByBlockIds() ]] routine. 82 | */ 83 | trait Request { 84 | def isCompleted: Boolean 85 | def getStats: Option[OperationStats] 86 | } 87 | 88 | /** 89 | * Async operation callbacks 90 | */ 91 | trait OperationCallback { 92 | def onComplete(result: OperationResult): Unit 93 | } 94 | 95 | /** 96 | * Transport flow example: 97 | * val transport = new UcxShuffleTransport() 98 | * transport.init() 99 | * 100 | * Mapper/writer: 101 | * transport.register(blockId, block) 102 | * 103 | * Reducer: 104 | * transport.fetchBlockByBlockId(blockId, resultBounceBuffer) 105 | * transport.progress() 106 | * 107 | * transport.unregister(blockId) 108 | * transport.close() 109 | */ 110 | trait ShuffleTransport { 111 | type ExecutorId = Long 112 | type BufferAllocator = Long => MemoryBlock 113 | /** 114 | * Initialize transport resources. This function should get called after ensuring that SparkConf 115 | * has the correct configurations since it will use the spark configuration to configure itself. 116 | * 117 | * @return worker address of current process, to use in [[ addExecutor()]] 118 | */ 119 | def init(): ByteBuffer 120 | 121 | /** 122 | * Close all transport resources 123 | */ 124 | def close(): Unit 125 | 126 | /** 127 | * Add executor's worker address. For standalone testing purpose and for implementations that makes 128 | * connection establishment outside of UcxShuffleManager. 129 | */ 130 | def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit 131 | 132 | /** 133 | * Remove executor from communications. 134 | */ 135 | def removeExecutor(executorId: ExecutorId): Unit 136 | 137 | /** 138 | * Registers blocks using blockId on SERVER side. 139 | */ 140 | def register(blockId: BlockId, block: Block): Unit 141 | 142 | /** 143 | * Change location of underlying blockId in memory 144 | */ 145 | def mutate(blockId: BlockId, newBlock: Block, callback: OperationCallback): Unit 146 | 147 | /** 148 | * Indicate that this blockId is not needed any more by an application. 149 | * Note: this is a blocking call. On return it's safe to free blocks memory. 150 | */ 151 | def unregister(blockId: BlockId): Unit 152 | 153 | /** 154 | * Batch version of [[ fetchBlocksByBlockIds ]]. 155 | */ 156 | def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], 157 | resultBufferAllocator: BufferAllocator, 158 | callbacks: Seq[OperationCallback]): Seq[Request] 159 | 160 | /** 161 | * Progress outstanding operations. This routine is blocking (though may poll for event). 162 | * It's required to call this routine within same thread that submitted [[ fetchBlocksByBlockIds ]]. 163 | * 164 | * Return from this method guarantees that at least some operation was progressed. 165 | * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! 166 | */ 167 | def progress(): Unit 168 | 169 | } 170 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx 6 | 7 | import org.apache.spark.SparkConf 8 | import org.apache.spark.internal.config.ConfigBuilder 9 | import org.apache.spark.network.util.ByteUnit 10 | 11 | /** 12 | * Plugin configuration properties. 13 | */ 14 | class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf { 15 | 16 | def getSparkConf: SparkConf = sparkConf 17 | 18 | private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" 19 | 20 | // Memory Pool 21 | private lazy val PREALLOCATE_BUFFERS = 22 | ConfigBuilder(getUcxConf("memory.preAllocateBuffers")) 23 | .doc("Comma separated list of buffer size : buffer count pairs to preallocate in memory pool. E.g. 4k:1000,16k:500") 24 | .stringConf.createWithDefault("") 25 | 26 | lazy val preallocateBuffersMap: Map[Long, Int] = { 27 | sparkConf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => s.nonEmpty) 28 | .map(entry => entry.split(":") match { 29 | case Array(bufferSize, bufferCount) => (bufferSize.toLong, bufferCount.toInt) 30 | }).toMap 31 | } 32 | 33 | private lazy val MIN_BUFFER_SIZE = ConfigBuilder(getUcxConf("memory.minBufferSize")) 34 | .doc("Minimal buffer size in memory pool.") 35 | .bytesConf(ByteUnit.BYTE) 36 | .createWithDefault(4096) 37 | 38 | lazy val minBufferSize: Long = sparkConf.getSizeAsBytes(MIN_BUFFER_SIZE.key, 39 | MIN_BUFFER_SIZE.defaultValueString) 40 | 41 | private lazy val MIN_REGISTRATION_SIZE = 42 | ConfigBuilder(getUcxConf("memory.minAllocationSize")) 43 | .doc("Minimal memory registration size in memory pool.") 44 | .bytesConf(ByteUnit.MiB) 45 | .createWithDefault(1) 46 | 47 | lazy val minRegistrationSize: Int = sparkConf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, 48 | MIN_REGISTRATION_SIZE.defaultValueString).toInt 49 | 50 | private lazy val SOCKADDR = 51 | ConfigBuilder(getUcxConf("listener.sockaddr")) 52 | .doc("Whether to use socket address to connect executors.") 53 | .stringConf 54 | .createWithDefault("0.0.0.0:0") 55 | 56 | lazy val listenerAddress: String = sparkConf.get(SOCKADDR.key, SOCKADDR.defaultValueString) 57 | 58 | private lazy val WAKEUP_FEATURE = 59 | ConfigBuilder(getUcxConf("useWakeup")) 60 | .doc("Whether to use busy polling for workers") 61 | .booleanConf 62 | .createWithDefault(true) 63 | 64 | lazy val useWakeup: Boolean = sparkConf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get) 65 | 66 | private lazy val NUM_IO_THREADS= ConfigBuilder(getUcxConf("numIoThreads")) 67 | .doc("Number of threads in io thread pool") 68 | .intConf 69 | .createWithDefault(1) 70 | 71 | lazy val numIoThreads: Int = sparkConf.getInt(NUM_IO_THREADS.key, NUM_IO_THREADS.defaultValue.get) 72 | 73 | private lazy val NUM_LISTNER_THREADS= ConfigBuilder(getUcxConf("numListenerThreads")) 74 | .doc("Number of threads in listener thread pool") 75 | .intConf 76 | .createWithDefault(3) 77 | 78 | lazy val numListenerThreads: Int = sparkConf.getInt(NUM_LISTNER_THREADS.key, NUM_LISTNER_THREADS.defaultValue.get) 79 | 80 | private lazy val NUM_WORKERS = ConfigBuilder(getUcxConf("numClientWorkers")) 81 | .doc("Number of client workers") 82 | .intConf 83 | .createWithDefault(1) 84 | 85 | lazy val numWorkers: Int = sparkConf.getInt(NUM_WORKERS.key, sparkConf.getInt("spark.executor.cores", 86 | NUM_WORKERS.defaultValue.get)) 87 | 88 | private lazy val MAX_BLOCKS_IN_FLIGHT = ConfigBuilder(getUcxConf("maxBlocksPerRequest")) 89 | .doc("Maximum number blocks per request") 90 | .intConf 91 | .createWithDefault(50) 92 | 93 | lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get) 94 | } 95 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2022, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx 6 | 7 | import org.apache.spark.SparkEnv 8 | import org.apache.spark.internal.Logging 9 | import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool 10 | import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread 11 | import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils} 12 | import org.apache.spark.shuffle.utils.UnsafeUtils 13 | import org.openucx.jucx.UcxException 14 | import org.openucx.jucx.ucp._ 15 | import org.openucx.jucx.ucs.UcsConstants 16 | 17 | import java.net.InetSocketAddress 18 | import java.nio.ByteBuffer 19 | import scala.collection.concurrent.TrieMap 20 | import scala.collection.mutable 21 | 22 | class UcxRequest(private var request: UcpRequest, stats: OperationStats) 23 | extends Request { 24 | 25 | private[ucx] var completed = false 26 | 27 | override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted) 28 | 29 | override def getStats: Option[OperationStats] = Some(stats) 30 | 31 | private[ucx] def setRequest(request: UcpRequest): Unit = { 32 | this.request = request 33 | } 34 | } 35 | 36 | class UcxStats extends OperationStats { 37 | private[ucx] val startTime = System.nanoTime() 38 | private[ucx] var amHandleTime = 0L 39 | private[ucx] var endTime: Long = 0L 40 | private[ucx] var receiveSize: Long = 0L 41 | 42 | /** 43 | * Time it took from operation submit to callback call. 44 | * This depends on [[ ShuffleTransport.progress() ]] calls, 45 | * and does not indicate actual data transfer time. 46 | */ 47 | override def getElapsedTimeNs: Long = endTime - startTime 48 | 49 | /** 50 | * Indicates number of valid bytes in receive memory 51 | */ 52 | override def recvSize: Long = receiveSize 53 | } 54 | 55 | case class UcxShuffleBockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { 56 | override def serializedSize: Int = 12 57 | 58 | override def serialize(byteBuffer: ByteBuffer): Unit = { 59 | byteBuffer.putInt(shuffleId) 60 | byteBuffer.putInt(mapId) 61 | byteBuffer.putInt(reduceId) 62 | } 63 | } 64 | 65 | object UcxShuffleBockId { 66 | def deserialize(byteBuffer: ByteBuffer): UcxShuffleBockId = { 67 | val shuffleId = byteBuffer.getInt 68 | val mapId = byteBuffer.getInt 69 | val reduceId = byteBuffer.getInt 70 | UcxShuffleBockId(shuffleId, mapId, reduceId) 71 | } 72 | } 73 | 74 | class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executorId: Long = 0) extends ShuffleTransport 75 | with Logging { 76 | @volatile private var initialized: Boolean = false 77 | private[ucx] var ucxContext: UcpContext = _ 78 | private var globalWorker: UcpWorker = _ 79 | private var listener: UcpListener = _ 80 | private val ucpWorkerParams = new UcpWorkerParams().requestThreadSafety() 81 | val endpoints = mutable.Set.empty[UcpEndpoint] 82 | val executorAddresses = new TrieMap[ExecutorId, ByteBuffer] 83 | 84 | private var allocatedClientWorkers: Array[UcxWorkerWrapper] = _ 85 | private var allocatedServerWorkers: Array[UcxWorkerWrapper] = _ 86 | 87 | private val registeredBlocks = new TrieMap[BlockId, Block] 88 | private var progressThread: Thread = _ 89 | var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _ 90 | 91 | private val errorHandler = new UcpEndpointErrorHandler { 92 | override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = { 93 | if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) { 94 | logWarning(s"Connection closed on ep: $ucpEndpoint") 95 | } else { 96 | logError(s"Ep $ucpEndpoint got an error: $errorString") 97 | } 98 | endpoints.remove(ucpEndpoint) 99 | ucpEndpoint.close() 100 | } 101 | } 102 | 103 | override def init(): ByteBuffer = { 104 | if (ucxShuffleConf == null) { 105 | ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) 106 | } 107 | 108 | val numEndpoints = ucxShuffleConf.numWorkers * 109 | ucxShuffleConf.getSparkConf.getInt("spark.executor.instances", 1) * 110 | ucxShuffleConf.numListenerThreads // Each listener thread creates backward endpoint 111 | logInfo(s"Creating UCX context with an estimated number of endpoints: $numEndpoints") 112 | 113 | val params = new UcpParams().requestAmFeature().setMtWorkersShared(true).setEstimatedNumEps(numEndpoints) 114 | .requestAmFeature().setConfig("USE_MT_MUTEX", "yes") 115 | 116 | if (ucxShuffleConf.useWakeup) { 117 | params.requestWakeupFeature() 118 | ucpWorkerParams.requestWakeupRX().requestWakeupTX().requestWakeupEdge() 119 | } 120 | 121 | ucxContext = new UcpContext(params) 122 | globalWorker = ucxContext.newWorker(ucpWorkerParams) 123 | hostBounceBufferMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) 124 | 125 | allocatedServerWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numListenerThreads) 126 | logInfo(s"Allocating ${ucxShuffleConf.numListenerThreads} server workers") 127 | for (i <- 0 until ucxShuffleConf.numListenerThreads) { 128 | val worker = ucxContext.newWorker(ucpWorkerParams) 129 | allocatedServerWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = false) 130 | } 131 | 132 | val Array(host, port) = ucxShuffleConf.listenerAddress.split(":") 133 | listener = globalWorker.newListener(new UcpListenerParams().setSockAddr( 134 | new InetSocketAddress(host, port.toInt)) 135 | .setConnectionHandler((ucpConnectionRequest: UcpConnectionRequest) => { 136 | endpoints.add(globalWorker.newEndpoint(new UcpEndpointParams().setConnectionRequest(ucpConnectionRequest) 137 | .setPeerErrorHandlingMode().setErrorHandler(errorHandler) 138 | .setName(s"Endpoint to ${ucpConnectionRequest.getClientId}"))) 139 | })) 140 | 141 | progressThread = new GlobalWorkerRpcThread(globalWorker, this) 142 | progressThread.start() 143 | 144 | allocatedClientWorkers = new Array[UcxWorkerWrapper](ucxShuffleConf.numWorkers) 145 | logInfo(s"Allocating ${ucxShuffleConf.numWorkers} client workers") 146 | for (i <- 0 until ucxShuffleConf.numWorkers) { 147 | val clientId: Long = ((i.toLong + 1L) << 32) | executorId 148 | ucpWorkerParams.setClientId(clientId) 149 | val worker = ucxContext.newWorker(ucpWorkerParams) 150 | allocatedClientWorkers(i) = UcxWorkerWrapper(worker, this, isClientWorker = true, clientId) 151 | } 152 | 153 | initialized = true 154 | logInfo(s"Started listener on ${listener.getAddress}") 155 | SerializationUtils.serializeInetAddress(listener.getAddress) 156 | } 157 | 158 | /** 159 | * Close all transport resources 160 | */ 161 | override def close(): Unit = { 162 | if (initialized) { 163 | endpoints.foreach(_.closeNonBlockingForce()) 164 | endpoints.clear() 165 | 166 | hostBounceBufferMemoryPool.close() 167 | 168 | allocatedClientWorkers.foreach(_.close()) 169 | allocatedServerWorkers.foreach(_.close()) 170 | 171 | if (listener != null) { 172 | listener.close() 173 | listener = null 174 | } 175 | 176 | if (progressThread != null) { 177 | progressThread.interrupt() 178 | progressThread.join(10) 179 | } 180 | 181 | if (globalWorker != null) { 182 | globalWorker.close() 183 | globalWorker = null 184 | } 185 | 186 | if (ucxContext != null) { 187 | ucxContext.close() 188 | ucxContext = null 189 | } 190 | } 191 | } 192 | 193 | /** 194 | * Add executor's worker address. For standalone testing purpose and for implementations that makes 195 | * connection establishment outside of UcxShuffleManager. 196 | */ 197 | override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { 198 | executorAddresses.put(executorId, workerAddress) 199 | allocatedClientWorkers.foreach(w => { 200 | w.getConnection(executorId) 201 | w.progressConnect() 202 | }) 203 | } 204 | 205 | def addExecutors(executorIdsToAddress: Map[ExecutorId, SerializableDirectBuffer]): Unit = { 206 | executorIdsToAddress.foreach { 207 | case (executorId, address) => executorAddresses.put(executorId, address.value) 208 | } 209 | } 210 | 211 | def preConnect(): Unit = { 212 | allocatedClientWorkers.foreach(_.preconnect()) 213 | } 214 | 215 | /** 216 | * Remove executor from communications. 217 | */ 218 | override def removeExecutor(executorId: Long): Unit = executorAddresses.remove(executorId) 219 | 220 | /** 221 | * Registers blocks using blockId on SERVER side. 222 | */ 223 | override def register(blockId: BlockId, block: Block): Unit = { 224 | registeredBlocks.put(blockId, block) 225 | } 226 | 227 | /** 228 | * Change location of underlying blockId in memory 229 | */ 230 | override def mutate(blockId: BlockId, newBlock: Block, callback: OperationCallback): Unit = { 231 | unregister(blockId) 232 | register(blockId, newBlock) 233 | callback.onComplete(new OperationResult { 234 | override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS 235 | 236 | override def getError: TransportError = null 237 | 238 | override def getStats: Option[OperationStats] = None 239 | 240 | override def getData: MemoryBlock = newBlock.getMemoryBlock 241 | }) 242 | 243 | } 244 | 245 | /** 246 | * Indicate that this blockId is not needed any more by an application. 247 | * Note: this is a blocking call. On return it's safe to free blocks memory. 248 | */ 249 | override def unregister(blockId: BlockId): Unit = { 250 | registeredBlocks.remove(blockId) 251 | } 252 | 253 | def unregisterShuffle(shuffleId: Int): Unit = { 254 | registeredBlocks.keysIterator.foreach(bid => 255 | if (bid.asInstanceOf[UcxShuffleBockId].shuffleId == shuffleId) { 256 | registeredBlocks.remove(bid) 257 | } 258 | ) 259 | } 260 | 261 | def unregisterAllBlocks(): Unit = { 262 | registeredBlocks.clear() 263 | } 264 | 265 | /** 266 | * Batch version of [[ fetchBlocksByBlockIds ]]. 267 | */ 268 | override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], 269 | resultBufferAllocator: BufferAllocator, 270 | callbacks: Seq[OperationCallback]): Seq[Request] = { 271 | allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt) 272 | .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) 273 | } 274 | 275 | def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { 276 | allocatedServerWorkers.foreach(w => w.connectByWorkerAddress(executorId, workerAddress)) 277 | } 278 | 279 | def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { 280 | val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) 281 | val blockIds = mutable.ArrayBuffer.empty[BlockId] 282 | 283 | // 1. Deserialize blockIds from header 284 | while (buffer.remaining() > 0) { 285 | val blockId = UcxShuffleBockId.deserialize(buffer) 286 | if (!registeredBlocks.contains(blockId)) { 287 | throw new UcxException(s"$blockId is not registered") 288 | } 289 | blockIds += blockId 290 | } 291 | 292 | val blocks = blockIds.map(bid => registeredBlocks(bid)) 293 | amData.close() 294 | allocatedServerWorkers((Thread.currentThread().getId % allocatedServerWorkers.length).toInt) 295 | .handleFetchBlockRequest(blocks, replyTag, replyExecutor) 296 | } 297 | 298 | 299 | /** 300 | * Progress outstanding operations. This routine is blocking (though may poll for event). 301 | * It's required to call this routine within same thread that submitted [[ fetchBlocksByBlockIds ]]. 302 | * 303 | * Return from this method guarantees that at least some operation was progressed. 304 | * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! 305 | */ 306 | override def progress(): Unit = { 307 | allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress() 308 | } 309 | 310 | def progressConnect(): Unit = { 311 | allocatedClientWorkers.par.foreach(_.progressConnect()) 312 | } 313 | } 314 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx 6 | 7 | import java.io.Closeable 8 | import java.util.concurrent.ConcurrentLinkedQueue 9 | import java.util.concurrent.atomic.AtomicInteger 10 | import scala.collection.concurrent.TrieMap 11 | import scala.util.Random 12 | import org.openucx.jucx.ucp._ 13 | import org.openucx.jucx.ucs.UcsConstants 14 | import org.openucx.jucx.ucs.UcsConstants.MEMORY_TYPE 15 | import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} 16 | import org.apache.spark.internal.Logging 17 | import org.apache.spark.shuffle.ucx.memory.UcxBounceBufferMemoryBlock 18 | import org.apache.spark.shuffle.ucx.utils.SerializationUtils 19 | import org.apache.spark.shuffle.utils.UnsafeUtils 20 | import org.apache.spark.unsafe.Platform 21 | import org.apache.spark.util.ThreadUtils 22 | 23 | import java.nio.ByteBuffer 24 | import scala.collection.parallel.ForkJoinTaskSupport 25 | 26 | 27 | class UcxFailureOperationResult(errorMsg: String) extends OperationResult { 28 | override def getStatus: OperationStatus.Value = OperationStatus.FAILURE 29 | 30 | override def getError: TransportError = new TransportError(errorMsg) 31 | 32 | override def getStats: Option[OperationStats] = None 33 | 34 | override def getData: MemoryBlock = null 35 | } 36 | 37 | class UcxAmDataMemoryBlock(ucpAmData: UcpAmData, offset: Long, size: Long, 38 | refCount: AtomicInteger) 39 | extends MemoryBlock(ucpAmData.getDataAddress + offset, size, true) with Logging { 40 | 41 | override def close(): Unit = { 42 | if (refCount.decrementAndGet() == 0) { 43 | ucpAmData.close() 44 | } 45 | } 46 | } 47 | 48 | class UcxRefCountMemoryBlock(baseBlock: MemoryBlock, offset: Long, size: Long, 49 | refCount: AtomicInteger) 50 | extends MemoryBlock(baseBlock.address + offset, size, true) with Logging { 51 | 52 | override def close(): Unit = { 53 | if (refCount.decrementAndGet() == 0) { 54 | baseBlock.close() 55 | } 56 | } 57 | } 58 | 59 | /** 60 | * Worker per thread wrapper, that maintains connection and progress logic. 61 | */ 62 | case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, isClientWorker: Boolean, 63 | id: Long = 0L) 64 | extends Closeable with Logging { 65 | 66 | private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] 67 | private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] 68 | private val tag = new AtomicInteger(Random.nextInt()) 69 | private val flushRequests = new ConcurrentLinkedQueue[UcpRequest]() 70 | 71 | private val ioThreadPool = ThreadUtils.newForkJoinPool("IO threads", 72 | transport.ucxShuffleConf.numIoThreads) 73 | private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) 74 | 75 | if (isClientWorker) { 76 | // Receive block data handler 77 | worker.setAmRecvHandler(1, 78 | (headerAddress: Long, headerSize: Long, ucpAmData: UcpAmData, _: UcpEndpoint) => { 79 | val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) 80 | val i = headerBuffer.getInt 81 | val data = requestData.remove(i) 82 | 83 | if (data.isEmpty) { 84 | throw new UcxException(s"No data for tag $i.") 85 | } 86 | 87 | val (callbacks, request, allocator) = data.get 88 | val stats = request.getStats.get.asInstanceOf[UcxStats] 89 | stats.receiveSize = ucpAmData.getLength 90 | 91 | // Header contains tag followed by sizes of blocks 92 | val numBlocks = (headerSize.toInt - UnsafeUtils.INT_SIZE) / UnsafeUtils.INT_SIZE 93 | 94 | var offset = 0 95 | val refCounts = new AtomicInteger(numBlocks) 96 | if (ucpAmData.isDataValid) { 97 | request.completed = true 98 | stats.endTime = System.nanoTime() 99 | logDebug(s"Received amData: $ucpAmData for tag $i " + 100 | s"in ${stats.getElapsedTimeNs} ns") 101 | 102 | for (b <- 0 until numBlocks) { 103 | val blockSize = headerBuffer.getInt 104 | if (callbacks(b) != null) { 105 | callbacks(b).onComplete(new OperationResult { 106 | override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS 107 | 108 | override def getError: TransportError = null 109 | 110 | override def getStats: Option[OperationStats] = Some(stats) 111 | 112 | override def getData: MemoryBlock = new UcxAmDataMemoryBlock(ucpAmData, offset, blockSize, refCounts) 113 | }) 114 | offset += blockSize 115 | } 116 | } 117 | if (callbacks.isEmpty) UcsConstants.STATUS.UCS_OK else UcsConstants.STATUS.UCS_INPROGRESS 118 | } else { 119 | val mem = allocator(ucpAmData.getLength) 120 | stats.amHandleTime = System.nanoTime() 121 | request.setRequest(worker.recvAmDataNonBlocking(ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, 122 | new UcxCallback() { 123 | override def onSuccess(r: UcpRequest): Unit = { 124 | request.completed = true 125 | stats.endTime = System.nanoTime() 126 | logDebug(s"Received rndv data of size: ${mem.size} for tag $i in " + 127 | s"${stats.getElapsedTimeNs} ns " + 128 | s"time from amHandle: ${System.nanoTime() - stats.amHandleTime} ns") 129 | for (b <- 0 until numBlocks) { 130 | val blockSize = headerBuffer.getInt 131 | callbacks(b).onComplete(new OperationResult { 132 | override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS 133 | 134 | override def getError: TransportError = null 135 | 136 | override def getStats: Option[OperationStats] = Some(stats) 137 | 138 | override def getData: MemoryBlock = new UcxRefCountMemoryBlock(mem, offset, blockSize, refCounts) 139 | }) 140 | offset += blockSize 141 | } 142 | 143 | } 144 | }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)) 145 | UcsConstants.STATUS.UCS_OK 146 | } 147 | }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG) 148 | } 149 | 150 | override def close(): Unit = { 151 | val closeRequests = connections.map { 152 | case (_, endpoint) => endpoint.closeNonBlockingForce() 153 | } 154 | while (!closeRequests.forall(_.isCompleted)) { 155 | progress() 156 | } 157 | ioThreadPool.shutdown() 158 | connections.clear() 159 | worker.close() 160 | } 161 | 162 | /** 163 | * Blocking progress until there's outstanding flush requests. 164 | */ 165 | def progressConnect(): Unit = { 166 | while (!flushRequests.isEmpty) { 167 | progress() 168 | flushRequests.removeIf(_.isCompleted) 169 | } 170 | logTrace(s"Flush completed. Number of connections: ${connections.keys.size}") 171 | } 172 | 173 | /** 174 | * The only place for worker progress 175 | */ 176 | def progress(): Int = worker.synchronized { 177 | worker.progress() 178 | } 179 | 180 | /** 181 | * Establish connections to known instances. 182 | */ 183 | def preconnect(): Unit = { 184 | transport.executorAddresses.keys.foreach(getConnection) 185 | progressConnect() 186 | } 187 | 188 | def connectByWorkerAddress(executorId: transport.ExecutorId, workerAddress: ByteBuffer): Unit = { 189 | logDebug(s"Worker $this connecting back to $executorId by worker address") 190 | val ep = worker.newEndpoint(new UcpEndpointParams().setName(s"Server connection to $executorId") 191 | .setUcpAddress(workerAddress)) 192 | connections.put(executorId, ep) 193 | } 194 | 195 | def getConnection(executorId: transport.ExecutorId): UcpEndpoint = { 196 | 197 | val startTime = System.currentTimeMillis() 198 | while (!transport.executorAddresses.contains(executorId)) { 199 | if (System.currentTimeMillis() - startTime > 200 | transport.ucxShuffleConf.getSparkConf.getTimeAsMs("spark.network.timeout", "100")) { 201 | throw new UcxException(s"Don't get a worker address for $executorId") 202 | } 203 | } 204 | 205 | connections.getOrElseUpdate(executorId, { 206 | val address = transport.executorAddresses(executorId) 207 | val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() 208 | .setSocketAddress(SerializationUtils.deserializeInetAddress(address)).sendClientId() 209 | .setErrorHandler(new UcpEndpointErrorHandler() { 210 | override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = { 211 | logError(s"Endpoint to $executorId got an error: $errorMsg") 212 | connections.remove(executorId) 213 | } 214 | }).setName(s"Endpoint to $executorId") 215 | 216 | logDebug(s"Worker $this connecting to Executor($executorId, " + 217 | s"${SerializationUtils.deserializeInetAddress(address)}") 218 | val ep = worker.newEndpoint(endpointParams) 219 | val header = Platform.allocateDirectBuffer(UnsafeUtils.LONG_SIZE) 220 | header.putLong(id) 221 | header.rewind() 222 | val workerAddress = worker.getAddress 223 | 224 | ep.sendAmNonBlocking(1, UcxUtils.getAddress(header), UnsafeUtils.LONG_SIZE, 225 | UcxUtils.getAddress(workerAddress), workerAddress.capacity().toLong, UcpConstants.UCP_AM_SEND_FLAG_EAGER, 226 | new UcxCallback() { 227 | override def onSuccess(request: UcpRequest): Unit = { 228 | header.clear() 229 | workerAddress.clear() 230 | } 231 | }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) 232 | flushRequests.add(ep.flushNonBlocking(null)) 233 | ep 234 | }) 235 | } 236 | 237 | def fetchBlocksByBlockIds(executorId: transport.ExecutorId, blockIds: Seq[BlockId], 238 | resultBufferAllocator: transport.BufferAllocator, 239 | callbacks: Seq[OperationCallback]): Seq[Request] = { 240 | val startTime = System.nanoTime() 241 | val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE 242 | val ep = getConnection(executorId) 243 | 244 | if (worker.getMaxAmHeaderSize <= 245 | headerSize + UnsafeUtils.INT_SIZE * blockIds.length) { 246 | val (b1, b2) = blockIds.splitAt(blockIds.length / 2) 247 | val (c1, c2) = callbacks.splitAt(callbacks.length / 2) 248 | val r1 = fetchBlocksByBlockIds(executorId, b1, resultBufferAllocator, c1) 249 | val r2 = fetchBlocksByBlockIds(executorId, b2, resultBufferAllocator, c2) 250 | return r1 ++ r2 251 | } 252 | 253 | val t = tag.incrementAndGet() 254 | 255 | val buffer = Platform.allocateDirectBuffer(headerSize + blockIds.map(_.serializedSize).sum) 256 | buffer.putInt(t) 257 | buffer.putLong(id) 258 | blockIds.foreach(b => b.serialize(buffer)) 259 | 260 | val request = new UcxRequest(null, new UcxStats()) 261 | requestData.put(t, (callbacks, request, resultBufferAllocator)) 262 | 263 | buffer.rewind() 264 | val address = UnsafeUtils.getAdress(buffer) 265 | val dataAddress = address + headerSize 266 | 267 | ep.sendAmNonBlocking(0, address, 268 | headerSize, dataAddress, buffer.capacity() - headerSize, 269 | UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { 270 | override def onSuccess(request: UcpRequest): Unit = { 271 | buffer.clear() 272 | logDebug(s"Sent message on $ep to $executorId to fetch ${blockIds.length} blocks on tag $t id $id" + 273 | s"in ${System.nanoTime() - startTime} ns") 274 | } 275 | }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) 276 | 277 | worker.progressRequest(ep.flushNonBlocking(null)) 278 | Seq(request) 279 | } 280 | 281 | def handleFetchBlockRequest(blocks: Seq[Block], replyTag: Int, replyExecutor: Long): Unit = try { 282 | val tagAndSizes = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE * blocks.length 283 | val resultMemory = transport.hostBounceBufferMemoryPool.get(tagAndSizes + blocks.map(_.getSize).sum) 284 | .asInstanceOf[UcxBounceBufferMemoryBlock] 285 | val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, 286 | resultMemory.size) 287 | resultBuffer.putInt(replyTag) 288 | 289 | var offset = 0 290 | val localBuffers = blocks.zipWithIndex.map { 291 | case (block, i) => 292 | resultBuffer.putInt(UnsafeUtils.INT_SIZE + i * UnsafeUtils.INT_SIZE, block.getSize.toInt) 293 | resultBuffer.position(tagAndSizes + offset) 294 | val localBuffer = resultBuffer.slice() 295 | offset += block.getSize.toInt 296 | localBuffer.limit(block.getSize.toInt) 297 | localBuffer 298 | } 299 | // Do parallel read of blocks 300 | val blocksCollection = if (transport.ucxShuffleConf.numIoThreads > 1) { 301 | val parCollection = blocks.indices.par 302 | parCollection.tasksupport = ioTaskSupport 303 | parCollection 304 | } else { 305 | blocks.indices 306 | } 307 | 308 | for (i <- blocksCollection) { 309 | blocks(i).getBlock(localBuffers(i)) 310 | } 311 | 312 | val startTime = System.nanoTime() 313 | val req = connections(replyExecutor).sendAmNonBlocking(1, resultMemory.address, tagAndSizes, 314 | resultMemory.address + tagAndSizes, resultMemory.size - tagAndSizes, 0, new UcxCallback { 315 | override def onSuccess(request: UcpRequest): Unit = { 316 | logTrace(s"Sent ${blocks.length} blocks of size: ${resultMemory.size} " + 317 | s"to tag $replyTag in ${System.nanoTime() - startTime} ns.") 318 | transport.hostBounceBufferMemoryPool.put(resultMemory) 319 | } 320 | 321 | override def onError(ucsStatus: Int, errorMsg: String): Unit = { 322 | logError(s"Failed to send $errorMsg") 323 | } 324 | }, new UcpRequestParams().setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) 325 | .setMemoryHandle(resultMemory.memory)) 326 | 327 | while (!req.isCompleted) { 328 | progress() 329 | } 330 | } catch { 331 | case ex: Throwable => logError(s"Failed to read and send data: $ex") 332 | } 333 | 334 | } 335 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.memory 6 | 7 | import java.io.Closeable 8 | import java.util.concurrent.atomic.AtomicInteger 9 | import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque} 10 | 11 | import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory} 12 | import org.openucx.jucx.ucs.UcsConstants 13 | import org.apache.spark.internal.Logging 14 | import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf} 15 | import org.apache.spark.util.Utils 16 | 17 | class UcxBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger, 18 | private[ucx] val memPool: MemoryPool, 19 | override val address: Long, override val size: Long) 20 | extends MemoryBlock(address, size, memory.getMemType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { 21 | 22 | override def close(): Unit = { 23 | memPool.put(this) 24 | } 25 | } 26 | 27 | 28 | /** 29 | * Base class to implement memory pool 30 | */ 31 | case class MemoryPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext, memoryType: Int) 32 | extends Closeable with Logging { 33 | 34 | protected def roundUpToTheNextPowerOf2(size: Long): Long = { 35 | if (size < ucxShuffleConf.minBufferSize) { 36 | ucxShuffleConf.minBufferSize 37 | } else { 38 | // Round up length to the nearest power of two 39 | var length = size 40 | length -= 1 41 | length |= length >> 1 42 | length |= length >> 2 43 | length |= length >> 4 44 | length |= length >> 8 45 | length |= length >> 16 46 | length += 1 47 | length 48 | } 49 | } 50 | 51 | protected val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]() 52 | 53 | private val memPool = this 54 | 55 | protected case class AllocatorStack(length: Long, memType: Int) extends Closeable { 56 | logInfo(s"Allocator stack of memType: $memType and size $length") 57 | private val stack = new ConcurrentLinkedDeque[UcxBounceBufferMemoryBlock] 58 | private val numAllocs = new AtomicInteger(0) 59 | private val memMapParams = new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length) 60 | 61 | private[memory] def get: UcxBounceBufferMemoryBlock = { 62 | var result = stack.pollFirst() 63 | if (result == null) { 64 | numAllocs.incrementAndGet() 65 | if (length < ucxShuffleConf.minRegistrationSize) { 66 | while (result == null) { 67 | preallocate((ucxShuffleConf.minRegistrationSize / length).toInt) 68 | result = stack.pollFirst() 69 | } 70 | } else { 71 | logTrace(s"Allocating buffer of size $length.") 72 | val memory = ucxContext.memoryMap(memMapParams) 73 | result = new UcxBounceBufferMemoryBlock(memory, new AtomicInteger(1), memPool, 74 | memory.getAddress, length) 75 | } 76 | } 77 | result 78 | } 79 | 80 | private[memory] def put(block: UcxBounceBufferMemoryBlock): Unit = { 81 | stack.add(block) 82 | } 83 | 84 | private[memory] def preallocate(numBuffers: Int): Unit = { 85 | logTrace(s"PreAllocating $numBuffers of size $length, " + 86 | s"totalSize: ${Utils.bytesToString(length * numBuffers)}.") 87 | val memory = ucxContext.memoryMap( 88 | new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length * numBuffers)) 89 | val refCount = new AtomicInteger(numBuffers) 90 | var offset = 0L 91 | (0 until numBuffers).foreach(_ => { 92 | stack.add(new UcxBounceBufferMemoryBlock(memory, refCount, memPool, memory.getAddress + offset, length)) 93 | offset += length 94 | }) 95 | } 96 | 97 | override def close(): Unit = { 98 | var numBuffers = 0 99 | stack.forEach(block => { 100 | block.refCount.decrementAndGet() 101 | if (block.memory.getNativeId != null) { 102 | block.memory.deregister() 103 | } 104 | numBuffers += 1 105 | }) 106 | logInfo(s"Closing $numBuffers buffers of size $length." + 107 | s"Number of allocations: ${numAllocs.get()}. Total size: ${Utils.bytesToString(length * numBuffers)}") 108 | stack.clear() 109 | } 110 | } 111 | 112 | override def close(): Unit = { 113 | allocatorMap.values.forEach(allocator => allocator.close()) 114 | allocatorMap.clear() 115 | } 116 | 117 | def get(size: Long): MemoryBlock = { 118 | val allocatorStack = allocatorMap.computeIfAbsent(roundUpToTheNextPowerOf2(size), 119 | s => AllocatorStack(s, memoryType)) 120 | val result = allocatorStack.get 121 | new UcxBounceBufferMemoryBlock(result.memory, result.refCount, memPool, result.address, size) 122 | } 123 | 124 | def put(mem: MemoryBlock): Unit = { 125 | mem match { 126 | case m: UcxBounceBufferMemoryBlock => 127 | val allocatorStack = allocatorMap.get(roundUpToTheNextPowerOf2(mem.size)) 128 | allocatorStack.put(m) 129 | case _ => logWarning(s"Unknown memory block $mem") 130 | } 131 | } 132 | 133 | def preAllocate(size: Long, numBuffers: Int): Unit = { 134 | val roundedSize = roundUpToTheNextPowerOf2(size) 135 | val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, 136 | s => AllocatorStack(s, memoryType)) 137 | allocatorStack.preallocate(numBuffers) 138 | } 139 | } 140 | 141 | class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) 142 | extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { 143 | 144 | ucxShuffleConf.preallocateBuffersMap.foreach{ 145 | case (bufferSize, count) => preAllocate(bufferSize, count) 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2022, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.perf 6 | 7 | import java.io.{File, RandomAccessFile} 8 | import java.net.InetSocketAddress 9 | import java.nio.ByteBuffer 10 | import java.nio.charset.StandardCharsets 11 | import java.nio.channels.FileChannel 12 | import java.util.concurrent.atomic.AtomicInteger 13 | import org.apache.commons.cli.{GnuParser, HelpFormatter, Options} 14 | import org.apache.spark.SparkConf 15 | import org.apache.spark.internal.Logging 16 | import org.apache.spark.network.util.JavaUtils 17 | import org.apache.spark.shuffle.ucx._ 18 | import org.apache.spark.shuffle.utils.UnsafeUtils 19 | import org.apache.spark.util.{ShutdownHookManager, ThreadUtils} 20 | 21 | import scala.collection.parallel.ForkJoinTaskSupport 22 | 23 | object UcxPerfBenchmark extends App with Logging { 24 | 25 | case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long, 26 | numIterations: Int, files: Array[File], numOutstanding: Int, randOrder: Boolean, 27 | numThreads: Int) 28 | 29 | private val HELP_OPTION = "h" 30 | private val ADDRESS_OPTION = "a" 31 | private val FILE_OPTION = "f" 32 | private val NUM_BLOCKS_OPTION = "n" 33 | private val SIZE_OPTION = "s" 34 | private val ITER_OPTION = "i" 35 | private val OUTSTANDING_OPTION = "o" 36 | private val RANDREAD_OPTION = "r" 37 | private val THREAD_OPTION = "t" 38 | 39 | private val sparkConf = new SparkConf() 40 | 41 | private def initOptions(): Options = { 42 | val options = new Options() 43 | options.addOption(HELP_OPTION, "help", false, 44 | "display help message") 45 | options.addOption(ADDRESS_OPTION, "address", true, 46 | "address of the listener on the remote host") 47 | options.addOption(NUM_BLOCKS_OPTION, "num-blocks", true, 48 | "number of blocks to transfer. Default: 1") 49 | options.addOption(SIZE_OPTION, "block-size", true, 50 | "size of block to transfer. Default: 1m") 51 | options.addOption(ITER_OPTION, "num-iterations", true, 52 | "number of iterations. Default: 1") 53 | options.addOption(OUTSTANDING_OPTION, "num-outstanding", true, 54 | "number of outstanding requests. Default: 1") 55 | options.addOption(FILE_OPTION, "files", true, "Files to transfer") 56 | options.addOption(THREAD_OPTION, "thread", true, "Number of threads. Default: 1") 57 | options.addOption(RANDREAD_OPTION, "random", false, "Read blocks in random order") 58 | options 59 | } 60 | 61 | private def parseOptions(args: Array[String]): PerfOptions = { 62 | val parser = new GnuParser() 63 | val options = initOptions() 64 | val cmd = parser.parse(options, args) 65 | 66 | if (cmd.hasOption(HELP_OPTION)) { 67 | new HelpFormatter().printHelp("UcxShufflePerfTool", options) 68 | System.exit(0) 69 | } 70 | 71 | val inetAddress = if (cmd.hasOption(ADDRESS_OPTION)) { 72 | val Array(host, port) = cmd.getOptionValue(ADDRESS_OPTION).split(":") 73 | new InetSocketAddress(host, Integer.parseInt(port)) 74 | } else { 75 | null 76 | } 77 | 78 | val files = if (cmd.hasOption(FILE_OPTION)) { 79 | cmd.getOptionValue(FILE_OPTION).split(",").map(f => new File(f)) 80 | } else { 81 | Array.empty[File] 82 | } 83 | 84 | val randOrder = if (cmd.hasOption(RANDREAD_OPTION)) { 85 | true 86 | } else { 87 | false 88 | } 89 | 90 | PerfOptions(inetAddress, 91 | Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")), 92 | JavaUtils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "1m")), 93 | Integer.parseInt(cmd.getOptionValue(ITER_OPTION, "1")), 94 | files, 95 | Integer.parseInt(cmd.getOptionValue(OUTSTANDING_OPTION, "1")), 96 | randOrder, 97 | Integer.parseInt(cmd.getOptionValue(THREAD_OPTION, "1"))) 98 | } 99 | 100 | def startClient(options: PerfOptions): Unit = { 101 | if (options.numThreads > 1) { 102 | sparkConf.set("spark.executor.cores", options.numThreads.toString) 103 | } 104 | val ucxTransport = new UcxShuffleTransport(new UcxShuffleConf(sparkConf), 0) 105 | ucxTransport.init() 106 | 107 | val hostString = options.remoteAddress.getHostString.getBytes(StandardCharsets.UTF_8) 108 | val address = ByteBuffer.allocateDirect(hostString.length + 4) 109 | address.putInt(options.remoteAddress.getPort) 110 | address.put(hostString) 111 | ucxTransport.addExecutor(1, address) 112 | 113 | val resultBufferAllocator = (size: Long) => ucxTransport.hostBounceBufferMemoryPool.get(size) 114 | val blocks = Array.ofDim[BlockId](options.numOutstanding) 115 | val callbacks = Array.ofDim[OperationCallback](options.numOutstanding) 116 | val requestInFlight = new AtomicInteger(0) 117 | val rnd = new scala.util.Random 118 | val blocksPerFile = options.numBlocks / options.files.length 119 | 120 | val blockCollection = if (options.numThreads > 1) { 121 | val parallelCollection = (0 until options.numBlocks by options.numOutstanding).par 122 | val threadPool = ThreadUtils.newForkJoinPool("Benchmark threads", options.numThreads) 123 | parallelCollection.tasksupport = new ForkJoinTaskSupport(threadPool) 124 | parallelCollection 125 | } else { 126 | 0 until options.numBlocks by options.numOutstanding 127 | } 128 | 129 | for (_ <- 0 until options.numIterations) { 130 | for (b <- blockCollection) { 131 | requestInFlight.set(options.numOutstanding) 132 | for (o <- 0 until options.numOutstanding) { 133 | val fileIdx = if (options.randOrder) rnd.nextInt(options.files.length) else (b+o) % options.files.length 134 | val blockIdx = if (options.randOrder) rnd.nextInt(blocksPerFile) else (b+o) % blocksPerFile 135 | blocks(o) = UcxShuffleBockId(0, fileIdx, blockIdx) 136 | callbacks(o) = (result: OperationResult) => { 137 | result.getData.close() 138 | val stats = result.getStats.get 139 | if (requestInFlight.decrementAndGet() == 0) { 140 | printf(s"Received ${options.numOutstanding} block of size: ${stats.recvSize} " + 141 | s"in ${stats.getElapsedTimeNs / 1000} usec. Bandwidth: %.2f Mb/s \n", 142 | (options.blockSize * options.numOutstanding * options.numThreads) / 143 | (1024.0 * 1024.0 * (stats.getElapsedTimeNs / 1e9))) 144 | } 145 | } 146 | } 147 | val requests = ucxTransport.fetchBlocksByBlockIds(1, blocks, resultBufferAllocator, callbacks) 148 | while (!requests.forall(_.isCompleted)) { 149 | ucxTransport.progress() 150 | } 151 | } 152 | } 153 | ucxTransport.close() 154 | } 155 | 156 | def startServer(options: PerfOptions): Unit = { 157 | 158 | if (options.files.isEmpty) { 159 | System.err.println(s"No file.") 160 | System.exit(-1) 161 | } 162 | options.files.foreach(f => if (!f.exists()) { 163 | System.err.println(s"File ${f.getPath} does not exist.") 164 | System.exit(-1) 165 | }) 166 | 167 | val ucxTransport = new UcxShuffleTransport(new UcxShuffleConf(sparkConf), 0) 168 | ucxTransport.init() 169 | val currentThread = Thread.currentThread() 170 | 171 | var channels = Array[FileChannel]() 172 | options.files.foreach(channels +:= new RandomAccessFile(_, "r").getChannel) 173 | 174 | ShutdownHookManager.addShutdownHook(()=>{ 175 | currentThread.interrupt() 176 | ucxTransport.close() 177 | }) 178 | 179 | for (fileIdx <- options.files.indices) { 180 | for (blockIdx <- 0 until (options.numBlocks / options.files.length)) { 181 | 182 | val blockId = UcxShuffleBockId(0, fileIdx, blockIdx) 183 | val block = new Block { 184 | private val channel = channels(fileIdx) 185 | private val fileOffset = blockIdx * options.blockSize 186 | 187 | override def getMemoryBlock: MemoryBlock = { 188 | val startTime = System.nanoTime() 189 | val memBlock = ucxTransport.hostBounceBufferMemoryPool.get(options.blockSize) 190 | val dstBuffer = UnsafeUtils.getByteBufferView(memBlock.address, options.blockSize.toInt) 191 | channel.read(dstBuffer, fileOffset) 192 | logTrace(s"Read $blockId block of size: ${options.blockSize} in ${System.nanoTime() - startTime} ns") 193 | memBlock 194 | } 195 | 196 | override def getSize: Long = options.blockSize 197 | 198 | override def getBlock(byteBuffer: ByteBuffer): Unit = { 199 | channel.read(byteBuffer, fileOffset) 200 | } 201 | } 202 | ucxTransport.register(blockId, block) 203 | } 204 | } 205 | while (!Thread.currentThread().isInterrupted) { 206 | Thread.sleep(10000) 207 | } 208 | } 209 | 210 | def start(): Unit = { 211 | val perfOptions = parseOptions(args) 212 | 213 | if (perfOptions.remoteAddress != null) { 214 | startClient(perfOptions) 215 | } else { 216 | startServer(perfOptions) 217 | } 218 | } 219 | 220 | start() 221 | } 222 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc 6 | 7 | import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} 8 | import org.openucx.jucx.ucs.UcsConstants 9 | import org.apache.spark.internal.Logging 10 | import org.apache.spark.shuffle.ucx.UcxShuffleTransport 11 | import org.apache.spark.shuffle.utils.UnsafeUtils 12 | import org.apache.spark.util.ThreadUtils 13 | 14 | class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransport) 15 | extends Thread with Logging { 16 | setDaemon(true) 17 | setName("Global worker progress thread") 18 | 19 | private val replyWorkersThreadPool = ThreadUtils.newDaemonFixedThreadPool(transport.ucxShuffleConf.numListenerThreads, 20 | "UcxListenerThread") 21 | 22 | // Main RPC thread. Submit each RPC request to separate thread and send reply back from separate worker. 23 | globalWorker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { 24 | val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) 25 | val replyTag = header.getInt 26 | val replyExecutor = header.getLong 27 | replyWorkersThreadPool.submit(new Runnable { 28 | override def run(): Unit = { 29 | transport.handleFetchBlockRequest(replyTag, amData, replyExecutor) 30 | } 31 | }) 32 | UcsConstants.STATUS.UCS_INPROGRESS 33 | }, UcpConstants.UCP_AM_FLAG_PERSISTENT_DATA | UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) 34 | 35 | 36 | // AM to get worker address for client worker and connect server workers to it 37 | globalWorker.setAmRecvHandler(1, (headerAddress: Long, headerSize: Long, amData: UcpAmData, 38 | _: UcpEndpoint) => { 39 | val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) 40 | val executorId = header.getLong 41 | val workerAddress = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) 42 | transport.connectServerWorkers(executorId, workerAddress) 43 | UcsConstants.STATUS.UCS_OK 44 | }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) 45 | 46 | override def run(): Unit = { 47 | if (transport.ucxShuffleConf.useWakeup) { 48 | while (!isInterrupted) { 49 | if (globalWorker.progress() == 0) { 50 | globalWorker.waitForEvents() 51 | } 52 | } 53 | } else { 54 | while (!isInterrupted) { 55 | globalWorker.progress() 56 | } 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc 6 | 7 | import scala.collection.immutable.HashMap 8 | import scala.collection.mutable 9 | 10 | import org.apache.spark.internal.Logging 11 | import org.apache.spark.rpc._ 12 | import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} 13 | import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer 14 | 15 | class UcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { 16 | 17 | private val endpoints: mutable.Set[RpcEndpointRef] = mutable.HashSet.empty 18 | private var executorToWorkerAddress = HashMap.empty[Long, SerializableDirectBuffer] 19 | 20 | 21 | override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { 22 | case message@ExecutorAdded(executorId: Long, endpoint: RpcEndpointRef, 23 | ucxWorkerAddress: SerializableDirectBuffer) => { 24 | // Driver receives a message from executor with it's workerAddress 25 | // 1. Introduce existing members of a cluster 26 | logDebug(s"Received $message") 27 | if (executorToWorkerAddress.nonEmpty) { 28 | val msg = IntroduceAllExecutors(executorToWorkerAddress) 29 | logDebug(s"replying $msg to $executorId") 30 | context.reply(msg) 31 | } 32 | executorToWorkerAddress += executorId -> ucxWorkerAddress 33 | // 2. For each existing member introduce newly joined executor. 34 | endpoints.foreach(ep => { 35 | logDebug(s"Sending $message to $ep") 36 | ep.send(message) 37 | }) 38 | logDebug(s"Connecting back to address: ${context.senderAddress}") 39 | endpoints.add(endpoint) 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc 6 | 7 | import org.apache.spark.internal.Logging 8 | import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} 9 | import org.apache.spark.shuffle.ucx.UcxShuffleTransport 10 | import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} 11 | import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer 12 | 13 | import java.util.concurrent.ExecutorService 14 | 15 | class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleTransport, 16 | executorService: ExecutorService) 17 | extends RpcEndpoint with Logging { 18 | 19 | override def receive: PartialFunction[Any, Unit] = { 20 | case ExecutorAdded(executorId: Long, _: RpcEndpointRef, 21 | ucxWorkerAddress: SerializableDirectBuffer) => 22 | logDebug(s"Received ExecutorAdded($executorId)") 23 | executorService.submit(new Runnable() { 24 | override def run(): Unit = { 25 | transport.addExecutor(executorId, ucxWorkerAddress.value) 26 | } 27 | }) 28 | case IntroduceAllExecutors(executorIdToWorkerAdresses: Map[Long, SerializableDirectBuffer]) => 29 | logDebug(s"Received IntroduceAllExecutors(${executorIdToWorkerAdresses.keys.mkString(",")}") 30 | executorService.submit(new Runnable() { 31 | override def run(): Unit = { 32 | transport.addExecutors(executorIdToWorkerAdresses) 33 | transport.preConnect() 34 | } 35 | }) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.rpc 6 | 7 | import org.apache.spark.rpc.RpcEndpointRef 8 | import org.apache.spark.shuffle.ucx.BlockId 9 | import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer 10 | 11 | object UcxRpcMessages { 12 | /** 13 | * Called from executor to driver, to introduce ucx worker address. 14 | */ 15 | case class ExecutorAdded(executorId: Long, endpoint: RpcEndpointRef, 16 | ucxWorkerAddress: SerializableDirectBuffer) 17 | 18 | /** 19 | * Reply from driver with all executors in the cluster with their worker addresses. 20 | */ 21 | case class IntroduceAllExecutors(executorIdToAddress: Map[Long, SerializableDirectBuffer]) 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/utils/SerializableDirectBuffer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.ucx.utils 6 | 7 | import java.io.{EOFException, ObjectInputStream, ObjectOutputStream} 8 | import java.net.InetSocketAddress 9 | import java.nio.ByteBuffer 10 | import java.nio.channels.Channels 11 | import java.nio.charset.StandardCharsets 12 | 13 | import org.apache.spark.internal.Logging 14 | import org.apache.spark.util.Utils 15 | 16 | /** 17 | * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make 18 | * it easier to pass ByteBuffers in case class messages. 19 | */ 20 | class SerializableDirectBuffer(@transient var buffer: ByteBuffer) extends Serializable 21 | with Logging { 22 | 23 | def value: ByteBuffer = buffer 24 | 25 | private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { 26 | val length = in.readInt() 27 | buffer = ByteBuffer.allocateDirect(length) 28 | var amountRead = 0 29 | val channel = Channels.newChannel(in) 30 | while (amountRead < length) { 31 | val ret = channel.read(buffer) 32 | if (ret == -1) { 33 | throw new EOFException("End of file before fully reading buffer") 34 | } 35 | amountRead += ret 36 | } 37 | buffer.rewind() // Allow us to read it later 38 | } 39 | 40 | private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { 41 | out.writeInt(buffer.limit()) 42 | buffer.rewind() 43 | while (buffer.position() < buffer.limit()) { 44 | out.write(buffer.get()) 45 | } 46 | buffer.rewind() // Allow us to write it again later 47 | } 48 | } 49 | 50 | class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)() extends Serializable 51 | with Logging { 52 | 53 | def value: ByteBuffer = buffer 54 | 55 | private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { 56 | val length = in.readInt() 57 | var amountRead = 0 58 | val channel = Channels.newChannel(in) 59 | while (amountRead < length) { 60 | val ret = channel.read(buffer) 61 | if (ret == -1) { 62 | throw new EOFException("End of file before fully reading buffer") 63 | } 64 | amountRead += ret 65 | } 66 | buffer.rewind() // Allow us to read it later 67 | } 68 | } 69 | 70 | 71 | object SerializationUtils { 72 | 73 | def deserializeInetAddress(workerAddress: ByteBuffer): InetSocketAddress = { 74 | val address = workerAddress.duplicate() 75 | address.rewind() 76 | val port = address.getInt() 77 | val host = StandardCharsets.UTF_8.decode(address.slice()).toString 78 | new InetSocketAddress(host, port) 79 | } 80 | 81 | def serializeInetAddress(address: InetSocketAddress): ByteBuffer = { 82 | val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort) 83 | val hostString = hostAddress.getHostName.getBytes(StandardCharsets.UTF_8) 84 | val result = ByteBuffer.allocateDirect(hostString.length + 4) 85 | result.putInt(hostAddress.getPort) 86 | result.put(hostString) 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/shuffle/utils/UnsafeUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2022, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. 3 | * See file LICENSE for terms. 4 | */ 5 | package org.apache.spark.shuffle.utils 6 | 7 | import java.lang.reflect.InvocationTargetException 8 | import java.nio.ByteBuffer 9 | import java.nio.channels.FileChannel 10 | 11 | import org.openucx.jucx.UcxException 12 | import sun.nio.ch.{DirectBuffer, FileChannelImpl} 13 | import org.apache.spark.internal.Logging 14 | 15 | object UnsafeUtils extends Logging { 16 | val INT_SIZE: Int = 4 17 | val LONG_SIZE: Int = 8 18 | 19 | private val mmap = classOf[FileChannelImpl].getDeclaredMethod("map0", classOf[Int], classOf[Long], classOf[Long]) 20 | mmap.setAccessible(true) 21 | 22 | private val unmmap = classOf[FileChannelImpl].getDeclaredMethod("unmap0", classOf[Long], classOf[Long]) 23 | unmmap.setAccessible(true) 24 | 25 | private val classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer") 26 | private val directBufferConstructor = classDirectByteBuffer.getDeclaredConstructor(classOf[Long], classOf[Int]) 27 | directBufferConstructor.setAccessible(true) 28 | 29 | def getByteBufferView(address: Long, length: Int): ByteBuffer with DirectBuffer = { 30 | directBufferConstructor.newInstance(address.asInstanceOf[Object], length.asInstanceOf[Object]) 31 | .asInstanceOf[ByteBuffer with DirectBuffer] 32 | } 33 | 34 | def getAdress(buffer: ByteBuffer): Long = { 35 | buffer.asInstanceOf[sun.nio.ch.DirectBuffer].address 36 | } 37 | 38 | def mmap(fileChannel: FileChannel, offset: Long, length: Long): Long = { 39 | try { 40 | mmap.invoke(fileChannel, 1.asInstanceOf[Object], offset.asInstanceOf[Object], length.asInstanceOf[Object]) 41 | .asInstanceOf[Long] 42 | } catch { 43 | case e: Exception => 44 | logError(s"Failed to mmap (${fileChannel.size()} $offset $length): $e") 45 | throw new UcxException(e.getMessage) 46 | } 47 | } 48 | 49 | def munmap(address: Long, length: Long): Unit = { 50 | try { 51 | unmmap.invoke(null, address.asInstanceOf[Object], length.asInstanceOf[Object]) 52 | } catch { 53 | case e@(_: IllegalAccessException | _: InvocationTargetException) => 54 | logError(e.getMessage) 55 | } 56 | } 57 | 58 | } 59 | --------------------------------------------------------------------------------