├── .arcconfig ├── .circleci └── config.yml ├── .gitignore ├── .idea ├── codeStyles │ ├── Project.xml │ └── codeStyleConfig.xml └── runConfigurations │ ├── Test_Spark_3_0.xml │ └── Test_Spark_3_1.xml ├── .java-version ├── .scalafmt.conf ├── CHANGELOG ├── LICENSE ├── Layerfile ├── README.md ├── build.sbt ├── ci └── secring.asc.enc ├── demo ├── Dockerfile ├── README.md └── notebook │ ├── pyspark-singlestore-demo_2F8XQUKFG.zpln │ ├── scala-singlestore-demo_2F6Y3APTX.zpln │ └── spark-sql-singlestore-demo_2F7PZ81H6.zpln ├── project ├── build.properties └── plugins.sbt ├── scripts ├── define-layerci-matrix.sh ├── jwt │ └── jwt_auth_config.json ├── setup-cluster.sh └── ssl │ ├── test-ca-cert.pem │ ├── test-ca-key.pem │ ├── test-singlestore-cert.pem │ └── test-singlestore-key.pem └── src ├── main ├── resources │ └── META-INF │ │ └── services │ │ └── org.apache.spark.sql.sources.DataSourceRegister ├── scala-sparkv3.1 │ └── spark │ │ ├── MaxNumConcurentTasks.scala │ │ ├── VersionSpecificAggregateExpressionExtractor.scala │ │ ├── VersionSpecificExpressionGen.scala │ │ ├── VersionSpecificUtil.scala │ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala ├── scala-sparkv3.2 │ └── spark │ │ ├── MaxNumConcurrentTasks.scala │ │ ├── VersionSpecificAggregateExpressionExtractor.scala │ │ ├── VersionSpecificExpressionGen.scala │ │ ├── VersionSpecificUtil.scala │ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala ├── scala-sparkv3.3 │ └── spark │ │ ├── MaxNumConcurrentTasks.scala │ │ ├── VersionSpecificAggregateExpressionExtractor.scala │ │ ├── VersionSpecificExpressionGen.scala │ │ ├── VersionSpecificUtil.scala │ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala ├── scala-sparkv3.4 │ └── spark │ │ ├── MaxNumConcurrentTasks.scala │ │ ├── VersionSpecificAggregateExpressionExtractor.scala │ │ ├── VersionSpecificExpressionGen.scala │ │ ├── VersionSpecificUtil.scala │ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala ├── scala-sparkv3.5 │ └── spark │ │ ├── MaxNumConcurrentTasks.scala │ │ ├── VersionSpecificAggregateExpressionExtractor.scala │ │ ├── VersionSpecificExpressionGen.scala │ │ ├── VersionSpecificUtil.scala │ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala └── scala │ └── com │ ├── memsql │ └── spark │ │ └── DefaultSource.scala │ └── singlestore │ └── spark │ ├── AggregatorParallelReadListener.scala │ ├── AvroSchemaHelper.scala │ ├── CompletionIterator.scala │ ├── DefaultSource.scala │ ├── ExpressionGen.scala │ ├── JdbcHelpers.scala │ ├── LazyLogging.scala │ ├── Loan.scala │ ├── MetricsHandler.scala │ ├── OverwriteBehavior.scala │ ├── ParallelReadEnablement.scala │ ├── ParallelReadType.scala │ ├── SQLGen.scala │ ├── SQLHelper.scala │ ├── SQLPushdownRule.scala │ ├── SinglestoreBatchInsertWriter.scala │ ├── SinglestoreConnectionPool.scala │ ├── SinglestoreConnectionPoolOptions.scala │ ├── SinglestoreDialect.scala │ ├── SinglestoreLoadDataWriter.scala │ ├── SinglestoreOptions.scala │ ├── SinglestorePartitioner.scala │ ├── SinglestoreRDD.scala │ ├── SinglestoreReader.scala │ └── vendor │ └── apache │ ├── SchemaConverters.scala │ └── third_party_license └── test ├── resources ├── data │ ├── movies.json │ ├── movies_rating.json │ ├── reviews.json │ └── users.json ├── log4j.properties ├── log4j2.properties └── mockito-extensions │ └── org.mockito.plugins.MockMaker └── scala └── com └── singlestore └── spark ├── BatchInsertBenchmark.scala ├── BatchInsertTest.scala ├── BenchmarkSerializingTest.scala ├── BinaryTypeBenchmark.scala ├── CustomDatatypesTest.scala ├── ExternalHostTest.scala ├── IntegrationSuiteBase.scala ├── IssuesTest.scala ├── LoadDataBenchmark.scala ├── LoadDataTest.scala ├── LoadbalanceTest.scala ├── MaxErrorsTest.scala ├── OutputMetricsTest.scala ├── ReferenceTableTest.scala ├── SQLHelperTest.scala ├── SQLOverwriteTest.scala ├── SQLPermissionsTest.scala ├── SQLPushdownTest.scala ├── SanityTest.scala ├── SinglestoreConnectionPoolTest.scala ├── SinglestoreOptionsTest.scala ├── TestHelper.scala └── VersionTest.scala /.arcconfig: -------------------------------------------------------------------------------- 1 | { 2 | "project_id" : "memsql-spark-connector", 3 | "conduit_uri" : "https:\/\/grizzly.internal.memcompute.com\/api\/" 4 | } 5 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | commands: 3 | setup_environment: 4 | description: "Setup the machine environment" 5 | parameters: 6 | sbt_version: 7 | type: string 8 | default: 1.3.6 9 | steps: 10 | - run: 11 | name: Setup Machine 12 | command: | 13 | sudo apt update 14 | sudo update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/bin/java 15 | sudo apt install -y curl 16 | sudo wget https://github.com/sbt/sbt/releases/download/v<< parameters.sbt_version >>/sbt-<< parameters.sbt_version >>.tgz 17 | sudo tar xzvf sbt-<< parameters.sbt_version >>.tgz -C /usr/share/ 18 | sudo rm sbt-<< parameters.sbt_version >>.tgz 19 | sudo update-alternatives --install /usr/bin/sbt sbt /usr/share/sbt/bin/sbt 100 20 | sudo apt-get update 21 | sudo apt-get install -y python-pip git mariadb-client-core-10.6 22 | sudo apt-get clean 23 | sudo apt-get autoclean 24 | 25 | jobs: 26 | test: 27 | parameters: 28 | spark_version: 29 | type: string 30 | singlestore_image: 31 | type: string 32 | machine: true 33 | resource_class: large 34 | environment: 35 | SINGLESTORE_IMAGE: << parameters.singlestore_image >> 36 | SINGLESTORE_PORT: 5506 37 | SINGLESTORE_USER: root 38 | SINGLESTORE_DB: test 39 | JAVA_HOME: /usr/lib/jvm/java-11-openjdk-amd64/ 40 | CONTINUOUS_INTEGRATION: true 41 | SBT_OPTS: "-Xmx256M" 42 | steps: 43 | - setup_environment 44 | - checkout 45 | - run: 46 | name: Setup test cluster 47 | command: ./scripts/setup-cluster.sh 48 | - run: 49 | name: Run tests 50 | command: | 51 | export SINGLESTORE_HOST=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' singlestore-integration) 52 | if [ << parameters.spark_version >> == '3.1.3' ] 53 | then 54 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark31" -Dspark.version=<< parameters.spark_version >> 55 | elif [ << parameters.spark_version >> == '3.2.4' ] 56 | then 57 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark32" -Dspark.version=<< parameters.spark_version >> 58 | elif [ << parameters.spark_version >> == '3.3.4' ] 59 | then 60 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark33" -Dspark.version=<< parameters.spark_version >> 61 | elif [ << parameters.spark_version >> == '3.4.2' ] 62 | then 63 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark34" -Dspark.version=<< parameters.spark_version >> 64 | else 65 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark35" -Dspark.version=<< parameters.spark_version >> 66 | fi 67 | 68 | publish: 69 | machine: true 70 | environment: 71 | JAVA_HOME: /usr/lib/jvm/java-11-openjdk-amd64/ 72 | SONATYPE_USERNAME: memsql 73 | steps: 74 | - setup_environment 75 | - checkout 76 | - run: 77 | name: Import GPG key 78 | command: | 79 | openssl enc -d -aes-256-cbc -K ${ENCRYPTION_KEY} -iv ${ENCRYPTION_IV} -in ci/secring.asc.enc -out ci/secring.asc 80 | gpg --import ci/secring.asc 81 | - run: 82 | name: Publish Spark 3.2.4 83 | command: | 84 | sbt ++2.12.12 -Dspark.version=3.2.4 clean publishSigned sonatypeBundleRelease 85 | - run: 86 | name: Publish Spark 3.1.3 87 | command: | 88 | sbt ++2.12.12 -Dspark.version=3.1.3 clean publishSigned sonatypeBundleRelease 89 | - run: 90 | name: Publish Spark 3.3.4 91 | command: | 92 | sbt ++2.12.12 -Dspark.version=3.3.4 clean publishSigned sonatypeBundleRelease 93 | - run: 94 | name: Publish Spark 3.4.2 95 | command: | 96 | sbt ++2.12.12 -Dspark.version=3.4.2 clean publishSigned sonatypeBundleRelease 97 | - run: 98 | name: Publish Spark 3.5.0 99 | command: | 100 | sbt ++2.12.12 -Dspark.version=3.5.0 clean publishSigned sonatypeBundleRelease 101 | 102 | workflows: 103 | test: 104 | jobs: 105 | - test: 106 | filters: 107 | tags: 108 | only: /^v.*/ 109 | branches: 110 | ignore: /.*/ 111 | matrix: 112 | parameters: 113 | spark_version: 114 | - 3.1.3 115 | - 3.2.4 116 | - 3.3.4 117 | - 3.4.2 118 | - 3.5.0 119 | singlestore_image: 120 | - singlestore/cluster-in-a-box:alma-8.0.19-f48780d261-4.0.11-1.16.0 121 | - singlestore/cluster-in-a-box:alma-8.1.32-e3d3cde6da-4.0.16-1.17.6 122 | - singlestore/cluster-in-a-box:alma-8.5.22-fe61f40cd1-4.1.0-1.17.11 123 | - singlestore/cluster-in-a-box:alma-8.7.12-483e5f8acb-4.1.0-1.17.15 124 | publish: 125 | jobs: 126 | - approve-publish: 127 | type: approval 128 | filters: 129 | tags: 130 | only: /^v.*/ 131 | branches: 132 | ignore: /.*/ 133 | - publish: 134 | requires: 135 | - approve-publish 136 | filters: 137 | tags: 138 | only: /^v.*/ 139 | branches: 140 | ignore: /.*/ 141 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 3 | 4 | # User-specific stuff 5 | /.idea/* 6 | !/.idea/codeStyles/* 7 | !/.idea/runConfigurations/* 8 | 9 | # JIRA plugin 10 | atlassian-ide-plugin.xml 11 | 12 | # IntelliJ 13 | /out/ 14 | /target/ 15 | 16 | # mpeltonen/sbt-idea plugin 17 | /.idea_modules/ 18 | 19 | # File-based project format 20 | /*.iws 21 | 22 | # sbt project stuff 23 | /project/* 24 | !/project/build.properties 25 | !/project/plugins.sbt 26 | /target 27 | /build 28 | /spark-warehouse 29 | 30 | /ci/secring.asc 31 | -------------------------------------------------------------------------------- /.idea/codeStyles/Project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/runConfigurations/Test_Spark_3_0.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | -------------------------------------------------------------------------------- /.idea/runConfigurations/Test_Spark_3_1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | -------------------------------------------------------------------------------- /.java-version: -------------------------------------------------------------------------------- 1 | 1.8 2 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | maxColumn = 100 2 | align = more -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | 2024-11-22 Version 4.1.10 2 | * Updated JDBC driver to 1.2.7 3 | 4 | 2024-11-22 Version 4.1.9 5 | * Changed to work with Databricks runtime 15.4 6 | 7 | 2024-06-14 Version 4.1.8 8 | * Changed retry during reading from result table to use exponential backoff 9 | * Used ForkJoinPool instead of FixedThreadPool 10 | * Added more logging 11 | 12 | 2024-05-13 Version 4.1.7 13 | * Fixed bug that caused reading from the wrong result table when the task was restarted 14 | 15 | 2024-04-11 Version 4.1.6 16 | * Changed LoadDataWriter to send data in batches 17 | * Added numPartitions parameter to specify exact number of resulting partition during parallel read 18 | 19 | 2023-10-05 Version 4.1.5 20 | * Added support of Spark 3.5 21 | * Updated dependencies 22 | 23 | 2023-07-18 Version 4.1.4 24 | * Added support of Spark 3.4 25 | * Added connection attributes 26 | * Fixed conflicts of result table names during parallel read 27 | * Updated version of the SingleStore JDBC driver 28 | 29 | 2023-03-31 Version 4.1.3 30 | * Updated version of the SingleStore JDBC driver 31 | * Fixed error handling when `onDuplicateKeySQL` option is used 32 | 33 | 2023-02-21 Version 4.1.2 34 | * Fixed an issue that would cause a `Table has reached its quota of 1 reader(s)`` error to be displayed when a parallel read was retried 35 | 36 | 2022-07-13 Version 4.1.1 37 | * Added clientEndpoint option for Cloud deployment of the SingleStoreDB 38 | * Fixed bug in the error handling that caused deadlock 39 | * Added support of the Spark 3.3 40 | 41 | 2022-06-22 Version 4.1.0 42 | * Added support of more SQL expressions in pushdown 43 | * Added multi-partition to the parallel read 44 | * Updated SingleStore JDBC Driver to 1.1.0 45 | * Added JWT authentication 46 | * Added connection pooling 47 | 48 | 2022-01-20 Version 4.0.0 49 | * Changed connector to use SingleStore JDBC Driver instead of MariaDB JDBC Driver 50 | 51 | 2021-12-23 Version 3.2.2 52 | * Added possibility to repartition result by columns in parallel read from aggregators 53 | * Replaced usages of `transformDown` with `transform` in order to make connector work with Databricks 9.1 LTS 54 | 55 | 2021-12-14 Version 3.2.1 56 | * Added support of the Spark 3.2 57 | * Fixed links in the README 58 | 59 | 2021-11-29 Version 3.2.0 60 | * Added support for reading in parallel from aggregator nodes instead of leaf nodes 61 | 62 | 2021-09-16 Version 3.1.3 63 | * Added Spark 3.1 support 64 | * Deleted Spark 2.3 and 2.4 support 65 | 66 | 2021-04-29 Version 3.1.2 67 | * Added using external host and port by default while using `useParallelRead` 68 | 69 | 2021-02-05 Version 3.1.1 70 | * Added support of `com.memsql.spark` data source name for backward compatibility 71 | 72 | 2021-01-22 Version 3.1.0 73 | * Rebranded `memsql-spark-connector` to `singlestore-spark-connector` 74 | * Spark data source format changed from `memsql` to `singlestore` 75 | * Configuration prefix changed from `spark.datasource.memsql.` to `spark.datasource.singlestore.` 76 | 77 | 2020-10-19 Version 3.0.5 78 | * Fixed bug with load balance connections to dml endpoint 79 | 80 | 2020-09-29 Version 3.1.0-beta 81 | * Added Spark 3.0 support 82 | * Fixed bugs in pushdowns 83 | * Fixed bug with wrong SQL code generation of attribute names that contains special characters 84 | * Added methods that allow you to run SQL queries on a MemSQL database directly 85 | 86 | 2020-08-20 Version 3.0.4 87 | * Added trim pushdown 88 | 89 | 2020-08-14 Version 3.0.3 90 | * Fixed bug with pushdown of the join condition 91 | 92 | 2020-08-03 Version 3.0.2 93 | * added maxErrors option 94 | * changed aliases in SQL queries to be more deterministic 95 | * disabled comments inside of the SQL queries when logging level is not TRACE 96 | 97 | 2020-06-12 Version 3.0.1 98 | * The connector now updates task metrics with the number of records written during write operations 99 | 100 | 2020-05-27 Version 3.0.0 101 | * Introduces SQL Optimization & Rewrite for most query shapes and compatible expressions 102 | * Implemented as a native Spark SQL plugin 103 | * Supports both the DataSource and DataSourceV2 API for maximum support of current and future functionality 104 | * Contains deep integrations with the Catalyst query optimizer 105 | * Is compatible with Spark 2.3 and 2.4 106 | * Leverages MemSQL LOAD DATA to accelerate ingest from Spark via compression, vectorized cpu instructions, and optimized segment sizes 107 | * Takes advantage of all the latest and greatest features in MemSQL 7.x 108 | 109 | 2020-05-06 Version 3.0.0-rc1 110 | * Support writing into MemSQL reference tables 111 | * Deprecated truncate option in favor of overwriteBehavior 112 | * New option overwriteBehavior allows you to specify how to overwrite or merge rows during ingest 113 | * The Ignore SaveMode now correctly skips all duplicate key errors during ingest 114 | 115 | 2020-04-30 Version 3.0.0-beta12 116 | * Improved performance of new batch insert functionality for `ON DUPLICATE KEY UPDATE` feature 117 | 118 | 2020-04-30 Version 3.0.0-beta11 119 | * Added support for merging rows on ingest via `ON DUPLICATE KEY UPDATE` 120 | * Added docker-based demo for running a Zeppelin notebook using the Spark connector 121 | 122 | 2020-04-20 Version 3.0.0-beta10 123 | * Additional functions supported in SQL Pushdown: toUnixTimestamp, unixTimestamp, nextDay, dateDiff, monthsAdd, hypot, rint 124 | * Now tested against MemSQL 6.7, and all tests use SSL 125 | * Fixed bug with disablePushdown 126 | 127 | 2020-04-09 Version 3.0.0-beta9 128 | * Add null handling to address Spark bug which causes incorrect handling of null literals (https://issues.apache.org/jira/browse/SPARK-31403) 129 | 130 | 2020-04-01 Version 3.0.0-beta8 131 | * Added support for more datetime expressions: 132 | * addition/subtraction of datetime objects 133 | * to_utc_timestamp, from_utc_timestamp 134 | * date_trunc, trunc 135 | 136 | 2020-03-25 Version 3.0.0-beta7 137 | * The connector now respects column selection when loading dataframes into MemSQL 138 | 139 | 2020-03-24 Version 3.0.0-beta6 140 | * Fix bug when you use an expression in an explicit query 141 | 142 | 2020-03-23 Version 3.0.0-beta5 143 | * Increase connection timeout to increase connector reliability 144 | 145 | 2020-03-20 Version 3.0.0-beta4 146 | * Set JDBC driver to MariaDB explicitely to avoid issues with the mysql driver 147 | 148 | 2020-03-19 Version 3.0.0-beta3 149 | * Created tables default to Columnstore 150 | * User can override keys attached to new tables 151 | * New parallelRead option which enables reading directly from MemSQL leaf nodes 152 | * Created tables now set case-sensitive collation on all columns 153 | to match Spark semantics 154 | * More SQL expressions supported in pushdown (tanh, sinh, cosh) 155 | 156 | 2020-02-08 Version 3.0.0-beta2 157 | * Removed options: masterHost and masterPort 158 | * Added ddlEndpoint and ddlEndpoints options 159 | * Added path option to support specifying the dbtable via `.load("mytable")` when creating a dataframe 160 | 161 | 2020-01-30 Version 3.0.0-beta 162 | * Full re-write of the Spark Connector 163 | 164 | 2019-02-27 Version 2.0.7 165 | * Add support for EXPLAIN JSON in MemSQL versions 6.7 and later to fix partition pushdown. 166 | 167 | 2018-09-14 Version 2.0.6 168 | * Force utf-8 encoding when loading data into MemSQL 169 | 170 | 2018-01-18 Version 2.0.5 171 | * Explicitly sort MemSQLRDD partitions due to MemSQL 6.0 no longer returning partitions in sorted order by ordinal. 172 | 173 | 2017-08-31 Version 2.0.4 174 | * Switch threads in LoadDataStrategy so that the parent thread reads from the RDD and the new thread writes 175 | to MemSQL so that Spark has access to the thread-local variables it expects 176 | 177 | 2017-07-19 Version 2.0.3 178 | * Handle special characters column names in query 179 | * Add option to enable jdbc connector to stream result sets row-by-row 180 | * Fix groupby queries incorrectly pushed down to leaves 181 | * Add option to write to master aggregator only 182 | * Add support for reading MemSQL columns of type unsigned bigint and unsigned int 183 | 184 | 2017-04-17 185 | * Pull MemSQL configuration from runtime configuration in sparkSession.conf instead of static config in sparkContext 186 | * Fix connection pooling bug where extraneous connections were created 187 | * Add MemSQL configuration to disable partition pushdown 188 | 189 | 2017-02-06 Version 2.0.1 190 | * Fixed bug to enable partition pushdown for MemSQL DataFrames loaded from a custom user query 191 | 192 | 2017-02-01 Version 2.0.0 193 | * Compatible with Apache Spark 2.0.0+ 194 | * Removed experimental strategy SQL pushdown to instead use the more stable Data Sources API for reading 195 | data from MemSQL 196 | * Removed memsql-spark-interface, memsql-etl 197 | 198 | 2015-12-15 Version 1.2.1 199 | * Python support for extractors and transformers 200 | * More extensive SQL pushdown for DataFrame operations 201 | * Use DataFrames as common interface between extractor, transformer, and loader 202 | * Rewrite connectorLib internals to support SparkSQL relation provider API 203 | * Remove RDD.saveToMemSQL 204 | 205 | 2015-11-19 Version 1.1.1 206 | * Set JDBC login timeout to 10 seconds 207 | 208 | 2015-11-02 Version 1.1.0 209 | 210 | * Available on Maven Central Repository 211 | * More events for batches 212 | * Deprecated the old Kafka extractor and replaced it with a new one that takes in a Zookeeper quorum address 213 | * Added a new field to pipeline API responses indicating whether or not a pipeline is currently running 214 | * Renamed projects: memsqlsparkinterface -> memsql-spark-interface, memsqletl -> memsql-etl, memsqlrdd -> memsql-connector. 215 | * Robustness and bug fixes 216 | 217 | 2015-09-24 Version 1.0.0 218 | 219 | * Initial release of MemSQL Streamliner 220 | -------------------------------------------------------------------------------- /Layerfile: -------------------------------------------------------------------------------- 1 | FROM vm/ubuntu:18.04 2 | 3 | # install curl, python and mysql-client 4 | RUN sudo apt update && \ 5 | sudo apt install -y curl python-pip mysql-client-core-5.7 6 | 7 | # install sbt 8 | RUN wget https://github.com/sbt/sbt/releases/download/v1.3.5/sbt-1.3.5.tgz && \ 9 | sudo tar xzvf sbt-1.3.5.tgz -C /usr/share/ && \ 10 | sudo update-alternatives --install /usr/bin/sbt sbt /usr/share/sbt/bin/sbt 100 11 | 12 | # install the latest version of Docker 13 | RUN apt-get update && \ 14 | apt-get install apt-transport-https ca-certificates curl software-properties-common && \ 15 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - && \ 16 | add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu bionic stable" && \ 17 | apt-get update && \ 18 | apt install docker-ce 19 | 20 | # install java 21 | RUN apt-get update && \ 22 | sudo apt-get install openjdk-8-jdk 23 | 24 | # set environment variables 25 | ENV MEMSQL_PORT=5506 26 | ENV MEMSQL_USER=root 27 | ENV MEMSQL_DB=test 28 | ENV JAVA_HOME=/usr/lib/jvm/jdk1.8.0 29 | ENV CONTINUOUS_INTEGRATION=true 30 | ENV SBT_OPTS=-Xmx1g 31 | ENV SBT_OPTS=-Xms1g 32 | SECRET ENV LICENSE_KEY 33 | SECRET ENV SINGLESTORE_PASSWORD 34 | SECRET ENV SINGLESTORE_JWT_PASSWORD 35 | 36 | # increase the memory 37 | MEMORY 4G 38 | MEMORY 8G 39 | MEMORY 12G 40 | MEMORY 16G 41 | 42 | # split to 21 states 43 | # each of them will run different version of the singlestore and spark 44 | SPLIT 21 45 | 46 | # copy the entire git repository 47 | COPY . . 48 | 49 | # setup split specific env variables 50 | RUN scripts/define-layerci-matrix.sh >> ~/.profile 51 | 52 | # start singlestore cluster 53 | RUN ./scripts/setup-cluster.sh 54 | 55 | # run tests 56 | RUN sbt ++$SCALA_VERSION -Dspark.version=$SPARK_VERSION "${TEST_FILTER}" 57 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import xerial.sbt.Sonatype._ 2 | 3 | /* 4 | To run tests or publish with a specific spark version use this java option: 5 | -Dspark.version=3.0.0 6 | */ 7 | val sparkVersion = sys.props.get("spark.version").getOrElse("3.1.3") 8 | val scalaVersionStr = "2.12.12" 9 | val scalaVersionPrefix = scalaVersionStr.substring(0, 4) 10 | val jacksonDatabindVersion = sparkVersion match { 11 | case "3.1.3" => "2.10.0" 12 | case "3.2.4" => "2.12.3" 13 | case "3.3.4" => "2.13.4.2" 14 | case "3.4.2" => "2.14.2" 15 | case "3.5.0" => "2.15.2" 16 | } 17 | 18 | lazy val root = project 19 | .withId("singlestore-spark-connector") 20 | .in(file(".")) 21 | .enablePlugins(BuildInfoPlugin) 22 | .settings( 23 | name := "singlestore-spark-connector", 24 | organization := "com.singlestore", 25 | scalaVersion := scalaVersionStr, 26 | Compile / unmanagedSourceDirectories += (Compile / sourceDirectory).value / (sparkVersion match { 27 | case "3.1.3" => "scala-sparkv3.1" 28 | case "3.2.4" => "scala-sparkv3.2" 29 | case "3.3.4" => "scala-sparkv3.3" 30 | case "3.4.2" => "scala-sparkv3.4" 31 | case "3.5.0" => "scala-sparkv3.5" 32 | }), 33 | version := s"4.1.10-spark-${sparkVersion}", 34 | licenses += "Apache-2.0" -> url( 35 | "http://opensource.org/licenses/Apache-2.0" 36 | ), 37 | resolvers += "Spark Packages Repo" at "https://dl.bintray.com/spark-packages/maven", 38 | libraryDependencies ++= Seq( 39 | // runtime dependencies 40 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided, test", 41 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided, test", 42 | "org.apache.avro" % "avro" % "1.11.3", 43 | "org.apache.commons" % "commons-dbcp2" % "2.7.0", 44 | "org.scala-lang.modules" %% "scala-java8-compat" % "0.9.0", 45 | "com.singlestore" % "singlestore-jdbc-client" % "1.2.7", 46 | "io.spray" %% "spray-json" % "1.3.5", 47 | "io.netty" % "netty-buffer" % "4.1.70.Final", 48 | "org.apache.commons" % "commons-dbcp2" % "2.9.0", 49 | // test dependencies 50 | "org.mariadb.jdbc" % "mariadb-java-client" % "2.+" % Test, 51 | "org.scalatest" %% "scalatest" % "3.1.0" % Test, 52 | "org.scalacheck" %% "scalacheck" % "1.14.1" % Test, 53 | "org.mockito" %% "mockito-scala" % "1.16.37" % Test, 54 | "com.github.mrpowers" %% "spark-fast-tests" % "0.21.3" % Test, 55 | "com.github.mrpowers" %% "spark-daria" % "0.38.2" % Test 56 | ), 57 | dependencyOverrides += "com.fasterxml.jackson.core" % "jackson-databind" % jacksonDatabindVersion, 58 | Test / testOptions += Tests.Argument("-oF"), 59 | Test / fork := true, 60 | buildInfoKeys := Seq[BuildInfoKey](version), 61 | buildInfoPackage := "com.singlestore.spark" 62 | ) 63 | 64 | credentials += Credentials( 65 | "GnuPG Key ID", 66 | "gpg", 67 | "CDD996495CF08BB2041D86D8D1EB3D14F1CD334F", 68 | "ignored" // this field is ignored; passwords are supplied by pinentry 69 | ) 70 | 71 | assemblyMergeStrategy in assembly := { 72 | case PathList("META-INF", _*) => MergeStrategy.discard 73 | case _ => MergeStrategy.first 74 | } 75 | 76 | publishTo := sonatypePublishToBundle.value 77 | publishMavenStyle := true 78 | sonatypeSessionName := s"[sbt-sonatype] ${name.value} ${version.value}" 79 | sonatypeProjectHosting := Some(GitHubHosting("memsql", "memsql-spark-connector", "carl@memsql.com")) 80 | -------------------------------------------------------------------------------- /ci/secring.asc.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/memsql/singlestore-spark-connector/febf9e5a854334233a2bbb8fdc52ad619bf3e13b/ci/secring.asc.enc -------------------------------------------------------------------------------- /demo/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM apache/zeppelin:0.9.0 2 | 3 | ENV SPARK_VERSION=3.5.0 4 | 5 | USER root 6 | 7 | RUN wget https://apache.ip-connect.vn.ua/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz 8 | RUN tar xf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz -C / 9 | RUN rm -rf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz 10 | ENV SPARK_HOME=/spark-${SPARK_VERSION}-bin-hadoop2.7 11 | ENV ZEPPELIN_PORT=8082 12 | RUN rm -rf /zeppelin/notebook/* 13 | 14 | EXPOSE ${ZEPPELIN_PORT}/tcp 15 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | ## singlestore-spark-connector demo 2 | 3 | This is Dockerfile which uses the upstream [Zeppelin Image](https://hub.docker.com/r/apache/zeppelin/) as it's base 4 | and has two notebooks with examples of singlestore-spark-connector. 5 | 6 | To run this docker with [MemSQL CIAB](https://hub.docker.com/r/memsql/cluster-in-a-box) follow the instructions 7 | 8 | * Create a docker network to be able to connect zeppelin and memsql-ciab 9 | ``` 10 | docker network create zeppelin-ciab-network 11 | ``` 12 | 13 | * Pull memsql-ciab docker image 14 | ``` 15 | docker pull memsql/cluster-in-a-box 16 | ``` 17 | 18 | * Run and start the SingleStore Cluster in a Box docker container 19 | 20 | ``` 21 | docker run -i --init \ 22 | --name singlestore-ciab-for-zeppelin \ 23 | -e LICENSE_KEY=[INPUT_YOUR_LICENSE_KEY] \ 24 | -e ROOT_PASSWORD=my_password \ 25 | -p 3306:3306 -p 8081:8080 \ 26 | --net=zeppelin-ciab-network \ 27 | memsql/cluster-in-a-box 28 | ``` 29 | ``` 30 | docker start singlestore-ciab-for-zeppelin 31 | ``` 32 | > :note: in this step you can hit a port collision error 33 | > 34 | > ``` 35 | > docker: Error response from daemon: driver failed programming external connectivity on endpoint singlestore-ciab-for-zeppelin 36 | > (38b0df3496f1ec83f120242a53a7023d8a0b74db67f5e487fb23641983c67a76): 37 | > Bind for 0.0.0.0:8080 failed: port is already allocated. 38 | > ERRO[0000] error waiting for container: context canceled 39 | > ``` 40 | > 41 | > If it happened then remove the container 42 | > 43 | >`docker rm singlestore-ciab-for-zeppelin` 44 | > 45 | > and run the first command with other ports `-p {new_port1}:3306 -p {new_port2}:8080` 46 | 47 | * Build zeppelin docker image in `singlestore-spark-connector/demo` folder 48 | 49 | ``` 50 | docker build -t zeppelin . 51 | ``` 52 | 53 | * Run zeppelin docker container 54 | ``` 55 | docker run -d --init \ 56 | --name zeppelin \ 57 | -p 8082:8082 \ 58 | --net=zeppelin-ciab-network \ 59 | -v $PWD/notebook:/opt/zeppelin/notebook/singlestore \ 60 | -v $PWD/notebook:/zeppelin/notebook/singlestore \ 61 | zeppelin 62 | ``` 63 | 64 | > :note: in this step you can hit a port collision error 65 | > 66 | > ``` 67 | > docker: Error response from daemon: driver failed programming external connectivity on endpoint zeppelin 68 | > (38b0df3496f1ec83f120242a53a7023d8a0b74db67f5e487fb23641983c67a76): 69 | > Bind for 0.0.0.0:8082 failed: port is already allocated. 70 | > ERRO[0000] error waiting for container: context canceled 71 | > ``` 72 | > 73 | > If it happened then remove the container 74 | > 75 | >`docker rm zeppelin` 76 | > 77 | > and run this command with other port `-p {new_port}:8082` 78 | 79 | 80 | * open [zeppelin](http://localhost:8082/next) in your browser and try 81 | [scala](http://localhost:8082/next/#/notebook/2F8XQUKFG), 82 | [pyspark](http://localhost:8082/next/#/notebook/2F6Y3APTX) 83 | and [spark sql](http://localhost:8082/next/#/notebook/2F7PZ81H6) notebooks 84 | 85 | For setting up more powerful SingleStore trial cluster use [SingleStore Managed Service](https://www.singlestore.com/managed-service/) 86 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.3.8 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.6") 2 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1") 3 | addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0") 4 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.2.0") 5 | -------------------------------------------------------------------------------- /scripts/define-layerci-matrix.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | 4 | SINGLESTORE_IMAGE_TAGS=( 5 | "alma-8.0.19-f48780d261-4.0.11-1.16.0" 6 | "alma-8.1.32-e3d3cde6da-4.0.16-1.17.6" 7 | "alma-8.5.22-fe61f40cd1-4.1.0-1.17.11" 8 | "alma-8.7.12-483e5f8acb-4.1.0-1.17.15" 9 | ) 10 | SINGLESTORE_IMAGE_TAGS_COUNT=${#SINGLESTORE_IMAGE_TAGS[@]} 11 | SPARK_VERSIONS=( 12 | "3.5.0" 13 | "3.4.2" 14 | "3.3.4" 15 | "3.2.4" 16 | "3.1.3" 17 | ) 18 | SPARK_VERSIONS_COUNT=${#SPARK_VERSIONS[@]} 19 | 20 | TEST_NUM=${SPLIT:-"0"} 21 | 22 | SINGLESTORE_IMAGE_TAG_INDEX=$(( $TEST_NUM / $SPARK_VERSIONS_COUNT)) 23 | SINGLESTORE_IMAGE_TAG_INDEX=$((SINGLESTORE_IMAGE_TAG_INDEX>=SINGLESTORE_IMAGE_TAGS_COUNT ? SINGLESTORE_IMAGE_TAGS_COUNT-1 : SINGLESTORE_IMAGE_TAG_INDEX)) 24 | SINGLESTORE_IMAGE_TAG=${SINGLESTORE_IMAGE_TAGS[SINGLESTORE_IMAGE_TAG_INDEX]} 25 | 26 | SPARK_VERSION_INDEX=$(( $TEST_NUM % $SPARK_VERSIONS_COUNT)) 27 | SPARK_VERSION=${SPARK_VERSIONS[SPARK_VERSION_INDEX]} 28 | 29 | if [ $TEST_NUM == $(($SINGLESTORE_IMAGE_TAGS_COUNT*$SPARK_VERSIONS_COUNT)) ] 30 | then 31 | echo 'export FORCE_READ_FROM_LEAVES=TRUE' 32 | else 33 | echo 'export FORCE_READ_FROM_LEAVES=FALSE' 34 | fi 35 | 36 | echo "export SINGLESTORE_IMAGE='singlestore/cluster-in-a-box:$SINGLESTORE_IMAGE_TAG'" 37 | echo "export SPARK_VERSION='$SPARK_VERSION'" 38 | echo "export TEST_FILTER='testOnly -- -l ExcludeFromSpark${SPARK_VERSION:0:1}${SPARK_VERSION:2:1}'" 39 | echo "export SCALA_VERSION='2.12.12'" 40 | -------------------------------------------------------------------------------- /scripts/jwt/jwt_auth_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "username_claim": "username", 3 | "methods": [ 4 | { 5 | "algorithms": [ "RS384" ], 6 | "secret": "-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA0i0dDauX6iaOaocic99O\nUTruTYPWFUv50aHTgfKxenFKpJTTL43T8ON36whwyObM3r/ayhPoyPxvSfkkCwxd\nE7XEmTRARHkJKQfebkbN6SaKlEgIdmZ8UroZCslSzOcsX0N1KNc3WSyFeOigHSp/\nww+roVtaJC/OJQ95kMjIGdN3ooO5g/YvZJZTn9KQ/dFmNDPaSyseT9/MCE2Rp0g0\nT2yCwewxVdfR+D4QcicaLat7CAFXMnoSxV9ifGXYkv6JE33dc95U4BgPYECca2QA\nNe3ZQHSNxC1rc+uim3cgcn6PP4WKTgTG4u74F9xA8FbumZUIMB7rChsr+E8Z4/Iq\nSGb7/y0J6Auho7BDJPL7ZFE9peuSA3NudZlpkH+GIWRW7fBY7qu1Koh8kGfZeg4k\nLIucgT0CpSvrxaDVqRSEIiqy4zoczrLXcVJN9ThtxVPJZCVTW2dir8aCwO9Sk60W\nQ8DP4wRWhnqa+Irsd/2r5c7QgSSOSo4/EIBU65qP5oA6xUw3F4sE7G+Q9ofijxxJ\n8iGr2Y7SBk5ztvSYAX/ZNfoUJZQ0cXDCSCKus8a9kQQ9zCNBLco9pIv7XUcJf1bj\nzfz1o911OZn/mdcgqUq0uRne8x73J/Z4uIsaVubvcSgv0XTSF8qYgpgRd016+nAa\nFtyVd9Y5xmYgIIHaXzQ3l7MCAwEAAQ==\n-----END PUBLIC KEY-----" 7 | } 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /scripts/setup-cluster.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | 4 | # this script must be run from the top-level of the repo 5 | cd "$(git rev-parse --show-toplevel)" 6 | 7 | DEFAULT_IMAGE_NAME="singlestore/cluster-in-a-box:alma-8.0.15-0b9b66384f-4.0.11-1.15.2" 8 | IMAGE_NAME="${SINGLESTORE_IMAGE:-$DEFAULT_IMAGE_NAME}" 9 | CONTAINER_NAME="singlestore-integration" 10 | 11 | EXISTS=$(docker inspect ${CONTAINER_NAME} >/dev/null 2>&1 && echo 1 || echo 0) 12 | 13 | if [[ "${EXISTS}" -eq 1 ]]; then 14 | EXISTING_IMAGE_NAME=$(docker inspect -f '{{.Config.Image}}' ${CONTAINER_NAME}) 15 | if [[ "${IMAGE_NAME}" != "${EXISTING_IMAGE_NAME}" ]]; then 16 | echo "Existing container ${CONTAINER_NAME} has image ${EXISTING_IMAGE_NAME} when ${IMAGE_NAME} is expected; recreating container." 17 | docker rm -f ${CONTAINER_NAME} 18 | EXISTS=0 19 | fi 20 | fi 21 | 22 | if [[ "${EXISTS}" -eq 0 ]]; then 23 | docker run -i --init \ 24 | --name ${CONTAINER_NAME} \ 25 | -v ${PWD}/scripts/ssl:/test-ssl \ 26 | -v ${PWD}/scripts/jwt:/test-jwt \ 27 | -e LICENSE_KEY=${LICENSE_KEY} \ 28 | -e ROOT_PASSWORD=${SINGLESTORE_PASSWORD} \ 29 | -p 5506:3306 -p 5507:3307 -p 5508:3308 \ 30 | ${IMAGE_NAME} 31 | fi 32 | 33 | docker start ${CONTAINER_NAME} 34 | 35 | singlestore-wait-start() { 36 | echo -n "Waiting for SingleStore to start..." 37 | while true; do 38 | if mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" -e "select 1" >/dev/null 2>/dev/null; then 39 | break 40 | fi 41 | echo -n "." 42 | sleep 0.2 43 | done 44 | echo ". Success!" 45 | } 46 | 47 | singlestore-wait-start 48 | 49 | if [[ "${EXISTS}" -eq 0 ]]; then 50 | echo 51 | echo "Creating aggregator node" 52 | docker exec -it ${CONTAINER_NAME} memsqlctl create-node --yes --password ${SINGLESTORE_PASSWORD} --port 3308 53 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key minimum_core_count --value 0 54 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key minimum_memory_mb --value 0 55 | docker exec -it ${CONTAINER_NAME} memsqlctl start-node --yes --all 56 | docker exec -it ${CONTAINER_NAME} memsqlctl add-aggregator --yes --host 127.0.0.1 --password ${SINGLESTORE_PASSWORD} --port 3308 57 | fi 58 | 59 | echo 60 | echo "Setting up SSL" 61 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key ssl_ca --value /test-ssl/test-ca-cert.pem 62 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key ssl_cert --value /test-ssl/test-singlestore-cert.pem 63 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key ssl_key --value /test-ssl/test-singlestore-key.pem 64 | echo "Setting up JWT" 65 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key jwt_auth_config_file --value /test-jwt/jwt_auth_config.json 66 | echo "Restarting cluster" 67 | docker exec -it ${CONTAINER_NAME} memsqlctl restart-node --yes --all 68 | singlestore-wait-start 69 | echo "Setting up root-ssl user" 70 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" -e 'create user "root-ssl"@"%" require ssl' 71 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" -e 'grant all privileges on *.* to "root-ssl" with grant option' 72 | mysql -u root -h 127.0.0.1 -P 5507 -p"${SINGLESTORE_PASSWORD}" -e 'create user "root-ssl"@"%" require ssl' 73 | mysql -u root -h 127.0.0.1 -P 5507 -p"${SINGLESTORE_PASSWORD}" -e 'grant all privileges on *.* to "root-ssl" with grant option' 74 | mysql -u root -h 127.0.0.1 -P 5508 -p"${SINGLESTORE_PASSWORD}" -e 'grant all privileges on *.* to "root-ssl" with grant option' 75 | echo "Done!" 76 | echo "Setting up root-jwt user" 77 | mysql -h 127.0.0.1 -u root -P 5506 -p"${SINGLESTORE_PASSWORD}" -e "CREATE USER 'test_jwt_user' IDENTIFIED WITH authentication_jwt" 78 | mysql -h 127.0.0.1 -u root -P 5506 -p"${SINGLESTORE_PASSWORD}" -e "GRANT ALL PRIVILEGES ON *.* TO 'test_jwt_user'@'%'" 79 | echo "Done!" 80 | 81 | echo 82 | echo "Ensuring child nodes are connected using container IP" 83 | CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${CONTAINER_NAME}) 84 | CURRENT_LEAF_IP=$(mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e 'select host from information_schema.leaves') 85 | if [[ ${CONTAINER_IP} != "${CURRENT_LEAF_IP}" ]]; then 86 | # remove leaf with current ip 87 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "remove leaf '${CURRENT_LEAF_IP}':3307" 88 | # add leaf with correct ip 89 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "add leaf root:'${SINGLESTORE_PASSWORD}'@'${CONTAINER_IP}':3307" 90 | fi 91 | CURRENT_AGG_IP=$(mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e 'select host from information_schema.aggregators where master_aggregator=0') 92 | if [[ ${CONTAINER_IP} != "${CURRENT_AGG_IP}" ]]; then 93 | # remove aggregator with current ip 94 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "remove aggregator '${CURRENT_AGG_IP}':3308" 95 | # add aggregator with correct ip 96 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "add aggregator root:'${SINGLESTORE_PASSWORD}'@'${CONTAINER_IP}':3308" 97 | fi 98 | echo "Done!" 99 | -------------------------------------------------------------------------------- /scripts/ssl/test-ca-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDCzCCAfOgAwIBAgIULJ/CPUCImfaUkxiXSMAXBKWvtnIwDQYJKoZIhvcNAQEL 3 | BQAwFDESMBAGA1UEAwwJdGVzdC1yb290MCAXDTIwMDQxMzIyMjYzOFoYDzMwMTkw 4 | ODE1MjIyNjM4WjAUMRIwEAYDVQQDDAl0ZXN0LXJvb3QwggEiMA0GCSqGSIb3DQEB 5 | AQUAA4IBDwAwggEKAoIBAQDQWWfkNMUG11C4INpkOf0E/Ao5pLfVIrAWVGRTQe+s 6 | KDx9w1UaLiKx21vWXWOHPz8bKvQ1OOhPP85r+ZOdbmsrOMyWF+kMZEgs8pUsG8kk 7 | ztO+nRWskAN0qCD9kLPcN0exqda7S+we/1WOBVUv92Qvc3ip6anwlNkRmCGDznP8 8 | p0q8vxj/cmPMhfwuq3RLvkJ3buBGboJBNkphTTR3K3iy40XiG25U3YEnjspaEpqs 9 | QO2JfZBRA4Jq/OTDfSPH16/JINsHkpe4GFgP78133DS7fFZs6HEonxMlr0cvgJGk 10 | JBJIcOZZYezTpgIlTETfNZeiJybTQ/0IX7KSugV/Dnt/AgMBAAGjUzBRMB0GA1Ud 11 | DgQWBBTneAQV+LCbwSS0EJ4u0ZvaUDJREjAfBgNVHSMEGDAWgBTneAQV+LCbwSS0 12 | EJ4u0ZvaUDJREjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAW 13 | sdYs9+C+GC81/dXMBstZDol9AQ2871hGHsjFaomk0/pxvCvgv+oLlHEjCyHaykcJ 14 | AY3hurFwyKMMXaR/AD5WcIe8Khk1wrzxoJhXyVIYmgE8cIdsDrA0D9ggILI0DmbH 15 | E2vZjAUdiDeS0IMv2AF1kprbPzA0lM/0Lm1DLLubPYs6oijOpVx8tyGxEesxqF8I 16 | KcIBk64WYkBNQFL7vy5fHI2S2N8D7COBybuabojCMKqUetGLw2BfSuOCOVAviwBK 17 | 5WMaFqULkAsc9GuMVvoddujK8B/yGQ1Nim4NTDdRGivtsSwphw+tNo+VRFLk/PK2 18 | VgeZz4rnrlh0Iog1f+iZ 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /scripts/ssl/test-ca-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEA0Fln5DTFBtdQuCDaZDn9BPwKOaS31SKwFlRkU0HvrCg8fcNV 3 | Gi4isdtb1l1jhz8/Gyr0NTjoTz/Oa/mTnW5rKzjMlhfpDGRILPKVLBvJJM7Tvp0V 4 | rJADdKgg/ZCz3DdHsanWu0vsHv9VjgVVL/dkL3N4qemp8JTZEZghg85z/KdKvL8Y 5 | /3JjzIX8Lqt0S75Cd27gRm6CQTZKYU00dyt4suNF4htuVN2BJ47KWhKarEDtiX2Q 6 | UQOCavzkw30jx9evySDbB5KXuBhYD+/Nd9w0u3xWbOhxKJ8TJa9HL4CRpCQSSHDm 7 | WWHs06YCJUxE3zWXoicm00P9CF+ykroFfw57fwIDAQABAoIBAQCbNd91W/JjNEfH 8 | w4GuJJze97vOUW05dAvltpy+gWJAyAC4V6mwRSpHgPibaxrYCD/Ex20BsREu6IOo 9 | YFadc0KXAks2jT1po9M42MZUA6cGqqWHXJJm6SoJ364j94ZlyTC5o6J6CQcv2Fst 10 | 378kapHR353GRnH47Yn/12swO76gOc7LLC8LqoWwMAnCjbas63rhBpdeBWymB1s5 11 | 8TbZn6d3ViHJuIt/wGuWvrDvc9Gi0GjQaUA+KhbdBLiL6THdFA9w1RUKIBoOFB7x 12 | 2rN/BL1cCARcrz31QpbLfQgHeFdXxj9iK15WOumSSD+iNWZoA16QxvzYDCL2Pnk9 13 | 0PyTFZcxAoGBAPvMnOfdwOmjXRBGJurYzj8O8amRqphvtRt76zZhjXG7wzdvMCmb 14 | WP77odIYtw97zNmIiS0p91b4O4xmByr6V74+8qIHM09xqB2+hqBft020vEGXNoOa 15 | Qq/hcAT6Nb0kLIMHOxs68CnPWmH7opjg28MaDXq1ScgHl7I7H4U2ZclFAoGBANPT 16 | Od7fLQ9uGnMaw5mIZ14mcb0gUQt+yFTXiZvI4XbAmuK6o2zzmG/4bBMktqOkTc1h 17 | JEnx1Zmeoqbbhv5cBQA03Y6HF/rREO0nfEFrc8YVGZN2ZqtZOwHR/cekobNkPckU 18 | CdGHVF4X6D10u/1S63tuxA7ON9A1s7YntTWq/CPzAoGATxXtEkZsGPXefQYLoyeF 19 | X/jpnkDKPCaZ05AQSHxLWLWIkxixH+BTC4MtSDfLB2ny5UAlFbJgpUhCK86/4ZfP 20 | h0luG8X3L7SbAPyefDCT+iwSFOfRj3QcDfHYpTeROV7rPBxBTEQuunMOCEhowWue 21 | mqDMKwZVriX0V16Kf+SeA6ECgYA6/+RojWTxnUtEsDm28+VGthKMCQpJ12BZMUek 22 | 2ojiGLeLW0zVtevJlDoWAu3UGpmJEPuYlQFXrnXDX/XztxG1gwQLBNnLBJxgUdUs 23 | K4+tpobfKeVi6JGk6iZziwl2+/6xmSE6+SSoqKQJKhCKeKQaVznIneux1KNfoyO3 24 | 9Q4RvQKBgQClP/fHEcdgsla47IHmwJi5+YzuX6y9gAJ6Taa+Fy6NawOrwE1c4d0o 25 | lOtwbl/Jxvjqew02kdm/R9LEQqIblzTT0S7o9n+1+TZ6AGacMEKbR3lNtgIKHIdl 26 | aer4g9JTw+Ob42RcRuS5CaKttZSxMYw6wgAxAnKM+zs5TaGK0jOL7g== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /scripts/ssl/test-singlestore-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIICpzCCAY8CAQEwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJdGVzdC1yb290 3 | MCAXDTIwMDQxMzIyMjkyNloYDzMwMTkwODE1MjIyOTI2WjAdMRswGQYDVQQDDBJ0 4 | ZXN0LW1lbXNxbC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB 5 | AQCUsvUKgAd4A2ggbGts/8L9y9mkS1z0R5oFvhuAu6zpqNSO0vQqkzY22mm6VxyD 6 | Di6vBECUTd+wQk8r1puU7c//H12okvxmZ2cB5UlfQQdFdfvzXJ7eUCEaHwIB7GUy 7 | 1nCwOFV1HT8CT+MCPZsS0SntsLOQfGwhkxsQ5qRhxcGfg5MoJh+Ew8b9CipLov8c 8 | Aryzai9aTW7NUOscuAaTrCuZyDdp1A940atQ/RZcaGyN1AueRZNiQYf4Kx8tff/D 9 | koAH293O86VbN4uj7IVKD7eZI9vDebRpAH2WuP82azLEHbQ4Zu9g2/Xf+4x2x1wl 10 | Xg0t4lxPntcNBLnX0b7PosBzAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAL4ohV3q 11 | V04vJg71/CG58tRpPfv9WAWr8yaiUedOxAKpG4vkQcN+eUXHneEXVWaPz8w4cCbV 12 | rIRDXpRoch33bjsKLMbKhnMOjv1NeNvPeIlrzXZjhMS4tfW0aAtAQAKdQCOZHfnZ 13 | GC1TY26c3YnLv2NvHwJ4PExgcFE7Ex3ej6Z0lMz971UoCgfm4zfb3ag4R9CBroAO 14 | GyqMteSxeXG1LOijXMaShjrba7moeqcClrxiV3oM7p84aGBHLlBiY1IvqVJzGTAx 15 | JwcISD0Ao3iRnrnQOhmC1ryQ8lYFFnmrxt4oVHQpUpanWNVV++SHQmO8dzRmTnFR 16 | xxVSDwHkhUt4m94= 17 | -----END CERTIFICATE----- 18 | -------------------------------------------------------------------------------- /scripts/ssl/test-singlestore-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEogIBAAKCAQEAlLL1CoAHeANoIGxrbP/C/cvZpEtc9EeaBb4bgLus6ajUjtL0 3 | KpM2Ntppulccgw4urwRAlE3fsEJPK9ablO3P/x9dqJL8ZmdnAeVJX0EHRXX781ye 4 | 3lAhGh8CAexlMtZwsDhVdR0/Ak/jAj2bEtEp7bCzkHxsIZMbEOakYcXBn4OTKCYf 5 | hMPG/QoqS6L/HAK8s2ovWk1uzVDrHLgGk6wrmcg3adQPeNGrUP0WXGhsjdQLnkWT 6 | YkGH+CsfLX3/w5KAB9vdzvOlWzeLo+yFSg+3mSPbw3m0aQB9lrj/NmsyxB20OGbv 7 | YNv13/uMdsdcJV4NLeJcT57XDQS519G+z6LAcwIDAQABAoIBAEcLeagao3bjqcxU 8 | AL+DM1avHr0whKjxzNURj3JiOKsqzuOuRppQ24Y5tGojVKwJCqT0EybITieYhtsb 9 | Hhp5xPbPtZ/lGlKS9NQjCHtKRn8Zb9dGWWE+R5KDXiItH+y6J/0J7UqXPpOMN5nK 10 | dVz4MmAuHJzb1Y31Cul4SPGt2mSrbiNcIrghTfTI7iCpk/+70zJdxIYrw/x35DMl 11 | rzbiNIuUnQDYeEa7I2SCZAWMwpKget9B3S3fgnA1WKQuNKlKJDT5X9PVfzPcrhzV 12 | IWs0mxq+vCJx7AVqwSMJyIV1ijSR6rU+atr/zJJJv9VRofhUodx11Kjm3yXGFmQ8 13 | bHUu9QECgYEAw5avzPpyrZSXwFclCXt+sjiC/cpaYQ8NFLjSZWexuL5/peGCS+cF 14 | fGSPJnkSGSPWB5s350g/rt4zmAOPqGgBlvjofGczRBqFGO3x+3+pzxgLhPMlkP43 15 | ko33r3J07dWwVPe2KxvbDKoQrhiE10H11uIu1Tfs9k2/uUZLAzFHQCkCgYEAwqCu 16 | 6h34nsl0FzQA/tascCLbwa95Da0IDPrGFZ95X8aryPPNcB4hpuGn0TaHDW8zJV8K 17 | j9kmvBA5H5z9cXV6CeZJrW4eJcmPgUi5RvFUuF7fYSNpFLTE8K1b1mCw8P927WMm 18 | f0EfdY4qCKDU1rIJa6iD5x6Imy6t0HnTg54iHzsCgYBOZa8PzW98DiyJhySsWVje 19 | XPJ8gciaUOsgXDjRNrAw6gLGXc7ZV7+GLdSHSk4rz4ZxxBCzXu1PzXcGvp6tlQrW 20 | Fe0yODd/W9XvuSiec3yAKxYq8z8ikBN8ZfVa2NjvoBCu7h+RxfeWavCGqANPOPwu 21 | Zrj49BLCY0WvIPLeU7lIiQKBgDNuDoqjHN2o0lqHTXQJ+ksviu6lldF9VdFIOyvf 22 | lk0uzJovgqwL6kyU+KmaRRnRtqw7bykP8uJjTxUBgR+IMZWIGxQPMzw9BQTe2Mbc 23 | YszNlS2wE8Z69ke7J7eAmYE1oJGeT7/0z4Fa7dSV22hYZ5DhWOmr8eE/9oJOjwwK 24 | r22dAoGARD2MUPMV7f/98RA9X7XfXNG1yVBDxxGMJLGIDYhExtV6hni0nvC3tzBx 25 | lPvVOUUC54Ho7AuQoQqSUgDQisaAymdEawDYhfwrC5PMqC63dpl6D//vUgAHJvkX 26 | Fwz071Rvr520Yv5yDGS47DzhxImGWK/Wn2ZFkfQzZvtzGBnmHh0= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | com.singlestore.spark.DefaultSource 2 | com.memsql.spark.DefaultSource 3 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.1/spark/MaxNumConcurentTasks.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.scheduler 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | object MaxNumConcurrentTasks { 6 | def get(rdd: RDD[_]): Int = { 7 | val (_, resourceProfiles) = 8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd) 9 | val resourceProfile = 10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles) 11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.1/spark/VersionSpecificAggregateExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op} 5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{ 6 | AggregateFunction, 7 | Average, 8 | First, 9 | Kurtosis, 10 | Last, 11 | Skewness, 12 | StddevPop, 13 | StddevSamp, 14 | Sum, 15 | VariancePop, 16 | VarianceSamp 17 | } 18 | 19 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor, 20 | context: SQLGenContext, 21 | filter: Option[SQLGen.Joinable]) { 22 | def unapply(aggFunc: AggregateFunction): Option[Statement] = { 23 | aggFunc match { 24 | // CentralMomentAgg.scala 25 | case StddevPop(expressionExtractor(child), true) => 26 | Some(aggregateWithFilter("STDDEV_POP", child, filter)) 27 | case StddevSamp(expressionExtractor(child), true) => 28 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter)) 29 | case VariancePop(expressionExtractor(child), true) => 30 | Some(aggregateWithFilter("VAR_POP", child, filter)) 31 | case VarianceSamp(expressionExtractor(child), true) => 32 | Some(aggregateWithFilter("VAR_SAMP", child, filter)) 33 | case Kurtosis(expressionExtractor(child), true) => 34 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) ) 35 | // / POW(STD(child), 4) ) - 3 36 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article 37 | Some( 38 | op( 39 | "-", 40 | op( 41 | "/", 42 | op( 43 | "-", 44 | op( 45 | "+", 46 | op( 47 | "-", 48 | aggregateWithFilter("AVG", f("POW", child, "4"), filter), 49 | op("*", 50 | op("*", 51 | aggregateWithFilter("AVG", child, filter), 52 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)), 53 | "4") 54 | ), 55 | op("*", 56 | "6", 57 | op("*", 58 | aggregateWithFilter("AVG", f("POW", child, "2"), filter), 59 | f("POW", aggregateWithFilter("AVG", child, filter), "2"))) 60 | ), 61 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4")) 62 | ), 63 | f("POW", aggregateWithFilter("STD", child, filter), "4") 64 | ), 65 | "3" 66 | ) 67 | ) 68 | 69 | case Skewness(expressionExtractor(child), true) => 70 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3) 71 | // following the definition section in https://en.wikipedia.org/wiki/Skewness 72 | Some( 73 | op( 74 | "/", 75 | op( 76 | "-", 77 | op( 78 | "-", 79 | aggregateWithFilter("AVG", f("POW", child, "3"), filter), 80 | op("*", 81 | op("*", 82 | aggregateWithFilter("AVG", child, filter), 83 | f("POW", aggregateWithFilter("STD", child, filter), "2")), 84 | "3") 85 | ), 86 | f("POW", aggregateWithFilter("AVG", child, filter), "3") 87 | ), 88 | f("POW", aggregateWithFilter("STD", child, filter), "3") 89 | ) 90 | ) 91 | 92 | // First.scala 93 | case First(expressionExtractor(child), false) => 94 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 95 | 96 | // Last.scala 97 | case Last(expressionExtractor(child), false) => 98 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 99 | 100 | // Sum.scala 101 | case Sum(expressionExtractor(child)) => 102 | Some(aggregateWithFilter("SUM", child, filter)) 103 | 104 | // Average.scala 105 | case Average(expressionExtractor(child)) => 106 | Some(aggregateWithFilter("AVG", child, filter)) 107 | 108 | case _ => None 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.1/spark/VersionSpecificUtil.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.spark.sql.types.{CalendarIntervalType, DataType} 4 | 5 | object VersionSpecificUtil { 6 | def isIntervalType(d:DataType): Boolean = 7 | d.isInstanceOf[CalendarIntervalType] 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.1/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} 5 | 6 | case class VersionSpecificWindowBoundaryExpressionExtractor( 7 | expressionExtractor: ExpressionExtractor) { 8 | def unapply(arg: Expression): Option[Statement] = { 9 | arg match { 10 | case UnaryMinus(expressionExtractor(child), false) => 11 | Some(child + "PRECEDING") 12 | case _ => None 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.2/spark/MaxNumConcurrentTasks.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.scheduler 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | object MaxNumConcurrentTasks { 6 | def get(rdd: RDD[_]): Int = { 7 | val (_, resourceProfiles) = 8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd) 9 | val resourceProfile = 10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles) 11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.2/spark/VersionSpecificAggregateExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op} 5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{ 6 | AggregateFunction, 7 | Average, 8 | First, 9 | Kurtosis, 10 | Last, 11 | Skewness, 12 | StddevPop, 13 | StddevSamp, 14 | Sum, 15 | VariancePop, 16 | VarianceSamp 17 | } 18 | 19 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor, 20 | context: SQLGenContext, 21 | filter: Option[SQLGen.Joinable]) { 22 | def unapply(aggFunc: AggregateFunction): Option[Statement] = { 23 | aggFunc match { 24 | // CentralMomentAgg.scala 25 | case StddevPop(expressionExtractor(child), true) => 26 | Some(aggregateWithFilter("STDDEV_POP", child, filter)) 27 | case StddevSamp(expressionExtractor(child), true) => 28 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter)) 29 | case VariancePop(expressionExtractor(child), true) => 30 | Some(aggregateWithFilter("VAR_POP", child, filter)) 31 | case VarianceSamp(expressionExtractor(child), true) => 32 | Some(aggregateWithFilter("VAR_SAMP", child, filter)) 33 | case Kurtosis(expressionExtractor(child), true) => 34 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) ) 35 | // / POW(STD(child), 4) ) - 3 36 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article 37 | Some( 38 | op( 39 | "-", 40 | op( 41 | "/", 42 | op( 43 | "-", 44 | op( 45 | "+", 46 | op( 47 | "-", 48 | aggregateWithFilter("AVG", f("POW", child, "4"), filter), 49 | op("*", 50 | op("*", 51 | aggregateWithFilter("AVG", child, filter), 52 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)), 53 | "4") 54 | ), 55 | op("*", 56 | "6", 57 | op("*", 58 | aggregateWithFilter("AVG", f("POW", child, "2"), filter), 59 | f("POW", aggregateWithFilter("AVG", child, filter), "2"))) 60 | ), 61 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4")) 62 | ), 63 | f("POW", aggregateWithFilter("STD", child, filter), "4") 64 | ), 65 | "3" 66 | ) 67 | ) 68 | 69 | case Skewness(expressionExtractor(child), true) => 70 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3) 71 | // following the definition section in https://en.wikipedia.org/wiki/Skewness 72 | Some( 73 | op( 74 | "/", 75 | op( 76 | "-", 77 | op( 78 | "-", 79 | aggregateWithFilter("AVG", f("POW", child, "3"), filter), 80 | op("*", 81 | op("*", 82 | aggregateWithFilter("AVG", child, filter), 83 | f("POW", aggregateWithFilter("STD", child, filter), "2")), 84 | "3") 85 | ), 86 | f("POW", aggregateWithFilter("AVG", child, filter), "3") 87 | ), 88 | f("POW", aggregateWithFilter("STD", child, filter), "3") 89 | ) 90 | ) 91 | 92 | // First.scala 93 | case First(expressionExtractor(child), false) => 94 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 95 | 96 | // Last.scala 97 | case Last(expressionExtractor(child), false) => 98 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 99 | 100 | // Sum.scala 101 | case Sum(expressionExtractor(child), false) => 102 | Some(aggregateWithFilter("SUM", child, filter)) 103 | 104 | // Average.scala 105 | case Average(expressionExtractor(child), false) => 106 | Some(aggregateWithFilter("AVG", child, filter)) 107 | 108 | case _ => None 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.2/spark/VersionSpecificUtil.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.spark.sql.types.{ 4 | CalendarIntervalType, 5 | DataType, 6 | DayTimeIntervalType, 7 | YearMonthIntervalType 8 | } 9 | 10 | object VersionSpecificUtil { 11 | def isIntervalType(d: DataType): Boolean = 12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d 13 | .isInstanceOf[YearMonthIntervalType] 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.2/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} 5 | 6 | case class VersionSpecificWindowBoundaryExpressionExtractor( 7 | expressionExtractor: ExpressionExtractor) { 8 | def unapply(arg: Expression): Option[Statement] = { 9 | arg match { 10 | case UnaryMinus(expressionExtractor(child), false) => 11 | Some(child + "PRECEDING") 12 | case _ => None 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.3/spark/MaxNumConcurrentTasks.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.scheduler 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | object MaxNumConcurrentTasks { 6 | def get(rdd: RDD[_]): Int = { 7 | val (_, resourceProfiles) = 8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd) 9 | val resourceProfile = 10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles) 11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.3/spark/VersionSpecificAggregateExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op} 5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{ 6 | AggregateFunction, 7 | Average, 8 | First, 9 | Kurtosis, 10 | Last, 11 | Skewness, 12 | StddevPop, 13 | StddevSamp, 14 | Sum, 15 | VariancePop, 16 | VarianceSamp 17 | } 18 | 19 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor, 20 | context: SQLGenContext, 21 | filter: Option[SQLGen.Joinable]) { 22 | def unapply(aggFunc: AggregateFunction): Option[Statement] = { 23 | aggFunc match { 24 | // CentralMomentAgg.scala 25 | case StddevPop(expressionExtractor(child), true) => 26 | Some(aggregateWithFilter("STDDEV_POP", child, filter)) 27 | case StddevSamp(expressionExtractor(child), true) => 28 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter)) 29 | case VariancePop(expressionExtractor(child), true) => 30 | Some(aggregateWithFilter("VAR_POP", child, filter)) 31 | case VarianceSamp(expressionExtractor(child), true) => 32 | Some(aggregateWithFilter("VAR_SAMP", child, filter)) 33 | case Kurtosis(expressionExtractor(child), true) => 34 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) ) 35 | // / POW(STD(child), 4) ) - 3 36 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article 37 | Some( 38 | op( 39 | "-", 40 | op( 41 | "/", 42 | op( 43 | "-", 44 | op( 45 | "+", 46 | op( 47 | "-", 48 | aggregateWithFilter("AVG", f("POW", child, "4"), filter), 49 | op("*", 50 | op("*", 51 | aggregateWithFilter("AVG", child, filter), 52 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)), 53 | "4") 54 | ), 55 | op("*", 56 | "6", 57 | op("*", 58 | aggregateWithFilter("AVG", f("POW", child, "2"), filter), 59 | f("POW", aggregateWithFilter("AVG", child, filter), "2"))) 60 | ), 61 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4")) 62 | ), 63 | f("POW", aggregateWithFilter("STD", child, filter), "4") 64 | ), 65 | "3" 66 | ) 67 | ) 68 | 69 | case Skewness(expressionExtractor(child), true) => 70 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3) 71 | // following the definition section in https://en.wikipedia.org/wiki/Skewness 72 | Some( 73 | op( 74 | "/", 75 | op( 76 | "-", 77 | op( 78 | "-", 79 | aggregateWithFilter("AVG", f("POW", child, "3"), filter), 80 | op("*", 81 | op("*", 82 | aggregateWithFilter("AVG", child, filter), 83 | f("POW", aggregateWithFilter("STD", child, filter), "2")), 84 | "3") 85 | ), 86 | f("POW", aggregateWithFilter("AVG", child, filter), "3") 87 | ), 88 | f("POW", aggregateWithFilter("STD", child, filter), "3") 89 | ) 90 | ) 91 | 92 | // First.scala 93 | case First(expressionExtractor(child), false) => 94 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 95 | 96 | // Last.scala 97 | case Last(expressionExtractor(child), false) => 98 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 99 | 100 | // Sum.scala 101 | case Sum(expressionExtractor(child), false) => 102 | Some(aggregateWithFilter("SUM", child, filter)) 103 | 104 | // Average.scala 105 | case Average(expressionExtractor(child), false) => 106 | Some(aggregateWithFilter("AVG", child, filter)) 107 | 108 | case _ => None 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.3/spark/VersionSpecificUtil.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.spark.sql.types.{ 4 | CalendarIntervalType, 5 | DataType, 6 | DayTimeIntervalType, 7 | YearMonthIntervalType 8 | } 9 | 10 | object VersionSpecificUtil { 11 | def isIntervalType(d: DataType): Boolean = 12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d 13 | .isInstanceOf[YearMonthIntervalType] 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.3/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} 5 | 6 | case class VersionSpecificWindowBoundaryExpressionExtractor( 7 | expressionExtractor: ExpressionExtractor) { 8 | def unapply(arg: Expression): Option[Statement] = { 9 | arg match { 10 | case UnaryMinus(expressionExtractor(child), false) => 11 | Some(child + "PRECEDING") 12 | case _ => None 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.4/spark/MaxNumConcurrentTasks.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.scheduler 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | object MaxNumConcurrentTasks { 6 | def get(rdd: RDD[_]): Int = { 7 | val (_, resourceProfiles) = 8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd) 9 | val resourceProfile = 10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles) 11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.4/spark/VersionSpecificAggregateExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op} 5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{ 6 | AggregateFunction, 7 | Average, 8 | First, 9 | Kurtosis, 10 | Last, 11 | Skewness, 12 | StddevPop, 13 | StddevSamp, 14 | Sum, 15 | VariancePop, 16 | VarianceSamp 17 | } 18 | import org.apache.spark.sql.catalyst.expressions.EvalMode 19 | 20 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor, 21 | context: SQLGenContext, 22 | filter: Option[SQLGen.Joinable]) { 23 | def unapply(aggFunc: AggregateFunction): Option[Statement] = { 24 | aggFunc match { 25 | // CentralMomentAgg.scala 26 | case StddevPop(expressionExtractor(child), true) => 27 | Some(aggregateWithFilter("STDDEV_POP", child, filter)) 28 | case StddevSamp(expressionExtractor(child), true) => 29 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter)) 30 | case VariancePop(expressionExtractor(child), true) => 31 | Some(aggregateWithFilter("VAR_POP", child, filter)) 32 | case VarianceSamp(expressionExtractor(child), true) => 33 | Some(aggregateWithFilter("VAR_SAMP", child, filter)) 34 | case Kurtosis(expressionExtractor(child), true) => 35 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) ) 36 | // / POW(STD(child), 4) ) - 3 37 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article 38 | Some( 39 | op( 40 | "-", 41 | op( 42 | "/", 43 | op( 44 | "-", 45 | op( 46 | "+", 47 | op( 48 | "-", 49 | aggregateWithFilter("AVG", f("POW", child, "4"), filter), 50 | op("*", 51 | op("*", 52 | aggregateWithFilter("AVG", child, filter), 53 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)), 54 | "4") 55 | ), 56 | op("*", 57 | "6", 58 | op("*", 59 | aggregateWithFilter("AVG", f("POW", child, "2"), filter), 60 | f("POW", aggregateWithFilter("AVG", child, filter), "2"))) 61 | ), 62 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4")) 63 | ), 64 | f("POW", aggregateWithFilter("STD", child, filter), "4") 65 | ), 66 | "3" 67 | ) 68 | ) 69 | 70 | case Skewness(expressionExtractor(child), true) => 71 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3) 72 | // following the definition section in https://en.wikipedia.org/wiki/Skewness 73 | Some( 74 | op( 75 | "/", 76 | op( 77 | "-", 78 | op( 79 | "-", 80 | aggregateWithFilter("AVG", f("POW", child, "3"), filter), 81 | op("*", 82 | op("*", 83 | aggregateWithFilter("AVG", child, filter), 84 | f("POW", aggregateWithFilter("STD", child, filter), "2")), 85 | "3") 86 | ), 87 | f("POW", aggregateWithFilter("AVG", child, filter), "3") 88 | ), 89 | f("POW", aggregateWithFilter("STD", child, filter), "3") 90 | ) 91 | ) 92 | 93 | // First.scala 94 | case First(expressionExtractor(child), false) => 95 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 96 | 97 | // Last.scala 98 | case Last(expressionExtractor(child), false) => 99 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 100 | 101 | // Sum.scala 102 | case Sum(expressionExtractor(child), EvalMode.LEGACY) => 103 | Some(aggregateWithFilter("SUM", child, filter)) 104 | 105 | // Average.scala 106 | case Average(expressionExtractor(child), EvalMode.LEGACY) => 107 | Some(aggregateWithFilter("AVG", child, filter)) 108 | 109 | case _ => None 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.4/spark/VersionSpecificUtil.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.spark.sql.types.{ 4 | CalendarIntervalType, 5 | DataType, 6 | DayTimeIntervalType, 7 | YearMonthIntervalType 8 | } 9 | 10 | object VersionSpecificUtil { 11 | def isIntervalType(d: DataType): Boolean = 12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d 13 | .isInstanceOf[YearMonthIntervalType] 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.4/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} 5 | 6 | case class VersionSpecificWindowBoundaryExpressionExtractor( 7 | expressionExtractor: ExpressionExtractor) { 8 | def unapply(arg: Expression): Option[Statement] = { 9 | arg match { 10 | case UnaryMinus(expressionExtractor(child), false) => 11 | Some(child + "PRECEDING") 12 | case _ => None 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.5/spark/MaxNumConcurrentTasks.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.scheduler 2 | 3 | import org.apache.spark.rdd.RDD 4 | 5 | object MaxNumConcurrentTasks { 6 | def get(rdd: RDD[_]): Int = { 7 | val (_, resourceProfiles) = 8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd) 9 | val resourceProfile = 10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles) 11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.5/spark/VersionSpecificAggregateExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op} 5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{ 6 | AggregateFunction, 7 | Average, 8 | First, 9 | Kurtosis, 10 | Last, 11 | Skewness, 12 | StddevPop, 13 | StddevSamp, 14 | Sum, 15 | VariancePop, 16 | VarianceSamp 17 | } 18 | import org.apache.spark.sql.catalyst.expressions.EvalMode 19 | 20 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor, 21 | context: SQLGenContext, 22 | filter: Option[SQLGen.Joinable]) { 23 | def unapply(aggFunc: AggregateFunction): Option[Statement] = { 24 | aggFunc match { 25 | // CentralMomentAgg.scala 26 | case StddevPop(expressionExtractor(child), true) => 27 | Some(aggregateWithFilter("STDDEV_POP", child, filter)) 28 | case StddevSamp(expressionExtractor(child), true) => 29 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter)) 30 | case VariancePop(expressionExtractor(child), true) => 31 | Some(aggregateWithFilter("VAR_POP", child, filter)) 32 | case VarianceSamp(expressionExtractor(child), true) => 33 | Some(aggregateWithFilter("VAR_SAMP", child, filter)) 34 | case Kurtosis(expressionExtractor(child), true) => 35 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) ) 36 | // / POW(STD(child), 4) ) - 3 37 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article 38 | Some( 39 | op( 40 | "-", 41 | op( 42 | "/", 43 | op( 44 | "-", 45 | op( 46 | "+", 47 | op( 48 | "-", 49 | aggregateWithFilter("AVG", f("POW", child, "4"), filter), 50 | op("*", 51 | op("*", 52 | aggregateWithFilter("AVG", child, filter), 53 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)), 54 | "4") 55 | ), 56 | op("*", 57 | "6", 58 | op("*", 59 | aggregateWithFilter("AVG", f("POW", child, "2"), filter), 60 | f("POW", aggregateWithFilter("AVG", child, filter), "2"))) 61 | ), 62 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4")) 63 | ), 64 | f("POW", aggregateWithFilter("STD", child, filter), "4") 65 | ), 66 | "3" 67 | ) 68 | ) 69 | 70 | case Skewness(expressionExtractor(child), true) => 71 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3) 72 | // following the definition section in https://en.wikipedia.org/wiki/Skewness 73 | Some( 74 | op( 75 | "/", 76 | op( 77 | "-", 78 | op( 79 | "-", 80 | aggregateWithFilter("AVG", f("POW", child, "3"), filter), 81 | op("*", 82 | op("*", 83 | aggregateWithFilter("AVG", child, filter), 84 | f("POW", aggregateWithFilter("STD", child, filter), "2")), 85 | "3") 86 | ), 87 | f("POW", aggregateWithFilter("AVG", child, filter), "3") 88 | ), 89 | f("POW", aggregateWithFilter("STD", child, filter), "3") 90 | ) 91 | ) 92 | 93 | // First.scala 94 | case First(expressionExtractor(child), false) => 95 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 96 | 97 | // Last.scala 98 | case Last(expressionExtractor(child), false) => 99 | Some(aggregateWithFilter("ANY_VALUE", child, filter)) 100 | 101 | // Sum.scala 102 | case Sum(expressionExtractor(child), EvalMode.LEGACY) => 103 | Some(aggregateWithFilter("SUM", child, filter)) 104 | 105 | // Average.scala 106 | case Average(expressionExtractor(child), EvalMode.LEGACY) => 107 | Some(aggregateWithFilter("AVG", child, filter)) 108 | 109 | case _ => None 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.5/spark/VersionSpecificUtil.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.spark.sql.types.{ 4 | CalendarIntervalType, 5 | DataType, 6 | DayTimeIntervalType, 7 | YearMonthIntervalType 8 | } 9 | 10 | object VersionSpecificUtil { 11 | def isIntervalType(d: DataType): Boolean = 12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d 13 | .isInstanceOf[YearMonthIntervalType] 14 | } 15 | -------------------------------------------------------------------------------- /src/main/scala-sparkv3.5/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement} 4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} 5 | 6 | case class VersionSpecificWindowBoundaryExpressionExtractor( 7 | expressionExtractor: ExpressionExtractor) { 8 | def unapply(arg: Expression): Option[Statement] = { 9 | arg match { 10 | case UnaryMinus(expressionExtractor(child), false) => 11 | Some(child + "PRECEDING") 12 | case _ => None 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/com/memsql/spark/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package com.memsql.spark 2 | 3 | import com.singlestore.spark 4 | 5 | class DefaultSource extends spark.DefaultSource { 6 | 7 | override def shortName(): String = spark.DefaultSource.MEMSQL_SOURCE_NAME_SHORT 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/AggregatorParallelReadListener.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Connection, SQLException} 4 | import java.util.Properties 5 | import com.singlestore.spark.JdbcHelpers.getDDLConnProperties 6 | import com.singlestore.spark.SQLGen.VariableList 7 | import org.apache.spark.SparkContext 8 | import org.apache.spark.scheduler.{ 9 | SparkListener, 10 | SparkListenerStageCompleted, 11 | SparkListenerStageSubmitted 12 | } 13 | import org.apache.spark.sql.types.StructType 14 | 15 | import scala.collection.mutable 16 | 17 | class AggregatorParallelReadListener(applicationId: String) extends SparkListener with LazyLogging { 18 | // connectionsMap is a map from the result table name to the connection with which this table was created 19 | private val connectionsMap: mutable.Map[String, Connection] = 20 | new mutable.HashMap[String, Connection]() 21 | 22 | // rddInfos is a map from RDD id to the info needed to create result table for this RDD 23 | private val rddInfos: mutable.Map[Int, SingleStoreRDDInfo] = 24 | new mutable.HashMap[Int, SingleStoreRDDInfo]() 25 | 26 | // SingleStoreRDDInfo is information needed to create a result table 27 | private case class SingleStoreRDDInfo(sc: SparkContext, 28 | query: String, 29 | variables: VariableList, 30 | schema: StructType, 31 | connectionProperties: Properties, 32 | materialized: Boolean, 33 | needsRepartition: Boolean, 34 | repartitionColumns: Seq[String]) 35 | 36 | def addRDDInfo(rdd: SinglestoreRDD): Unit = { 37 | rddInfos.synchronized({ 38 | rddInfos += (rdd.id -> SingleStoreRDDInfo( 39 | rdd.sparkContext, 40 | rdd.query, 41 | rdd.variables, 42 | rdd.schema, 43 | getDDLConnProperties(rdd.options, isOnExecutor = false), 44 | rdd.parallelReadType.contains(ReadFromAggregatorsMaterialized), 45 | rdd.options.parallelReadRepartition, 46 | rdd.parallelReadRepartitionColumns, 47 | )) 48 | }) 49 | } 50 | 51 | def deleteRDDInfo(rdd: SinglestoreRDD): Unit = { 52 | rddInfos.synchronized({ 53 | rddInfos -= rdd.id 54 | }) 55 | } 56 | 57 | def isEmpty: Boolean = { 58 | rddInfos.synchronized({ 59 | rddInfos.isEmpty 60 | }) 61 | } 62 | 63 | override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { 64 | stageSubmitted.stageInfo.rddInfos.foreach(rddInfo => { 65 | if (rddInfo.name.startsWith("SingleStoreRDD")) { 66 | rddInfos 67 | .synchronized( 68 | rddInfos.get(rddInfo.id) 69 | ) 70 | .foreach(singleStoreRDDInfo => { 71 | val stageId = stageSubmitted.stageInfo.stageId 72 | val attemptNumber = stageSubmitted.stageInfo.attemptNumber() 73 | val randHex = rddInfo.name.substring("SingleStoreRDD".size) 74 | val tableName = 75 | JdbcHelpers 76 | .getResultTableName(applicationId, stageId, rddInfo.id, attemptNumber, randHex) 77 | 78 | // Create connection and save it in the map 79 | val conn = 80 | SinglestoreConnectionPool.getConnection(singleStoreRDDInfo.connectionProperties) 81 | connectionsMap.synchronized( 82 | connectionsMap += (tableName -> conn) 83 | ) 84 | 85 | log.info(s"Creating result table '$tableName'") 86 | try { 87 | // Create result table 88 | JdbcHelpers.createResultTable( 89 | conn, 90 | tableName, 91 | singleStoreRDDInfo.query, 92 | singleStoreRDDInfo.schema, 93 | singleStoreRDDInfo.variables, 94 | singleStoreRDDInfo.materialized, 95 | singleStoreRDDInfo.needsRepartition, 96 | singleStoreRDDInfo.repartitionColumns 97 | ) 98 | log.info(s"Successfully created result table '$tableName'") 99 | } catch { 100 | // Cancel execution if we failed to create a result table 101 | case e: SQLException => { 102 | singleStoreRDDInfo.sc.cancelStage(stageId) 103 | throw e 104 | } 105 | } 106 | }) 107 | } 108 | }) 109 | } 110 | 111 | override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { 112 | stageCompleted.stageInfo.rddInfos.foreach(rddInfo => { 113 | if (rddInfo.name.startsWith("SingleStoreRDD")) { 114 | val stageId = stageCompleted.stageInfo.stageId 115 | val attemptNumber = stageCompleted.stageInfo.attemptNumber() 116 | val randHex = rddInfo.name.substring("SingleStoreRDD".size) 117 | val tableName = 118 | JdbcHelpers.getResultTableName(applicationId, stageId, rddInfo.id, attemptNumber, randHex) 119 | 120 | connectionsMap.synchronized( 121 | connectionsMap 122 | .get(tableName) 123 | .foreach(conn => { 124 | // Drop result table 125 | log.info(s"Dropping result table '$tableName'") 126 | JdbcHelpers.dropResultTable(conn, tableName) 127 | log.info(s"Successfully dropped result table '$tableName'") 128 | // Close connection 129 | conn.close() 130 | // Delete connection from map 131 | connectionsMap -= tableName 132 | }) 133 | ) 134 | } 135 | }) 136 | } 137 | } 138 | 139 | case object AggregatorParallelReadListenerAdder { 140 | // listeners is a map from SparkContext hash code to the listener associated with this SparkContext 141 | private val listeners = new mutable.HashMap[SparkContext, AggregatorParallelReadListener]() 142 | 143 | def addRDD(rdd: SinglestoreRDD): Unit = { 144 | this.synchronized({ 145 | val listener = listeners.getOrElse( 146 | rdd.sparkContext, { 147 | val newListener = new AggregatorParallelReadListener(rdd.sparkContext.applicationId) 148 | rdd.sparkContext.addSparkListener(newListener) 149 | listeners += (rdd.sparkContext -> newListener) 150 | newListener 151 | } 152 | ) 153 | listener.addRDDInfo(rdd) 154 | }) 155 | } 156 | 157 | def deleteRDD(rdd: SinglestoreRDD): Unit = { 158 | this.synchronized({ 159 | listeners 160 | .get(rdd.sparkContext) 161 | .foreach(listener => { 162 | listener.deleteRDDInfo(rdd) 163 | if (listener.isEmpty) { 164 | listeners -= rdd.sparkContext 165 | rdd.sparkContext.removeSparkListener(listener) 166 | } 167 | }) 168 | }) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/AvroSchemaHelper.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.avro.Schema 4 | import org.apache.avro.Schema.Type 5 | import org.apache.avro.Schema.Type._ 6 | 7 | import scala.collection.JavaConverters._ 8 | 9 | object AvroSchemaHelper { 10 | 11 | def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { 12 | if (nullable && avroType.getType != NULL) { 13 | // avro uses union to represent nullable type. 14 | val fields = avroType.getTypes.asScala 15 | assert(fields.length == 2) 16 | val actualType = fields.filter(_.getType != Type.NULL) 17 | assert(actualType.length == 1) 18 | actualType.head 19 | } else { 20 | avroType 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/CompletionIterator.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | // Copied from spark's CompletionIterator which is private even though it is generically useful 4 | 5 | abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A] { 6 | private[this] var completed = false 7 | def next(): A = sub.next() 8 | def hasNext: Boolean = { 9 | val r = sub.hasNext 10 | if (!r && !completed) { 11 | completed = true 12 | completion() 13 | } 14 | r 15 | } 16 | 17 | def completion(): Unit 18 | } 19 | 20 | private[spark] object CompletionIterator { 21 | def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit): CompletionIterator[A, I] = { 22 | new CompletionIterator[A, I](sub) { 23 | def completion(): Unit = completionFunction 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.SQLGenContext 4 | import org.apache.spark.TaskContext 5 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 6 | import org.apache.spark.metrics.source.MetricsHandler 7 | import org.apache.spark.sql.sources.{ 8 | BaseRelation, 9 | CreatableRelationProvider, 10 | DataSourceRegister, 11 | RelationProvider 12 | } 13 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} 14 | 15 | object DefaultSource { 16 | 17 | val SINGLESTORE_SOURCE_NAME = "com.singlestore.spark" 18 | val SINGLESTORE_SOURCE_NAME_SHORT = "singlestore" 19 | val SINGLESTORE_GLOBAL_OPTION_PREFIX = "spark.datasource.singlestore." 20 | 21 | @Deprecated val MEMSQL_SOURCE_NAME = "com.memsql.spark" 22 | @Deprecated val MEMSQL_SOURCE_NAME_SHORT = "memsql" 23 | @Deprecated val MEMSQL_GLOBAL_OPTION_PREFIX = "spark.datasource.memsql." 24 | } 25 | 26 | class DefaultSource 27 | extends RelationProvider 28 | with DataSourceRegister 29 | with CreatableRelationProvider 30 | with LazyLogging { 31 | 32 | override def shortName(): String = DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT 33 | 34 | private def includeGlobalParams(sqlContext: SQLContext, 35 | params: Map[String, String]): Map[String, String] = 36 | sqlContext.getAllConfs.foldLeft(params)({ 37 | case (params, (k, v)) if k.startsWith(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) => 38 | params + (k.stripPrefix(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) -> v) 39 | case (params, (k, v)) if k.startsWith(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) => 40 | params + (k.stripPrefix(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) -> v) 41 | case (params, _) => params 42 | }) 43 | 44 | override def createRelation(sqlContext: SQLContext, 45 | parameters: Map[String, String]): BaseRelation = { 46 | val params = CaseInsensitiveMap(includeGlobalParams(sqlContext, parameters)) 47 | val options = SinglestoreOptions(params, sqlContext.sparkSession.sparkContext) 48 | if (options.disablePushdown) { 49 | SQLPushdownRule.ensureRemoved(sqlContext.sparkSession) 50 | SinglestoreReaderNoPushdown(SinglestoreOptions.getQuery(params), options, sqlContext) 51 | } else { 52 | SQLPushdownRule.ensureInjected(sqlContext.sparkSession) 53 | SinglestoreReader(SinglestoreOptions.getQuery(params), 54 | Nil, 55 | options, 56 | sqlContext, 57 | context = SQLGenContext(options)) 58 | } 59 | } 60 | 61 | override def createRelation(sqlContext: SQLContext, 62 | mode: SaveMode, 63 | parameters: Map[String, String], 64 | data: DataFrame): BaseRelation = { 65 | val opts = CaseInsensitiveMap(includeGlobalParams(sqlContext, parameters)) 66 | val conf = SinglestoreOptions(opts, sqlContext.sparkSession.sparkContext) 67 | 68 | val table = SinglestoreOptions 69 | .getTable(opts) 70 | .getOrElse( 71 | throw new IllegalArgumentException( 72 | s"To write a dataframe to SingleStore you must specify a table name via the '${SinglestoreOptions.TABLE_NAME}' parameter" 73 | ) 74 | ) 75 | JdbcHelpers.prepareTableForWrite(conf, table, mode, data.schema) 76 | val isReferenceTable = JdbcHelpers.isReferenceTable(conf, table) 77 | val partitionWriterFactory = 78 | if (conf.onDuplicateKeySQL.isEmpty) { 79 | new LoadDataWriterFactory(table, conf) 80 | } else { 81 | new BatchInsertWriterFactory(table, conf) 82 | } 83 | 84 | val schema = data.schema 85 | var totalRowCount = 0L 86 | data.foreachPartition((partition: Iterator[Row]) => { 87 | val writer = partitionWriterFactory.createDataWriter(schema, 88 | TaskContext.getPartitionId(), 89 | 0, 90 | isReferenceTable, 91 | mode) 92 | try { 93 | partition.foreach(record => { 94 | writer.write(record) 95 | totalRowCount += 1 96 | }) 97 | writer.commit() 98 | MetricsHandler.setRecordsWritten(totalRowCount) 99 | } catch { 100 | case e: Exception => 101 | writer.abort(e) 102 | throw e 103 | } 104 | }) 105 | 106 | createRelation(sqlContext, parameters) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/LazyLogging.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.slf4j.{Logger, LoggerFactory} 4 | 5 | trait LazyLogging { 6 | @transient 7 | protected lazy val log: Logger = LoggerFactory.getLogger(getClass.getName) 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/Loan.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | class Loan[A <: AutoCloseable](resource: A) { 4 | def to[T](handle: A => T): T = 5 | try handle(resource) 6 | finally resource.close() 7 | } 8 | 9 | object Loan { 10 | def apply[A <: AutoCloseable](resource: A) = new Loan(resource) 11 | } 12 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/MetricsHandler.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.metrics.source 2 | 3 | import org.apache.spark.TaskContext 4 | 5 | object MetricsHandler { 6 | def setRecordsWritten(r: Long): Unit = { 7 | TaskContext.get().taskMetrics().outputMetrics.setRecordsWritten(r) 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/OverwriteBehavior.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | sealed trait OverwriteBehavior 4 | 5 | case object Truncate extends OverwriteBehavior 6 | case object Merge extends OverwriteBehavior 7 | case object DropAndCreate extends OverwriteBehavior 8 | 9 | object OverwriteBehavior { 10 | def apply(value: String): OverwriteBehavior = value.toLowerCase match { 11 | case "truncate" => Truncate 12 | case "merge" => Merge 13 | case "dropandcreate" => DropAndCreate 14 | case _ => 15 | throw new IllegalArgumentException( 16 | s"Illegal argument for `${SinglestoreOptions.OVERWRITE_BEHAVIOR}` option") 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/ParallelReadEnablement.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | sealed trait ParallelReadEnablement 4 | 5 | case object Disabled extends ParallelReadEnablement 6 | case object Automatic extends ParallelReadEnablement 7 | case object AutomaticLite extends ParallelReadEnablement 8 | case object Forced extends ParallelReadEnablement 9 | 10 | object ParallelReadEnablement { 11 | def apply(value: String): ParallelReadEnablement = value.toLowerCase match { 12 | case "disabled" => Disabled 13 | case "automaticlite" => AutomaticLite 14 | case "automatic" => Automatic 15 | case "forced" => Forced 16 | 17 | // These two options are added for compatibility purposes 18 | case "false" => Disabled 19 | case "true" => Automatic 20 | 21 | case _ => 22 | throw new IllegalArgumentException( 23 | s"""Illegal argument for `${SinglestoreOptions.ENABLE_PARALLEL_READ}` option. Valid arguments are: 24 | | - "Disabled" 25 | | - "AutomaticLite" 26 | | - "Automatic" 27 | | - "Forced"""".stripMargin) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/ParallelReadType.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | sealed trait ParallelReadType 4 | 5 | case object ReadFromLeaves extends ParallelReadType 6 | case object ReadFromAggregators extends ParallelReadType 7 | case object ReadFromAggregatorsMaterialized extends ParallelReadType 8 | 9 | object ParallelReadType { 10 | def apply(value: String): ParallelReadType = value.toLowerCase match { 11 | case "readfromleaves" => ReadFromLeaves 12 | case "readfromaggregators" => ReadFromAggregators 13 | case "readfromaggregatorsmaterialized" => ReadFromAggregatorsMaterialized 14 | case _ => 15 | throw new IllegalArgumentException( 16 | s"""Illegal argument for `${SinglestoreOptions.PARALLEL_READ_FEATURES}` option. Valid arguments are: 17 | | - "ReadFromLeaves" 18 | | - "ReadFromAggregators" 19 | | - "ReadFromAggregatorsMaterialized"""".stripMargin) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SQLHelper.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.JdbcHelpers.{executeQuery, getDDLConnProperties} 4 | import org.apache.spark.sql.{Row, SparkSession} 5 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 6 | import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils 7 | 8 | object SQLHelper extends LazyLogging { 9 | implicit class QueryMethods(spark: SparkSession) { 10 | private def singlestoreQuery(db: Option[String], 11 | query: String, 12 | variables: Any*): Iterator[Row] = { 13 | val ctx = spark.sqlContext 14 | var opts = ctx.getAllConfs.collect { 15 | case (k, v) if k.startsWith(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) => 16 | k.stripPrefix(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) -> v 17 | case (k, v) if k.startsWith(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) => 18 | k.stripPrefix(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) -> v 19 | } 20 | 21 | if (db.isDefined) { 22 | val dbValue = db.get 23 | if (dbValue.isEmpty) { 24 | opts -= "database" 25 | } else { 26 | opts += ("database" -> dbValue) 27 | } 28 | } 29 | 30 | val conf = SinglestoreOptions(CaseInsensitiveMap(opts), spark.sparkContext) 31 | val conn = 32 | SinglestoreConnectionPool.getConnection(getDDLConnProperties(conf, isOnExecutor = false)) 33 | try { 34 | executeQuery(conn, query, variables: _*) 35 | } finally { 36 | conn.close() 37 | } 38 | } 39 | 40 | def executeSinglestoreQueryDB(db: String, query: String, variables: Any*): Iterator[Row] = { 41 | singlestoreQuery(Some(db), query, variables: _*) 42 | } 43 | 44 | def executeSinglestoreQuery(query: String, variables: Any*): Iterator[Row] = { 45 | singlestoreQuery(None, query, variables: _*) 46 | } 47 | 48 | @Deprecated def executeMemsqlQueryDB(db: String, 49 | query: String, 50 | variables: Any*): Iterator[Row] = { 51 | singlestoreQuery(Some(db), query, variables: _*) 52 | } 53 | 54 | @Deprecated def executeMemsqlQuery(query: String, variables: Any*): Iterator[Row] = { 55 | singlestoreQuery(None, query, variables: _*) 56 | } 57 | } 58 | 59 | def executeSinglestoreQueryDB(spark: SparkSession, 60 | db: String, 61 | query: String, 62 | variables: Any*): Iterator[Row] = { 63 | spark.executeSinglestoreQueryDB(db, query, variables: _*) 64 | } 65 | 66 | def executeSinglestoreQuery(spark: SparkSession, 67 | query: String, 68 | variables: Any*): Iterator[Row] = { 69 | spark.executeSinglestoreQuery(query, variables: _*) 70 | } 71 | 72 | @Deprecated def executeMemsqlQueryDB(spark: SparkSession, 73 | db: String, 74 | query: String, 75 | variables: Any*): Iterator[Row] = { 76 | spark.executeSinglestoreQueryDB(db, query, variables: _*) 77 | } 78 | 79 | @Deprecated def executeMemsqlQuery(spark: SparkSession, 80 | query: String, 81 | variables: Any*): Iterator[Row] = { 82 | spark.executeSinglestoreQuery(query, variables: _*) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SQLPushdownRule.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext} 4 | import org.apache.spark.sql.SparkSession 5 | import org.apache.spark.sql.catalyst.plans.logical._ 6 | import org.apache.spark.sql.catalyst.rules.Rule 7 | 8 | class SQLPushdownRule extends Rule[LogicalPlan] { 9 | override def apply(root: LogicalPlan): LogicalPlan = { 10 | var context: SQLGenContext = null 11 | val needsPushdown = root 12 | .find({ 13 | case SQLGen.Relation(r: SQLGen.Relation) if !r.reader.isFinal => 14 | context = SQLGenContext(root, r.reader.options) 15 | true 16 | case _ => false 17 | }) 18 | .isDefined 19 | 20 | if (!needsPushdown) { 21 | return root 22 | } 23 | 24 | if (log.isTraceEnabled) { 25 | log.trace(s"Optimizing plan:\n${root.treeString(true)}") 26 | } 27 | 28 | // We first need to set a SQLGenContext in every reader. 29 | // This transform is done to ensure that we will generate the same aliases in the same queries. 30 | val normalized = root.transform({ 31 | case SQLGen.Relation(relation) => 32 | relation.toLogicalPlan( 33 | relation.output, 34 | relation.reader.query, 35 | relation.reader.variables, 36 | relation.reader.isFinal, 37 | context 38 | ) 39 | }) 40 | 41 | // Second, we need to rename the outputs of each SingleStore relation in the tree. This transform is 42 | // done to ensure that we can handle projections which involve ambiguous column name references. 43 | var ptr, nextPtr = normalized.transform({ 44 | case SQLGen.Relation(relation) => relation.renameOutput 45 | }) 46 | 47 | val expressionExtractor = ExpressionExtractor(context) 48 | val transforms = 49 | List( 50 | // do all rewrites except top-level sort, e.g. Project([a,b,c], Relation(select * from foo)) 51 | SQLGen.fromLogicalPlan(expressionExtractor).andThen(_.asLogicalPlan()), 52 | // do rewrites with top-level Sort, e.g. Sort(a, Limit(10, Relation(select * from foo)) 53 | // won't be done for relations with parallel read enabled 54 | SQLGen.fromTopLevelSort(expressionExtractor), 55 | ) 56 | 57 | // Run our transforms in a loop until the tree converges 58 | do { 59 | ptr = nextPtr 60 | nextPtr = transforms.foldLeft(ptr)(_.transformUp(_)) 61 | } while (!ptr.fastEquals(nextPtr)) 62 | 63 | // Finalize all the relations in the tree and perform casts into the expected output datatype for Spark 64 | val out = ptr.transform({ 65 | case SQLGen.Relation(relation) if !relation.isFinal => relation.castOutputAndFinalize 66 | }) 67 | 68 | if (log.isTraceEnabled) { 69 | log.trace(s"Optimized Plan:\n${out.treeString(true)}") 70 | } 71 | 72 | out 73 | } 74 | } 75 | 76 | object SQLPushdownRule { 77 | def injected(session: SparkSession): Boolean = { 78 | session.experimental.extraOptimizations 79 | .exists(s => s.isInstanceOf[SQLPushdownRule]) 80 | } 81 | 82 | def ensureInjected(session: SparkSession): Unit = { 83 | if (!injected(session)) { 84 | session.experimental.extraOptimizations ++= Seq(new SQLPushdownRule) 85 | } 86 | } 87 | 88 | def ensureRemoved(session: SparkSession): Unit = { 89 | session.experimental.extraOptimizations = session.experimental.extraOptimizations 90 | .filterNot(s => s.isInstanceOf[SQLPushdownRule]) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SinglestoreBatchInsertWriter.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.Connection 4 | import java.util.Base64 5 | 6 | import com.singlestore.spark.JdbcHelpers.{getDDLConnProperties, getDMLConnProperties} 7 | import org.apache.spark.sql.catalyst.TableIdentifier 8 | import org.apache.spark.sql.types.{BinaryType, StructType} 9 | import org.apache.spark.sql.{Row, SaveMode} 10 | 11 | import scala.collection.mutable.ListBuffer 12 | import scala.concurrent.ExecutionContext 13 | 14 | // TODO: extend it from DataWriterFactory 15 | class BatchInsertWriterFactory(table: TableIdentifier, conf: SinglestoreOptions) 16 | extends WriterFactory 17 | with LazyLogging { 18 | 19 | def createDataWriter(schema: StructType, 20 | partitionId: Int, 21 | attemptNumber: Int, 22 | isReferenceTable: Boolean, 23 | mode: SaveMode): DataWriter[Row] = { 24 | val columnNames = schema.map(s => SinglestoreDialect.quoteIdentifier(s.name)) 25 | val queryPrefix = s"INSERT INTO ${table.quotedString} (${columnNames.mkString(", ")}) VALUES " 26 | val querySuffix = s" ON DUPLICATE KEY UPDATE ${conf.onDuplicateKeySQL.get}" 27 | 28 | val rowTemplate = "(" + schema 29 | .map(x => 30 | x.dataType match { 31 | case BinaryType => "FROM_BASE64(?)" 32 | case _ => "?" 33 | }) 34 | .mkString(",") + ")" 35 | def valueTemplate(rows: Int): String = 36 | List.fill(rows)(rowTemplate).mkString(",") 37 | val fullBatchQuery = queryPrefix + valueTemplate(conf.insertBatchSize) + querySuffix 38 | 39 | val conn = SinglestoreConnectionPool.getConnection(if (isReferenceTable) { 40 | getDDLConnProperties(conf, isOnExecutor = true) 41 | } else { 42 | getDMLConnProperties(conf, isOnExecutor = true) 43 | }) 44 | conn.setAutoCommit(false) 45 | 46 | def writeBatch(buff: ListBuffer[Row]): Long = { 47 | if (buff.isEmpty) { 48 | 0 49 | } else { 50 | val rowsCount = buff.size 51 | val query = if (rowsCount == conf.insertBatchSize) { 52 | fullBatchQuery 53 | } else { 54 | queryPrefix + valueTemplate(rowsCount) + querySuffix 55 | } 56 | 57 | val stmt = conn.prepareStatement(query) 58 | try { 59 | for { 60 | (row, i) <- buff.iterator.zipWithIndex 61 | rowLength = row.size 62 | j <- 0 until rowLength 63 | } row(j) match { 64 | case bytes: Array[Byte] => 65 | stmt.setObject(i * rowLength + j + 1, Base64.getEncoder.encode(bytes)) 66 | case obj => 67 | stmt.setObject(i * rowLength + j + 1, obj) 68 | } 69 | stmt.executeUpdate() 70 | } finally { 71 | stmt.close() 72 | conn.commit() 73 | } 74 | } 75 | } 76 | 77 | new BatchInsertWriter(conf.insertBatchSize, writeBatch, conn) 78 | } 79 | } 80 | 81 | class BatchInsertWriter(batchSize: Int, writeBatch: ListBuffer[Row] => Long, conn: Connection) 82 | extends DataWriter[Row] { 83 | var buff: ListBuffer[Row] = ListBuffer.empty[Row] 84 | 85 | override def write(row: Row): Unit = { 86 | buff += row 87 | if (buff.size >= batchSize) { 88 | writeBatch(buff) 89 | buff = ListBuffer.empty[Row] 90 | } 91 | } 92 | 93 | override def commit(): WriterCommitMessage = { 94 | try { 95 | writeBatch(buff) 96 | buff = ListBuffer.empty[Row] 97 | } finally { 98 | conn.close() 99 | } 100 | new WriteSuccess 101 | } 102 | 103 | override def abort(e: Exception): Unit = { 104 | buff = ListBuffer.empty[Row] 105 | if (!conn.isClosed) { 106 | conn.abort(ExecutionContext.global) 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SinglestoreConnectionPool.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.Connection 4 | import java.util.Properties 5 | import org.apache.commons.dbcp2.{BasicDataSource, BasicDataSourceFactory} 6 | import scala.collection.mutable 7 | 8 | object SinglestoreConnectionPool { 9 | private var dataSources = new mutable.HashMap[Properties, BasicDataSource]() 10 | 11 | private def deleteEmptyDataSources(): Unit = { 12 | dataSources = dataSources.filter(pair => { 13 | val dataSource = pair._2 14 | if (dataSource.getNumActive + dataSource.getNumIdle == 0) { 15 | dataSource.close() 16 | false 17 | } else { 18 | true 19 | } 20 | }) 21 | } 22 | 23 | def getConnection(properties: Properties): Connection = { 24 | this.synchronized({ 25 | dataSources 26 | .getOrElse( 27 | properties, { 28 | deleteEmptyDataSources() 29 | val newDataSource = BasicDataSourceFactory.createDataSource(properties) 30 | newDataSource.addConnectionProperty("connectionAttributes", 31 | properties.getProperty("connectionAttributes")) 32 | dataSources += (properties -> newDataSource) 33 | newDataSource 34 | } 35 | ) 36 | .getConnection 37 | }) 38 | } 39 | 40 | def close(): Unit = { 41 | this.synchronized({ 42 | dataSources.foreach(pair => pair._2.close()) 43 | dataSources = new mutable.HashMap[Properties, BasicDataSource]() 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SinglestoreConnectionPoolOptions.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | case class SinglestoreConnectionPoolOptions(enabled: Boolean, 4 | MaxOpenConns: Int, 5 | MaxIdleConns: Int, 6 | MinEvictableIdleTimeMs: Long, 7 | TimeBetweenEvictionRunsMS: Long, 8 | MaxWaitMS: Long, 9 | MaxConnLifetimeMS: Long) 10 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SinglestoreDialect.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.Types 4 | 5 | import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils 6 | import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType} 7 | import org.apache.spark.sql.types._ 8 | 9 | case object SinglestoreDialect extends JdbcDialect { 10 | override def canHandle(url: String): Boolean = url.startsWith("jdbc:memsql") 11 | 12 | val SINGLESTORE_DECIMAL_MAX_SCALE = 30 13 | 14 | override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { 15 | case BooleanType => Option(JdbcType("BOOL", Types.BOOLEAN)) 16 | case ByteType => Option(JdbcType("TINYINT", Types.TINYINT)) 17 | case ShortType => Option(JdbcType("SMALLINT", Types.SMALLINT)) 18 | case FloatType => Option(JdbcType("FLOAT", Types.FLOAT)) 19 | case TimestampType => Option(JdbcType("TIMESTAMP(6)", Types.TIMESTAMP)) 20 | case dt: DecimalType if (dt.scale <= SINGLESTORE_DECIMAL_MAX_SCALE) => 21 | Option(JdbcType(s"DECIMAL(${dt.precision}, ${dt.scale})", Types.DECIMAL)) 22 | case dt: DecimalType => 23 | throw new IllegalArgumentException( 24 | s"Too big scale specified(${dt.scale}). SingleStore DECIMAL maximum scale is ${SINGLESTORE_DECIMAL_MAX_SCALE}") 25 | case NullType => 26 | throw new IllegalArgumentException( 27 | "No corresponding SingleStore type found for NullType. If you want to use NullType, please write to an already existing SingleStore table.") 28 | case t => JdbcUtils.getCommonJDBCType(t) 29 | } 30 | 31 | override def getCatalystType(sqlType: Int, 32 | typeName: String, 33 | size: Int, 34 | md: MetadataBuilder): Option[DataType] = { 35 | (sqlType, typeName) match { 36 | case (Types.REAL, "FLOAT") => Option(FloatType) 37 | case (Types.BIT, "BIT") => Option(BinaryType) 38 | // JDBC driver returns incorrect SQL type for BIT 39 | // TODO delete after PLAT-6829 is fixed 40 | case (Types.BOOLEAN, "BIT") => Option(BinaryType) 41 | case (Types.TINYINT, "TINYINT") => Option(ShortType) 42 | case (Types.SMALLINT, "SMALLINT") => Option(ShortType) 43 | case (Types.INTEGER, "SMALLINT") => Option(IntegerType) 44 | case (Types.INTEGER, "SMALLINT UNSIGNED") => Option(IntegerType) 45 | case (Types.DECIMAL, "DECIMAL") => { 46 | if (size > DecimalType.MAX_PRECISION) { 47 | throw new IllegalArgumentException( 48 | s"DECIMAL precision ${size} exceeds max precision ${DecimalType.MAX_PRECISION}") 49 | } else { 50 | Option( 51 | DecimalType(size, md.build().getLong("scale").toInt) 52 | ) 53 | } 54 | } 55 | case _ => None 56 | } 57 | } 58 | 59 | override def quoteIdentifier(colName: String): String = { 60 | s"`${colName.replace("`", "``")}`" 61 | } 62 | 63 | override def isCascadingTruncateTable(): Option[Boolean] = Some(false) 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SinglestoreRDD.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Connection, PreparedStatement, ResultSet} 4 | import java.util.concurrent.{Executors, ForkJoinPool} 5 | import com.singlestore.spark.SQLGen.VariableList 6 | import org.apache.spark.rdd.RDD 7 | import org.apache.spark.sql.Row 8 | import org.apache.spark.sql.catalyst.expressions.Attribute 9 | import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils 10 | import org.apache.spark.sql.types._ 11 | import org.apache.spark.util.TaskCompletionListener 12 | import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} 13 | 14 | import scala.concurrent.duration.Duration 15 | import scala.concurrent.ExecutionContext.Implicits.global 16 | import scala.concurrent.{Await, ExecutionContext, Future} 17 | 18 | case class SinglestoreRDD(query: String, 19 | variables: VariableList, 20 | options: SinglestoreOptions, 21 | schema: StructType, 22 | expectedOutput: Seq[Attribute], 23 | resultMustBeSorted: Boolean, 24 | parallelReadRepartitionColumns: Seq[String], 25 | @transient val sc: SparkContext, 26 | randHex: String) 27 | extends RDD[Row](sc, Nil) { 28 | val (parallelReadType, partitions_) = SinglestorePartitioner(this).getPartitions 29 | 30 | // Spark serializes RDD object and sends it to executor 31 | // On executor sc value will be null as it is marked as transient 32 | def isRunOnExecutor: Boolean = sc == null 33 | 34 | val applicationId: String = sc.applicationId 35 | 36 | val aggregatorParallelReadUsed: Boolean = 37 | parallelReadType.contains(ReadFromAggregators) || 38 | parallelReadType.contains(ReadFromAggregatorsMaterialized) 39 | 40 | if (!isRunOnExecutor && aggregatorParallelReadUsed) { 41 | AggregatorParallelReadListenerAdder.addRDD(this) 42 | } 43 | 44 | override def finalize(): Unit = { 45 | if (!isRunOnExecutor && aggregatorParallelReadUsed) { 46 | AggregatorParallelReadListenerAdder.deleteRDD(this) 47 | } 48 | super.finalize() 49 | } 50 | 51 | override protected def getPartitions: Array[Partition] = partitions_ 52 | 53 | override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = { 54 | val multiPartition: SinglestoreMultiPartition = 55 | rawPartition.asInstanceOf[SinglestoreMultiPartition] 56 | val threadPool = new ForkJoinPool(multiPartition.partitions.size) 57 | try { 58 | val executionContext = 59 | ExecutionContext.fromExecutor(threadPool) 60 | val future: Future[Seq[Iterator[Row]]] = Future.sequence(multiPartition.partitions.map(p => 61 | Future(computeSinglePartition(p, context))(executionContext))) 62 | 63 | Await.result(future, Duration.Inf).foldLeft(Iterator[Row]())(_ ++ _) 64 | } finally { 65 | threadPool.shutdownNow() 66 | } 67 | } 68 | 69 | def computeSinglePartition(rawPartition: SinglestorePartition, 70 | context: TaskContext): Iterator[Row] = { 71 | var closed = false 72 | var rs: ResultSet = null 73 | var stmt: PreparedStatement = null 74 | var conn: Connection = null 75 | val partition: SinglestorePartition = rawPartition.asInstanceOf[SinglestorePartition] 76 | 77 | def tryClose(name: String, what: AutoCloseable): Unit = { 78 | try { 79 | if (what != null) { what.close() } 80 | } catch { 81 | case e: Exception => logWarning(s"Exception closing $name", e) 82 | } 83 | } 84 | 85 | val ErrResultTableNotExistCode = 2318 86 | 87 | def close(): Unit = { 88 | if (closed) { return } 89 | tryClose("resultset", rs) 90 | tryClose("statement", stmt) 91 | tryClose("connection", conn) 92 | closed = true 93 | } 94 | 95 | context.addTaskCompletionListener { 96 | new TaskCompletionListener { 97 | override def onTaskCompletion(context: TaskContext): Unit = close() 98 | } 99 | } 100 | 101 | conn = SinglestoreConnectionPool.getConnection(partition.connectionInfo) 102 | if (aggregatorParallelReadUsed) { 103 | val tableName = JdbcHelpers.getResultTableName(applicationId, 104 | context.stageId(), 105 | id, 106 | context.stageAttemptNumber(), 107 | randHex) 108 | 109 | stmt = 110 | conn.prepareStatement(JdbcHelpers.getSelectFromResultTableQuery(tableName, partition.index)) 111 | 112 | val startTime = System.currentTimeMillis() 113 | val timeout = parallelReadType match { 114 | case Some(ReadFromAggregators) => 115 | options.parallelReadTableCreationTimeoutMS 116 | case Some(ReadFromAggregatorsMaterialized) => 117 | options.parallelReadMaterializedTableCreationTimeoutMS 118 | case _ => 119 | 0 120 | } 121 | 122 | var lastError: java.sql.SQLException = null 123 | var delay = 50 124 | val maxDelay = 10000 125 | while (rs == null && (timeout == 0 || System.currentTimeMillis() - startTime < timeout)) { 126 | try { 127 | rs = stmt.executeQuery() 128 | } catch { 129 | case e: java.sql.SQLException if e.getErrorCode == ErrResultTableNotExistCode => 130 | lastError = e 131 | delay = Math.min(maxDelay, delay * 2) 132 | Thread.sleep(delay) 133 | } 134 | } 135 | 136 | if (rs == null) { 137 | throw new java.sql.SQLException("Failed to read data from result table", lastError) 138 | } 139 | } else { 140 | stmt = conn.prepareStatement(partition.query) 141 | JdbcHelpers.fillStatement(stmt, partition.variables) 142 | rs = stmt.executeQuery() 143 | } 144 | 145 | var rowsIter = JdbcUtils.resultSetToRows(rs, schema) 146 | 147 | if (expectedOutput.nonEmpty) { 148 | val schemaDatatypes = schema.map(_.dataType) 149 | val expectedDatatypes = expectedOutput.map(_.dataType) 150 | 151 | def getOrNull(f: => Any, r: Row, i: Int): Any = { 152 | if (r.isNullAt(i)) null 153 | else f 154 | } 155 | 156 | if (schemaDatatypes != expectedDatatypes) { 157 | val columnEncoders = schemaDatatypes.zip(expectedDatatypes).zipWithIndex.map { 158 | case ((_: StringType, _: NullType), _) => ((_: Row) => null) 159 | case ((_: ShortType, _: BooleanType), i) => 160 | (r: Row) => 161 | getOrNull(r.getShort(i) != 0, r, i) 162 | case ((_: IntegerType, _: BooleanType), i) => 163 | (r: Row) => 164 | getOrNull(r.getInt(i) != 0, r, i) 165 | case ((_: LongType, _: BooleanType), i) => 166 | (r: Row) => 167 | getOrNull(r.getLong(i) != 0, r, i) 168 | 169 | case ((_: ShortType, _: ByteType), i) => 170 | (r: Row) => 171 | getOrNull(r.getShort(i).toByte, r, i) 172 | case ((_: IntegerType, _: ByteType), i) => 173 | (r: Row) => 174 | getOrNull(r.getInt(i).toByte, r, i) 175 | case ((_: LongType, _: ByteType), i) => 176 | (r: Row) => 177 | getOrNull(r.getLong(i).toByte, r, i) 178 | 179 | case ((l, r), i) => 180 | options.assert(l == r, s"SinglestoreRDD: unable to encode ${l} into ${r}") 181 | ((r: Row) => getOrNull(r.get(i), r, i)) 182 | } 183 | 184 | rowsIter = rowsIter 185 | .map(row => Row.fromSeq(columnEncoders.map(_(row)))) 186 | } 187 | } 188 | 189 | CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close) 190 | } 191 | 192 | } 193 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/SinglestoreReader.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.SQLSyntaxErrorException 4 | 5 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, VariableList} 6 | import org.apache.spark.rdd.RDD 7 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression => CatalystExpression} 8 | import org.apache.spark.sql.sources.{BaseRelation, CatalystScan, TableScan} 9 | import org.apache.spark.sql.{Row, SQLContext} 10 | 11 | import scala.util.Random; 12 | 13 | case class SinglestoreReaderNoPushdown(query: String, 14 | options: SinglestoreOptions, 15 | @transient val sqlContext: SQLContext) 16 | extends BaseRelation 17 | with TableScan { 18 | 19 | override lazy val schema = JdbcHelpers.loadSchema(options, query, Nil) 20 | 21 | override def buildScan: RDD[Row] = { 22 | val randHex = Random.nextInt().toHexString 23 | val rdd = 24 | SinglestoreRDD( 25 | query, 26 | Nil, 27 | options, 28 | schema, 29 | Nil, 30 | resultMustBeSorted = false, 31 | schema 32 | .filter(sf => options.parallelReadRepartitionColumns.contains(sf.name)) 33 | .map(sf => SQLGen.Ident(sf.name).sql), 34 | sqlContext.sparkContext, 35 | randHex 36 | ) 37 | // Add random hex to the name 38 | // It is needed to generate unique names for result tables during parallel read 39 | .setName("SingleStoreRDD" + randHex) 40 | if (rdd.parallelReadType.contains(ReadFromAggregators)) { 41 | // Wrap an RDD with barrier stage, to force all readers start reading at the same time. 42 | // Repartition it to force spark to read data and do all other computations in different stages. 43 | // Otherwise we will likely get the following error: 44 | // [SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following 45 | // pattern of RDD chain within a barrier stage... 46 | rdd.barrier().mapPartitions(v => v).repartition(rdd.getNumPartitions) 47 | } else { 48 | rdd 49 | } 50 | } 51 | } 52 | 53 | case class SinglestoreReader(query: String, 54 | variables: VariableList, 55 | options: SinglestoreOptions, 56 | @transient val sqlContext: SQLContext, 57 | isFinal: Boolean = false, 58 | expectedOutput: Seq[Attribute] = Nil, 59 | var resultMustBeSorted: Boolean = false, 60 | context: SQLGenContext) 61 | extends BaseRelation 62 | with LazyLogging 63 | with TableScan 64 | with CatalystScan { 65 | 66 | override lazy val schema = JdbcHelpers.loadSchema(options, query, variables) 67 | 68 | override def buildScan: RDD[Row] = { 69 | val randHex = Random.nextInt().toHexString 70 | val rdd = 71 | SinglestoreRDD( 72 | query, 73 | variables, 74 | options, 75 | schema, 76 | expectedOutput, 77 | resultMustBeSorted, 78 | expectedOutput 79 | .filter(attr => options.parallelReadRepartitionColumns.contains(attr.name)) 80 | .map(attr => context.ident(attr.name, attr.exprId)), 81 | sqlContext.sparkContext, 82 | randHex 83 | ) 84 | // Add random hex to the name 85 | // It is needed to generate unique names for result tables during parallel read 86 | .setName("SingleStoreRDD" + randHex) 87 | if (rdd.parallelReadType.contains(ReadFromAggregators)) { 88 | // Wrap an RDD with barrier stage, to force all readers start reading at the same time. 89 | // Repartition it to force spark to read data and do all other computations in different stages. 90 | // Otherwise we will likely get the following error: 91 | // [SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following 92 | // pattern of RDD chain within a barrier stage... 93 | rdd.barrier().mapPartitions(v => v).repartition(rdd.getNumPartitions) 94 | } else { 95 | rdd 96 | } 97 | } 98 | 99 | override def buildScan(rawColumns: Seq[Attribute], 100 | rawFilters: Seq[CatalystExpression]): RDD[Row] = { 101 | // we don't have to push down *everything* using this interface since Spark will 102 | // run the projection and filter again upon receiving the results from SingleStore 103 | val projection = 104 | rawColumns 105 | .flatMap(ExpressionGen.apply(ExpressionExtractor(context)).lift(_)) 106 | .reduceOption(_ + "," + _) 107 | val filters = 108 | rawFilters 109 | .flatMap(ExpressionGen.apply(ExpressionExtractor(context)).lift(_)) 110 | .reduceOption(_ + "AND" + _) 111 | 112 | val stmt = (projection, filters) match { 113 | case (Some(p), Some(f)) => 114 | SQLGen 115 | .select(p) 116 | .from(SQLGen.Relation(Nil, this, context.nextAlias(), null)) 117 | .where(f) 118 | .output(rawColumns) 119 | case (Some(p), None) => 120 | SQLGen 121 | .select(p) 122 | .from(SQLGen.Relation(Nil, this, context.nextAlias(), null)) 123 | .output(rawColumns) 124 | case (None, Some(f)) => 125 | SQLGen.selectAll 126 | .from(SQLGen.Relation(Nil, this, context.nextAlias(), null)) 127 | .where(f) 128 | .output(expectedOutput) 129 | case _ => 130 | return buildScan 131 | } 132 | 133 | val newReader = copy(query = stmt.sql, variables = stmt.variables, expectedOutput = stmt.output) 134 | 135 | if (log.isTraceEnabled) { 136 | log.trace(s"CatalystScan additional rewrite:\n${newReader}") 137 | } 138 | 139 | newReader.buildScan 140 | } 141 | 142 | override def toString: String = { 143 | val explain = 144 | try { 145 | JdbcHelpers.explainQuery(options, query, variables) 146 | } catch { 147 | case e: SQLSyntaxErrorException => e.toString 148 | case e: Exception => throw e 149 | } 150 | val v = variables.map(_.variable).mkString(", ") 151 | 152 | s""" 153 | |--------------- 154 | |SingleStore Query 155 | |Variables: ($v) 156 | |SQL: 157 | |$query 158 | | 159 | |EXPLAIN: 160 | |$explain 161 | |--------------- 162 | """.stripMargin 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/main/scala/com/singlestore/spark/vendor/apache/SchemaConverters.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark.vendor.apache 2 | 3 | import org.apache.avro._ 4 | import org.apache.spark.sql.types._ 5 | 6 | /** 7 | * NOTE: this converter has been taken from `spark-avro` library, 8 | * as this functionality starts with spark 2.4.0 but we support lower versions. 9 | * org.apache.spark.sql.avro.SchemaConverters (2.4.0 version) 10 | * Changes: 11 | * 1. Removed everything except toAvroType function 12 | */ 13 | object SchemaConverters { 14 | 15 | private lazy val nullSchema = Schema.create(Schema.Type.NULL) 16 | 17 | case class SchemaType(dataType: DataType, nullable: Boolean) 18 | 19 | def toAvroType(catalystType: DataType, 20 | nullable: Boolean = false, 21 | recordName: String = "topLevelRecord", 22 | nameSpace: String = ""): Schema = { 23 | val builder = SchemaBuilder.builder() 24 | 25 | val schema = catalystType match { 26 | case BooleanType => builder.booleanType() 27 | case ByteType | ShortType | IntegerType => builder.intType() 28 | case LongType => builder.longType() 29 | case DateType => 30 | LogicalTypes.date().addToSchema(builder.intType()) 31 | case TimestampType => 32 | LogicalTypes.timestampMicros().addToSchema(builder.longType()) 33 | 34 | case FloatType => builder.floatType() 35 | case DoubleType => builder.doubleType() 36 | case StringType => builder.stringType() 37 | case _: DecimalType => builder.stringType() 38 | case BinaryType => builder.bytesType() 39 | case ArrayType(et, containsNull) => 40 | builder 41 | .array() 42 | .items(toAvroType(et, containsNull, recordName, nameSpace)) 43 | case MapType(StringType, vt, valueContainsNull) => 44 | builder 45 | .map() 46 | .values(toAvroType(vt, valueContainsNull, recordName, nameSpace)) 47 | case st: StructType => 48 | val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName 49 | val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields() 50 | st.foreach { f => 51 | val fieldAvroType = 52 | toAvroType(f.dataType, f.nullable, f.name, childNameSpace) 53 | fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault() 54 | } 55 | fieldsAssembler.endRecord() 56 | 57 | // This should never happen. 58 | case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") 59 | } 60 | if (nullable) { 61 | Schema.createUnion(schema, nullSchema) 62 | } else { 63 | schema 64 | } 65 | } 66 | } 67 | 68 | class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex) 69 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=ERROR, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.eclipse.jetty=WARN 10 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=ERROR 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=ERROR 13 | 14 | # Make our logs LOUD 15 | log4j.logger.com.singlestore.spark=DEBUG 16 | -------------------------------------------------------------------------------- /src/test/resources/log4j2.properties: -------------------------------------------------------------------------------- 1 | # Create STDOUT appender that writes data to the console 2 | appenders = console 3 | appender.console.type = Console 4 | appender.console.name = STDOUT 5 | appender.console.layout.type = PatternLayout 6 | appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Set everything to be logged to the console 9 | rootLogger.level = ERROR 10 | rootLogger.appenderRef.stdout.ref = STDOUT 11 | 12 | # Make our logs LOUD 13 | loggers = singlestore 14 | logger.singlestore.name = com.singlestore.spark 15 | logger.singlestore.level = TRACE 16 | logger.singlestore.appenderRef.stdout.ref = STDOUT 17 | logger.singlestore.additivity = false 18 | -------------------------------------------------------------------------------- /src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker: -------------------------------------------------------------------------------- 1 | mock-maker-inline -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/BatchInsertBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Connection, Date, DriverManager} 4 | import java.time.LocalDate 5 | import java.util.Properties 6 | 7 | import org.apache.spark.sql.types._ 8 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 9 | import org.apache.spark.sql.{SaveMode, SparkSession} 10 | 11 | import scala.util.Random 12 | 13 | // BatchInsertBenchmark is written to test batch insert with CPU profiler 14 | // this feature is accessible in Ultimate version of IntelliJ IDEA 15 | // see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details 16 | object BatchInsertBenchmark extends App { 17 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost") 18 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506") 19 | 20 | val spark: SparkSession = SparkSession 21 | .builder() 22 | .master("local") 23 | .config("spark.sql.shuffle.partitions", "1") 24 | .config("spark.driver.bindAddress", "localhost") 25 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}") 26 | .config("spark.datasource.singlestore.database", "testdb") 27 | .getOrCreate() 28 | 29 | def jdbcConnection: Loan[Connection] = { 30 | val connProperties = new Properties() 31 | connProperties.put("user", "root") 32 | 33 | Loan( 34 | DriverManager.getConnection( 35 | s"jdbc:singlestore://$masterHost:$masterPort", 36 | connProperties 37 | )) 38 | } 39 | 40 | def executeQuery(sql: String): Unit = { 41 | jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql))) 42 | } 43 | 44 | executeQuery("set global default_partitions_per_leaf = 2") 45 | executeQuery("drop database if exists testdb") 46 | executeQuery("create database testdb") 47 | 48 | def genDate() = 49 | Date.valueOf(LocalDate.ofEpochDay(LocalDate.of(2001, 4, 11).toEpochDay + Random.nextInt(10000))) 50 | def genRow(): (Long, Int, Double, String, Date) = 51 | (Random.nextLong(), Random.nextInt(), Random.nextDouble(), Random.nextString(20), genDate()) 52 | val df = 53 | spark.createDF( 54 | List.fill(1000000)(genRow()), 55 | List(("LongType", LongType, true), 56 | ("IntType", IntegerType, true), 57 | ("DoubleType", DoubleType, true), 58 | ("StringType", StringType, true), 59 | ("DateType", DateType, true)) 60 | ) 61 | 62 | val start = System.nanoTime() 63 | df.write 64 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 65 | .option("tableKey.primary", "IntType") 66 | .option("onDuplicateKeySQL", "IntType = IntType") 67 | .mode(SaveMode.Append) 68 | .save("testdb.batchinsert") 69 | 70 | val diff = System.nanoTime() - start 71 | println("Elapsed time: " + diff + "ns") 72 | } 73 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/BatchInsertTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 4 | import org.apache.spark.sql.types.{IntegerType, StringType} 5 | import org.apache.spark.sql.{DataFrame, SaveMode} 6 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 7 | 8 | class BatchInsertTest extends IntegrationSuiteBase with BeforeAndAfterEach with BeforeAndAfterAll { 9 | var df: DataFrame = _ 10 | 11 | override def beforeEach(): Unit = { 12 | super.beforeEach() 13 | 14 | df = spark.createDF( 15 | List( 16 | (1, "Jack", 20), 17 | (2, "Dan", 30), 18 | (3, "Bob", 15), 19 | (4, "Alice", 40) 20 | ), 21 | List(("id", IntegerType, false), ("name", StringType, true), ("age", IntegerType, true)) 22 | ) 23 | 24 | df.write 25 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 26 | .mode(SaveMode.Overwrite) 27 | .option("tableKey.primary", "id") 28 | .save("testdb.batchinsert") 29 | } 30 | 31 | it("insert into a new table") { 32 | df = spark.createDF( 33 | List((5, "Eric", 5)), 34 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 35 | ) 36 | df.write 37 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 38 | .option("tableKey.primary", "id") 39 | .option("onDuplicateKeySQL", "age = age + 1") 40 | .option("insertBatchSize", 10) 41 | .mode(SaveMode.Append) 42 | .save("testdb.batchinsertnew") 43 | 44 | val actualDF = 45 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.batchinsertnew") 46 | assertSmallDataFrameEquality(actualDF, df) 47 | } 48 | 49 | def insertAndCheckContent(batchSize: Int, 50 | dfToInsert: List[Any], 51 | expectedContent: List[Any]): Unit = { 52 | df = spark.createDF( 53 | dfToInsert, 54 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 55 | ) 56 | insertValues("testdb.batchinsert", df, "age = age + 1", batchSize) 57 | 58 | val actualDF = 59 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.batchinsert") 60 | assertSmallDataFrameEquality( 61 | actualDF, 62 | spark 63 | .createDF( 64 | expectedContent, 65 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 66 | ), 67 | orderedComparison = false 68 | ) 69 | } 70 | 71 | it("insert a new row") { 72 | insertAndCheckContent( 73 | 10, 74 | List((5, "Eric", 5)), 75 | List( 76 | (1, "Jack", 20), 77 | (2, "Dan", 30), 78 | (3, "Bob", 15), 79 | (4, "Alice", 40), 80 | (5, "Eric", 5) 81 | ) 82 | ) 83 | } 84 | 85 | it("insert several new rows with small batchSize") { 86 | insertAndCheckContent( 87 | 2, 88 | List( 89 | (5, "Jack", 20), 90 | (6, "Mark", 30), 91 | (7, "Fred", 15), 92 | (8, "Jany", 40), 93 | (9, "Monica", 5) 94 | ), 95 | List( 96 | (1, "Jack", 20), 97 | (2, "Dan", 30), 98 | (3, "Bob", 15), 99 | (4, "Alice", 40), 100 | (5, "Jack", 20), 101 | (6, "Mark", 30), 102 | (7, "Fred", 15), 103 | (8, "Jany", 40), 104 | (9, "Monica", 5) 105 | ) 106 | ) 107 | } 108 | 109 | it("insert exactly batchSize rows") { 110 | insertAndCheckContent( 111 | 2, 112 | List( 113 | (5, "Jack", 20), 114 | (6, "Mark", 30) 115 | ), 116 | List( 117 | (1, "Jack", 20), 118 | (2, "Dan", 30), 119 | (3, "Bob", 15), 120 | (4, "Alice", 40), 121 | (5, "Jack", 20), 122 | (6, "Mark", 30) 123 | ) 124 | ) 125 | } 126 | 127 | it("negative batchsize") { 128 | insertAndCheckContent( 129 | -2, 130 | List( 131 | (5, "Jack", 20), 132 | (6, "Mark", 30) 133 | ), 134 | List( 135 | (1, "Jack", 20), 136 | (2, "Dan", 30), 137 | (3, "Bob", 15), 138 | (4, "Alice", 40), 139 | (5, "Jack", 20), 140 | (6, "Mark", 30) 141 | ) 142 | ) 143 | } 144 | 145 | it("empty insert") { 146 | insertAndCheckContent( 147 | 2, 148 | List(), 149 | List( 150 | (1, "Jack", 20), 151 | (2, "Dan", 30), 152 | (3, "Bob", 15), 153 | (4, "Alice", 40) 154 | ) 155 | ) 156 | } 157 | 158 | it("insert one existing row") { 159 | insertAndCheckContent( 160 | 2, 161 | List((1, "Jack", 20)), 162 | List( 163 | (1, "Jack", 21), 164 | (2, "Dan", 30), 165 | (3, "Bob", 15), 166 | (4, "Alice", 40) 167 | ) 168 | ) 169 | } 170 | 171 | it("insert several existing rows") { 172 | insertAndCheckContent( 173 | 2, 174 | List( 175 | (1, "Jack", 20), 176 | (2, "Dan", 30), 177 | (3, "Bob", 15), 178 | (4, "Alice", 40) 179 | ), 180 | List( 181 | (1, "Jack", 21), 182 | (2, "Dan", 31), 183 | (3, "Bob", 16), 184 | (4, "Alice", 41) 185 | ) 186 | ) 187 | } 188 | 189 | it("insert existing and non existing row") { 190 | insertAndCheckContent( 191 | 2, 192 | List( 193 | (1, "Jack", 20), 194 | (5, "Mark", 30) 195 | ), 196 | List( 197 | (1, "Jack", 21), 198 | (2, "Dan", 30), 199 | (3, "Bob", 15), 200 | (4, "Alice", 40), 201 | (5, "Mark", 30) 202 | ) 203 | ) 204 | } 205 | 206 | it("insert NULL") { 207 | insertAndCheckContent( 208 | 2, 209 | List( 210 | (5, null, null) 211 | ), 212 | List( 213 | (1, "Jack", 20), 214 | (2, "Dan", 30), 215 | (3, "Bob", 15), 216 | (4, "Alice", 40), 217 | (5, null, null) 218 | ) 219 | ) 220 | } 221 | 222 | it("non-existing column") { 223 | executeQueryWithLog("DROP TABLE IF EXISTS batchinsert") 224 | executeQueryWithLog("CREATE TABLE batchinsert(id INT, name TEXT)") 225 | 226 | df = spark.createDF( 227 | List((5, "EBCEFGRHFED" * 100, 50)), 228 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 229 | ) 230 | 231 | try { 232 | insertValues("testdb.batchinsert", df, "age = age + 1", 10) 233 | fail() 234 | } catch { 235 | case e: Exception if e.getMessage.contains("Unknown column 'age' in 'field list'") => 236 | } 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/BenchmarkSerializingTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 4 | import org.apache.spark.sql.{DataFrame, SaveMode} 5 | import org.apache.spark.sql.types.{IntegerType, LongType, StringType} 6 | 7 | class BenchmarkSerializingTest extends IntegrationSuiteBase { 8 | 9 | val dbName = "testdb" 10 | val tableName = "avro_table" 11 | 12 | val writeIterations = 10 13 | 14 | val smallDataCount = 1 15 | val mediumDataCount = 1000 16 | val largeDataCount = 100000 17 | 18 | val smallSchema = List(("id", StringType, false)) 19 | val mediumSchema = List(("id", StringType, false), 20 | ("name", StringType, false), 21 | ("surname", StringType, false), 22 | ("age", IntegerType, false)) 23 | val largeSchema = List( 24 | ("id", StringType, false), 25 | ("name", StringType, false), 26 | ("surname", StringType, false), 27 | ("someString", StringType, false), 28 | ("anotherString", StringType, false), 29 | ("age", IntegerType, false), 30 | ("secondNumber", IntegerType, false), 31 | ("thirdNumber", LongType, false) 32 | ) 33 | 34 | def generateSmallData(index: Int) = s"$index" 35 | def generateMediumData(index: Int) = (s"$index", s"name$index", s"surname$index", index) 36 | def generateLargeData(index: Int) = 37 | (s"$index", 38 | s"name$index", 39 | s"surname$index", 40 | s"someString$index", 41 | s"anotherString$index", 42 | index, 43 | index + 1, 44 | index * 2L) 45 | 46 | val generateDataMap: Map[List[scala.Product], Int => Any] = Map( 47 | smallSchema -> generateSmallData, 48 | mediumSchema -> generateMediumData, 49 | largeSchema -> generateLargeData 50 | ) 51 | 52 | override def beforeEach(): Unit = { 53 | super.beforeEach() 54 | executeQueryWithLog(s"drop table if exists $dbName.$tableName") 55 | } 56 | 57 | def doWriteOperation(dataFrame: DataFrame, options: Map[String, String]): Long = { 58 | val startTime = System.currentTimeMillis() 59 | for (_ <- 1 to writeIterations) { 60 | dataFrame.write 61 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 62 | .mode(SaveMode.Append) 63 | .options(options) 64 | .save(s"$dbName.$tableName") 65 | } 66 | val endTime = System.currentTimeMillis() 67 | endTime - startTime 68 | } 69 | 70 | def doTestOperation(dataCount: Int, 71 | schema: List[scala.Product], 72 | options: Map[String, String]): Unit = { 73 | val dataFrameValues = List.tabulate(dataCount)(generateDataMap(schema)) 74 | val dataFrame = 75 | spark.createDF(dataFrameValues, schema) 76 | val timeSpend = doWriteOperation(dataFrame, options) 77 | print(s"Time spend for $writeIterations iterations: $timeSpend") 78 | } 79 | 80 | describe("Avro testing") { 81 | 82 | it("small data | small schema") { 83 | doTestOperation(smallDataCount, 84 | smallSchema, 85 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 86 | } 87 | 88 | it("medium data | small schema") { 89 | doTestOperation(mediumDataCount, 90 | smallSchema, 91 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 92 | } 93 | 94 | it("large data | small schema") { 95 | doTestOperation(largeDataCount, 96 | smallSchema, 97 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 98 | } 99 | 100 | it("small data | medium schema") { 101 | doTestOperation(smallDataCount, 102 | mediumSchema, 103 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 104 | } 105 | 106 | it("medium data | medium schema") { 107 | doTestOperation(mediumDataCount, 108 | mediumSchema, 109 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 110 | } 111 | 112 | it("large data | medium schema") { 113 | doTestOperation(largeDataCount, 114 | mediumSchema, 115 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 116 | } 117 | 118 | it("small data | large schema") { 119 | doTestOperation(smallDataCount, 120 | largeSchema, 121 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 122 | } 123 | 124 | it("medium data | large schema") { 125 | doTestOperation(mediumDataCount, 126 | largeSchema, 127 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 128 | } 129 | 130 | it("large data | large schema") { 131 | doTestOperation(largeDataCount, 132 | largeSchema, 133 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro")) 134 | } 135 | } 136 | 137 | describe("CSV testing") { 138 | it("small data | small schema") { 139 | doTestOperation(smallDataCount, smallSchema, Map.empty) 140 | } 141 | 142 | it("medium data | small schema") { 143 | doTestOperation(mediumDataCount, smallSchema, Map.empty) 144 | } 145 | 146 | it("large data | small schema") { 147 | doTestOperation(largeDataCount, smallSchema, Map.empty) 148 | } 149 | 150 | it("small data | medium schema") { 151 | doTestOperation(smallDataCount, mediumSchema, Map.empty) 152 | } 153 | 154 | it("medium data | medium schema") { 155 | doTestOperation(mediumDataCount, mediumSchema, Map.empty) 156 | } 157 | 158 | it("large data | medium schema") { 159 | doTestOperation(largeDataCount, mediumSchema, Map.empty) 160 | } 161 | 162 | it("small data | large schema") { 163 | doTestOperation(smallDataCount, largeSchema, Map.empty) 164 | } 165 | 166 | it("medium data | large schema") { 167 | doTestOperation(mediumDataCount, largeSchema, Map.empty) 168 | } 169 | 170 | it("large data | large schema") { 171 | doTestOperation(largeDataCount, largeSchema, Map.empty) 172 | } 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/BinaryTypeBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Connection, DriverManager} 4 | import java.util.Properties 5 | 6 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 7 | import com.singlestore.spark.BatchInsertBenchmark.{df, executeQuery} 8 | import org.apache.spark.sql.types.{BinaryType, IntegerType} 9 | import org.apache.spark.sql.{SaveMode, SparkSession} 10 | 11 | import scala.util.Random 12 | 13 | // BinaryTypeBenchmark is written to writing of the BinaryType with CPU profiler 14 | // this feature is accessible in Ultimate version of IntelliJ IDEA 15 | // see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details 16 | object BinaryTypeBenchmark extends App { 17 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost") 18 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506") 19 | 20 | val spark: SparkSession = SparkSession 21 | .builder() 22 | .master("local") 23 | .config("spark.sql.shuffle.partitions", "1") 24 | .config("spark.driver.bindAddress", "localhost") 25 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}") 26 | .config("spark.datasource.singlestore.database", "testdb") 27 | .getOrCreate() 28 | 29 | def jdbcConnection: Loan[Connection] = { 30 | val connProperties = new Properties() 31 | connProperties.put("user", "root") 32 | 33 | Loan( 34 | DriverManager.getConnection( 35 | s"jdbc:singlestore://$masterHost:$masterPort", 36 | connProperties 37 | )) 38 | } 39 | 40 | def executeQuery(sql: String): Unit = { 41 | jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql))) 42 | } 43 | 44 | executeQuery("set global default_partitions_per_leaf = 2") 45 | executeQuery("drop database if exists testdb") 46 | executeQuery("create database testdb") 47 | 48 | def genRandomByte(): Byte = (Random.nextInt(256) - 128).toByte 49 | def genRandomRow(): Array[Byte] = 50 | Array.fill(1000)(genRandomByte()) 51 | 52 | val df = spark.createDF( 53 | List.fill(100000)(genRandomRow()).zipWithIndex, 54 | List(("data", BinaryType, true), ("id", IntegerType, true)) 55 | ) 56 | 57 | val start1 = System.nanoTime() 58 | df.write 59 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 60 | .mode(SaveMode.Overwrite) 61 | .save("testdb.LoadData") 62 | 63 | println("Elapsed time: " + (System.nanoTime() - start1) + "ns [LoadData CSV]") 64 | 65 | val start2 = System.nanoTime() 66 | df.write 67 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 68 | .option("tableKey.primary", "id") 69 | .option("onDuplicateKeySQL", "data = data") 70 | .mode(SaveMode.Overwrite) 71 | .save("testdb.BatchInsert") 72 | 73 | println("Elapsed time: " + (System.nanoTime() - start2) + "ns [BatchInsert]") 74 | 75 | val avroStart = System.nanoTime() 76 | df.write 77 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 78 | .mode(SaveMode.Overwrite) 79 | .option(SinglestoreOptions.LOAD_DATA_FORMAT, "Avro") 80 | .save("testdb.AvroSerialization") 81 | println("Elapsed time: " + (System.nanoTime() - avroStart) + "ns [LoadData Avro] ") 82 | } 83 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/ExternalHostTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.PreparedStatement 4 | import java.util.Properties 5 | 6 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 7 | import com.singlestore.spark.JdbcHelpers.getDDLConnProperties 8 | import com.singlestore.spark.SQLGen.VariableList 9 | import org.apache.spark.sql.DataFrame 10 | import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} 11 | import org.apache.spark.sql.types.{IntegerType, StringType} 12 | import org.mockito.ArgumentMatchers.any 13 | import org.mockito.MockitoSugar 14 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 15 | 16 | class ExternalHostTest 17 | extends IntegrationSuiteBase 18 | with BeforeAndAfterEach 19 | with BeforeAndAfterAll 20 | with MockitoSugar { 21 | 22 | val testDb = "testdb" 23 | val testCollection = "externalHost" 24 | val mvNodesCollection = "mv_nodes" 25 | 26 | var df: DataFrame = _ 27 | 28 | override def beforeEach(): Unit = { 29 | super.beforeEach() 30 | spark.sqlContext.setConf("spark.datasource.singlestore.enableParallelRead", "forced") 31 | spark.sqlContext.setConf("spark.datasource.singlestore.parallelRead.Features", "ReadFromLeaves") 32 | df = spark.createDF( 33 | List((2, "B")), 34 | List(("id", IntegerType, true), ("name", StringType, true)) 35 | ) 36 | writeTable(s"$testDb.$testCollection", df) 37 | } 38 | 39 | def setupMockJdbcHelper(): Unit = { 40 | when(JdbcHelpers.loadSchema(any[SinglestoreOptions], any[String], any[SQLGen.VariableList])) 41 | .thenCallRealMethod() 42 | when(JdbcHelpers.getDDLConnProperties(any[SinglestoreOptions], any[Boolean])) 43 | .thenCallRealMethod() 44 | when(JdbcHelpers.getDMLConnProperties(any[SinglestoreOptions], any[Boolean])) 45 | .thenCallRealMethod() 46 | when(JdbcHelpers.getConnProperties(any[SinglestoreOptions], any[Boolean], any[String])) 47 | .thenCallRealMethod() 48 | when(JdbcHelpers.explainJSONQuery(any[SinglestoreOptions], any[String], any[VariableList])) 49 | .thenCallRealMethod() 50 | when(JdbcHelpers.partitionHostPorts(any[SinglestoreOptions], any[String])) 51 | .thenCallRealMethod() 52 | when(JdbcHelpers.fillStatement(any[PreparedStatement], any[VariableList])) 53 | .thenCallRealMethod() 54 | } 55 | 56 | describe("success tests") { 57 | it("low SingleStore version") { 58 | 59 | withObjectMocked[JdbcHelpers.type] { 60 | 61 | setupMockJdbcHelper() 62 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("6.8.10") 63 | 64 | val actualDF = 65 | spark.read 66 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 67 | .option("useExternalHost", "true") 68 | .load(s"$testDb.$testCollection") 69 | 70 | assertSmallDataFrameEquality( 71 | actualDF, 72 | df 73 | ) 74 | } 75 | } 76 | 77 | it("valid external host") { 78 | 79 | withObjectMocked[JdbcHelpers.type] { 80 | 81 | setupMockJdbcHelper() 82 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0") 83 | 84 | val externalHostMap = Map( 85 | "172.17.0.2:3307" -> "172.17.0.2:3307" 86 | ) 87 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions])) 88 | .thenReturn(externalHostMap) 89 | 90 | val actualDF = 91 | spark.read 92 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 93 | .option("useExternalHost", "true") 94 | .load(s"$testDb.$testCollection") 95 | 96 | assertSmallDataFrameEquality( 97 | actualDF, 98 | df 99 | ) 100 | } 101 | } 102 | 103 | it("empty external host map") { 104 | 105 | withObjectMocked[JdbcHelpers.type] { 106 | 107 | setupMockJdbcHelper() 108 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0") 109 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions])) 110 | .thenReturn(Map.empty[String, String]) 111 | 112 | val actualDf = spark.read 113 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 114 | .option("useExternalHost", "true") 115 | .load(s"$testDb.$testCollection") 116 | 117 | assertSmallDataFrameEquality(df, actualDf) 118 | } 119 | } 120 | 121 | it("wrong external host map") { 122 | 123 | withObjectMocked[JdbcHelpers.type] { 124 | 125 | setupMockJdbcHelper() 126 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0") 127 | 128 | val externalHostMap = Map( 129 | "172.17.0.3:3307" -> "172.17.0.100:3307", 130 | "172.17.0.4:3307" -> "172.17.0.200:3307" 131 | ) 132 | 133 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions])) 134 | .thenReturn(externalHostMap) 135 | 136 | val actualDf = spark.read 137 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 138 | .option("useExternalHost", "true") 139 | .load(s"$testDb.$testCollection") 140 | 141 | assertSmallDataFrameEquality(df, actualDf) 142 | } 143 | } 144 | 145 | it("valid external host function") { 146 | 147 | val mvNodesDf = spark.createDF( 148 | List(("172.17.0.2", 3307, "172.17.0.10", 3310), 149 | ("172.17.0.20", 3312, "172.17.0.100", null), 150 | ("172.17.0.2", 3308, null, 3310), 151 | ("172.17.0.15", 3311, null, null)), 152 | List(("IP_ADDR", StringType, true), 153 | ("PORT", IntegerType, true), 154 | ("EXTERNAL_HOST", StringType, true), 155 | ("EXTERNAL_PORT", IntegerType, true)) 156 | ) 157 | writeTable(s"$testDb.$mvNodesCollection", mvNodesDf) 158 | 159 | val conf = new SinglestoreOptions( 160 | s"$masterHost:$masterPort", 161 | List.empty[String], 162 | "root", 163 | masterPassword, 164 | None, 165 | Map.empty[String, String], 166 | false, 167 | false, 168 | Automatic, 169 | List(ReadFromLeaves), 170 | 0, 171 | 0, 172 | 0, 173 | 0, 174 | true, 175 | Set.empty, 176 | Truncate, 177 | SinglestoreOptions.CompressionType.GZip, 178 | SinglestoreOptions.LoadDataFormat.CSV, 179 | List.empty[SinglestoreOptions.TableKey], 180 | None, 181 | 10, 182 | 10, 183 | false, 184 | SinglestoreConnectionPoolOptions(enabled = true, -1, 8, 30000, 1000, -1, -1), 185 | SinglestoreConnectionPoolOptions(enabled = true, -1, 8, 2000, 1000, -1, -1), 186 | "3.4.0" 187 | ) 188 | 189 | val conn = 190 | SinglestoreConnectionPool.getConnection(getDDLConnProperties(conf, isOnExecutor = false)) 191 | val statement = conn.prepareStatement(s""" 192 | SELECT IP_ADDR, 193 | PORT, 194 | EXTERNAL_HOST, 195 | EXTERNAL_PORT 196 | FROM testdb.mv_nodes; 197 | """) 198 | val spyConn = spy(conn) 199 | val spyStatement = spy(statement) 200 | when(spyConn.prepareStatement(s""" 201 | SELECT IP_ADDR, 202 | PORT, 203 | EXTERNAL_HOST, 204 | EXTERNAL_PORT 205 | FROM INFORMATION_SCHEMA.MV_NODES 206 | WHERE TYPE = "LEAF"; 207 | """)).thenReturn(spyStatement) 208 | 209 | withObjectMocked[SinglestoreConnectionPool.type] { 210 | 211 | when(SinglestoreConnectionPool.getConnection(any[Properties])).thenReturn(spyConn) 212 | val externalHostPorts = JdbcHelpers.externalHostPorts(conf) 213 | val expectedResult = Map( 214 | "172.17.0.2:3307" -> "172.17.0.10:3310" 215 | ) 216 | assert(externalHostPorts.equals(expectedResult)) 217 | } 218 | } 219 | } 220 | 221 | describe("failed tests") { 222 | 223 | it("wrong external host") { 224 | 225 | withObjectMocked[JdbcHelpers.type] { 226 | 227 | setupMockJdbcHelper() 228 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0") 229 | 230 | val externalHostMap = Map( 231 | "172.17.0.2:3307" -> "somehost:3307" 232 | ) 233 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions])) 234 | .thenReturn(externalHostMap) 235 | 236 | try { 237 | spark.read 238 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 239 | .option("useExternalHost", "true") 240 | .option("enableParallelRead", "forced") 241 | .option("parallelRead.Features", "ReadFromLeaves") 242 | .load(s"$testDb.$testCollection") 243 | .collect() 244 | fail("Exception expected") 245 | } catch { 246 | case ex: Throwable => 247 | ex match { 248 | case sqlEx: ParallelReadFailedException => 249 | assert( 250 | sqlEx.getMessage startsWith "Failed to read data in parallel.\nTried following parallel read features:") 251 | case _ => fail("ParallelReadFailedException expected") 252 | } 253 | } 254 | } 255 | } 256 | } 257 | } 258 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/IntegrationSuiteBase.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Connection, DriverManager} 4 | import java.util.{Properties, TimeZone} 5 | 6 | import com.github.mrpowers.spark.fast.tests.DataFrameComparer 7 | import org.apache.log4j.{Level, LogManager} 8 | import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} 9 | import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} 10 | import org.scalatest._ 11 | import org.scalatest.funspec.AnyFunSpec 12 | import com.singlestore.spark.JdbcHelpers.executeQuery 13 | import com.singlestore.spark.SQLGen.SinglestoreVersion 14 | import com.singlestore.spark.SQLHelper._ 15 | 16 | import scala.util.Random 17 | 18 | trait IntegrationSuiteBase 19 | extends AnyFunSpec 20 | with BeforeAndAfterEach 21 | with BeforeAndAfterAll 22 | with DataFrameComparer 23 | with LazyLogging { 24 | object ExcludeFromSpark35 extends Tag("ExcludeFromSpark35") 25 | object ExcludeFromSpark34 extends Tag("ExcludeFromSpark34") 26 | object ExcludeFromSpark33 extends Tag("ExcludeFromSpark33") 27 | object ExcludeFromSpark32 extends Tag("ExcludeFromSpark32") 28 | object ExcludeFromSpark31 extends Tag("ExcludeFromSpark31") 29 | 30 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost") 31 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506") 32 | 33 | final val continuousIntegration: Boolean = sys.env 34 | .getOrElse("CONTINUOUS_INTEGRATION", "false") == "true" 35 | 36 | final val masterPassword: String = sys.env.getOrElse("SINGLESTORE_PASSWORD", "1") 37 | final val masterJWTPassword: String = sys.env.getOrElse("SINGLESTORE_JWT_PASSWORD", "") 38 | final val forceReadFromLeaves: Boolean = 39 | sys.env.getOrElse("FORCE_READ_FROM_LEAVES", "FALSE").equalsIgnoreCase("TRUE") 40 | 41 | var spark: SparkSession = _ 42 | 43 | val jdbcDefaultProps = new Properties() 44 | jdbcDefaultProps.setProperty(JDBCOptions.JDBC_TABLE_NAME, "XXX") 45 | jdbcDefaultProps.setProperty(JDBCOptions.JDBC_DRIVER_CLASS, "org.mariadb.jdbc.Driver") 46 | jdbcDefaultProps.setProperty("user", "root") 47 | jdbcDefaultProps.setProperty("password", masterPassword) 48 | 49 | val version: SinglestoreVersion = { 50 | val conn = 51 | DriverManager.getConnection(s"jdbc:mysql://$masterHost:$masterPort", jdbcDefaultProps) 52 | val resultSet = executeQuery(conn, "select @@memsql_version") 53 | SinglestoreVersion(resultSet.next().getString(0)) 54 | } 55 | 56 | val canDoParallelReadFromAggregators: Boolean = version.atLeast("7.5.0") && !forceReadFromLeaves 57 | 58 | override def beforeAll(): Unit = { 59 | // override global JVM timezone to GMT 60 | TimeZone.setDefault(TimeZone.getTimeZone("GMT")) 61 | 62 | val conn = 63 | DriverManager.getConnection(s"jdbc:mysql://$masterHost:$masterPort", jdbcDefaultProps) 64 | try { 65 | // make singlestore use less memory 66 | executeQuery(conn, "set global default_partitions_per_leaf = 2") 67 | executeQuery(conn, "set global data_conversion_compatibility_level = '6.0'") 68 | 69 | executeQuery(conn, "drop database if exists testdb") 70 | executeQuery(conn, "create database testdb") 71 | } finally { 72 | conn.close() 73 | } 74 | } 75 | 76 | override def withFixture(test: NoArgTest): Outcome = { 77 | def retryThrowable(t: Throwable): Boolean = t match { 78 | case _: java.sql.SQLNonTransientConnectionException => true 79 | case _ => false 80 | } 81 | 82 | @scala.annotation.tailrec 83 | def runWithRetry(attempts: Int, lastError: Option[Throwable]): Outcome = { 84 | if (attempts == 0) { 85 | return Canceled( 86 | s"too many SQLNonTransientConnectionExceptions occurred, last error was:\n${lastError.get}") 87 | } 88 | 89 | super.withFixture(test) match { 90 | case Failed(t: Throwable) if retryThrowable(t) || retryThrowable(t.getCause) => { 91 | Thread.sleep(3000) 92 | runWithRetry(attempts - 1, Some(t)) 93 | } 94 | case other => other 95 | } 96 | } 97 | 98 | runWithRetry(attempts = 5, None) 99 | } 100 | 101 | override def beforeEach(): Unit = { 102 | super.beforeEach() 103 | 104 | val seed = Random.nextLong() 105 | log.debug("Random seed: " + seed) 106 | Random.setSeed(seed) 107 | 108 | if (!continuousIntegration) { 109 | LogManager.getLogger("com.singlestore.spark").setLevel(Level.TRACE) 110 | } 111 | 112 | spark = SparkSession 113 | .builder() 114 | .master(if (canDoParallelReadFromAggregators) "local[2]" else "local") 115 | .appName("singlestore-integration-tests") 116 | .config("spark.sql.shuffle.partitions", "1") 117 | .config("spark.driver.bindAddress", "localhost") 118 | .config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT") 119 | .config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT") 120 | .config("spark.sql.session.timeZone", "GMT") 121 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}") 122 | .config("spark.datasource.singlestore.user", "root-ssl") 123 | .config("spark.datasource.singlestore.password", "") 124 | .config("spark.datasource.singlestore.enableAsserts", "true") 125 | .config("spark.datasource.singlestore.enableParallelRead", "automaticLite") 126 | .config("spark.datasource.singlestore.parallelRead.Features", 127 | if (forceReadFromLeaves) "ReadFromLeaves" else "ReadFromAggregators,ReadFromLeaves") 128 | .config("spark.datasource.singlestore.database", "testdb") 129 | .config("spark.datasource.singlestore.useSSL", "true") 130 | .config("spark.datasource.singlestore.serverSslCert", 131 | s"${System.getProperty("user.dir")}/scripts/ssl/test-ca-cert.pem") 132 | .config("spark.datasource.singlestore.disableSslHostnameVerification", "true") 133 | .config("spark.sql.crossJoin.enabled", "true") 134 | .getOrCreate() 135 | } 136 | 137 | override def afterEach(): Unit = { 138 | super.afterEach() 139 | spark.close() 140 | } 141 | 142 | def executeQueryWithLog(sql: String): Unit = { 143 | log.trace(s"executing query: ${sql}") 144 | spark.executeSinglestoreQuery(sql) 145 | } 146 | 147 | def jdbcOptions(dbtable: String): Map[String, String] = Map( 148 | "url" -> s"jdbc:mysql://$masterHost:$masterPort", 149 | "dbtable" -> dbtable, 150 | "user" -> "root", 151 | "password" -> masterPassword, 152 | "pushDownPredicate" -> "false" 153 | ) 154 | 155 | def jdbcOptionsSQL(dbtable: String): String = 156 | jdbcOptions(dbtable) 157 | .foldLeft(List.empty[String])({ 158 | case (out, (k, v)) => s"'${k}'='${v}'" :: out 159 | }) 160 | .mkString(", ") 161 | 162 | def writeTable(dbtable: String, df: DataFrame, saveMode: SaveMode = SaveMode.Overwrite): Unit = 163 | df.write 164 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 165 | .mode(saveMode) 166 | .save(dbtable) 167 | 168 | def insertValues(dbtable: String, 169 | df: DataFrame, 170 | onDuplicateKeySQL: String, 171 | insertBatchSize: Long): Unit = 172 | df.write 173 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 174 | .option("onDuplicateKeySQL", onDuplicateKeySQL) 175 | .option("insertBatchSize", insertBatchSize) 176 | .mode(SaveMode.Append) 177 | .save(dbtable) 178 | } 179 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/IssuesTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 4 | import org.apache.spark.sql.SaveMode 5 | import org.apache.spark.sql.functions._ 6 | import org.apache.spark.sql.types._ 7 | 8 | class IssuesTest extends IntegrationSuiteBase { 9 | it("https://github.com/memsql/singlestore-spark-connector/issues/41") { 10 | executeQueryWithLog(""" 11 | | create table if not exists testdb.issue41 ( 12 | | start_video_pos smallint(5) unsigned DEFAULT NULL 13 | | ) 14 | |""".stripMargin) 15 | 16 | val df = spark.createDF( 17 | List(1.toShort, 2.toShort, 3.toShort, 4.toShort), 18 | List(("start_video_pos", ShortType, true)) 19 | ) 20 | df.write 21 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 22 | .mode(SaveMode.Append) 23 | .save("issue41") 24 | 25 | val df2 = spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("issue41") 26 | assertSmallDataFrameEquality(df2, 27 | spark.createDF( 28 | List(1, 2, 3, 4), 29 | List(("start_video_pos", IntegerType, true)) 30 | ), 31 | orderedComparison = false) 32 | } 33 | 34 | it("https://memsql.zendesk.com/agent/tickets/10451") { 35 | // parallel read should support columnar scan with filter 36 | executeQueryWithLog(""" 37 | | create table if not exists testdb.ticket10451 ( 38 | | t text, 39 | | h bigint(20) DEFAULT NULL, 40 | | KEY h (h) USING CLUSTERED COLUMNSTORE 41 | | ) 42 | | """.stripMargin) 43 | 44 | val df = spark.createDF( 45 | List(("hi", 2L), ("hi", 3L), ("foo", 4L)), 46 | List(("t", StringType, true), ("h", LongType, true)) 47 | ) 48 | df.write 49 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 50 | .mode(SaveMode.Append) 51 | .save("ticket10451") 52 | 53 | val df2 = spark.read 54 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 55 | .load("ticket10451") 56 | .where(col("t") === "hi") 57 | .where(col("h") === 3L) 58 | 59 | assert(df2.rdd.getNumPartitions > 1) 60 | assertSmallDataFrameEquality(df2, 61 | spark.createDF( 62 | List(("hi", 3L)), 63 | List(("t", StringType, true), ("h", LongType, true)) 64 | )) 65 | } 66 | 67 | it("supports reading count from query") { 68 | val df = spark.createDF( 69 | List((1, "Albert"), (5, "Ronny"), (7, "Ben"), (9, "David")), 70 | List(("id", IntegerType, true), ("name", StringType, true)) 71 | ) 72 | writeTable("testdb.testcount", df) 73 | val data = spark.read 74 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 75 | .option("query", "select count(1) from testcount where id > 1 ") 76 | .option("database", "testdb") 77 | .load() 78 | .collect() 79 | val count = data.head.getLong(0) 80 | assert(count == 3) 81 | } 82 | 83 | it("handles exceptions raised by asCode") { 84 | // in certain cases asCode will raise NullPointerException due to this bug 85 | // https://issues.apache.org/jira/browse/SPARK-31403 86 | writeTable("testdb.nulltest", 87 | spark.createDF( 88 | List(1, null), 89 | List(("i", IntegerType, true)) 90 | )) 91 | spark.sql(s"create table nulltest using singlestore options ('dbtable'='testdb.nulltest')") 92 | 93 | val df2 = spark.sql("select if(isnull(i), null, 2) as x from nulltest order by i") 94 | 95 | assertSmallDataFrameEquality(df2, 96 | spark.createDF( 97 | List(null, 2), 98 | List(("x", IntegerType, true)) 99 | )) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/LoadDataBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Connection, Date, DriverManager} 4 | import java.time.{Instant, LocalDate} 5 | import java.util.Properties 6 | 7 | import org.apache.spark.sql.types._ 8 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 9 | import org.apache.spark.sql.{SaveMode, SparkSession} 10 | 11 | import scala.util.Random 12 | 13 | // LoadDataBenchmark is written to test load data with CPU profiler 14 | // this feature is accessible in Ultimate version of IntelliJ IDEA 15 | // see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details 16 | object LoadDataBenchmark extends App { 17 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost") 18 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506") 19 | 20 | val spark: SparkSession = SparkSession 21 | .builder() 22 | .master("local") 23 | .config("spark.sql.shuffle.partitions", "1") 24 | .config("spark.driver.bindAddress", "localhost") 25 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}") 26 | .config("spark.datasource.singlestore.database", "testdb") 27 | .getOrCreate() 28 | 29 | def jdbcConnection: Loan[Connection] = { 30 | val connProperties = new Properties() 31 | connProperties.put("user", "root") 32 | 33 | Loan( 34 | DriverManager.getConnection( 35 | s"jdbc:singlestore://$masterHost:$masterPort", 36 | connProperties 37 | )) 38 | } 39 | 40 | def executeQuery(sql: String): Unit = { 41 | jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql))) 42 | } 43 | 44 | executeQuery("set global default_partitions_per_leaf = 2") 45 | executeQuery("drop database if exists testdb") 46 | executeQuery("create database testdb") 47 | 48 | def genRow(): (Long, Int, Double, String) = 49 | (Random.nextLong(), Random.nextInt(), Random.nextDouble(), Random.nextString(20)) 50 | val df = 51 | spark.createDF( 52 | List.fill(1000000)(genRow()), 53 | List(("LongType", LongType, true), 54 | ("IntType", IntegerType, true), 55 | ("DoubleType", DoubleType, true), 56 | ("StringType", StringType, true)) 57 | ) 58 | 59 | val start = System.nanoTime() 60 | df.write 61 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 62 | .mode(SaveMode.Append) 63 | .save("testdb.batchinsert") 64 | 65 | val diff = System.nanoTime() - start 66 | println("Elapsed time: " + diff + "ns [CSV serialization] ") 67 | 68 | executeQuery("truncate testdb.batchinsert") 69 | 70 | val avroStart = System.nanoTime() 71 | df.write 72 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 73 | .mode(SaveMode.Append) 74 | .option(SinglestoreOptions.LOAD_DATA_FORMAT, "Avro") 75 | .save("testdb.batchinsert") 76 | val avroDiff = System.nanoTime() - avroStart 77 | println("Elapsed time: " + avroDiff + "ns [Avro serialization] ") 78 | } 79 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/LoadDataTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 4 | import org.apache.spark.sql.types.{DecimalType, IntegerType, NullType, StringType} 5 | import org.apache.spark.sql.{DataFrame, SaveMode} 6 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 7 | 8 | import scala.util.Try 9 | 10 | class LoadDataTest extends IntegrationSuiteBase with BeforeAndAfterEach with BeforeAndAfterAll { 11 | var df: DataFrame = _ 12 | 13 | override def beforeEach(): Unit = { 14 | super.beforeEach() 15 | 16 | df = spark.createDF( 17 | List(), 18 | List(("id", IntegerType, false), ("name", StringType, true), ("age", IntegerType, true)) 19 | ) 20 | 21 | writeTable("testdb.loaddata", df) 22 | } 23 | 24 | it("appends row without `age` field") { 25 | df = spark.createDF( 26 | List((2, "B")), 27 | List(("id", IntegerType, true), ("name", StringType, true)) 28 | ) 29 | writeTable("testdb.loaddata", df, SaveMode.Append) 30 | 31 | val actualDF = 32 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata") 33 | assertSmallDataFrameEquality( 34 | actualDF, 35 | spark.createDF( 36 | List((2, "B", null)), 37 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 38 | )) 39 | } 40 | 41 | it("appends row without `name` field") { 42 | df = spark.createDF( 43 | List((3, 30)), 44 | List(("id", IntegerType, true), ("age", IntegerType, true)) 45 | ) 46 | writeTable("testdb.loaddata", df, SaveMode.Append) 47 | 48 | val actualDF = 49 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata") 50 | assertSmallDataFrameEquality( 51 | actualDF, 52 | spark.createDF( 53 | List((3, null, 30)), 54 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 55 | )) 56 | } 57 | 58 | it("should not append row without not nullable `id` field") { 59 | df = spark.createDF( 60 | List(("D", 40)), 61 | List(("name", StringType, true), ("age", IntegerType, true)) 62 | ) 63 | 64 | try { 65 | writeTable("testdb.loaddata", df, SaveMode.Append) 66 | fail() 67 | } catch { 68 | // error code 1364 is `Field 'id' doesn't have a default value` 69 | case e: Throwable if TestHelper.isSQLExceptionWithCode(e, List(1364)) => 70 | } 71 | } 72 | 73 | it("appends row with all fields") { 74 | df = spark.createDF( 75 | List((5, "E", 50)), 76 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 77 | ) 78 | writeTable("testdb.loaddata", df, SaveMode.Append) 79 | 80 | val actualDF = 81 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata") 82 | assertSmallDataFrameEquality( 83 | actualDF, 84 | spark.createDF( 85 | List((5, "E", 50)), 86 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 87 | )) 88 | } 89 | 90 | it("appends row only with `id` field") { 91 | df = spark.createDF(List(6), List(("id", IntegerType, true))) 92 | writeTable("testdb.loaddata", df, SaveMode.Append) 93 | 94 | val actualDF = 95 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata") 96 | assertSmallDataFrameEquality( 97 | actualDF, 98 | spark.createDF( 99 | List((6, null, null)), 100 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 101 | ) 102 | ) 103 | } 104 | 105 | it("appends row with all fields in wrong order") { 106 | df = spark.createDF( 107 | List(("WO", 101, 101)), 108 | List(("name", StringType, true), ("age", IntegerType, true), ("id", IntegerType, true)) 109 | ) 110 | writeTable("testdb.loaddata", df, SaveMode.Append) 111 | 112 | val actualDF = 113 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata") 114 | assertSmallDataFrameEquality( 115 | actualDF, 116 | spark.createDF( 117 | List((101, "WO", 101)), 118 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 119 | )) 120 | } 121 | 122 | it("appends row with `id` and `name` fields in wrong order") { 123 | df = spark.createDF( 124 | List(("WO2", 102)), 125 | List(("name", StringType, true), ("id", IntegerType, true)) 126 | ) 127 | writeTable("testdb.loaddata", df, SaveMode.Append) 128 | 129 | val actualDF = 130 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata") 131 | assertSmallDataFrameEquality( 132 | actualDF, 133 | spark.createDF( 134 | List((102, "WO2", null)), 135 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 136 | ) 137 | ) 138 | } 139 | 140 | it("should not append row with more fields than expected") { 141 | df = spark.createDF( 142 | List((1, "1", 1, 1)), 143 | List(("id", IntegerType, true), 144 | ("name", StringType, true), 145 | ("age", IntegerType, true), 146 | ("extra", IntegerType, false)) 147 | ) 148 | 149 | try { 150 | writeTable("testdb.loaddata", df, SaveMode.Append) 151 | fail() 152 | } catch { 153 | // error code 1054 is `Unknown column 'extra' in 'field list'` 154 | case e: Throwable if TestHelper.isSQLExceptionWithCode(e, List(1054)) => 155 | } 156 | } 157 | 158 | it("should not append row with wrong field name") { 159 | df = spark.createDF( 160 | List((1, "1", 1)), 161 | List(("id", IntegerType, true), ("wrongname", StringType, true), ("age", IntegerType, true))) 162 | 163 | try { 164 | writeTable("testdb.loaddata", df, SaveMode.Append) 165 | fail() 166 | } catch { 167 | // error code 1054 is `Unknown column 'wrongname' in 'field list'` 168 | case e: Throwable if TestHelper.isSQLExceptionWithCode(e, List(1054)) => 169 | } 170 | } 171 | 172 | it("should fail creating table with NullType") { 173 | val tableName = "null_type" 174 | 175 | val dfNull = spark.createDF(List(null), List(("id", NullType, true))) 176 | val writeResult = Try { 177 | writeTable(s"testdb.$tableName", dfNull, SaveMode.Append) 178 | } 179 | assert(writeResult.isFailure) 180 | assert( 181 | writeResult.failed.get.getMessage 182 | .equals( 183 | "No corresponding SingleStore type found for NullType. If you want to use NullType, please write to an already existing SingleStore table.")) 184 | } 185 | 186 | it("should succeed inserting NullType in existed table") { 187 | val tableName = "null_type" 188 | 189 | df = spark.createDF(List(1, 2, 3, 4, 5), List(("id", IntegerType, true))) 190 | writeTable(s"testdb.$tableName", df, SaveMode.Append) 191 | 192 | val dfNull = spark.createDF(List(null), List(("id", NullType, true))) 193 | writeTable(s"testdb.$tableName", dfNull, SaveMode.Append) 194 | } 195 | 196 | it("should write BigDecimal with Avro serializing") { 197 | val tableName = "bigDecimalAvro" 198 | val df = spark.createDF( 199 | List((1, "Alice", 213: BigDecimal)), 200 | List(("id", IntegerType, true), ("name", StringType, true), ("age", DecimalType(10, 0), true)) 201 | ) 202 | df.write 203 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 204 | .mode(SaveMode.Overwrite) 205 | .option(SinglestoreOptions.LOAD_DATA_FORMAT, "avro") 206 | .save(s"testdb.$tableName") 207 | 208 | val actualDF = 209 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load(s"testdb.$tableName") 210 | assertLargeDataFrameEquality(actualDF, df) 211 | } 212 | 213 | it("should work with `memsql` and `com.memsql.spark` source") { 214 | df = spark.createDF( 215 | List((5, "E", 50)), 216 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 217 | ) 218 | writeTable("testdb.loaddata", df, SaveMode.Append) 219 | 220 | val actualDFShort = 221 | spark.read.format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT).load("testdb.loaddata") 222 | assertSmallDataFrameEquality( 223 | actualDFShort, 224 | spark.createDF( 225 | List((5, "E", 50)), 226 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 227 | )) 228 | val actualDF = 229 | spark.read.format(DefaultSource.MEMSQL_SOURCE_NAME).load("testdb.loaddata") 230 | assertSmallDataFrameEquality( 231 | actualDF, 232 | spark.createDF( 233 | List((5, "E", 50)), 234 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 235 | )) 236 | } 237 | 238 | it("non-existing column") { 239 | executeQueryWithLog("DROP TABLE IF EXISTS loaddata") 240 | executeQueryWithLog("CREATE TABLE loaddata(id INT, name TEXT)") 241 | 242 | df = spark.createDF( 243 | List((5, "EBCEFGRHFED" * 10000000, 50)), 244 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 245 | ) 246 | 247 | try { 248 | writeTable("testdb.loaddata", df, SaveMode.Append) 249 | fail() 250 | } catch { 251 | case e: Exception if e.getMessage.contains("Unknown column 'age' in 'field list'") => 252 | } 253 | } 254 | 255 | it("load data in batches") { 256 | val df1 = spark.createDF( 257 | List( 258 | (5, "Jack", 20), 259 | (6, "Mark", 30), 260 | (7, "Fred", 15), 261 | (8, "Jany", 40), 262 | (9, "Monica", 5) 263 | ), 264 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 265 | ) 266 | val df2 = spark.createDF( 267 | List( 268 | (10, "Jany", 40), 269 | (11, "Monica", 5) 270 | ), 271 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 272 | ) 273 | val df3 = spark.createDF( 274 | List(), 275 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true)) 276 | ) 277 | 278 | df1.write 279 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 280 | .option("insertBatchSize", 2) 281 | .mode(SaveMode.Append) 282 | .save("testdb.loadDataBatches") 283 | df2.write 284 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 285 | .option("insertBatchSize", 2) 286 | .mode(SaveMode.Append) 287 | .save("testdb.loadDataBatches") 288 | df3.write 289 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 290 | .option("insertBatchSize", 2) 291 | .mode(SaveMode.Append) 292 | .save("testdb.loadDataBatches") 293 | 294 | val actualDF = 295 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loadDataBatches") 296 | assertSmallDataFrameEquality( 297 | actualDF, 298 | df1.union(df2).union(df3), 299 | orderedComparison = false 300 | ) 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/LoadbalanceTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 4 | import com.singlestore.spark.JdbcHelpers.executeQuery 5 | import org.apache.spark.sql.SaveMode 6 | import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} 7 | import org.apache.spark.sql.jdbc.JdbcDialects 8 | import org.apache.spark.sql.types.IntegerType 9 | 10 | import java.sql.DriverManager 11 | import java.util.Properties 12 | 13 | class LoadbalanceTest extends IntegrationSuiteBase { 14 | 15 | val masterHostPort = s"${masterHost}:${masterPort}" 16 | val childHostPort = "localhost:5508" 17 | 18 | override def beforeEach(): Unit = { 19 | super.beforeEach() 20 | 21 | // Set master + child aggregator as dmlEndpoints 22 | spark.conf 23 | .set("spark.datasource.singlestore.dmlEndpoints", s"${masterHostPort},${childHostPort}") 24 | } 25 | 26 | def countQueries(hostport: String): Int = { 27 | val props = new Properties() 28 | props.setProperty("dbtable", "testdb") 29 | props.setProperty("user", "root") 30 | props.setProperty("password", masterPassword) 31 | 32 | val conn = DriverManager.getConnection(s"jdbc:singlestore://$hostport", props) 33 | try { 34 | // we only use write queries since read queries are always increasing due to internal status checks 35 | val rows = 36 | JdbcHelpers.executeQuery(conn, "show status extended like 'Successful_write_queries'") 37 | rows.map(r => r.getAs[String](1).toInt).sum 38 | } finally { 39 | conn.close() 40 | } 41 | } 42 | 43 | def counters = 44 | Map( 45 | masterHostPort -> countQueries(masterHostPort), 46 | childHostPort -> countQueries(childHostPort) 47 | ) 48 | 49 | describe("load-balances among all hosts listed in dmlEndpoints") { 50 | 51 | it("queries both aggregators eventually") { 52 | 53 | val df = spark.createDF( 54 | List(4, 5, 6), 55 | List(("id", IntegerType, true)) 56 | ) 57 | 58 | val startCounters = counters 59 | 60 | // 50/50 chance of picking either agg, 10 tries should be enough to ensure we hit both aggs with write queries 61 | for (i <- 0 to 10) { 62 | df.write 63 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 64 | .option("driverConnectionPool.MinEvictableIdleTimeMs", "100") 65 | .option("driverConnectionPool.TimeBetweenEvictionRunsMS", "50") 66 | .option("executorConnectionPool.MinEvictableIdleTimeMs", "100") 67 | .option("executorConnectionPool.TimeBetweenEvictionRunsMS", "50") 68 | .mode(SaveMode.Overwrite) 69 | .save("test") 70 | 71 | Thread.sleep(300) 72 | } 73 | 74 | val endCounters = counters 75 | 76 | assert(endCounters(childHostPort) > startCounters(childHostPort)) 77 | assert(endCounters(masterHostPort) > startCounters(masterHostPort)) 78 | } 79 | 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/MaxErrorsTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.SQLTransientConnectionException 4 | 5 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 6 | import org.apache.spark.sql.SaveMode 7 | import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType} 8 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} 9 | 10 | import scala.util.Try 11 | 12 | class MaxErrorsTest extends IntegrationSuiteBase with BeforeAndAfterEach with BeforeAndAfterAll { 13 | 14 | def testMaxErrors(tableName: String, maxErrors: Int, duplicateItems: Int): Unit = { 15 | val df = spark.createDF( 16 | List.fill(duplicateItems + 1)((1, "Alice", 213: BigDecimal)), 17 | List(("id", IntegerType, true), 18 | ("name", StringType, false), 19 | ("age", DecimalType(10, 0), true)) 20 | ) 21 | val result = Try { 22 | df.write 23 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 24 | .option("tableKey.primary", "name") 25 | .option("maxErrors", maxErrors) 26 | .mode(SaveMode.Ignore) 27 | .save(s"testdb.$tableName") 28 | } 29 | if (duplicateItems > maxErrors) { 30 | assert(result.isFailure) 31 | result.failed.get.getCause match { 32 | case _: SQLTransientConnectionException => 33 | case _ => fail("SQLTransientConnectionException should be thrown") 34 | } 35 | } else { 36 | assert(result.isSuccess) 37 | } 38 | } 39 | 40 | describe("small dataset") { 41 | // TODO DB-51213 42 | //it("hit maxErrors") { 43 | // testMaxErrors("hitMaxErrorsSmall", 1, 2) 44 | //} 45 | 46 | it("not hit maxErrors") { 47 | testMaxErrors("notHitMaxErrorsSmall", 1, 1) 48 | } 49 | } 50 | 51 | describe("big dataset") { 52 | // TODO DB-51213 53 | //it("hit maxErrors") { 54 | // testMaxErrors("hitMaxErrorsBig", 10000, 10001) 55 | //} 56 | 57 | it("not hit maxErrors") { 58 | testMaxErrors("notHitMaxErrorsBig", 10000, 10000) 59 | } 60 | } 61 | 62 | it("wrong configuration") { 63 | val df = spark.createDF( 64 | List((1, "Alice", 213: BigDecimal)), 65 | List(("id", IntegerType, true), 66 | ("name", StringType, false), 67 | ("age", DecimalType(10, 0), true)) 68 | ) 69 | val result = Try { 70 | df.write 71 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 72 | .option("onDuplicateKeySQL", "id=id") 73 | .option("maxErrors", 1) 74 | .mode(SaveMode.Ignore) 75 | .save(s"testdb.someTable") 76 | } 77 | assert(result.isFailure) 78 | result.failed.get match { 79 | case ex: IllegalArgumentException 80 | if ex.getMessage.equals("can't use both `onDuplicateKeySQL` and `maxErrors` options") => 81 | succeed 82 | case _ => fail() 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/OutputMetricsTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.util.concurrent.CountDownLatch 4 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 5 | import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} 6 | import org.apache.spark.sql.types.{IntegerType, StringType} 7 | 8 | class OutputMetricsTest extends IntegrationSuiteBase { 9 | it("records written") { 10 | var outputWritten = 0L 11 | var countDownLatch: CountDownLatch = null 12 | spark.sparkContext.addSparkListener(new SparkListener() { 13 | override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { 14 | if (taskEnd.taskType == "ResultTask") { 15 | outputWritten.synchronized({ 16 | val metrics = taskEnd.taskMetrics 17 | outputWritten += metrics.outputMetrics.recordsWritten 18 | countDownLatch.countDown() 19 | }) 20 | } 21 | } 22 | }) 23 | 24 | val numRows = 100000 25 | var df1 = spark.createDF( 26 | List.range(0, numRows), 27 | List(("id", IntegerType, true)) 28 | ) 29 | 30 | var numPartitions = 30 31 | countDownLatch = new CountDownLatch(numPartitions) 32 | df1 = df1.repartition(numPartitions) 33 | 34 | df1.write 35 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 36 | .save("metricsInts") 37 | 38 | countDownLatch.await() 39 | assert(outputWritten == numRows) 40 | 41 | var df2 = spark.createDF( 42 | List("st1", "", null), 43 | List(("st", StringType, true)) 44 | ) 45 | 46 | outputWritten = 0 47 | numPartitions = 1 48 | countDownLatch = new CountDownLatch(numPartitions) 49 | df2 = df2.repartition(numPartitions) 50 | 51 | df2.write 52 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 53 | .save("metricsStrings") 54 | countDownLatch.await() 55 | assert(outputWritten == 3) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/ReferenceTableTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 4 | import org.apache.spark.sql.types.IntegerType 5 | import org.apache.spark.sql.{DataFrame, SaveMode} 6 | 7 | import scala.util.Try 8 | 9 | class ReferenceTableTest extends IntegrationSuiteBase { 10 | 11 | val childAggregatorHost = "localhost" 12 | val childAggregatorPort = "5508" 13 | 14 | val dbName = "testdb" 15 | val commonCollectionName = "test_table" 16 | val referenceCollectionName = "reference_table" 17 | 18 | override def beforeEach(): Unit = { 19 | super.beforeEach() 20 | 21 | // Set child aggregator as a dmlEndpoint 22 | spark.conf 23 | .set("spark.datasource.singlestore.dmlEndpoints", 24 | s"${childAggregatorHost}:${childAggregatorPort}") 25 | } 26 | 27 | def writeToTable(tableName: String): Unit = { 28 | val df = spark.createDF( 29 | List(4, 5, 6), 30 | List(("id", IntegerType, true)) 31 | ) 32 | df.write 33 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 34 | .mode(SaveMode.Append) 35 | .save(s"${dbName}.${tableName}") 36 | } 37 | 38 | def readFromTable(tableName: String): DataFrame = { 39 | spark.read 40 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 41 | .load(s"${dbName}.${tableName}") 42 | } 43 | 44 | def writeAndReadFromTable(tableName: String): Unit = { 45 | writeToTable(tableName) 46 | val dataFrame = readFromTable(tableName) 47 | val sqlRows = dataFrame.collect(); 48 | assert(sqlRows.length == 3) 49 | } 50 | 51 | def dropTable(tableName: String): Unit = 52 | executeQueryWithLog(s"drop table if exists $dbName.$tableName") 53 | 54 | describe("Success during write operations") { 55 | 56 | it("to common table") { 57 | dropTable(commonCollectionName) 58 | executeQueryWithLog( 59 | s"create table if not exists $dbName.$commonCollectionName (id INT NOT NULL, PRIMARY KEY (id))") 60 | writeAndReadFromTable(commonCollectionName) 61 | } 62 | 63 | it("to reference table") { 64 | dropTable(referenceCollectionName) 65 | executeQueryWithLog( 66 | s"create reference table if not exists $dbName.$referenceCollectionName (id INT NOT NULL, PRIMARY KEY (id))") 67 | writeAndReadFromTable(referenceCollectionName) 68 | } 69 | } 70 | 71 | describe("Success during creating") { 72 | 73 | it("common table") { 74 | dropTable(commonCollectionName) 75 | writeAndReadFromTable(commonCollectionName) 76 | } 77 | } 78 | 79 | describe("Failure because of") { 80 | 81 | it("database name not specified") { 82 | spark.conf.set("spark.datasource.singlestore.database", "") 83 | val df = spark.createDF( 84 | List(4, 5, 6), 85 | List(("id", IntegerType, true)) 86 | ) 87 | val result = Try { 88 | df.write 89 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 90 | .mode(SaveMode.Append) 91 | .save(s"${commonCollectionName}") 92 | } 93 | /* Error code description: 94 | 1046 = Database name not provided 95 | * */ 96 | assert(TestHelper.isSQLExceptionWithCode(result.failed.get, List(1046))) 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/SQLHelperTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.{Date, Timestamp} 4 | 5 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 6 | import com.singlestore.spark.SQLHelper._ 7 | import org.apache.spark.sql.DataFrame 8 | import org.apache.spark.sql.types._ 9 | import org.scalatest.BeforeAndAfterEach 10 | 11 | class SQLHelperTest extends IntegrationSuiteBase with BeforeAndAfterEach { 12 | var df: DataFrame = _ 13 | 14 | override def beforeEach(): Unit = { 15 | super.beforeEach() 16 | df = spark.createDF( 17 | List((1, "Cat", true), (2, "Dog", true), (3, "CatDog", false)), 18 | List(("id", IntegerType, true), ("name", StringType, true), ("domestic", BooleanType, true)) 19 | ) 20 | writeTable("testdb.animal", df) 21 | 22 | df = spark.createDF( 23 | List((1: Long, 2: Short, 3: Float, 4: Double), 24 | (-2: Long, 22: Short, 2.0.toFloat, 5.1: Double), 25 | (3: Long, 4: Short, -0.11.toFloat, 66.77: Double)), 26 | List(("nLong", LongType, true), 27 | ("nShort", ShortType, true), 28 | ("nFloat", FloatType, true), 29 | ("nDouble", DoubleType, true)) 30 | ) 31 | writeTable("testdb.numbers", df) 32 | 33 | df = spark.createDF( 34 | List((1: Byte, Date.valueOf("2015-03-30"), new Timestamp(2147483649L)), 35 | (2: Byte, Date.valueOf("2020-09-09"), new Timestamp(1000))), 36 | List(("nByte", ByteType, true), 37 | ("nDate", DateType, true), 38 | ("nTimestamp", TimestampType, true)) 39 | ) 40 | writeTable("testdb.byte_dates", df) 41 | } 42 | 43 | override def afterEach(): Unit = { 44 | super.afterEach() 45 | 46 | spark.executeSinglestoreQueryDB("testdb", "DROP TABLE IF EXISTS animal") 47 | spark.executeSinglestoreQueryDB("testdb", "DROP TABLE IF EXISTS numbers") 48 | spark.executeSinglestoreQueryDB("testdb", "DROP TABLE IF EXISTS byte_dates") 49 | spark.executeSinglestoreQuery("DROP DATABASE IF EXISTS test_db_1") 50 | } 51 | 52 | describe("implicit version") { 53 | it("global query test") { 54 | val s = spark.executeSinglestoreQuery("SHOW DATABASES") 55 | val result = 56 | (for (row <- s) 57 | yield row.getString(0)).toList 58 | for (db <- List("memsql", "cluster", "testdb", "information_schema")) { 59 | assert(result.contains(db)) 60 | } 61 | } 62 | 63 | it("executeSinglestoreQuery with 2 parameters") { 64 | spark.executeSinglestoreQuery("DROP DATABASE IF EXISTS test_db_1") 65 | spark.executeSinglestoreQuery("CREATE DATABASE test_db_1") 66 | } 67 | 68 | it("executeSinglestoreQueryDB with 3 parameters") { 69 | val res = spark.executeSinglestoreQueryDB("testdb", "SELECT * FROM animal") 70 | val out = 71 | for (row <- res) 72 | yield row.getString(1) 73 | 74 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog")) 75 | } 76 | 77 | it("executeSingleStoreQuery without explicitly specified db") { 78 | val res = spark.executeSinglestoreQuery("SELECT * FROM animal") 79 | val out = 80 | for (row <- res) 81 | yield row.getString(1) 82 | 83 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog")) 84 | } 85 | 86 | it("executeSinglestoreQuery with query params") { 87 | val params = List("%Cat%", 1, false) 88 | val res = spark.executeSinglestoreQuery( 89 | "SELECT * FROM animal WHERE name LIKE ? AND id > ? AND domestic = ?", 90 | "%Cat%", 91 | 1, 92 | false) 93 | val out = 94 | for (row <- res) 95 | yield row.getString(1) 96 | 97 | assert(out.toList == List("CatDog")) 98 | } 99 | 100 | it("executeSinglestoreQuery with db different from the one found in SparkContext") { 101 | spark.executeSinglestoreQuery(query = "CREATE DATABASE test_db_1") 102 | spark.executeSinglestoreQuery(query = "CREATE TABLE test_db_1.animal (id INT, name TEXT)") 103 | 104 | val res = spark.executeSinglestoreQueryDB("test_db_1", "SELECT * FROM animal") 105 | val out = 106 | for (row <- res) 107 | yield row.getString(1) 108 | 109 | assert(out.toList == List()) 110 | } 111 | } 112 | 113 | it("executeSinglestoreQuery with numeric columns") { 114 | val params = List(1, 22, 2, null) 115 | val res = spark.executeSinglestoreQuery( 116 | "SELECT * FROM numbers WHERE nLong = ? OR nShort = ? OR nFloat = ? OR nDouble = ?", 117 | 1, 118 | 22, 119 | 2, 120 | null) 121 | assert(res.length == 2) 122 | } 123 | 124 | it("executeSinglestoreQuery with byte and date columns") { 125 | val res = spark.executeSinglestoreQuery( 126 | "SELECT * FROM byte_dates WHERE nByte = ? OR nDate = ? OR nTimestamp = ?", 127 | 2, 128 | "2015-03-30", 129 | 2000) 130 | assert(res.length == 2) 131 | } 132 | 133 | describe("explicit version") { 134 | it("global query test") { 135 | val s = executeSinglestoreQuery(spark, "SHOW DATABASES") 136 | val result = 137 | (for (row <- s) 138 | yield row.getString(0)).toList 139 | for (db <- List("memsql", "cluster", "testdb", "information_schema")) { 140 | assert(result.contains(db)) 141 | } 142 | } 143 | 144 | it("executeSinglestoreQuery with 2 parameters") { 145 | executeSinglestoreQuery(spark, "DROP DATABASE IF EXISTS test_db_1") 146 | executeSinglestoreQuery(spark, "CREATE DATABASE test_db_1") 147 | } 148 | 149 | it("executeSinglestoreQueryDB with 3 parameters") { 150 | val res = executeSinglestoreQueryDB(spark, "testdb", "SELECT * FROM animal") 151 | val out = 152 | for (row <- res) 153 | yield row.getString(1) 154 | 155 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog")) 156 | } 157 | 158 | it("executeSinglestoreQuery without explicitly specified db") { 159 | val res = executeSinglestoreQuery(spark, "SELECT * FROM animal") 160 | val out = 161 | for (row <- res) 162 | yield row.getString(1) 163 | 164 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog")) 165 | } 166 | 167 | it("executeSinglestoreQuery with query params") { 168 | val params = List("%Cat%", 1, false) 169 | val res = 170 | executeSinglestoreQuery( 171 | spark, 172 | "SELECT * FROM animal WHERE name LIKE ? AND id > ? AND domestic = ?", 173 | "%Cat%", 174 | 1, 175 | false) 176 | val out = 177 | for (row <- res) 178 | yield row.getString(1) 179 | 180 | assert(out.toList == List("CatDog")) 181 | } 182 | 183 | it("executeSinglestoreQuery with db different from the one found in SparkContext") { 184 | executeSinglestoreQuery(spark, query = "CREATE DATABASE test_db_1") 185 | executeSinglestoreQuery(spark, query = "CREATE TABLE test_db_1.animal (id INT, name TEXT)") 186 | 187 | val res = executeSinglestoreQueryDB(spark, "test_db_1", "SELECT * FROM animal") 188 | val out = 189 | for (row <- res) 190 | yield row.getString(1) 191 | 192 | assert(out.toList == List()) 193 | } 194 | 195 | it("executeSinglestoreQuery with numeric columns") { 196 | val params = List(1, 22, 2, null) 197 | val res = executeSinglestoreQuery( 198 | spark, 199 | "SELECT * FROM numbers WHERE nLong = ? OR nShort = ? OR nFloat = ? OR nDouble = ?", 200 | 1, 201 | 22, 202 | 2, 203 | null) 204 | assert(res.length == 2) 205 | } 206 | 207 | it("executeSinglestoreQuery with byte and date columns") { 208 | val res = 209 | executeSinglestoreQuery( 210 | spark, 211 | "SELECT * FROM byte_dates WHERE nByte = ? OR nDate = ? OR nTimestamp = ?", 212 | 2, 213 | "2015-03-30", 214 | 2000) 215 | assert(res.length == 2) 216 | } 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/SQLPermissionsTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.util.UUID 4 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._ 5 | import com.singlestore.spark.JdbcHelpers.executeQuery 6 | import org.apache.spark.sql.SaveMode 7 | import org.apache.spark.sql.types.IntegerType 8 | 9 | import java.sql.DriverManager 10 | import scala.util.Try 11 | 12 | class SQLPermissionsTest extends IntegrationSuiteBase { 13 | 14 | val testUserName = "sparkuserselect" 15 | val dbName = "testdb" 16 | val collectionName = "temps_test" 17 | 18 | override def beforeAll(): Unit = { 19 | super.beforeAll() 20 | val conn = 21 | DriverManager.getConnection(s"jdbc:mysql://$masterHost:$masterPort", jdbcDefaultProps) 22 | executeQuery(conn, s"CREATE USER '${testUserName}'@'%'") 23 | } 24 | 25 | override def beforeEach(): Unit = { 26 | super.beforeEach() 27 | val df = spark.createDF( 28 | List(1, 2, 3), 29 | List(("id", IntegerType, true)) 30 | ) 31 | writeTable(s"${dbName}.${collectionName}", df) 32 | } 33 | 34 | private def setUpUserPermissions(privilege: String): Unit = { 35 | /* Revoke all permissions from user */ 36 | Try(executeQueryWithLog(s"REVOKE ALL PRIVILEGES ON ${dbName}.* FROM '${testUserName}'@'%'")) 37 | /* Give permissions to user */ 38 | executeQueryWithLog(s"GRANT ${privilege} ON ${dbName}.* TO '${testUserName}'@'%'") 39 | /* Set up user to spark */ 40 | spark.conf.set("spark.datasource.singlestore.user", s"${testUserName}") 41 | } 42 | 43 | private def doSuccessOperation(operation: () => Unit)(privilege: String): Unit = { 44 | it(s"success with ${privilege} permission") { 45 | setUpUserPermissions(privilege) 46 | val result = Try(operation()) 47 | if (result.isFailure) { 48 | result.failed.get.printStackTrace() 49 | fail() 50 | } 51 | } 52 | } 53 | 54 | private def doFailOperation(operation: () => Unit)(privilege: String): Unit = { 55 | it(s"fails with ${privilege} permission") { 56 | setUpUserPermissions(privilege) 57 | val result = Try(operation()) 58 | /* Error codes description: 59 | 1142 = denied to current user 60 | 1050 = table already exists (error throws when we don't have SELECT permission to check if such table already exists) 61 | */ 62 | assert(TestHelper.isSQLExceptionWithCode(result.failed.get, List(1142, 1050))) 63 | } 64 | } 65 | 66 | describe("read permissions") { 67 | /* List of supported privileges for read operation */ 68 | val supportedPrivileges = List("SELECT", "ALL PRIVILEGES") 69 | /* List of unsupported privileges for read operation */ 70 | val unsupportedPrivileges = List("CREATE", "DROP", "DELETE", "INSERT", "UPDATE") 71 | 72 | def operation(): Unit = 73 | spark.read 74 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 75 | .load(s"${dbName}.${collectionName}") 76 | 77 | unsupportedPrivileges.foreach(doFailOperation(operation)) 78 | supportedPrivileges.foreach(doSuccessOperation(operation)) 79 | } 80 | 81 | describe("write permissions") { 82 | /* List of supported privileges for write operation */ 83 | val supportedPrivileges = List("INSERT, SELECT", "ALL PRIVILEGES") 84 | /* List of unsupported privileges for write operation */ 85 | val unsupportedPrivileges = List("CREATE", "DROP", "DELETE", "SELECT", "UPDATE") 86 | 87 | def operation(): Unit = { 88 | val df = spark.createDF( 89 | List(4, 5, 6), 90 | List(("id", IntegerType, true)) 91 | ) 92 | df.write 93 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 94 | .mode(SaveMode.Append) 95 | .save(s"${dbName}.${collectionName}") 96 | } 97 | 98 | unsupportedPrivileges.foreach(doFailOperation(operation)) 99 | supportedPrivileges.foreach(doSuccessOperation(operation)) 100 | } 101 | 102 | describe("drop permissions") { 103 | 104 | /* List of supported privileges for drop operation */ 105 | val supportedPrivileges = List("DROP, SELECT, INSERT", "ALL PRIVILEGES") 106 | /* List of unsupported privileges for drop operation */ 107 | val unsupportedPrivileges = List("CREATE", "INSERT", "DELETE", "SELECT", "UPDATE") 108 | 109 | implicit def operation(): Unit = { 110 | val df = spark.createDF( 111 | List(1, 2, 3), 112 | List(("id", IntegerType, true)) 113 | ) 114 | df.write 115 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 116 | .option("truncate", "true") 117 | .mode(SaveMode.Overwrite) 118 | .save(s"${dbName}.${collectionName}") 119 | } 120 | 121 | unsupportedPrivileges.foreach(doFailOperation(operation)) 122 | supportedPrivileges.foreach(doSuccessOperation(operation)) 123 | } 124 | 125 | describe("create permissions") { 126 | 127 | /* List of supported privileges for create operation */ 128 | val supportedPrivileges = List("CREATE, SELECT, INSERT", "ALL PRIVILEGES") 129 | /* List of unsupported privileges for create operation */ 130 | val unsupportedPrivileges = List("DROP", "INSERT", "DELETE", "SELECT", "UPDATE") 131 | 132 | implicit def operation(): Unit = { 133 | val df = spark.createDF( 134 | List(1, 2, 3), 135 | List(("id", IntegerType, true)) 136 | ) 137 | df.write 138 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT) 139 | .mode(SaveMode.Overwrite) 140 | .save(s"${dbName}.${collectionName}_${UUID.randomUUID().toString.split("-")(0)}") 141 | } 142 | 143 | unsupportedPrivileges.foreach(doFailOperation(operation)) 144 | supportedPrivileges.foreach(doSuccessOperation(operation)) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/SinglestoreConnectionPoolTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.Connection 4 | import java.time.Duration 5 | import java.util.Properties 6 | 7 | import org.apache.commons.dbcp2.DelegatingConnection 8 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 9 | 10 | class SinglestoreConnectionPoolTest extends IntegrationSuiteBase { 11 | var properties = new Properties() 12 | 13 | override def beforeEach(): Unit = { 14 | super.beforeEach() 15 | properties = JdbcHelpers.getConnProperties( 16 | SinglestoreOptions( 17 | CaseInsensitiveMap( 18 | Map("ddlEndpoint" -> s"$masterHost:$masterPort", "password" -> masterPassword)), 19 | spark.sparkContext), 20 | isOnExecutor = false, 21 | s"$masterHost:$masterPort" 22 | ) 23 | } 24 | 25 | override def afterEach(): Unit = { 26 | super.afterEach() 27 | SinglestoreConnectionPool.close() 28 | } 29 | 30 | it("reuses a connection") { 31 | var conn = SinglestoreConnectionPool.getConnection(properties) 32 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 33 | conn.close() 34 | 35 | conn = SinglestoreConnectionPool.getConnection(properties) 36 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 37 | conn.close() 38 | 39 | assert(conn1 == conn2, "should reuse idle connection") 40 | } 41 | 42 | it("creates a new connection when existing is in use") { 43 | val conn1 = SinglestoreConnectionPool.getConnection(properties) 44 | val conn2 = SinglestoreConnectionPool.getConnection(properties) 45 | val originalConn1 = 46 | conn1.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 47 | val originalConn2 = 48 | conn2.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 49 | 50 | assert(originalConn1 != originalConn2, "should create a new connection when existing is in use") 51 | 52 | conn1.close() 53 | conn2.close() 54 | } 55 | 56 | it("creates different pools for different properties") { 57 | var conn = SinglestoreConnectionPool.getConnection(properties) 58 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 59 | conn.close() 60 | 61 | properties.setProperty("newProperty", "") 62 | 63 | conn = SinglestoreConnectionPool.getConnection(properties) 64 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 65 | conn.close() 66 | 67 | assert(conn1 != conn2, "should create different pools for different properties") 68 | } 69 | 70 | it("maxTotal and maxWaitMillis") { 71 | val maxWaitMillis = 200 72 | properties.setProperty("maxTotal", "1") 73 | properties.setProperty("maxWaitMillis", maxWaitMillis.toString) 74 | 75 | val conn1 = SinglestoreConnectionPool.getConnection(properties) 76 | 77 | val start = System.nanoTime() 78 | try { 79 | SinglestoreConnectionPool.getConnection(properties) 80 | fail() 81 | } catch { 82 | case e: Throwable => 83 | assert( 84 | e.getMessage.equals( 85 | "Cannot get a connection, pool error Timeout waiting for idle object"), 86 | "should throw timeout error" 87 | ) 88 | assert(System.nanoTime() - start > Duration.ofMillis(maxWaitMillis).toNanos, 89 | "should throw timeout error after 1 sec") 90 | } 91 | 92 | conn1.close() 93 | } 94 | 95 | it("eviction of idle connections") { 96 | val minEvictableIdleTimeMillis = 100 97 | val timeBetweenEvictionRunsMillis = 50 98 | properties.setProperty("minEvictableIdleTimeMillis", minEvictableIdleTimeMillis.toString) 99 | properties.setProperty("timeBetweenEvictionRunsMillis", timeBetweenEvictionRunsMillis.toString) 100 | 101 | var conn = SinglestoreConnectionPool.getConnection(properties) 102 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 103 | conn.close() 104 | 105 | Thread.sleep(((minEvictableIdleTimeMillis + timeBetweenEvictionRunsMillis) * 1.1).toLong) 106 | 107 | conn = SinglestoreConnectionPool.getConnection(properties) 108 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 109 | conn.close() 110 | 111 | assert(conn1 != conn2, "should evict idle connection") 112 | } 113 | 114 | it("maxConnLifetimeMillis") { 115 | val maxConnLifetimeMillis = 100 116 | properties.setProperty("maxConnLifetimeMillis", maxConnLifetimeMillis.toString) 117 | var conn = SinglestoreConnectionPool.getConnection(properties) 118 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 119 | conn.close() 120 | 121 | Thread.sleep((maxConnLifetimeMillis * 1.1).toLong) 122 | 123 | conn = SinglestoreConnectionPool.getConnection(properties) 124 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal 125 | conn.close() 126 | 127 | assert(conn1 != conn2, "should not use a connection after end of lifetime") 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/SinglestoreOptionsTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 4 | 5 | class SinglestoreOptionsTest extends IntegrationSuiteBase { 6 | val requiredOptions = Map("ddlEndpoint" -> "h:3306") 7 | 8 | describe("equality") { 9 | it("should sort dmlEndpoints") { 10 | assert( 11 | SinglestoreOptions( 12 | CaseInsensitiveMap( 13 | requiredOptions ++ Map("dmlEndpoints" -> "host1:3302,host2:3302,host1:3342")), 14 | spark.sparkContext) == 15 | SinglestoreOptions( 16 | CaseInsensitiveMap( 17 | requiredOptions ++ Map("dmlEndpoints" -> "host2:3302,host1:3302,host1:3342")), 18 | spark.sparkContext), 19 | "Should sort dmlEndpoints" 20 | ) 21 | } 22 | } 23 | 24 | describe("splitEscapedColumns") { 25 | it("empty string") { 26 | assert(SinglestoreOptions.splitEscapedColumns("") == List()) 27 | } 28 | 29 | it("3 columns") { 30 | assert( 31 | SinglestoreOptions.splitEscapedColumns("col1,col2,col3") == List("col1", "col2", "col3")) 32 | } 33 | 34 | it("with spaces") { 35 | assert( 36 | SinglestoreOptions 37 | .splitEscapedColumns(" col1 , col2, col3") == List(" col1 ", " col2", " col3")) 38 | } 39 | 40 | it("with backticks") { 41 | assert( 42 | SinglestoreOptions.splitEscapedColumns(" ` col1` , `col2`, `` col3") == List(" ` col1` ", 43 | " `col2`", 44 | " `` col3")) 45 | } 46 | 47 | it("with commas inside of backticks") { 48 | assert( 49 | SinglestoreOptions 50 | .splitEscapedColumns(" ` ,, col1,` , ``,```,col3`, `` col4,`,,`") == List( 51 | " ` ,, col1,` ", 52 | " ``", 53 | "```,col3`", 54 | " `` col4", 55 | "`,,`")) 56 | } 57 | } 58 | 59 | describe("trimAndUnescapeColumn") { 60 | it("empty string") { 61 | assert(SinglestoreOptions.trimAndUnescapeColumn("") == "") 62 | } 63 | 64 | it("spaces") { 65 | assert(SinglestoreOptions.trimAndUnescapeColumn(" ") == "") 66 | } 67 | 68 | it("in backticks") { 69 | assert(SinglestoreOptions.trimAndUnescapeColumn(" `asd` ") == "asd") 70 | } 71 | 72 | it("backticks in the result") { 73 | assert(SinglestoreOptions.trimAndUnescapeColumn(" ```a``sd` ") == "`a`sd") 74 | } 75 | 76 | it("several escaped words") { 77 | assert(SinglestoreOptions.trimAndUnescapeColumn(" ```a``sd` ```a``sd` ") == "`a`sd `a`sd") 78 | } 79 | 80 | it("backtick in the middle of string") { 81 | assert( 82 | SinglestoreOptions 83 | .trimAndUnescapeColumn(" a```a``sd` ```a``sd` ") == "a```a``sd` ```a``sd`") 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/TestHelper.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import java.sql.SQLException 4 | 5 | import scala.annotation.tailrec 6 | 7 | object TestHelper { 8 | 9 | @tailrec 10 | def isSQLExceptionWithCode(e: Throwable, codes: List[Integer]): Boolean = e match { 11 | case e: SQLException if codes.contains(e.getErrorCode) => true 12 | case e if e.getCause != null => isSQLExceptionWithCode(e.getCause, codes) 13 | case e => 14 | e.printStackTrace() 15 | false 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/test/scala/com/singlestore/spark/VersionTest.scala: -------------------------------------------------------------------------------- 1 | package com.singlestore.spark 2 | 3 | import com.singlestore.spark.SQLGen.SinglestoreVersion 4 | import org.scalatest.funspec.AnyFunSpec 5 | 6 | class VersionTest extends AnyFunSpec { 7 | 8 | it("singlestore version test") { 9 | 10 | assert(SinglestoreVersion("7.0.1").atLeast("6.8.1")) 11 | assert(!SinglestoreVersion("6.8.1").atLeast("7.0.1")) 12 | assert(SinglestoreVersion("7.0.2").atLeast("7.0.1")) 13 | assert(SinglestoreVersion("7.0.10").atLeast("7.0.9")) 14 | assert(SinglestoreVersion("7.2.5").atLeast("7.1.99999")) 15 | assert(SinglestoreVersion("7.2.500").atLeast("7.2.499")) 16 | } 17 | } 18 | --------------------------------------------------------------------------------