├── .arcconfig
├── .circleci
└── config.yml
├── .gitignore
├── .idea
├── codeStyles
│ ├── Project.xml
│ └── codeStyleConfig.xml
└── runConfigurations
│ ├── Test_Spark_3_0.xml
│ └── Test_Spark_3_1.xml
├── .java-version
├── .scalafmt.conf
├── CHANGELOG
├── LICENSE
├── Layerfile
├── README.md
├── build.sbt
├── ci
└── secring.asc.enc
├── demo
├── Dockerfile
├── README.md
└── notebook
│ ├── pyspark-singlestore-demo_2F8XQUKFG.zpln
│ ├── scala-singlestore-demo_2F6Y3APTX.zpln
│ └── spark-sql-singlestore-demo_2F7PZ81H6.zpln
├── project
├── build.properties
└── plugins.sbt
├── scripts
├── define-layerci-matrix.sh
├── jwt
│ └── jwt_auth_config.json
├── setup-cluster.sh
└── ssl
│ ├── test-ca-cert.pem
│ ├── test-ca-key.pem
│ ├── test-singlestore-cert.pem
│ └── test-singlestore-key.pem
└── src
├── main
├── resources
│ └── META-INF
│ │ └── services
│ │ └── org.apache.spark.sql.sources.DataSourceRegister
├── scala-sparkv3.1
│ └── spark
│ │ ├── MaxNumConcurentTasks.scala
│ │ ├── VersionSpecificAggregateExpressionExtractor.scala
│ │ ├── VersionSpecificExpressionGen.scala
│ │ ├── VersionSpecificUtil.scala
│ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala
├── scala-sparkv3.2
│ └── spark
│ │ ├── MaxNumConcurrentTasks.scala
│ │ ├── VersionSpecificAggregateExpressionExtractor.scala
│ │ ├── VersionSpecificExpressionGen.scala
│ │ ├── VersionSpecificUtil.scala
│ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala
├── scala-sparkv3.3
│ └── spark
│ │ ├── MaxNumConcurrentTasks.scala
│ │ ├── VersionSpecificAggregateExpressionExtractor.scala
│ │ ├── VersionSpecificExpressionGen.scala
│ │ ├── VersionSpecificUtil.scala
│ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala
├── scala-sparkv3.4
│ └── spark
│ │ ├── MaxNumConcurrentTasks.scala
│ │ ├── VersionSpecificAggregateExpressionExtractor.scala
│ │ ├── VersionSpecificExpressionGen.scala
│ │ ├── VersionSpecificUtil.scala
│ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala
├── scala-sparkv3.5
│ └── spark
│ │ ├── MaxNumConcurrentTasks.scala
│ │ ├── VersionSpecificAggregateExpressionExtractor.scala
│ │ ├── VersionSpecificExpressionGen.scala
│ │ ├── VersionSpecificUtil.scala
│ │ └── VersionSpecificWindowBoundaryExpressionExtractor.scala
└── scala
│ └── com
│ ├── memsql
│ └── spark
│ │ └── DefaultSource.scala
│ └── singlestore
│ └── spark
│ ├── AggregatorParallelReadListener.scala
│ ├── AvroSchemaHelper.scala
│ ├── CompletionIterator.scala
│ ├── DefaultSource.scala
│ ├── ExpressionGen.scala
│ ├── JdbcHelpers.scala
│ ├── LazyLogging.scala
│ ├── Loan.scala
│ ├── MetricsHandler.scala
│ ├── OverwriteBehavior.scala
│ ├── ParallelReadEnablement.scala
│ ├── ParallelReadType.scala
│ ├── SQLGen.scala
│ ├── SQLHelper.scala
│ ├── SQLPushdownRule.scala
│ ├── SinglestoreBatchInsertWriter.scala
│ ├── SinglestoreConnectionPool.scala
│ ├── SinglestoreConnectionPoolOptions.scala
│ ├── SinglestoreDialect.scala
│ ├── SinglestoreLoadDataWriter.scala
│ ├── SinglestoreOptions.scala
│ ├── SinglestorePartitioner.scala
│ ├── SinglestoreRDD.scala
│ ├── SinglestoreReader.scala
│ └── vendor
│ └── apache
│ ├── SchemaConverters.scala
│ └── third_party_license
└── test
├── resources
├── data
│ ├── movies.json
│ ├── movies_rating.json
│ ├── reviews.json
│ └── users.json
├── log4j.properties
├── log4j2.properties
└── mockito-extensions
│ └── org.mockito.plugins.MockMaker
└── scala
└── com
└── singlestore
└── spark
├── BatchInsertBenchmark.scala
├── BatchInsertTest.scala
├── BenchmarkSerializingTest.scala
├── BinaryTypeBenchmark.scala
├── CustomDatatypesTest.scala
├── ExternalHostTest.scala
├── IntegrationSuiteBase.scala
├── IssuesTest.scala
├── LoadDataBenchmark.scala
├── LoadDataTest.scala
├── LoadbalanceTest.scala
├── MaxErrorsTest.scala
├── OutputMetricsTest.scala
├── ReferenceTableTest.scala
├── SQLHelperTest.scala
├── SQLOverwriteTest.scala
├── SQLPermissionsTest.scala
├── SQLPushdownTest.scala
├── SanityTest.scala
├── SinglestoreConnectionPoolTest.scala
├── SinglestoreOptionsTest.scala
├── TestHelper.scala
└── VersionTest.scala
/.arcconfig:
--------------------------------------------------------------------------------
1 | {
2 | "project_id" : "memsql-spark-connector",
3 | "conduit_uri" : "https:\/\/grizzly.internal.memcompute.com\/api\/"
4 | }
5 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 | commands:
3 | setup_environment:
4 | description: "Setup the machine environment"
5 | parameters:
6 | sbt_version:
7 | type: string
8 | default: 1.3.6
9 | steps:
10 | - run:
11 | name: Setup Machine
12 | command: |
13 | sudo apt update
14 | sudo update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/bin/java
15 | sudo apt install -y curl
16 | sudo wget https://github.com/sbt/sbt/releases/download/v<< parameters.sbt_version >>/sbt-<< parameters.sbt_version >>.tgz
17 | sudo tar xzvf sbt-<< parameters.sbt_version >>.tgz -C /usr/share/
18 | sudo rm sbt-<< parameters.sbt_version >>.tgz
19 | sudo update-alternatives --install /usr/bin/sbt sbt /usr/share/sbt/bin/sbt 100
20 | sudo apt-get update
21 | sudo apt-get install -y python-pip git mariadb-client-core-10.6
22 | sudo apt-get clean
23 | sudo apt-get autoclean
24 |
25 | jobs:
26 | test:
27 | parameters:
28 | spark_version:
29 | type: string
30 | singlestore_image:
31 | type: string
32 | machine: true
33 | resource_class: large
34 | environment:
35 | SINGLESTORE_IMAGE: << parameters.singlestore_image >>
36 | SINGLESTORE_PORT: 5506
37 | SINGLESTORE_USER: root
38 | SINGLESTORE_DB: test
39 | JAVA_HOME: /usr/lib/jvm/java-11-openjdk-amd64/
40 | CONTINUOUS_INTEGRATION: true
41 | SBT_OPTS: "-Xmx256M"
42 | steps:
43 | - setup_environment
44 | - checkout
45 | - run:
46 | name: Setup test cluster
47 | command: ./scripts/setup-cluster.sh
48 | - run:
49 | name: Run tests
50 | command: |
51 | export SINGLESTORE_HOST=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' singlestore-integration)
52 | if [ << parameters.spark_version >> == '3.1.3' ]
53 | then
54 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark31" -Dspark.version=<< parameters.spark_version >>
55 | elif [ << parameters.spark_version >> == '3.2.4' ]
56 | then
57 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark32" -Dspark.version=<< parameters.spark_version >>
58 | elif [ << parameters.spark_version >> == '3.3.4' ]
59 | then
60 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark33" -Dspark.version=<< parameters.spark_version >>
61 | elif [ << parameters.spark_version >> == '3.4.2' ]
62 | then
63 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark34" -Dspark.version=<< parameters.spark_version >>
64 | else
65 | sbt ++2.12.12 "testOnly -- -l ExcludeFromSpark35" -Dspark.version=<< parameters.spark_version >>
66 | fi
67 |
68 | publish:
69 | machine: true
70 | environment:
71 | JAVA_HOME: /usr/lib/jvm/java-11-openjdk-amd64/
72 | SONATYPE_USERNAME: memsql
73 | steps:
74 | - setup_environment
75 | - checkout
76 | - run:
77 | name: Import GPG key
78 | command: |
79 | openssl enc -d -aes-256-cbc -K ${ENCRYPTION_KEY} -iv ${ENCRYPTION_IV} -in ci/secring.asc.enc -out ci/secring.asc
80 | gpg --import ci/secring.asc
81 | - run:
82 | name: Publish Spark 3.2.4
83 | command: |
84 | sbt ++2.12.12 -Dspark.version=3.2.4 clean publishSigned sonatypeBundleRelease
85 | - run:
86 | name: Publish Spark 3.1.3
87 | command: |
88 | sbt ++2.12.12 -Dspark.version=3.1.3 clean publishSigned sonatypeBundleRelease
89 | - run:
90 | name: Publish Spark 3.3.4
91 | command: |
92 | sbt ++2.12.12 -Dspark.version=3.3.4 clean publishSigned sonatypeBundleRelease
93 | - run:
94 | name: Publish Spark 3.4.2
95 | command: |
96 | sbt ++2.12.12 -Dspark.version=3.4.2 clean publishSigned sonatypeBundleRelease
97 | - run:
98 | name: Publish Spark 3.5.0
99 | command: |
100 | sbt ++2.12.12 -Dspark.version=3.5.0 clean publishSigned sonatypeBundleRelease
101 |
102 | workflows:
103 | test:
104 | jobs:
105 | - test:
106 | filters:
107 | tags:
108 | only: /^v.*/
109 | branches:
110 | ignore: /.*/
111 | matrix:
112 | parameters:
113 | spark_version:
114 | - 3.1.3
115 | - 3.2.4
116 | - 3.3.4
117 | - 3.4.2
118 | - 3.5.0
119 | singlestore_image:
120 | - singlestore/cluster-in-a-box:alma-8.0.19-f48780d261-4.0.11-1.16.0
121 | - singlestore/cluster-in-a-box:alma-8.1.32-e3d3cde6da-4.0.16-1.17.6
122 | - singlestore/cluster-in-a-box:alma-8.5.22-fe61f40cd1-4.1.0-1.17.11
123 | - singlestore/cluster-in-a-box:alma-8.7.12-483e5f8acb-4.1.0-1.17.15
124 | publish:
125 | jobs:
126 | - approve-publish:
127 | type: approval
128 | filters:
129 | tags:
130 | only: /^v.*/
131 | branches:
132 | ignore: /.*/
133 | - publish:
134 | requires:
135 | - approve-publish
136 | filters:
137 | tags:
138 | only: /^v.*/
139 | branches:
140 | ignore: /.*/
141 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
2 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
3 |
4 | # User-specific stuff
5 | /.idea/*
6 | !/.idea/codeStyles/*
7 | !/.idea/runConfigurations/*
8 |
9 | # JIRA plugin
10 | atlassian-ide-plugin.xml
11 |
12 | # IntelliJ
13 | /out/
14 | /target/
15 |
16 | # mpeltonen/sbt-idea plugin
17 | /.idea_modules/
18 |
19 | # File-based project format
20 | /*.iws
21 |
22 | # sbt project stuff
23 | /project/*
24 | !/project/build.properties
25 | !/project/plugins.sbt
26 | /target
27 | /build
28 | /spark-warehouse
29 |
30 | /ci/secring.asc
31 |
--------------------------------------------------------------------------------
/.idea/codeStyles/Project.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/runConfigurations/Test_Spark_3_0.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/runConfigurations/Test_Spark_3_1.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.java-version:
--------------------------------------------------------------------------------
1 | 1.8
2 |
--------------------------------------------------------------------------------
/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | maxColumn = 100
2 | align = more
--------------------------------------------------------------------------------
/CHANGELOG:
--------------------------------------------------------------------------------
1 | 2024-11-22 Version 4.1.10
2 | * Updated JDBC driver to 1.2.7
3 |
4 | 2024-11-22 Version 4.1.9
5 | * Changed to work with Databricks runtime 15.4
6 |
7 | 2024-06-14 Version 4.1.8
8 | * Changed retry during reading from result table to use exponential backoff
9 | * Used ForkJoinPool instead of FixedThreadPool
10 | * Added more logging
11 |
12 | 2024-05-13 Version 4.1.7
13 | * Fixed bug that caused reading from the wrong result table when the task was restarted
14 |
15 | 2024-04-11 Version 4.1.6
16 | * Changed LoadDataWriter to send data in batches
17 | * Added numPartitions parameter to specify exact number of resulting partition during parallel read
18 |
19 | 2023-10-05 Version 4.1.5
20 | * Added support of Spark 3.5
21 | * Updated dependencies
22 |
23 | 2023-07-18 Version 4.1.4
24 | * Added support of Spark 3.4
25 | * Added connection attributes
26 | * Fixed conflicts of result table names during parallel read
27 | * Updated version of the SingleStore JDBC driver
28 |
29 | 2023-03-31 Version 4.1.3
30 | * Updated version of the SingleStore JDBC driver
31 | * Fixed error handling when `onDuplicateKeySQL` option is used
32 |
33 | 2023-02-21 Version 4.1.2
34 | * Fixed an issue that would cause a `Table has reached its quota of 1 reader(s)`` error to be displayed when a parallel read was retried
35 |
36 | 2022-07-13 Version 4.1.1
37 | * Added clientEndpoint option for Cloud deployment of the SingleStoreDB
38 | * Fixed bug in the error handling that caused deadlock
39 | * Added support of the Spark 3.3
40 |
41 | 2022-06-22 Version 4.1.0
42 | * Added support of more SQL expressions in pushdown
43 | * Added multi-partition to the parallel read
44 | * Updated SingleStore JDBC Driver to 1.1.0
45 | * Added JWT authentication
46 | * Added connection pooling
47 |
48 | 2022-01-20 Version 4.0.0
49 | * Changed connector to use SingleStore JDBC Driver instead of MariaDB JDBC Driver
50 |
51 | 2021-12-23 Version 3.2.2
52 | * Added possibility to repartition result by columns in parallel read from aggregators
53 | * Replaced usages of `transformDown` with `transform` in order to make connector work with Databricks 9.1 LTS
54 |
55 | 2021-12-14 Version 3.2.1
56 | * Added support of the Spark 3.2
57 | * Fixed links in the README
58 |
59 | 2021-11-29 Version 3.2.0
60 | * Added support for reading in parallel from aggregator nodes instead of leaf nodes
61 |
62 | 2021-09-16 Version 3.1.3
63 | * Added Spark 3.1 support
64 | * Deleted Spark 2.3 and 2.4 support
65 |
66 | 2021-04-29 Version 3.1.2
67 | * Added using external host and port by default while using `useParallelRead`
68 |
69 | 2021-02-05 Version 3.1.1
70 | * Added support of `com.memsql.spark` data source name for backward compatibility
71 |
72 | 2021-01-22 Version 3.1.0
73 | * Rebranded `memsql-spark-connector` to `singlestore-spark-connector`
74 | * Spark data source format changed from `memsql` to `singlestore`
75 | * Configuration prefix changed from `spark.datasource.memsql.` to `spark.datasource.singlestore.`
76 |
77 | 2020-10-19 Version 3.0.5
78 | * Fixed bug with load balance connections to dml endpoint
79 |
80 | 2020-09-29 Version 3.1.0-beta
81 | * Added Spark 3.0 support
82 | * Fixed bugs in pushdowns
83 | * Fixed bug with wrong SQL code generation of attribute names that contains special characters
84 | * Added methods that allow you to run SQL queries on a MemSQL database directly
85 |
86 | 2020-08-20 Version 3.0.4
87 | * Added trim pushdown
88 |
89 | 2020-08-14 Version 3.0.3
90 | * Fixed bug with pushdown of the join condition
91 |
92 | 2020-08-03 Version 3.0.2
93 | * added maxErrors option
94 | * changed aliases in SQL queries to be more deterministic
95 | * disabled comments inside of the SQL queries when logging level is not TRACE
96 |
97 | 2020-06-12 Version 3.0.1
98 | * The connector now updates task metrics with the number of records written during write operations
99 |
100 | 2020-05-27 Version 3.0.0
101 | * Introduces SQL Optimization & Rewrite for most query shapes and compatible expressions
102 | * Implemented as a native Spark SQL plugin
103 | * Supports both the DataSource and DataSourceV2 API for maximum support of current and future functionality
104 | * Contains deep integrations with the Catalyst query optimizer
105 | * Is compatible with Spark 2.3 and 2.4
106 | * Leverages MemSQL LOAD DATA to accelerate ingest from Spark via compression, vectorized cpu instructions, and optimized segment sizes
107 | * Takes advantage of all the latest and greatest features in MemSQL 7.x
108 |
109 | 2020-05-06 Version 3.0.0-rc1
110 | * Support writing into MemSQL reference tables
111 | * Deprecated truncate option in favor of overwriteBehavior
112 | * New option overwriteBehavior allows you to specify how to overwrite or merge rows during ingest
113 | * The Ignore SaveMode now correctly skips all duplicate key errors during ingest
114 |
115 | 2020-04-30 Version 3.0.0-beta12
116 | * Improved performance of new batch insert functionality for `ON DUPLICATE KEY UPDATE` feature
117 |
118 | 2020-04-30 Version 3.0.0-beta11
119 | * Added support for merging rows on ingest via `ON DUPLICATE KEY UPDATE`
120 | * Added docker-based demo for running a Zeppelin notebook using the Spark connector
121 |
122 | 2020-04-20 Version 3.0.0-beta10
123 | * Additional functions supported in SQL Pushdown: toUnixTimestamp, unixTimestamp, nextDay, dateDiff, monthsAdd, hypot, rint
124 | * Now tested against MemSQL 6.7, and all tests use SSL
125 | * Fixed bug with disablePushdown
126 |
127 | 2020-04-09 Version 3.0.0-beta9
128 | * Add null handling to address Spark bug which causes incorrect handling of null literals (https://issues.apache.org/jira/browse/SPARK-31403)
129 |
130 | 2020-04-01 Version 3.0.0-beta8
131 | * Added support for more datetime expressions:
132 | * addition/subtraction of datetime objects
133 | * to_utc_timestamp, from_utc_timestamp
134 | * date_trunc, trunc
135 |
136 | 2020-03-25 Version 3.0.0-beta7
137 | * The connector now respects column selection when loading dataframes into MemSQL
138 |
139 | 2020-03-24 Version 3.0.0-beta6
140 | * Fix bug when you use an expression in an explicit query
141 |
142 | 2020-03-23 Version 3.0.0-beta5
143 | * Increase connection timeout to increase connector reliability
144 |
145 | 2020-03-20 Version 3.0.0-beta4
146 | * Set JDBC driver to MariaDB explicitely to avoid issues with the mysql driver
147 |
148 | 2020-03-19 Version 3.0.0-beta3
149 | * Created tables default to Columnstore
150 | * User can override keys attached to new tables
151 | * New parallelRead option which enables reading directly from MemSQL leaf nodes
152 | * Created tables now set case-sensitive collation on all columns
153 | to match Spark semantics
154 | * More SQL expressions supported in pushdown (tanh, sinh, cosh)
155 |
156 | 2020-02-08 Version 3.0.0-beta2
157 | * Removed options: masterHost and masterPort
158 | * Added ddlEndpoint and ddlEndpoints options
159 | * Added path option to support specifying the dbtable via `.load("mytable")` when creating a dataframe
160 |
161 | 2020-01-30 Version 3.0.0-beta
162 | * Full re-write of the Spark Connector
163 |
164 | 2019-02-27 Version 2.0.7
165 | * Add support for EXPLAIN JSON in MemSQL versions 6.7 and later to fix partition pushdown.
166 |
167 | 2018-09-14 Version 2.0.6
168 | * Force utf-8 encoding when loading data into MemSQL
169 |
170 | 2018-01-18 Version 2.0.5
171 | * Explicitly sort MemSQLRDD partitions due to MemSQL 6.0 no longer returning partitions in sorted order by ordinal.
172 |
173 | 2017-08-31 Version 2.0.4
174 | * Switch threads in LoadDataStrategy so that the parent thread reads from the RDD and the new thread writes
175 | to MemSQL so that Spark has access to the thread-local variables it expects
176 |
177 | 2017-07-19 Version 2.0.3
178 | * Handle special characters column names in query
179 | * Add option to enable jdbc connector to stream result sets row-by-row
180 | * Fix groupby queries incorrectly pushed down to leaves
181 | * Add option to write to master aggregator only
182 | * Add support for reading MemSQL columns of type unsigned bigint and unsigned int
183 |
184 | 2017-04-17
185 | * Pull MemSQL configuration from runtime configuration in sparkSession.conf instead of static config in sparkContext
186 | * Fix connection pooling bug where extraneous connections were created
187 | * Add MemSQL configuration to disable partition pushdown
188 |
189 | 2017-02-06 Version 2.0.1
190 | * Fixed bug to enable partition pushdown for MemSQL DataFrames loaded from a custom user query
191 |
192 | 2017-02-01 Version 2.0.0
193 | * Compatible with Apache Spark 2.0.0+
194 | * Removed experimental strategy SQL pushdown to instead use the more stable Data Sources API for reading
195 | data from MemSQL
196 | * Removed memsql-spark-interface, memsql-etl
197 |
198 | 2015-12-15 Version 1.2.1
199 | * Python support for extractors and transformers
200 | * More extensive SQL pushdown for DataFrame operations
201 | * Use DataFrames as common interface between extractor, transformer, and loader
202 | * Rewrite connectorLib internals to support SparkSQL relation provider API
203 | * Remove RDD.saveToMemSQL
204 |
205 | 2015-11-19 Version 1.1.1
206 | * Set JDBC login timeout to 10 seconds
207 |
208 | 2015-11-02 Version 1.1.0
209 |
210 | * Available on Maven Central Repository
211 | * More events for batches
212 | * Deprecated the old Kafka extractor and replaced it with a new one that takes in a Zookeeper quorum address
213 | * Added a new field to pipeline API responses indicating whether or not a pipeline is currently running
214 | * Renamed projects: memsqlsparkinterface -> memsql-spark-interface, memsqletl -> memsql-etl, memsqlrdd -> memsql-connector.
215 | * Robustness and bug fixes
216 |
217 | 2015-09-24 Version 1.0.0
218 |
219 | * Initial release of MemSQL Streamliner
220 |
--------------------------------------------------------------------------------
/Layerfile:
--------------------------------------------------------------------------------
1 | FROM vm/ubuntu:18.04
2 |
3 | # install curl, python and mysql-client
4 | RUN sudo apt update && \
5 | sudo apt install -y curl python-pip mysql-client-core-5.7
6 |
7 | # install sbt
8 | RUN wget https://github.com/sbt/sbt/releases/download/v1.3.5/sbt-1.3.5.tgz && \
9 | sudo tar xzvf sbt-1.3.5.tgz -C /usr/share/ && \
10 | sudo update-alternatives --install /usr/bin/sbt sbt /usr/share/sbt/bin/sbt 100
11 |
12 | # install the latest version of Docker
13 | RUN apt-get update && \
14 | apt-get install apt-transport-https ca-certificates curl software-properties-common && \
15 | curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - && \
16 | add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu bionic stable" && \
17 | apt-get update && \
18 | apt install docker-ce
19 |
20 | # install java
21 | RUN apt-get update && \
22 | sudo apt-get install openjdk-8-jdk
23 |
24 | # set environment variables
25 | ENV MEMSQL_PORT=5506
26 | ENV MEMSQL_USER=root
27 | ENV MEMSQL_DB=test
28 | ENV JAVA_HOME=/usr/lib/jvm/jdk1.8.0
29 | ENV CONTINUOUS_INTEGRATION=true
30 | ENV SBT_OPTS=-Xmx1g
31 | ENV SBT_OPTS=-Xms1g
32 | SECRET ENV LICENSE_KEY
33 | SECRET ENV SINGLESTORE_PASSWORD
34 | SECRET ENV SINGLESTORE_JWT_PASSWORD
35 |
36 | # increase the memory
37 | MEMORY 4G
38 | MEMORY 8G
39 | MEMORY 12G
40 | MEMORY 16G
41 |
42 | # split to 21 states
43 | # each of them will run different version of the singlestore and spark
44 | SPLIT 21
45 |
46 | # copy the entire git repository
47 | COPY . .
48 |
49 | # setup split specific env variables
50 | RUN scripts/define-layerci-matrix.sh >> ~/.profile
51 |
52 | # start singlestore cluster
53 | RUN ./scripts/setup-cluster.sh
54 |
55 | # run tests
56 | RUN sbt ++$SCALA_VERSION -Dspark.version=$SPARK_VERSION "${TEST_FILTER}"
57 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | import xerial.sbt.Sonatype._
2 |
3 | /*
4 | To run tests or publish with a specific spark version use this java option:
5 | -Dspark.version=3.0.0
6 | */
7 | val sparkVersion = sys.props.get("spark.version").getOrElse("3.1.3")
8 | val scalaVersionStr = "2.12.12"
9 | val scalaVersionPrefix = scalaVersionStr.substring(0, 4)
10 | val jacksonDatabindVersion = sparkVersion match {
11 | case "3.1.3" => "2.10.0"
12 | case "3.2.4" => "2.12.3"
13 | case "3.3.4" => "2.13.4.2"
14 | case "3.4.2" => "2.14.2"
15 | case "3.5.0" => "2.15.2"
16 | }
17 |
18 | lazy val root = project
19 | .withId("singlestore-spark-connector")
20 | .in(file("."))
21 | .enablePlugins(BuildInfoPlugin)
22 | .settings(
23 | name := "singlestore-spark-connector",
24 | organization := "com.singlestore",
25 | scalaVersion := scalaVersionStr,
26 | Compile / unmanagedSourceDirectories += (Compile / sourceDirectory).value / (sparkVersion match {
27 | case "3.1.3" => "scala-sparkv3.1"
28 | case "3.2.4" => "scala-sparkv3.2"
29 | case "3.3.4" => "scala-sparkv3.3"
30 | case "3.4.2" => "scala-sparkv3.4"
31 | case "3.5.0" => "scala-sparkv3.5"
32 | }),
33 | version := s"4.1.10-spark-${sparkVersion}",
34 | licenses += "Apache-2.0" -> url(
35 | "http://opensource.org/licenses/Apache-2.0"
36 | ),
37 | resolvers += "Spark Packages Repo" at "https://dl.bintray.com/spark-packages/maven",
38 | libraryDependencies ++= Seq(
39 | // runtime dependencies
40 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided, test",
41 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided, test",
42 | "org.apache.avro" % "avro" % "1.11.3",
43 | "org.apache.commons" % "commons-dbcp2" % "2.7.0",
44 | "org.scala-lang.modules" %% "scala-java8-compat" % "0.9.0",
45 | "com.singlestore" % "singlestore-jdbc-client" % "1.2.7",
46 | "io.spray" %% "spray-json" % "1.3.5",
47 | "io.netty" % "netty-buffer" % "4.1.70.Final",
48 | "org.apache.commons" % "commons-dbcp2" % "2.9.0",
49 | // test dependencies
50 | "org.mariadb.jdbc" % "mariadb-java-client" % "2.+" % Test,
51 | "org.scalatest" %% "scalatest" % "3.1.0" % Test,
52 | "org.scalacheck" %% "scalacheck" % "1.14.1" % Test,
53 | "org.mockito" %% "mockito-scala" % "1.16.37" % Test,
54 | "com.github.mrpowers" %% "spark-fast-tests" % "0.21.3" % Test,
55 | "com.github.mrpowers" %% "spark-daria" % "0.38.2" % Test
56 | ),
57 | dependencyOverrides += "com.fasterxml.jackson.core" % "jackson-databind" % jacksonDatabindVersion,
58 | Test / testOptions += Tests.Argument("-oF"),
59 | Test / fork := true,
60 | buildInfoKeys := Seq[BuildInfoKey](version),
61 | buildInfoPackage := "com.singlestore.spark"
62 | )
63 |
64 | credentials += Credentials(
65 | "GnuPG Key ID",
66 | "gpg",
67 | "CDD996495CF08BB2041D86D8D1EB3D14F1CD334F",
68 | "ignored" // this field is ignored; passwords are supplied by pinentry
69 | )
70 |
71 | assemblyMergeStrategy in assembly := {
72 | case PathList("META-INF", _*) => MergeStrategy.discard
73 | case _ => MergeStrategy.first
74 | }
75 |
76 | publishTo := sonatypePublishToBundle.value
77 | publishMavenStyle := true
78 | sonatypeSessionName := s"[sbt-sonatype] ${name.value} ${version.value}"
79 | sonatypeProjectHosting := Some(GitHubHosting("memsql", "memsql-spark-connector", "carl@memsql.com"))
80 |
--------------------------------------------------------------------------------
/ci/secring.asc.enc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/memsql/singlestore-spark-connector/febf9e5a854334233a2bbb8fdc52ad619bf3e13b/ci/secring.asc.enc
--------------------------------------------------------------------------------
/demo/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM apache/zeppelin:0.9.0
2 |
3 | ENV SPARK_VERSION=3.5.0
4 |
5 | USER root
6 |
7 | RUN wget https://apache.ip-connect.vn.ua/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop2.7.tgz
8 | RUN tar xf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz -C /
9 | RUN rm -rf spark-${SPARK_VERSION}-bin-hadoop2.7.tgz
10 | ENV SPARK_HOME=/spark-${SPARK_VERSION}-bin-hadoop2.7
11 | ENV ZEPPELIN_PORT=8082
12 | RUN rm -rf /zeppelin/notebook/*
13 |
14 | EXPOSE ${ZEPPELIN_PORT}/tcp
15 |
--------------------------------------------------------------------------------
/demo/README.md:
--------------------------------------------------------------------------------
1 | ## singlestore-spark-connector demo
2 |
3 | This is Dockerfile which uses the upstream [Zeppelin Image](https://hub.docker.com/r/apache/zeppelin/) as it's base
4 | and has two notebooks with examples of singlestore-spark-connector.
5 |
6 | To run this docker with [MemSQL CIAB](https://hub.docker.com/r/memsql/cluster-in-a-box) follow the instructions
7 |
8 | * Create a docker network to be able to connect zeppelin and memsql-ciab
9 | ```
10 | docker network create zeppelin-ciab-network
11 | ```
12 |
13 | * Pull memsql-ciab docker image
14 | ```
15 | docker pull memsql/cluster-in-a-box
16 | ```
17 |
18 | * Run and start the SingleStore Cluster in a Box docker container
19 |
20 | ```
21 | docker run -i --init \
22 | --name singlestore-ciab-for-zeppelin \
23 | -e LICENSE_KEY=[INPUT_YOUR_LICENSE_KEY] \
24 | -e ROOT_PASSWORD=my_password \
25 | -p 3306:3306 -p 8081:8080 \
26 | --net=zeppelin-ciab-network \
27 | memsql/cluster-in-a-box
28 | ```
29 | ```
30 | docker start singlestore-ciab-for-zeppelin
31 | ```
32 | > :note: in this step you can hit a port collision error
33 | >
34 | > ```
35 | > docker: Error response from daemon: driver failed programming external connectivity on endpoint singlestore-ciab-for-zeppelin
36 | > (38b0df3496f1ec83f120242a53a7023d8a0b74db67f5e487fb23641983c67a76):
37 | > Bind for 0.0.0.0:8080 failed: port is already allocated.
38 | > ERRO[0000] error waiting for container: context canceled
39 | > ```
40 | >
41 | > If it happened then remove the container
42 | >
43 | >`docker rm singlestore-ciab-for-zeppelin`
44 | >
45 | > and run the first command with other ports `-p {new_port1}:3306 -p {new_port2}:8080`
46 |
47 | * Build zeppelin docker image in `singlestore-spark-connector/demo` folder
48 |
49 | ```
50 | docker build -t zeppelin .
51 | ```
52 |
53 | * Run zeppelin docker container
54 | ```
55 | docker run -d --init \
56 | --name zeppelin \
57 | -p 8082:8082 \
58 | --net=zeppelin-ciab-network \
59 | -v $PWD/notebook:/opt/zeppelin/notebook/singlestore \
60 | -v $PWD/notebook:/zeppelin/notebook/singlestore \
61 | zeppelin
62 | ```
63 |
64 | > :note: in this step you can hit a port collision error
65 | >
66 | > ```
67 | > docker: Error response from daemon: driver failed programming external connectivity on endpoint zeppelin
68 | > (38b0df3496f1ec83f120242a53a7023d8a0b74db67f5e487fb23641983c67a76):
69 | > Bind for 0.0.0.0:8082 failed: port is already allocated.
70 | > ERRO[0000] error waiting for container: context canceled
71 | > ```
72 | >
73 | > If it happened then remove the container
74 | >
75 | >`docker rm zeppelin`
76 | >
77 | > and run this command with other port `-p {new_port}:8082`
78 |
79 |
80 | * open [zeppelin](http://localhost:8082/next) in your browser and try
81 | [scala](http://localhost:8082/next/#/notebook/2F8XQUKFG),
82 | [pyspark](http://localhost:8082/next/#/notebook/2F6Y3APTX)
83 | and [spark sql](http://localhost:8082/next/#/notebook/2F7PZ81H6) notebooks
84 |
85 | For setting up more powerful SingleStore trial cluster use [SingleStore Managed Service](https://www.singlestore.com/managed-service/)
86 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.3.8
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.6")
2 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1")
3 | addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0")
4 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.2.0")
5 |
--------------------------------------------------------------------------------
/scripts/define-layerci-matrix.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -eu
3 |
4 | SINGLESTORE_IMAGE_TAGS=(
5 | "alma-8.0.19-f48780d261-4.0.11-1.16.0"
6 | "alma-8.1.32-e3d3cde6da-4.0.16-1.17.6"
7 | "alma-8.5.22-fe61f40cd1-4.1.0-1.17.11"
8 | "alma-8.7.12-483e5f8acb-4.1.0-1.17.15"
9 | )
10 | SINGLESTORE_IMAGE_TAGS_COUNT=${#SINGLESTORE_IMAGE_TAGS[@]}
11 | SPARK_VERSIONS=(
12 | "3.5.0"
13 | "3.4.2"
14 | "3.3.4"
15 | "3.2.4"
16 | "3.1.3"
17 | )
18 | SPARK_VERSIONS_COUNT=${#SPARK_VERSIONS[@]}
19 |
20 | TEST_NUM=${SPLIT:-"0"}
21 |
22 | SINGLESTORE_IMAGE_TAG_INDEX=$(( $TEST_NUM / $SPARK_VERSIONS_COUNT))
23 | SINGLESTORE_IMAGE_TAG_INDEX=$((SINGLESTORE_IMAGE_TAG_INDEX>=SINGLESTORE_IMAGE_TAGS_COUNT ? SINGLESTORE_IMAGE_TAGS_COUNT-1 : SINGLESTORE_IMAGE_TAG_INDEX))
24 | SINGLESTORE_IMAGE_TAG=${SINGLESTORE_IMAGE_TAGS[SINGLESTORE_IMAGE_TAG_INDEX]}
25 |
26 | SPARK_VERSION_INDEX=$(( $TEST_NUM % $SPARK_VERSIONS_COUNT))
27 | SPARK_VERSION=${SPARK_VERSIONS[SPARK_VERSION_INDEX]}
28 |
29 | if [ $TEST_NUM == $(($SINGLESTORE_IMAGE_TAGS_COUNT*$SPARK_VERSIONS_COUNT)) ]
30 | then
31 | echo 'export FORCE_READ_FROM_LEAVES=TRUE'
32 | else
33 | echo 'export FORCE_READ_FROM_LEAVES=FALSE'
34 | fi
35 |
36 | echo "export SINGLESTORE_IMAGE='singlestore/cluster-in-a-box:$SINGLESTORE_IMAGE_TAG'"
37 | echo "export SPARK_VERSION='$SPARK_VERSION'"
38 | echo "export TEST_FILTER='testOnly -- -l ExcludeFromSpark${SPARK_VERSION:0:1}${SPARK_VERSION:2:1}'"
39 | echo "export SCALA_VERSION='2.12.12'"
40 |
--------------------------------------------------------------------------------
/scripts/jwt/jwt_auth_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "username_claim": "username",
3 | "methods": [
4 | {
5 | "algorithms": [ "RS384" ],
6 | "secret": "-----BEGIN PUBLIC KEY-----\nMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA0i0dDauX6iaOaocic99O\nUTruTYPWFUv50aHTgfKxenFKpJTTL43T8ON36whwyObM3r/ayhPoyPxvSfkkCwxd\nE7XEmTRARHkJKQfebkbN6SaKlEgIdmZ8UroZCslSzOcsX0N1KNc3WSyFeOigHSp/\nww+roVtaJC/OJQ95kMjIGdN3ooO5g/YvZJZTn9KQ/dFmNDPaSyseT9/MCE2Rp0g0\nT2yCwewxVdfR+D4QcicaLat7CAFXMnoSxV9ifGXYkv6JE33dc95U4BgPYECca2QA\nNe3ZQHSNxC1rc+uim3cgcn6PP4WKTgTG4u74F9xA8FbumZUIMB7rChsr+E8Z4/Iq\nSGb7/y0J6Auho7BDJPL7ZFE9peuSA3NudZlpkH+GIWRW7fBY7qu1Koh8kGfZeg4k\nLIucgT0CpSvrxaDVqRSEIiqy4zoczrLXcVJN9ThtxVPJZCVTW2dir8aCwO9Sk60W\nQ8DP4wRWhnqa+Irsd/2r5c7QgSSOSo4/EIBU65qP5oA6xUw3F4sE7G+Q9ofijxxJ\n8iGr2Y7SBk5ztvSYAX/ZNfoUJZQ0cXDCSCKus8a9kQQ9zCNBLco9pIv7XUcJf1bj\nzfz1o911OZn/mdcgqUq0uRne8x73J/Z4uIsaVubvcSgv0XTSF8qYgpgRd016+nAa\nFtyVd9Y5xmYgIIHaXzQ3l7MCAwEAAQ==\n-----END PUBLIC KEY-----"
7 | }
8 | ]
9 | }
10 |
--------------------------------------------------------------------------------
/scripts/setup-cluster.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -eu
3 |
4 | # this script must be run from the top-level of the repo
5 | cd "$(git rev-parse --show-toplevel)"
6 |
7 | DEFAULT_IMAGE_NAME="singlestore/cluster-in-a-box:alma-8.0.15-0b9b66384f-4.0.11-1.15.2"
8 | IMAGE_NAME="${SINGLESTORE_IMAGE:-$DEFAULT_IMAGE_NAME}"
9 | CONTAINER_NAME="singlestore-integration"
10 |
11 | EXISTS=$(docker inspect ${CONTAINER_NAME} >/dev/null 2>&1 && echo 1 || echo 0)
12 |
13 | if [[ "${EXISTS}" -eq 1 ]]; then
14 | EXISTING_IMAGE_NAME=$(docker inspect -f '{{.Config.Image}}' ${CONTAINER_NAME})
15 | if [[ "${IMAGE_NAME}" != "${EXISTING_IMAGE_NAME}" ]]; then
16 | echo "Existing container ${CONTAINER_NAME} has image ${EXISTING_IMAGE_NAME} when ${IMAGE_NAME} is expected; recreating container."
17 | docker rm -f ${CONTAINER_NAME}
18 | EXISTS=0
19 | fi
20 | fi
21 |
22 | if [[ "${EXISTS}" -eq 0 ]]; then
23 | docker run -i --init \
24 | --name ${CONTAINER_NAME} \
25 | -v ${PWD}/scripts/ssl:/test-ssl \
26 | -v ${PWD}/scripts/jwt:/test-jwt \
27 | -e LICENSE_KEY=${LICENSE_KEY} \
28 | -e ROOT_PASSWORD=${SINGLESTORE_PASSWORD} \
29 | -p 5506:3306 -p 5507:3307 -p 5508:3308 \
30 | ${IMAGE_NAME}
31 | fi
32 |
33 | docker start ${CONTAINER_NAME}
34 |
35 | singlestore-wait-start() {
36 | echo -n "Waiting for SingleStore to start..."
37 | while true; do
38 | if mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" -e "select 1" >/dev/null 2>/dev/null; then
39 | break
40 | fi
41 | echo -n "."
42 | sleep 0.2
43 | done
44 | echo ". Success!"
45 | }
46 |
47 | singlestore-wait-start
48 |
49 | if [[ "${EXISTS}" -eq 0 ]]; then
50 | echo
51 | echo "Creating aggregator node"
52 | docker exec -it ${CONTAINER_NAME} memsqlctl create-node --yes --password ${SINGLESTORE_PASSWORD} --port 3308
53 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key minimum_core_count --value 0
54 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key minimum_memory_mb --value 0
55 | docker exec -it ${CONTAINER_NAME} memsqlctl start-node --yes --all
56 | docker exec -it ${CONTAINER_NAME} memsqlctl add-aggregator --yes --host 127.0.0.1 --password ${SINGLESTORE_PASSWORD} --port 3308
57 | fi
58 |
59 | echo
60 | echo "Setting up SSL"
61 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key ssl_ca --value /test-ssl/test-ca-cert.pem
62 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key ssl_cert --value /test-ssl/test-singlestore-cert.pem
63 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key ssl_key --value /test-ssl/test-singlestore-key.pem
64 | echo "Setting up JWT"
65 | docker exec -it ${CONTAINER_NAME} memsqlctl update-config --yes --all --key jwt_auth_config_file --value /test-jwt/jwt_auth_config.json
66 | echo "Restarting cluster"
67 | docker exec -it ${CONTAINER_NAME} memsqlctl restart-node --yes --all
68 | singlestore-wait-start
69 | echo "Setting up root-ssl user"
70 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" -e 'create user "root-ssl"@"%" require ssl'
71 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" -e 'grant all privileges on *.* to "root-ssl" with grant option'
72 | mysql -u root -h 127.0.0.1 -P 5507 -p"${SINGLESTORE_PASSWORD}" -e 'create user "root-ssl"@"%" require ssl'
73 | mysql -u root -h 127.0.0.1 -P 5507 -p"${SINGLESTORE_PASSWORD}" -e 'grant all privileges on *.* to "root-ssl" with grant option'
74 | mysql -u root -h 127.0.0.1 -P 5508 -p"${SINGLESTORE_PASSWORD}" -e 'grant all privileges on *.* to "root-ssl" with grant option'
75 | echo "Done!"
76 | echo "Setting up root-jwt user"
77 | mysql -h 127.0.0.1 -u root -P 5506 -p"${SINGLESTORE_PASSWORD}" -e "CREATE USER 'test_jwt_user' IDENTIFIED WITH authentication_jwt"
78 | mysql -h 127.0.0.1 -u root -P 5506 -p"${SINGLESTORE_PASSWORD}" -e "GRANT ALL PRIVILEGES ON *.* TO 'test_jwt_user'@'%'"
79 | echo "Done!"
80 |
81 | echo
82 | echo "Ensuring child nodes are connected using container IP"
83 | CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' ${CONTAINER_NAME})
84 | CURRENT_LEAF_IP=$(mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e 'select host from information_schema.leaves')
85 | if [[ ${CONTAINER_IP} != "${CURRENT_LEAF_IP}" ]]; then
86 | # remove leaf with current ip
87 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "remove leaf '${CURRENT_LEAF_IP}':3307"
88 | # add leaf with correct ip
89 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "add leaf root:'${SINGLESTORE_PASSWORD}'@'${CONTAINER_IP}':3307"
90 | fi
91 | CURRENT_AGG_IP=$(mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e 'select host from information_schema.aggregators where master_aggregator=0')
92 | if [[ ${CONTAINER_IP} != "${CURRENT_AGG_IP}" ]]; then
93 | # remove aggregator with current ip
94 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "remove aggregator '${CURRENT_AGG_IP}':3308"
95 | # add aggregator with correct ip
96 | mysql -u root -h 127.0.0.1 -P 5506 -p"${SINGLESTORE_PASSWORD}" --batch -N -e "add aggregator root:'${SINGLESTORE_PASSWORD}'@'${CONTAINER_IP}':3308"
97 | fi
98 | echo "Done!"
99 |
--------------------------------------------------------------------------------
/scripts/ssl/test-ca-cert.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIDCzCCAfOgAwIBAgIULJ/CPUCImfaUkxiXSMAXBKWvtnIwDQYJKoZIhvcNAQEL
3 | BQAwFDESMBAGA1UEAwwJdGVzdC1yb290MCAXDTIwMDQxMzIyMjYzOFoYDzMwMTkw
4 | ODE1MjIyNjM4WjAUMRIwEAYDVQQDDAl0ZXN0LXJvb3QwggEiMA0GCSqGSIb3DQEB
5 | AQUAA4IBDwAwggEKAoIBAQDQWWfkNMUG11C4INpkOf0E/Ao5pLfVIrAWVGRTQe+s
6 | KDx9w1UaLiKx21vWXWOHPz8bKvQ1OOhPP85r+ZOdbmsrOMyWF+kMZEgs8pUsG8kk
7 | ztO+nRWskAN0qCD9kLPcN0exqda7S+we/1WOBVUv92Qvc3ip6anwlNkRmCGDznP8
8 | p0q8vxj/cmPMhfwuq3RLvkJ3buBGboJBNkphTTR3K3iy40XiG25U3YEnjspaEpqs
9 | QO2JfZBRA4Jq/OTDfSPH16/JINsHkpe4GFgP78133DS7fFZs6HEonxMlr0cvgJGk
10 | JBJIcOZZYezTpgIlTETfNZeiJybTQ/0IX7KSugV/Dnt/AgMBAAGjUzBRMB0GA1Ud
11 | DgQWBBTneAQV+LCbwSS0EJ4u0ZvaUDJREjAfBgNVHSMEGDAWgBTneAQV+LCbwSS0
12 | EJ4u0ZvaUDJREjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAW
13 | sdYs9+C+GC81/dXMBstZDol9AQ2871hGHsjFaomk0/pxvCvgv+oLlHEjCyHaykcJ
14 | AY3hurFwyKMMXaR/AD5WcIe8Khk1wrzxoJhXyVIYmgE8cIdsDrA0D9ggILI0DmbH
15 | E2vZjAUdiDeS0IMv2AF1kprbPzA0lM/0Lm1DLLubPYs6oijOpVx8tyGxEesxqF8I
16 | KcIBk64WYkBNQFL7vy5fHI2S2N8D7COBybuabojCMKqUetGLw2BfSuOCOVAviwBK
17 | 5WMaFqULkAsc9GuMVvoddujK8B/yGQ1Nim4NTDdRGivtsSwphw+tNo+VRFLk/PK2
18 | VgeZz4rnrlh0Iog1f+iZ
19 | -----END CERTIFICATE-----
20 |
--------------------------------------------------------------------------------
/scripts/ssl/test-ca-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEpAIBAAKCAQEA0Fln5DTFBtdQuCDaZDn9BPwKOaS31SKwFlRkU0HvrCg8fcNV
3 | Gi4isdtb1l1jhz8/Gyr0NTjoTz/Oa/mTnW5rKzjMlhfpDGRILPKVLBvJJM7Tvp0V
4 | rJADdKgg/ZCz3DdHsanWu0vsHv9VjgVVL/dkL3N4qemp8JTZEZghg85z/KdKvL8Y
5 | /3JjzIX8Lqt0S75Cd27gRm6CQTZKYU00dyt4suNF4htuVN2BJ47KWhKarEDtiX2Q
6 | UQOCavzkw30jx9evySDbB5KXuBhYD+/Nd9w0u3xWbOhxKJ8TJa9HL4CRpCQSSHDm
7 | WWHs06YCJUxE3zWXoicm00P9CF+ykroFfw57fwIDAQABAoIBAQCbNd91W/JjNEfH
8 | w4GuJJze97vOUW05dAvltpy+gWJAyAC4V6mwRSpHgPibaxrYCD/Ex20BsREu6IOo
9 | YFadc0KXAks2jT1po9M42MZUA6cGqqWHXJJm6SoJ364j94ZlyTC5o6J6CQcv2Fst
10 | 378kapHR353GRnH47Yn/12swO76gOc7LLC8LqoWwMAnCjbas63rhBpdeBWymB1s5
11 | 8TbZn6d3ViHJuIt/wGuWvrDvc9Gi0GjQaUA+KhbdBLiL6THdFA9w1RUKIBoOFB7x
12 | 2rN/BL1cCARcrz31QpbLfQgHeFdXxj9iK15WOumSSD+iNWZoA16QxvzYDCL2Pnk9
13 | 0PyTFZcxAoGBAPvMnOfdwOmjXRBGJurYzj8O8amRqphvtRt76zZhjXG7wzdvMCmb
14 | WP77odIYtw97zNmIiS0p91b4O4xmByr6V74+8qIHM09xqB2+hqBft020vEGXNoOa
15 | Qq/hcAT6Nb0kLIMHOxs68CnPWmH7opjg28MaDXq1ScgHl7I7H4U2ZclFAoGBANPT
16 | Od7fLQ9uGnMaw5mIZ14mcb0gUQt+yFTXiZvI4XbAmuK6o2zzmG/4bBMktqOkTc1h
17 | JEnx1Zmeoqbbhv5cBQA03Y6HF/rREO0nfEFrc8YVGZN2ZqtZOwHR/cekobNkPckU
18 | CdGHVF4X6D10u/1S63tuxA7ON9A1s7YntTWq/CPzAoGATxXtEkZsGPXefQYLoyeF
19 | X/jpnkDKPCaZ05AQSHxLWLWIkxixH+BTC4MtSDfLB2ny5UAlFbJgpUhCK86/4ZfP
20 | h0luG8X3L7SbAPyefDCT+iwSFOfRj3QcDfHYpTeROV7rPBxBTEQuunMOCEhowWue
21 | mqDMKwZVriX0V16Kf+SeA6ECgYA6/+RojWTxnUtEsDm28+VGthKMCQpJ12BZMUek
22 | 2ojiGLeLW0zVtevJlDoWAu3UGpmJEPuYlQFXrnXDX/XztxG1gwQLBNnLBJxgUdUs
23 | K4+tpobfKeVi6JGk6iZziwl2+/6xmSE6+SSoqKQJKhCKeKQaVznIneux1KNfoyO3
24 | 9Q4RvQKBgQClP/fHEcdgsla47IHmwJi5+YzuX6y9gAJ6Taa+Fy6NawOrwE1c4d0o
25 | lOtwbl/Jxvjqew02kdm/R9LEQqIblzTT0S7o9n+1+TZ6AGacMEKbR3lNtgIKHIdl
26 | aer4g9JTw+Ob42RcRuS5CaKttZSxMYw6wgAxAnKM+zs5TaGK0jOL7g==
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/scripts/ssl/test-singlestore-cert.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIICpzCCAY8CAQEwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJdGVzdC1yb290
3 | MCAXDTIwMDQxMzIyMjkyNloYDzMwMTkwODE1MjIyOTI2WjAdMRswGQYDVQQDDBJ0
4 | ZXN0LW1lbXNxbC1zZXJ2ZXIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
5 | AQCUsvUKgAd4A2ggbGts/8L9y9mkS1z0R5oFvhuAu6zpqNSO0vQqkzY22mm6VxyD
6 | Di6vBECUTd+wQk8r1puU7c//H12okvxmZ2cB5UlfQQdFdfvzXJ7eUCEaHwIB7GUy
7 | 1nCwOFV1HT8CT+MCPZsS0SntsLOQfGwhkxsQ5qRhxcGfg5MoJh+Ew8b9CipLov8c
8 | Aryzai9aTW7NUOscuAaTrCuZyDdp1A940atQ/RZcaGyN1AueRZNiQYf4Kx8tff/D
9 | koAH293O86VbN4uj7IVKD7eZI9vDebRpAH2WuP82azLEHbQ4Zu9g2/Xf+4x2x1wl
10 | Xg0t4lxPntcNBLnX0b7PosBzAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAL4ohV3q
11 | V04vJg71/CG58tRpPfv9WAWr8yaiUedOxAKpG4vkQcN+eUXHneEXVWaPz8w4cCbV
12 | rIRDXpRoch33bjsKLMbKhnMOjv1NeNvPeIlrzXZjhMS4tfW0aAtAQAKdQCOZHfnZ
13 | GC1TY26c3YnLv2NvHwJ4PExgcFE7Ex3ej6Z0lMz971UoCgfm4zfb3ag4R9CBroAO
14 | GyqMteSxeXG1LOijXMaShjrba7moeqcClrxiV3oM7p84aGBHLlBiY1IvqVJzGTAx
15 | JwcISD0Ao3iRnrnQOhmC1ryQ8lYFFnmrxt4oVHQpUpanWNVV++SHQmO8dzRmTnFR
16 | xxVSDwHkhUt4m94=
17 | -----END CERTIFICATE-----
18 |
--------------------------------------------------------------------------------
/scripts/ssl/test-singlestore-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEogIBAAKCAQEAlLL1CoAHeANoIGxrbP/C/cvZpEtc9EeaBb4bgLus6ajUjtL0
3 | KpM2Ntppulccgw4urwRAlE3fsEJPK9ablO3P/x9dqJL8ZmdnAeVJX0EHRXX781ye
4 | 3lAhGh8CAexlMtZwsDhVdR0/Ak/jAj2bEtEp7bCzkHxsIZMbEOakYcXBn4OTKCYf
5 | hMPG/QoqS6L/HAK8s2ovWk1uzVDrHLgGk6wrmcg3adQPeNGrUP0WXGhsjdQLnkWT
6 | YkGH+CsfLX3/w5KAB9vdzvOlWzeLo+yFSg+3mSPbw3m0aQB9lrj/NmsyxB20OGbv
7 | YNv13/uMdsdcJV4NLeJcT57XDQS519G+z6LAcwIDAQABAoIBAEcLeagao3bjqcxU
8 | AL+DM1avHr0whKjxzNURj3JiOKsqzuOuRppQ24Y5tGojVKwJCqT0EybITieYhtsb
9 | Hhp5xPbPtZ/lGlKS9NQjCHtKRn8Zb9dGWWE+R5KDXiItH+y6J/0J7UqXPpOMN5nK
10 | dVz4MmAuHJzb1Y31Cul4SPGt2mSrbiNcIrghTfTI7iCpk/+70zJdxIYrw/x35DMl
11 | rzbiNIuUnQDYeEa7I2SCZAWMwpKget9B3S3fgnA1WKQuNKlKJDT5X9PVfzPcrhzV
12 | IWs0mxq+vCJx7AVqwSMJyIV1ijSR6rU+atr/zJJJv9VRofhUodx11Kjm3yXGFmQ8
13 | bHUu9QECgYEAw5avzPpyrZSXwFclCXt+sjiC/cpaYQ8NFLjSZWexuL5/peGCS+cF
14 | fGSPJnkSGSPWB5s350g/rt4zmAOPqGgBlvjofGczRBqFGO3x+3+pzxgLhPMlkP43
15 | ko33r3J07dWwVPe2KxvbDKoQrhiE10H11uIu1Tfs9k2/uUZLAzFHQCkCgYEAwqCu
16 | 6h34nsl0FzQA/tascCLbwa95Da0IDPrGFZ95X8aryPPNcB4hpuGn0TaHDW8zJV8K
17 | j9kmvBA5H5z9cXV6CeZJrW4eJcmPgUi5RvFUuF7fYSNpFLTE8K1b1mCw8P927WMm
18 | f0EfdY4qCKDU1rIJa6iD5x6Imy6t0HnTg54iHzsCgYBOZa8PzW98DiyJhySsWVje
19 | XPJ8gciaUOsgXDjRNrAw6gLGXc7ZV7+GLdSHSk4rz4ZxxBCzXu1PzXcGvp6tlQrW
20 | Fe0yODd/W9XvuSiec3yAKxYq8z8ikBN8ZfVa2NjvoBCu7h+RxfeWavCGqANPOPwu
21 | Zrj49BLCY0WvIPLeU7lIiQKBgDNuDoqjHN2o0lqHTXQJ+ksviu6lldF9VdFIOyvf
22 | lk0uzJovgqwL6kyU+KmaRRnRtqw7bykP8uJjTxUBgR+IMZWIGxQPMzw9BQTe2Mbc
23 | YszNlS2wE8Z69ke7J7eAmYE1oJGeT7/0z4Fa7dSV22hYZ5DhWOmr8eE/9oJOjwwK
24 | r22dAoGARD2MUPMV7f/98RA9X7XfXNG1yVBDxxGMJLGIDYhExtV6hni0nvC3tzBx
25 | lPvVOUUC54Ho7AuQoQqSUgDQisaAymdEawDYhfwrC5PMqC63dpl6D//vUgAHJvkX
26 | Fwz071Rvr520Yv5yDGS47DzhxImGWK/Wn2ZFkfQzZvtzGBnmHh0=
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister:
--------------------------------------------------------------------------------
1 | com.singlestore.spark.DefaultSource
2 | com.memsql.spark.DefaultSource
3 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.1/spark/MaxNumConcurentTasks.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.scheduler
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | object MaxNumConcurrentTasks {
6 | def get(rdd: RDD[_]): Int = {
7 | val (_, resourceProfiles) =
8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd)
9 | val resourceProfile =
10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles)
11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.1/spark/VersionSpecificAggregateExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op}
5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{
6 | AggregateFunction,
7 | Average,
8 | First,
9 | Kurtosis,
10 | Last,
11 | Skewness,
12 | StddevPop,
13 | StddevSamp,
14 | Sum,
15 | VariancePop,
16 | VarianceSamp
17 | }
18 |
19 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor,
20 | context: SQLGenContext,
21 | filter: Option[SQLGen.Joinable]) {
22 | def unapply(aggFunc: AggregateFunction): Option[Statement] = {
23 | aggFunc match {
24 | // CentralMomentAgg.scala
25 | case StddevPop(expressionExtractor(child), true) =>
26 | Some(aggregateWithFilter("STDDEV_POP", child, filter))
27 | case StddevSamp(expressionExtractor(child), true) =>
28 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter))
29 | case VariancePop(expressionExtractor(child), true) =>
30 | Some(aggregateWithFilter("VAR_POP", child, filter))
31 | case VarianceSamp(expressionExtractor(child), true) =>
32 | Some(aggregateWithFilter("VAR_SAMP", child, filter))
33 | case Kurtosis(expressionExtractor(child), true) =>
34 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) )
35 | // / POW(STD(child), 4) ) - 3
36 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article
37 | Some(
38 | op(
39 | "-",
40 | op(
41 | "/",
42 | op(
43 | "-",
44 | op(
45 | "+",
46 | op(
47 | "-",
48 | aggregateWithFilter("AVG", f("POW", child, "4"), filter),
49 | op("*",
50 | op("*",
51 | aggregateWithFilter("AVG", child, filter),
52 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)),
53 | "4")
54 | ),
55 | op("*",
56 | "6",
57 | op("*",
58 | aggregateWithFilter("AVG", f("POW", child, "2"), filter),
59 | f("POW", aggregateWithFilter("AVG", child, filter), "2")))
60 | ),
61 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4"))
62 | ),
63 | f("POW", aggregateWithFilter("STD", child, filter), "4")
64 | ),
65 | "3"
66 | )
67 | )
68 |
69 | case Skewness(expressionExtractor(child), true) =>
70 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3)
71 | // following the definition section in https://en.wikipedia.org/wiki/Skewness
72 | Some(
73 | op(
74 | "/",
75 | op(
76 | "-",
77 | op(
78 | "-",
79 | aggregateWithFilter("AVG", f("POW", child, "3"), filter),
80 | op("*",
81 | op("*",
82 | aggregateWithFilter("AVG", child, filter),
83 | f("POW", aggregateWithFilter("STD", child, filter), "2")),
84 | "3")
85 | ),
86 | f("POW", aggregateWithFilter("AVG", child, filter), "3")
87 | ),
88 | f("POW", aggregateWithFilter("STD", child, filter), "3")
89 | )
90 | )
91 |
92 | // First.scala
93 | case First(expressionExtractor(child), false) =>
94 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
95 |
96 | // Last.scala
97 | case Last(expressionExtractor(child), false) =>
98 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
99 |
100 | // Sum.scala
101 | case Sum(expressionExtractor(child)) =>
102 | Some(aggregateWithFilter("SUM", child, filter))
103 |
104 | // Average.scala
105 | case Average(expressionExtractor(child)) =>
106 | Some(aggregateWithFilter("AVG", child, filter))
107 |
108 | case _ => None
109 | }
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.1/spark/VersionSpecificUtil.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.spark.sql.types.{CalendarIntervalType, DataType}
4 |
5 | object VersionSpecificUtil {
6 | def isIntervalType(d:DataType): Boolean =
7 | d.isInstanceOf[CalendarIntervalType]
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.1/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus}
5 |
6 | case class VersionSpecificWindowBoundaryExpressionExtractor(
7 | expressionExtractor: ExpressionExtractor) {
8 | def unapply(arg: Expression): Option[Statement] = {
9 | arg match {
10 | case UnaryMinus(expressionExtractor(child), false) =>
11 | Some(child + "PRECEDING")
12 | case _ => None
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.2/spark/MaxNumConcurrentTasks.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.scheduler
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | object MaxNumConcurrentTasks {
6 | def get(rdd: RDD[_]): Int = {
7 | val (_, resourceProfiles) =
8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd)
9 | val resourceProfile =
10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles)
11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.2/spark/VersionSpecificAggregateExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op}
5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{
6 | AggregateFunction,
7 | Average,
8 | First,
9 | Kurtosis,
10 | Last,
11 | Skewness,
12 | StddevPop,
13 | StddevSamp,
14 | Sum,
15 | VariancePop,
16 | VarianceSamp
17 | }
18 |
19 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor,
20 | context: SQLGenContext,
21 | filter: Option[SQLGen.Joinable]) {
22 | def unapply(aggFunc: AggregateFunction): Option[Statement] = {
23 | aggFunc match {
24 | // CentralMomentAgg.scala
25 | case StddevPop(expressionExtractor(child), true) =>
26 | Some(aggregateWithFilter("STDDEV_POP", child, filter))
27 | case StddevSamp(expressionExtractor(child), true) =>
28 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter))
29 | case VariancePop(expressionExtractor(child), true) =>
30 | Some(aggregateWithFilter("VAR_POP", child, filter))
31 | case VarianceSamp(expressionExtractor(child), true) =>
32 | Some(aggregateWithFilter("VAR_SAMP", child, filter))
33 | case Kurtosis(expressionExtractor(child), true) =>
34 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) )
35 | // / POW(STD(child), 4) ) - 3
36 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article
37 | Some(
38 | op(
39 | "-",
40 | op(
41 | "/",
42 | op(
43 | "-",
44 | op(
45 | "+",
46 | op(
47 | "-",
48 | aggregateWithFilter("AVG", f("POW", child, "4"), filter),
49 | op("*",
50 | op("*",
51 | aggregateWithFilter("AVG", child, filter),
52 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)),
53 | "4")
54 | ),
55 | op("*",
56 | "6",
57 | op("*",
58 | aggregateWithFilter("AVG", f("POW", child, "2"), filter),
59 | f("POW", aggregateWithFilter("AVG", child, filter), "2")))
60 | ),
61 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4"))
62 | ),
63 | f("POW", aggregateWithFilter("STD", child, filter), "4")
64 | ),
65 | "3"
66 | )
67 | )
68 |
69 | case Skewness(expressionExtractor(child), true) =>
70 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3)
71 | // following the definition section in https://en.wikipedia.org/wiki/Skewness
72 | Some(
73 | op(
74 | "/",
75 | op(
76 | "-",
77 | op(
78 | "-",
79 | aggregateWithFilter("AVG", f("POW", child, "3"), filter),
80 | op("*",
81 | op("*",
82 | aggregateWithFilter("AVG", child, filter),
83 | f("POW", aggregateWithFilter("STD", child, filter), "2")),
84 | "3")
85 | ),
86 | f("POW", aggregateWithFilter("AVG", child, filter), "3")
87 | ),
88 | f("POW", aggregateWithFilter("STD", child, filter), "3")
89 | )
90 | )
91 |
92 | // First.scala
93 | case First(expressionExtractor(child), false) =>
94 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
95 |
96 | // Last.scala
97 | case Last(expressionExtractor(child), false) =>
98 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
99 |
100 | // Sum.scala
101 | case Sum(expressionExtractor(child), false) =>
102 | Some(aggregateWithFilter("SUM", child, filter))
103 |
104 | // Average.scala
105 | case Average(expressionExtractor(child), false) =>
106 | Some(aggregateWithFilter("AVG", child, filter))
107 |
108 | case _ => None
109 | }
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.2/spark/VersionSpecificUtil.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.spark.sql.types.{
4 | CalendarIntervalType,
5 | DataType,
6 | DayTimeIntervalType,
7 | YearMonthIntervalType
8 | }
9 |
10 | object VersionSpecificUtil {
11 | def isIntervalType(d: DataType): Boolean =
12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d
13 | .isInstanceOf[YearMonthIntervalType]
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.2/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus}
5 |
6 | case class VersionSpecificWindowBoundaryExpressionExtractor(
7 | expressionExtractor: ExpressionExtractor) {
8 | def unapply(arg: Expression): Option[Statement] = {
9 | arg match {
10 | case UnaryMinus(expressionExtractor(child), false) =>
11 | Some(child + "PRECEDING")
12 | case _ => None
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.3/spark/MaxNumConcurrentTasks.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.scheduler
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | object MaxNumConcurrentTasks {
6 | def get(rdd: RDD[_]): Int = {
7 | val (_, resourceProfiles) =
8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd)
9 | val resourceProfile =
10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles)
11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.3/spark/VersionSpecificAggregateExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op}
5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{
6 | AggregateFunction,
7 | Average,
8 | First,
9 | Kurtosis,
10 | Last,
11 | Skewness,
12 | StddevPop,
13 | StddevSamp,
14 | Sum,
15 | VariancePop,
16 | VarianceSamp
17 | }
18 |
19 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor,
20 | context: SQLGenContext,
21 | filter: Option[SQLGen.Joinable]) {
22 | def unapply(aggFunc: AggregateFunction): Option[Statement] = {
23 | aggFunc match {
24 | // CentralMomentAgg.scala
25 | case StddevPop(expressionExtractor(child), true) =>
26 | Some(aggregateWithFilter("STDDEV_POP", child, filter))
27 | case StddevSamp(expressionExtractor(child), true) =>
28 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter))
29 | case VariancePop(expressionExtractor(child), true) =>
30 | Some(aggregateWithFilter("VAR_POP", child, filter))
31 | case VarianceSamp(expressionExtractor(child), true) =>
32 | Some(aggregateWithFilter("VAR_SAMP", child, filter))
33 | case Kurtosis(expressionExtractor(child), true) =>
34 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) )
35 | // / POW(STD(child), 4) ) - 3
36 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article
37 | Some(
38 | op(
39 | "-",
40 | op(
41 | "/",
42 | op(
43 | "-",
44 | op(
45 | "+",
46 | op(
47 | "-",
48 | aggregateWithFilter("AVG", f("POW", child, "4"), filter),
49 | op("*",
50 | op("*",
51 | aggregateWithFilter("AVG", child, filter),
52 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)),
53 | "4")
54 | ),
55 | op("*",
56 | "6",
57 | op("*",
58 | aggregateWithFilter("AVG", f("POW", child, "2"), filter),
59 | f("POW", aggregateWithFilter("AVG", child, filter), "2")))
60 | ),
61 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4"))
62 | ),
63 | f("POW", aggregateWithFilter("STD", child, filter), "4")
64 | ),
65 | "3"
66 | )
67 | )
68 |
69 | case Skewness(expressionExtractor(child), true) =>
70 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3)
71 | // following the definition section in https://en.wikipedia.org/wiki/Skewness
72 | Some(
73 | op(
74 | "/",
75 | op(
76 | "-",
77 | op(
78 | "-",
79 | aggregateWithFilter("AVG", f("POW", child, "3"), filter),
80 | op("*",
81 | op("*",
82 | aggregateWithFilter("AVG", child, filter),
83 | f("POW", aggregateWithFilter("STD", child, filter), "2")),
84 | "3")
85 | ),
86 | f("POW", aggregateWithFilter("AVG", child, filter), "3")
87 | ),
88 | f("POW", aggregateWithFilter("STD", child, filter), "3")
89 | )
90 | )
91 |
92 | // First.scala
93 | case First(expressionExtractor(child), false) =>
94 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
95 |
96 | // Last.scala
97 | case Last(expressionExtractor(child), false) =>
98 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
99 |
100 | // Sum.scala
101 | case Sum(expressionExtractor(child), false) =>
102 | Some(aggregateWithFilter("SUM", child, filter))
103 |
104 | // Average.scala
105 | case Average(expressionExtractor(child), false) =>
106 | Some(aggregateWithFilter("AVG", child, filter))
107 |
108 | case _ => None
109 | }
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.3/spark/VersionSpecificUtil.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.spark.sql.types.{
4 | CalendarIntervalType,
5 | DataType,
6 | DayTimeIntervalType,
7 | YearMonthIntervalType
8 | }
9 |
10 | object VersionSpecificUtil {
11 | def isIntervalType(d: DataType): Boolean =
12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d
13 | .isInstanceOf[YearMonthIntervalType]
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.3/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus}
5 |
6 | case class VersionSpecificWindowBoundaryExpressionExtractor(
7 | expressionExtractor: ExpressionExtractor) {
8 | def unapply(arg: Expression): Option[Statement] = {
9 | arg match {
10 | case UnaryMinus(expressionExtractor(child), false) =>
11 | Some(child + "PRECEDING")
12 | case _ => None
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.4/spark/MaxNumConcurrentTasks.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.scheduler
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | object MaxNumConcurrentTasks {
6 | def get(rdd: RDD[_]): Int = {
7 | val (_, resourceProfiles) =
8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd)
9 | val resourceProfile =
10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles)
11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.4/spark/VersionSpecificAggregateExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op}
5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{
6 | AggregateFunction,
7 | Average,
8 | First,
9 | Kurtosis,
10 | Last,
11 | Skewness,
12 | StddevPop,
13 | StddevSamp,
14 | Sum,
15 | VariancePop,
16 | VarianceSamp
17 | }
18 | import org.apache.spark.sql.catalyst.expressions.EvalMode
19 |
20 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor,
21 | context: SQLGenContext,
22 | filter: Option[SQLGen.Joinable]) {
23 | def unapply(aggFunc: AggregateFunction): Option[Statement] = {
24 | aggFunc match {
25 | // CentralMomentAgg.scala
26 | case StddevPop(expressionExtractor(child), true) =>
27 | Some(aggregateWithFilter("STDDEV_POP", child, filter))
28 | case StddevSamp(expressionExtractor(child), true) =>
29 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter))
30 | case VariancePop(expressionExtractor(child), true) =>
31 | Some(aggregateWithFilter("VAR_POP", child, filter))
32 | case VarianceSamp(expressionExtractor(child), true) =>
33 | Some(aggregateWithFilter("VAR_SAMP", child, filter))
34 | case Kurtosis(expressionExtractor(child), true) =>
35 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) )
36 | // / POW(STD(child), 4) ) - 3
37 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article
38 | Some(
39 | op(
40 | "-",
41 | op(
42 | "/",
43 | op(
44 | "-",
45 | op(
46 | "+",
47 | op(
48 | "-",
49 | aggregateWithFilter("AVG", f("POW", child, "4"), filter),
50 | op("*",
51 | op("*",
52 | aggregateWithFilter("AVG", child, filter),
53 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)),
54 | "4")
55 | ),
56 | op("*",
57 | "6",
58 | op("*",
59 | aggregateWithFilter("AVG", f("POW", child, "2"), filter),
60 | f("POW", aggregateWithFilter("AVG", child, filter), "2")))
61 | ),
62 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4"))
63 | ),
64 | f("POW", aggregateWithFilter("STD", child, filter), "4")
65 | ),
66 | "3"
67 | )
68 | )
69 |
70 | case Skewness(expressionExtractor(child), true) =>
71 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3)
72 | // following the definition section in https://en.wikipedia.org/wiki/Skewness
73 | Some(
74 | op(
75 | "/",
76 | op(
77 | "-",
78 | op(
79 | "-",
80 | aggregateWithFilter("AVG", f("POW", child, "3"), filter),
81 | op("*",
82 | op("*",
83 | aggregateWithFilter("AVG", child, filter),
84 | f("POW", aggregateWithFilter("STD", child, filter), "2")),
85 | "3")
86 | ),
87 | f("POW", aggregateWithFilter("AVG", child, filter), "3")
88 | ),
89 | f("POW", aggregateWithFilter("STD", child, filter), "3")
90 | )
91 | )
92 |
93 | // First.scala
94 | case First(expressionExtractor(child), false) =>
95 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
96 |
97 | // Last.scala
98 | case Last(expressionExtractor(child), false) =>
99 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
100 |
101 | // Sum.scala
102 | case Sum(expressionExtractor(child), EvalMode.LEGACY) =>
103 | Some(aggregateWithFilter("SUM", child, filter))
104 |
105 | // Average.scala
106 | case Average(expressionExtractor(child), EvalMode.LEGACY) =>
107 | Some(aggregateWithFilter("AVG", child, filter))
108 |
109 | case _ => None
110 | }
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.4/spark/VersionSpecificUtil.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.spark.sql.types.{
4 | CalendarIntervalType,
5 | DataType,
6 | DayTimeIntervalType,
7 | YearMonthIntervalType
8 | }
9 |
10 | object VersionSpecificUtil {
11 | def isIntervalType(d: DataType): Boolean =
12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d
13 | .isInstanceOf[YearMonthIntervalType]
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.4/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus}
5 |
6 | case class VersionSpecificWindowBoundaryExpressionExtractor(
7 | expressionExtractor: ExpressionExtractor) {
8 | def unapply(arg: Expression): Option[Statement] = {
9 | arg match {
10 | case UnaryMinus(expressionExtractor(child), false) =>
11 | Some(child + "PRECEDING")
12 | case _ => None
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.5/spark/MaxNumConcurrentTasks.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.scheduler
2 |
3 | import org.apache.spark.rdd.RDD
4 |
5 | object MaxNumConcurrentTasks {
6 | def get(rdd: RDD[_]): Int = {
7 | val (_, resourceProfiles) =
8 | rdd.sparkContext.dagScheduler.getShuffleDependenciesAndResourceProfiles(rdd)
9 | val resourceProfile =
10 | rdd.sparkContext.dagScheduler.mergeResourceProfilesForStage(resourceProfiles)
11 | rdd.sparkContext.maxNumConcurrentTasks(resourceProfile)
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.5/spark/VersionSpecificAggregateExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import com.singlestore.spark.ExpressionGen.{aggregateWithFilter, f, op}
5 | import org.apache.spark.sql.catalyst.expressions.aggregate.{
6 | AggregateFunction,
7 | Average,
8 | First,
9 | Kurtosis,
10 | Last,
11 | Skewness,
12 | StddevPop,
13 | StddevSamp,
14 | Sum,
15 | VariancePop,
16 | VarianceSamp
17 | }
18 | import org.apache.spark.sql.catalyst.expressions.EvalMode
19 |
20 | case class VersionSpecificAggregateExpressionExtractor(expressionExtractor: ExpressionExtractor,
21 | context: SQLGenContext,
22 | filter: Option[SQLGen.Joinable]) {
23 | def unapply(aggFunc: AggregateFunction): Option[Statement] = {
24 | aggFunc match {
25 | // CentralMomentAgg.scala
26 | case StddevPop(expressionExtractor(child), true) =>
27 | Some(aggregateWithFilter("STDDEV_POP", child, filter))
28 | case StddevSamp(expressionExtractor(child), true) =>
29 | Some(aggregateWithFilter("STDDEV_SAMP", child, filter))
30 | case VariancePop(expressionExtractor(child), true) =>
31 | Some(aggregateWithFilter("VAR_POP", child, filter))
32 | case VarianceSamp(expressionExtractor(child), true) =>
33 | Some(aggregateWithFilter("VAR_SAMP", child, filter))
34 | case Kurtosis(expressionExtractor(child), true) =>
35 | // ( (AVG(POW(child, 4)) - AVG(child) * POW(AVG(child), 3) * 4 + 6 * AVG(POW(child), 2) * POW(AVG(child), 2) - 3 * POW(AVG(child), 4) )
36 | // / POW(STD(child), 4) ) - 3
37 | // following the formula from https://stats.oarc.ucla.edu/other/mult-pkg/faq/general/faq-whats-with-the-different-formulas-for-kurtosis/ article
38 | Some(
39 | op(
40 | "-",
41 | op(
42 | "/",
43 | op(
44 | "-",
45 | op(
46 | "+",
47 | op(
48 | "-",
49 | aggregateWithFilter("AVG", f("POW", child, "4"), filter),
50 | op("*",
51 | op("*",
52 | aggregateWithFilter("AVG", child, filter),
53 | aggregateWithFilter("AVG", f("POW", child, "3"), filter)),
54 | "4")
55 | ),
56 | op("*",
57 | "6",
58 | op("*",
59 | aggregateWithFilter("AVG", f("POW", child, "2"), filter),
60 | f("POW", aggregateWithFilter("AVG", child, filter), "2")))
61 | ),
62 | op("*", "3", f("POW", aggregateWithFilter("AVG", child, filter), "4"))
63 | ),
64 | f("POW", aggregateWithFilter("STD", child, filter), "4")
65 | ),
66 | "3"
67 | )
68 | )
69 |
70 | case Skewness(expressionExtractor(child), true) =>
71 | // (AVG(POW(child, 3)) - AVG(child) * POW(STD(child), 2) * 3 - POW(AVG(child), 3) ) / POW(STD(child), 3)
72 | // following the definition section in https://en.wikipedia.org/wiki/Skewness
73 | Some(
74 | op(
75 | "/",
76 | op(
77 | "-",
78 | op(
79 | "-",
80 | aggregateWithFilter("AVG", f("POW", child, "3"), filter),
81 | op("*",
82 | op("*",
83 | aggregateWithFilter("AVG", child, filter),
84 | f("POW", aggregateWithFilter("STD", child, filter), "2")),
85 | "3")
86 | ),
87 | f("POW", aggregateWithFilter("AVG", child, filter), "3")
88 | ),
89 | f("POW", aggregateWithFilter("STD", child, filter), "3")
90 | )
91 | )
92 |
93 | // First.scala
94 | case First(expressionExtractor(child), false) =>
95 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
96 |
97 | // Last.scala
98 | case Last(expressionExtractor(child), false) =>
99 | Some(aggregateWithFilter("ANY_VALUE", child, filter))
100 |
101 | // Sum.scala
102 | case Sum(expressionExtractor(child), EvalMode.LEGACY) =>
103 | Some(aggregateWithFilter("SUM", child, filter))
104 |
105 | // Average.scala
106 | case Average(expressionExtractor(child), EvalMode.LEGACY) =>
107 | Some(aggregateWithFilter("AVG", child, filter))
108 |
109 | case _ => None
110 | }
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.5/spark/VersionSpecificUtil.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.spark.sql.types.{
4 | CalendarIntervalType,
5 | DataType,
6 | DayTimeIntervalType,
7 | YearMonthIntervalType
8 | }
9 |
10 | object VersionSpecificUtil {
11 | def isIntervalType(d: DataType): Boolean =
12 | d.isInstanceOf[CalendarIntervalType] || d.isInstanceOf[DayTimeIntervalType] || d
13 | .isInstanceOf[YearMonthIntervalType]
14 | }
15 |
--------------------------------------------------------------------------------
/src/main/scala-sparkv3.5/spark/VersionSpecificWindowBoundaryExpressionExtractor.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, Statement}
4 | import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus}
5 |
6 | case class VersionSpecificWindowBoundaryExpressionExtractor(
7 | expressionExtractor: ExpressionExtractor) {
8 | def unapply(arg: Expression): Option[Statement] = {
9 | arg match {
10 | case UnaryMinus(expressionExtractor(child), false) =>
11 | Some(child + "PRECEDING")
12 | case _ => None
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/src/main/scala/com/memsql/spark/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | package com.memsql.spark
2 |
3 | import com.singlestore.spark
4 |
5 | class DefaultSource extends spark.DefaultSource {
6 |
7 | override def shortName(): String = spark.DefaultSource.MEMSQL_SOURCE_NAME_SHORT
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/AggregatorParallelReadListener.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Connection, SQLException}
4 | import java.util.Properties
5 | import com.singlestore.spark.JdbcHelpers.getDDLConnProperties
6 | import com.singlestore.spark.SQLGen.VariableList
7 | import org.apache.spark.SparkContext
8 | import org.apache.spark.scheduler.{
9 | SparkListener,
10 | SparkListenerStageCompleted,
11 | SparkListenerStageSubmitted
12 | }
13 | import org.apache.spark.sql.types.StructType
14 |
15 | import scala.collection.mutable
16 |
17 | class AggregatorParallelReadListener(applicationId: String) extends SparkListener with LazyLogging {
18 | // connectionsMap is a map from the result table name to the connection with which this table was created
19 | private val connectionsMap: mutable.Map[String, Connection] =
20 | new mutable.HashMap[String, Connection]()
21 |
22 | // rddInfos is a map from RDD id to the info needed to create result table for this RDD
23 | private val rddInfos: mutable.Map[Int, SingleStoreRDDInfo] =
24 | new mutable.HashMap[Int, SingleStoreRDDInfo]()
25 |
26 | // SingleStoreRDDInfo is information needed to create a result table
27 | private case class SingleStoreRDDInfo(sc: SparkContext,
28 | query: String,
29 | variables: VariableList,
30 | schema: StructType,
31 | connectionProperties: Properties,
32 | materialized: Boolean,
33 | needsRepartition: Boolean,
34 | repartitionColumns: Seq[String])
35 |
36 | def addRDDInfo(rdd: SinglestoreRDD): Unit = {
37 | rddInfos.synchronized({
38 | rddInfos += (rdd.id -> SingleStoreRDDInfo(
39 | rdd.sparkContext,
40 | rdd.query,
41 | rdd.variables,
42 | rdd.schema,
43 | getDDLConnProperties(rdd.options, isOnExecutor = false),
44 | rdd.parallelReadType.contains(ReadFromAggregatorsMaterialized),
45 | rdd.options.parallelReadRepartition,
46 | rdd.parallelReadRepartitionColumns,
47 | ))
48 | })
49 | }
50 |
51 | def deleteRDDInfo(rdd: SinglestoreRDD): Unit = {
52 | rddInfos.synchronized({
53 | rddInfos -= rdd.id
54 | })
55 | }
56 |
57 | def isEmpty: Boolean = {
58 | rddInfos.synchronized({
59 | rddInfos.isEmpty
60 | })
61 | }
62 |
63 | override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
64 | stageSubmitted.stageInfo.rddInfos.foreach(rddInfo => {
65 | if (rddInfo.name.startsWith("SingleStoreRDD")) {
66 | rddInfos
67 | .synchronized(
68 | rddInfos.get(rddInfo.id)
69 | )
70 | .foreach(singleStoreRDDInfo => {
71 | val stageId = stageSubmitted.stageInfo.stageId
72 | val attemptNumber = stageSubmitted.stageInfo.attemptNumber()
73 | val randHex = rddInfo.name.substring("SingleStoreRDD".size)
74 | val tableName =
75 | JdbcHelpers
76 | .getResultTableName(applicationId, stageId, rddInfo.id, attemptNumber, randHex)
77 |
78 | // Create connection and save it in the map
79 | val conn =
80 | SinglestoreConnectionPool.getConnection(singleStoreRDDInfo.connectionProperties)
81 | connectionsMap.synchronized(
82 | connectionsMap += (tableName -> conn)
83 | )
84 |
85 | log.info(s"Creating result table '$tableName'")
86 | try {
87 | // Create result table
88 | JdbcHelpers.createResultTable(
89 | conn,
90 | tableName,
91 | singleStoreRDDInfo.query,
92 | singleStoreRDDInfo.schema,
93 | singleStoreRDDInfo.variables,
94 | singleStoreRDDInfo.materialized,
95 | singleStoreRDDInfo.needsRepartition,
96 | singleStoreRDDInfo.repartitionColumns
97 | )
98 | log.info(s"Successfully created result table '$tableName'")
99 | } catch {
100 | // Cancel execution if we failed to create a result table
101 | case e: SQLException => {
102 | singleStoreRDDInfo.sc.cancelStage(stageId)
103 | throw e
104 | }
105 | }
106 | })
107 | }
108 | })
109 | }
110 |
111 | override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
112 | stageCompleted.stageInfo.rddInfos.foreach(rddInfo => {
113 | if (rddInfo.name.startsWith("SingleStoreRDD")) {
114 | val stageId = stageCompleted.stageInfo.stageId
115 | val attemptNumber = stageCompleted.stageInfo.attemptNumber()
116 | val randHex = rddInfo.name.substring("SingleStoreRDD".size)
117 | val tableName =
118 | JdbcHelpers.getResultTableName(applicationId, stageId, rddInfo.id, attemptNumber, randHex)
119 |
120 | connectionsMap.synchronized(
121 | connectionsMap
122 | .get(tableName)
123 | .foreach(conn => {
124 | // Drop result table
125 | log.info(s"Dropping result table '$tableName'")
126 | JdbcHelpers.dropResultTable(conn, tableName)
127 | log.info(s"Successfully dropped result table '$tableName'")
128 | // Close connection
129 | conn.close()
130 | // Delete connection from map
131 | connectionsMap -= tableName
132 | })
133 | )
134 | }
135 | })
136 | }
137 | }
138 |
139 | case object AggregatorParallelReadListenerAdder {
140 | // listeners is a map from SparkContext hash code to the listener associated with this SparkContext
141 | private val listeners = new mutable.HashMap[SparkContext, AggregatorParallelReadListener]()
142 |
143 | def addRDD(rdd: SinglestoreRDD): Unit = {
144 | this.synchronized({
145 | val listener = listeners.getOrElse(
146 | rdd.sparkContext, {
147 | val newListener = new AggregatorParallelReadListener(rdd.sparkContext.applicationId)
148 | rdd.sparkContext.addSparkListener(newListener)
149 | listeners += (rdd.sparkContext -> newListener)
150 | newListener
151 | }
152 | )
153 | listener.addRDDInfo(rdd)
154 | })
155 | }
156 |
157 | def deleteRDD(rdd: SinglestoreRDD): Unit = {
158 | this.synchronized({
159 | listeners
160 | .get(rdd.sparkContext)
161 | .foreach(listener => {
162 | listener.deleteRDDInfo(rdd)
163 | if (listener.isEmpty) {
164 | listeners -= rdd.sparkContext
165 | rdd.sparkContext.removeSparkListener(listener)
166 | }
167 | })
168 | })
169 | }
170 | }
171 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/AvroSchemaHelper.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.avro.Schema
4 | import org.apache.avro.Schema.Type
5 | import org.apache.avro.Schema.Type._
6 |
7 | import scala.collection.JavaConverters._
8 |
9 | object AvroSchemaHelper {
10 |
11 | def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = {
12 | if (nullable && avroType.getType != NULL) {
13 | // avro uses union to represent nullable type.
14 | val fields = avroType.getTypes.asScala
15 | assert(fields.length == 2)
16 | val actualType = fields.filter(_.getType != Type.NULL)
17 | assert(actualType.length == 1)
18 | actualType.head
19 | } else {
20 | avroType
21 | }
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/CompletionIterator.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | // Copied from spark's CompletionIterator which is private even though it is generically useful
4 |
5 | abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A] {
6 | private[this] var completed = false
7 | def next(): A = sub.next()
8 | def hasNext: Boolean = {
9 | val r = sub.hasNext
10 | if (!r && !completed) {
11 | completed = true
12 | completion()
13 | }
14 | r
15 | }
16 |
17 | def completion(): Unit
18 | }
19 |
20 | private[spark] object CompletionIterator {
21 | def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit): CompletionIterator[A, I] = {
22 | new CompletionIterator[A, I](sub) {
23 | def completion(): Unit = completionFunction
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/DefaultSource.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.SQLGenContext
4 | import org.apache.spark.TaskContext
5 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
6 | import org.apache.spark.metrics.source.MetricsHandler
7 | import org.apache.spark.sql.sources.{
8 | BaseRelation,
9 | CreatableRelationProvider,
10 | DataSourceRegister,
11 | RelationProvider
12 | }
13 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
14 |
15 | object DefaultSource {
16 |
17 | val SINGLESTORE_SOURCE_NAME = "com.singlestore.spark"
18 | val SINGLESTORE_SOURCE_NAME_SHORT = "singlestore"
19 | val SINGLESTORE_GLOBAL_OPTION_PREFIX = "spark.datasource.singlestore."
20 |
21 | @Deprecated val MEMSQL_SOURCE_NAME = "com.memsql.spark"
22 | @Deprecated val MEMSQL_SOURCE_NAME_SHORT = "memsql"
23 | @Deprecated val MEMSQL_GLOBAL_OPTION_PREFIX = "spark.datasource.memsql."
24 | }
25 |
26 | class DefaultSource
27 | extends RelationProvider
28 | with DataSourceRegister
29 | with CreatableRelationProvider
30 | with LazyLogging {
31 |
32 | override def shortName(): String = DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT
33 |
34 | private def includeGlobalParams(sqlContext: SQLContext,
35 | params: Map[String, String]): Map[String, String] =
36 | sqlContext.getAllConfs.foldLeft(params)({
37 | case (params, (k, v)) if k.startsWith(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) =>
38 | params + (k.stripPrefix(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) -> v)
39 | case (params, (k, v)) if k.startsWith(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) =>
40 | params + (k.stripPrefix(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) -> v)
41 | case (params, _) => params
42 | })
43 |
44 | override def createRelation(sqlContext: SQLContext,
45 | parameters: Map[String, String]): BaseRelation = {
46 | val params = CaseInsensitiveMap(includeGlobalParams(sqlContext, parameters))
47 | val options = SinglestoreOptions(params, sqlContext.sparkSession.sparkContext)
48 | if (options.disablePushdown) {
49 | SQLPushdownRule.ensureRemoved(sqlContext.sparkSession)
50 | SinglestoreReaderNoPushdown(SinglestoreOptions.getQuery(params), options, sqlContext)
51 | } else {
52 | SQLPushdownRule.ensureInjected(sqlContext.sparkSession)
53 | SinglestoreReader(SinglestoreOptions.getQuery(params),
54 | Nil,
55 | options,
56 | sqlContext,
57 | context = SQLGenContext(options))
58 | }
59 | }
60 |
61 | override def createRelation(sqlContext: SQLContext,
62 | mode: SaveMode,
63 | parameters: Map[String, String],
64 | data: DataFrame): BaseRelation = {
65 | val opts = CaseInsensitiveMap(includeGlobalParams(sqlContext, parameters))
66 | val conf = SinglestoreOptions(opts, sqlContext.sparkSession.sparkContext)
67 |
68 | val table = SinglestoreOptions
69 | .getTable(opts)
70 | .getOrElse(
71 | throw new IllegalArgumentException(
72 | s"To write a dataframe to SingleStore you must specify a table name via the '${SinglestoreOptions.TABLE_NAME}' parameter"
73 | )
74 | )
75 | JdbcHelpers.prepareTableForWrite(conf, table, mode, data.schema)
76 | val isReferenceTable = JdbcHelpers.isReferenceTable(conf, table)
77 | val partitionWriterFactory =
78 | if (conf.onDuplicateKeySQL.isEmpty) {
79 | new LoadDataWriterFactory(table, conf)
80 | } else {
81 | new BatchInsertWriterFactory(table, conf)
82 | }
83 |
84 | val schema = data.schema
85 | var totalRowCount = 0L
86 | data.foreachPartition((partition: Iterator[Row]) => {
87 | val writer = partitionWriterFactory.createDataWriter(schema,
88 | TaskContext.getPartitionId(),
89 | 0,
90 | isReferenceTable,
91 | mode)
92 | try {
93 | partition.foreach(record => {
94 | writer.write(record)
95 | totalRowCount += 1
96 | })
97 | writer.commit()
98 | MetricsHandler.setRecordsWritten(totalRowCount)
99 | } catch {
100 | case e: Exception =>
101 | writer.abort(e)
102 | throw e
103 | }
104 | })
105 |
106 | createRelation(sqlContext, parameters)
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/LazyLogging.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.slf4j.{Logger, LoggerFactory}
4 |
5 | trait LazyLogging {
6 | @transient
7 | protected lazy val log: Logger = LoggerFactory.getLogger(getClass.getName)
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/Loan.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | class Loan[A <: AutoCloseable](resource: A) {
4 | def to[T](handle: A => T): T =
5 | try handle(resource)
6 | finally resource.close()
7 | }
8 |
9 | object Loan {
10 | def apply[A <: AutoCloseable](resource: A) = new Loan(resource)
11 | }
12 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/MetricsHandler.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.metrics.source
2 |
3 | import org.apache.spark.TaskContext
4 |
5 | object MetricsHandler {
6 | def setRecordsWritten(r: Long): Unit = {
7 | TaskContext.get().taskMetrics().outputMetrics.setRecordsWritten(r)
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/OverwriteBehavior.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | sealed trait OverwriteBehavior
4 |
5 | case object Truncate extends OverwriteBehavior
6 | case object Merge extends OverwriteBehavior
7 | case object DropAndCreate extends OverwriteBehavior
8 |
9 | object OverwriteBehavior {
10 | def apply(value: String): OverwriteBehavior = value.toLowerCase match {
11 | case "truncate" => Truncate
12 | case "merge" => Merge
13 | case "dropandcreate" => DropAndCreate
14 | case _ =>
15 | throw new IllegalArgumentException(
16 | s"Illegal argument for `${SinglestoreOptions.OVERWRITE_BEHAVIOR}` option")
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/ParallelReadEnablement.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | sealed trait ParallelReadEnablement
4 |
5 | case object Disabled extends ParallelReadEnablement
6 | case object Automatic extends ParallelReadEnablement
7 | case object AutomaticLite extends ParallelReadEnablement
8 | case object Forced extends ParallelReadEnablement
9 |
10 | object ParallelReadEnablement {
11 | def apply(value: String): ParallelReadEnablement = value.toLowerCase match {
12 | case "disabled" => Disabled
13 | case "automaticlite" => AutomaticLite
14 | case "automatic" => Automatic
15 | case "forced" => Forced
16 |
17 | // These two options are added for compatibility purposes
18 | case "false" => Disabled
19 | case "true" => Automatic
20 |
21 | case _ =>
22 | throw new IllegalArgumentException(
23 | s"""Illegal argument for `${SinglestoreOptions.ENABLE_PARALLEL_READ}` option. Valid arguments are:
24 | | - "Disabled"
25 | | - "AutomaticLite"
26 | | - "Automatic"
27 | | - "Forced"""".stripMargin)
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/ParallelReadType.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | sealed trait ParallelReadType
4 |
5 | case object ReadFromLeaves extends ParallelReadType
6 | case object ReadFromAggregators extends ParallelReadType
7 | case object ReadFromAggregatorsMaterialized extends ParallelReadType
8 |
9 | object ParallelReadType {
10 | def apply(value: String): ParallelReadType = value.toLowerCase match {
11 | case "readfromleaves" => ReadFromLeaves
12 | case "readfromaggregators" => ReadFromAggregators
13 | case "readfromaggregatorsmaterialized" => ReadFromAggregatorsMaterialized
14 | case _ =>
15 | throw new IllegalArgumentException(
16 | s"""Illegal argument for `${SinglestoreOptions.PARALLEL_READ_FEATURES}` option. Valid arguments are:
17 | | - "ReadFromLeaves"
18 | | - "ReadFromAggregators"
19 | | - "ReadFromAggregatorsMaterialized"""".stripMargin)
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SQLHelper.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.JdbcHelpers.{executeQuery, getDDLConnProperties}
4 | import org.apache.spark.sql.{Row, SparkSession}
5 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
6 | import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
7 |
8 | object SQLHelper extends LazyLogging {
9 | implicit class QueryMethods(spark: SparkSession) {
10 | private def singlestoreQuery(db: Option[String],
11 | query: String,
12 | variables: Any*): Iterator[Row] = {
13 | val ctx = spark.sqlContext
14 | var opts = ctx.getAllConfs.collect {
15 | case (k, v) if k.startsWith(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) =>
16 | k.stripPrefix(DefaultSource.SINGLESTORE_GLOBAL_OPTION_PREFIX) -> v
17 | case (k, v) if k.startsWith(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) =>
18 | k.stripPrefix(DefaultSource.MEMSQL_GLOBAL_OPTION_PREFIX) -> v
19 | }
20 |
21 | if (db.isDefined) {
22 | val dbValue = db.get
23 | if (dbValue.isEmpty) {
24 | opts -= "database"
25 | } else {
26 | opts += ("database" -> dbValue)
27 | }
28 | }
29 |
30 | val conf = SinglestoreOptions(CaseInsensitiveMap(opts), spark.sparkContext)
31 | val conn =
32 | SinglestoreConnectionPool.getConnection(getDDLConnProperties(conf, isOnExecutor = false))
33 | try {
34 | executeQuery(conn, query, variables: _*)
35 | } finally {
36 | conn.close()
37 | }
38 | }
39 |
40 | def executeSinglestoreQueryDB(db: String, query: String, variables: Any*): Iterator[Row] = {
41 | singlestoreQuery(Some(db), query, variables: _*)
42 | }
43 |
44 | def executeSinglestoreQuery(query: String, variables: Any*): Iterator[Row] = {
45 | singlestoreQuery(None, query, variables: _*)
46 | }
47 |
48 | @Deprecated def executeMemsqlQueryDB(db: String,
49 | query: String,
50 | variables: Any*): Iterator[Row] = {
51 | singlestoreQuery(Some(db), query, variables: _*)
52 | }
53 |
54 | @Deprecated def executeMemsqlQuery(query: String, variables: Any*): Iterator[Row] = {
55 | singlestoreQuery(None, query, variables: _*)
56 | }
57 | }
58 |
59 | def executeSinglestoreQueryDB(spark: SparkSession,
60 | db: String,
61 | query: String,
62 | variables: Any*): Iterator[Row] = {
63 | spark.executeSinglestoreQueryDB(db, query, variables: _*)
64 | }
65 |
66 | def executeSinglestoreQuery(spark: SparkSession,
67 | query: String,
68 | variables: Any*): Iterator[Row] = {
69 | spark.executeSinglestoreQuery(query, variables: _*)
70 | }
71 |
72 | @Deprecated def executeMemsqlQueryDB(spark: SparkSession,
73 | db: String,
74 | query: String,
75 | variables: Any*): Iterator[Row] = {
76 | spark.executeSinglestoreQueryDB(db, query, variables: _*)
77 | }
78 |
79 | @Deprecated def executeMemsqlQuery(spark: SparkSession,
80 | query: String,
81 | variables: Any*): Iterator[Row] = {
82 | spark.executeSinglestoreQuery(query, variables: _*)
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SQLPushdownRule.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext}
4 | import org.apache.spark.sql.SparkSession
5 | import org.apache.spark.sql.catalyst.plans.logical._
6 | import org.apache.spark.sql.catalyst.rules.Rule
7 |
8 | class SQLPushdownRule extends Rule[LogicalPlan] {
9 | override def apply(root: LogicalPlan): LogicalPlan = {
10 | var context: SQLGenContext = null
11 | val needsPushdown = root
12 | .find({
13 | case SQLGen.Relation(r: SQLGen.Relation) if !r.reader.isFinal =>
14 | context = SQLGenContext(root, r.reader.options)
15 | true
16 | case _ => false
17 | })
18 | .isDefined
19 |
20 | if (!needsPushdown) {
21 | return root
22 | }
23 |
24 | if (log.isTraceEnabled) {
25 | log.trace(s"Optimizing plan:\n${root.treeString(true)}")
26 | }
27 |
28 | // We first need to set a SQLGenContext in every reader.
29 | // This transform is done to ensure that we will generate the same aliases in the same queries.
30 | val normalized = root.transform({
31 | case SQLGen.Relation(relation) =>
32 | relation.toLogicalPlan(
33 | relation.output,
34 | relation.reader.query,
35 | relation.reader.variables,
36 | relation.reader.isFinal,
37 | context
38 | )
39 | })
40 |
41 | // Second, we need to rename the outputs of each SingleStore relation in the tree. This transform is
42 | // done to ensure that we can handle projections which involve ambiguous column name references.
43 | var ptr, nextPtr = normalized.transform({
44 | case SQLGen.Relation(relation) => relation.renameOutput
45 | })
46 |
47 | val expressionExtractor = ExpressionExtractor(context)
48 | val transforms =
49 | List(
50 | // do all rewrites except top-level sort, e.g. Project([a,b,c], Relation(select * from foo))
51 | SQLGen.fromLogicalPlan(expressionExtractor).andThen(_.asLogicalPlan()),
52 | // do rewrites with top-level Sort, e.g. Sort(a, Limit(10, Relation(select * from foo))
53 | // won't be done for relations with parallel read enabled
54 | SQLGen.fromTopLevelSort(expressionExtractor),
55 | )
56 |
57 | // Run our transforms in a loop until the tree converges
58 | do {
59 | ptr = nextPtr
60 | nextPtr = transforms.foldLeft(ptr)(_.transformUp(_))
61 | } while (!ptr.fastEquals(nextPtr))
62 |
63 | // Finalize all the relations in the tree and perform casts into the expected output datatype for Spark
64 | val out = ptr.transform({
65 | case SQLGen.Relation(relation) if !relation.isFinal => relation.castOutputAndFinalize
66 | })
67 |
68 | if (log.isTraceEnabled) {
69 | log.trace(s"Optimized Plan:\n${out.treeString(true)}")
70 | }
71 |
72 | out
73 | }
74 | }
75 |
76 | object SQLPushdownRule {
77 | def injected(session: SparkSession): Boolean = {
78 | session.experimental.extraOptimizations
79 | .exists(s => s.isInstanceOf[SQLPushdownRule])
80 | }
81 |
82 | def ensureInjected(session: SparkSession): Unit = {
83 | if (!injected(session)) {
84 | session.experimental.extraOptimizations ++= Seq(new SQLPushdownRule)
85 | }
86 | }
87 |
88 | def ensureRemoved(session: SparkSession): Unit = {
89 | session.experimental.extraOptimizations = session.experimental.extraOptimizations
90 | .filterNot(s => s.isInstanceOf[SQLPushdownRule])
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SinglestoreBatchInsertWriter.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.Connection
4 | import java.util.Base64
5 |
6 | import com.singlestore.spark.JdbcHelpers.{getDDLConnProperties, getDMLConnProperties}
7 | import org.apache.spark.sql.catalyst.TableIdentifier
8 | import org.apache.spark.sql.types.{BinaryType, StructType}
9 | import org.apache.spark.sql.{Row, SaveMode}
10 |
11 | import scala.collection.mutable.ListBuffer
12 | import scala.concurrent.ExecutionContext
13 |
14 | // TODO: extend it from DataWriterFactory
15 | class BatchInsertWriterFactory(table: TableIdentifier, conf: SinglestoreOptions)
16 | extends WriterFactory
17 | with LazyLogging {
18 |
19 | def createDataWriter(schema: StructType,
20 | partitionId: Int,
21 | attemptNumber: Int,
22 | isReferenceTable: Boolean,
23 | mode: SaveMode): DataWriter[Row] = {
24 | val columnNames = schema.map(s => SinglestoreDialect.quoteIdentifier(s.name))
25 | val queryPrefix = s"INSERT INTO ${table.quotedString} (${columnNames.mkString(", ")}) VALUES "
26 | val querySuffix = s" ON DUPLICATE KEY UPDATE ${conf.onDuplicateKeySQL.get}"
27 |
28 | val rowTemplate = "(" + schema
29 | .map(x =>
30 | x.dataType match {
31 | case BinaryType => "FROM_BASE64(?)"
32 | case _ => "?"
33 | })
34 | .mkString(",") + ")"
35 | def valueTemplate(rows: Int): String =
36 | List.fill(rows)(rowTemplate).mkString(",")
37 | val fullBatchQuery = queryPrefix + valueTemplate(conf.insertBatchSize) + querySuffix
38 |
39 | val conn = SinglestoreConnectionPool.getConnection(if (isReferenceTable) {
40 | getDDLConnProperties(conf, isOnExecutor = true)
41 | } else {
42 | getDMLConnProperties(conf, isOnExecutor = true)
43 | })
44 | conn.setAutoCommit(false)
45 |
46 | def writeBatch(buff: ListBuffer[Row]): Long = {
47 | if (buff.isEmpty) {
48 | 0
49 | } else {
50 | val rowsCount = buff.size
51 | val query = if (rowsCount == conf.insertBatchSize) {
52 | fullBatchQuery
53 | } else {
54 | queryPrefix + valueTemplate(rowsCount) + querySuffix
55 | }
56 |
57 | val stmt = conn.prepareStatement(query)
58 | try {
59 | for {
60 | (row, i) <- buff.iterator.zipWithIndex
61 | rowLength = row.size
62 | j <- 0 until rowLength
63 | } row(j) match {
64 | case bytes: Array[Byte] =>
65 | stmt.setObject(i * rowLength + j + 1, Base64.getEncoder.encode(bytes))
66 | case obj =>
67 | stmt.setObject(i * rowLength + j + 1, obj)
68 | }
69 | stmt.executeUpdate()
70 | } finally {
71 | stmt.close()
72 | conn.commit()
73 | }
74 | }
75 | }
76 |
77 | new BatchInsertWriter(conf.insertBatchSize, writeBatch, conn)
78 | }
79 | }
80 |
81 | class BatchInsertWriter(batchSize: Int, writeBatch: ListBuffer[Row] => Long, conn: Connection)
82 | extends DataWriter[Row] {
83 | var buff: ListBuffer[Row] = ListBuffer.empty[Row]
84 |
85 | override def write(row: Row): Unit = {
86 | buff += row
87 | if (buff.size >= batchSize) {
88 | writeBatch(buff)
89 | buff = ListBuffer.empty[Row]
90 | }
91 | }
92 |
93 | override def commit(): WriterCommitMessage = {
94 | try {
95 | writeBatch(buff)
96 | buff = ListBuffer.empty[Row]
97 | } finally {
98 | conn.close()
99 | }
100 | new WriteSuccess
101 | }
102 |
103 | override def abort(e: Exception): Unit = {
104 | buff = ListBuffer.empty[Row]
105 | if (!conn.isClosed) {
106 | conn.abort(ExecutionContext.global)
107 | }
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SinglestoreConnectionPool.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.Connection
4 | import java.util.Properties
5 | import org.apache.commons.dbcp2.{BasicDataSource, BasicDataSourceFactory}
6 | import scala.collection.mutable
7 |
8 | object SinglestoreConnectionPool {
9 | private var dataSources = new mutable.HashMap[Properties, BasicDataSource]()
10 |
11 | private def deleteEmptyDataSources(): Unit = {
12 | dataSources = dataSources.filter(pair => {
13 | val dataSource = pair._2
14 | if (dataSource.getNumActive + dataSource.getNumIdle == 0) {
15 | dataSource.close()
16 | false
17 | } else {
18 | true
19 | }
20 | })
21 | }
22 |
23 | def getConnection(properties: Properties): Connection = {
24 | this.synchronized({
25 | dataSources
26 | .getOrElse(
27 | properties, {
28 | deleteEmptyDataSources()
29 | val newDataSource = BasicDataSourceFactory.createDataSource(properties)
30 | newDataSource.addConnectionProperty("connectionAttributes",
31 | properties.getProperty("connectionAttributes"))
32 | dataSources += (properties -> newDataSource)
33 | newDataSource
34 | }
35 | )
36 | .getConnection
37 | })
38 | }
39 |
40 | def close(): Unit = {
41 | this.synchronized({
42 | dataSources.foreach(pair => pair._2.close())
43 | dataSources = new mutable.HashMap[Properties, BasicDataSource]()
44 | })
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SinglestoreConnectionPoolOptions.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | case class SinglestoreConnectionPoolOptions(enabled: Boolean,
4 | MaxOpenConns: Int,
5 | MaxIdleConns: Int,
6 | MinEvictableIdleTimeMs: Long,
7 | TimeBetweenEvictionRunsMS: Long,
8 | MaxWaitMS: Long,
9 | MaxConnLifetimeMS: Long)
10 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SinglestoreDialect.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.Types
4 |
5 | import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
6 | import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType}
7 | import org.apache.spark.sql.types._
8 |
9 | case object SinglestoreDialect extends JdbcDialect {
10 | override def canHandle(url: String): Boolean = url.startsWith("jdbc:memsql")
11 |
12 | val SINGLESTORE_DECIMAL_MAX_SCALE = 30
13 |
14 | override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
15 | case BooleanType => Option(JdbcType("BOOL", Types.BOOLEAN))
16 | case ByteType => Option(JdbcType("TINYINT", Types.TINYINT))
17 | case ShortType => Option(JdbcType("SMALLINT", Types.SMALLINT))
18 | case FloatType => Option(JdbcType("FLOAT", Types.FLOAT))
19 | case TimestampType => Option(JdbcType("TIMESTAMP(6)", Types.TIMESTAMP))
20 | case dt: DecimalType if (dt.scale <= SINGLESTORE_DECIMAL_MAX_SCALE) =>
21 | Option(JdbcType(s"DECIMAL(${dt.precision}, ${dt.scale})", Types.DECIMAL))
22 | case dt: DecimalType =>
23 | throw new IllegalArgumentException(
24 | s"Too big scale specified(${dt.scale}). SingleStore DECIMAL maximum scale is ${SINGLESTORE_DECIMAL_MAX_SCALE}")
25 | case NullType =>
26 | throw new IllegalArgumentException(
27 | "No corresponding SingleStore type found for NullType. If you want to use NullType, please write to an already existing SingleStore table.")
28 | case t => JdbcUtils.getCommonJDBCType(t)
29 | }
30 |
31 | override def getCatalystType(sqlType: Int,
32 | typeName: String,
33 | size: Int,
34 | md: MetadataBuilder): Option[DataType] = {
35 | (sqlType, typeName) match {
36 | case (Types.REAL, "FLOAT") => Option(FloatType)
37 | case (Types.BIT, "BIT") => Option(BinaryType)
38 | // JDBC driver returns incorrect SQL type for BIT
39 | // TODO delete after PLAT-6829 is fixed
40 | case (Types.BOOLEAN, "BIT") => Option(BinaryType)
41 | case (Types.TINYINT, "TINYINT") => Option(ShortType)
42 | case (Types.SMALLINT, "SMALLINT") => Option(ShortType)
43 | case (Types.INTEGER, "SMALLINT") => Option(IntegerType)
44 | case (Types.INTEGER, "SMALLINT UNSIGNED") => Option(IntegerType)
45 | case (Types.DECIMAL, "DECIMAL") => {
46 | if (size > DecimalType.MAX_PRECISION) {
47 | throw new IllegalArgumentException(
48 | s"DECIMAL precision ${size} exceeds max precision ${DecimalType.MAX_PRECISION}")
49 | } else {
50 | Option(
51 | DecimalType(size, md.build().getLong("scale").toInt)
52 | )
53 | }
54 | }
55 | case _ => None
56 | }
57 | }
58 |
59 | override def quoteIdentifier(colName: String): String = {
60 | s"`${colName.replace("`", "``")}`"
61 | }
62 |
63 | override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SinglestoreRDD.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Connection, PreparedStatement, ResultSet}
4 | import java.util.concurrent.{Executors, ForkJoinPool}
5 | import com.singlestore.spark.SQLGen.VariableList
6 | import org.apache.spark.rdd.RDD
7 | import org.apache.spark.sql.Row
8 | import org.apache.spark.sql.catalyst.expressions.Attribute
9 | import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
10 | import org.apache.spark.sql.types._
11 | import org.apache.spark.util.TaskCompletionListener
12 | import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
13 |
14 | import scala.concurrent.duration.Duration
15 | import scala.concurrent.ExecutionContext.Implicits.global
16 | import scala.concurrent.{Await, ExecutionContext, Future}
17 |
18 | case class SinglestoreRDD(query: String,
19 | variables: VariableList,
20 | options: SinglestoreOptions,
21 | schema: StructType,
22 | expectedOutput: Seq[Attribute],
23 | resultMustBeSorted: Boolean,
24 | parallelReadRepartitionColumns: Seq[String],
25 | @transient val sc: SparkContext,
26 | randHex: String)
27 | extends RDD[Row](sc, Nil) {
28 | val (parallelReadType, partitions_) = SinglestorePartitioner(this).getPartitions
29 |
30 | // Spark serializes RDD object and sends it to executor
31 | // On executor sc value will be null as it is marked as transient
32 | def isRunOnExecutor: Boolean = sc == null
33 |
34 | val applicationId: String = sc.applicationId
35 |
36 | val aggregatorParallelReadUsed: Boolean =
37 | parallelReadType.contains(ReadFromAggregators) ||
38 | parallelReadType.contains(ReadFromAggregatorsMaterialized)
39 |
40 | if (!isRunOnExecutor && aggregatorParallelReadUsed) {
41 | AggregatorParallelReadListenerAdder.addRDD(this)
42 | }
43 |
44 | override def finalize(): Unit = {
45 | if (!isRunOnExecutor && aggregatorParallelReadUsed) {
46 | AggregatorParallelReadListenerAdder.deleteRDD(this)
47 | }
48 | super.finalize()
49 | }
50 |
51 | override protected def getPartitions: Array[Partition] = partitions_
52 |
53 | override def compute(rawPartition: Partition, context: TaskContext): Iterator[Row] = {
54 | val multiPartition: SinglestoreMultiPartition =
55 | rawPartition.asInstanceOf[SinglestoreMultiPartition]
56 | val threadPool = new ForkJoinPool(multiPartition.partitions.size)
57 | try {
58 | val executionContext =
59 | ExecutionContext.fromExecutor(threadPool)
60 | val future: Future[Seq[Iterator[Row]]] = Future.sequence(multiPartition.partitions.map(p =>
61 | Future(computeSinglePartition(p, context))(executionContext)))
62 |
63 | Await.result(future, Duration.Inf).foldLeft(Iterator[Row]())(_ ++ _)
64 | } finally {
65 | threadPool.shutdownNow()
66 | }
67 | }
68 |
69 | def computeSinglePartition(rawPartition: SinglestorePartition,
70 | context: TaskContext): Iterator[Row] = {
71 | var closed = false
72 | var rs: ResultSet = null
73 | var stmt: PreparedStatement = null
74 | var conn: Connection = null
75 | val partition: SinglestorePartition = rawPartition.asInstanceOf[SinglestorePartition]
76 |
77 | def tryClose(name: String, what: AutoCloseable): Unit = {
78 | try {
79 | if (what != null) { what.close() }
80 | } catch {
81 | case e: Exception => logWarning(s"Exception closing $name", e)
82 | }
83 | }
84 |
85 | val ErrResultTableNotExistCode = 2318
86 |
87 | def close(): Unit = {
88 | if (closed) { return }
89 | tryClose("resultset", rs)
90 | tryClose("statement", stmt)
91 | tryClose("connection", conn)
92 | closed = true
93 | }
94 |
95 | context.addTaskCompletionListener {
96 | new TaskCompletionListener {
97 | override def onTaskCompletion(context: TaskContext): Unit = close()
98 | }
99 | }
100 |
101 | conn = SinglestoreConnectionPool.getConnection(partition.connectionInfo)
102 | if (aggregatorParallelReadUsed) {
103 | val tableName = JdbcHelpers.getResultTableName(applicationId,
104 | context.stageId(),
105 | id,
106 | context.stageAttemptNumber(),
107 | randHex)
108 |
109 | stmt =
110 | conn.prepareStatement(JdbcHelpers.getSelectFromResultTableQuery(tableName, partition.index))
111 |
112 | val startTime = System.currentTimeMillis()
113 | val timeout = parallelReadType match {
114 | case Some(ReadFromAggregators) =>
115 | options.parallelReadTableCreationTimeoutMS
116 | case Some(ReadFromAggregatorsMaterialized) =>
117 | options.parallelReadMaterializedTableCreationTimeoutMS
118 | case _ =>
119 | 0
120 | }
121 |
122 | var lastError: java.sql.SQLException = null
123 | var delay = 50
124 | val maxDelay = 10000
125 | while (rs == null && (timeout == 0 || System.currentTimeMillis() - startTime < timeout)) {
126 | try {
127 | rs = stmt.executeQuery()
128 | } catch {
129 | case e: java.sql.SQLException if e.getErrorCode == ErrResultTableNotExistCode =>
130 | lastError = e
131 | delay = Math.min(maxDelay, delay * 2)
132 | Thread.sleep(delay)
133 | }
134 | }
135 |
136 | if (rs == null) {
137 | throw new java.sql.SQLException("Failed to read data from result table", lastError)
138 | }
139 | } else {
140 | stmt = conn.prepareStatement(partition.query)
141 | JdbcHelpers.fillStatement(stmt, partition.variables)
142 | rs = stmt.executeQuery()
143 | }
144 |
145 | var rowsIter = JdbcUtils.resultSetToRows(rs, schema)
146 |
147 | if (expectedOutput.nonEmpty) {
148 | val schemaDatatypes = schema.map(_.dataType)
149 | val expectedDatatypes = expectedOutput.map(_.dataType)
150 |
151 | def getOrNull(f: => Any, r: Row, i: Int): Any = {
152 | if (r.isNullAt(i)) null
153 | else f
154 | }
155 |
156 | if (schemaDatatypes != expectedDatatypes) {
157 | val columnEncoders = schemaDatatypes.zip(expectedDatatypes).zipWithIndex.map {
158 | case ((_: StringType, _: NullType), _) => ((_: Row) => null)
159 | case ((_: ShortType, _: BooleanType), i) =>
160 | (r: Row) =>
161 | getOrNull(r.getShort(i) != 0, r, i)
162 | case ((_: IntegerType, _: BooleanType), i) =>
163 | (r: Row) =>
164 | getOrNull(r.getInt(i) != 0, r, i)
165 | case ((_: LongType, _: BooleanType), i) =>
166 | (r: Row) =>
167 | getOrNull(r.getLong(i) != 0, r, i)
168 |
169 | case ((_: ShortType, _: ByteType), i) =>
170 | (r: Row) =>
171 | getOrNull(r.getShort(i).toByte, r, i)
172 | case ((_: IntegerType, _: ByteType), i) =>
173 | (r: Row) =>
174 | getOrNull(r.getInt(i).toByte, r, i)
175 | case ((_: LongType, _: ByteType), i) =>
176 | (r: Row) =>
177 | getOrNull(r.getLong(i).toByte, r, i)
178 |
179 | case ((l, r), i) =>
180 | options.assert(l == r, s"SinglestoreRDD: unable to encode ${l} into ${r}")
181 | ((r: Row) => getOrNull(r.get(i), r, i))
182 | }
183 |
184 | rowsIter = rowsIter
185 | .map(row => Row.fromSeq(columnEncoders.map(_(row))))
186 | }
187 | }
188 |
189 | CompletionIterator[Row, Iterator[Row]](new InterruptibleIterator[Row](context, rowsIter), close)
190 | }
191 |
192 | }
193 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/SinglestoreReader.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.SQLSyntaxErrorException
4 |
5 | import com.singlestore.spark.SQLGen.{ExpressionExtractor, SQLGenContext, VariableList}
6 | import org.apache.spark.rdd.RDD
7 | import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression => CatalystExpression}
8 | import org.apache.spark.sql.sources.{BaseRelation, CatalystScan, TableScan}
9 | import org.apache.spark.sql.{Row, SQLContext}
10 |
11 | import scala.util.Random;
12 |
13 | case class SinglestoreReaderNoPushdown(query: String,
14 | options: SinglestoreOptions,
15 | @transient val sqlContext: SQLContext)
16 | extends BaseRelation
17 | with TableScan {
18 |
19 | override lazy val schema = JdbcHelpers.loadSchema(options, query, Nil)
20 |
21 | override def buildScan: RDD[Row] = {
22 | val randHex = Random.nextInt().toHexString
23 | val rdd =
24 | SinglestoreRDD(
25 | query,
26 | Nil,
27 | options,
28 | schema,
29 | Nil,
30 | resultMustBeSorted = false,
31 | schema
32 | .filter(sf => options.parallelReadRepartitionColumns.contains(sf.name))
33 | .map(sf => SQLGen.Ident(sf.name).sql),
34 | sqlContext.sparkContext,
35 | randHex
36 | )
37 | // Add random hex to the name
38 | // It is needed to generate unique names for result tables during parallel read
39 | .setName("SingleStoreRDD" + randHex)
40 | if (rdd.parallelReadType.contains(ReadFromAggregators)) {
41 | // Wrap an RDD with barrier stage, to force all readers start reading at the same time.
42 | // Repartition it to force spark to read data and do all other computations in different stages.
43 | // Otherwise we will likely get the following error:
44 | // [SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following
45 | // pattern of RDD chain within a barrier stage...
46 | rdd.barrier().mapPartitions(v => v).repartition(rdd.getNumPartitions)
47 | } else {
48 | rdd
49 | }
50 | }
51 | }
52 |
53 | case class SinglestoreReader(query: String,
54 | variables: VariableList,
55 | options: SinglestoreOptions,
56 | @transient val sqlContext: SQLContext,
57 | isFinal: Boolean = false,
58 | expectedOutput: Seq[Attribute] = Nil,
59 | var resultMustBeSorted: Boolean = false,
60 | context: SQLGenContext)
61 | extends BaseRelation
62 | with LazyLogging
63 | with TableScan
64 | with CatalystScan {
65 |
66 | override lazy val schema = JdbcHelpers.loadSchema(options, query, variables)
67 |
68 | override def buildScan: RDD[Row] = {
69 | val randHex = Random.nextInt().toHexString
70 | val rdd =
71 | SinglestoreRDD(
72 | query,
73 | variables,
74 | options,
75 | schema,
76 | expectedOutput,
77 | resultMustBeSorted,
78 | expectedOutput
79 | .filter(attr => options.parallelReadRepartitionColumns.contains(attr.name))
80 | .map(attr => context.ident(attr.name, attr.exprId)),
81 | sqlContext.sparkContext,
82 | randHex
83 | )
84 | // Add random hex to the name
85 | // It is needed to generate unique names for result tables during parallel read
86 | .setName("SingleStoreRDD" + randHex)
87 | if (rdd.parallelReadType.contains(ReadFromAggregators)) {
88 | // Wrap an RDD with barrier stage, to force all readers start reading at the same time.
89 | // Repartition it to force spark to read data and do all other computations in different stages.
90 | // Otherwise we will likely get the following error:
91 | // [SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following
92 | // pattern of RDD chain within a barrier stage...
93 | rdd.barrier().mapPartitions(v => v).repartition(rdd.getNumPartitions)
94 | } else {
95 | rdd
96 | }
97 | }
98 |
99 | override def buildScan(rawColumns: Seq[Attribute],
100 | rawFilters: Seq[CatalystExpression]): RDD[Row] = {
101 | // we don't have to push down *everything* using this interface since Spark will
102 | // run the projection and filter again upon receiving the results from SingleStore
103 | val projection =
104 | rawColumns
105 | .flatMap(ExpressionGen.apply(ExpressionExtractor(context)).lift(_))
106 | .reduceOption(_ + "," + _)
107 | val filters =
108 | rawFilters
109 | .flatMap(ExpressionGen.apply(ExpressionExtractor(context)).lift(_))
110 | .reduceOption(_ + "AND" + _)
111 |
112 | val stmt = (projection, filters) match {
113 | case (Some(p), Some(f)) =>
114 | SQLGen
115 | .select(p)
116 | .from(SQLGen.Relation(Nil, this, context.nextAlias(), null))
117 | .where(f)
118 | .output(rawColumns)
119 | case (Some(p), None) =>
120 | SQLGen
121 | .select(p)
122 | .from(SQLGen.Relation(Nil, this, context.nextAlias(), null))
123 | .output(rawColumns)
124 | case (None, Some(f)) =>
125 | SQLGen.selectAll
126 | .from(SQLGen.Relation(Nil, this, context.nextAlias(), null))
127 | .where(f)
128 | .output(expectedOutput)
129 | case _ =>
130 | return buildScan
131 | }
132 |
133 | val newReader = copy(query = stmt.sql, variables = stmt.variables, expectedOutput = stmt.output)
134 |
135 | if (log.isTraceEnabled) {
136 | log.trace(s"CatalystScan additional rewrite:\n${newReader}")
137 | }
138 |
139 | newReader.buildScan
140 | }
141 |
142 | override def toString: String = {
143 | val explain =
144 | try {
145 | JdbcHelpers.explainQuery(options, query, variables)
146 | } catch {
147 | case e: SQLSyntaxErrorException => e.toString
148 | case e: Exception => throw e
149 | }
150 | val v = variables.map(_.variable).mkString(", ")
151 |
152 | s"""
153 | |---------------
154 | |SingleStore Query
155 | |Variables: ($v)
156 | |SQL:
157 | |$query
158 | |
159 | |EXPLAIN:
160 | |$explain
161 | |---------------
162 | """.stripMargin
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/src/main/scala/com/singlestore/spark/vendor/apache/SchemaConverters.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark.vendor.apache
2 |
3 | import org.apache.avro._
4 | import org.apache.spark.sql.types._
5 |
6 | /**
7 | * NOTE: this converter has been taken from `spark-avro` library,
8 | * as this functionality starts with spark 2.4.0 but we support lower versions.
9 | * org.apache.spark.sql.avro.SchemaConverters (2.4.0 version)
10 | * Changes:
11 | * 1. Removed everything except toAvroType function
12 | */
13 | object SchemaConverters {
14 |
15 | private lazy val nullSchema = Schema.create(Schema.Type.NULL)
16 |
17 | case class SchemaType(dataType: DataType, nullable: Boolean)
18 |
19 | def toAvroType(catalystType: DataType,
20 | nullable: Boolean = false,
21 | recordName: String = "topLevelRecord",
22 | nameSpace: String = ""): Schema = {
23 | val builder = SchemaBuilder.builder()
24 |
25 | val schema = catalystType match {
26 | case BooleanType => builder.booleanType()
27 | case ByteType | ShortType | IntegerType => builder.intType()
28 | case LongType => builder.longType()
29 | case DateType =>
30 | LogicalTypes.date().addToSchema(builder.intType())
31 | case TimestampType =>
32 | LogicalTypes.timestampMicros().addToSchema(builder.longType())
33 |
34 | case FloatType => builder.floatType()
35 | case DoubleType => builder.doubleType()
36 | case StringType => builder.stringType()
37 | case _: DecimalType => builder.stringType()
38 | case BinaryType => builder.bytesType()
39 | case ArrayType(et, containsNull) =>
40 | builder
41 | .array()
42 | .items(toAvroType(et, containsNull, recordName, nameSpace))
43 | case MapType(StringType, vt, valueContainsNull) =>
44 | builder
45 | .map()
46 | .values(toAvroType(vt, valueContainsNull, recordName, nameSpace))
47 | case st: StructType =>
48 | val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName
49 | val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
50 | st.foreach { f =>
51 | val fieldAvroType =
52 | toAvroType(f.dataType, f.nullable, f.name, childNameSpace)
53 | fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
54 | }
55 | fieldsAssembler.endRecord()
56 |
57 | // This should never happen.
58 | case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
59 | }
60 | if (nullable) {
61 | Schema.createUnion(schema, nullSchema)
62 | } else {
63 | schema
64 | }
65 | }
66 | }
67 |
68 | class IncompatibleSchemaException(msg: String, ex: Throwable = null) extends Exception(msg, ex)
69 |
--------------------------------------------------------------------------------
/src/test/resources/log4j.properties:
--------------------------------------------------------------------------------
1 | # Set everything to be logged to the console
2 | log4j.rootCategory=ERROR, console
3 | log4j.appender.console=org.apache.log4j.ConsoleAppender
4 | log4j.appender.console.target=System.err
5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout
6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
7 |
8 | # Settings to quiet third party logs that are too verbose
9 | log4j.logger.org.eclipse.jetty=WARN
10 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=ERROR
12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=ERROR
13 |
14 | # Make our logs LOUD
15 | log4j.logger.com.singlestore.spark=DEBUG
16 |
--------------------------------------------------------------------------------
/src/test/resources/log4j2.properties:
--------------------------------------------------------------------------------
1 | # Create STDOUT appender that writes data to the console
2 | appenders = console
3 | appender.console.type = Console
4 | appender.console.name = STDOUT
5 | appender.console.layout.type = PatternLayout
6 | appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
7 |
8 | # Set everything to be logged to the console
9 | rootLogger.level = ERROR
10 | rootLogger.appenderRef.stdout.ref = STDOUT
11 |
12 | # Make our logs LOUD
13 | loggers = singlestore
14 | logger.singlestore.name = com.singlestore.spark
15 | logger.singlestore.level = TRACE
16 | logger.singlestore.appenderRef.stdout.ref = STDOUT
17 | logger.singlestore.additivity = false
18 |
--------------------------------------------------------------------------------
/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker:
--------------------------------------------------------------------------------
1 | mock-maker-inline
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/BatchInsertBenchmark.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Connection, Date, DriverManager}
4 | import java.time.LocalDate
5 | import java.util.Properties
6 |
7 | import org.apache.spark.sql.types._
8 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
9 | import org.apache.spark.sql.{SaveMode, SparkSession}
10 |
11 | import scala.util.Random
12 |
13 | // BatchInsertBenchmark is written to test batch insert with CPU profiler
14 | // this feature is accessible in Ultimate version of IntelliJ IDEA
15 | // see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
16 | object BatchInsertBenchmark extends App {
17 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost")
18 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506")
19 |
20 | val spark: SparkSession = SparkSession
21 | .builder()
22 | .master("local")
23 | .config("spark.sql.shuffle.partitions", "1")
24 | .config("spark.driver.bindAddress", "localhost")
25 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}")
26 | .config("spark.datasource.singlestore.database", "testdb")
27 | .getOrCreate()
28 |
29 | def jdbcConnection: Loan[Connection] = {
30 | val connProperties = new Properties()
31 | connProperties.put("user", "root")
32 |
33 | Loan(
34 | DriverManager.getConnection(
35 | s"jdbc:singlestore://$masterHost:$masterPort",
36 | connProperties
37 | ))
38 | }
39 |
40 | def executeQuery(sql: String): Unit = {
41 | jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
42 | }
43 |
44 | executeQuery("set global default_partitions_per_leaf = 2")
45 | executeQuery("drop database if exists testdb")
46 | executeQuery("create database testdb")
47 |
48 | def genDate() =
49 | Date.valueOf(LocalDate.ofEpochDay(LocalDate.of(2001, 4, 11).toEpochDay + Random.nextInt(10000)))
50 | def genRow(): (Long, Int, Double, String, Date) =
51 | (Random.nextLong(), Random.nextInt(), Random.nextDouble(), Random.nextString(20), genDate())
52 | val df =
53 | spark.createDF(
54 | List.fill(1000000)(genRow()),
55 | List(("LongType", LongType, true),
56 | ("IntType", IntegerType, true),
57 | ("DoubleType", DoubleType, true),
58 | ("StringType", StringType, true),
59 | ("DateType", DateType, true))
60 | )
61 |
62 | val start = System.nanoTime()
63 | df.write
64 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
65 | .option("tableKey.primary", "IntType")
66 | .option("onDuplicateKeySQL", "IntType = IntType")
67 | .mode(SaveMode.Append)
68 | .save("testdb.batchinsert")
69 |
70 | val diff = System.nanoTime() - start
71 | println("Elapsed time: " + diff + "ns")
72 | }
73 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/BatchInsertTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
4 | import org.apache.spark.sql.types.{IntegerType, StringType}
5 | import org.apache.spark.sql.{DataFrame, SaveMode}
6 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
7 |
8 | class BatchInsertTest extends IntegrationSuiteBase with BeforeAndAfterEach with BeforeAndAfterAll {
9 | var df: DataFrame = _
10 |
11 | override def beforeEach(): Unit = {
12 | super.beforeEach()
13 |
14 | df = spark.createDF(
15 | List(
16 | (1, "Jack", 20),
17 | (2, "Dan", 30),
18 | (3, "Bob", 15),
19 | (4, "Alice", 40)
20 | ),
21 | List(("id", IntegerType, false), ("name", StringType, true), ("age", IntegerType, true))
22 | )
23 |
24 | df.write
25 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
26 | .mode(SaveMode.Overwrite)
27 | .option("tableKey.primary", "id")
28 | .save("testdb.batchinsert")
29 | }
30 |
31 | it("insert into a new table") {
32 | df = spark.createDF(
33 | List((5, "Eric", 5)),
34 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
35 | )
36 | df.write
37 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
38 | .option("tableKey.primary", "id")
39 | .option("onDuplicateKeySQL", "age = age + 1")
40 | .option("insertBatchSize", 10)
41 | .mode(SaveMode.Append)
42 | .save("testdb.batchinsertnew")
43 |
44 | val actualDF =
45 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.batchinsertnew")
46 | assertSmallDataFrameEquality(actualDF, df)
47 | }
48 |
49 | def insertAndCheckContent(batchSize: Int,
50 | dfToInsert: List[Any],
51 | expectedContent: List[Any]): Unit = {
52 | df = spark.createDF(
53 | dfToInsert,
54 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
55 | )
56 | insertValues("testdb.batchinsert", df, "age = age + 1", batchSize)
57 |
58 | val actualDF =
59 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.batchinsert")
60 | assertSmallDataFrameEquality(
61 | actualDF,
62 | spark
63 | .createDF(
64 | expectedContent,
65 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
66 | ),
67 | orderedComparison = false
68 | )
69 | }
70 |
71 | it("insert a new row") {
72 | insertAndCheckContent(
73 | 10,
74 | List((5, "Eric", 5)),
75 | List(
76 | (1, "Jack", 20),
77 | (2, "Dan", 30),
78 | (3, "Bob", 15),
79 | (4, "Alice", 40),
80 | (5, "Eric", 5)
81 | )
82 | )
83 | }
84 |
85 | it("insert several new rows with small batchSize") {
86 | insertAndCheckContent(
87 | 2,
88 | List(
89 | (5, "Jack", 20),
90 | (6, "Mark", 30),
91 | (7, "Fred", 15),
92 | (8, "Jany", 40),
93 | (9, "Monica", 5)
94 | ),
95 | List(
96 | (1, "Jack", 20),
97 | (2, "Dan", 30),
98 | (3, "Bob", 15),
99 | (4, "Alice", 40),
100 | (5, "Jack", 20),
101 | (6, "Mark", 30),
102 | (7, "Fred", 15),
103 | (8, "Jany", 40),
104 | (9, "Monica", 5)
105 | )
106 | )
107 | }
108 |
109 | it("insert exactly batchSize rows") {
110 | insertAndCheckContent(
111 | 2,
112 | List(
113 | (5, "Jack", 20),
114 | (6, "Mark", 30)
115 | ),
116 | List(
117 | (1, "Jack", 20),
118 | (2, "Dan", 30),
119 | (3, "Bob", 15),
120 | (4, "Alice", 40),
121 | (5, "Jack", 20),
122 | (6, "Mark", 30)
123 | )
124 | )
125 | }
126 |
127 | it("negative batchsize") {
128 | insertAndCheckContent(
129 | -2,
130 | List(
131 | (5, "Jack", 20),
132 | (6, "Mark", 30)
133 | ),
134 | List(
135 | (1, "Jack", 20),
136 | (2, "Dan", 30),
137 | (3, "Bob", 15),
138 | (4, "Alice", 40),
139 | (5, "Jack", 20),
140 | (6, "Mark", 30)
141 | )
142 | )
143 | }
144 |
145 | it("empty insert") {
146 | insertAndCheckContent(
147 | 2,
148 | List(),
149 | List(
150 | (1, "Jack", 20),
151 | (2, "Dan", 30),
152 | (3, "Bob", 15),
153 | (4, "Alice", 40)
154 | )
155 | )
156 | }
157 |
158 | it("insert one existing row") {
159 | insertAndCheckContent(
160 | 2,
161 | List((1, "Jack", 20)),
162 | List(
163 | (1, "Jack", 21),
164 | (2, "Dan", 30),
165 | (3, "Bob", 15),
166 | (4, "Alice", 40)
167 | )
168 | )
169 | }
170 |
171 | it("insert several existing rows") {
172 | insertAndCheckContent(
173 | 2,
174 | List(
175 | (1, "Jack", 20),
176 | (2, "Dan", 30),
177 | (3, "Bob", 15),
178 | (4, "Alice", 40)
179 | ),
180 | List(
181 | (1, "Jack", 21),
182 | (2, "Dan", 31),
183 | (3, "Bob", 16),
184 | (4, "Alice", 41)
185 | )
186 | )
187 | }
188 |
189 | it("insert existing and non existing row") {
190 | insertAndCheckContent(
191 | 2,
192 | List(
193 | (1, "Jack", 20),
194 | (5, "Mark", 30)
195 | ),
196 | List(
197 | (1, "Jack", 21),
198 | (2, "Dan", 30),
199 | (3, "Bob", 15),
200 | (4, "Alice", 40),
201 | (5, "Mark", 30)
202 | )
203 | )
204 | }
205 |
206 | it("insert NULL") {
207 | insertAndCheckContent(
208 | 2,
209 | List(
210 | (5, null, null)
211 | ),
212 | List(
213 | (1, "Jack", 20),
214 | (2, "Dan", 30),
215 | (3, "Bob", 15),
216 | (4, "Alice", 40),
217 | (5, null, null)
218 | )
219 | )
220 | }
221 |
222 | it("non-existing column") {
223 | executeQueryWithLog("DROP TABLE IF EXISTS batchinsert")
224 | executeQueryWithLog("CREATE TABLE batchinsert(id INT, name TEXT)")
225 |
226 | df = spark.createDF(
227 | List((5, "EBCEFGRHFED" * 100, 50)),
228 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
229 | )
230 |
231 | try {
232 | insertValues("testdb.batchinsert", df, "age = age + 1", 10)
233 | fail()
234 | } catch {
235 | case e: Exception if e.getMessage.contains("Unknown column 'age' in 'field list'") =>
236 | }
237 | }
238 | }
239 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/BenchmarkSerializingTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
4 | import org.apache.spark.sql.{DataFrame, SaveMode}
5 | import org.apache.spark.sql.types.{IntegerType, LongType, StringType}
6 |
7 | class BenchmarkSerializingTest extends IntegrationSuiteBase {
8 |
9 | val dbName = "testdb"
10 | val tableName = "avro_table"
11 |
12 | val writeIterations = 10
13 |
14 | val smallDataCount = 1
15 | val mediumDataCount = 1000
16 | val largeDataCount = 100000
17 |
18 | val smallSchema = List(("id", StringType, false))
19 | val mediumSchema = List(("id", StringType, false),
20 | ("name", StringType, false),
21 | ("surname", StringType, false),
22 | ("age", IntegerType, false))
23 | val largeSchema = List(
24 | ("id", StringType, false),
25 | ("name", StringType, false),
26 | ("surname", StringType, false),
27 | ("someString", StringType, false),
28 | ("anotherString", StringType, false),
29 | ("age", IntegerType, false),
30 | ("secondNumber", IntegerType, false),
31 | ("thirdNumber", LongType, false)
32 | )
33 |
34 | def generateSmallData(index: Int) = s"$index"
35 | def generateMediumData(index: Int) = (s"$index", s"name$index", s"surname$index", index)
36 | def generateLargeData(index: Int) =
37 | (s"$index",
38 | s"name$index",
39 | s"surname$index",
40 | s"someString$index",
41 | s"anotherString$index",
42 | index,
43 | index + 1,
44 | index * 2L)
45 |
46 | val generateDataMap: Map[List[scala.Product], Int => Any] = Map(
47 | smallSchema -> generateSmallData,
48 | mediumSchema -> generateMediumData,
49 | largeSchema -> generateLargeData
50 | )
51 |
52 | override def beforeEach(): Unit = {
53 | super.beforeEach()
54 | executeQueryWithLog(s"drop table if exists $dbName.$tableName")
55 | }
56 |
57 | def doWriteOperation(dataFrame: DataFrame, options: Map[String, String]): Long = {
58 | val startTime = System.currentTimeMillis()
59 | for (_ <- 1 to writeIterations) {
60 | dataFrame.write
61 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
62 | .mode(SaveMode.Append)
63 | .options(options)
64 | .save(s"$dbName.$tableName")
65 | }
66 | val endTime = System.currentTimeMillis()
67 | endTime - startTime
68 | }
69 |
70 | def doTestOperation(dataCount: Int,
71 | schema: List[scala.Product],
72 | options: Map[String, String]): Unit = {
73 | val dataFrameValues = List.tabulate(dataCount)(generateDataMap(schema))
74 | val dataFrame =
75 | spark.createDF(dataFrameValues, schema)
76 | val timeSpend = doWriteOperation(dataFrame, options)
77 | print(s"Time spend for $writeIterations iterations: $timeSpend")
78 | }
79 |
80 | describe("Avro testing") {
81 |
82 | it("small data | small schema") {
83 | doTestOperation(smallDataCount,
84 | smallSchema,
85 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
86 | }
87 |
88 | it("medium data | small schema") {
89 | doTestOperation(mediumDataCount,
90 | smallSchema,
91 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
92 | }
93 |
94 | it("large data | small schema") {
95 | doTestOperation(largeDataCount,
96 | smallSchema,
97 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
98 | }
99 |
100 | it("small data | medium schema") {
101 | doTestOperation(smallDataCount,
102 | mediumSchema,
103 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
104 | }
105 |
106 | it("medium data | medium schema") {
107 | doTestOperation(mediumDataCount,
108 | mediumSchema,
109 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
110 | }
111 |
112 | it("large data | medium schema") {
113 | doTestOperation(largeDataCount,
114 | mediumSchema,
115 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
116 | }
117 |
118 | it("small data | large schema") {
119 | doTestOperation(smallDataCount,
120 | largeSchema,
121 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
122 | }
123 |
124 | it("medium data | large schema") {
125 | doTestOperation(mediumDataCount,
126 | largeSchema,
127 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
128 | }
129 |
130 | it("large data | large schema") {
131 | doTestOperation(largeDataCount,
132 | largeSchema,
133 | Map(SinglestoreOptions.LOAD_DATA_FORMAT -> "avro"))
134 | }
135 | }
136 |
137 | describe("CSV testing") {
138 | it("small data | small schema") {
139 | doTestOperation(smallDataCount, smallSchema, Map.empty)
140 | }
141 |
142 | it("medium data | small schema") {
143 | doTestOperation(mediumDataCount, smallSchema, Map.empty)
144 | }
145 |
146 | it("large data | small schema") {
147 | doTestOperation(largeDataCount, smallSchema, Map.empty)
148 | }
149 |
150 | it("small data | medium schema") {
151 | doTestOperation(smallDataCount, mediumSchema, Map.empty)
152 | }
153 |
154 | it("medium data | medium schema") {
155 | doTestOperation(mediumDataCount, mediumSchema, Map.empty)
156 | }
157 |
158 | it("large data | medium schema") {
159 | doTestOperation(largeDataCount, mediumSchema, Map.empty)
160 | }
161 |
162 | it("small data | large schema") {
163 | doTestOperation(smallDataCount, largeSchema, Map.empty)
164 | }
165 |
166 | it("medium data | large schema") {
167 | doTestOperation(mediumDataCount, largeSchema, Map.empty)
168 | }
169 |
170 | it("large data | large schema") {
171 | doTestOperation(largeDataCount, largeSchema, Map.empty)
172 | }
173 | }
174 | }
175 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/BinaryTypeBenchmark.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Connection, DriverManager}
4 | import java.util.Properties
5 |
6 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
7 | import com.singlestore.spark.BatchInsertBenchmark.{df, executeQuery}
8 | import org.apache.spark.sql.types.{BinaryType, IntegerType}
9 | import org.apache.spark.sql.{SaveMode, SparkSession}
10 |
11 | import scala.util.Random
12 |
13 | // BinaryTypeBenchmark is written to writing of the BinaryType with CPU profiler
14 | // this feature is accessible in Ultimate version of IntelliJ IDEA
15 | // see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
16 | object BinaryTypeBenchmark extends App {
17 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost")
18 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506")
19 |
20 | val spark: SparkSession = SparkSession
21 | .builder()
22 | .master("local")
23 | .config("spark.sql.shuffle.partitions", "1")
24 | .config("spark.driver.bindAddress", "localhost")
25 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}")
26 | .config("spark.datasource.singlestore.database", "testdb")
27 | .getOrCreate()
28 |
29 | def jdbcConnection: Loan[Connection] = {
30 | val connProperties = new Properties()
31 | connProperties.put("user", "root")
32 |
33 | Loan(
34 | DriverManager.getConnection(
35 | s"jdbc:singlestore://$masterHost:$masterPort",
36 | connProperties
37 | ))
38 | }
39 |
40 | def executeQuery(sql: String): Unit = {
41 | jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
42 | }
43 |
44 | executeQuery("set global default_partitions_per_leaf = 2")
45 | executeQuery("drop database if exists testdb")
46 | executeQuery("create database testdb")
47 |
48 | def genRandomByte(): Byte = (Random.nextInt(256) - 128).toByte
49 | def genRandomRow(): Array[Byte] =
50 | Array.fill(1000)(genRandomByte())
51 |
52 | val df = spark.createDF(
53 | List.fill(100000)(genRandomRow()).zipWithIndex,
54 | List(("data", BinaryType, true), ("id", IntegerType, true))
55 | )
56 |
57 | val start1 = System.nanoTime()
58 | df.write
59 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
60 | .mode(SaveMode.Overwrite)
61 | .save("testdb.LoadData")
62 |
63 | println("Elapsed time: " + (System.nanoTime() - start1) + "ns [LoadData CSV]")
64 |
65 | val start2 = System.nanoTime()
66 | df.write
67 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
68 | .option("tableKey.primary", "id")
69 | .option("onDuplicateKeySQL", "data = data")
70 | .mode(SaveMode.Overwrite)
71 | .save("testdb.BatchInsert")
72 |
73 | println("Elapsed time: " + (System.nanoTime() - start2) + "ns [BatchInsert]")
74 |
75 | val avroStart = System.nanoTime()
76 | df.write
77 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
78 | .mode(SaveMode.Overwrite)
79 | .option(SinglestoreOptions.LOAD_DATA_FORMAT, "Avro")
80 | .save("testdb.AvroSerialization")
81 | println("Elapsed time: " + (System.nanoTime() - avroStart) + "ns [LoadData Avro] ")
82 | }
83 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/ExternalHostTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.PreparedStatement
4 | import java.util.Properties
5 |
6 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
7 | import com.singlestore.spark.JdbcHelpers.getDDLConnProperties
8 | import com.singlestore.spark.SQLGen.VariableList
9 | import org.apache.spark.sql.DataFrame
10 | import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
11 | import org.apache.spark.sql.types.{IntegerType, StringType}
12 | import org.mockito.ArgumentMatchers.any
13 | import org.mockito.MockitoSugar
14 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
15 |
16 | class ExternalHostTest
17 | extends IntegrationSuiteBase
18 | with BeforeAndAfterEach
19 | with BeforeAndAfterAll
20 | with MockitoSugar {
21 |
22 | val testDb = "testdb"
23 | val testCollection = "externalHost"
24 | val mvNodesCollection = "mv_nodes"
25 |
26 | var df: DataFrame = _
27 |
28 | override def beforeEach(): Unit = {
29 | super.beforeEach()
30 | spark.sqlContext.setConf("spark.datasource.singlestore.enableParallelRead", "forced")
31 | spark.sqlContext.setConf("spark.datasource.singlestore.parallelRead.Features", "ReadFromLeaves")
32 | df = spark.createDF(
33 | List((2, "B")),
34 | List(("id", IntegerType, true), ("name", StringType, true))
35 | )
36 | writeTable(s"$testDb.$testCollection", df)
37 | }
38 |
39 | def setupMockJdbcHelper(): Unit = {
40 | when(JdbcHelpers.loadSchema(any[SinglestoreOptions], any[String], any[SQLGen.VariableList]))
41 | .thenCallRealMethod()
42 | when(JdbcHelpers.getDDLConnProperties(any[SinglestoreOptions], any[Boolean]))
43 | .thenCallRealMethod()
44 | when(JdbcHelpers.getDMLConnProperties(any[SinglestoreOptions], any[Boolean]))
45 | .thenCallRealMethod()
46 | when(JdbcHelpers.getConnProperties(any[SinglestoreOptions], any[Boolean], any[String]))
47 | .thenCallRealMethod()
48 | when(JdbcHelpers.explainJSONQuery(any[SinglestoreOptions], any[String], any[VariableList]))
49 | .thenCallRealMethod()
50 | when(JdbcHelpers.partitionHostPorts(any[SinglestoreOptions], any[String]))
51 | .thenCallRealMethod()
52 | when(JdbcHelpers.fillStatement(any[PreparedStatement], any[VariableList]))
53 | .thenCallRealMethod()
54 | }
55 |
56 | describe("success tests") {
57 | it("low SingleStore version") {
58 |
59 | withObjectMocked[JdbcHelpers.type] {
60 |
61 | setupMockJdbcHelper()
62 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("6.8.10")
63 |
64 | val actualDF =
65 | spark.read
66 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
67 | .option("useExternalHost", "true")
68 | .load(s"$testDb.$testCollection")
69 |
70 | assertSmallDataFrameEquality(
71 | actualDF,
72 | df
73 | )
74 | }
75 | }
76 |
77 | it("valid external host") {
78 |
79 | withObjectMocked[JdbcHelpers.type] {
80 |
81 | setupMockJdbcHelper()
82 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0")
83 |
84 | val externalHostMap = Map(
85 | "172.17.0.2:3307" -> "172.17.0.2:3307"
86 | )
87 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions]))
88 | .thenReturn(externalHostMap)
89 |
90 | val actualDF =
91 | spark.read
92 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
93 | .option("useExternalHost", "true")
94 | .load(s"$testDb.$testCollection")
95 |
96 | assertSmallDataFrameEquality(
97 | actualDF,
98 | df
99 | )
100 | }
101 | }
102 |
103 | it("empty external host map") {
104 |
105 | withObjectMocked[JdbcHelpers.type] {
106 |
107 | setupMockJdbcHelper()
108 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0")
109 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions]))
110 | .thenReturn(Map.empty[String, String])
111 |
112 | val actualDf = spark.read
113 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
114 | .option("useExternalHost", "true")
115 | .load(s"$testDb.$testCollection")
116 |
117 | assertSmallDataFrameEquality(df, actualDf)
118 | }
119 | }
120 |
121 | it("wrong external host map") {
122 |
123 | withObjectMocked[JdbcHelpers.type] {
124 |
125 | setupMockJdbcHelper()
126 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0")
127 |
128 | val externalHostMap = Map(
129 | "172.17.0.3:3307" -> "172.17.0.100:3307",
130 | "172.17.0.4:3307" -> "172.17.0.200:3307"
131 | )
132 |
133 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions]))
134 | .thenReturn(externalHostMap)
135 |
136 | val actualDf = spark.read
137 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
138 | .option("useExternalHost", "true")
139 | .load(s"$testDb.$testCollection")
140 |
141 | assertSmallDataFrameEquality(df, actualDf)
142 | }
143 | }
144 |
145 | it("valid external host function") {
146 |
147 | val mvNodesDf = spark.createDF(
148 | List(("172.17.0.2", 3307, "172.17.0.10", 3310),
149 | ("172.17.0.20", 3312, "172.17.0.100", null),
150 | ("172.17.0.2", 3308, null, 3310),
151 | ("172.17.0.15", 3311, null, null)),
152 | List(("IP_ADDR", StringType, true),
153 | ("PORT", IntegerType, true),
154 | ("EXTERNAL_HOST", StringType, true),
155 | ("EXTERNAL_PORT", IntegerType, true))
156 | )
157 | writeTable(s"$testDb.$mvNodesCollection", mvNodesDf)
158 |
159 | val conf = new SinglestoreOptions(
160 | s"$masterHost:$masterPort",
161 | List.empty[String],
162 | "root",
163 | masterPassword,
164 | None,
165 | Map.empty[String, String],
166 | false,
167 | false,
168 | Automatic,
169 | List(ReadFromLeaves),
170 | 0,
171 | 0,
172 | 0,
173 | 0,
174 | true,
175 | Set.empty,
176 | Truncate,
177 | SinglestoreOptions.CompressionType.GZip,
178 | SinglestoreOptions.LoadDataFormat.CSV,
179 | List.empty[SinglestoreOptions.TableKey],
180 | None,
181 | 10,
182 | 10,
183 | false,
184 | SinglestoreConnectionPoolOptions(enabled = true, -1, 8, 30000, 1000, -1, -1),
185 | SinglestoreConnectionPoolOptions(enabled = true, -1, 8, 2000, 1000, -1, -1),
186 | "3.4.0"
187 | )
188 |
189 | val conn =
190 | SinglestoreConnectionPool.getConnection(getDDLConnProperties(conf, isOnExecutor = false))
191 | val statement = conn.prepareStatement(s"""
192 | SELECT IP_ADDR,
193 | PORT,
194 | EXTERNAL_HOST,
195 | EXTERNAL_PORT
196 | FROM testdb.mv_nodes;
197 | """)
198 | val spyConn = spy(conn)
199 | val spyStatement = spy(statement)
200 | when(spyConn.prepareStatement(s"""
201 | SELECT IP_ADDR,
202 | PORT,
203 | EXTERNAL_HOST,
204 | EXTERNAL_PORT
205 | FROM INFORMATION_SCHEMA.MV_NODES
206 | WHERE TYPE = "LEAF";
207 | """)).thenReturn(spyStatement)
208 |
209 | withObjectMocked[SinglestoreConnectionPool.type] {
210 |
211 | when(SinglestoreConnectionPool.getConnection(any[Properties])).thenReturn(spyConn)
212 | val externalHostPorts = JdbcHelpers.externalHostPorts(conf)
213 | val expectedResult = Map(
214 | "172.17.0.2:3307" -> "172.17.0.10:3310"
215 | )
216 | assert(externalHostPorts.equals(expectedResult))
217 | }
218 | }
219 | }
220 |
221 | describe("failed tests") {
222 |
223 | it("wrong external host") {
224 |
225 | withObjectMocked[JdbcHelpers.type] {
226 |
227 | setupMockJdbcHelper()
228 | when(JdbcHelpers.getSinglestoreVersion(any[SinglestoreOptions])).thenReturn("7.1.0")
229 |
230 | val externalHostMap = Map(
231 | "172.17.0.2:3307" -> "somehost:3307"
232 | )
233 | when(JdbcHelpers.externalHostPorts(any[SinglestoreOptions]))
234 | .thenReturn(externalHostMap)
235 |
236 | try {
237 | spark.read
238 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
239 | .option("useExternalHost", "true")
240 | .option("enableParallelRead", "forced")
241 | .option("parallelRead.Features", "ReadFromLeaves")
242 | .load(s"$testDb.$testCollection")
243 | .collect()
244 | fail("Exception expected")
245 | } catch {
246 | case ex: Throwable =>
247 | ex match {
248 | case sqlEx: ParallelReadFailedException =>
249 | assert(
250 | sqlEx.getMessage startsWith "Failed to read data in parallel.\nTried following parallel read features:")
251 | case _ => fail("ParallelReadFailedException expected")
252 | }
253 | }
254 | }
255 | }
256 | }
257 | }
258 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/IntegrationSuiteBase.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Connection, DriverManager}
4 | import java.util.{Properties, TimeZone}
5 |
6 | import com.github.mrpowers.spark.fast.tests.DataFrameComparer
7 | import org.apache.log4j.{Level, LogManager}
8 | import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
9 | import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
10 | import org.scalatest._
11 | import org.scalatest.funspec.AnyFunSpec
12 | import com.singlestore.spark.JdbcHelpers.executeQuery
13 | import com.singlestore.spark.SQLGen.SinglestoreVersion
14 | import com.singlestore.spark.SQLHelper._
15 |
16 | import scala.util.Random
17 |
18 | trait IntegrationSuiteBase
19 | extends AnyFunSpec
20 | with BeforeAndAfterEach
21 | with BeforeAndAfterAll
22 | with DataFrameComparer
23 | with LazyLogging {
24 | object ExcludeFromSpark35 extends Tag("ExcludeFromSpark35")
25 | object ExcludeFromSpark34 extends Tag("ExcludeFromSpark34")
26 | object ExcludeFromSpark33 extends Tag("ExcludeFromSpark33")
27 | object ExcludeFromSpark32 extends Tag("ExcludeFromSpark32")
28 | object ExcludeFromSpark31 extends Tag("ExcludeFromSpark31")
29 |
30 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost")
31 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506")
32 |
33 | final val continuousIntegration: Boolean = sys.env
34 | .getOrElse("CONTINUOUS_INTEGRATION", "false") == "true"
35 |
36 | final val masterPassword: String = sys.env.getOrElse("SINGLESTORE_PASSWORD", "1")
37 | final val masterJWTPassword: String = sys.env.getOrElse("SINGLESTORE_JWT_PASSWORD", "")
38 | final val forceReadFromLeaves: Boolean =
39 | sys.env.getOrElse("FORCE_READ_FROM_LEAVES", "FALSE").equalsIgnoreCase("TRUE")
40 |
41 | var spark: SparkSession = _
42 |
43 | val jdbcDefaultProps = new Properties()
44 | jdbcDefaultProps.setProperty(JDBCOptions.JDBC_TABLE_NAME, "XXX")
45 | jdbcDefaultProps.setProperty(JDBCOptions.JDBC_DRIVER_CLASS, "org.mariadb.jdbc.Driver")
46 | jdbcDefaultProps.setProperty("user", "root")
47 | jdbcDefaultProps.setProperty("password", masterPassword)
48 |
49 | val version: SinglestoreVersion = {
50 | val conn =
51 | DriverManager.getConnection(s"jdbc:mysql://$masterHost:$masterPort", jdbcDefaultProps)
52 | val resultSet = executeQuery(conn, "select @@memsql_version")
53 | SinglestoreVersion(resultSet.next().getString(0))
54 | }
55 |
56 | val canDoParallelReadFromAggregators: Boolean = version.atLeast("7.5.0") && !forceReadFromLeaves
57 |
58 | override def beforeAll(): Unit = {
59 | // override global JVM timezone to GMT
60 | TimeZone.setDefault(TimeZone.getTimeZone("GMT"))
61 |
62 | val conn =
63 | DriverManager.getConnection(s"jdbc:mysql://$masterHost:$masterPort", jdbcDefaultProps)
64 | try {
65 | // make singlestore use less memory
66 | executeQuery(conn, "set global default_partitions_per_leaf = 2")
67 | executeQuery(conn, "set global data_conversion_compatibility_level = '6.0'")
68 |
69 | executeQuery(conn, "drop database if exists testdb")
70 | executeQuery(conn, "create database testdb")
71 | } finally {
72 | conn.close()
73 | }
74 | }
75 |
76 | override def withFixture(test: NoArgTest): Outcome = {
77 | def retryThrowable(t: Throwable): Boolean = t match {
78 | case _: java.sql.SQLNonTransientConnectionException => true
79 | case _ => false
80 | }
81 |
82 | @scala.annotation.tailrec
83 | def runWithRetry(attempts: Int, lastError: Option[Throwable]): Outcome = {
84 | if (attempts == 0) {
85 | return Canceled(
86 | s"too many SQLNonTransientConnectionExceptions occurred, last error was:\n${lastError.get}")
87 | }
88 |
89 | super.withFixture(test) match {
90 | case Failed(t: Throwable) if retryThrowable(t) || retryThrowable(t.getCause) => {
91 | Thread.sleep(3000)
92 | runWithRetry(attempts - 1, Some(t))
93 | }
94 | case other => other
95 | }
96 | }
97 |
98 | runWithRetry(attempts = 5, None)
99 | }
100 |
101 | override def beforeEach(): Unit = {
102 | super.beforeEach()
103 |
104 | val seed = Random.nextLong()
105 | log.debug("Random seed: " + seed)
106 | Random.setSeed(seed)
107 |
108 | if (!continuousIntegration) {
109 | LogManager.getLogger("com.singlestore.spark").setLevel(Level.TRACE)
110 | }
111 |
112 | spark = SparkSession
113 | .builder()
114 | .master(if (canDoParallelReadFromAggregators) "local[2]" else "local")
115 | .appName("singlestore-integration-tests")
116 | .config("spark.sql.shuffle.partitions", "1")
117 | .config("spark.driver.bindAddress", "localhost")
118 | .config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT")
119 | .config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT")
120 | .config("spark.sql.session.timeZone", "GMT")
121 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}")
122 | .config("spark.datasource.singlestore.user", "root-ssl")
123 | .config("spark.datasource.singlestore.password", "")
124 | .config("spark.datasource.singlestore.enableAsserts", "true")
125 | .config("spark.datasource.singlestore.enableParallelRead", "automaticLite")
126 | .config("spark.datasource.singlestore.parallelRead.Features",
127 | if (forceReadFromLeaves) "ReadFromLeaves" else "ReadFromAggregators,ReadFromLeaves")
128 | .config("spark.datasource.singlestore.database", "testdb")
129 | .config("spark.datasource.singlestore.useSSL", "true")
130 | .config("spark.datasource.singlestore.serverSslCert",
131 | s"${System.getProperty("user.dir")}/scripts/ssl/test-ca-cert.pem")
132 | .config("spark.datasource.singlestore.disableSslHostnameVerification", "true")
133 | .config("spark.sql.crossJoin.enabled", "true")
134 | .getOrCreate()
135 | }
136 |
137 | override def afterEach(): Unit = {
138 | super.afterEach()
139 | spark.close()
140 | }
141 |
142 | def executeQueryWithLog(sql: String): Unit = {
143 | log.trace(s"executing query: ${sql}")
144 | spark.executeSinglestoreQuery(sql)
145 | }
146 |
147 | def jdbcOptions(dbtable: String): Map[String, String] = Map(
148 | "url" -> s"jdbc:mysql://$masterHost:$masterPort",
149 | "dbtable" -> dbtable,
150 | "user" -> "root",
151 | "password" -> masterPassword,
152 | "pushDownPredicate" -> "false"
153 | )
154 |
155 | def jdbcOptionsSQL(dbtable: String): String =
156 | jdbcOptions(dbtable)
157 | .foldLeft(List.empty[String])({
158 | case (out, (k, v)) => s"'${k}'='${v}'" :: out
159 | })
160 | .mkString(", ")
161 |
162 | def writeTable(dbtable: String, df: DataFrame, saveMode: SaveMode = SaveMode.Overwrite): Unit =
163 | df.write
164 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
165 | .mode(saveMode)
166 | .save(dbtable)
167 |
168 | def insertValues(dbtable: String,
169 | df: DataFrame,
170 | onDuplicateKeySQL: String,
171 | insertBatchSize: Long): Unit =
172 | df.write
173 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
174 | .option("onDuplicateKeySQL", onDuplicateKeySQL)
175 | .option("insertBatchSize", insertBatchSize)
176 | .mode(SaveMode.Append)
177 | .save(dbtable)
178 | }
179 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/IssuesTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
4 | import org.apache.spark.sql.SaveMode
5 | import org.apache.spark.sql.functions._
6 | import org.apache.spark.sql.types._
7 |
8 | class IssuesTest extends IntegrationSuiteBase {
9 | it("https://github.com/memsql/singlestore-spark-connector/issues/41") {
10 | executeQueryWithLog("""
11 | | create table if not exists testdb.issue41 (
12 | | start_video_pos smallint(5) unsigned DEFAULT NULL
13 | | )
14 | |""".stripMargin)
15 |
16 | val df = spark.createDF(
17 | List(1.toShort, 2.toShort, 3.toShort, 4.toShort),
18 | List(("start_video_pos", ShortType, true))
19 | )
20 | df.write
21 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
22 | .mode(SaveMode.Append)
23 | .save("issue41")
24 |
25 | val df2 = spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("issue41")
26 | assertSmallDataFrameEquality(df2,
27 | spark.createDF(
28 | List(1, 2, 3, 4),
29 | List(("start_video_pos", IntegerType, true))
30 | ),
31 | orderedComparison = false)
32 | }
33 |
34 | it("https://memsql.zendesk.com/agent/tickets/10451") {
35 | // parallel read should support columnar scan with filter
36 | executeQueryWithLog("""
37 | | create table if not exists testdb.ticket10451 (
38 | | t text,
39 | | h bigint(20) DEFAULT NULL,
40 | | KEY h (h) USING CLUSTERED COLUMNSTORE
41 | | )
42 | | """.stripMargin)
43 |
44 | val df = spark.createDF(
45 | List(("hi", 2L), ("hi", 3L), ("foo", 4L)),
46 | List(("t", StringType, true), ("h", LongType, true))
47 | )
48 | df.write
49 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
50 | .mode(SaveMode.Append)
51 | .save("ticket10451")
52 |
53 | val df2 = spark.read
54 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
55 | .load("ticket10451")
56 | .where(col("t") === "hi")
57 | .where(col("h") === 3L)
58 |
59 | assert(df2.rdd.getNumPartitions > 1)
60 | assertSmallDataFrameEquality(df2,
61 | spark.createDF(
62 | List(("hi", 3L)),
63 | List(("t", StringType, true), ("h", LongType, true))
64 | ))
65 | }
66 |
67 | it("supports reading count from query") {
68 | val df = spark.createDF(
69 | List((1, "Albert"), (5, "Ronny"), (7, "Ben"), (9, "David")),
70 | List(("id", IntegerType, true), ("name", StringType, true))
71 | )
72 | writeTable("testdb.testcount", df)
73 | val data = spark.read
74 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
75 | .option("query", "select count(1) from testcount where id > 1 ")
76 | .option("database", "testdb")
77 | .load()
78 | .collect()
79 | val count = data.head.getLong(0)
80 | assert(count == 3)
81 | }
82 |
83 | it("handles exceptions raised by asCode") {
84 | // in certain cases asCode will raise NullPointerException due to this bug
85 | // https://issues.apache.org/jira/browse/SPARK-31403
86 | writeTable("testdb.nulltest",
87 | spark.createDF(
88 | List(1, null),
89 | List(("i", IntegerType, true))
90 | ))
91 | spark.sql(s"create table nulltest using singlestore options ('dbtable'='testdb.nulltest')")
92 |
93 | val df2 = spark.sql("select if(isnull(i), null, 2) as x from nulltest order by i")
94 |
95 | assertSmallDataFrameEquality(df2,
96 | spark.createDF(
97 | List(null, 2),
98 | List(("x", IntegerType, true))
99 | ))
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/LoadDataBenchmark.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Connection, Date, DriverManager}
4 | import java.time.{Instant, LocalDate}
5 | import java.util.Properties
6 |
7 | import org.apache.spark.sql.types._
8 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
9 | import org.apache.spark.sql.{SaveMode, SparkSession}
10 |
11 | import scala.util.Random
12 |
13 | // LoadDataBenchmark is written to test load data with CPU profiler
14 | // this feature is accessible in Ultimate version of IntelliJ IDEA
15 | // see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
16 | object LoadDataBenchmark extends App {
17 | final val masterHost: String = sys.props.getOrElse("singlestore.host", "localhost")
18 | final val masterPort: String = sys.props.getOrElse("singlestore.port", "5506")
19 |
20 | val spark: SparkSession = SparkSession
21 | .builder()
22 | .master("local")
23 | .config("spark.sql.shuffle.partitions", "1")
24 | .config("spark.driver.bindAddress", "localhost")
25 | .config("spark.datasource.singlestore.ddlEndpoint", s"${masterHost}:${masterPort}")
26 | .config("spark.datasource.singlestore.database", "testdb")
27 | .getOrCreate()
28 |
29 | def jdbcConnection: Loan[Connection] = {
30 | val connProperties = new Properties()
31 | connProperties.put("user", "root")
32 |
33 | Loan(
34 | DriverManager.getConnection(
35 | s"jdbc:singlestore://$masterHost:$masterPort",
36 | connProperties
37 | ))
38 | }
39 |
40 | def executeQuery(sql: String): Unit = {
41 | jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
42 | }
43 |
44 | executeQuery("set global default_partitions_per_leaf = 2")
45 | executeQuery("drop database if exists testdb")
46 | executeQuery("create database testdb")
47 |
48 | def genRow(): (Long, Int, Double, String) =
49 | (Random.nextLong(), Random.nextInt(), Random.nextDouble(), Random.nextString(20))
50 | val df =
51 | spark.createDF(
52 | List.fill(1000000)(genRow()),
53 | List(("LongType", LongType, true),
54 | ("IntType", IntegerType, true),
55 | ("DoubleType", DoubleType, true),
56 | ("StringType", StringType, true))
57 | )
58 |
59 | val start = System.nanoTime()
60 | df.write
61 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
62 | .mode(SaveMode.Append)
63 | .save("testdb.batchinsert")
64 |
65 | val diff = System.nanoTime() - start
66 | println("Elapsed time: " + diff + "ns [CSV serialization] ")
67 |
68 | executeQuery("truncate testdb.batchinsert")
69 |
70 | val avroStart = System.nanoTime()
71 | df.write
72 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
73 | .mode(SaveMode.Append)
74 | .option(SinglestoreOptions.LOAD_DATA_FORMAT, "Avro")
75 | .save("testdb.batchinsert")
76 | val avroDiff = System.nanoTime() - avroStart
77 | println("Elapsed time: " + avroDiff + "ns [Avro serialization] ")
78 | }
79 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/LoadDataTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
4 | import org.apache.spark.sql.types.{DecimalType, IntegerType, NullType, StringType}
5 | import org.apache.spark.sql.{DataFrame, SaveMode}
6 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
7 |
8 | import scala.util.Try
9 |
10 | class LoadDataTest extends IntegrationSuiteBase with BeforeAndAfterEach with BeforeAndAfterAll {
11 | var df: DataFrame = _
12 |
13 | override def beforeEach(): Unit = {
14 | super.beforeEach()
15 |
16 | df = spark.createDF(
17 | List(),
18 | List(("id", IntegerType, false), ("name", StringType, true), ("age", IntegerType, true))
19 | )
20 |
21 | writeTable("testdb.loaddata", df)
22 | }
23 |
24 | it("appends row without `age` field") {
25 | df = spark.createDF(
26 | List((2, "B")),
27 | List(("id", IntegerType, true), ("name", StringType, true))
28 | )
29 | writeTable("testdb.loaddata", df, SaveMode.Append)
30 |
31 | val actualDF =
32 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata")
33 | assertSmallDataFrameEquality(
34 | actualDF,
35 | spark.createDF(
36 | List((2, "B", null)),
37 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
38 | ))
39 | }
40 |
41 | it("appends row without `name` field") {
42 | df = spark.createDF(
43 | List((3, 30)),
44 | List(("id", IntegerType, true), ("age", IntegerType, true))
45 | )
46 | writeTable("testdb.loaddata", df, SaveMode.Append)
47 |
48 | val actualDF =
49 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata")
50 | assertSmallDataFrameEquality(
51 | actualDF,
52 | spark.createDF(
53 | List((3, null, 30)),
54 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
55 | ))
56 | }
57 |
58 | it("should not append row without not nullable `id` field") {
59 | df = spark.createDF(
60 | List(("D", 40)),
61 | List(("name", StringType, true), ("age", IntegerType, true))
62 | )
63 |
64 | try {
65 | writeTable("testdb.loaddata", df, SaveMode.Append)
66 | fail()
67 | } catch {
68 | // error code 1364 is `Field 'id' doesn't have a default value`
69 | case e: Throwable if TestHelper.isSQLExceptionWithCode(e, List(1364)) =>
70 | }
71 | }
72 |
73 | it("appends row with all fields") {
74 | df = spark.createDF(
75 | List((5, "E", 50)),
76 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
77 | )
78 | writeTable("testdb.loaddata", df, SaveMode.Append)
79 |
80 | val actualDF =
81 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata")
82 | assertSmallDataFrameEquality(
83 | actualDF,
84 | spark.createDF(
85 | List((5, "E", 50)),
86 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
87 | ))
88 | }
89 |
90 | it("appends row only with `id` field") {
91 | df = spark.createDF(List(6), List(("id", IntegerType, true)))
92 | writeTable("testdb.loaddata", df, SaveMode.Append)
93 |
94 | val actualDF =
95 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata")
96 | assertSmallDataFrameEquality(
97 | actualDF,
98 | spark.createDF(
99 | List((6, null, null)),
100 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
101 | )
102 | )
103 | }
104 |
105 | it("appends row with all fields in wrong order") {
106 | df = spark.createDF(
107 | List(("WO", 101, 101)),
108 | List(("name", StringType, true), ("age", IntegerType, true), ("id", IntegerType, true))
109 | )
110 | writeTable("testdb.loaddata", df, SaveMode.Append)
111 |
112 | val actualDF =
113 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata")
114 | assertSmallDataFrameEquality(
115 | actualDF,
116 | spark.createDF(
117 | List((101, "WO", 101)),
118 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
119 | ))
120 | }
121 |
122 | it("appends row with `id` and `name` fields in wrong order") {
123 | df = spark.createDF(
124 | List(("WO2", 102)),
125 | List(("name", StringType, true), ("id", IntegerType, true))
126 | )
127 | writeTable("testdb.loaddata", df, SaveMode.Append)
128 |
129 | val actualDF =
130 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loaddata")
131 | assertSmallDataFrameEquality(
132 | actualDF,
133 | spark.createDF(
134 | List((102, "WO2", null)),
135 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
136 | )
137 | )
138 | }
139 |
140 | it("should not append row with more fields than expected") {
141 | df = spark.createDF(
142 | List((1, "1", 1, 1)),
143 | List(("id", IntegerType, true),
144 | ("name", StringType, true),
145 | ("age", IntegerType, true),
146 | ("extra", IntegerType, false))
147 | )
148 |
149 | try {
150 | writeTable("testdb.loaddata", df, SaveMode.Append)
151 | fail()
152 | } catch {
153 | // error code 1054 is `Unknown column 'extra' in 'field list'`
154 | case e: Throwable if TestHelper.isSQLExceptionWithCode(e, List(1054)) =>
155 | }
156 | }
157 |
158 | it("should not append row with wrong field name") {
159 | df = spark.createDF(
160 | List((1, "1", 1)),
161 | List(("id", IntegerType, true), ("wrongname", StringType, true), ("age", IntegerType, true)))
162 |
163 | try {
164 | writeTable("testdb.loaddata", df, SaveMode.Append)
165 | fail()
166 | } catch {
167 | // error code 1054 is `Unknown column 'wrongname' in 'field list'`
168 | case e: Throwable if TestHelper.isSQLExceptionWithCode(e, List(1054)) =>
169 | }
170 | }
171 |
172 | it("should fail creating table with NullType") {
173 | val tableName = "null_type"
174 |
175 | val dfNull = spark.createDF(List(null), List(("id", NullType, true)))
176 | val writeResult = Try {
177 | writeTable(s"testdb.$tableName", dfNull, SaveMode.Append)
178 | }
179 | assert(writeResult.isFailure)
180 | assert(
181 | writeResult.failed.get.getMessage
182 | .equals(
183 | "No corresponding SingleStore type found for NullType. If you want to use NullType, please write to an already existing SingleStore table."))
184 | }
185 |
186 | it("should succeed inserting NullType in existed table") {
187 | val tableName = "null_type"
188 |
189 | df = spark.createDF(List(1, 2, 3, 4, 5), List(("id", IntegerType, true)))
190 | writeTable(s"testdb.$tableName", df, SaveMode.Append)
191 |
192 | val dfNull = spark.createDF(List(null), List(("id", NullType, true)))
193 | writeTable(s"testdb.$tableName", dfNull, SaveMode.Append)
194 | }
195 |
196 | it("should write BigDecimal with Avro serializing") {
197 | val tableName = "bigDecimalAvro"
198 | val df = spark.createDF(
199 | List((1, "Alice", 213: BigDecimal)),
200 | List(("id", IntegerType, true), ("name", StringType, true), ("age", DecimalType(10, 0), true))
201 | )
202 | df.write
203 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
204 | .mode(SaveMode.Overwrite)
205 | .option(SinglestoreOptions.LOAD_DATA_FORMAT, "avro")
206 | .save(s"testdb.$tableName")
207 |
208 | val actualDF =
209 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load(s"testdb.$tableName")
210 | assertLargeDataFrameEquality(actualDF, df)
211 | }
212 |
213 | it("should work with `memsql` and `com.memsql.spark` source") {
214 | df = spark.createDF(
215 | List((5, "E", 50)),
216 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
217 | )
218 | writeTable("testdb.loaddata", df, SaveMode.Append)
219 |
220 | val actualDFShort =
221 | spark.read.format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT).load("testdb.loaddata")
222 | assertSmallDataFrameEquality(
223 | actualDFShort,
224 | spark.createDF(
225 | List((5, "E", 50)),
226 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
227 | ))
228 | val actualDF =
229 | spark.read.format(DefaultSource.MEMSQL_SOURCE_NAME).load("testdb.loaddata")
230 | assertSmallDataFrameEquality(
231 | actualDF,
232 | spark.createDF(
233 | List((5, "E", 50)),
234 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
235 | ))
236 | }
237 |
238 | it("non-existing column") {
239 | executeQueryWithLog("DROP TABLE IF EXISTS loaddata")
240 | executeQueryWithLog("CREATE TABLE loaddata(id INT, name TEXT)")
241 |
242 | df = spark.createDF(
243 | List((5, "EBCEFGRHFED" * 10000000, 50)),
244 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
245 | )
246 |
247 | try {
248 | writeTable("testdb.loaddata", df, SaveMode.Append)
249 | fail()
250 | } catch {
251 | case e: Exception if e.getMessage.contains("Unknown column 'age' in 'field list'") =>
252 | }
253 | }
254 |
255 | it("load data in batches") {
256 | val df1 = spark.createDF(
257 | List(
258 | (5, "Jack", 20),
259 | (6, "Mark", 30),
260 | (7, "Fred", 15),
261 | (8, "Jany", 40),
262 | (9, "Monica", 5)
263 | ),
264 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
265 | )
266 | val df2 = spark.createDF(
267 | List(
268 | (10, "Jany", 40),
269 | (11, "Monica", 5)
270 | ),
271 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
272 | )
273 | val df3 = spark.createDF(
274 | List(),
275 | List(("id", IntegerType, true), ("name", StringType, true), ("age", IntegerType, true))
276 | )
277 |
278 | df1.write
279 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
280 | .option("insertBatchSize", 2)
281 | .mode(SaveMode.Append)
282 | .save("testdb.loadDataBatches")
283 | df2.write
284 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
285 | .option("insertBatchSize", 2)
286 | .mode(SaveMode.Append)
287 | .save("testdb.loadDataBatches")
288 | df3.write
289 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
290 | .option("insertBatchSize", 2)
291 | .mode(SaveMode.Append)
292 | .save("testdb.loadDataBatches")
293 |
294 | val actualDF =
295 | spark.read.format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT).load("testdb.loadDataBatches")
296 | assertSmallDataFrameEquality(
297 | actualDF,
298 | df1.union(df2).union(df3),
299 | orderedComparison = false
300 | )
301 | }
302 | }
303 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/LoadbalanceTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
4 | import com.singlestore.spark.JdbcHelpers.executeQuery
5 | import org.apache.spark.sql.SaveMode
6 | import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
7 | import org.apache.spark.sql.jdbc.JdbcDialects
8 | import org.apache.spark.sql.types.IntegerType
9 |
10 | import java.sql.DriverManager
11 | import java.util.Properties
12 |
13 | class LoadbalanceTest extends IntegrationSuiteBase {
14 |
15 | val masterHostPort = s"${masterHost}:${masterPort}"
16 | val childHostPort = "localhost:5508"
17 |
18 | override def beforeEach(): Unit = {
19 | super.beforeEach()
20 |
21 | // Set master + child aggregator as dmlEndpoints
22 | spark.conf
23 | .set("spark.datasource.singlestore.dmlEndpoints", s"${masterHostPort},${childHostPort}")
24 | }
25 |
26 | def countQueries(hostport: String): Int = {
27 | val props = new Properties()
28 | props.setProperty("dbtable", "testdb")
29 | props.setProperty("user", "root")
30 | props.setProperty("password", masterPassword)
31 |
32 | val conn = DriverManager.getConnection(s"jdbc:singlestore://$hostport", props)
33 | try {
34 | // we only use write queries since read queries are always increasing due to internal status checks
35 | val rows =
36 | JdbcHelpers.executeQuery(conn, "show status extended like 'Successful_write_queries'")
37 | rows.map(r => r.getAs[String](1).toInt).sum
38 | } finally {
39 | conn.close()
40 | }
41 | }
42 |
43 | def counters =
44 | Map(
45 | masterHostPort -> countQueries(masterHostPort),
46 | childHostPort -> countQueries(childHostPort)
47 | )
48 |
49 | describe("load-balances among all hosts listed in dmlEndpoints") {
50 |
51 | it("queries both aggregators eventually") {
52 |
53 | val df = spark.createDF(
54 | List(4, 5, 6),
55 | List(("id", IntegerType, true))
56 | )
57 |
58 | val startCounters = counters
59 |
60 | // 50/50 chance of picking either agg, 10 tries should be enough to ensure we hit both aggs with write queries
61 | for (i <- 0 to 10) {
62 | df.write
63 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
64 | .option("driverConnectionPool.MinEvictableIdleTimeMs", "100")
65 | .option("driverConnectionPool.TimeBetweenEvictionRunsMS", "50")
66 | .option("executorConnectionPool.MinEvictableIdleTimeMs", "100")
67 | .option("executorConnectionPool.TimeBetweenEvictionRunsMS", "50")
68 | .mode(SaveMode.Overwrite)
69 | .save("test")
70 |
71 | Thread.sleep(300)
72 | }
73 |
74 | val endCounters = counters
75 |
76 | assert(endCounters(childHostPort) > startCounters(childHostPort))
77 | assert(endCounters(masterHostPort) > startCounters(masterHostPort))
78 | }
79 |
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/MaxErrorsTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.SQLTransientConnectionException
4 |
5 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
6 | import org.apache.spark.sql.SaveMode
7 | import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType}
8 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
9 |
10 | import scala.util.Try
11 |
12 | class MaxErrorsTest extends IntegrationSuiteBase with BeforeAndAfterEach with BeforeAndAfterAll {
13 |
14 | def testMaxErrors(tableName: String, maxErrors: Int, duplicateItems: Int): Unit = {
15 | val df = spark.createDF(
16 | List.fill(duplicateItems + 1)((1, "Alice", 213: BigDecimal)),
17 | List(("id", IntegerType, true),
18 | ("name", StringType, false),
19 | ("age", DecimalType(10, 0), true))
20 | )
21 | val result = Try {
22 | df.write
23 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
24 | .option("tableKey.primary", "name")
25 | .option("maxErrors", maxErrors)
26 | .mode(SaveMode.Ignore)
27 | .save(s"testdb.$tableName")
28 | }
29 | if (duplicateItems > maxErrors) {
30 | assert(result.isFailure)
31 | result.failed.get.getCause match {
32 | case _: SQLTransientConnectionException =>
33 | case _ => fail("SQLTransientConnectionException should be thrown")
34 | }
35 | } else {
36 | assert(result.isSuccess)
37 | }
38 | }
39 |
40 | describe("small dataset") {
41 | // TODO DB-51213
42 | //it("hit maxErrors") {
43 | // testMaxErrors("hitMaxErrorsSmall", 1, 2)
44 | //}
45 |
46 | it("not hit maxErrors") {
47 | testMaxErrors("notHitMaxErrorsSmall", 1, 1)
48 | }
49 | }
50 |
51 | describe("big dataset") {
52 | // TODO DB-51213
53 | //it("hit maxErrors") {
54 | // testMaxErrors("hitMaxErrorsBig", 10000, 10001)
55 | //}
56 |
57 | it("not hit maxErrors") {
58 | testMaxErrors("notHitMaxErrorsBig", 10000, 10000)
59 | }
60 | }
61 |
62 | it("wrong configuration") {
63 | val df = spark.createDF(
64 | List((1, "Alice", 213: BigDecimal)),
65 | List(("id", IntegerType, true),
66 | ("name", StringType, false),
67 | ("age", DecimalType(10, 0), true))
68 | )
69 | val result = Try {
70 | df.write
71 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
72 | .option("onDuplicateKeySQL", "id=id")
73 | .option("maxErrors", 1)
74 | .mode(SaveMode.Ignore)
75 | .save(s"testdb.someTable")
76 | }
77 | assert(result.isFailure)
78 | result.failed.get match {
79 | case ex: IllegalArgumentException
80 | if ex.getMessage.equals("can't use both `onDuplicateKeySQL` and `maxErrors` options") =>
81 | succeed
82 | case _ => fail()
83 | }
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/OutputMetricsTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.util.concurrent.CountDownLatch
4 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
5 | import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
6 | import org.apache.spark.sql.types.{IntegerType, StringType}
7 |
8 | class OutputMetricsTest extends IntegrationSuiteBase {
9 | it("records written") {
10 | var outputWritten = 0L
11 | var countDownLatch: CountDownLatch = null
12 | spark.sparkContext.addSparkListener(new SparkListener() {
13 | override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
14 | if (taskEnd.taskType == "ResultTask") {
15 | outputWritten.synchronized({
16 | val metrics = taskEnd.taskMetrics
17 | outputWritten += metrics.outputMetrics.recordsWritten
18 | countDownLatch.countDown()
19 | })
20 | }
21 | }
22 | })
23 |
24 | val numRows = 100000
25 | var df1 = spark.createDF(
26 | List.range(0, numRows),
27 | List(("id", IntegerType, true))
28 | )
29 |
30 | var numPartitions = 30
31 | countDownLatch = new CountDownLatch(numPartitions)
32 | df1 = df1.repartition(numPartitions)
33 |
34 | df1.write
35 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
36 | .save("metricsInts")
37 |
38 | countDownLatch.await()
39 | assert(outputWritten == numRows)
40 |
41 | var df2 = spark.createDF(
42 | List("st1", "", null),
43 | List(("st", StringType, true))
44 | )
45 |
46 | outputWritten = 0
47 | numPartitions = 1
48 | countDownLatch = new CountDownLatch(numPartitions)
49 | df2 = df2.repartition(numPartitions)
50 |
51 | df2.write
52 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
53 | .save("metricsStrings")
54 | countDownLatch.await()
55 | assert(outputWritten == 3)
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/ReferenceTableTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
4 | import org.apache.spark.sql.types.IntegerType
5 | import org.apache.spark.sql.{DataFrame, SaveMode}
6 |
7 | import scala.util.Try
8 |
9 | class ReferenceTableTest extends IntegrationSuiteBase {
10 |
11 | val childAggregatorHost = "localhost"
12 | val childAggregatorPort = "5508"
13 |
14 | val dbName = "testdb"
15 | val commonCollectionName = "test_table"
16 | val referenceCollectionName = "reference_table"
17 |
18 | override def beforeEach(): Unit = {
19 | super.beforeEach()
20 |
21 | // Set child aggregator as a dmlEndpoint
22 | spark.conf
23 | .set("spark.datasource.singlestore.dmlEndpoints",
24 | s"${childAggregatorHost}:${childAggregatorPort}")
25 | }
26 |
27 | def writeToTable(tableName: String): Unit = {
28 | val df = spark.createDF(
29 | List(4, 5, 6),
30 | List(("id", IntegerType, true))
31 | )
32 | df.write
33 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
34 | .mode(SaveMode.Append)
35 | .save(s"${dbName}.${tableName}")
36 | }
37 |
38 | def readFromTable(tableName: String): DataFrame = {
39 | spark.read
40 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
41 | .load(s"${dbName}.${tableName}")
42 | }
43 |
44 | def writeAndReadFromTable(tableName: String): Unit = {
45 | writeToTable(tableName)
46 | val dataFrame = readFromTable(tableName)
47 | val sqlRows = dataFrame.collect();
48 | assert(sqlRows.length == 3)
49 | }
50 |
51 | def dropTable(tableName: String): Unit =
52 | executeQueryWithLog(s"drop table if exists $dbName.$tableName")
53 |
54 | describe("Success during write operations") {
55 |
56 | it("to common table") {
57 | dropTable(commonCollectionName)
58 | executeQueryWithLog(
59 | s"create table if not exists $dbName.$commonCollectionName (id INT NOT NULL, PRIMARY KEY (id))")
60 | writeAndReadFromTable(commonCollectionName)
61 | }
62 |
63 | it("to reference table") {
64 | dropTable(referenceCollectionName)
65 | executeQueryWithLog(
66 | s"create reference table if not exists $dbName.$referenceCollectionName (id INT NOT NULL, PRIMARY KEY (id))")
67 | writeAndReadFromTable(referenceCollectionName)
68 | }
69 | }
70 |
71 | describe("Success during creating") {
72 |
73 | it("common table") {
74 | dropTable(commonCollectionName)
75 | writeAndReadFromTable(commonCollectionName)
76 | }
77 | }
78 |
79 | describe("Failure because of") {
80 |
81 | it("database name not specified") {
82 | spark.conf.set("spark.datasource.singlestore.database", "")
83 | val df = spark.createDF(
84 | List(4, 5, 6),
85 | List(("id", IntegerType, true))
86 | )
87 | val result = Try {
88 | df.write
89 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
90 | .mode(SaveMode.Append)
91 | .save(s"${commonCollectionName}")
92 | }
93 | /* Error code description:
94 | 1046 = Database name not provided
95 | * */
96 | assert(TestHelper.isSQLExceptionWithCode(result.failed.get, List(1046)))
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/SQLHelperTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.{Date, Timestamp}
4 |
5 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
6 | import com.singlestore.spark.SQLHelper._
7 | import org.apache.spark.sql.DataFrame
8 | import org.apache.spark.sql.types._
9 | import org.scalatest.BeforeAndAfterEach
10 |
11 | class SQLHelperTest extends IntegrationSuiteBase with BeforeAndAfterEach {
12 | var df: DataFrame = _
13 |
14 | override def beforeEach(): Unit = {
15 | super.beforeEach()
16 | df = spark.createDF(
17 | List((1, "Cat", true), (2, "Dog", true), (3, "CatDog", false)),
18 | List(("id", IntegerType, true), ("name", StringType, true), ("domestic", BooleanType, true))
19 | )
20 | writeTable("testdb.animal", df)
21 |
22 | df = spark.createDF(
23 | List((1: Long, 2: Short, 3: Float, 4: Double),
24 | (-2: Long, 22: Short, 2.0.toFloat, 5.1: Double),
25 | (3: Long, 4: Short, -0.11.toFloat, 66.77: Double)),
26 | List(("nLong", LongType, true),
27 | ("nShort", ShortType, true),
28 | ("nFloat", FloatType, true),
29 | ("nDouble", DoubleType, true))
30 | )
31 | writeTable("testdb.numbers", df)
32 |
33 | df = spark.createDF(
34 | List((1: Byte, Date.valueOf("2015-03-30"), new Timestamp(2147483649L)),
35 | (2: Byte, Date.valueOf("2020-09-09"), new Timestamp(1000))),
36 | List(("nByte", ByteType, true),
37 | ("nDate", DateType, true),
38 | ("nTimestamp", TimestampType, true))
39 | )
40 | writeTable("testdb.byte_dates", df)
41 | }
42 |
43 | override def afterEach(): Unit = {
44 | super.afterEach()
45 |
46 | spark.executeSinglestoreQueryDB("testdb", "DROP TABLE IF EXISTS animal")
47 | spark.executeSinglestoreQueryDB("testdb", "DROP TABLE IF EXISTS numbers")
48 | spark.executeSinglestoreQueryDB("testdb", "DROP TABLE IF EXISTS byte_dates")
49 | spark.executeSinglestoreQuery("DROP DATABASE IF EXISTS test_db_1")
50 | }
51 |
52 | describe("implicit version") {
53 | it("global query test") {
54 | val s = spark.executeSinglestoreQuery("SHOW DATABASES")
55 | val result =
56 | (for (row <- s)
57 | yield row.getString(0)).toList
58 | for (db <- List("memsql", "cluster", "testdb", "information_schema")) {
59 | assert(result.contains(db))
60 | }
61 | }
62 |
63 | it("executeSinglestoreQuery with 2 parameters") {
64 | spark.executeSinglestoreQuery("DROP DATABASE IF EXISTS test_db_1")
65 | spark.executeSinglestoreQuery("CREATE DATABASE test_db_1")
66 | }
67 |
68 | it("executeSinglestoreQueryDB with 3 parameters") {
69 | val res = spark.executeSinglestoreQueryDB("testdb", "SELECT * FROM animal")
70 | val out =
71 | for (row <- res)
72 | yield row.getString(1)
73 |
74 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog"))
75 | }
76 |
77 | it("executeSingleStoreQuery without explicitly specified db") {
78 | val res = spark.executeSinglestoreQuery("SELECT * FROM animal")
79 | val out =
80 | for (row <- res)
81 | yield row.getString(1)
82 |
83 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog"))
84 | }
85 |
86 | it("executeSinglestoreQuery with query params") {
87 | val params = List("%Cat%", 1, false)
88 | val res = spark.executeSinglestoreQuery(
89 | "SELECT * FROM animal WHERE name LIKE ? AND id > ? AND domestic = ?",
90 | "%Cat%",
91 | 1,
92 | false)
93 | val out =
94 | for (row <- res)
95 | yield row.getString(1)
96 |
97 | assert(out.toList == List("CatDog"))
98 | }
99 |
100 | it("executeSinglestoreQuery with db different from the one found in SparkContext") {
101 | spark.executeSinglestoreQuery(query = "CREATE DATABASE test_db_1")
102 | spark.executeSinglestoreQuery(query = "CREATE TABLE test_db_1.animal (id INT, name TEXT)")
103 |
104 | val res = spark.executeSinglestoreQueryDB("test_db_1", "SELECT * FROM animal")
105 | val out =
106 | for (row <- res)
107 | yield row.getString(1)
108 |
109 | assert(out.toList == List())
110 | }
111 | }
112 |
113 | it("executeSinglestoreQuery with numeric columns") {
114 | val params = List(1, 22, 2, null)
115 | val res = spark.executeSinglestoreQuery(
116 | "SELECT * FROM numbers WHERE nLong = ? OR nShort = ? OR nFloat = ? OR nDouble = ?",
117 | 1,
118 | 22,
119 | 2,
120 | null)
121 | assert(res.length == 2)
122 | }
123 |
124 | it("executeSinglestoreQuery with byte and date columns") {
125 | val res = spark.executeSinglestoreQuery(
126 | "SELECT * FROM byte_dates WHERE nByte = ? OR nDate = ? OR nTimestamp = ?",
127 | 2,
128 | "2015-03-30",
129 | 2000)
130 | assert(res.length == 2)
131 | }
132 |
133 | describe("explicit version") {
134 | it("global query test") {
135 | val s = executeSinglestoreQuery(spark, "SHOW DATABASES")
136 | val result =
137 | (for (row <- s)
138 | yield row.getString(0)).toList
139 | for (db <- List("memsql", "cluster", "testdb", "information_schema")) {
140 | assert(result.contains(db))
141 | }
142 | }
143 |
144 | it("executeSinglestoreQuery with 2 parameters") {
145 | executeSinglestoreQuery(spark, "DROP DATABASE IF EXISTS test_db_1")
146 | executeSinglestoreQuery(spark, "CREATE DATABASE test_db_1")
147 | }
148 |
149 | it("executeSinglestoreQueryDB with 3 parameters") {
150 | val res = executeSinglestoreQueryDB(spark, "testdb", "SELECT * FROM animal")
151 | val out =
152 | for (row <- res)
153 | yield row.getString(1)
154 |
155 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog"))
156 | }
157 |
158 | it("executeSinglestoreQuery without explicitly specified db") {
159 | val res = executeSinglestoreQuery(spark, "SELECT * FROM animal")
160 | val out =
161 | for (row <- res)
162 | yield row.getString(1)
163 |
164 | assert(out.toList.sorted == List("Cat", "CatDog", "Dog"))
165 | }
166 |
167 | it("executeSinglestoreQuery with query params") {
168 | val params = List("%Cat%", 1, false)
169 | val res =
170 | executeSinglestoreQuery(
171 | spark,
172 | "SELECT * FROM animal WHERE name LIKE ? AND id > ? AND domestic = ?",
173 | "%Cat%",
174 | 1,
175 | false)
176 | val out =
177 | for (row <- res)
178 | yield row.getString(1)
179 |
180 | assert(out.toList == List("CatDog"))
181 | }
182 |
183 | it("executeSinglestoreQuery with db different from the one found in SparkContext") {
184 | executeSinglestoreQuery(spark, query = "CREATE DATABASE test_db_1")
185 | executeSinglestoreQuery(spark, query = "CREATE TABLE test_db_1.animal (id INT, name TEXT)")
186 |
187 | val res = executeSinglestoreQueryDB(spark, "test_db_1", "SELECT * FROM animal")
188 | val out =
189 | for (row <- res)
190 | yield row.getString(1)
191 |
192 | assert(out.toList == List())
193 | }
194 |
195 | it("executeSinglestoreQuery with numeric columns") {
196 | val params = List(1, 22, 2, null)
197 | val res = executeSinglestoreQuery(
198 | spark,
199 | "SELECT * FROM numbers WHERE nLong = ? OR nShort = ? OR nFloat = ? OR nDouble = ?",
200 | 1,
201 | 22,
202 | 2,
203 | null)
204 | assert(res.length == 2)
205 | }
206 |
207 | it("executeSinglestoreQuery with byte and date columns") {
208 | val res =
209 | executeSinglestoreQuery(
210 | spark,
211 | "SELECT * FROM byte_dates WHERE nByte = ? OR nDate = ? OR nTimestamp = ?",
212 | 2,
213 | "2015-03-30",
214 | 2000)
215 | assert(res.length == 2)
216 | }
217 | }
218 | }
219 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/SQLPermissionsTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.util.UUID
4 | import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
5 | import com.singlestore.spark.JdbcHelpers.executeQuery
6 | import org.apache.spark.sql.SaveMode
7 | import org.apache.spark.sql.types.IntegerType
8 |
9 | import java.sql.DriverManager
10 | import scala.util.Try
11 |
12 | class SQLPermissionsTest extends IntegrationSuiteBase {
13 |
14 | val testUserName = "sparkuserselect"
15 | val dbName = "testdb"
16 | val collectionName = "temps_test"
17 |
18 | override def beforeAll(): Unit = {
19 | super.beforeAll()
20 | val conn =
21 | DriverManager.getConnection(s"jdbc:mysql://$masterHost:$masterPort", jdbcDefaultProps)
22 | executeQuery(conn, s"CREATE USER '${testUserName}'@'%'")
23 | }
24 |
25 | override def beforeEach(): Unit = {
26 | super.beforeEach()
27 | val df = spark.createDF(
28 | List(1, 2, 3),
29 | List(("id", IntegerType, true))
30 | )
31 | writeTable(s"${dbName}.${collectionName}", df)
32 | }
33 |
34 | private def setUpUserPermissions(privilege: String): Unit = {
35 | /* Revoke all permissions from user */
36 | Try(executeQueryWithLog(s"REVOKE ALL PRIVILEGES ON ${dbName}.* FROM '${testUserName}'@'%'"))
37 | /* Give permissions to user */
38 | executeQueryWithLog(s"GRANT ${privilege} ON ${dbName}.* TO '${testUserName}'@'%'")
39 | /* Set up user to spark */
40 | spark.conf.set("spark.datasource.singlestore.user", s"${testUserName}")
41 | }
42 |
43 | private def doSuccessOperation(operation: () => Unit)(privilege: String): Unit = {
44 | it(s"success with ${privilege} permission") {
45 | setUpUserPermissions(privilege)
46 | val result = Try(operation())
47 | if (result.isFailure) {
48 | result.failed.get.printStackTrace()
49 | fail()
50 | }
51 | }
52 | }
53 |
54 | private def doFailOperation(operation: () => Unit)(privilege: String): Unit = {
55 | it(s"fails with ${privilege} permission") {
56 | setUpUserPermissions(privilege)
57 | val result = Try(operation())
58 | /* Error codes description:
59 | 1142 = denied to current user
60 | 1050 = table already exists (error throws when we don't have SELECT permission to check if such table already exists)
61 | */
62 | assert(TestHelper.isSQLExceptionWithCode(result.failed.get, List(1142, 1050)))
63 | }
64 | }
65 |
66 | describe("read permissions") {
67 | /* List of supported privileges for read operation */
68 | val supportedPrivileges = List("SELECT", "ALL PRIVILEGES")
69 | /* List of unsupported privileges for read operation */
70 | val unsupportedPrivileges = List("CREATE", "DROP", "DELETE", "INSERT", "UPDATE")
71 |
72 | def operation(): Unit =
73 | spark.read
74 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
75 | .load(s"${dbName}.${collectionName}")
76 |
77 | unsupportedPrivileges.foreach(doFailOperation(operation))
78 | supportedPrivileges.foreach(doSuccessOperation(operation))
79 | }
80 |
81 | describe("write permissions") {
82 | /* List of supported privileges for write operation */
83 | val supportedPrivileges = List("INSERT, SELECT", "ALL PRIVILEGES")
84 | /* List of unsupported privileges for write operation */
85 | val unsupportedPrivileges = List("CREATE", "DROP", "DELETE", "SELECT", "UPDATE")
86 |
87 | def operation(): Unit = {
88 | val df = spark.createDF(
89 | List(4, 5, 6),
90 | List(("id", IntegerType, true))
91 | )
92 | df.write
93 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
94 | .mode(SaveMode.Append)
95 | .save(s"${dbName}.${collectionName}")
96 | }
97 |
98 | unsupportedPrivileges.foreach(doFailOperation(operation))
99 | supportedPrivileges.foreach(doSuccessOperation(operation))
100 | }
101 |
102 | describe("drop permissions") {
103 |
104 | /* List of supported privileges for drop operation */
105 | val supportedPrivileges = List("DROP, SELECT, INSERT", "ALL PRIVILEGES")
106 | /* List of unsupported privileges for drop operation */
107 | val unsupportedPrivileges = List("CREATE", "INSERT", "DELETE", "SELECT", "UPDATE")
108 |
109 | implicit def operation(): Unit = {
110 | val df = spark.createDF(
111 | List(1, 2, 3),
112 | List(("id", IntegerType, true))
113 | )
114 | df.write
115 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
116 | .option("truncate", "true")
117 | .mode(SaveMode.Overwrite)
118 | .save(s"${dbName}.${collectionName}")
119 | }
120 |
121 | unsupportedPrivileges.foreach(doFailOperation(operation))
122 | supportedPrivileges.foreach(doSuccessOperation(operation))
123 | }
124 |
125 | describe("create permissions") {
126 |
127 | /* List of supported privileges for create operation */
128 | val supportedPrivileges = List("CREATE, SELECT, INSERT", "ALL PRIVILEGES")
129 | /* List of unsupported privileges for create operation */
130 | val unsupportedPrivileges = List("DROP", "INSERT", "DELETE", "SELECT", "UPDATE")
131 |
132 | implicit def operation(): Unit = {
133 | val df = spark.createDF(
134 | List(1, 2, 3),
135 | List(("id", IntegerType, true))
136 | )
137 | df.write
138 | .format(DefaultSource.SINGLESTORE_SOURCE_NAME_SHORT)
139 | .mode(SaveMode.Overwrite)
140 | .save(s"${dbName}.${collectionName}_${UUID.randomUUID().toString.split("-")(0)}")
141 | }
142 |
143 | unsupportedPrivileges.foreach(doFailOperation(operation))
144 | supportedPrivileges.foreach(doSuccessOperation(operation))
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/SinglestoreConnectionPoolTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.Connection
4 | import java.time.Duration
5 | import java.util.Properties
6 |
7 | import org.apache.commons.dbcp2.DelegatingConnection
8 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
9 |
10 | class SinglestoreConnectionPoolTest extends IntegrationSuiteBase {
11 | var properties = new Properties()
12 |
13 | override def beforeEach(): Unit = {
14 | super.beforeEach()
15 | properties = JdbcHelpers.getConnProperties(
16 | SinglestoreOptions(
17 | CaseInsensitiveMap(
18 | Map("ddlEndpoint" -> s"$masterHost:$masterPort", "password" -> masterPassword)),
19 | spark.sparkContext),
20 | isOnExecutor = false,
21 | s"$masterHost:$masterPort"
22 | )
23 | }
24 |
25 | override def afterEach(): Unit = {
26 | super.afterEach()
27 | SinglestoreConnectionPool.close()
28 | }
29 |
30 | it("reuses a connection") {
31 | var conn = SinglestoreConnectionPool.getConnection(properties)
32 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
33 | conn.close()
34 |
35 | conn = SinglestoreConnectionPool.getConnection(properties)
36 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
37 | conn.close()
38 |
39 | assert(conn1 == conn2, "should reuse idle connection")
40 | }
41 |
42 | it("creates a new connection when existing is in use") {
43 | val conn1 = SinglestoreConnectionPool.getConnection(properties)
44 | val conn2 = SinglestoreConnectionPool.getConnection(properties)
45 | val originalConn1 =
46 | conn1.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
47 | val originalConn2 =
48 | conn2.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
49 |
50 | assert(originalConn1 != originalConn2, "should create a new connection when existing is in use")
51 |
52 | conn1.close()
53 | conn2.close()
54 | }
55 |
56 | it("creates different pools for different properties") {
57 | var conn = SinglestoreConnectionPool.getConnection(properties)
58 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
59 | conn.close()
60 |
61 | properties.setProperty("newProperty", "")
62 |
63 | conn = SinglestoreConnectionPool.getConnection(properties)
64 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
65 | conn.close()
66 |
67 | assert(conn1 != conn2, "should create different pools for different properties")
68 | }
69 |
70 | it("maxTotal and maxWaitMillis") {
71 | val maxWaitMillis = 200
72 | properties.setProperty("maxTotal", "1")
73 | properties.setProperty("maxWaitMillis", maxWaitMillis.toString)
74 |
75 | val conn1 = SinglestoreConnectionPool.getConnection(properties)
76 |
77 | val start = System.nanoTime()
78 | try {
79 | SinglestoreConnectionPool.getConnection(properties)
80 | fail()
81 | } catch {
82 | case e: Throwable =>
83 | assert(
84 | e.getMessage.equals(
85 | "Cannot get a connection, pool error Timeout waiting for idle object"),
86 | "should throw timeout error"
87 | )
88 | assert(System.nanoTime() - start > Duration.ofMillis(maxWaitMillis).toNanos,
89 | "should throw timeout error after 1 sec")
90 | }
91 |
92 | conn1.close()
93 | }
94 |
95 | it("eviction of idle connections") {
96 | val minEvictableIdleTimeMillis = 100
97 | val timeBetweenEvictionRunsMillis = 50
98 | properties.setProperty("minEvictableIdleTimeMillis", minEvictableIdleTimeMillis.toString)
99 | properties.setProperty("timeBetweenEvictionRunsMillis", timeBetweenEvictionRunsMillis.toString)
100 |
101 | var conn = SinglestoreConnectionPool.getConnection(properties)
102 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
103 | conn.close()
104 |
105 | Thread.sleep(((minEvictableIdleTimeMillis + timeBetweenEvictionRunsMillis) * 1.1).toLong)
106 |
107 | conn = SinglestoreConnectionPool.getConnection(properties)
108 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
109 | conn.close()
110 |
111 | assert(conn1 != conn2, "should evict idle connection")
112 | }
113 |
114 | it("maxConnLifetimeMillis") {
115 | val maxConnLifetimeMillis = 100
116 | properties.setProperty("maxConnLifetimeMillis", maxConnLifetimeMillis.toString)
117 | var conn = SinglestoreConnectionPool.getConnection(properties)
118 | val conn1 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
119 | conn.close()
120 |
121 | Thread.sleep((maxConnLifetimeMillis * 1.1).toLong)
122 |
123 | conn = SinglestoreConnectionPool.getConnection(properties)
124 | val conn2 = conn.asInstanceOf[DelegatingConnection[Connection]].getInnermostDelegateInternal
125 | conn.close()
126 |
127 | assert(conn1 != conn2, "should not use a connection after end of lifetime")
128 | }
129 | }
130 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/SinglestoreOptionsTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
4 |
5 | class SinglestoreOptionsTest extends IntegrationSuiteBase {
6 | val requiredOptions = Map("ddlEndpoint" -> "h:3306")
7 |
8 | describe("equality") {
9 | it("should sort dmlEndpoints") {
10 | assert(
11 | SinglestoreOptions(
12 | CaseInsensitiveMap(
13 | requiredOptions ++ Map("dmlEndpoints" -> "host1:3302,host2:3302,host1:3342")),
14 | spark.sparkContext) ==
15 | SinglestoreOptions(
16 | CaseInsensitiveMap(
17 | requiredOptions ++ Map("dmlEndpoints" -> "host2:3302,host1:3302,host1:3342")),
18 | spark.sparkContext),
19 | "Should sort dmlEndpoints"
20 | )
21 | }
22 | }
23 |
24 | describe("splitEscapedColumns") {
25 | it("empty string") {
26 | assert(SinglestoreOptions.splitEscapedColumns("") == List())
27 | }
28 |
29 | it("3 columns") {
30 | assert(
31 | SinglestoreOptions.splitEscapedColumns("col1,col2,col3") == List("col1", "col2", "col3"))
32 | }
33 |
34 | it("with spaces") {
35 | assert(
36 | SinglestoreOptions
37 | .splitEscapedColumns(" col1 , col2, col3") == List(" col1 ", " col2", " col3"))
38 | }
39 |
40 | it("with backticks") {
41 | assert(
42 | SinglestoreOptions.splitEscapedColumns(" ` col1` , `col2`, `` col3") == List(" ` col1` ",
43 | " `col2`",
44 | " `` col3"))
45 | }
46 |
47 | it("with commas inside of backticks") {
48 | assert(
49 | SinglestoreOptions
50 | .splitEscapedColumns(" ` ,, col1,` , ``,```,col3`, `` col4,`,,`") == List(
51 | " ` ,, col1,` ",
52 | " ``",
53 | "```,col3`",
54 | " `` col4",
55 | "`,,`"))
56 | }
57 | }
58 |
59 | describe("trimAndUnescapeColumn") {
60 | it("empty string") {
61 | assert(SinglestoreOptions.trimAndUnescapeColumn("") == "")
62 | }
63 |
64 | it("spaces") {
65 | assert(SinglestoreOptions.trimAndUnescapeColumn(" ") == "")
66 | }
67 |
68 | it("in backticks") {
69 | assert(SinglestoreOptions.trimAndUnescapeColumn(" `asd` ") == "asd")
70 | }
71 |
72 | it("backticks in the result") {
73 | assert(SinglestoreOptions.trimAndUnescapeColumn(" ```a``sd` ") == "`a`sd")
74 | }
75 |
76 | it("several escaped words") {
77 | assert(SinglestoreOptions.trimAndUnescapeColumn(" ```a``sd` ```a``sd` ") == "`a`sd `a`sd")
78 | }
79 |
80 | it("backtick in the middle of string") {
81 | assert(
82 | SinglestoreOptions
83 | .trimAndUnescapeColumn(" a```a``sd` ```a``sd` ") == "a```a``sd` ```a``sd`")
84 | }
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/TestHelper.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import java.sql.SQLException
4 |
5 | import scala.annotation.tailrec
6 |
7 | object TestHelper {
8 |
9 | @tailrec
10 | def isSQLExceptionWithCode(e: Throwable, codes: List[Integer]): Boolean = e match {
11 | case e: SQLException if codes.contains(e.getErrorCode) => true
12 | case e if e.getCause != null => isSQLExceptionWithCode(e.getCause, codes)
13 | case e =>
14 | e.printStackTrace()
15 | false
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/src/test/scala/com/singlestore/spark/VersionTest.scala:
--------------------------------------------------------------------------------
1 | package com.singlestore.spark
2 |
3 | import com.singlestore.spark.SQLGen.SinglestoreVersion
4 | import org.scalatest.funspec.AnyFunSpec
5 |
6 | class VersionTest extends AnyFunSpec {
7 |
8 | it("singlestore version test") {
9 |
10 | assert(SinglestoreVersion("7.0.1").atLeast("6.8.1"))
11 | assert(!SinglestoreVersion("6.8.1").atLeast("7.0.1"))
12 | assert(SinglestoreVersion("7.0.2").atLeast("7.0.1"))
13 | assert(SinglestoreVersion("7.0.10").atLeast("7.0.9"))
14 | assert(SinglestoreVersion("7.2.5").atLeast("7.1.99999"))
15 | assert(SinglestoreVersion("7.2.500").atLeast("7.2.499"))
16 | }
17 | }
18 |
--------------------------------------------------------------------------------