├── .github └── workflows │ ├── ISSUE_TEMPLATE.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── check_label.yml │ ├── pull_request.yml │ ├── release.yml │ └── snapshot.yml ├── .gitignore ├── .scalafmt.conf ├── .travis.yml ├── LICENSE ├── LICENSES └── Apache-2.0.txt ├── README.md ├── README_CN.md ├── codecov.yml ├── example ├── .gitignore ├── pom.xml └── src │ └── main │ ├── resources │ ├── data.csv │ ├── edge │ ├── log4j.properties │ ├── ssl │ │ ├── casigned.crt │ │ ├── casigned.key │ │ ├── casigned.pem │ │ ├── selfsigned.key │ │ ├── selfsigned.password │ │ └── selfsigned.pem │ └── vertex │ └── scala │ └── com │ └── vesoft │ └── nebula │ └── examples │ └── connector │ ├── NebulaSparkReaderExample.scala │ └── NebulaSparkWriterExample.scala ├── nebula-spark-common ├── pom.xml └── src │ ├── main │ └── scala │ │ └── com │ │ └── vesoft │ │ └── nebula │ │ └── connector │ │ ├── NebulaConfig.scala │ │ ├── NebulaEnum.scala │ │ ├── NebulaOptions.scala │ │ ├── NebulaUtils.scala │ │ ├── PartitionUtils.scala │ │ ├── Template.scala │ │ ├── exception │ │ └── Exception.scala │ │ ├── nebula │ │ ├── GraphProvider.scala │ │ └── MetaProvider.scala │ │ ├── package.scala │ │ ├── reader │ │ └── NebulaReader.scala │ │ ├── ssl │ │ ├── SSLEnum.scala │ │ └── SSLSignParams.scala │ │ ├── utils │ │ ├── AddressCheckUtil.scala │ │ └── SparkValidate.scala │ │ └── writer │ │ └── NebulaExecutor.scala │ └── test │ ├── resources │ └── log4j.properties │ └── scala │ └── com │ └── vesoft │ └── nebula │ └── connector │ ├── AddressCheckUtilsSuite.scala │ ├── DataTypeEnumSuite.scala │ ├── NebulaConfigSuite.scala │ ├── NebulaUtilsSuite.scala │ ├── PartitionUtilsSuite.scala │ ├── mock │ └── NebulaGraphMock.scala │ └── nebula │ ├── GraphProviderTest.scala │ └── MetaProviderTest.scala ├── nebula-spark-connector ├── .gitignore ├── pom.xml └── src │ ├── main │ └── scala │ │ └── com │ │ └── vesoft │ │ └── nebula │ │ └── connector │ │ ├── NebulaDataSource.scala │ │ ├── package.scala │ │ ├── reader │ │ ├── NebulaEdgePartitionReader.scala │ │ ├── NebulaNgqlEdgePartitionReader.scala │ │ ├── NebulaPartition.scala │ │ ├── NebulaPartitionReader.scala │ │ ├── NebulaSourceReader.scala │ │ └── NebulaVertexPartitionReader.scala │ │ └── writer │ │ ├── NebulaCommitMessage.scala │ │ ├── NebulaEdgeWriter.scala │ │ ├── NebulaSourceWriter.scala │ │ ├── NebulaVertexWriter.scala │ │ └── NebulaWriter.scala │ └── test │ ├── resources │ ├── docker-compose.yaml │ ├── edge.csv │ ├── log4j.properties │ └── vertex.csv │ └── scala │ └── com │ └── vesoft │ └── nebula │ └── connector │ ├── SparkVersionValidateSuite.scala │ ├── mock │ ├── NebulaGraphMock.scala │ └── SparkMock.scala │ ├── reader │ └── ReadSuite.scala │ └── writer │ ├── NebulaExecutorSuite.scala │ ├── WriteDeleteSuite.scala │ └── WriteInsertSuite.scala ├── nebula-spark-connector_2.2 ├── .gitignore ├── pom.xml └── src │ ├── main │ └── scala │ │ └── com │ │ └── vesoft │ │ └── nebula │ │ └── connector │ │ ├── NebulaDataSource.scala │ │ ├── package.scala │ │ ├── reader │ │ ├── NebulaEdgePartitionReader.scala │ │ ├── NebulaIterator.scala │ │ ├── NebulaNgqlEdgeReader.scala │ │ ├── NebulaNgqlRDD.scala │ │ ├── NebulaRDD.scala │ │ ├── NebulaRelation.scala │ │ ├── NebulaRelationProvider.scala │ │ └── NebulaVertexPartitionReader.scala │ │ └── writer │ │ ├── NebulaCommitMessage.scala │ │ ├── NebulaEdgeWriter.scala │ │ ├── NebulaInsertableRelation.scala │ │ ├── NebulaVertexWriter.scala │ │ ├── NebulaWriter.scala │ │ └── NebulaWriterResultRelation.scala │ └── test │ ├── resources │ ├── edge.csv │ ├── log4j.properties │ └── vertex.csv │ └── scala │ └── com │ └── vesoft │ └── nebula │ └── connector │ ├── SparkVersionValidateSuite.scala │ ├── mock │ ├── NebulaGraphMock.scala │ └── SparkMock.scala │ ├── reader │ └── ReadSuite.scala │ └── writer │ ├── WriteDeleteSuite.scala │ └── WriteInsertSuite.scala ├── nebula-spark-connector_3.0 ├── .gitignore ├── pom.xml └── src │ ├── main │ └── scala │ │ └── com │ │ └── vesoft │ │ └── nebula │ │ └── connector │ │ ├── NebulaDataSource.scala │ │ ├── NebulaTable.scala │ │ ├── package.scala │ │ ├── reader │ │ ├── NebulaEdgePartitionReader.scala │ │ ├── NebulaNgqlEdgePartitionReader.scala │ │ ├── NebulaPartitionReader.scala │ │ ├── NebulaPartitionReaderFactory.scala │ │ ├── NebulaVertexPartitionReader.scala │ │ └── SimpleScanBuilder.scala │ │ ├── utils │ │ └── Validations.scala │ │ └── writer │ │ ├── NebulaCommitMessage.scala │ │ ├── NebulaEdgeWriter.scala │ │ ├── NebulaSourceWriter.scala │ │ ├── NebulaVertexWriter.scala │ │ ├── NebulaWriter.scala │ │ └── NebulaWriterBuilder.scala │ └── test │ ├── resources │ ├── docker-compose.yaml │ ├── edge.csv │ ├── log4j.properties │ └── vertex.csv │ └── scala │ └── com │ └── vesoft │ └── nebula │ └── connector │ ├── SparkVersionValidateSuite.scala │ ├── mock │ ├── NebulaGraphMock.scala │ └── SparkMock.scala │ ├── reader │ └── ReadSuite.scala │ └── writer │ ├── NebulaExecutorSuite.scala │ ├── WriteDeleteSuite.scala │ └── WriteInsertSuite.scala └── pom.xml /.github/workflows/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | #### Expected behavior 2 | 3 | #### Actual behavior 4 | 5 | #### Steps to reproduce 6 | 7 | #### JVM version (e.g. `java -version`) 8 | 9 | #### Scala version (e.g. `scala -version`) 10 | 11 | #### OS version (e.g. `uname -a`) 12 | -------------------------------------------------------------------------------- /.github/workflows/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Motivation: 2 | 3 | Why you're making that change and what is the problem you're trying to solve. 4 | 5 | Modification: 6 | 7 | Describe the modifications you've done. 8 | 9 | Result: 10 | 11 | Fixes #. 12 | -------------------------------------------------------------------------------- /.github/workflows/check_label.yml: -------------------------------------------------------------------------------- 1 | name: Auto label 2 | 3 | on: 4 | issues: 5 | types: 6 | - reopened 7 | - opened 8 | - labeled 9 | - unlabeled 10 | - closed 11 | 12 | env: 13 | GH_PAT: ${{ secrets.GITHUB_TOKEN }} 14 | EVENT: ${{ toJSON(github.event)}} 15 | EVENT_NAME: ${{ github.event_name}} 16 | 17 | jobs: 18 | sync: 19 | name: auto label 20 | runs-on: ubuntu-latest 21 | steps: 22 | - uses: HarrisChu/auto_label@v1 23 | -------------------------------------------------------------------------------- /.github/workflows/pull_request.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Maven 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-maven 3 | 4 | name: pull_request 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: 11 | - master 12 | - 'v[0-9]+.*' 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up JDK 1.8 22 | uses: actions/setup-java@v4 23 | with: 24 | distribution: "temurin" 25 | java-version: "8" 26 | 27 | - name: Cache the Maven packages to speed up build 28 | uses: actions/cache@v2 29 | with: 30 | path: ~/.m2/repository 31 | key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} 32 | restore-keys: ${{ runner.os }}-maven- 33 | 34 | - name: Install nebula-graph 35 | run: | 36 | mkdir tmp 37 | pushd tmp 38 | git clone https://github.com/vesoft-inc/nebula-docker-compose.git 39 | pushd nebula-docker-compose/ 40 | cp ../../nebula-spark-connector/src/test/resources/docker-compose.yaml . 41 | docker compose up -d 42 | sleep 30 43 | docker compose ps 44 | popd 45 | popd 46 | 47 | - name: Build with Maven 48 | run: | 49 | mvn clean package -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 50 | mvn clean package -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 51 | mvn clean package -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 52 | 53 | 54 | - uses: codecov/codecov-action@v2 55 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Maven 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-maven 3 | 4 | name: release 5 | 6 | on: 7 | release: 8 | types: published 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up JDK 1.8 18 | uses: actions/setup-java@v4 19 | with: 20 | distribution: "temurin" 21 | java-version: "8" 22 | 23 | - name: Cache the Maven packages to speed up build 24 | uses: actions/cache@v2 25 | with: 26 | path: ~/.m2/repository 27 | key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} 28 | restore-keys: ${{ runner.os }}-maven- 29 | 30 | - name: Install nebula-graph 31 | run: | 32 | mkdir tmp 33 | pushd tmp 34 | git clone https://github.com/vesoft-inc/nebula-docker-compose.git 35 | pushd nebula-docker-compose/ 36 | cp ../../nebula-spark-connector/src/test/resources/docker-compose.yaml . 37 | docker compose up -d 38 | sleep 30 39 | popd 40 | popd 41 | 42 | - name: Deploy release for spark2.4 to Maven 43 | uses: samuelmeuli/action-maven-publish@v1 44 | with: 45 | gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} 46 | gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} 47 | nexus_username: ${{ secrets.OSSRH_USERNAME }} 48 | nexus_password: ${{ secrets.OSSRH_TOKEN }} 49 | maven_args: -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 50 | 51 | - name: Deploy release for spark2.2 to Maven 52 | uses: samuelmeuli/action-maven-publish@v1 53 | with: 54 | gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} 55 | gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} 56 | nexus_username: ${{ secrets.OSSRH_USERNAME }} 57 | nexus_password: ${{ secrets.OSSRH_TOKEN }} 58 | maven_args: -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 59 | 60 | - name: Deploy release for spark3.0 to Maven 61 | uses: samuelmeuli/action-maven-publish@v1 62 | with: 63 | gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} 64 | gpg_passphrase: "" 65 | nexus_username: ${{ secrets.OSSRH_USERNAME }} 66 | nexus_password: ${{ secrets.OSSRH_TOKEN }} 67 | maven_args: -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 68 | -------------------------------------------------------------------------------- /.github/workflows/snapshot.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Maven 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-maven 3 | 4 | name: snapshot 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | schedule: 10 | - cron: '0 6 * * *' 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up JDK 1.8 20 | uses: actions/setup-java@v4 21 | with: 22 | distribution: "temurin" 23 | java-version: "8" 24 | 25 | - name: Cache the Maven packages to speed up build 26 | uses: actions/cache@v2 27 | with: 28 | path: ~/.m2/repository 29 | key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} 30 | restore-keys: ${{ runner.os }}-maven- 31 | 32 | - name: Install nebula-graph 33 | run: | 34 | mkdir tmp 35 | pushd tmp 36 | git clone https://github.com/vesoft-inc/nebula-docker-compose.git 37 | pushd nebula-docker-compose/ 38 | cp ../../nebula-spark-connector/src/test/resources/docker-compose.yaml . 39 | docker compose up -d 40 | sleep 30 41 | popd 42 | popd 43 | 44 | - name: Deploy SNAPSHOT for spark2.4 to Sonatype 45 | uses: samuelmeuli/action-maven-publish@v1 46 | with: 47 | gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} 48 | gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} 49 | nexus_username: ${{ secrets.OSSRH_USERNAME }} 50 | nexus_password: ${{ secrets.OSSRH_TOKEN }} 51 | maven_args: -pl nebula-spark-connector -am -Pscala-2.11 -Pspark-2.4 52 | 53 | - name: Deploy SNAPSHOT for spark2.2 to Sonatype 54 | uses: samuelmeuli/action-maven-publish@v1 55 | with: 56 | gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} 57 | gpg_passphrase: "" 58 | nexus_username: ${{ secrets.OSSRH_USERNAME }} 59 | nexus_password: ${{ secrets.OSSRH_TOKEN }} 60 | maven_args: -pl nebula-spark-connector_2.2 -am -Pscala-2.11 -Pspark-2.2 61 | 62 | - name: Deploy SNAPSHOT for spark3.0 to Sonatype 63 | uses: samuelmeuli/action-maven-publish@v1 64 | with: 65 | gpg_private_key: ${{ secrets.JAVA_GPG_PRIVATE_KEY }} 66 | gpg_passphrase: ${{ secrets.JAVA_GPG_PASSPHRASE }} 67 | nexus_username: ${{ secrets.OSSRH_USERNAME }} 68 | nexus_password: ${{ secrets.OSSRH_TOKEN }} 69 | maven_args: -pl nebula-spark-connector_3.0 -am -Pscala-2.12 -Pspark-3.0 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.nar 17 | *.ear 18 | *.zip 19 | *.tar.gz 20 | *.rar 21 | 22 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 23 | hs_err_pid* 24 | 25 | # build target 26 | target/ 27 | 28 | # IDE 29 | .idea/ 30 | .eclipse/ 31 | *.iml 32 | 33 | spark-importer.ipr 34 | spark-importer.iws 35 | 36 | # mac 37 | .DS_Store 38 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | align = more 2 | maxColumn = 100 3 | docstrings = ScalaDoc 4 | assumeStandardLibraryStripMargin = true -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 vesoft inc. All rights reserved. 2 | # 3 | # This source code is licensed under Apache 2.0 License. 4 | 5 | language: java 6 | 7 | jdk: 8 | - oraclejdk11 9 | - openjdk8 10 | - openjdk11 11 | 12 | install: mvn clean compile package install -Dgpg.skip -Dmaven.javadoc.skip=true 13 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 vesoft inc. All rights reserved. 2 | # 3 | # This source code is licensed under Apache 2.0 License. 4 | 5 | # For more configuration details: 6 | # https://docs.codecov.io/docs/codecov-yaml 7 | 8 | # validate the configuration: 9 | # curl -X POST --data-binary @codecov.yml https://codecov.io/validate 10 | 11 | codecov: 12 | require_ci_to_pass: false 13 | 14 | -------------------------------------------------------------------------------- /example/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.nar 17 | *.ear 18 | *.zip 19 | *.tar.gz 20 | *.rar 21 | 22 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 23 | hs_err_pid* 24 | 25 | # build target 26 | target/ 27 | 28 | # IDE 29 | .idea/ 30 | .eclipse/ 31 | *.iml 32 | 33 | spark-importer.ipr 34 | spark-importer.iws 35 | 36 | .DS_Store 37 | -------------------------------------------------------------------------------- /example/src/main/resources/data.csv: -------------------------------------------------------------------------------- 1 | id,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13 2 | 1,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10 3 | 2,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10 4 | 3,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10 5 | 4,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10 6 | 5,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10 7 | 6,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10 8 | 7,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10 9 | 8,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10 10 | 9,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10 11 | 10,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10 12 | -1,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10 13 | -2,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10 14 | -3,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10 15 | -------------------------------------------------------------------------------- /example/src/main/resources/edge: -------------------------------------------------------------------------------- 1 | {"src":12345,"dst":23456,"degree":34, "descr": "aaa","timep": "2020-01-01"} 2 | {"src":11111,"dst":22222,"degree":33, "descr": "aaa","timep": "2020-01-01"} 3 | {"src":11111,"dst":33333,"degree":32, "descr": "a\baa","timep": "2020-01-01"} 4 | {"src":11111,"dst":44444,"degree":31, "descr": "aaa","timep": "2020-01-01"} 5 | {"src":22222,"dst":55555,"degree":30, "descr": "a\naa","timep": "2020-01-01"} 6 | {"src":33333,"dst":44444,"degree":29, "descr": "aaa","timep": "2020-01-01"} 7 | {"src":33333,"dst":55555,"degree":28, "descr": "aa\ta","timep": "2020-01-01"} 8 | {"src":44444,"dst":22222,"degree":27, "descr": "aaa","timep": "2020-01-01"} 9 | {"src":44444,"dst":55555,"degree":26, "descr": "aaa","timep": "2020-01-01"} 10 | {"src":22222,"dst":66666,"degree":25, "descr": "aaa","timep": "2020-01-01"} -------------------------------------------------------------------------------- /example/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Global logging configuration 2 | log4j.rootLogger=INFO, stdout 3 | # Console output... 4 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 5 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n 7 | -------------------------------------------------------------------------------- /example/src/main/resources/ssl/casigned.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIICljCCAX4CCQC9uuUY+ah8qzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJD 3 | TjAeFw0yMTA5MjkwNzM4MDRaFw0yNDAxMDIwNzM4MDRaMA0xCzAJBgNVBAYTAkNO 4 | MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuo7hKpcs+VQKbGRq0fUL 5 | +GcSfPfJ8mARtIeI8WfU0j1vI5KNujI//G2olOGEueDCw4OO0UbdjnsFpgj2awAo 6 | rj4ga2W6adQHK8qHY6q/Rdqv0oDCrcePMtQ8IwbFjNWOXC4bn7GcV7mzOkigdcj8 7 | UPkSeaqI9XxBRm3OoDX+T8h6cDLrm+ncKB8KKe/QApKH4frV3HYDqGtN49zuRs6F 8 | iurFbXDGVAZEdFEJl38IQJdmE2ASOzEHZbxWKzO/DZr/Z2+L1CuycZIwuITcnddx 9 | b2Byx/opwX4HlyODeUBbyDp+hd+GkasmIcpOlIDw9OXIvrcajKvzLEbqGt2ThsxX 10 | QwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQAxzxtbYBQ2WgBGrpzOX4TxsuSaigqo 11 | YJ5zbVEHtwbsbBTZ7UJvRc9IyhrOL5Ui4PJI85chh1GpGqOmMoYSaWdddaIroilQ 12 | 56bn5haB8ezAMnLXbPuf97UENO0RIkyzt63XPIUkDnwlzOukIq50qgsYEDuiioM/ 13 | wpCqSbMJ4iK/SlSSUWw3cKuAHvFfLv7hkC6AhvT7yfaCNDs29xEQUCD12XlIdFGH 14 | FjMgVMcvcIePQq5ZcmSfVMge9jPjPx/Nj9SVauF5z5pil9qHG4jyXPGThiiJ3CE4 15 | GU5d/Qfe7OeiYI3LaoVufZ5pZnR9nMnpzqU46w9gY7vgi6bAhNwsCDr3 16 | -----END CERTIFICATE----- 17 | -------------------------------------------------------------------------------- /example/src/main/resources/ssl/casigned.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAuo7hKpcs+VQKbGRq0fUL+GcSfPfJ8mARtIeI8WfU0j1vI5KN 3 | ujI//G2olOGEueDCw4OO0UbdjnsFpgj2awAorj4ga2W6adQHK8qHY6q/Rdqv0oDC 4 | rcePMtQ8IwbFjNWOXC4bn7GcV7mzOkigdcj8UPkSeaqI9XxBRm3OoDX+T8h6cDLr 5 | m+ncKB8KKe/QApKH4frV3HYDqGtN49zuRs6FiurFbXDGVAZEdFEJl38IQJdmE2AS 6 | OzEHZbxWKzO/DZr/Z2+L1CuycZIwuITcnddxb2Byx/opwX4HlyODeUBbyDp+hd+G 7 | kasmIcpOlIDw9OXIvrcajKvzLEbqGt2ThsxXQwIDAQABAoIBAH4SEBe4EaxsHp8h 8 | PQ6linFTNis9SDuCsHRPIzv/7tIksfZYE27Ahn0Pndz+ibMTMIrvXJQQT6j5ede6 9 | NswYT2Vwlnf9Rvw9TJtLQjMYMCoEnsyiNu047oxq4DjLWrTRnGKuxfwlCoI9++Bn 10 | NAhkyh3uM44EsIk0bugpTHj4A+PlbUPe7xdEI/6XpaZrRN9oiejJ4VxZAPgFGiTm 11 | uNF5qg16+0900Pfj5Y/M4vXmn+gq39PO/y0FlTpaoEuYZiZZS3xHGmSVhlt8LIgI 12 | 8MdMRaKTfNeNITaqgOWh9pAW4xmK48/KfLgNPQgtDHjMJpgM0BbcBOayOY8Eio0x 13 | Z66G2AECgYEA9vj/8Fm3CKn/ogNOO81y9kIs0iPcbjasMnQ3UXeOdD0z0+7TM86F 14 | Xj3GK/z2ecvY7skWtO5ZUbbxp4aB7omW8Ke9+q8XPzMEmUuAOTzxQkAOxdr++HXP 15 | TILy0hNX2cmiLQT1U60KoZHzPZ5o5hNIQPMt7hN12ERWcIfR/MUZa5UCgYEAwWCP 16 | 6Y7Zso1QxQR/qfjuILET3/xU+ZmqSRDvzJPEiGI3oeWNG4L6cKR+XTe0FWZBAmVk 17 | Qq/1qXmdBnf5S7azffoJe2+H/m3kHJSprIiAAWlBN2e+kFlNfBhtkgia5NvsrjRw 18 | al6mf/+weRD1FiPoZY3e1wBKoqro7aI8fE5gwXcCgYEAnEI05OROeyvb8qy2vf2i 19 | JA8AfsBzwkPTNWT0bxX+yqrCdO/hLyEWnubk0IYPiEYibgpK1JUNbDcctErVQJBL 20 | MN5gxBAt3C2yVi8/5HcbijgvYJ3LvnYDf7xGWAYnCkOZ2XQOqC+Oz2UhijYE1rUS 21 | fQ2fXMdxQzERo8c7Y/tstvUCgYBuixy5jwezokUB20h/ieXWmmOaL00EQmutyRjM 22 | AczfigXzbp3zlDRGIEJ8V1OCyClxjTR7SstMTlENWZgRSCfjZAP3pBJBx+AW1oUI 23 | NB+4rsqxOYUeT26T+gLo8DJbkb0C+Mcqh2D22tuu2ZrBRVWceDVjAq+nvbvZ3Fxn 24 | UwbMkQKBgQCxL3aA6ART6laIxT/ZqMhV0ZcaoDJogjF+4I4bhlO4ivWGWJ4RpEDn 25 | ziFb6+M/4pe4vCou9yuAof6WTKM8JG4rok0yxhN3V6QGP49TjtrfkkrEPCtB2LSI 26 | N1+YRSTrS5VDcl8h8JH7fpghRnXHONEyIqasYVqsbxKzNyLV/z2rkw== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /example/src/main/resources/ssl/casigned.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEGzCCAwOgAwIBAgIUDcmZFpL4PcdCXfLRBK8bR2vb39cwDQYJKoZIhvcNAQEL 3 | BQAwgZwxCzAJBgNVBAYTAkNOMREwDwYDVQQIDAhaaGVqaWFuZzERMA8GA1UEBwwI 4 | SGFuZ3pob3UxFDASBgNVBAoMC1Zlc29mdCBJbmMuMRAwDgYDVQQLDAdzZWN0aW9u 5 | MRYwFAYDVQQDDA1zaHlsb2NrIGh1YW5nMScwJQYJKoZIhvcNAQkBFhhzaHlsb2Nr 6 | Lmh1YW5nQHZlc29mdC5jb20wHhcNMjEwODE5MDkyNDQ3WhcNMjUwODE4MDkyNDQ3 7 | WjCBnDELMAkGA1UEBhMCQ04xETAPBgNVBAgMCFpoZWppYW5nMREwDwYDVQQHDAhI 8 | YW5nemhvdTEUMBIGA1UECgwLVmVzb2Z0IEluYy4xEDAOBgNVBAsMB3NlY3Rpb24x 9 | FjAUBgNVBAMMDXNoeWxvY2sgaHVhbmcxJzAlBgkqhkiG9w0BCQEWGHNoeWxvY2su 10 | aHVhbmdAdmVzb2Z0LmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB 11 | AMEAgpamCQHl+8JnUHI6/VmJHjDLYJLTliN/CwpFrhMqIVjJ8wG57WYLpXpn91Lz 12 | eHu52LkVzcikybIJ2a+LOTvnhNFdbmTbqDtrb+s6wM/sO+nF6tU2Av4e5zhyKoeR 13 | LL+rHMk3nymohbdN4djySFmOOU5A1O/4b0bZz4Ylu995kUawdiaEo13BzxxOC7Ik 14 | Gge5RyDcm0uLXZqTAPy5Sjv/zpOyj0AqL1CJUH7XBN9OMRhVU0ZX9nHWl1vgLRld 15 | J6XT17Y9QbbHhCNEdAmFE5kEFgCvZc+MungUYABlkvoj86TLmC/FMV6fWdxQssyd 16 | hS+ssfJFLaTDaEFz5a/Tr48CAwEAAaNTMFEwHQYDVR0OBBYEFK0GVrQx+wX1GCHy 17 | e+6fl4X+prmYMB8GA1UdIwQYMBaAFK0GVrQx+wX1GCHye+6fl4X+prmYMA8GA1Ud 18 | EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAHqP8P+ZUHmngviHLSSN1ln5 19 | Mx4BCkVeFRUaFx0yFXytV/iLXcG2HpFg3A9rAFoYgCDwi1xpsERnBZ/ShTv/eFOc 20 | IxBY5yggx3/lGi8tAgvUdarhd7mQO67UJ0V4YU3hAkbnZ8grHHXj+4hfgUpY4ok6 21 | yaed6HXwknBb9W8N1jZI8ginhkhjaeRCHdMiF+fBvNCtmeR1bCml1Uz7ailrpcaT 22 | Mf84+5VYuFEnaRZYWFNsWNCOBlJ/6/b3V10vMXzMmYHqz3xgAq0M3fVTFTzopnAX 23 | DLSzorL/dYVdqEDCQi5XI9YAlgWN4VeGzJI+glkLOCNzHxRNP6Qev+YI+7Uxz6I= 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /example/src/main/resources/ssl/selfsigned.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | Proc-Type: 4,ENCRYPTED 3 | DEK-Info: DES-EDE3-CBC,6D12ED8559E80FA3 4 | 5 | tv9epnwlt4dP6Q5ee0dACOyFA5BTwYTdoMykQRJrKGwfaNeXUXn+sQ/U/oFHp1Wx 6 | O8VZE+z2aHpiFSTw+Eh6MPt86X5yVG3tpeVO6dErvr8Kd+NpuI8zn7rNoOFRh8wD 7 | 33EFcQMLQPneDl10O18hooIoi0qwp1pd63hYZPwEhB3eOrM5Mnv9OVJs65bzYfyf 8 | Wku33YWYxeqlDvMCsou8PZnv/M2wYsr7+QoTcNmGKP45igMthMDBzwgF+q0p9ZZU 9 | N11c6ojAs01kfuqFf3vKfHNYe6zsBiNhnUuEy8enXSxD5E7tR/OI8aEzPLdk7fmN 10 | /UsMK2LE0Yd5iS3O1x/1ZjSBxJ+M/UzzCO692GTAiD6Hc13iJOavq/vt1mEPjfCD 11 | neF38Bhb5DfFi+UAHrz6EHMreamGCzP82us2maIs7mSTq7nXDZfbBc7mBDLAUUnT 12 | J6tlrTyc+DQXzkJa6jmbxJhcsWm6XvjIBEzSXVHxEDPLnZICQk3VXODjCXTD75Rg 13 | 0WaS78Ven7DW8wn07q3VzWAFDKaet3VI+TVTv7EfIavlfiA6LSshaENdFLeHahNE 14 | s/V/j5K3Pg6+WQcZRgOsfqIwUCSQxY13R6TTdaaCkLay5BggF5iiAO3pkqsJiadf 15 | w843Ak4USBptymJxoZgJyFtQHpQyNiFfsAbs9BaYbg2evvE7/VQhLk0gQ7HgQMeJ 16 | wgxEQqZQKDCCSugSzY1YEGXKnrZYCKyipzyyH936mE15zNwhYp/Pi2020+gmtP3h 17 | CDfcPs1yeLI2/1JuimafbuKsv9xchWa6ASU8p8Q7wTLtUj9ylLKyA4A/75pK0DXG 18 | Hv/q0O+UfhAMD438SoPBle7RSvIsDU1VjUqstlNybBglBZxGIME7/18+Ms7U32wh 19 | 4xFkZwxT2nqFgyk37tXMdMz9UBh12/AXR9NU4XY37C3Ao2TDT7/0DvU6KdJhsDpv 20 | rGcaC2zzhko+0CPrLlk52KbqP003JXiWvOSI+FylyPPDB/YGitmndJUuQblf3u/E 21 | l+tGi9MeSBQeWKV6D3AVnO05AZjfTUzSK0vw4DgNh5YPNJvLy31B7kDAS88vyGI1 22 | t6MBwjW4/tz/nS/p1Go3mSzBhPkIsCrZE+ar7lH8p8JqkLl4fXIMaVKIfyfJdzyS 23 | lkh3K7bOGDPegxxxaWdb+EnC7k+1R3EOU7uJFW61HyrGI3q6Y7kOl5aYSJ5Ge1Uv 24 | PycFWHWVTHq/R7HRE6HIJzGe/PnLIbStXLDFeivjfcYq1YaSaF8Vl+xg+0u3ULOl 25 | P6IuPTph6dlcgttRZVl3ETcF0T+2wfbUwgjf0ZiguCJfR2jLGhPl1KBg0Kd9cTSY 26 | zI3YMMd2G8hApt/QFlm4Ry8CqaJUmDcjDNIJT3M+RldUgfz37NsX05cA5e9+I1AL 27 | 2406F/v5U9gWsYx7HuwJtQrDzYYDbl1GD4H+qHFJE5JYhPP4AyWYxJ1NR5dqyvrt 28 | +3r5+xlwZrS76c10RsBWL7th8ZEzRxOZxbtLwbf4bG/tIGfQP2sTnWwA+qym6b2S 29 | sRduqOTP+xwnhOq/ZKn8lfsDfhT8CPnKHBsd09kM9y/UWuxFe0upLydRLE/Wsb9s 30 | -----END RSA PRIVATE KEY----- 31 | -------------------------------------------------------------------------------- /example/src/main/resources/ssl/selfsigned.password: -------------------------------------------------------------------------------- 1 | vesoft 2 | -------------------------------------------------------------------------------- /example/src/main/resources/ssl/selfsigned.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEGzCCAwOgAwIBAgIUDcmZFpL4PcdCXfLRBK8bR2vb39cwDQYJKoZIhvcNAQEL 3 | BQAwgZwxCzAJBgNVBAYTAkNOMREwDwYDVQQIDAhaaGVqaWFuZzERMA8GA1UEBwwI 4 | SGFuZ3pob3UxFDASBgNVBAoMC1Zlc29mdCBJbmMuMRAwDgYDVQQLDAdzZWN0aW9u 5 | MRYwFAYDVQQDDA1zaHlsb2NrIGh1YW5nMScwJQYJKoZIhvcNAQkBFhhzaHlsb2Nr 6 | Lmh1YW5nQHZlc29mdC5jb20wHhcNMjEwODE5MDkyNDQ3WhcNMjUwODE4MDkyNDQ3 7 | WjCBnDELMAkGA1UEBhMCQ04xETAPBgNVBAgMCFpoZWppYW5nMREwDwYDVQQHDAhI 8 | YW5nemhvdTEUMBIGA1UECgwLVmVzb2Z0IEluYy4xEDAOBgNVBAsMB3NlY3Rpb24x 9 | FjAUBgNVBAMMDXNoeWxvY2sgaHVhbmcxJzAlBgkqhkiG9w0BCQEWGHNoeWxvY2su 10 | aHVhbmdAdmVzb2Z0LmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB 11 | AMEAgpamCQHl+8JnUHI6/VmJHjDLYJLTliN/CwpFrhMqIVjJ8wG57WYLpXpn91Lz 12 | eHu52LkVzcikybIJ2a+LOTvnhNFdbmTbqDtrb+s6wM/sO+nF6tU2Av4e5zhyKoeR 13 | LL+rHMk3nymohbdN4djySFmOOU5A1O/4b0bZz4Ylu995kUawdiaEo13BzxxOC7Ik 14 | Gge5RyDcm0uLXZqTAPy5Sjv/zpOyj0AqL1CJUH7XBN9OMRhVU0ZX9nHWl1vgLRld 15 | J6XT17Y9QbbHhCNEdAmFE5kEFgCvZc+MungUYABlkvoj86TLmC/FMV6fWdxQssyd 16 | hS+ssfJFLaTDaEFz5a/Tr48CAwEAAaNTMFEwHQYDVR0OBBYEFK0GVrQx+wX1GCHy 17 | e+6fl4X+prmYMB8GA1UdIwQYMBaAFK0GVrQx+wX1GCHye+6fl4X+prmYMA8GA1Ud 18 | EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAHqP8P+ZUHmngviHLSSN1ln5 19 | Mx4BCkVeFRUaFx0yFXytV/iLXcG2HpFg3A9rAFoYgCDwi1xpsERnBZ/ShTv/eFOc 20 | IxBY5yggx3/lGi8tAgvUdarhd7mQO67UJ0V4YU3hAkbnZ8grHHXj+4hfgUpY4ok6 21 | yaed6HXwknBb9W8N1jZI8ginhkhjaeRCHdMiF+fBvNCtmeR1bCml1Uz7ailrpcaT 22 | Mf84+5VYuFEnaRZYWFNsWNCOBlJ/6/b3V10vMXzMmYHqz3xgAq0M3fVTFTzopnAX 23 | DLSzorL/dYVdqEDCQi5XI9YAlgWN4VeGzJI+glkLOCNzHxRNP6Qev+YI+7Uxz6I= 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /example/src/main/resources/vertex: -------------------------------------------------------------------------------- 1 | {"id":12,"name":"Tom","age":20,"born": "2000-01-01"} 2 | {"id":13,"name":"Bob","age":21,"born": "1999-01-02"} 3 | {"id":14,"name":"Jane","age":22,"born": "1998-01-03"} 4 | {"id":15,"name":"Jena","age":23,"born": "1997-01-04"} 5 | {"id":16,"name":"Nic","age":24,"born": "1996-01-05"} 6 | {"id":17,"name":"Mei","age":25,"born": "1995-01-06"} 7 | {"id":18,"name":"HH","age":26,"born": "1994-01-07"} 8 | {"id":19,"name":"Tyler","age":27,"born": "1993-01-08"} 9 | {"id":20,"name":"Ber","age":28,"born": "1992-01-09"} 10 | {"id":21,"name":"Mercy","age":29,"born": "1991-01-10"} -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/NebulaEnum.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | object DataTypeEnum extends Enumeration { 9 | 10 | type DataType = Value 11 | val VERTEX = Value("vertex") 12 | val EDGE = Value("edge") 13 | 14 | def validDataType(dataType: String): Boolean = 15 | values.exists(_.toString.equalsIgnoreCase(dataType)) 16 | } 17 | 18 | object KeyPolicy extends Enumeration { 19 | 20 | type POLICY = Value 21 | val HASH = Value("hash") 22 | val UUID = Value("uuid") 23 | } 24 | 25 | object OperaType extends Enumeration { 26 | 27 | type Operation = Value 28 | val READ = Value("read") 29 | val WRITE = Value("write") 30 | } 31 | 32 | object WriteMode extends Enumeration { 33 | 34 | type Mode = Value 35 | val INSERT = Value("insert") 36 | val UPDATE = Value("update") 37 | val DELETE = Value("delete") 38 | } 39 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/PartitionUtils.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | object PartitionUtils { 9 | 10 | /** 11 | * compute each spark partition should assign how many nebula parts 12 | * 13 | * @param index spark partition index 14 | * @param nebulaTotalPart nebula space partition number 15 | * @param sparkPartitionNum spark total partition number 16 | * @return the list of nebula partitions assign to spark index partition 17 | */ 18 | def getScanParts(index: Int, nebulaTotalPart: Int, sparkPartitionNum: Int): List[Int] = 19 | (index to nebulaTotalPart by sparkPartitionNum).toList 20 | 21 | } 22 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/Template.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | object NebulaTemplate { 9 | 10 | private[connector] val BATCH_INSERT_TEMPLATE = "INSERT %s `%s`(%s) VALUES %s" 11 | private[connector] val BATCH_INSERT_NO_OVERWRITE_TEMPLATE = 12 | "INSERT %s IF NOT EXISTS `%s`(%s) VALUES %s" 13 | private[connector] val VERTEX_VALUE_TEMPLATE = "%s: (%s)" 14 | private[connector] val VERTEX_VALUE_TEMPLATE_WITH_POLICY = "%s(\"%s\"): (%s)" 15 | private[connector] val ENDPOINT_TEMPLATE = "%s(\"%s\")" 16 | private[connector] val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%s->%s: (%s)" 17 | private[connector] val EDGE_VALUE_TEMPLATE = "%s->%s@%d: (%s)" 18 | private[connector] val USE_TEMPLATE = "USE %s" 19 | 20 | private[connector] val UPDATE_VERTEX_TEMPLATE = "UPDATE %s ON `%s` %s SET %s" 21 | private[connector] val UPDATE_EDGE_TEMPLATE = "UPDATE %s ON `%s` %s->%s@%d SET %s" 22 | private[connector] val UPDATE_VALUE_TEMPLATE = "`%s`=%s" 23 | 24 | private[connector] val DELETE_VERTEX_TEMPLATE = "DELETE VERTEX %s" 25 | private[connector] val DELETE_VERTEX_WITH_EDGE_TEMPLATE = "DELETE VERTEX %s WITH EDGE" 26 | private[connector] val DELETE_EDGE_TEMPLATE = "DELETE EDGE `%s` %s" 27 | private[connector] val EDGE_ENDPOINT_TEMPLATE = "%s->%s@%d" 28 | } 29 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/exception/Exception.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.exception 7 | 8 | import com.facebook.thrift.TException 9 | 10 | /*** 11 | * An exception thrown if nebula client connects failed. 12 | */ 13 | class GraphConnectException(message: String, cause: Throwable = null) 14 | extends TException(message, cause) 15 | 16 | /** 17 | * An exception thrown if a required option is missing form [[NebulaOptions]] 18 | */ 19 | class IllegalOptionException(message: String, cause: Throwable = null) 20 | extends IllegalArgumentException(message, cause) 21 | 22 | /** 23 | * An exception thrown if nebula execution failed. 24 | */ 25 | class GraphExecuteException(message: String, cause: Throwable = null) 26 | extends TException(message, cause) 27 | 28 | /** 29 | * An exception thrown if nebula execution occur rpc exception. 30 | */ 31 | class NebulaRPCException(message: String, cause: Throwable = null) 32 | extends RuntimeException(message, cause) 33 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/nebula/GraphProvider.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.nebula 7 | 8 | import com.vesoft.nebula.client.graph.NebulaPoolConfig 9 | import com.vesoft.nebula.client.graph.data.{ 10 | CASignedSSLParam, 11 | HostAddress, 12 | ResultSet, 13 | SelfSignedSSLParam 14 | } 15 | import com.vesoft.nebula.client.graph.net.{NebulaPool, Session} 16 | import com.vesoft.nebula.connector.Address 17 | import com.vesoft.nebula.connector.exception.GraphConnectException 18 | import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams} 19 | import org.apache.log4j.Logger 20 | 21 | import scala.collection.JavaConverters._ 22 | 23 | /** 24 | * GraphProvider for Nebula Graph Service 25 | */ 26 | class GraphProvider(addresses: List[Address], 27 | user: String, 28 | password: String, 29 | timeout: Int, 30 | enableSSL: Boolean = false, 31 | sslSignType: String = null, 32 | caSignParam: CASSLSignParams = null, 33 | selfSignParam: SelfSSLSignParams = null) 34 | extends AutoCloseable 35 | with Serializable { 36 | @transient private[this] lazy val LOG = Logger.getLogger(this.getClass) 37 | 38 | @transient val nebulaPoolConfig = new NebulaPoolConfig 39 | 40 | @transient val pool: NebulaPool = new NebulaPool 41 | val address = addresses.map { case (host, port) => new HostAddress(host, port) } 42 | nebulaPoolConfig.setMaxConnSize(1) 43 | nebulaPoolConfig.setTimeout(timeout) 44 | 45 | if (enableSSL) { 46 | nebulaPoolConfig.setEnableSsl(enableSSL) 47 | SSLSignType.withName(sslSignType) match { 48 | case SSLSignType.CA => 49 | nebulaPoolConfig.setSslParam( 50 | new CASignedSSLParam(caSignParam.caCrtFilePath, 51 | caSignParam.crtFilePath, 52 | caSignParam.keyFilePath)) 53 | case SSLSignType.SELF => 54 | nebulaPoolConfig.setSslParam( 55 | new SelfSignedSSLParam(selfSignParam.crtFilePath, 56 | selfSignParam.keyFilePath, 57 | selfSignParam.password)) 58 | case _ => throw new IllegalArgumentException("ssl sign type is not supported") 59 | } 60 | } 61 | val randAddr = scala.util.Random.shuffle(address) 62 | pool.init(randAddr.asJava, nebulaPoolConfig) 63 | 64 | lazy val session: Session = pool.getSession(user, password, true) 65 | 66 | /** 67 | * release session 68 | */ 69 | def releaseGraphClient(): Unit = 70 | session.release() 71 | 72 | override def close(): Unit = { 73 | releaseGraphClient() 74 | pool.close() 75 | } 76 | 77 | /** 78 | * switch space 79 | * 80 | * @param user 81 | * @param password 82 | * @param space 83 | * @return if execute succeed 84 | */ 85 | def switchSpace(space: String): Boolean = { 86 | val switchStatment = s"use $space" 87 | LOG.info(s"switch space $space") 88 | val result = submit(switchStatment) 89 | if (!result.isSucceeded) { 90 | LOG.error(s"switch space $space failed, ${result.getErrorMessage}") 91 | throw new RuntimeException(s"switch space $space failed, ${result.getErrorMessage}") 92 | } 93 | true 94 | } 95 | 96 | /** 97 | * execute the statement 98 | * 99 | * @param statement insert tag/edge statement 100 | * @return execute result 101 | */ 102 | def submit(statement: String): ResultSet = 103 | session.execute(statement) 104 | } 105 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/nebula/MetaProvider.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.nebula 7 | 8 | import com.vesoft.nebula.PropertyType 9 | import com.vesoft.nebula.client.graph.data.{ 10 | CASignedSSLParam, 11 | HostAddress, 12 | SSLParam, 13 | SelfSignedSSLParam 14 | } 15 | import com.vesoft.nebula.client.meta.MetaClient 16 | import com.vesoft.nebula.connector.{Address, DataTypeEnum} 17 | import com.vesoft.nebula.connector.ssl.{CASSLSignParams, SSLSignType, SelfSSLSignParams} 18 | import com.vesoft.nebula.meta.Schema 19 | 20 | import scala.collection.JavaConverters._ 21 | import scala.collection.mutable 22 | 23 | class MetaProvider(addresses: List[Address], 24 | timeout: Int, 25 | connectionRetry: Int, 26 | executionRetry: Int, 27 | enableSSL: Boolean, 28 | sslSignType: String = null, 29 | caSignParam: CASSLSignParams, 30 | selfSignParam: SelfSSLSignParams) 31 | extends AutoCloseable 32 | with Serializable { 33 | 34 | val metaAddress = addresses.map(address => new HostAddress(address._1, address._2)).asJava 35 | @transient var client: MetaClient = null 36 | @transient var sslParam: SSLParam = null 37 | if (enableSSL) { 38 | SSLSignType.withName(sslSignType) match { 39 | case SSLSignType.CA => 40 | sslParam = new CASignedSSLParam(caSignParam.caCrtFilePath, 41 | caSignParam.crtFilePath, 42 | caSignParam.keyFilePath) 43 | case SSLSignType.SELF => 44 | sslParam = new SelfSignedSSLParam(selfSignParam.crtFilePath, 45 | selfSignParam.keyFilePath, 46 | selfSignParam.password) 47 | case _ => throw new IllegalArgumentException("ssl sign type is not supported") 48 | } 49 | client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry, true, sslParam) 50 | } else { 51 | client = new MetaClient(metaAddress, timeout, connectionRetry, executionRetry) 52 | } 53 | client.connect() 54 | 55 | /** 56 | * get the partition num of nebula space 57 | */ 58 | def getPartitionNumber(space: String): Int = { 59 | client.getPartsAlloc(space).size() 60 | } 61 | 62 | /** 63 | * get the vid type of nebula space 64 | */ 65 | def getVidType(space: String): VidType.Value = { 66 | val vidType = client.getSpace(space).getProperties.getVid_type.getType 67 | if (vidType == PropertyType.FIXED_STRING) VidType.STRING 68 | else VidType.INT 69 | } 70 | 71 | /** 72 | * get {@link Schema} of nebula tag 73 | * 74 | * @param space 75 | * @param tag 76 | * @return schema 77 | */ 78 | def getTag(space: String, tag: String): Schema = { 79 | client.getTag(space, tag) 80 | } 81 | 82 | /** 83 | * get {@link Schema} of nebula edge type 84 | * 85 | * @param space 86 | * @param edge 87 | * @return schema 88 | */ 89 | def getEdge(space: String, edge: String): Schema = { 90 | client.getEdge(space, edge) 91 | } 92 | 93 | /** 94 | * get tag's schema info 95 | * 96 | * @param space 97 | * @param tag 98 | * @return Map, property name -> data type {@link PropertyType} 99 | */ 100 | def getTagSchema(space: String, tag: String): Map[String, Integer] = { 101 | val tagSchema = client.getTag(space, tag) 102 | val schema = new mutable.HashMap[String, Integer] 103 | 104 | val columns = tagSchema.getColumns 105 | for (colDef <- columns.asScala) { 106 | schema.put(new String(colDef.getName), colDef.getType.getType.getValue) 107 | } 108 | schema.toMap 109 | } 110 | 111 | /** 112 | * get edge's schema info 113 | * 114 | * @param space 115 | * @param edge 116 | * @return Map, property name -> data type {@link PropertyType} 117 | */ 118 | def getEdgeSchema(space: String, edge: String): Map[String, Integer] = { 119 | val edgeSchema = client.getEdge(space, edge) 120 | val schema = new mutable.HashMap[String, Integer] 121 | 122 | val columns = edgeSchema.getColumns 123 | for (colDef <- columns.asScala) { 124 | schema.put(new String(colDef.getName), colDef.getType.getType.getValue) 125 | } 126 | schema.toMap 127 | } 128 | 129 | /** 130 | * check if a label is Tag or Edge 131 | */ 132 | def getLabelType(space: String, label: String): DataTypeEnum.Value = { 133 | val tags = client.getTags(space) 134 | tags.asScala.collectFirst { 135 | case tag if new String(tag.getTag_name).equals(label) => DataTypeEnum.VERTEX 136 | }.orElse { 137 | client.getEdges(space).asScala.collectFirst { 138 | case edge if new String(edge.getEdge_name).equals(label) => DataTypeEnum.EDGE 139 | } 140 | }.orNull 141 | } 142 | 143 | override def close(): Unit = { 144 | client.close() 145 | } 146 | 147 | } 148 | 149 | object VidType extends Enumeration { 150 | type Type = Value 151 | 152 | val STRING = Value("STRING") 153 | val INT = Value("INT") 154 | } 155 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/package.scala: -------------------------------------------------------------------------------- 1 | package com.vesoft.nebula 2 | 3 | import com.vesoft.nebula.connector.writer.NebulaExecutor 4 | 5 | package object connector { 6 | 7 | type Address = (String, Int) 8 | type NebulaType = Int 9 | type Prop = List[Any] 10 | type PropertyNames = List[String] 11 | type PropertyValues = List[Any] 12 | 13 | type VertexID = Long 14 | type VertexIDSlice = String 15 | type NebulaGraphxVertex = (VertexID, PropertyValues) 16 | type NebulaGraphxEdge = org.apache.spark.graphx.Edge[(EdgeRank, Prop)] 17 | type EdgeRank = Long 18 | 19 | case class NebulaVertex(vertexIDSlice: VertexIDSlice, values: PropertyValues) { 20 | def propertyValues = values.mkString(", ") 21 | 22 | override def toString: String = { 23 | s"Vertex ID: ${vertexIDSlice}, Values: ${values.mkString(", ")}" 24 | } 25 | } 26 | 27 | case class NebulaVertices(propNames: PropertyNames, 28 | values: List[NebulaVertex], 29 | policy: Option[KeyPolicy.Value]) { 30 | 31 | def propertyNames: String = NebulaExecutor.escapePropName(propNames).mkString(",") 32 | 33 | override def toString: String = { 34 | s"Vertices: " + 35 | s"Property Names: ${propNames.mkString(", ")}" + 36 | s"Vertex Values: ${values.mkString(", ")} " + 37 | s"with policy: ${policy}" 38 | } 39 | } 40 | 41 | case class NebulaEdge(source: VertexIDSlice, 42 | target: VertexIDSlice, 43 | rank: Option[EdgeRank], 44 | values: PropertyValues) { 45 | def propertyValues: String = values.mkString(", ") 46 | 47 | override def toString: String = { 48 | s"Edge: ${source}->${target}@${rank} values: ${propertyValues}" 49 | } 50 | } 51 | 52 | case class NebulaEdges(propNames: PropertyNames, 53 | values: List[NebulaEdge], 54 | sourcePolicy: Option[KeyPolicy.Value], 55 | targetPolicy: Option[KeyPolicy.Value]) { 56 | def propertyNames: String = NebulaExecutor.escapePropName(propNames).mkString(",") 57 | def getSourcePolicy = sourcePolicy 58 | def getTargetPolicy = targetPolicy 59 | 60 | override def toString: String = { 61 | "Edges:" + 62 | s" Property Names: ${propNames.mkString(", ")}" + 63 | s" with source policy ${sourcePolicy}" + 64 | s" with target policy ${targetPolicy}" 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/ssl/SSLEnum.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.ssl 7 | 8 | object SSLSignType extends Enumeration { 9 | 10 | type signType = Value 11 | val CA = Value("ca") 12 | val SELF = Value("self") 13 | } 14 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/ssl/SSLSignParams.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.ssl 7 | 8 | case class CASSLSignParams(caCrtFilePath: String, crtFilePath: String, keyFilePath: String) 9 | 10 | case class SelfSSLSignParams(crtFilePath: String, keyFilePath: String, password: String) 11 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/utils/AddressCheckUtil.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.utils 7 | 8 | import com.google.common.base.Strings 9 | 10 | object AddressCheckUtil { 11 | 12 | def getAddressFromString(addr: String): (String, Int) = { 13 | if (addr == null) { 14 | throw new IllegalArgumentException("wrong address format.") 15 | } 16 | 17 | val (host, portString) = 18 | if (addr.startsWith("[")) { 19 | getHostAndPortFromBracketedHost(addr) 20 | } else if (addr.count(_ == ':') == 1) { 21 | val array = addr.split(":", 2) 22 | (array(0), array(1)) 23 | } else { 24 | (addr, null) 25 | } 26 | 27 | val port = getPort(portString, addr) 28 | (host, port) 29 | } 30 | 31 | private def getPort(portString: String, addr: String): Int = 32 | if (Strings.isNullOrEmpty(portString)) { 33 | -1 34 | } else { 35 | require(portString.forall(_.isDigit), s"Port must be numeric: $addr") 36 | val port = portString.toInt 37 | require(1 <= port && port <= 65535, s"Port number out of range: $addr") 38 | port 39 | } 40 | 41 | def getHostAndPortFromBracketedHost(addr: String): (String, String) = { 42 | val colonIndex = addr.indexOf(":") 43 | val closeBracketIndex = addr.lastIndexOf("]") 44 | if (colonIndex < 0 || closeBracketIndex < colonIndex) { 45 | throw new IllegalArgumentException(s"invalid bracketed host/port: $addr") 46 | } 47 | val host: String = addr.substring(1, closeBracketIndex) 48 | if (closeBracketIndex + 1 == addr.length) { 49 | (host, "") 50 | } else if (addr.charAt(closeBracketIndex + 1) != ':') { 51 | throw new IllegalArgumentException(s"only a colon may follow a close bracket: $addr") 52 | } else { 53 | val port = addr.substring(closeBracketIndex + 2) 54 | if (port.forall(_.isDigit)) { 55 | (host, port) 56 | } else { 57 | throw new IllegalArgumentException(s"Port must be numeric: $addr") 58 | } 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /nebula-spark-common/src/main/scala/com/vesoft/nebula/connector/utils/SparkValidate.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.utils 7 | 8 | import org.apache.spark.sql.SparkSession 9 | 10 | object SparkValidate { 11 | def validate(supportedVersions: String*): Unit = { 12 | val sparkVersion = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN") 13 | if (sparkVersion != "UNKNOWN" && !supportedVersions.exists(sparkVersion.matches)) { 14 | throw new RuntimeException( 15 | s"""Your current spark version ${sparkVersion} is not supported by the current NebulaGraph Spark Connector. 16 | | please visit https://github.com/vesoft-inc/nebula-spark-connector#version-match to know which Connector you need. 17 | | """.stripMargin) 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Global logging configuration 2 | log4j.rootLogger=INFO, stdout 3 | # Console output... 4 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 5 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n 7 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/AddressCheckUtilsSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.connector.utils.AddressCheckUtil 9 | import org.scalatest.funsuite.AnyFunSuite 10 | 11 | class AddressCheckUtilsSuite extends AnyFunSuite { 12 | 13 | test("checkAddress") { 14 | var addr = "127.0.0.1:9669" 15 | var hostAddress = AddressCheckUtil.getAddressFromString(addr) 16 | assert("127.0.0.1".equals(hostAddress._1)) 17 | assert(hostAddress._2 == 9669) 18 | 19 | addr = "localhost:9669" 20 | hostAddress = AddressCheckUtil.getAddressFromString(addr) 21 | assert("localhost".equals(hostAddress._1)) 22 | 23 | addr = "www.baidu.com:22" 24 | hostAddress = AddressCheckUtil.getAddressFromString(addr) 25 | assert(hostAddress._2 == 22) 26 | 27 | addr = "[2023::2]:65535" 28 | hostAddress = AddressCheckUtil.getAddressFromString(addr) 29 | assert(hostAddress._2 == 65535) 30 | 31 | addr = "2023::3" 32 | hostAddress = AddressCheckUtil.getAddressFromString(addr) 33 | assert(hostAddress._1.equals("2023::3")) 34 | assert(hostAddress._2 == -1) 35 | 36 | // bad address 37 | addr = "localhost:65536" 38 | assertThrows[IllegalArgumentException](AddressCheckUtil.getAddressFromString(addr)) 39 | addr = "localhost:-1" 40 | assertThrows[IllegalArgumentException](AddressCheckUtil.getAddressFromString(addr)) 41 | addr = "[localhost]:9669" 42 | assertThrows[IllegalArgumentException](AddressCheckUtil.getAddressFromString(addr)) 43 | addr = "www.baidu.com:+25" 44 | assertThrows[IllegalArgumentException](AddressCheckUtil.getAddressFromString(addr)) 45 | addr = "[]:8080" 46 | assertThrows[IllegalArgumentException](AddressCheckUtil.getAddressFromString(addr)) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/DataTypeEnumSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import org.scalatest.funsuite.AnyFunSuite 9 | 10 | class DataTypeEnumSuite extends AnyFunSuite { 11 | 12 | test("validDataType") { 13 | assert(DataTypeEnum.validDataType("vertex")) 14 | assert(DataTypeEnum.validDataType("VERTEX")) 15 | assert(DataTypeEnum.validDataType("edge")) 16 | assert(DataTypeEnum.validDataType("EDGE")) 17 | assert(!DataTypeEnum.validDataType("relation")) 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/NebulaConfigSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.connector.ssl.SSLSignType 9 | import org.scalatest.BeforeAndAfterAll 10 | import org.scalatest.funsuite.AnyFunSuite 11 | 12 | class NebulaConfigSuite extends AnyFunSuite with BeforeAndAfterAll { 13 | 14 | test("test NebulaConnectionConfig") { 15 | 16 | assertThrows[AssertionError](NebulaConnectionConfig.builder().withTimeout(1).build()) 17 | 18 | assertThrows[AssertionError](NebulaConnectionConfig.builder().withTimeout(-1).build()) 19 | 20 | NebulaConnectionConfig 21 | .builder() 22 | .withMetaAddress("127.0.0.1:9559") 23 | .withTimeout(1) 24 | .build() 25 | } 26 | 27 | test("test correct ssl config") { 28 | NebulaConnectionConfig 29 | .builder() 30 | .withMetaAddress("127.0.0.1:9559") 31 | .withGraphAddress("127.0.0.1:9669") 32 | .withEnableGraphSSL(true) 33 | .withEnableMetaSSL(true) 34 | .withSSLSignType(SSLSignType.CA) 35 | .withCaSSLSignParam("cacrtFile", "crtFile", "keyFile") 36 | .build() 37 | } 38 | 39 | test("test correct ssl config with wrong ssl priority") { 40 | assertThrows[AssertionError]( 41 | NebulaConnectionConfig 42 | .builder() 43 | .withMetaAddress("127.0.0.1:9559") 44 | .withGraphAddress("127.0.0.1:9669") 45 | .withEnableStorageSSL(true) 46 | .withEnableMetaSSL(false) 47 | .withSSLSignType(SSLSignType.CA) 48 | .withCaSSLSignParam("caCrtFile", "crtFile", "keyFile") 49 | .build()) 50 | } 51 | 52 | test("test correct ssl config with no sign type param") { 53 | assertThrows[AssertionError]( 54 | NebulaConnectionConfig 55 | .builder() 56 | .withMetaAddress("127.0.0.1:9559") 57 | .withGraphAddress("127.0.0.1:9669") 58 | .withEnableGraphSSL(true) 59 | .withEnableMetaSSL(true) 60 | .withCaSSLSignParam("caCrtFile", "crtFile", "keyFile") 61 | .build()) 62 | } 63 | 64 | test("test correct ssl config with wrong ca param") { 65 | assertThrows[AssertionError]( 66 | NebulaConnectionConfig 67 | .builder() 68 | .withMetaAddress("127.0.0.1:9559") 69 | .withGraphAddress("127.0.0.1:9669") 70 | .withEnableGraphSSL(true) 71 | .withEnableMetaSSL(true) 72 | .withSSLSignType(SSLSignType.CA) 73 | .withSelfSSLSignParam("crtFile", "keyFile", "password") 74 | .build()) 75 | } 76 | 77 | test("test correct ssl config with wrong self param") { 78 | assertThrows[AssertionError]( 79 | NebulaConnectionConfig 80 | .builder() 81 | .withMetaAddress("127.0.0.1:9559") 82 | .withGraphAddress("127.0.0.1:9669") 83 | .withEnableGraphSSL(true) 84 | .withEnableMetaSSL(true) 85 | .withSSLSignType(SSLSignType.SELF) 86 | .withCaSSLSignParam("cacrtFile", "crtFile", "keyFile") 87 | .build()) 88 | } 89 | 90 | test("test WriteNebulaConfig") { 91 | var writeNebulaConfig: WriteNebulaVertexConfig = null 92 | 93 | writeNebulaConfig = WriteNebulaVertexConfig 94 | .builder() 95 | .withSpace("test") 96 | .withTag("tag") 97 | .withVidField("vid") 98 | .build() 99 | 100 | assert(!writeNebulaConfig.getVidAsProp) 101 | assert(writeNebulaConfig.getSpace.equals("test")) 102 | } 103 | 104 | test("wrong batch size for update") { 105 | assertThrows[AssertionError]( 106 | WriteNebulaVertexConfig 107 | .builder() 108 | .withSpace("test") 109 | .withTag("tag") 110 | .withVidField("vId") 111 | .withWriteMode(WriteMode.UPDATE) 112 | .withBatch(513) 113 | .build()) 114 | assertThrows[AssertionError]( 115 | WriteNebulaEdgeConfig 116 | .builder() 117 | .withSpace("test") 118 | .withEdge("edge") 119 | .withSrcIdField("src") 120 | .withDstIdField("dst") 121 | .withWriteMode(WriteMode.UPDATE) 122 | .withBatch(513) 123 | .build()) 124 | } 125 | 126 | test("test wrong policy") { 127 | assertThrows[AssertionError]( 128 | WriteNebulaVertexConfig 129 | .builder() 130 | .withSpace("test") 131 | .withTag("tag") 132 | .withVidField("vId") 133 | .withVidPolicy("wrong_policy") 134 | .build()) 135 | } 136 | 137 | test("test wrong batch") { 138 | assertThrows[AssertionError]( 139 | WriteNebulaVertexConfig 140 | .builder() 141 | .withSpace("test") 142 | .withTag("tag") 143 | .withVidField("vId") 144 | .withVidPolicy("hash") 145 | .withBatch(-1) 146 | .build()) 147 | } 148 | 149 | test("test ReadNebulaConfig") { 150 | ReadNebulaConfig 151 | .builder() 152 | .withSpace("test") 153 | .withLabel("tagName") 154 | .withNoColumn(true) 155 | .withReturnCols(List("col")) 156 | .build() 157 | } 158 | 159 | } 160 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/NebulaUtilsSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.PropertyType 9 | import com.vesoft.nebula.meta.{ColumnDef, ColumnTypeDef} 10 | import org.apache.spark.sql.types.{ 11 | BooleanType, 12 | DoubleType, 13 | LongType, 14 | StringType, 15 | StructField, 16 | StructType 17 | } 18 | import org.scalatest.funsuite.AnyFunSuite 19 | 20 | class NebulaUtilsSuite extends AnyFunSuite { 21 | 22 | test("convertDataType") { 23 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.VID)) == LongType) 24 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.INT8)) == LongType) 25 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.INT16)) == LongType) 26 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.INT32)) == LongType) 27 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.INT64)) == LongType) 28 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.TIMESTAMP)) == LongType) 29 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.BOOL)) == BooleanType) 30 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.FLOAT)) == DoubleType) 31 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.DOUBLE)) == DoubleType) 32 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.FIXED_STRING)) == StringType) 33 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.STRING)) == StringType) 34 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.DATE)) == StringType) 35 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.DATETIME)) == StringType) 36 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.TIME)) == StringType) 37 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.GEOGRAPHY)) == StringType) 38 | assert(NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.DURATION)) == StringType) 39 | assertThrows[IllegalArgumentException]( 40 | NebulaUtils.convertDataType(new ColumnTypeDef(PropertyType.UNKNOWN))) 41 | } 42 | 43 | test("getColDataType") { 44 | val columnDefs: List[ColumnDef] = List( 45 | new ColumnDef("col1".getBytes(), new ColumnTypeDef(PropertyType.INT8)), 46 | new ColumnDef("col2".getBytes(), new ColumnTypeDef(PropertyType.DOUBLE)), 47 | new ColumnDef("col3".getBytes(), new ColumnTypeDef(PropertyType.STRING)), 48 | new ColumnDef("col4".getBytes(), new ColumnTypeDef(PropertyType.DATE)), 49 | new ColumnDef("col5".getBytes(), new ColumnTypeDef(PropertyType.DATETIME)), 50 | new ColumnDef("col6".getBytes(), new ColumnTypeDef(PropertyType.TIME)), 51 | new ColumnDef("col7".getBytes(), new ColumnTypeDef(PropertyType.TIMESTAMP)), 52 | new ColumnDef("col8".getBytes(), new ColumnTypeDef(PropertyType.BOOL)) 53 | ) 54 | assert(NebulaUtils.getColDataType(columnDefs, "col1") == LongType) 55 | assert(NebulaUtils.getColDataType(columnDefs, "col2") == DoubleType) 56 | assert(NebulaUtils.getColDataType(columnDefs, "col3") == StringType) 57 | assert(NebulaUtils.getColDataType(columnDefs, "col4") == StringType) 58 | assert(NebulaUtils.getColDataType(columnDefs, "col5") == StringType) 59 | assert(NebulaUtils.getColDataType(columnDefs, "col6") == StringType) 60 | assert(NebulaUtils.getColDataType(columnDefs, "col7") == LongType) 61 | assert(NebulaUtils.getColDataType(columnDefs, "col8") == BooleanType) 62 | assertThrows[IllegalArgumentException](NebulaUtils.getColDataType(columnDefs, "col9")) 63 | } 64 | 65 | test("makeGetters") { 66 | val schema = StructType( 67 | List( 68 | StructField("col1", LongType, nullable = false), 69 | StructField("col2", LongType, nullable = true) 70 | )) 71 | assert(NebulaUtils.makeGetters(schema).length == 2) 72 | } 73 | 74 | test("isNumic") { 75 | assert(NebulaUtils.isNumic("123")) 76 | assert(NebulaUtils.isNumic("-123")) 77 | assert(!NebulaUtils.isNumic("")) 78 | assert(!NebulaUtils.isNumic("-")) 79 | assert(!NebulaUtils.isNumic("1.0")) 80 | assert(!NebulaUtils.isNumic("a123")) 81 | assert(!NebulaUtils.isNumic("123b")) 82 | } 83 | 84 | test("escapeUtil") { 85 | assert(NebulaUtils.escapeUtil("123").equals("123")) 86 | // a\bc -> a\\bc 87 | assert(NebulaUtils.escapeUtil("a\bc").equals("a\\bc")) 88 | // a\tbc -> a\\tbc 89 | assert(NebulaUtils.escapeUtil("a\tbc").equals("a\\tbc")) 90 | // a\nbc -> a\\nbc 91 | assert(NebulaUtils.escapeUtil("a\nbc").equals("a\\nbc")) 92 | // a\"bc -> a\\"bc 93 | assert(NebulaUtils.escapeUtil("a\"bc").equals("a\\\"bc")) 94 | // a\'bc -> a\\'bc 95 | assert(NebulaUtils.escapeUtil("a\'bc").equals("a\\'bc")) 96 | // a\rbc -> a\\rbc 97 | assert(NebulaUtils.escapeUtil("a\rbc").equals("a\\rbc")) 98 | // a\bbc -> a\\bbc 99 | assert(NebulaUtils.escapeUtil("a\bbc").equals("a\\bbc")) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/PartitionUtilsSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import org.scalatest.funsuite.AnyFunSuite 9 | 10 | /** 11 | * base data: spark partition is 10 12 | */ 13 | class PartitionUtilsSuite extends AnyFunSuite { 14 | val partition: Int = 10 15 | 16 | test("getScanParts: nebula part is the same with spark partition") { 17 | val nebulaPart: Int = 10 18 | for (i <- 1 to 10) { 19 | val partsForIndex = PartitionUtils.getScanParts(i, nebulaPart, partition) 20 | assert(partsForIndex.size == 1) 21 | assert(partsForIndex.head == i) 22 | } 23 | } 24 | 25 | test("getScanParts: nebula part is more than spark partition") { 26 | val nebulaPart: Int = 20 27 | for (i <- 1 to 10) { 28 | val partsForIndex = PartitionUtils.getScanParts(i, nebulaPart, partition) 29 | assert(partsForIndex.contains(i) && partsForIndex.contains(i + 10)) 30 | assert(partsForIndex.size == 2) 31 | } 32 | } 33 | 34 | test("getScanParts: nebula part is less than spark partition") { 35 | val nebulaPart: Int = 5 36 | for (i <- 1 to 5) { 37 | val partsForIndex = PartitionUtils.getScanParts(i, nebulaPart, partition) 38 | assert(partsForIndex.contains(i)) 39 | } 40 | for (j <- 6 to 10) { 41 | val partsForIndex = PartitionUtils.getScanParts(j, nebulaPart, partition) 42 | assert(partsForIndex.isEmpty) 43 | } 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/nebula/GraphProviderTest.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.nebula 7 | 8 | import com.vesoft.nebula.connector.Address 9 | import com.vesoft.nebula.connector.mock.NebulaGraphMock 10 | import org.apache.log4j.BasicConfigurator 11 | import org.scalatest.BeforeAndAfterAll 12 | import org.scalatest.funsuite.AnyFunSuite 13 | 14 | class GraphProviderTest extends AnyFunSuite with BeforeAndAfterAll { 15 | BasicConfigurator.configure() 16 | 17 | var graphProvider: GraphProvider = null 18 | 19 | override def beforeAll(): Unit = { 20 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 21 | graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 22 | val graphMock = new NebulaGraphMock 23 | graphMock.mockIntIdGraph() 24 | graphMock.mockStringIdGraph() 25 | graphMock.close() 26 | } 27 | 28 | override def afterAll(): Unit = { 29 | graphProvider.close() 30 | } 31 | 32 | test("switchSpace") { 33 | assertThrows[RuntimeException](graphProvider.switchSpace("space_not_exist")) 34 | assert(graphProvider.switchSpace("test_int")) 35 | } 36 | 37 | test("submit") { 38 | val result = graphProvider.submit("fetch prop on person 1 yield vertex as v") 39 | assert(result.isSucceeded) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /nebula-spark-common/src/test/scala/com/vesoft/nebula/connector/nebula/MetaProviderTest.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.nebula 7 | 8 | import com.vesoft.nebula.PropertyType 9 | import com.vesoft.nebula.connector.mock.NebulaGraphMock 10 | import com.vesoft.nebula.connector.{Address, DataTypeEnum} 11 | import com.vesoft.nebula.meta.Schema 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class MetaProviderTest extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | var metaProvider: MetaProvider = null 19 | 20 | override def beforeAll(): Unit = { 21 | val addresses: List[Address] = List(new Address("127.0.0.1", 9559)) 22 | metaProvider = new MetaProvider(addresses, 6000, 3, 3, false, null, null, null) 23 | 24 | val graphMock = new NebulaGraphMock 25 | graphMock.mockStringIdGraph() 26 | graphMock.mockIntIdGraph() 27 | graphMock.close() 28 | } 29 | 30 | override def afterAll(): Unit = { 31 | metaProvider.close() 32 | } 33 | 34 | test("getPartitionNumber") { 35 | assert(metaProvider.getPartitionNumber("test_int") == 10) 36 | assert(metaProvider.getPartitionNumber("test_string") == 10) 37 | } 38 | 39 | test("getVidType") { 40 | assert(metaProvider.getVidType("test_int") == VidType.INT) 41 | assert(metaProvider.getVidType("test_string") == VidType.STRING) 42 | } 43 | 44 | test("getTag") { 45 | val schema: Schema = metaProvider.getTag("test_int", "person") 46 | assert(schema.columns.size() == 13) 47 | 48 | val schema1: Schema = metaProvider.getTag("test_string", "person") 49 | assert(schema1.columns.size() == 13) 50 | } 51 | 52 | test("getEdge") { 53 | val schema: Schema = metaProvider.getEdge("test_int", "friend") 54 | assert(schema.columns.size() == 13) 55 | 56 | val schema1: Schema = metaProvider.getEdge("test_string", "friend") 57 | assert(schema1.columns.size() == 13) 58 | } 59 | 60 | test("getTagSchema for person") { 61 | val schemaMap: Map[String, Integer] = metaProvider.getTagSchema("test_int", "person") 62 | assert(schemaMap.size == 13) 63 | assert(schemaMap("col1") == PropertyType.STRING.getValue) 64 | assert(schemaMap("col2") == PropertyType.FIXED_STRING.getValue) 65 | assert(schemaMap("col3") == PropertyType.INT8.getValue) 66 | assert(schemaMap("col4") == PropertyType.INT16.getValue) 67 | assert(schemaMap("col5") == PropertyType.INT32.getValue) 68 | assert(schemaMap("col6") == PropertyType.INT64.getValue) 69 | assert(schemaMap("col7") == PropertyType.DATE.getValue) 70 | assert(schemaMap("col8") == PropertyType.DATETIME.getValue) 71 | assert(schemaMap("col9") == PropertyType.TIMESTAMP.getValue) 72 | assert(schemaMap("col10") == PropertyType.BOOL.getValue) 73 | assert(schemaMap("col11") == PropertyType.DOUBLE.getValue) 74 | assert(schemaMap("col12") == PropertyType.FLOAT.getValue) 75 | assert(schemaMap("col13") == PropertyType.TIME.getValue) 76 | } 77 | 78 | test("getTagSchema for geo_shape") { 79 | val schemaMap: Map[String, Integer] = metaProvider.getTagSchema("test_int", "geo_shape") 80 | assert(schemaMap.size == 1) 81 | assert(schemaMap("geo") == PropertyType.GEOGRAPHY.getValue) 82 | } 83 | 84 | test("getEdgeSchema") { 85 | val schemaMap: Map[String, Integer] = metaProvider.getEdgeSchema("test_int", "friend") 86 | assert(schemaMap.size == 13) 87 | assert(schemaMap("col1") == PropertyType.STRING.getValue) 88 | assert(schemaMap("col2") == PropertyType.FIXED_STRING.getValue) 89 | assert(schemaMap("col3") == PropertyType.INT8.getValue) 90 | assert(schemaMap("col4") == PropertyType.INT16.getValue) 91 | assert(schemaMap("col5") == PropertyType.INT32.getValue) 92 | assert(schemaMap("col6") == PropertyType.INT64.getValue) 93 | assert(schemaMap("col7") == PropertyType.DATE.getValue) 94 | assert(schemaMap("col8") == PropertyType.DATETIME.getValue) 95 | assert(schemaMap("col9") == PropertyType.TIMESTAMP.getValue) 96 | assert(schemaMap("col10") == PropertyType.BOOL.getValue) 97 | assert(schemaMap("col11") == PropertyType.DOUBLE.getValue) 98 | assert(schemaMap("col12") == PropertyType.FLOAT.getValue) 99 | assert(schemaMap("col13") == PropertyType.TIME.getValue) 100 | } 101 | 102 | test("getLabelType") { 103 | assert(metaProvider.getLabelType("test_int", "person") == DataTypeEnum.VERTEX) 104 | assert(metaProvider.getLabelType("test_int", "friend") == DataTypeEnum.EDGE) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /nebula-spark-connector/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.nar 17 | *.ear 18 | *.zip 19 | *.tar.gz 20 | *.rar 21 | 22 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 23 | hs_err_pid* 24 | 25 | # build target 26 | target/ 27 | 28 | # IDE 29 | .idea/ 30 | .eclipse/ 31 | *.iml 32 | 33 | spark-importer.ipr 34 | spark-importer.iws 35 | 36 | .DS_Store 37 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import java.util.Map.Entry 9 | import java.util.Optional 10 | import com.vesoft.nebula.connector.exception.IllegalOptionException 11 | import com.vesoft.nebula.connector.reader.{NebulaDataSourceEdgeReader, NebulaDataSourceNgqlEdgeReader, NebulaDataSourceVertexReader} 12 | import com.vesoft.nebula.connector.writer.{NebulaDataSourceEdgeWriter, NebulaDataSourceVertexWriter} 13 | import org.apache.spark.sql.SaveMode 14 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 15 | import org.apache.spark.sql.sources.DataSourceRegister 16 | import org.apache.spark.sql.sources.v2.reader.DataSourceReader 17 | import org.apache.spark.sql.sources.v2.writer.DataSourceWriter 18 | import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, WriteSupport} 19 | import org.apache.spark.sql.types.StructType 20 | import org.slf4j.LoggerFactory 21 | 22 | import scala.collection.JavaConversions.iterableAsScalaIterable 23 | 24 | class NebulaDataSource 25 | extends DataSourceV2 26 | with ReadSupport 27 | with WriteSupport 28 | with DataSourceRegister { 29 | private val LOG = LoggerFactory.getLogger(this.getClass) 30 | 31 | /** 32 | * The string that represents the format that nebula data source provider uses. 33 | */ 34 | override def shortName(): String = "nebula" 35 | 36 | /** 37 | * Creates a {@link DataSourceReader} to scan the data from Nebula Graph. 38 | */ 39 | override def createReader(options: DataSourceOptions): DataSourceReader = { 40 | val nebulaOptions = getNebulaOptions(options) 41 | val dataType = nebulaOptions.dataType 42 | 43 | LOG.info("create reader") 44 | val parameters = options.asMap() 45 | parameters.remove("passwd") 46 | LOG.info(s"options ${parameters}") 47 | 48 | if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { 49 | new NebulaDataSourceVertexReader(nebulaOptions) 50 | } else if (nebulaOptions.ngql != null && nebulaOptions.ngql.nonEmpty) { 51 | new NebulaDataSourceNgqlEdgeReader(nebulaOptions) 52 | } else { 53 | new NebulaDataSourceEdgeReader(nebulaOptions) 54 | } 55 | } 56 | 57 | /** 58 | * Creates an optional {@link DataSourceWriter} to save the data to Nebula Graph. 59 | */ 60 | override def createWriter(writeUUID: String, 61 | schema: StructType, 62 | mode: SaveMode, 63 | options: DataSourceOptions): Optional[DataSourceWriter] = { 64 | 65 | val nebulaOptions = getNebulaOptions(options) 66 | val dataType = nebulaOptions.dataType 67 | if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { 68 | LOG.warn(s"Currently do not support mode") 69 | } 70 | 71 | LOG.info("create writer") 72 | val parameters = options.asMap() 73 | parameters.remove("passwd") 74 | LOG.info(s"options ${parameters}") 75 | 76 | if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { 77 | val vertexFiled = nebulaOptions.vertexField 78 | val vertexIndex: Int = { 79 | var index: Int = -1 80 | for (i <- schema.fields.indices) { 81 | if (schema.fields(i).name.equals(vertexFiled)) { 82 | index = i 83 | } 84 | } 85 | if (index < 0) { 86 | throw new IllegalOptionException( 87 | s" vertex field ${vertexFiled} does not exist in dataframe") 88 | } 89 | index 90 | } 91 | Optional.of(new NebulaDataSourceVertexWriter(nebulaOptions, vertexIndex, schema)) 92 | } else { 93 | val srcVertexFiled = nebulaOptions.srcVertexField 94 | val dstVertexField = nebulaOptions.dstVertexField 95 | val rankExist = !nebulaOptions.rankField.isEmpty 96 | val edgeFieldsIndex = { 97 | var srcIndex: Int = -1 98 | var dstIndex: Int = -1 99 | var rankIndex: Int = -1 100 | for (i <- schema.fields.indices) { 101 | if (schema.fields(i).name.equals(srcVertexFiled)) { 102 | srcIndex = i 103 | } 104 | if (schema.fields(i).name.equals(dstVertexField)) { 105 | dstIndex = i 106 | } 107 | if (rankExist) { 108 | if (schema.fields(i).name.equals(nebulaOptions.rankField)) { 109 | rankIndex = i 110 | } 111 | } 112 | } 113 | // check src filed and dst field 114 | if (srcIndex < 0 || dstIndex < 0) { 115 | throw new IllegalOptionException( 116 | s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe") 117 | } 118 | // check rank field 119 | if (rankExist && rankIndex < 0) { 120 | throw new IllegalOptionException(s"rank field does not exist in dataframe") 121 | } 122 | 123 | if (!rankExist) { 124 | (srcIndex, dstIndex, Option.empty) 125 | } else { 126 | (srcIndex, dstIndex, Option(rankIndex)) 127 | } 128 | 129 | } 130 | Optional.of( 131 | new NebulaDataSourceEdgeWriter(nebulaOptions, 132 | edgeFieldsIndex._1, 133 | edgeFieldsIndex._2, 134 | edgeFieldsIndex._3, 135 | schema)) 136 | } 137 | } 138 | 139 | /** 140 | * construct nebula options with DataSourceOptions 141 | */ 142 | def getNebulaOptions(options: DataSourceOptions): NebulaOptions = { 143 | var parameters: Map[String, String] = Map() 144 | for (entry: Entry[String, String] <- options.asMap().entrySet) { 145 | parameters += (entry.getKey -> entry.getValue) 146 | } 147 | val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters)) 148 | nebulaOptions 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.sql.types.StructType 10 | 11 | class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) 12 | extends NebulaPartitionReader(index, nebulaOptions, schema) { 13 | 14 | override def next(): Boolean = hasNextEdgeRow 15 | } 16 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgePartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.Value 9 | import com.vesoft.nebula.client.graph.data.{Relationship, ResultSet, ValueWrapper} 10 | import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} 13 | import org.apache.spark.sql.catalyst.InternalRow 14 | import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow 15 | import org.apache.spark.sql.sources.v2.reader.InputPartitionReader 16 | import org.apache.spark.sql.types.StructType 17 | import org.slf4j.{Logger, LoggerFactory} 18 | 19 | import scala.collection.JavaConversions.asScalaBuffer 20 | import scala.collection.mutable 21 | import scala.collection.mutable.ListBuffer 22 | 23 | /** 24 | * create reader by ngql 25 | */ 26 | class NebulaNgqlEdgePartitionReader extends InputPartitionReader[InternalRow] { 27 | 28 | private val LOG: Logger = LoggerFactory.getLogger(this.getClass) 29 | 30 | private var nebulaOptions: NebulaOptions = _ 31 | private var graphProvider: GraphProvider = _ 32 | private var schema: StructType = _ 33 | private var resultSet: ResultSet = _ 34 | private var edgeIterator: Iterator[ListBuffer[ValueWrapper]] = _ 35 | 36 | def this(nebulaOptions: NebulaOptions, schema: StructType) { 37 | this() 38 | this.schema = schema 39 | this.nebulaOptions = nebulaOptions 40 | this.graphProvider = new GraphProvider( 41 | nebulaOptions.getGraphAddress, 42 | nebulaOptions.user, 43 | nebulaOptions.passwd, 44 | nebulaOptions.timeout, 45 | nebulaOptions.enableGraphSSL, 46 | nebulaOptions.sslSignType, 47 | nebulaOptions.caSignParam, 48 | nebulaOptions.selfSignParam 49 | ) 50 | // add exception when session build failed 51 | graphProvider.switchSpace(nebulaOptions.spaceName) 52 | resultSet = graphProvider.submit(nebulaOptions.ngql) 53 | edgeIterator = query() 54 | } 55 | 56 | def query(): Iterator[ListBuffer[ValueWrapper]] = { 57 | val edges: ListBuffer[ListBuffer[ValueWrapper]] = new ListBuffer[ListBuffer[ValueWrapper]] 58 | val properties = nebulaOptions.getReturnCols 59 | for (i <- 0 until resultSet.rowsSize()) { 60 | val rowValues = resultSet.rowValues(i).values() 61 | for (j <- 0 until rowValues.size()) { 62 | val value = rowValues.get(j) 63 | val valueType = value.getValue.getSetField 64 | if (valueType == Value.EVAL) { 65 | val relationship = value.asRelationship() 66 | if (checkLabel(relationship)) { 67 | edges.append(convertToEdge(relationship, properties)) 68 | } 69 | } else if (valueType == Value.LVAL) { 70 | val list: mutable.Buffer[ValueWrapper] = value.asList() 71 | edges.appendAll( 72 | list.toStream 73 | .filter(e => e != null && e.isEdge() && checkLabel(e.asRelationship())) 74 | .map(e => convertToEdge(e.asRelationship(), properties)) 75 | ) 76 | } else if (valueType == Value.PVAL){ 77 | val list: java.util.List[Relationship] = value.asPath().getRelationships() 78 | edges.appendAll( 79 | list.toStream 80 | .filter(e => checkLabel(e)) 81 | .map(e => convertToEdge(e, properties)) 82 | ) 83 | } else if (valueType != Value.NVAL && valueType != 0) { 84 | LOG.error(s"Unexpected edge type encountered: ${valueType}. Only edge or path should be returned.") 85 | throw new RuntimeException("Invalid nGQL return type. Value type conversion failed."); 86 | } 87 | } 88 | } 89 | edges.iterator 90 | } 91 | 92 | def checkLabel(relationship: Relationship): Boolean = { 93 | this.nebulaOptions.label.equals(relationship.edgeName()) 94 | } 95 | 96 | def convertToEdge(relationship: Relationship, 97 | properties: List[String]): ListBuffer[ValueWrapper] = { 98 | val edge: ListBuffer[ValueWrapper] = new ListBuffer[ValueWrapper] 99 | edge.append(relationship.srcId()) 100 | edge.append(relationship.dstId()) 101 | edge.append(new ValueWrapper(new Value(Value.IVAL, relationship.ranking()), "utf-8")) 102 | if (properties == null || properties.isEmpty) 103 | return edge 104 | else { 105 | for (i <- properties.indices) { 106 | edge.append(relationship.properties().get(properties(i))) 107 | } 108 | } 109 | edge 110 | } 111 | 112 | override def next(): Boolean = { 113 | edgeIterator.hasNext 114 | } 115 | 116 | override def get(): InternalRow = { 117 | val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) 118 | val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) 119 | 120 | val edge = edgeIterator.next(); 121 | for (i <- getters.indices) { 122 | val value: ValueWrapper = edge(i) 123 | var resolved = false 124 | if (value.isNull) { 125 | mutableRow.setNullAt(i) 126 | resolved = true 127 | } 128 | if (value.isString) { 129 | getters(i).apply(value.asString(), mutableRow, i) 130 | resolved = true 131 | } 132 | if (value.isDate) { 133 | getters(i).apply(value.asDate(), mutableRow, i) 134 | resolved = true 135 | } 136 | if (value.isTime) { 137 | getters(i).apply(value.asTime(), mutableRow, i) 138 | resolved = true 139 | } 140 | if (value.isDateTime) { 141 | getters(i).apply(value.asDateTime(), mutableRow, i) 142 | resolved = true 143 | } 144 | if (value.isLong) { 145 | getters(i).apply(value.asLong(), mutableRow, i) 146 | } 147 | if (value.isBoolean) { 148 | getters(i).apply(value.asBoolean(), mutableRow, i) 149 | } 150 | if (value.isDouble) { 151 | getters(i).apply(value.asDouble(), mutableRow, i) 152 | } 153 | if (value.isGeography) { 154 | getters(i).apply(value.asGeography(), mutableRow, i) 155 | } 156 | if (value.isDuration) { 157 | getters(i).apply(value.asDuration(), mutableRow, i) 158 | } 159 | } 160 | mutableRow 161 | 162 | } 163 | 164 | override def close(): Unit = { 165 | graphProvider.close(); 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartition.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} 11 | import org.apache.spark.sql.types.StructType 12 | 13 | class NebulaVertexPartition(index: Int, nebulaOptions: NebulaOptions, schema: StructType) 14 | extends InputPartition[InternalRow] { 15 | override def createPartitionReader(): InputPartitionReader[InternalRow] = 16 | new NebulaVertexPartitionReader(index, nebulaOptions, schema) 17 | } 18 | 19 | class NebulaEdgePartition(index: Int, nebulaOptions: NebulaOptions, schema: StructType) 20 | extends InputPartition[InternalRow] { 21 | override def createPartitionReader(): InputPartitionReader[InternalRow] = 22 | new NebulaEdgePartitionReader(index, nebulaOptions, schema) 23 | } 24 | 25 | class NebulaNgqlEdgePartition(nebulaOptions: NebulaOptions, schema: StructType) 26 | extends InputPartition[InternalRow] { 27 | override def createPartitionReader(): InputPartitionReader[InternalRow] = 28 | new NebulaNgqlEdgePartitionReader(nebulaOptions, schema) 29 | } -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{NebulaOptions, PartitionUtils} 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.sources.v2.reader.InputPartitionReader 11 | import org.apache.spark.sql.types.StructType 12 | import org.slf4j.{Logger, LoggerFactory} 13 | 14 | /** 15 | * Read nebula data for each spark partition 16 | */ 17 | abstract class NebulaPartitionReader extends InputPartitionReader[InternalRow] with NebulaReader { 18 | private val LOG: Logger = LoggerFactory.getLogger(this.getClass) 19 | 20 | /** 21 | * @param index identifier for spark partition 22 | * @param nebulaOptions nebula Options 23 | * @param schema of data need to read 24 | */ 25 | def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { 26 | this() 27 | val totalPart = super.init(index, nebulaOptions, schema) 28 | // index starts with 1 29 | val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) 30 | LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") 31 | scanPartIterator = scanParts.iterator 32 | } 33 | 34 | override def get(): InternalRow = super.getRow() 35 | 36 | override def close(): Unit = { 37 | super.closeReader() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaSourceReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import java.util 9 | 10 | import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} 11 | import org.apache.spark.sql.catalyst.InternalRow 12 | import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition} 13 | import org.apache.spark.sql.types.{StructType} 14 | import org.slf4j.LoggerFactory 15 | 16 | import scala.collection.JavaConverters._ 17 | 18 | /** 19 | * Base class of Nebula Source Reader 20 | */ 21 | abstract class NebulaSourceReader(nebulaOptions: NebulaOptions) extends DataSourceReader { 22 | private val LOG = LoggerFactory.getLogger(this.getClass) 23 | 24 | private var datasetSchema: StructType = _ 25 | 26 | override def readSchema(): StructType = { 27 | if (datasetSchema == null) { 28 | datasetSchema = NebulaUtils.getSchema(nebulaOptions) 29 | } 30 | 31 | LOG.info(s"dataset's schema: $datasetSchema") 32 | datasetSchema 33 | } 34 | 35 | protected def getSchema: StructType = 36 | if (datasetSchema == null) NebulaUtils.getSchema(nebulaOptions) else datasetSchema 37 | } 38 | 39 | /** 40 | * DataSourceReader for Nebula Vertex 41 | */ 42 | class NebulaDataSourceVertexReader(nebulaOptions: NebulaOptions) 43 | extends NebulaSourceReader(nebulaOptions) { 44 | 45 | override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { 46 | val partitionNum = nebulaOptions.partitionNums.toInt 47 | val partitions = for (index <- 1 to partitionNum) 48 | yield { 49 | new NebulaVertexPartition(index, nebulaOptions, getSchema) 50 | } 51 | partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava 52 | } 53 | } 54 | 55 | /** 56 | * DataSourceReader for Nebula Edge 57 | */ 58 | class NebulaDataSourceEdgeReader(nebulaOptions: NebulaOptions) 59 | extends NebulaSourceReader(nebulaOptions) { 60 | 61 | override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { 62 | val partitionNum = nebulaOptions.partitionNums.toInt 63 | val partitions = for (index <- 1 to partitionNum) 64 | yield new NebulaEdgePartition(index, nebulaOptions, getSchema) 65 | 66 | partitions.map(_.asInstanceOf[InputPartition[InternalRow]]).asJava 67 | } 68 | } 69 | 70 | /** 71 | * DataSourceReader for Nebula Edge by ngql 72 | */ 73 | class NebulaDataSourceNgqlEdgeReader(nebulaOptions: NebulaOptions) 74 | extends NebulaSourceReader(nebulaOptions) { 75 | 76 | override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { 77 | val partitions = new util.ArrayList[InputPartition[InternalRow]]() 78 | partitions.add(new NebulaNgqlEdgePartition(nebulaOptions, getSchema)) 79 | partitions 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.sql.types.StructType 10 | 11 | class NebulaVertexPartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) 12 | extends NebulaPartitionReader(index, nebulaOptions, schema) { 13 | 14 | override def next(): Boolean = hasNextVertexRow 15 | } 16 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage 9 | 10 | case class NebulaCommitMessage(executeStatements: List[String]) extends WriterCommitMessage 11 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.{NebulaEdge, NebulaEdges} 9 | import com.vesoft.nebula.connector.{KeyPolicy, NebulaOptions, WriteMode} 10 | import org.apache.spark.sql.catalyst.InternalRow 11 | import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} 12 | import org.apache.spark.sql.types.StructType 13 | import org.slf4j.LoggerFactory 14 | 15 | import scala.collection.mutable.ListBuffer 16 | 17 | class NebulaEdgeWriter(nebulaOptions: NebulaOptions, 18 | srcIndex: Int, 19 | dstIndex: Int, 20 | rankIndex: Option[Int], 21 | schema: StructType) 22 | extends NebulaWriter(nebulaOptions) 23 | with DataWriter[InternalRow] { 24 | 25 | private val LOG = LoggerFactory.getLogger(this.getClass) 26 | 27 | val rankIdx = if (rankIndex.isDefined) rankIndex.get else -1 28 | val propNames = NebulaExecutor.assignEdgePropNames(schema, 29 | srcIndex, 30 | dstIndex, 31 | rankIdx, 32 | nebulaOptions.srcAsProp, 33 | nebulaOptions.dstAsProp, 34 | nebulaOptions.rankAsProp) 35 | val fieldTypMap: Map[String, Integer] = 36 | if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() 37 | else metaProvider.getEdgeSchema(nebulaOptions.spaceName, nebulaOptions.label) 38 | 39 | val srcPolicy = 40 | if (nebulaOptions.srcPolicy.isEmpty) Option.empty 41 | else Option(KeyPolicy.withName(nebulaOptions.srcPolicy)) 42 | val dstPolicy = { 43 | if (nebulaOptions.dstPolicy.isEmpty) Option.empty 44 | else Option(KeyPolicy.withName(nebulaOptions.dstPolicy)) 45 | } 46 | 47 | /** buffer to save batch edges */ 48 | var edges: ListBuffer[NebulaEdge] = new ListBuffer() 49 | 50 | prepareSpace() 51 | 52 | /** 53 | * write one edge record to buffer 54 | */ 55 | override def write(row: InternalRow): Unit = { 56 | val srcId = NebulaExecutor.extraID(schema, row, srcIndex, srcPolicy, isVidStringType) 57 | val dstId = NebulaExecutor.extraID(schema, row, dstIndex, dstPolicy, isVidStringType) 58 | val rank = 59 | if (rankIndex.isEmpty) Option.empty 60 | else Option(NebulaExecutor.extraRank(schema, row, rankIndex.get)) 61 | val values = 62 | if (nebulaOptions.writeMode == WriteMode.DELETE) List() 63 | else 64 | NebulaExecutor.assignEdgeValues(schema, 65 | row, 66 | srcIndex, 67 | dstIndex, 68 | rankIdx, 69 | nebulaOptions.srcAsProp, 70 | nebulaOptions.dstAsProp, 71 | nebulaOptions.rankAsProp, 72 | fieldTypMap) 73 | val nebulaEdge = NebulaEdge(srcId, dstId, rank, values) 74 | edges.append(nebulaEdge) 75 | if (edges.size >= nebulaOptions.batch) { 76 | execute() 77 | } 78 | } 79 | 80 | /** 81 | * submit buffer edges to nebula 82 | */ 83 | def execute(): Unit = { 84 | val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) 85 | val exec = nebulaOptions.writeMode match { 86 | case WriteMode.INSERT => 87 | NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges, nebulaOptions.overwrite) 88 | case WriteMode.UPDATE => 89 | NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) 90 | case WriteMode.DELETE => 91 | NebulaExecutor.toDeleteExecuteStatement(nebulaOptions.label, nebulaEdges) 92 | case _ => 93 | throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") 94 | } 95 | edges.clear() 96 | submit(exec) 97 | } 98 | 99 | override def commit(): WriterCommitMessage = { 100 | if (edges.nonEmpty) { 101 | execute() 102 | } 103 | graphProvider.close() 104 | metaProvider.close() 105 | NebulaCommitMessage.apply(failedExecs.toList) 106 | } 107 | 108 | override def abort(): Unit = { 109 | LOG.error("insert edge task abort.") 110 | graphProvider.close() 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.TaskContext 10 | import org.apache.spark.sql.catalyst.InternalRow 11 | import org.apache.spark.sql.sources.v2.writer.{ 12 | DataSourceWriter, 13 | DataWriter, 14 | DataWriterFactory, 15 | WriterCommitMessage 16 | } 17 | import org.apache.spark.sql.types.StructType 18 | import org.slf4j.LoggerFactory 19 | 20 | /** 21 | * creating and initializing the actual Nebula vertex writer at executor side 22 | */ 23 | class NebulaVertexWriterFactory(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) 24 | extends DataWriterFactory[InternalRow] { 25 | override def createDataWriter(partitionId: Int, 26 | taskId: Long, 27 | epochId: Long): DataWriter[InternalRow] = { 28 | new NebulaVertexWriter(nebulaOptions, vertexIndex, schema) 29 | } 30 | } 31 | 32 | /** 33 | * creating and initializing the actual Nebula edge writer at executor side 34 | */ 35 | class NebulaEdgeWriterFactory(nebulaOptions: NebulaOptions, 36 | srcIndex: Int, 37 | dstIndex: Int, 38 | rankIndex: Option[Int], 39 | schema: StructType) 40 | extends DataWriterFactory[InternalRow] { 41 | override def createDataWriter(partitionId: Int, 42 | taskId: Long, 43 | epochId: Long): DataWriter[InternalRow] = { 44 | new NebulaEdgeWriter(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) 45 | } 46 | } 47 | 48 | /** 49 | * nebula vertex writer to create factory 50 | */ 51 | class NebulaDataSourceVertexWriter(nebulaOptions: NebulaOptions, 52 | vertexIndex: Int, 53 | schema: StructType) 54 | extends DataSourceWriter { 55 | private val LOG = LoggerFactory.getLogger(this.getClass) 56 | 57 | override def createWriterFactory(): DataWriterFactory[InternalRow] = { 58 | new NebulaVertexWriterFactory(nebulaOptions, vertexIndex, schema) 59 | } 60 | 61 | override def commit(messages: Array[WriterCommitMessage]): Unit = { 62 | LOG.debug(s"${messages.length}") 63 | for (msg <- messages) { 64 | val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] 65 | if (nebulaMsg.executeStatements.nonEmpty) { 66 | LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") 67 | } else { 68 | LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") 69 | } 70 | } 71 | } 72 | 73 | override def abort(messages: Array[WriterCommitMessage]): Unit = { 74 | LOG.error("NebulaDataSourceVertexWriter abort") 75 | } 76 | } 77 | 78 | /** 79 | * nebula edge writer to create factory 80 | */ 81 | class NebulaDataSourceEdgeWriter(nebulaOptions: NebulaOptions, 82 | srcIndex: Int, 83 | dstIndex: Int, 84 | rankIndex: Option[Int], 85 | schema: StructType) 86 | extends DataSourceWriter { 87 | private val LOG = LoggerFactory.getLogger(this.getClass) 88 | 89 | override def createWriterFactory(): DataWriterFactory[InternalRow] = { 90 | new NebulaEdgeWriterFactory(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) 91 | } 92 | 93 | override def commit(messages: Array[WriterCommitMessage]): Unit = { 94 | LOG.debug(s"${messages.length}") 95 | for (msg <- messages) { 96 | val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] 97 | if (nebulaMsg.executeStatements.nonEmpty) { 98 | LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") 99 | } else { 100 | LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") 101 | } 102 | } 103 | 104 | } 105 | 106 | override def abort(messages: Array[WriterCommitMessage]): Unit = { 107 | LOG.error("NebulaDataSourceEdgeWriter abort") 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.{ 9 | KeyPolicy, 10 | NebulaOptions, 11 | NebulaVertex, 12 | NebulaVertices, 13 | WriteMode 14 | } 15 | import org.apache.spark.sql.catalyst.InternalRow 16 | import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} 17 | import org.apache.spark.sql.types.StructType 18 | import org.slf4j.LoggerFactory 19 | 20 | import scala.collection.mutable.ListBuffer 21 | 22 | class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) 23 | extends NebulaWriter(nebulaOptions) 24 | with DataWriter[InternalRow] { 25 | 26 | private val LOG = LoggerFactory.getLogger(this.getClass) 27 | 28 | val propNames = NebulaExecutor.assignVertexPropNames(schema, vertexIndex, nebulaOptions.vidAsProp) 29 | val fieldTypMap: Map[String, Integer] = 30 | if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() 31 | else metaProvider.getTagSchema(nebulaOptions.spaceName, nebulaOptions.label) 32 | 33 | val policy = { 34 | if (nebulaOptions.vidPolicy.isEmpty) Option.empty 35 | else Option(KeyPolicy.withName(nebulaOptions.vidPolicy)) 36 | } 37 | 38 | /** buffer to save batch vertices */ 39 | var vertices: ListBuffer[NebulaVertex] = new ListBuffer() 40 | 41 | prepareSpace() 42 | 43 | /** 44 | * write one vertex row to buffer 45 | */ 46 | override def write(row: InternalRow): Unit = { 47 | val vertex = 48 | NebulaExecutor.extraID(schema, row, vertexIndex, policy, isVidStringType) 49 | val values = 50 | if (nebulaOptions.writeMode == WriteMode.DELETE) List() 51 | else 52 | NebulaExecutor.assignVertexPropValues(schema, 53 | row, 54 | vertexIndex, 55 | nebulaOptions.vidAsProp, 56 | fieldTypMap) 57 | val nebulaVertex = NebulaVertex(vertex, values) 58 | vertices.append(nebulaVertex) 59 | if (vertices.size >= nebulaOptions.batch) { 60 | execute() 61 | } 62 | } 63 | 64 | /** 65 | * submit buffer vertices to nebula 66 | */ 67 | def execute(): Unit = { 68 | val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) 69 | val exec = nebulaOptions.writeMode match { 70 | case WriteMode.INSERT => 71 | NebulaExecutor.toExecuteSentence(nebulaOptions.label, 72 | nebulaVertices, 73 | nebulaOptions.overwrite) 74 | case WriteMode.UPDATE => 75 | NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) 76 | case WriteMode.DELETE => 77 | NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, nebulaOptions.deleteEdge) 78 | case _ => 79 | throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") 80 | } 81 | vertices.clear() 82 | submit(exec) 83 | } 84 | 85 | override def commit(): WriterCommitMessage = { 86 | if (vertices.nonEmpty) { 87 | execute() 88 | } 89 | graphProvider.close() 90 | metaProvider.close() 91 | NebulaCommitMessage(failedExecs.toList) 92 | } 93 | 94 | override def abort(): Unit = { 95 | LOG.error("insert vertex task abort.") 96 | graphProvider.close() 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import java.util.concurrent.TimeUnit 9 | 10 | import com.google.common.util.concurrent.RateLimiter 11 | import com.vesoft.nebula.connector.NebulaOptions 12 | import com.vesoft.nebula.connector.nebula.{GraphProvider, MetaProvider, VidType} 13 | import org.slf4j.LoggerFactory 14 | 15 | import scala.collection.mutable.ListBuffer 16 | 17 | class NebulaWriter(nebulaOptions: NebulaOptions) extends Serializable { 18 | private val LOG = LoggerFactory.getLogger(this.getClass) 19 | 20 | val failedExecs: ListBuffer[String] = new ListBuffer[String] 21 | 22 | val metaProvider = new MetaProvider( 23 | nebulaOptions.getMetaAddress, 24 | nebulaOptions.timeout, 25 | nebulaOptions.connectionRetry, 26 | nebulaOptions.executionRetry, 27 | nebulaOptions.enableMetaSSL, 28 | nebulaOptions.sslSignType, 29 | nebulaOptions.caSignParam, 30 | nebulaOptions.selfSignParam 31 | ) 32 | val graphProvider = new GraphProvider( 33 | nebulaOptions.getGraphAddress, 34 | nebulaOptions.user, 35 | nebulaOptions.passwd, 36 | nebulaOptions.timeout, 37 | nebulaOptions.enableGraphSSL, 38 | nebulaOptions.sslSignType, 39 | nebulaOptions.caSignParam, 40 | nebulaOptions.selfSignParam 41 | ) 42 | val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING 43 | 44 | def prepareSpace(): Unit = { 45 | graphProvider.switchSpace(nebulaOptions.spaceName) 46 | } 47 | 48 | def submit(exec: String): Unit = { 49 | @transient val rateLimiter = RateLimiter.create(nebulaOptions.rateLimit) 50 | if (rateLimiter.tryAcquire(nebulaOptions.rateTimeOut, TimeUnit.MILLISECONDS)) { 51 | val result = graphProvider.submit(exec) 52 | if (!result.isSucceeded) { 53 | failedExecs.append(exec) 54 | if (nebulaOptions.disableWriteLog) { 55 | LOG.error(s"write failed: " + result.getErrorMessage) 56 | } else { 57 | LOG.error(s"write failed: ${result.getErrorMessage} failed statement: \n ${exec}") 58 | } 59 | } else { 60 | LOG.info(s"batch write succeed") 61 | LOG.debug(s"batch write succeed: ${exec}") 62 | } 63 | } else { 64 | failedExecs.append(exec) 65 | LOG.error(s"failed to acquire reteLimiter for statement {$exec}") 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/resources/edge.csv: -------------------------------------------------------------------------------- 1 | id1,id2,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14 2 | 1,2,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2) 3 | 2,3,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4) 4 | 3,4,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6) 5 | 4,5,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7) 6 | 5,6,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5) 7 | 6,7,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)" 8 | 7,1,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)" 9 | 8,1,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)" 10 | 9,1,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)" 11 | 10,2,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)" 12 | -1,5,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 13 | -2,6,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 14 | -3,7,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 15 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Global logging configuration 2 | log4j.rootLogger=INFO, stdout 3 | # Console output... 4 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 5 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n 7 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/resources/vertex.csv: -------------------------------------------------------------------------------- 1 | id,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15 2 | 1,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2),"duration({years:1,months:1,seconds:1})" 3 | 2,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4),"duration({years:1,months:1,seconds:1})" 4 | 3,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6),"duration({years:1,months:1,seconds:1})" 5 | 4,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7),"duration({years:1,months:1,seconds:1})" 6 | 5,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5),"duration({years:1,months:1,seconds:1})" 7 | 6,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 8 | 7,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 9 | 8,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 10 | 9,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 11 | 10,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 12 | -1,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 13 | -2,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 14 | -3,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 15 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.connector.utils.SparkValidate 9 | import org.apache.spark.sql.SparkSession 10 | import org.scalatest.funsuite.AnyFunSuite 11 | 12 | class SparkVersionValidateSuite extends AnyFunSuite { 13 | test("spark version validate") { 14 | try { 15 | val version = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN") 16 | SparkValidate.validate("2.4.*") 17 | } catch { 18 | case e: Exception => assert(false) 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.mock 7 | 8 | import com.facebook.thrift.protocol.TCompactProtocol 9 | import com.vesoft.nebula.connector.{ 10 | NebulaConnectionConfig, 11 | WriteMode, 12 | WriteNebulaEdgeConfig, 13 | WriteNebulaVertexConfig 14 | } 15 | import com.vesoft.nebula.connector.connector.NebulaDataFrameWriter 16 | import org.apache.spark.SparkConf 17 | import org.apache.spark.sql.SparkSession 18 | 19 | object SparkMock { 20 | 21 | /** 22 | * write nebula vertex with insert mode 23 | */ 24 | def writeVertex(): Unit = { 25 | val sparkConf = new SparkConf 26 | sparkConf 27 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 28 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 29 | val spark = SparkSession 30 | .builder() 31 | .master("local") 32 | .config(sparkConf) 33 | .getOrCreate() 34 | 35 | val df = spark.read 36 | .option("header", true) 37 | .csv("src/test/resources/vertex.csv") 38 | 39 | val config = 40 | NebulaConnectionConfig 41 | .builder() 42 | .withMetaAddress("127.0.0.1:9559") 43 | .withGraphAddress("127.0.0.1:9669") 44 | .withConnectionRetry(2) 45 | .build() 46 | val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig 47 | .builder() 48 | .withSpace("test_write_string") 49 | .withTag("person_connector") 50 | .withVidField("id") 51 | .withVidAsProp(false) 52 | .withBatch(5) 53 | .build() 54 | df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() 55 | 56 | spark.stop() 57 | } 58 | 59 | /** 60 | * write nebula vertex with delete mode 61 | */ 62 | def deleteVertex(): Unit = { 63 | val sparkConf = new SparkConf 64 | sparkConf 65 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 66 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 67 | val spark = SparkSession 68 | .builder() 69 | .master("local") 70 | .config(sparkConf) 71 | .getOrCreate() 72 | 73 | val df = spark.read 74 | .option("header", true) 75 | .csv("src/test/resources/vertex.csv") 76 | 77 | val config = 78 | NebulaConnectionConfig 79 | .builder() 80 | .withMetaAddress("127.0.0.1:9559") 81 | .withGraphAddress("127.0.0.1:9669") 82 | .withConnectionRetry(2) 83 | .build() 84 | val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig 85 | .builder() 86 | .withSpace("test_write_string") 87 | .withTag("person_connector") 88 | .withVidField("id") 89 | .withVidAsProp(false) 90 | .withWriteMode(WriteMode.DELETE) 91 | .withBatch(5) 92 | .build() 93 | df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() 94 | 95 | spark.stop() 96 | } 97 | 98 | /** 99 | * write nebula edge with insert mode 100 | */ 101 | def writeEdge(): Unit = { 102 | val sparkConf = new SparkConf 103 | sparkConf 104 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 105 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 106 | val spark = SparkSession 107 | .builder() 108 | .master("local") 109 | .config(sparkConf) 110 | .getOrCreate() 111 | 112 | val df = spark.read 113 | .option("header", true) 114 | .csv("src/test/resources/edge.csv") 115 | 116 | val config = 117 | NebulaConnectionConfig 118 | .builder() 119 | .withMetaAddress("127.0.0.1:9559") 120 | .withGraphAddress("127.0.0.1:9669") 121 | .withConnectionRetry(2) 122 | .build() 123 | val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig 124 | .builder() 125 | .withSpace("test_write_string") 126 | .withEdge("friend_connector") 127 | .withSrcIdField("id1") 128 | .withDstIdField("id2") 129 | .withRankField("col3") 130 | .withRankAsProperty(true) 131 | .withBatch(5) 132 | .build() 133 | df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges() 134 | 135 | spark.stop() 136 | } 137 | 138 | /** 139 | * write nebula edge with delete mode 140 | */ 141 | def deleteEdge(): Unit = { 142 | val sparkConf = new SparkConf 143 | sparkConf 144 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 145 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 146 | val spark = SparkSession 147 | .builder() 148 | .master("local") 149 | .config(sparkConf) 150 | .getOrCreate() 151 | 152 | val df = spark.read 153 | .option("header", true) 154 | .csv("src/test/resources/edge.csv") 155 | 156 | val config = 157 | NebulaConnectionConfig 158 | .builder() 159 | .withMetaAddress("127.0.0.1:9559") 160 | .withGraphAddress("127.0.0.1:9669") 161 | .withConnectionRetry(2) 162 | .build() 163 | val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig 164 | .builder() 165 | .withSpace("test_write_string") 166 | .withEdge("friend_connector") 167 | .withSrcIdField("id1") 168 | .withDstIdField("id2") 169 | .withRankField("col3") 170 | .withRankAsProperty(true) 171 | .withWriteMode(WriteMode.DELETE) 172 | .withBatch(5) 173 | .build() 174 | df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges() 175 | 176 | spark.stop() 177 | } 178 | 179 | /** 180 | * write nebula vertex with delete_with_edge mode 181 | */ 182 | def deleteVertexWithEdge(): Unit = { 183 | val sparkConf = new SparkConf 184 | sparkConf 185 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 186 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 187 | val spark = SparkSession 188 | .builder() 189 | .master("local") 190 | .config(sparkConf) 191 | .getOrCreate() 192 | 193 | val df = spark.read 194 | .option("header", true) 195 | .csv("src/test/resources/vertex.csv") 196 | 197 | val config = 198 | NebulaConnectionConfig 199 | .builder() 200 | .withMetaAddress("127.0.0.1:9559") 201 | .withGraphAddress("127.0.0.1:9669") 202 | .withConnectionRetry(2) 203 | .build() 204 | val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig 205 | .builder() 206 | .withSpace("test_write_string") 207 | .withTag("person_connector") 208 | .withVidField("id") 209 | .withVidAsProp(false) 210 | .withWriteMode(WriteMode.DELETE) 211 | .withDeleteEdge(true) 212 | .withBatch(5) 213 | .build() 214 | df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() 215 | 216 | spark.stop() 217 | } 218 | 219 | } 220 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.client.graph.data.ResultSet 9 | import com.vesoft.nebula.connector.Address 10 | import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | 19 | override def beforeAll(): Unit = { 20 | val graphMock = new NebulaGraphMock 21 | graphMock.mockStringIdGraphSchema() 22 | graphMock.mockIntIdGraphSchema() 23 | graphMock.close() 24 | Thread.sleep(10000) 25 | SparkMock.writeVertex() 26 | SparkMock.writeEdge() 27 | } 28 | 29 | test("write vertex into test_write_string space with delete mode") { 30 | SparkMock.deleteVertex() 31 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 32 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 33 | 34 | graphProvider.switchSpace("test_write_string") 35 | val resultSet: ResultSet = 36 | graphProvider.submit("use test_write_string;" 37 | + "match (v:person_connector) return v limit 100000;") 38 | assert(resultSet.isSucceeded) 39 | assert(resultSet.getColumnNames.size() == 1) 40 | assert(resultSet.isEmpty) 41 | } 42 | 43 | test("write vertex into test_write_with_edge_string space with delete with edge mode") { 44 | SparkMock.writeVertex() 45 | SparkMock.writeEdge() 46 | SparkMock.deleteVertexWithEdge() 47 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 48 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 49 | 50 | graphProvider.switchSpace("test_write_string") 51 | // assert vertex is deleted 52 | val vertexResultSet: ResultSet = 53 | graphProvider.submit("use test_write_string;" 54 | + "match (v:person_connector) return v limit 1000000;") 55 | assert(vertexResultSet.isSucceeded) 56 | assert(vertexResultSet.getColumnNames.size() == 1) 57 | assert(vertexResultSet.isEmpty) 58 | 59 | // assert edge is deleted 60 | val edgeResultSet: ResultSet = 61 | graphProvider.submit("use test_write_string;" 62 | + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e") 63 | assert(vertexResultSet.isSucceeded) 64 | assert(edgeResultSet.getColumnNames.size() == 1) 65 | assert(edgeResultSet.isEmpty) 66 | 67 | } 68 | 69 | test("write edge into test_write_string space with delete mode") { 70 | SparkMock.deleteEdge() 71 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 72 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 73 | 74 | graphProvider.switchSpace("test_write_string") 75 | val resultSet: ResultSet = 76 | graphProvider.submit("use test_write_string;" 77 | + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e;") 78 | assert(resultSet.isSucceeded) 79 | assert(resultSet.getColumnNames.size() == 1) 80 | assert(resultSet.isEmpty) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /nebula-spark-connector/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.client.graph.data.ResultSet 9 | import com.vesoft.nebula.connector.Address 10 | import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | 19 | override def beforeAll(): Unit = { 20 | val graphMock = new NebulaGraphMock 21 | graphMock.mockStringIdGraphSchema() 22 | graphMock.mockIntIdGraphSchema() 23 | graphMock.close() 24 | Thread.sleep(10000) 25 | } 26 | 27 | test("write vertex into test_write_string space with insert mode") { 28 | SparkMock.writeVertex() 29 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 30 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 31 | 32 | graphProvider.switchSpace("test_write_string") 33 | val createIndexResult: ResultSet = graphProvider.submit( 34 | "use test_write_string; " 35 | + "create tag index if not exists person_index on person_connector(col1(20));") 36 | Thread.sleep(5000) 37 | graphProvider.submit("rebuild tag index person_index;") 38 | 39 | Thread.sleep(5000) 40 | 41 | graphProvider.submit("use test_write_string;") 42 | val resultSet: ResultSet = 43 | graphProvider.submit("match (v:person_connector) return v;") 44 | assert(resultSet.isSucceeded) 45 | assert(resultSet.getColumnNames.size() == 1) 46 | assert(resultSet.getRows.size() == 13) 47 | } 48 | 49 | test("write edge into test_write_string space with insert mode") { 50 | SparkMock.writeEdge() 51 | 52 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 53 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 54 | 55 | graphProvider.switchSpace("test_write_string") 56 | val createIndexResult: ResultSet = graphProvider.submit( 57 | "use test_write_string; " 58 | + "create edge index if not exists friend_index on friend_connector(col1(20));") 59 | Thread.sleep(5000) 60 | graphProvider.submit("rebuild edge index friend_index;") 61 | 62 | Thread.sleep(5000) 63 | 64 | graphProvider.submit("use test_write_string;") 65 | val resultSet: ResultSet = 66 | graphProvider.submit("match (v:person_connector)-[e:friend_connector]-> () return e;") 67 | assert(resultSet.isSucceeded) 68 | assert(resultSet.getColumnNames.size() == 1) 69 | assert(resultSet.getRows.size() == 13) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.nar 17 | *.ear 18 | *.zip 19 | *.tar.gz 20 | *.rar 21 | 22 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 23 | hs_err_pid* 24 | 25 | # build target 26 | target/ 27 | 28 | # IDE 29 | .idea/ 30 | .eclipse/ 31 | *.iml 32 | 33 | spark-importer.ipr 34 | spark-importer.iws 35 | 36 | .DS_Store 37 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.connector.exception.IllegalOptionException 9 | import com.vesoft.nebula.connector.reader.NebulaRelation 10 | import com.vesoft.nebula.connector.writer.{ 11 | NebulaCommitMessage, 12 | NebulaEdgeWriter, 13 | NebulaVertexWriter, 14 | NebulaWriter, 15 | NebulaWriterResultRelation 16 | } 17 | import org.apache.spark.TaskContext 18 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} 19 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 20 | import org.apache.spark.sql.sources.{ 21 | BaseRelation, 22 | CreatableRelationProvider, 23 | DataSourceRegister, 24 | RelationProvider 25 | } 26 | import org.apache.spark.sql.types.StructType 27 | import org.slf4j.LoggerFactory 28 | import scala.collection.mutable 29 | 30 | class NebulaDataSource 31 | extends RelationProvider 32 | with CreatableRelationProvider 33 | with DataSourceRegister 34 | with Serializable { 35 | private val LOG = LoggerFactory.getLogger(this.getClass) 36 | 37 | /** 38 | * The string that represents the format that nebula data source provider uses. 39 | */ 40 | override def shortName(): String = "nebula" 41 | 42 | /** 43 | * Creates a {@link DataSourceReader} to scan the data from Nebula Graph. 44 | */ 45 | override def createRelation(sqlContext: SQLContext, 46 | parameters: Map[String, String]): BaseRelation = { 47 | val nebulaOptions = getNebulaOptions(parameters) 48 | 49 | LOG.info("create relation") 50 | val optionMap = new mutable.HashMap[String, String]() 51 | for (k: String <- parameters.keySet) { 52 | if (!k.equalsIgnoreCase("passwd")) { 53 | optionMap += (k -> parameters(k)) 54 | } 55 | } 56 | LOG.info(s"options ${optionMap}") 57 | 58 | NebulaRelation(sqlContext, nebulaOptions) 59 | } 60 | 61 | /** 62 | * Saves a DataFrame to a destination (using data source-specific parameters) 63 | */ 64 | override def createRelation(sqlContext: SQLContext, 65 | mode: SaveMode, 66 | parameters: Map[String, String], 67 | data: DataFrame): BaseRelation = { 68 | 69 | val nebulaOptions = getNebulaOptions(parameters) 70 | if (mode == SaveMode.Ignore || mode == SaveMode.ErrorIfExists) { 71 | LOG.warn(s"Currently do not support mode") 72 | } 73 | 74 | LOG.info("create writer") 75 | val optionMap = new mutable.HashMap[String, String]() 76 | for (k: String <- parameters.keySet) { 77 | if (!k.equalsIgnoreCase("passwd")) { 78 | optionMap += (k -> parameters(k)) 79 | } 80 | } 81 | LOG.info(s"options ${optionMap}") 82 | 83 | val schema = data.schema 84 | data.foreachPartition(iterator => { 85 | savePartition(nebulaOptions, schema, iterator) 86 | }) 87 | 88 | new NebulaWriterResultRelation(sqlContext, data.schema) 89 | } 90 | 91 | /** 92 | * construct nebula options with DataSourceOptions 93 | */ 94 | def getNebulaOptions(options: Map[String, String]): NebulaOptions = { 95 | val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(options)) 96 | nebulaOptions 97 | } 98 | 99 | private def savePartition(nebulaOptions: NebulaOptions, 100 | schema: StructType, 101 | iterator: Iterator[Row]): Unit = { 102 | val dataType = nebulaOptions.dataType 103 | val writer: NebulaWriter = { 104 | if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { 105 | val vertexFiled = nebulaOptions.vertexField 106 | val vertexIndex: Int = { 107 | var index: Int = -1 108 | for (i <- schema.fields.indices) { 109 | if (schema.fields(i).name.equals(vertexFiled)) { 110 | index = i 111 | } 112 | } 113 | if (index < 0) { 114 | throw new IllegalOptionException( 115 | s" vertex field ${vertexFiled} does not exist in dataframe") 116 | } 117 | index 118 | } 119 | new NebulaVertexWriter(nebulaOptions, vertexIndex, schema).asInstanceOf[NebulaWriter] 120 | } else { 121 | val srcVertexFiled = nebulaOptions.srcVertexField 122 | val dstVertexField = nebulaOptions.dstVertexField 123 | val rankExist = !nebulaOptions.rankField.isEmpty 124 | val edgeFieldsIndex = { 125 | var srcIndex: Int = -1 126 | var dstIndex: Int = -1 127 | var rankIndex: Int = -1 128 | for (i <- schema.fields.indices) { 129 | if (schema.fields(i).name.equals(srcVertexFiled)) { 130 | srcIndex = i 131 | } 132 | if (schema.fields(i).name.equals(dstVertexField)) { 133 | dstIndex = i 134 | } 135 | if (rankExist) { 136 | if (schema.fields(i).name.equals(nebulaOptions.rankField)) { 137 | rankIndex = i 138 | } 139 | } 140 | } 141 | // check src filed and dst field 142 | if (srcIndex < 0 || dstIndex < 0) { 143 | throw new IllegalOptionException( 144 | s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe") 145 | } 146 | // check rank field 147 | if (rankExist && rankIndex < 0) { 148 | throw new IllegalOptionException(s"rank field does not exist in dataframe") 149 | } 150 | 151 | if (!rankExist) { 152 | (srcIndex, dstIndex, Option.empty) 153 | } else { 154 | (srcIndex, dstIndex, Option(rankIndex)) 155 | } 156 | 157 | } 158 | new NebulaEdgeWriter(nebulaOptions, 159 | edgeFieldsIndex._1, 160 | edgeFieldsIndex._2, 161 | edgeFieldsIndex._3, 162 | schema).asInstanceOf[NebulaWriter] 163 | } 164 | } 165 | val message = writer.writeData(iterator) 166 | LOG.debug( 167 | s"spark partition id ${message.partitionId} write failed size: ${message.executeStatements.length}") 168 | if (message.executeStatements.nonEmpty) { 169 | LOG.error(s"failed execs:\n ${message.executeStatements.toString()}") 170 | } else { 171 | LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") 172 | } 173 | 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.Partition 10 | import org.apache.spark.sql.types.StructType 11 | import org.slf4j.{Logger, LoggerFactory} 12 | 13 | class NebulaEdgePartitionReader(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) 14 | extends NebulaIterator(index, nebulaOptions, schema) { 15 | 16 | override def hasNext(): Boolean = hasNextEdgeRow 17 | } 18 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaIterator.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils, PartitionUtils} 9 | import org.apache.spark.Partition 10 | import org.apache.spark.sql.catalyst.InternalRow 11 | import org.apache.spark.sql.types.StructType 12 | import org.slf4j.{Logger, LoggerFactory} 13 | 14 | /** 15 | * iterator for nebula vertex or edge data 16 | * convert each vertex data or edge data to Spark SQL's Row 17 | */ 18 | abstract class NebulaIterator extends Iterator[InternalRow] with NebulaReader { 19 | private val LOG: Logger = LoggerFactory.getLogger(this.getClass) 20 | 21 | def this(index: Partition, nebulaOptions: NebulaOptions, schema: StructType) { 22 | this() 23 | val totalPart = super.init(index.index, nebulaOptions, schema) 24 | // index starts with 0 25 | val nebulaPartition = index.asInstanceOf[NebulaPartition] 26 | val scanParts = 27 | nebulaPartition.getScanParts(totalPart, nebulaOptions.partitionNums.toInt) 28 | LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") 29 | scanPartIterator = scanParts.iterator 30 | } 31 | 32 | /** 33 | * whether this iterator can provide another element. 34 | */ 35 | override def hasNext: Boolean 36 | 37 | /** 38 | * Produces the next vertex or edge of this iterator. 39 | */ 40 | override def next(): InternalRow = super.getRow() 41 | } 42 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgeReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import java.util 9 | 10 | import com.vesoft.nebula.Value 11 | import com.vesoft.nebula.client.graph.data.{Relationship, ResultSet, ValueWrapper} 12 | import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} 13 | import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter 14 | import com.vesoft.nebula.connector.nebula.GraphProvider 15 | import org.apache.spark.sql.catalyst.InternalRow 16 | import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow 17 | import org.apache.spark.sql.types.StructType 18 | import org.slf4j.{Logger, LoggerFactory} 19 | 20 | import scala.collection.JavaConversions.asScalaBuffer 21 | import scala.collection.mutable 22 | import scala.collection.mutable.ListBuffer 23 | 24 | /** 25 | * create reader by ngql 26 | */ 27 | class NebulaNgqlEdgeReader extends Iterator[InternalRow] { 28 | 29 | private val LOG: Logger = LoggerFactory.getLogger(this.getClass) 30 | 31 | private var nebulaOptions: NebulaOptions = _ 32 | private var graphProvider: GraphProvider = _ 33 | private var schema: StructType = _ 34 | private var resultSet: ResultSet = _ 35 | private var edgeIterator: Iterator[ListBuffer[ValueWrapper]] = _ 36 | 37 | def this(nebulaOptions: NebulaOptions, schema: StructType) { 38 | this() 39 | this.schema = schema 40 | this.nebulaOptions = nebulaOptions 41 | this.graphProvider = new GraphProvider( 42 | nebulaOptions.getGraphAddress, 43 | nebulaOptions.user, 44 | nebulaOptions.passwd, 45 | nebulaOptions.timeout, 46 | nebulaOptions.enableGraphSSL, 47 | nebulaOptions.sslSignType, 48 | nebulaOptions.caSignParam, 49 | nebulaOptions.selfSignParam 50 | ) 51 | // add exception when session build failed 52 | graphProvider.switchSpace(nebulaOptions.spaceName) 53 | resultSet = graphProvider.submit(nebulaOptions.ngql) 54 | close() 55 | edgeIterator = query() 56 | } 57 | 58 | def query(): Iterator[ListBuffer[ValueWrapper]] = { 59 | val edges: ListBuffer[ListBuffer[ValueWrapper]] = new ListBuffer[ListBuffer[ValueWrapper]] 60 | val properties = nebulaOptions.getReturnCols 61 | for (i <- 0 until resultSet.rowsSize()) { 62 | val rowValues = resultSet.rowValues(i).values() 63 | for (j <- 0 until rowValues.size()) { 64 | val value = rowValues.get(j) 65 | val valueType = value.getValue.getSetField 66 | if (valueType == Value.EVAL) { 67 | val relationship = value.asRelationship() 68 | if (checkLabel(relationship)) { 69 | edges.append(convertToEdge(relationship, properties)) 70 | } 71 | } else if (valueType == Value.LVAL) { 72 | val list: mutable.Buffer[ValueWrapper] = value.asList() 73 | edges.appendAll( 74 | list.toStream 75 | .filter(e => checkLabel(e.asRelationship())) 76 | .map(e => convertToEdge(e.asRelationship(), properties)) 77 | ) 78 | } else { 79 | LOG.error(s"Exception convert edge type ${valueType} ") 80 | throw new RuntimeException(" convert value type failed"); 81 | } 82 | } 83 | } 84 | edges.iterator 85 | } 86 | 87 | def checkLabel(relationship: Relationship): Boolean = { 88 | this.nebulaOptions.label.equals(relationship.edgeName()) 89 | } 90 | 91 | def convertToEdge(relationship: Relationship, 92 | properties: List[String]): ListBuffer[ValueWrapper] = { 93 | val edge: ListBuffer[ValueWrapper] = new ListBuffer[ValueWrapper] 94 | edge.append(relationship.srcId()) 95 | edge.append(relationship.dstId()) 96 | edge.append(new ValueWrapper(new Value(3, relationship.ranking()), "utf-8")) 97 | if (properties == null || properties.isEmpty) 98 | return edge 99 | else { 100 | for (i <- properties.indices) { 101 | edge.append(relationship.properties().get(properties(i))) 102 | } 103 | } 104 | edge 105 | } 106 | 107 | override def hasNext(): Boolean = { 108 | edgeIterator.hasNext 109 | } 110 | 111 | override def next(): InternalRow = { 112 | val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) 113 | val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) 114 | 115 | val edge = edgeIterator.next(); 116 | for (i <- getters.indices) { 117 | val value: ValueWrapper = edge(i) 118 | var resolved = false 119 | if (value.isNull) { 120 | mutableRow.setNullAt(i) 121 | resolved = true 122 | } 123 | if (value.isString) { 124 | getters(i).apply(value.asString(), mutableRow, i) 125 | resolved = true 126 | } 127 | if (value.isDate) { 128 | getters(i).apply(value.asDate(), mutableRow, i) 129 | resolved = true 130 | } 131 | if (value.isTime) { 132 | getters(i).apply(value.asTime(), mutableRow, i) 133 | resolved = true 134 | } 135 | if (value.isDateTime) { 136 | getters(i).apply(value.asDateTime(), mutableRow, i) 137 | resolved = true 138 | } 139 | if (value.isLong) { 140 | getters(i).apply(value.asLong(), mutableRow, i) 141 | } 142 | if (value.isBoolean) { 143 | getters(i).apply(value.asBoolean(), mutableRow, i) 144 | } 145 | if (value.isDouble) { 146 | getters(i).apply(value.asDouble(), mutableRow, i) 147 | } 148 | if (value.isGeography) { 149 | getters(i).apply(value.asGeography(), mutableRow, i) 150 | } 151 | if (value.isDuration) { 152 | getters(i).apply(value.asDuration(), mutableRow, i) 153 | } 154 | } 155 | mutableRow 156 | 157 | } 158 | 159 | def close(): Unit = { 160 | graphProvider.close(); 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlRDD.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} 9 | import org.apache.spark.{Partition, TaskContext} 10 | import org.apache.spark.rdd.RDD 11 | import org.apache.spark.sql.SQLContext 12 | import org.apache.spark.sql.catalyst.InternalRow 13 | import org.apache.spark.sql.types.StructType 14 | 15 | import scala.collection.mutable.ListBuffer 16 | 17 | class NebulaNgqlRDD(val sqlContext: SQLContext, 18 | var nebulaOptions: NebulaOptions, 19 | schema: StructType) 20 | extends RDD[InternalRow](sqlContext.sparkContext, Nil) { 21 | 22 | /** 23 | * start to get edge data from query resultSet 24 | * 25 | * @param split 26 | * @param context 27 | * @return Iterator 28 | */ 29 | override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { 30 | new NebulaNgqlEdgeReader() 31 | } 32 | 33 | override def getPartitions: Array[Partition] = { 34 | val partitions = new Array[Partition](1) 35 | partitions(0) = NebulaNgqlPartition(0) 36 | partitions 37 | } 38 | 39 | } 40 | 41 | /** 42 | * An identifier for a partition in an NebulaRDD. 43 | */ 44 | case class NebulaNgqlPartition(indexNum: Int) extends Partition { 45 | override def index: Int = indexNum 46 | } 47 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRDD.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} 9 | import org.apache.spark.{Partition, TaskContext} 10 | import org.apache.spark.rdd.RDD 11 | import org.apache.spark.sql.SQLContext 12 | import org.apache.spark.sql.catalyst.InternalRow 13 | import org.apache.spark.sql.types.StructType 14 | 15 | class NebulaRDD(val sqlContext: SQLContext, var nebulaOptions: NebulaOptions, schema: StructType) 16 | extends RDD[InternalRow](sqlContext.sparkContext, Nil) { 17 | 18 | /** 19 | * start to scan vertex or edge data 20 | * 21 | * @param split 22 | * @param context 23 | * @return Iterator 24 | */ 25 | override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { 26 | val dataType = nebulaOptions.dataType 27 | if (DataTypeEnum.VERTEX.toString.equalsIgnoreCase(dataType)) 28 | new NebulaVertexPartitionReader(split, nebulaOptions, schema) 29 | else new NebulaEdgePartitionReader(split, nebulaOptions, schema) 30 | } 31 | 32 | override def getPartitions = { 33 | val partitionNumber = nebulaOptions.partitionNums.toInt 34 | val partitions = new Array[Partition](partitionNumber) 35 | for (i <- 0 until partitionNumber) { 36 | partitions(i) = NebulaPartition(i) 37 | } 38 | partitions 39 | } 40 | } 41 | 42 | /** 43 | * An identifier for a partition in an NebulaRDD. 44 | */ 45 | case class NebulaPartition(indexNum: Int) extends Partition { 46 | override def index: Int = indexNum 47 | 48 | /** 49 | * allocate scanPart to partition 50 | * 51 | * @param totalPart nebula data part num 52 | * @return scan data part list 53 | */ 54 | def getScanParts(totalPart: Int, totalPartition: Int): List[Int] = 55 | (indexNum + 1 to totalPart by totalPartition).toList 56 | 57 | } 58 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelation.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} 9 | import org.apache.spark.rdd.RDD 10 | import org.apache.spark.sql.{Row, SQLContext} 11 | import org.apache.spark.sql.sources.{BaseRelation, TableScan} 12 | import org.apache.spark.sql.types.{StructType} 13 | import org.slf4j.LoggerFactory 14 | 15 | case class NebulaRelation(override val sqlContext: SQLContext, nebulaOptions: NebulaOptions) 16 | extends BaseRelation 17 | with TableScan { 18 | private val LOG = LoggerFactory.getLogger(this.getClass) 19 | 20 | protected var datasetSchema: StructType = _ 21 | NebulaUtils.getSchema(nebulaOptions) 22 | 23 | override val needConversion: Boolean = false 24 | 25 | override def schema: StructType = { 26 | if (datasetSchema == null) { 27 | datasetSchema = NebulaUtils.getSchema(nebulaOptions) 28 | } 29 | datasetSchema 30 | } 31 | 32 | override def buildScan(): RDD[Row] = { 33 | 34 | if (datasetSchema == null) { 35 | datasetSchema = NebulaUtils.getSchema(nebulaOptions) 36 | } 37 | if (nebulaOptions.ngql != null && nebulaOptions.ngql.nonEmpty) { 38 | new NebulaNgqlRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]] 39 | } else { 40 | new NebulaRDD(sqlContext, nebulaOptions, datasetSchema).asInstanceOf[RDD[Row]] 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaRelationProvider.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{NebulaOptions, OperaType} 9 | import org.apache.spark.sql.SQLContext 10 | import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} 11 | 12 | class NebulaRelationProvider extends RelationProvider with DataSourceRegister { 13 | 14 | /** 15 | * The string that represents the format that nebula data source provider uses. 16 | */ 17 | override def shortName(): String = "nebula" 18 | 19 | /** 20 | * Returns a new base relation with the given parameters. 21 | * you can see it as reader. 22 | */ 23 | override def createRelation(sqlContext: SQLContext, 24 | parameters: Map[String, String]): BaseRelation = { 25 | val nebulaOptions = new NebulaOptions(parameters, OperaType.READ) 26 | NebulaRelation(sqlContext, nebulaOptions) 27 | } 28 | 29 | } 30 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.Partition 10 | import org.apache.spark.sql.types.StructType 11 | 12 | class NebulaVertexPartitionReader(index: Partition, 13 | nebulaOptions: NebulaOptions, 14 | schema: StructType) 15 | extends NebulaIterator(index, nebulaOptions, schema) { 16 | 17 | override def hasNext: Boolean = hasNextVertexRow 18 | 19 | } 20 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | case class NebulaCommitMessage(partitionId: Int, executeStatements: List[String]) 9 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.{KeyPolicy, NebulaEdge, NebulaEdges, NebulaOptions, WriteMode} 9 | import org.apache.spark.TaskContext 10 | import org.apache.spark.sql.Row 11 | import org.apache.spark.sql.catalyst.InternalRow 12 | import org.apache.spark.sql.types.StructType 13 | import org.slf4j.LoggerFactory 14 | 15 | import scala.collection.mutable.ListBuffer 16 | 17 | class NebulaEdgeWriter(nebulaOptions: NebulaOptions, 18 | srcIndex: Int, 19 | dstIndex: Int, 20 | rankIndex: Option[Int], 21 | schema: StructType) 22 | extends NebulaWriter(nebulaOptions, schema) { 23 | 24 | private val LOG = LoggerFactory.getLogger(this.getClass) 25 | 26 | val rankIdx = if (rankIndex.isDefined) rankIndex.get else -1 27 | val propNames = NebulaExecutor.assignEdgePropNames(schema, 28 | srcIndex, 29 | dstIndex, 30 | rankIdx, 31 | nebulaOptions.srcAsProp, 32 | nebulaOptions.dstAsProp, 33 | nebulaOptions.rankAsProp) 34 | val fieldTypMap: Map[String, Integer] = 35 | if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() 36 | else metaProvider.getEdgeSchema(nebulaOptions.spaceName, nebulaOptions.label) 37 | 38 | val srcPolicy = 39 | if (nebulaOptions.srcPolicy.isEmpty) Option.empty 40 | else Option(KeyPolicy.withName(nebulaOptions.srcPolicy)) 41 | val dstPolicy = { 42 | if (nebulaOptions.dstPolicy.isEmpty) Option.empty 43 | else Option(KeyPolicy.withName(nebulaOptions.dstPolicy)) 44 | } 45 | 46 | /** buffer to save batch edges */ 47 | var edges: ListBuffer[NebulaEdge] = new ListBuffer() 48 | 49 | prepareSpace() 50 | 51 | override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = { 52 | while (iterator.hasNext) { 53 | val internalRow = rowEncoder.toRow(iterator.next()) 54 | write(internalRow) 55 | } 56 | if (edges.nonEmpty) { 57 | execute() 58 | } 59 | graphProvider.close() 60 | metaProvider.close() 61 | NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList) 62 | } 63 | 64 | /** 65 | * write one edge record to buffer 66 | */ 67 | override def write(row: InternalRow): Unit = { 68 | val srcId = NebulaExecutor.extraID(schema, row, srcIndex, srcPolicy, isVidStringType) 69 | val dstId = NebulaExecutor.extraID(schema, row, dstIndex, dstPolicy, isVidStringType) 70 | val rank = 71 | if (rankIndex.isEmpty) Option.empty 72 | else Option(NebulaExecutor.extraRank(schema, row, rankIndex.get)) 73 | val values = 74 | if (nebulaOptions.writeMode == WriteMode.DELETE) List() 75 | else 76 | NebulaExecutor.assignEdgeValues(schema, 77 | row, 78 | srcIndex, 79 | dstIndex, 80 | rankIdx, 81 | nebulaOptions.srcAsProp, 82 | nebulaOptions.dstAsProp, 83 | nebulaOptions.rankAsProp, 84 | fieldTypMap) 85 | val nebulaEdge = NebulaEdge(srcId, dstId, rank, values) 86 | edges.append(nebulaEdge) 87 | if (edges.size >= nebulaOptions.batch) { 88 | execute() 89 | } 90 | } 91 | 92 | /** 93 | * submit buffer edges to nebula 94 | */ 95 | def execute(): Unit = { 96 | val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) 97 | val exec = nebulaOptions.writeMode match { 98 | case WriteMode.INSERT => 99 | NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges, nebulaOptions.overwrite) 100 | case WriteMode.UPDATE => 101 | NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) 102 | case WriteMode.DELETE => 103 | NebulaExecutor.toDeleteExecuteStatement(nebulaOptions.label, nebulaEdges) 104 | case _ => 105 | throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") 106 | } 107 | edges.clear() 108 | submit(exec) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaInsertableRelation.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import org.apache.spark.sql.DataFrame 9 | import org.apache.spark.sql.sources.InsertableRelation 10 | 11 | class NebulaInsertableRelation extends InsertableRelation { 12 | override def insert(data: DataFrame, overwrite: Boolean): Unit = {} 13 | } 14 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.{ 9 | KeyPolicy, 10 | NebulaOptions, 11 | NebulaVertex, 12 | NebulaVertices, 13 | WriteMode 14 | } 15 | import org.apache.spark.TaskContext 16 | import org.apache.spark.sql.Row 17 | import org.apache.spark.sql.catalyst.InternalRow 18 | import org.apache.spark.sql.types.StructType 19 | import org.slf4j.LoggerFactory 20 | 21 | import scala.collection.mutable.ListBuffer 22 | 23 | class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) 24 | extends NebulaWriter(nebulaOptions, schema) { 25 | 26 | private val LOG = LoggerFactory.getLogger(this.getClass) 27 | 28 | val propNames = NebulaExecutor.assignVertexPropNames(schema, vertexIndex, nebulaOptions.vidAsProp) 29 | val fieldTypMap: Map[String, Integer] = 30 | if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() 31 | else metaProvider.getTagSchema(nebulaOptions.spaceName, nebulaOptions.label) 32 | 33 | val policy = { 34 | if (nebulaOptions.vidPolicy.isEmpty) Option.empty 35 | else Option(KeyPolicy.withName(nebulaOptions.vidPolicy)) 36 | } 37 | 38 | /** buffer to save batch vertices */ 39 | var vertices: ListBuffer[NebulaVertex] = new ListBuffer() 40 | 41 | prepareSpace() 42 | 43 | override def writeData(iterator: Iterator[Row]): NebulaCommitMessage = { 44 | while (iterator.hasNext) { 45 | val internalRow = rowEncoder.toRow(iterator.next()) 46 | write(internalRow) 47 | } 48 | if (vertices.nonEmpty) { 49 | execute() 50 | } 51 | graphProvider.close() 52 | metaProvider.close() 53 | NebulaCommitMessage(TaskContext.getPartitionId(), failedExecs.toList) 54 | } 55 | 56 | /** 57 | * write one vertex row to buffer 58 | */ 59 | override def write(row: InternalRow): Unit = { 60 | val vertex = 61 | NebulaExecutor.extraID(schema, row, vertexIndex, policy, isVidStringType) 62 | val values = 63 | if (nebulaOptions.writeMode == WriteMode.DELETE) List() 64 | else 65 | NebulaExecutor.assignVertexPropValues(schema, 66 | row, 67 | vertexIndex, 68 | nebulaOptions.vidAsProp, 69 | fieldTypMap) 70 | val nebulaVertex = NebulaVertex(vertex, values) 71 | vertices.append(nebulaVertex) 72 | if (vertices.size >= nebulaOptions.batch) { 73 | execute() 74 | } 75 | } 76 | 77 | /** 78 | * submit buffer vertices to nebula 79 | */ 80 | private def execute(): Unit = { 81 | val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) 82 | val exec = nebulaOptions.writeMode match { 83 | case WriteMode.INSERT => 84 | NebulaExecutor.toExecuteSentence(nebulaOptions.label, 85 | nebulaVertices, 86 | nebulaOptions.overwrite) 87 | case WriteMode.UPDATE => 88 | NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) 89 | case WriteMode.DELETE => 90 | NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, nebulaOptions.deleteEdge) 91 | case _ => 92 | throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") 93 | } 94 | vertices.clear() 95 | submit(exec) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import java.util.concurrent.TimeUnit 9 | 10 | import com.google.common.util.concurrent.RateLimiter 11 | import com.vesoft.nebula.connector.NebulaOptions 12 | import com.vesoft.nebula.connector.nebula.{GraphProvider, MetaProvider, VidType} 13 | import org.apache.spark.TaskContext 14 | import org.apache.spark.sql.Row 15 | import org.apache.spark.sql.catalyst.InternalRow 16 | import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} 17 | import org.apache.spark.sql.types.StructType 18 | import org.slf4j.LoggerFactory 19 | 20 | import scala.collection.mutable.ListBuffer 21 | 22 | abstract class NebulaWriter(nebulaOptions: NebulaOptions, schema: StructType) extends Serializable { 23 | private val LOG = LoggerFactory.getLogger(this.getClass) 24 | 25 | protected val rowEncoder: ExpressionEncoder[Row] = RowEncoder(schema).resolveAndBind() 26 | protected val failedExecs: ListBuffer[String] = new ListBuffer[String] 27 | 28 | val metaProvider = new MetaProvider( 29 | nebulaOptions.getMetaAddress, 30 | nebulaOptions.timeout, 31 | nebulaOptions.connectionRetry, 32 | nebulaOptions.executionRetry, 33 | nebulaOptions.enableMetaSSL, 34 | nebulaOptions.sslSignType, 35 | nebulaOptions.caSignParam, 36 | nebulaOptions.selfSignParam 37 | ) 38 | val graphProvider = new GraphProvider( 39 | nebulaOptions.getGraphAddress, 40 | nebulaOptions.user, 41 | nebulaOptions.passwd, 42 | nebulaOptions.timeout, 43 | nebulaOptions.enableGraphSSL, 44 | nebulaOptions.sslSignType, 45 | nebulaOptions.caSignParam, 46 | nebulaOptions.selfSignParam 47 | ) 48 | val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING 49 | 50 | def prepareSpace(): Unit = { 51 | graphProvider.switchSpace(nebulaOptions.spaceName) 52 | } 53 | 54 | def submit(exec: String): Unit = { 55 | @transient val rateLimiter = RateLimiter.create(nebulaOptions.rateLimit) 56 | if (rateLimiter.tryAcquire(nebulaOptions.rateTimeOut, TimeUnit.MILLISECONDS)) { 57 | val result = graphProvider.submit(exec) 58 | if (!result.isSucceeded) { 59 | failedExecs.append(exec) 60 | if (nebulaOptions.disableWriteLog) { 61 | LOG.error(s"write failed: " + result.getErrorMessage) 62 | } else { 63 | LOG.error(s"write failed: ${result.getErrorMessage} failed statement: \n ${exec}") 64 | } 65 | } else { 66 | LOG.info(s"batch write succeed") 67 | LOG.debug(s"batch write succeed: ${exec}") 68 | } 69 | } else { 70 | failedExecs.append(exec) 71 | LOG.error(s"failed to acquire reteLimiter for statement {$exec}") 72 | } 73 | } 74 | 75 | def write(row: InternalRow): Unit 76 | 77 | /** write dataframe data into nebula for each partition */ 78 | def writeData(iterator: Iterator[Row]): NebulaCommitMessage 79 | 80 | } 81 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterResultRelation.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import org.apache.spark.sql.SQLContext 9 | import org.apache.spark.sql.sources.BaseRelation 10 | import org.apache.spark.sql.types.StructType 11 | 12 | class NebulaWriterResultRelation(SQLContext: SQLContext, userDefSchema: StructType) 13 | extends BaseRelation { 14 | override def sqlContext: SQLContext = SQLContext 15 | 16 | override def schema: StructType = userDefSchema 17 | } 18 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/resources/edge.csv: -------------------------------------------------------------------------------- 1 | id1,id2,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14 2 | 1,2,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2) 3 | 2,3,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4) 4 | 3,4,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6) 5 | 4,5,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7) 6 | 5,6,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5) 7 | 6,7,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)" 8 | 7,1,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)" 9 | 8,1,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)" 10 | 9,1,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)" 11 | 10,2,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)" 12 | -1,5,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 13 | -2,6,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 14 | -3,7,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 15 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Global logging configuration 2 | log4j.rootLogger=INFO, stdout 3 | # Console output... 4 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 5 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n 7 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/resources/vertex.csv: -------------------------------------------------------------------------------- 1 | id,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15 2 | 1,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2),"duration({years:1,months:1,seconds:1})" 3 | 2,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4),"duration({years:1,months:1,seconds:1})" 4 | 3,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6),"duration({years:1,months:1,seconds:1})" 5 | 4,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7),"duration({years:1,months:1,seconds:1})" 6 | 5,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5),"duration({years:1,months:1,seconds:1})" 7 | 6,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 8 | 7,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 9 | 8,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 10 | 9,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 11 | 10,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 12 | -1,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 13 | -2,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 14 | -3,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 15 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.connector.utils.SparkValidate 9 | import org.apache.spark.sql.SparkSession 10 | import org.scalatest.funsuite.AnyFunSuite 11 | 12 | class SparkVersionValidateSuite { 13 | class SparkVersionValidateSuite extends AnyFunSuite { 14 | test("spark version validate") { 15 | try { 16 | val version = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN") 17 | SparkValidate.validate("2.2.*") 18 | } catch { 19 | case e: Exception => assert(false) 20 | } 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/mock/SparkMock.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.mock 7 | 8 | import com.facebook.thrift.protocol.TCompactProtocol 9 | import com.vesoft.nebula.connector.connector.NebulaDataFrameWriter 10 | import com.vesoft.nebula.connector.{ 11 | NebulaConnectionConfig, 12 | WriteMode, 13 | WriteNebulaEdgeConfig, 14 | WriteNebulaVertexConfig 15 | } 16 | import org.apache.spark.SparkConf 17 | import org.apache.spark.sql.SparkSession 18 | 19 | object SparkMock { 20 | 21 | /** 22 | * write nebula vertex with insert mode 23 | */ 24 | def writeVertex(): Unit = { 25 | val sparkConf = new SparkConf 26 | sparkConf 27 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 28 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 29 | val spark = SparkSession 30 | .builder() 31 | .master("local") 32 | .config(sparkConf) 33 | .getOrCreate() 34 | 35 | val df = spark.read 36 | .option("header", true) 37 | .csv("src/test/resources/vertex.csv") 38 | 39 | val config = 40 | NebulaConnectionConfig 41 | .builder() 42 | .withMetaAddress("127.0.0.1:9559") 43 | .withGraphAddress("127.0.0.1:9669") 44 | .withConnectionRetry(2) 45 | .build() 46 | val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig 47 | .builder() 48 | .withSpace("test_write_string") 49 | .withTag("person_connector") 50 | .withVidField("id") 51 | .withVidAsProp(false) 52 | .withBatch(5) 53 | .build() 54 | df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() 55 | 56 | spark.stop() 57 | } 58 | 59 | /** 60 | * write nebula vertex with delete mode 61 | */ 62 | def deleteVertex(): Unit = { 63 | val sparkConf = new SparkConf 64 | sparkConf 65 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 66 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 67 | val spark = SparkSession 68 | .builder() 69 | .master("local") 70 | .config(sparkConf) 71 | .getOrCreate() 72 | 73 | val df = spark.read 74 | .option("header", true) 75 | .csv("src/test/resources/vertex.csv") 76 | 77 | val config = 78 | NebulaConnectionConfig 79 | .builder() 80 | .withMetaAddress("127.0.0.1:9559") 81 | .withGraphAddress("127.0.0.1:9669") 82 | .withConnectionRetry(2) 83 | .build() 84 | val nebulaWriteVertexConfig: WriteNebulaVertexConfig = WriteNebulaVertexConfig 85 | .builder() 86 | .withSpace("test_write_string") 87 | .withTag("person_connector") 88 | .withVidField("id") 89 | .withVidAsProp(false) 90 | .withWriteMode(WriteMode.DELETE) 91 | .withBatch(5) 92 | .build() 93 | df.write.nebula(config, nebulaWriteVertexConfig).writeVertices() 94 | 95 | spark.stop() 96 | } 97 | 98 | /** 99 | * write nebula edge with insert mode 100 | */ 101 | def writeEdge(): Unit = { 102 | val sparkConf = new SparkConf 103 | sparkConf 104 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 105 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 106 | val spark = SparkSession 107 | .builder() 108 | .master("local") 109 | .config(sparkConf) 110 | .getOrCreate() 111 | 112 | val df = spark.read 113 | .option("header", true) 114 | .csv("src/test/resources/edge.csv") 115 | 116 | val config = 117 | NebulaConnectionConfig 118 | .builder() 119 | .withMetaAddress("127.0.0.1:9559") 120 | .withGraphAddress("127.0.0.1:9669") 121 | .withConnectionRetry(2) 122 | .build() 123 | val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig 124 | .builder() 125 | .withSpace("test_write_string") 126 | .withEdge("friend_connector") 127 | .withSrcIdField("id1") 128 | .withDstIdField("id2") 129 | .withRankField("col3") 130 | .withRankAsProperty(true) 131 | .withBatch(5) 132 | .build() 133 | df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges() 134 | 135 | spark.stop() 136 | } 137 | 138 | /** 139 | * write nebula edge with delete mode 140 | */ 141 | def deleteEdge(): Unit = { 142 | val sparkConf = new SparkConf 143 | sparkConf 144 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 145 | .registerKryoClasses(Array[Class[_]](classOf[TCompactProtocol])) 146 | val spark = SparkSession 147 | .builder() 148 | .master("local") 149 | .config(sparkConf) 150 | .getOrCreate() 151 | 152 | val df = spark.read 153 | .option("header", true) 154 | .csv("src/test/resources/edge.csv") 155 | 156 | val config = 157 | NebulaConnectionConfig 158 | .builder() 159 | .withMetaAddress("127.0.0.1:9559") 160 | .withGraphAddress("127.0.0.1:9669") 161 | .withConnectionRetry(2) 162 | .build() 163 | val nebulaWriteEdgeConfig: WriteNebulaEdgeConfig = WriteNebulaEdgeConfig 164 | .builder() 165 | .withSpace("test_write_string") 166 | .withEdge("friend_connector") 167 | .withSrcIdField("id1") 168 | .withDstIdField("id2") 169 | .withRankField("col3") 170 | .withRankAsProperty(true) 171 | .withWriteMode(WriteMode.DELETE) 172 | .withBatch(5) 173 | .build() 174 | df.write.nebula(config, nebulaWriteEdgeConfig).writeEdges() 175 | 176 | spark.stop() 177 | } 178 | 179 | } 180 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.client.graph.data.ResultSet 9 | import com.vesoft.nebula.connector.connector.Address 10 | import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | 19 | override def beforeAll(): Unit = { 20 | val graphMock = new NebulaGraphMock 21 | graphMock.mockStringIdGraphSchema() 22 | graphMock.mockIntIdGraphSchema() 23 | graphMock.close() 24 | SparkMock.writeVertex() 25 | } 26 | 27 | test("write vertex into test_write_string space with delete mode") { 28 | SparkMock.deleteVertex() 29 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 30 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 31 | 32 | graphProvider.switchSpace("test_write_string") 33 | val resultSet: ResultSet = 34 | graphProvider.submit("use test_write_string;" 35 | + "match (v:person_connector) return v limit 100000;") 36 | assert(resultSet.isSucceeded) 37 | assert(resultSet.getColumnNames.size() == 1) 38 | assert(resultSet.isEmpty) 39 | } 40 | 41 | test("write edge into test_write_string space with delete mode") { 42 | SparkMock.deleteEdge() 43 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 44 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 45 | 46 | graphProvider.switchSpace("test_write_string") 47 | val resultSet: ResultSet = 48 | graphProvider.submit("use test_write_string;" 49 | + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e;") 50 | assert(resultSet.isSucceeded) 51 | assert(resultSet.getColumnNames.size() == 1) 52 | assert(resultSet.isEmpty) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /nebula-spark-connector_2.2/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.client.graph.data.ResultSet 9 | import com.vesoft.nebula.connector.connector.Address 10 | import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | 19 | override def beforeAll(): Unit = { 20 | val graphMock = new NebulaGraphMock 21 | graphMock.mockStringIdGraphSchema() 22 | graphMock.mockIntIdGraphSchema() 23 | graphMock.close() 24 | } 25 | 26 | test("write vertex into test_write_string space with insert mode") { 27 | SparkMock.writeVertex() 28 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 29 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 30 | 31 | graphProvider.switchSpace("test_write_string") 32 | val createIndexResult: ResultSet = graphProvider.submit( 33 | "use test_write_string; " 34 | + "create tag index if not exists person_index on person_connector(col1(20));") 35 | Thread.sleep(5000) 36 | graphProvider.submit("rebuild tag index person_index;") 37 | 38 | Thread.sleep(5000) 39 | 40 | graphProvider.submit("use test_write_string;") 41 | val resultSet: ResultSet = 42 | graphProvider.submit("match (v:person_connector) return v;") 43 | assert(resultSet.getColumnNames.size() == 1) 44 | assert(resultSet.getRows.size() == 13) 45 | 46 | for (i <- 0 until resultSet.getRows.size) { 47 | println(resultSet.rowValues(i).toString) 48 | } 49 | } 50 | 51 | test("write edge into test_write_string space with insert mode") { 52 | SparkMock.writeEdge() 53 | 54 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 55 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 56 | 57 | graphProvider.switchSpace("test_write_string") 58 | val createIndexResult: ResultSet = graphProvider.submit( 59 | "use test_write_string; " 60 | + "create edge index if not exists friend_index on friend_connector(col1(20));") 61 | Thread.sleep(5000) 62 | graphProvider.submit("rebuild edge index friend_index;") 63 | 64 | Thread.sleep(5000) 65 | 66 | graphProvider.submit("use test_write_string;") 67 | val resultSet: ResultSet = 68 | graphProvider.submit("match (v:person_connector)-[e:friend_connector]-> () return e;") 69 | assert(resultSet.getColumnNames.size() == 1) 70 | assert(resultSet.getRows.size() == 13) 71 | 72 | for (i <- 0 until resultSet.getRows.size) { 73 | println(resultSet.rowValues(i).toString) 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.nar 17 | *.ear 18 | *.zip 19 | *.tar.gz 20 | *.rar 21 | 22 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 23 | hs_err_pid* 24 | 25 | # build target 26 | target/ 27 | 28 | # IDE 29 | .idea/ 30 | .eclipse/ 31 | *.iml 32 | 33 | spark-importer.ipr 34 | spark-importer.iws 35 | 36 | .DS_Store 37 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaDataSource.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import java.util 9 | import java.util.Map.Entry 10 | 11 | import com.vesoft.nebula.connector.nebula.MetaProvider 12 | import com.vesoft.nebula.connector.reader.SimpleScanBuilder 13 | import com.vesoft.nebula.connector.utils.Validations 14 | import com.vesoft.nebula.meta.ColumnDef 15 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 16 | import org.apache.spark.sql.connector.catalog.{ 17 | SupportsRead, 18 | SupportsWrite, 19 | Table, 20 | TableCapability, 21 | TableProvider 22 | } 23 | import org.apache.spark.sql.connector.expressions.Transform 24 | import org.apache.spark.sql.connector.read.ScanBuilder 25 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} 26 | import org.apache.spark.sql.sources.DataSourceRegister 27 | import org.apache.spark.sql.types.{DataTypes, StructField, StructType} 28 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 29 | import org.slf4j.LoggerFactory 30 | 31 | import scala.collection.mutable.ListBuffer 32 | import scala.jdk.CollectionConverters.asScalaSetConverter 33 | 34 | class NebulaDataSource extends TableProvider with DataSourceRegister { 35 | private val LOG = LoggerFactory.getLogger(this.getClass) 36 | 37 | Validations.validateSparkVersion("3.*") 38 | 39 | private var schema: StructType = null 40 | private var nebulaOptions: NebulaOptions = _ 41 | 42 | /** 43 | * The string that represents the format that nebula data source provider uses. 44 | */ 45 | override def shortName(): String = "nebula" 46 | 47 | override def supportsExternalMetadata(): Boolean = true 48 | 49 | override def inferSchema(caseInsensitiveStringMap: CaseInsensitiveStringMap): StructType = { 50 | if (schema == null) { 51 | nebulaOptions = getNebulaOptions(caseInsensitiveStringMap) 52 | if (nebulaOptions.operaType == OperaType.READ) { 53 | schema = NebulaUtils.getSchema(nebulaOptions) 54 | } else { 55 | schema = new StructType() 56 | } 57 | } 58 | schema 59 | } 60 | 61 | override def getTable(tableSchema: StructType, 62 | transforms: Array[Transform], 63 | map: util.Map[String, String]): Table = { 64 | if (nebulaOptions == null) { 65 | nebulaOptions = getNebulaOptions(new CaseInsensitiveStringMap(map)) 66 | } 67 | new NebulaTable(tableSchema, nebulaOptions) 68 | } 69 | 70 | /** 71 | * construct nebula options with DataSourceOptions 72 | */ 73 | private def getNebulaOptions( 74 | caseInsensitiveStringMap: CaseInsensitiveStringMap): NebulaOptions = { 75 | var parameters: Map[String, String] = Map() 76 | for (entry: Entry[String, String] <- caseInsensitiveStringMap 77 | .asCaseSensitiveMap() 78 | .entrySet() 79 | .asScala) { 80 | parameters += (entry.getKey -> entry.getValue) 81 | } 82 | val nebulaOptions = new NebulaOptions(CaseInsensitiveMap(parameters)) 83 | nebulaOptions 84 | } 85 | 86 | } 87 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/NebulaTable.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import java.util 9 | import java.util.Map.Entry 10 | import com.vesoft.nebula.connector.reader.SimpleScanBuilder 11 | import com.vesoft.nebula.connector.writer.NebulaWriterBuilder 12 | import org.apache.spark.sql.SaveMode 13 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 14 | import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} 15 | import org.apache.spark.sql.connector.read.ScanBuilder 16 | import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} 17 | import org.apache.spark.sql.types.StructType 18 | import org.apache.spark.sql.util.CaseInsensitiveStringMap 19 | import org.slf4j.LoggerFactory 20 | 21 | import scala.collection.JavaConverters._ 22 | import scala.collection.mutable 23 | 24 | class NebulaTable(schema: StructType, nebulaOptions: NebulaOptions) 25 | extends Table 26 | with SupportsRead 27 | with SupportsWrite { 28 | 29 | private val LOG = LoggerFactory.getLogger(this.getClass) 30 | 31 | /** 32 | * Creates a {@link DataSourceReader} to scan the data from Nebula Graph. 33 | */ 34 | override def newScanBuilder(caseInsensitiveStringMap: CaseInsensitiveStringMap): ScanBuilder = { 35 | LOG.info("create scan builder") 36 | val options = new mutable.HashMap[String, String]() 37 | val parameters = caseInsensitiveStringMap.asCaseSensitiveMap().asScala 38 | for (k: String <- parameters.keySet) { 39 | if (!k.equalsIgnoreCase("passwd")) { 40 | options += (k -> parameters(k)) 41 | } 42 | } 43 | LOG.info(s"options ${options}") 44 | 45 | new SimpleScanBuilder(nebulaOptions, schema) 46 | } 47 | 48 | /** 49 | * Creates an optional {@link DataSourceWriter} to save the data to Nebula Graph. 50 | */ 51 | override def newWriteBuilder(logicalWriteInfo: LogicalWriteInfo): WriteBuilder = { 52 | LOG.info("create writer") 53 | val options = new mutable.HashMap[String, String]() 54 | val parameters = logicalWriteInfo.options().asCaseSensitiveMap().asScala 55 | for (k: String <- parameters.keySet) { 56 | if (!k.equalsIgnoreCase("passwd")) { 57 | options += (k -> parameters(k)) 58 | } 59 | } 60 | LOG.info(s"options ${options}") 61 | new NebulaWriterBuilder(logicalWriteInfo.schema(), SaveMode.Append, nebulaOptions) 62 | } 63 | 64 | /** 65 | * NebulaGraph table name 66 | */ 67 | override def name(): String = { 68 | nebulaOptions.label 69 | } 70 | 71 | override def schema(): StructType = schema 72 | 73 | override def capabilities(): util.Set[TableCapability] = 74 | Set( 75 | TableCapability.BATCH_READ, 76 | TableCapability.BATCH_WRITE, 77 | TableCapability.ACCEPT_ANY_SCHEMA, 78 | TableCapability.OVERWRITE_BY_FILTER, 79 | TableCapability.OVERWRITE_DYNAMIC, 80 | TableCapability.STREAMING_WRITE, 81 | TableCapability.MICRO_BATCH_READ 82 | ).asJava 83 | 84 | } 85 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaEdgePartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.sql.types.StructType 10 | 11 | class NebulaEdgePartitionReader(index: Int, nebulaOptions: NebulaOptions, schema: StructType) 12 | extends NebulaPartitionReader(index, nebulaOptions, schema) { 13 | 14 | override def next(): Boolean = hasNextEdgeRow 15 | } 16 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaNgqlEdgePartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.Value 9 | import com.vesoft.nebula.client.graph.data.{Relationship, ResultSet, ValueWrapper} 10 | import com.vesoft.nebula.connector.NebulaUtils.NebulaValueGetter 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import com.vesoft.nebula.connector.{NebulaOptions, NebulaUtils} 13 | import org.apache.spark.sql.catalyst.InternalRow 14 | import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow 15 | import org.apache.spark.sql.connector.read.PartitionReader 16 | import org.apache.spark.sql.types.StructType 17 | import org.slf4j.{Logger, LoggerFactory} 18 | 19 | import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` 20 | import scala.collection.mutable 21 | import scala.collection.mutable.ListBuffer 22 | import scala.jdk.CollectionConverters.asScalaBufferConverter 23 | 24 | /** 25 | * create reader by ngql 26 | */ 27 | class NebulaNgqlEdgePartitionReader extends PartitionReader[InternalRow] { 28 | 29 | private val LOG: Logger = LoggerFactory.getLogger(this.getClass) 30 | 31 | private var nebulaOptions: NebulaOptions = _ 32 | private var graphProvider: GraphProvider = _ 33 | private var schema: StructType = _ 34 | private var resultSet: ResultSet = _ 35 | private var edgeIterator: Iterator[ListBuffer[ValueWrapper]] = _ 36 | 37 | def this(nebulaOptions: NebulaOptions, schema: StructType) { 38 | this() 39 | this.schema = schema 40 | this.nebulaOptions = nebulaOptions 41 | this.graphProvider = new GraphProvider( 42 | nebulaOptions.getGraphAddress, 43 | nebulaOptions.user, 44 | nebulaOptions.passwd, 45 | nebulaOptions.timeout, 46 | nebulaOptions.enableGraphSSL, 47 | nebulaOptions.sslSignType, 48 | nebulaOptions.caSignParam, 49 | nebulaOptions.selfSignParam 50 | ) 51 | // add exception when session build failed 52 | graphProvider.switchSpace(nebulaOptions.spaceName) 53 | resultSet = graphProvider.submit(nebulaOptions.ngql) 54 | edgeIterator = query() 55 | } 56 | 57 | def query(): Iterator[ListBuffer[ValueWrapper]] = { 58 | val edges: ListBuffer[ListBuffer[ValueWrapper]] = new ListBuffer[ListBuffer[ValueWrapper]] 59 | val properties = nebulaOptions.getReturnCols 60 | for (i <- 0 until resultSet.rowsSize()) { 61 | val rowValues = resultSet.rowValues(i).values() 62 | for (j <- 0 until rowValues.size()) { 63 | val value = rowValues.get(j) 64 | val valueType = value.getValue.getSetField 65 | if (valueType == Value.EVAL) { 66 | val relationship = value.asRelationship() 67 | if (checkLabel(relationship)) { 68 | edges.append(convertToEdge(relationship, properties)) 69 | } 70 | } else if (valueType == Value.LVAL) { 71 | val list: mutable.Buffer[ValueWrapper] = value.asList().asScala 72 | edges.appendAll( 73 | list.toStream 74 | .filter(e => e != null && e.isEdge() && checkLabel(e.asRelationship())) 75 | .map(e => convertToEdge(e.asRelationship(), properties)) 76 | ) 77 | } else if (valueType == Value.PVAL){ 78 | val list: java.util.List[Relationship] = value.asPath().getRelationships() 79 | edges.appendAll( 80 | list.toStream 81 | .filter(e => checkLabel(e)) 82 | .map(e => convertToEdge(e, properties)) 83 | ) 84 | } else if (valueType != Value.NVAL && valueType != 0) { 85 | LOG.error(s"Unexpected edge type encountered: ${valueType}. Only edge or path should be returned.") 86 | throw new RuntimeException("Invalid nGQL return type. Value type conversion failed."); 87 | } 88 | } 89 | } 90 | edges.iterator 91 | } 92 | 93 | def checkLabel(relationship: Relationship): Boolean = { 94 | this.nebulaOptions.label.equals(relationship.edgeName()) 95 | } 96 | 97 | def convertToEdge(relationship: Relationship, 98 | properties: List[String]): ListBuffer[ValueWrapper] = { 99 | val edge: ListBuffer[ValueWrapper] = new ListBuffer[ValueWrapper] 100 | edge.append(relationship.srcId()) 101 | edge.append(relationship.dstId()) 102 | edge.append(new ValueWrapper(new Value(Value.IVAL, relationship.ranking()), "utf-8")) 103 | if (properties == null || properties.isEmpty) 104 | return edge 105 | else { 106 | for (i <- properties.indices) { 107 | edge.append(relationship.properties().get(properties(i))) 108 | } 109 | } 110 | edge 111 | } 112 | 113 | override def next(): Boolean = { 114 | edgeIterator.hasNext 115 | } 116 | 117 | override def get(): InternalRow = { 118 | val getters: Array[NebulaValueGetter] = NebulaUtils.makeGetters(schema) 119 | val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) 120 | 121 | val edge = edgeIterator.next(); 122 | for (i <- getters.indices) { 123 | val value: ValueWrapper = edge(i) 124 | var resolved = false 125 | if (value.isNull) { 126 | mutableRow.setNullAt(i) 127 | resolved = true 128 | } 129 | if (value.isString) { 130 | getters(i).apply(value.asString(), mutableRow, i) 131 | resolved = true 132 | } 133 | if (value.isDate) { 134 | getters(i).apply(value.asDate(), mutableRow, i) 135 | resolved = true 136 | } 137 | if (value.isTime) { 138 | getters(i).apply(value.asTime(), mutableRow, i) 139 | resolved = true 140 | } 141 | if (value.isDateTime) { 142 | getters(i).apply(value.asDateTime(), mutableRow, i) 143 | resolved = true 144 | } 145 | if (value.isLong) { 146 | getters(i).apply(value.asLong(), mutableRow, i) 147 | } 148 | if (value.isBoolean) { 149 | getters(i).apply(value.asBoolean(), mutableRow, i) 150 | } 151 | if (value.isDouble) { 152 | getters(i).apply(value.asDouble(), mutableRow, i) 153 | } 154 | if (value.isGeography) { 155 | getters(i).apply(value.asGeography(), mutableRow, i) 156 | } 157 | if (value.isDuration) { 158 | getters(i).apply(value.asDuration(), mutableRow, i) 159 | } 160 | } 161 | mutableRow 162 | 163 | } 164 | 165 | override def close(): Unit = { 166 | graphProvider.close(); 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{NebulaOptions, PartitionUtils} 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.connector.read.PartitionReader 11 | import org.apache.spark.sql.types.StructType 12 | import org.slf4j.{Logger, LoggerFactory} 13 | 14 | /** 15 | * Read nebula data for each spark partition 16 | */ 17 | abstract class NebulaPartitionReader extends PartitionReader[InternalRow] with NebulaReader { 18 | private val LOG: Logger = LoggerFactory.getLogger(this.getClass) 19 | 20 | /** 21 | * @param index identifier for spark partition 22 | * @param nebulaOptions nebula Options 23 | * @param schema of data need to read 24 | */ 25 | def this(index: Int, nebulaOptions: NebulaOptions, schema: StructType) { 26 | this() 27 | val totalPart = super.init(index, nebulaOptions, schema) 28 | // index starts with 1 29 | val scanParts = PartitionUtils.getScanParts(index, totalPart, nebulaOptions.partitionNums.toInt) 30 | LOG.info(s"partition index: ${index}, scanParts: ${scanParts.toString}") 31 | scanPartIterator = scanParts.iterator 32 | } 33 | 34 | override def get(): InternalRow = super.getRow() 35 | 36 | override def close(): Unit = { 37 | super.closeReader() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaPartitionReaderFactory.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} 11 | import org.apache.spark.sql.types.StructType 12 | 13 | class NebulaPartitionReaderFactory(private val nebulaOptions: NebulaOptions, 14 | private val schema: StructType) 15 | extends PartitionReaderFactory { 16 | override def createReader(inputPartition: InputPartition): PartitionReader[InternalRow] = { 17 | val partition = inputPartition.asInstanceOf[NebulaPartition].partition 18 | if (DataTypeEnum.VERTEX.toString.equals(nebulaOptions.dataType)) { 19 | 20 | new NebulaVertexPartitionReader(partition, nebulaOptions, schema) 21 | } else if (DataTypeEnum.EDGE.toString.equals(nebulaOptions.dataType)) { 22 | new NebulaEdgePartitionReader(partition, nebulaOptions, schema) 23 | } else { 24 | new NebulaNgqlEdgePartitionReader(nebulaOptions, schema) 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/NebulaVertexPartitionReader.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.sql.types.StructType 10 | 11 | class NebulaVertexPartitionReader(split: Int, nebulaOptions: NebulaOptions, schema: StructType) 12 | extends NebulaPartitionReader(split, nebulaOptions, schema) { 13 | 14 | override def next(): Boolean = hasNextVertexRow 15 | } 16 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/reader/SimpleScanBuilder.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.reader 7 | 8 | import java.util 9 | 10 | import com.vesoft.nebula.connector.NebulaOptions 11 | import org.apache.spark.sql.connector.read.{ 12 | Batch, 13 | InputPartition, 14 | PartitionReaderFactory, 15 | Scan, 16 | ScanBuilder, 17 | SupportsPushDownFilters, 18 | SupportsPushDownRequiredColumns 19 | } 20 | import org.apache.spark.sql.sources.Filter 21 | import org.apache.spark.sql.types.StructType 22 | 23 | import scala.collection.mutable.ListBuffer 24 | import scala.jdk.CollectionConverters.asScalaBufferConverter 25 | 26 | class SimpleScanBuilder(nebulaOptions: NebulaOptions, schema: StructType) 27 | extends ScanBuilder 28 | with SupportsPushDownFilters 29 | with SupportsPushDownRequiredColumns { 30 | 31 | private var filters: Array[Filter] = Array[Filter]() 32 | 33 | override def build(): Scan = { 34 | new SimpleScan(nebulaOptions, nebulaOptions.partitionNums.toInt, schema) 35 | } 36 | 37 | override def pushFilters(pushFilters: Array[Filter]): Array[Filter] = { 38 | if (nebulaOptions.pushDownFiltersEnabled) { 39 | filters = pushFilters 40 | } 41 | pushFilters 42 | } 43 | 44 | override def pushedFilters(): Array[Filter] = filters 45 | 46 | override def pruneColumns(requiredColumns: StructType): Unit = { 47 | if (!nebulaOptions.pushDownFiltersEnabled || requiredColumns == schema) { 48 | new StructType() 49 | } 50 | } 51 | } 52 | 53 | class SimpleScan(nebulaOptions: NebulaOptions, nebulaTotalPart: Int, schema: StructType) 54 | extends Scan 55 | with Batch { 56 | override def toBatch: Batch = this 57 | 58 | override def planInputPartitions(): Array[InputPartition] = { 59 | val partitionSize = nebulaTotalPart 60 | val inputPartitions = for (i <- 1 to partitionSize) 61 | yield { 62 | NebulaPartition(i) 63 | } 64 | inputPartitions.map(_.asInstanceOf[InputPartition]).toArray 65 | } 66 | 67 | override def readSchema(): StructType = schema 68 | 69 | override def createReaderFactory(): PartitionReaderFactory = 70 | new NebulaPartitionReaderFactory(nebulaOptions, schema) 71 | } 72 | 73 | /** 74 | * An identifier for a partition in an NebulaRDD. 75 | */ 76 | case class NebulaPartition(partition: Int) extends InputPartition 77 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/utils/Validations.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.utils 7 | 8 | import org.apache.spark.sql.SparkSession 9 | 10 | object Validations { 11 | def validateSparkVersion(supportedVersions: String*): Unit = { 12 | val sparkVersion = SparkSession.getActiveSession.map { _.version }.getOrElse("UNKNOWN") 13 | if (!(sparkVersion == "UNKNOWN" || supportedVersions.exists(sparkVersion.matches))) { 14 | throw new RuntimeException( 15 | s"Your current spark version ${sparkVersion} is not supported bt the current connector.") 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaCommitMessage.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import org.apache.spark.sql.connector.write.WriterCommitMessage 9 | 10 | case class NebulaCommitMessage(executeStatements: List[String]) extends WriterCommitMessage 11 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaEdgeWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.{NebulaEdge, NebulaEdges} 9 | import com.vesoft.nebula.connector.{KeyPolicy, NebulaOptions, WriteMode} 10 | import org.apache.spark.sql.catalyst.InternalRow 11 | import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} 12 | import org.apache.spark.sql.types.StructType 13 | import org.slf4j.LoggerFactory 14 | 15 | import scala.collection.mutable.ListBuffer 16 | 17 | class NebulaEdgeWriter(nebulaOptions: NebulaOptions, 18 | srcIndex: Int, 19 | dstIndex: Int, 20 | rankIndex: Option[Int], 21 | schema: StructType) 22 | extends NebulaWriter(nebulaOptions) 23 | with DataWriter[InternalRow] { 24 | 25 | private val LOG = LoggerFactory.getLogger(this.getClass) 26 | 27 | val rankIdx = if (rankIndex.isDefined) rankIndex.get else -1 28 | val propNames = NebulaExecutor.assignEdgePropNames(schema, 29 | srcIndex, 30 | dstIndex, 31 | rankIdx, 32 | nebulaOptions.srcAsProp, 33 | nebulaOptions.dstAsProp, 34 | nebulaOptions.rankAsProp) 35 | val fieldTypMap: Map[String, Integer] = 36 | if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() 37 | else metaProvider.getEdgeSchema(nebulaOptions.spaceName, nebulaOptions.label) 38 | 39 | val srcPolicy = 40 | if (nebulaOptions.srcPolicy.isEmpty) Option.empty 41 | else Option(KeyPolicy.withName(nebulaOptions.srcPolicy)) 42 | val dstPolicy = { 43 | if (nebulaOptions.dstPolicy.isEmpty) Option.empty 44 | else Option(KeyPolicy.withName(nebulaOptions.dstPolicy)) 45 | } 46 | 47 | /** buffer to save batch edges */ 48 | var edges: ListBuffer[NebulaEdge] = new ListBuffer() 49 | 50 | prepareSpace() 51 | 52 | /** 53 | * write one edge record to buffer 54 | */ 55 | override def write(row: InternalRow): Unit = { 56 | val srcId = NebulaExecutor.extraID(schema, row, srcIndex, srcPolicy, isVidStringType) 57 | val dstId = NebulaExecutor.extraID(schema, row, dstIndex, dstPolicy, isVidStringType) 58 | val rank = 59 | if (rankIndex.isEmpty) Option.empty 60 | else Option(NebulaExecutor.extraRank(schema, row, rankIndex.get)) 61 | val values = 62 | if (nebulaOptions.writeMode == WriteMode.DELETE) List() 63 | else 64 | NebulaExecutor.assignEdgeValues(schema, 65 | row, 66 | srcIndex, 67 | dstIndex, 68 | rankIdx, 69 | nebulaOptions.srcAsProp, 70 | nebulaOptions.dstAsProp, 71 | nebulaOptions.rankAsProp, 72 | fieldTypMap) 73 | val nebulaEdge = NebulaEdge(srcId, dstId, rank, values) 74 | edges.append(nebulaEdge) 75 | if (edges.size >= nebulaOptions.batch) { 76 | execute() 77 | } 78 | } 79 | 80 | /** 81 | * submit buffer edges to nebula 82 | */ 83 | def execute(): Unit = { 84 | val nebulaEdges = NebulaEdges(propNames, edges.toList, srcPolicy, dstPolicy) 85 | val exec = nebulaOptions.writeMode match { 86 | case WriteMode.INSERT => 87 | NebulaExecutor.toExecuteSentence(nebulaOptions.label, nebulaEdges, nebulaOptions.overwrite) 88 | case WriteMode.UPDATE => 89 | NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaEdges) 90 | case WriteMode.DELETE => 91 | NebulaExecutor.toDeleteExecuteStatement(nebulaOptions.label, nebulaEdges) 92 | case _ => 93 | throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") 94 | } 95 | edges.clear() 96 | submit(exec) 97 | } 98 | 99 | override def commit(): WriterCommitMessage = { 100 | if (edges.nonEmpty) { 101 | execute() 102 | } 103 | graphProvider.close() 104 | metaProvider.close() 105 | NebulaCommitMessage.apply(failedExecs.toList) 106 | } 107 | 108 | override def abort(): Unit = { 109 | LOG.error("insert edge task abort.") 110 | graphProvider.close() 111 | } 112 | 113 | override def close(): Unit = { 114 | graphProvider.close() 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaSourceWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.NebulaOptions 9 | import org.apache.spark.TaskContext 10 | import org.apache.spark.sql.catalyst.InternalRow 11 | import org.apache.spark.sql.connector.write.{ 12 | BatchWrite, 13 | DataWriter, 14 | DataWriterFactory, 15 | PhysicalWriteInfo, 16 | WriterCommitMessage 17 | } 18 | import org.apache.spark.sql.types.StructType 19 | import org.slf4j.LoggerFactory 20 | 21 | /** 22 | * creating and initializing the actual Nebula vertex writer at executor side 23 | */ 24 | class NebulaVertexWriterFactory(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) 25 | extends DataWriterFactory { 26 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { 27 | new NebulaVertexWriter(nebulaOptions, vertexIndex, schema) 28 | } 29 | } 30 | 31 | /** 32 | * creating and initializing the actual Nebula edge writer at executor side 33 | */ 34 | class NebulaEdgeWriterFactory(nebulaOptions: NebulaOptions, 35 | srcIndex: Int, 36 | dstIndex: Int, 37 | rankIndex: Option[Int], 38 | schema: StructType) 39 | extends DataWriterFactory { 40 | override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { 41 | new NebulaEdgeWriter(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) 42 | } 43 | } 44 | 45 | /** 46 | * nebula vertex writer to create factory 47 | */ 48 | class NebulaDataSourceVertexWriter(nebulaOptions: NebulaOptions, 49 | vertexIndex: Int, 50 | schema: StructType) 51 | extends BatchWrite { 52 | private val LOG = LoggerFactory.getLogger(this.getClass) 53 | 54 | override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { 55 | new NebulaVertexWriterFactory(nebulaOptions, vertexIndex, schema) 56 | } 57 | 58 | override def commit(messages: Array[WriterCommitMessage]): Unit = { 59 | LOG.debug(s"${messages.length}") 60 | for (msg <- messages) { 61 | val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] 62 | if (nebulaMsg.executeStatements.nonEmpty) { 63 | LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") 64 | } else { 65 | LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") 66 | } 67 | } 68 | } 69 | 70 | override def abort(messages: Array[WriterCommitMessage]): Unit = { 71 | LOG.error("NebulaDataSourceVertexWriter abort") 72 | } 73 | } 74 | 75 | /** 76 | * nebula edge writer to create factory 77 | */ 78 | class NebulaDataSourceEdgeWriter(nebulaOptions: NebulaOptions, 79 | srcIndex: Int, 80 | dstIndex: Int, 81 | rankIndex: Option[Int], 82 | schema: StructType) 83 | extends BatchWrite { 84 | private val LOG = LoggerFactory.getLogger(this.getClass) 85 | 86 | override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = 87 | new NebulaEdgeWriterFactory(nebulaOptions, srcIndex, dstIndex, rankIndex, schema) 88 | 89 | override def commit(messages: Array[WriterCommitMessage]): Unit = { 90 | LOG.debug(s"${messages.length}") 91 | for (msg <- messages) { 92 | val nebulaMsg = msg.asInstanceOf[NebulaCommitMessage] 93 | if (nebulaMsg.executeStatements.nonEmpty) { 94 | LOG.error(s"failed execs:\n ${nebulaMsg.executeStatements.toString()}") 95 | } else { 96 | LOG.info(s"execs for spark partition ${TaskContext.getPartitionId()} all succeed") 97 | } 98 | } 99 | 100 | } 101 | 102 | override def abort(messages: Array[WriterCommitMessage]): Unit = { 103 | LOG.error("NebulaDataSourceEdgeWriter abort") 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaVertexWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.{ 9 | KeyPolicy, 10 | NebulaOptions, 11 | NebulaVertex, 12 | NebulaVertices, 13 | WriteMode 14 | } 15 | import org.apache.spark.sql.catalyst.InternalRow 16 | import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} 17 | import org.apache.spark.sql.types.StructType 18 | import org.slf4j.LoggerFactory 19 | 20 | import scala.collection.mutable.ListBuffer 21 | 22 | class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexIndex: Int, schema: StructType) 23 | extends NebulaWriter(nebulaOptions) 24 | with DataWriter[InternalRow] { 25 | 26 | private val LOG = LoggerFactory.getLogger(this.getClass) 27 | 28 | val propNames = NebulaExecutor.assignVertexPropNames(schema, vertexIndex, nebulaOptions.vidAsProp) 29 | val fieldTypMap: Map[String, Integer] = 30 | if (nebulaOptions.writeMode == WriteMode.DELETE) Map[String, Integer]() 31 | else metaProvider.getTagSchema(nebulaOptions.spaceName, nebulaOptions.label) 32 | 33 | val policy = { 34 | if (nebulaOptions.vidPolicy.isEmpty) Option.empty 35 | else Option(KeyPolicy.withName(nebulaOptions.vidPolicy)) 36 | } 37 | 38 | /** buffer to save batch vertices */ 39 | var vertices: ListBuffer[NebulaVertex] = new ListBuffer() 40 | 41 | prepareSpace() 42 | 43 | /** 44 | * write one vertex row to buffer 45 | */ 46 | override def write(row: InternalRow): Unit = { 47 | val vertex = 48 | NebulaExecutor.extraID(schema, row, vertexIndex, policy, isVidStringType) 49 | val values = 50 | if (nebulaOptions.writeMode == WriteMode.DELETE) List() 51 | else 52 | NebulaExecutor.assignVertexPropValues(schema, 53 | row, 54 | vertexIndex, 55 | nebulaOptions.vidAsProp, 56 | fieldTypMap) 57 | val nebulaVertex = NebulaVertex(vertex, values) 58 | vertices.append(nebulaVertex) 59 | if (vertices.size >= nebulaOptions.batch) { 60 | execute() 61 | } 62 | } 63 | 64 | /** 65 | * submit buffer vertices to nebula 66 | */ 67 | def execute(): Unit = { 68 | val nebulaVertices = NebulaVertices(propNames, vertices.toList, policy) 69 | val exec = nebulaOptions.writeMode match { 70 | case WriteMode.INSERT => 71 | NebulaExecutor.toExecuteSentence(nebulaOptions.label, 72 | nebulaVertices, 73 | nebulaOptions.overwrite) 74 | case WriteMode.UPDATE => 75 | NebulaExecutor.toUpdateExecuteStatement(nebulaOptions.label, nebulaVertices) 76 | case WriteMode.DELETE => 77 | NebulaExecutor.toDeleteExecuteStatement(nebulaVertices, nebulaOptions.deleteEdge) 78 | case _ => 79 | throw new IllegalArgumentException(s"write mode ${nebulaOptions.writeMode} not supported.") 80 | } 81 | vertices.clear() 82 | submit(exec) 83 | } 84 | 85 | override def commit(): WriterCommitMessage = { 86 | if (vertices.nonEmpty) { 87 | execute() 88 | } 89 | graphProvider.close() 90 | metaProvider.close() 91 | NebulaCommitMessage(failedExecs.toList) 92 | } 93 | 94 | override def abort(): Unit = { 95 | LOG.error("insert vertex task abort.") 96 | graphProvider.close() 97 | } 98 | 99 | override def close(): Unit = { 100 | graphProvider.close() 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriter.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2020 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import java.util.concurrent.TimeUnit 9 | 10 | import com.google.common.util.concurrent.RateLimiter 11 | import com.vesoft.nebula.connector.NebulaOptions 12 | import com.vesoft.nebula.connector.nebula.{GraphProvider, MetaProvider, VidType} 13 | import org.slf4j.LoggerFactory 14 | 15 | import scala.collection.mutable.ListBuffer 16 | 17 | class NebulaWriter(nebulaOptions: NebulaOptions) extends Serializable { 18 | private val LOG = LoggerFactory.getLogger(this.getClass) 19 | 20 | val failedExecs: ListBuffer[String] = new ListBuffer[String] 21 | 22 | val metaProvider = new MetaProvider( 23 | nebulaOptions.getMetaAddress, 24 | nebulaOptions.timeout, 25 | nebulaOptions.connectionRetry, 26 | nebulaOptions.executionRetry, 27 | nebulaOptions.enableMetaSSL, 28 | nebulaOptions.sslSignType, 29 | nebulaOptions.caSignParam, 30 | nebulaOptions.selfSignParam 31 | ) 32 | val graphProvider = new GraphProvider( 33 | nebulaOptions.getGraphAddress, 34 | nebulaOptions.user, 35 | nebulaOptions.passwd, 36 | nebulaOptions.timeout, 37 | nebulaOptions.enableGraphSSL, 38 | nebulaOptions.sslSignType, 39 | nebulaOptions.caSignParam, 40 | nebulaOptions.selfSignParam 41 | ) 42 | val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING 43 | 44 | def prepareSpace(): Unit = { 45 | graphProvider.switchSpace(nebulaOptions.spaceName) 46 | } 47 | 48 | def submit(exec: String): Unit = { 49 | @transient val rateLimiter = RateLimiter.create(nebulaOptions.rateLimit) 50 | if (rateLimiter.tryAcquire(nebulaOptions.rateTimeOut, TimeUnit.MILLISECONDS)) { 51 | val result = graphProvider.submit(exec) 52 | if (!result.isSucceeded) { 53 | failedExecs.append(exec) 54 | if (nebulaOptions.disableWriteLog) { 55 | LOG.error(s"write failed: " + result.getErrorMessage) 56 | } else { 57 | LOG.error(s"write failed: ${result.getErrorMessage} failed statement: \n ${exec}") 58 | } 59 | } else { 60 | LOG.info(s"batch write succeed") 61 | LOG.debug(s"batch write succeed: ${exec}") 62 | } 63 | } else { 64 | failedExecs.append(exec) 65 | LOG.error(s"failed to acquire reteLimiter for statement {$exec}") 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/main/scala/com/vesoft/nebula/connector/writer/NebulaWriterBuilder.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.connector.exception.IllegalOptionException 9 | import com.vesoft.nebula.connector.{DataTypeEnum, NebulaOptions} 10 | import org.apache.spark.sql.SaveMode 11 | import org.apache.spark.sql.connector.write.{ 12 | BatchWrite, 13 | SupportsOverwrite, 14 | SupportsTruncate, 15 | WriteBuilder 16 | } 17 | import org.apache.spark.sql.sources.Filter 18 | import org.apache.spark.sql.types.StructType 19 | 20 | class NebulaWriterBuilder(schema: StructType, saveMode: SaveMode, nebulaOptions: NebulaOptions) 21 | extends WriteBuilder 22 | with SupportsOverwrite 23 | with SupportsTruncate { 24 | 25 | override def buildForBatch(): BatchWrite = { 26 | val dataType = nebulaOptions.dataType 27 | if (DataTypeEnum.VERTEX == DataTypeEnum.withName(dataType)) { 28 | val vertexFiled = nebulaOptions.vertexField 29 | val vertexIndex: Int = { 30 | var index: Int = -1 31 | for (i <- schema.fields.indices) { 32 | if (schema.fields(i).name.equals(vertexFiled)) { 33 | index = i 34 | } 35 | } 36 | if (index < 0) { 37 | throw new IllegalOptionException( 38 | s" vertex field ${vertexFiled} does not exist in dataframe") 39 | } 40 | index 41 | } 42 | new NebulaDataSourceVertexWriter(nebulaOptions, vertexIndex, schema) 43 | } else { 44 | val srcVertexFiled = nebulaOptions.srcVertexField 45 | val dstVertexField = nebulaOptions.dstVertexField 46 | val rankExist = !nebulaOptions.rankField.isEmpty 47 | val edgeFieldsIndex = { 48 | var srcIndex: Int = -1 49 | var dstIndex: Int = -1 50 | var rankIndex: Int = -1 51 | for (i <- schema.fields.indices) { 52 | if (schema.fields(i).name.equals(srcVertexFiled)) { 53 | srcIndex = i 54 | } 55 | if (schema.fields(i).name.equals(dstVertexField)) { 56 | dstIndex = i 57 | } 58 | if (rankExist) { 59 | if (schema.fields(i).name.equals(nebulaOptions.rankField)) { 60 | rankIndex = i 61 | } 62 | } 63 | } 64 | // check src filed and dst field 65 | if (srcIndex < 0 || dstIndex < 0) { 66 | throw new IllegalOptionException( 67 | s" srcVertex field ${srcVertexFiled} or dstVertex field ${dstVertexField} do not exist in dataframe") 68 | } 69 | // check rank field 70 | if (rankExist && rankIndex < 0) { 71 | throw new IllegalOptionException(s"rank field does not exist in dataframe") 72 | } 73 | 74 | if (!rankExist) { 75 | (srcIndex, dstIndex, Option.empty) 76 | } else { 77 | (srcIndex, dstIndex, Option(rankIndex)) 78 | } 79 | 80 | } 81 | new NebulaDataSourceEdgeWriter(nebulaOptions, 82 | edgeFieldsIndex._1, 83 | edgeFieldsIndex._2, 84 | edgeFieldsIndex._3, 85 | schema) 86 | } 87 | } 88 | 89 | override def overwrite(filters: Array[Filter]): WriteBuilder = { 90 | new NebulaWriterBuilder(schema, SaveMode.Overwrite, nebulaOptions) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/test/resources/edge.csv: -------------------------------------------------------------------------------- 1 | id1,id2,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14 2 | 1,2,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2) 3 | 2,3,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4) 4 | 3,4,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6) 5 | 4,5,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7) 6 | 5,6,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5) 7 | 6,7,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)" 8 | 7,1,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)" 9 | 8,1,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)" 10 | 9,1,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)" 11 | 10,2,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)" 12 | -1,5,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 13 | -2,6,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 14 | -3,7,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))" 15 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Global logging configuration 2 | log4j.rootLogger=INFO, stdout 3 | # Console output... 4 | log4j.appender.stdout=org.apache.log4j.ConsoleAppender 5 | log4j.appender.stdout.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.stdout.layout.ConversionPattern=%5p [%t] - %m%n 7 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/test/resources/vertex.csv: -------------------------------------------------------------------------------- 1 | id,col1,col2,col3,col4,col5,col6,col7,col8,col9,col10,col11,col12,col13,col14,col15 2 | 1,Tom,tom,10,20,30,40,2021-01-27,2021-01-01T12:10:10,43535232,true,1.0,2.0,10:10:10,POINT(1 2),"duration({years:1,months:1,seconds:1})" 3 | 2,Jina,Jina,11,21,31,41,2021-01-28,2021-01-02T12:10:10,43535232,false,1.1,2.1,11:10:10,POINT(3 4),"duration({years:1,months:1,seconds:1})" 4 | 3,Tim,Tim,12,22,32,42,2021-01-29,2021-01-03T12:10:10,43535232,false,1.2,2.2,12:10:10,POINT(5 6),"duration({years:1,months:1,seconds:1})" 5 | 4,张三,张三,13,23,33,43,2021-01-30,2021-01-04T12:10:10,43535232,true,1.3,2.3,13:10:10,POINT(6 7),"duration({years:1,months:1,seconds:1})" 6 | 5,李四,李四,14,24,34,44,2021-02-01,2021-01-05T12:10:10,43535232,false,1.4,2.4,14:10:10,POINT(1 5),"duration({years:1,months:1,seconds:1})" 7 | 6,王五,王五,15,25,35,45,2021-02-02,2021-01-06T12:10:10,0,false,1.5,2.5,15:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 8 | 7,Jina,Jina,16,26,36,46,2021-02-03,2021-01-07T12:10:10,43535232,true,1.6,2.6,16:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 9 | 8,Jina,Jina,17,27,37,47,2021-02-04,2021-01-08T12:10:10,43535232,false,1.7,2.7,17:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 10 | 9,Jina,Jina,18,28,38,48,2021-02-05,2021-01-09T12:10:10,43535232,true,1.8,2.8,18:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 11 | 10,Jina,Jina,19,29,39,49,2021-02-06,2021-01-10T12:10:10,43535232,false,1.9,2.9,19:10:10,"LINESTRING(1 3, 4.7 73.23)","duration({years:1,months:1,seconds:1})" 12 | -1,Jina,Jina,20,30,40,50,2021-02-07,2021-02-11T12:10:10,43535232,false,2.0,3.0,20:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 13 | -2,Jina,Jina,21,31,41,51,2021-02-08,2021-03-12T12:10:10,43535232,false,2.1,3.1,21:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 14 | -3,Jina,Jina,22,32,42,52,2021-02-09,2021-04-13T12:10:10,43535232,false,2.2,3.2,22:10:10,"POLYGON((0 1, 1 2, 2 3, 0 1))","duration({years:1,months:1,seconds:1})" 15 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/SparkVersionValidateSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2022 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector 7 | 8 | import com.vesoft.nebula.connector.utils.SparkValidate 9 | import org.apache.spark.sql.SparkSession 10 | import org.scalatest.funsuite.AnyFunSuite 11 | 12 | class SparkVersionValidateSuite extends AnyFunSuite { 13 | test("spark version validate") { 14 | try { 15 | val version = SparkSession.getActiveSession.map(_.version).getOrElse("UNKNOWN") 16 | SparkValidate.validate("3.0.*", "3.1.*", "3.2.*", "3.3.*") 17 | } catch { 18 | case e: Exception => assert(false) 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteDeleteSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.client.graph.data.ResultSet 9 | import com.vesoft.nebula.connector.Address 10 | import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class WriteDeleteSuite extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | 19 | override def beforeAll(): Unit = { 20 | val graphMock = new NebulaGraphMock 21 | graphMock.mockStringIdGraphSchema() 22 | graphMock.mockIntIdGraphSchema() 23 | graphMock.close() 24 | Thread.sleep(10000) 25 | SparkMock.writeVertex() 26 | SparkMock.writeEdge() 27 | } 28 | 29 | test("write vertex into test_write_string space with delete mode") { 30 | SparkMock.deleteVertex() 31 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 32 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 33 | 34 | graphProvider.switchSpace("test_write_string") 35 | val resultSet: ResultSet = 36 | graphProvider.submit("use test_write_string;" 37 | + "match (v:person_connector) return v limit 100000;") 38 | assert(resultSet.isSucceeded) 39 | assert(resultSet.getColumnNames.size() == 1) 40 | assert(resultSet.isEmpty) 41 | } 42 | 43 | test("write vertex into test_write_with_edge_string space with delete with edge mode") { 44 | SparkMock.writeVertex() 45 | SparkMock.writeEdge() 46 | SparkMock.deleteVertexWithEdge() 47 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 48 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 49 | 50 | graphProvider.switchSpace("test_write_string") 51 | // assert vertex is deleted 52 | val vertexResultSet: ResultSet = 53 | graphProvider.submit("use test_write_string;" 54 | + "match (v:person_connector) return v limit 1000000;") 55 | assert(vertexResultSet.isSucceeded) 56 | assert(vertexResultSet.getColumnNames.size() == 1) 57 | assert(vertexResultSet.isEmpty) 58 | 59 | // assert edge is deleted 60 | val edgeResultSet: ResultSet = 61 | graphProvider.submit("use test_write_string;" 62 | + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e") 63 | assert(vertexResultSet.isSucceeded) 64 | assert(edgeResultSet.getColumnNames.size() == 1) 65 | assert(edgeResultSet.isEmpty) 66 | 67 | } 68 | 69 | test("write edge into test_write_string space with delete mode") { 70 | SparkMock.deleteEdge() 71 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 72 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 73 | 74 | graphProvider.switchSpace("test_write_string") 75 | val resultSet: ResultSet = 76 | graphProvider.submit("use test_write_string;" 77 | + "fetch prop on friend_connector \"1\"->\"2\"@10 yield edge as e;") 78 | assert(resultSet.isSucceeded) 79 | assert(resultSet.getColumnNames.size() == 1) 80 | assert(resultSet.isEmpty) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /nebula-spark-connector_3.0/src/test/scala/com/vesoft/nebula/connector/writer/WriteInsertSuite.scala: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021 vesoft inc. All rights reserved. 2 | * 3 | * This source code is licensed under Apache 2.0 License. 4 | */ 5 | 6 | package com.vesoft.nebula.connector.writer 7 | 8 | import com.vesoft.nebula.client.graph.data.ResultSet 9 | import com.vesoft.nebula.connector.Address 10 | import com.vesoft.nebula.connector.mock.{NebulaGraphMock, SparkMock} 11 | import com.vesoft.nebula.connector.nebula.GraphProvider 12 | import org.apache.log4j.BasicConfigurator 13 | import org.scalatest.BeforeAndAfterAll 14 | import org.scalatest.funsuite.AnyFunSuite 15 | 16 | class WriteInsertSuite extends AnyFunSuite with BeforeAndAfterAll { 17 | BasicConfigurator.configure() 18 | 19 | override def beforeAll(): Unit = { 20 | val graphMock = new NebulaGraphMock 21 | graphMock.mockStringIdGraphSchema() 22 | graphMock.mockIntIdGraphSchema() 23 | graphMock.close() 24 | Thread.sleep(10000) 25 | } 26 | 27 | test("write vertex into test_write_string space with insert mode") { 28 | SparkMock.writeVertex() 29 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 30 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 31 | 32 | graphProvider.switchSpace("test_write_string") 33 | val createIndexResult: ResultSet = graphProvider.submit( 34 | "use test_write_string; " 35 | + "create tag index if not exists person_index on person_connector(col1(20));") 36 | Thread.sleep(5000) 37 | graphProvider.submit("rebuild tag index person_index;") 38 | 39 | Thread.sleep(5000) 40 | 41 | graphProvider.submit("use test_write_string;") 42 | val resultSet: ResultSet = 43 | graphProvider.submit("match (v:person_connector) return v;") 44 | assert(resultSet.isSucceeded) 45 | assert(resultSet.getColumnNames.size() == 1) 46 | assert(resultSet.getRows.size() == 13) 47 | } 48 | 49 | test("write edge into test_write_string space with insert mode") { 50 | SparkMock.writeEdge() 51 | 52 | val addresses: List[Address] = List(new Address("127.0.0.1", 9669)) 53 | val graphProvider = new GraphProvider(addresses, "root", "nebula", 3000) 54 | 55 | graphProvider.switchSpace("test_write_string") 56 | val createIndexResult: ResultSet = graphProvider.submit( 57 | "use test_write_string; " 58 | + "create edge index if not exists friend_index on friend_connector(col1(20));") 59 | Thread.sleep(5000) 60 | graphProvider.submit("rebuild edge index friend_index;") 61 | 62 | Thread.sleep(5000) 63 | 64 | graphProvider.submit("use test_write_string;") 65 | val resultSet: ResultSet = 66 | graphProvider.submit("match (v:person_connector)-[e:friend_connector]-> () return e;") 67 | assert(resultSet.isSucceeded) 68 | assert(resultSet.getColumnNames.size() == 1) 69 | assert(resultSet.getRows.size() == 13) 70 | } 71 | } 72 | --------------------------------------------------------------------------------