├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── build.gradle ├── deploy.sh └── src ├── README.md └── main ├── resources └── log4j.properties └── scala └── org └── apache └── spark └── sql └── execution ├── datasources └── jdbc2 │ ├── DefaultSource.scala │ ├── DriverRegistry.scala │ ├── DriverWrapper.scala │ ├── JDBCOptions.scala │ ├── JDBCPartition.scala │ ├── JDBCPartitioningInfo.scala │ ├── JDBCSaveMode.scala │ └── JdbcUtils.scala └── jdbc ├── JdbcSink.scala └── JdbcSourceProvider.scala /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Publish package to the Maven Central Repository 2 | on: 3 | release: 4 | types: [created] 5 | jobs: 6 | publish: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Java 11 | uses: actions/setup-java@v1 12 | with: 13 | java-version: 1.8 14 | - name: Decode gpg 15 | run: | 16 | echo "${{ secrets.SIGNING_SECRET_KEY_RING_KEY }}" > ~/secring.gpg.b64 17 | base64 -d ~/secring.gpg.b64 > ~/secring.gpg 18 | - name: Publish package 19 | run: gradle publish -Psigning.keyId=${{ secrets.SIGNING_KEY_ID }} -Psigning.password=${{ secrets.SIGNING_PASSWORD }} -Psigning.secretKeyRingFile=$(echo ~/secring.gpg) 20 | env: 21 | MAVEN_USERNAME: ${{ secrets.MAVEN_USERNAME }} 22 | MAVEN_PASSWORD: ${{ secrets.MAVEN_PASSWORD }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .bloop/ 3 | .gradle/ 4 | .metals/ 5 | .vscode/ 6 | build/ 7 | .scalafmt.conf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 dounine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](https://github.com/dounine/spark-sql-datasource/workflows/Publish%20package%20to%20the%20Maven%20Central%20Repository/badge.svg) ![](https://img.shields.io/github/license/dounine/spark-sql-datasource) 2 | 3 | ## Useage 4 | Depend on 5 | ``` 6 | 7 | com.dounine 8 | spark-sql-datasource 9 | 1.0.4 10 | 11 | 12 | ``` 13 | 14 | example 15 | ``` 16 | spark.sql("select name,time,indicator from log") 17 | .write 18 | .format("org.apache.spark.sql.execution.datasources.jdbc2") 19 | .options( 20 | Map( 21 | "savemode" -> JDBCSaveMode.Update.toString, 22 | "driver" -> "com.mysql.jdbc.Driver", 23 | "url" -> "jdbc:mysql://localhost:3306/ttable", 24 | "user" -> "root", 25 | "password" -> "root", 26 | "dbtable" -> "test", 27 | "useSSL" -> "false", 28 | "duplicateIncs" -> "indicator", 29 | "showSql" -> "true" 30 | ) 31 | ).save() 32 | ``` 33 | will be create the follow sql 34 | ``` 35 | INSERT INTO test (`name`,`time`,`indicator`) VALUES (?,?,?) ON DUPLICATE KEY UPDATE `name`=?,`time`=?,`indicator`=`indicator`+? 36 | ``` 37 | if option duplicateIncs unset will be create the follow sql 38 | ``` 39 | INSERT INTO test (`name`,`time`,`indicator`) VALUES (?,?,?) ON DUPLICATE KEY UPDATE `name`=?,`time`=?,`indicator`=? 40 | ``` 41 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | group 'com.dounine' 2 | version '1.0.4' 3 | 4 | buildscript { 5 | repositories { 6 | maven { 7 | url "https://plugins.gradle.org/m2/" 8 | } 9 | } 10 | dependencies { 11 | classpath "com.github.jengelman.gradle.plugins:shadow:4.0.2" 12 | } 13 | } 14 | 15 | apply plugin: "com.github.johnrengelman.shadow" 16 | apply plugin: 'signing' 17 | apply plugin: 'scala' 18 | apply plugin: 'maven-publish' 19 | 20 | task sourcesJar(type: Jar) { 21 | from sourceSets.main.allJava 22 | classifier = 'sources' 23 | } 24 | 25 | task javadocJar(type: Jar) { 26 | from javadoc 27 | classifier = 'javadoc' 28 | } 29 | 30 | publishing { 31 | publications { 32 | mavenJava(MavenPublication) { 33 | from components.java 34 | 35 | artifact sourcesJar 36 | artifact javadocJar 37 | 38 | pom { 39 | name = 'spark-sql-datasource' 40 | description = 'spark sql datasource' 41 | url = 'https://github.com/dounine/spark-sql-datasource' 42 | licenses { 43 | license { 44 | name = 'The Apache License, Version 2.0' 45 | url = 'http://www.apache.org/licenses/LICENSE-2.0.txt' 46 | } 47 | } 48 | developers { 49 | developer { 50 | id = 'lake' 51 | name = 'lake' 52 | email = 'amwoqmgo@gmail.com' 53 | } 54 | } 55 | scm { 56 | connection = 'scm:git:git://github.com/dounine/spark-sql-datasource.git' 57 | developerConnection = 'scm:git:ssh://github.com/dounine/spark-sql-datasource.git' 58 | url = 'https://github.com/dounine/spark-sql-datasource' 59 | } 60 | } 61 | } 62 | } 63 | repositories { 64 | maven { 65 | def releasesRepoUrl = "https://oss.sonatype.org/service/local/staging/deploy/maven2/" 66 | def snapshotsRepoUrl = "https://oss.sonatype.org/content/repositories/snapshots/" 67 | url = version.endsWith('SNAPSHOT') ? snapshotsRepoUrl : releasesRepoUrl 68 | credentials { 69 | username System.getenv("MAVEN_USERNAME") 70 | password System.getenv("MAVEN_PASSWORD") 71 | } 72 | } 73 | } 74 | } 75 | 76 | sourceCompatibility = 1.8 77 | 78 | repositories { 79 | mavenLocal() 80 | mavenCentral() 81 | maven { url "http://repo.hortonworks.com/content/repositories/releases/" } 82 | } 83 | 84 | 85 | task copyJars(type: Copy) { 86 | from configurations.runtime 87 | into new File('build/libs/lib') 88 | } 89 | compileJava.dependsOn copyJars 90 | 91 | shadowJar { 92 | zip64 true 93 | } 94 | 95 | signing { 96 | sign publishing.publications.mavenJava 97 | } 98 | 99 | ext { 100 | flinkVersion = "1.8-SNAPSHOT" 101 | hbaseVersion = "2.0.0" 102 | slf4jVersion = "1.7.25" 103 | } 104 | 105 | dependencies { 106 | compile 'org.scala-lang:scala-library:2.12.12' 107 | 108 | compile group: 'org.apache.spark', name: 'spark-sql_2.12', version: '3.0.0' 109 | compile group: 'mysql', name: 'mysql-connector-java', version: '5.1.47' 110 | 111 | testCompile group: 'junit', name: 'junit', version: '4.12' 112 | } 113 | -------------------------------------------------------------------------------- /deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | gradle clean build -xtest 3 | gradle publish 4 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | ## Usage 2 | Spark-sql-datasource jdbc2 3 | 4 | ## Demo 5 | ``` 6 | spark.sql("select name,time,indicator from log") 7 | .write 8 | .format("org.apache.spark.sql.execution.datasources.jdbc2") 9 | .options( 10 | Map( 11 | "savemode" -> JDBCSaveMode.Update.toString, 12 | "driver" -> "com.mysql.jdbc.Driver", 13 | "url" -> "jdbc:mysql://localhost:3306/ttable", 14 | "user" -> "root", 15 | "password" -> "root", 16 | "dbtable" -> "test", 17 | "useSSL" -> "false", 18 | "duplicateIncs" -> "indicator" 19 | ) 20 | ).save() 21 | ``` 22 | will be create the follow sql 23 | ``` 24 | INSERT INTO test (`name`,`time`,`indicator`) VALUES (?,?,?) ON DUPLICATE KEY UPDATE `name`=?,`time`=?,`indicator`=`indicator`+? 25 | ``` 26 | if option duplicateIncs unset will be create the follow sql 27 | ``` 28 | INSERT INTO test (`name`,`time`,`indicator`) VALUES (?,?,?) ON DUPLICATE KEY UPDATE `name`=?,`time`=?,`indicator`=? 29 | ``` -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootLogger=INFO, console 2 | 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | import java.sql.Connection 4 | 5 | import org.apache.spark.sql.execution.datasources.jdbc2.JDBCOptions._ 6 | import org.apache.spark.sql.execution.datasources.jdbc2.JdbcUtils._ 7 | import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} 8 | import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} 9 | 10 | class DefaultSource extends CreatableRelationProvider with RelationProvider with DataSourceRegister { 11 | 12 | override def shortName(): String = "jdbc2" 13 | 14 | override def createRelation( 15 | sqlContext: SQLContext, 16 | parameters: Map[String, String]): BaseRelation = { 17 | 18 | val jdbcOptions = new JDBCOptions(parameters) 19 | val partitionColumn = jdbcOptions.partitionColumn 20 | val lowerBound = jdbcOptions.lowerBound 21 | val upperBound = jdbcOptions.upperBound 22 | val numPartitions = jdbcOptions.numPartitions 23 | 24 | val partitionInfo = if (partitionColumn.isEmpty) { 25 | assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " + 26 | s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty") 27 | null 28 | } else { 29 | assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty, 30 | s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " + 31 | s"'$JDBC_NUM_PARTITIONS' are also required") 32 | JDBCPartitioningInfo( 33 | partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get) 34 | } 35 | val parts = JDBCRelation.columnPartition(partitionInfo) 36 | JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession) 37 | } 38 | 39 | override def createRelation( 40 | sqlContext: SQLContext, 41 | mode: SaveMode, 42 | parameters: Map[String, String], 43 | df: DataFrame): BaseRelation = { 44 | val options = new JDBCOptions(parameters) 45 | val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis 46 | 47 | var saveMode: JDBCSaveMode.Value = mode match { 48 | case SaveMode.Overwrite => JDBCSaveMode.Overwrite 49 | case SaveMode.Append => JDBCSaveMode.Append 50 | case SaveMode.ErrorIfExists => JDBCSaveMode.ErrorIfExists 51 | case SaveMode.Ignore => JDBCSaveMode.Ignore 52 | } 53 | val parameterLower: Map[String, String] = parameters.map(kv => (kv._1.toLowerCase, kv._2)) 54 | saveMode = if (parameterLower.keySet.contains("savemode")) { 55 | if (parameterLower("savemode").toLowerCase == JDBCSaveMode.Update.toString.toLowerCase) JDBCSaveMode.Update else saveMode 56 | } else { 57 | saveMode 58 | } 59 | 60 | val conn: Connection = JdbcUtils.createConnectionFactory(options)() 61 | try { 62 | val tableExists: Boolean = JdbcUtils.tableExists(conn, options) 63 | if (tableExists) { 64 | saveMode match { 65 | case JDBCSaveMode.Overwrite => 66 | if (options.isTruncate && isCascadingTruncateTable(options.url).contains(false)) { 67 | // In this case, we should truncate table and then load. 68 | truncateTable(conn, options) 69 | val tableSchema = JdbcUtils.getSchemaOption(conn, options) 70 | saveTable(df, tableSchema, isCaseSensitive, options, saveMode) 71 | } else { 72 | // Otherwise, do not truncate the table, instead drop and recreate it 73 | dropTable(conn, options.table) 74 | createTable(conn, df, options) 75 | saveTable(df, Some(df.schema), isCaseSensitive, options, saveMode) 76 | } 77 | 78 | case JDBCSaveMode.Update => 79 | val tableSchema = JdbcUtils.getSchemaOption(conn, options) 80 | saveTable(df, tableSchema, isCaseSensitive, options, saveMode) 81 | 82 | case JDBCSaveMode.Append => 83 | val tableSchema = JdbcUtils.getSchemaOption(conn, options) 84 | saveTable(df, tableSchema, isCaseSensitive, options, saveMode) 85 | 86 | case JDBCSaveMode.ErrorIfExists => 87 | throw new AnalysisException( 88 | s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.") 89 | 90 | case JDBCSaveMode.Ignore => 91 | // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected 92 | // to not save the contents of the DataFrame and to not change the existing data. 93 | // Therefore, it is okay to do nothing here and then just return the relation below. 94 | } 95 | } else { 96 | createTable(conn, df, options) 97 | saveTable(df, Some(df.schema), isCaseSensitive, options, saveMode) 98 | } 99 | } finally { 100 | conn.close() 101 | } 102 | 103 | createRelation(sqlContext, parameters) 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/DriverRegistry.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.sql.execution.datasources.jdbc2 19 | 20 | import java.sql.{Driver, DriverManager} 21 | 22 | import org.apache.spark.internal.Logging 23 | import org.apache.spark.util.Utils 24 | 25 | import scala.collection.mutable 26 | 27 | 28 | object DriverRegistry extends Logging { 29 | 30 | DriverManager.getDrivers 31 | 32 | private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty 33 | 34 | def register(className: String): Unit = { 35 | val cls = Utils.getContextOrSparkClassLoader.loadClass(className) 36 | if (cls.getClassLoader == null) { 37 | logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required") 38 | } else if (wrapperMap.get(className).isDefined) { 39 | logTrace(s"Wrapper for $className already exists") 40 | } else { 41 | synchronized { 42 | if (wrapperMap.get(className).isEmpty) { 43 | val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver]) 44 | DriverManager.registerDriver(wrapper) 45 | wrapperMap(className) = wrapper 46 | logTrace(s"Wrapper for $className registered") 47 | } 48 | } 49 | } 50 | } 51 | } 52 | 53 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/DriverWrapper.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | import java.sql.{Connection, Driver, DriverPropertyInfo, SQLFeatureNotSupportedException} 4 | import java.util.Properties 5 | 6 | class DriverWrapper(val wrapped: Driver) extends Driver { 7 | override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url) 8 | 9 | override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant() 10 | 11 | override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = { 12 | wrapped.getPropertyInfo(url, info) 13 | } 14 | 15 | override def getMinorVersion: Int = wrapped.getMinorVersion 16 | 17 | def getParentLogger: java.util.logging.Logger = { 18 | throw new SQLFeatureNotSupportedException( 19 | s"${this.getClass.getName}.getParentLogger is not yet implemented.") 20 | } 21 | 22 | override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) 23 | 24 | override def getMajorVersion: Int = wrapped.getMajorVersion 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/JDBCOptions.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | import java.sql.{Connection, DriverManager} 4 | import java.util.{Locale, Properties} 5 | 6 | import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap 7 | 8 | class JDBCOptions( 9 | @transient private val parameters: CaseInsensitiveMap[String]) 10 | extends Serializable { 11 | 12 | import JDBCOptions._ 13 | 14 | def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters)) 15 | 16 | def this(url: String, table: String, parameters: Map[String, String]) = { 17 | this(CaseInsensitiveMap(parameters ++ Map( 18 | JDBCOptions.JDBC_URL -> url, 19 | JDBCOptions.JDBC_TABLE_NAME -> table))) 20 | } 21 | 22 | /** 23 | * Returns a property with all options. 24 | */ 25 | val asProperties: Properties = { 26 | val properties = new Properties() 27 | parameters.originalMap.foreach { case (k, v) => properties.setProperty(k, v) } 28 | properties 29 | } 30 | 31 | /** 32 | * Returns a property with all options except Spark internal data source options like `url`, 33 | * `dbtable`, and `numPartition`. This should be used when invoking JDBC API like `Driver.connect` 34 | * because each DBMS vendor has its own property list for JDBC driver. See SPARK-17776. 35 | */ 36 | val asConnectionProperties: Properties = { 37 | val properties = new Properties() 38 | parameters.originalMap.filterKeys(key => !jdbcOptionNames(key.toLowerCase(Locale.ROOT))) 39 | .foreach { case (k, v) => properties.setProperty(k, v) } 40 | properties 41 | } 42 | 43 | // ------------------------------------------------------------ 44 | // Required parameters 45 | // ------------------------------------------------------------ 46 | require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.") 47 | require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.") 48 | // a JDBC URL 49 | val url = parameters(JDBC_URL) 50 | // name of table 51 | val table = parameters(JDBC_TABLE_NAME) 52 | 53 | // ------------------------------------------------------------ 54 | // Optional parameters 55 | // ------------------------------------------------------------ 56 | val driverClass = { 57 | val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS) 58 | userSpecifiedDriverClass.foreach(DriverRegistry.register) 59 | 60 | // Performing this part of the logic on the driver guards against the corner-case where the 61 | // driver returned for a URL is different on the driver and executors due to classpath 62 | // differences. 63 | userSpecifiedDriverClass.getOrElse { 64 | DriverManager.getDriver(url).getClass.getCanonicalName 65 | } 66 | } 67 | 68 | // the number of partitions 69 | val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) 70 | 71 | // ------------------------------------------------------------ 72 | // Optional parameters only for reading 73 | // ------------------------------------------------------------ 74 | // the column used to partition 75 | val partitionColumn = parameters.get(JDBC_PARTITION_COLUMN) 76 | // the lower bound of partition column 77 | val lowerBound = parameters.get(JDBC_LOWER_BOUND).map(_.toLong) 78 | // the upper bound of the partition column 79 | val upperBound = parameters.get(JDBC_UPPER_BOUND).map(_.toLong) 80 | // numPartitions is also used for data source writing 81 | require((partitionColumn.isEmpty && lowerBound.isEmpty && upperBound.isEmpty) || 82 | (partitionColumn.isDefined && lowerBound.isDefined && upperBound.isDefined && 83 | numPartitions.isDefined), 84 | s"When reading JDBC data sources, users need to specify all or none for the following " + 85 | s"options: '$JDBC_PARTITION_COLUMN', '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', " + 86 | s"and '$JDBC_NUM_PARTITIONS'") 87 | val fetchSize = { 88 | val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt 89 | require(size >= 0, 90 | s"Invalid value `${size.toString}` for parameter " + 91 | s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " + 92 | "the JDBC driver ignores the value and does the estimates.") 93 | size 94 | } 95 | 96 | // ------------------------------------------------------------ 97 | // Optional parameters only for writing 98 | // ------------------------------------------------------------ 99 | // if to truncate the table from the JDBC database 100 | val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean 101 | // the create table option , which can be table_options or partition_options. 102 | // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" 103 | // TODO: to reuse the existing partition parameters for those partition specific options 104 | val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") 105 | val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES) 106 | val customSchema = parameters.get(JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES) 107 | 108 | val batchSize = { 109 | val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt 110 | require(size >= 1, 111 | s"Invalid value `${size.toString}` for parameter " + 112 | s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.") 113 | size 114 | } 115 | val isolationLevel = 116 | parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match { 117 | case "NONE" => Connection.TRANSACTION_NONE 118 | case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED 119 | case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED 120 | case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ 121 | case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE 122 | } 123 | // An option to execute custom SQL before fetching data from the remote DB 124 | val sessionInitStatement = parameters.get(JDBC_SESSION_INIT_STATEMENT) 125 | } 126 | 127 | object JDBCOptions { 128 | private val jdbcOptionNames = collection.mutable.Set[String]() 129 | 130 | private def newOption(name: String): String = { 131 | jdbcOptionNames += name.toLowerCase(Locale.ROOT) 132 | name 133 | } 134 | 135 | val JDBC_URL = newOption("url") 136 | val JDBC_TABLE_NAME = newOption("dbtable") 137 | val JDBC_DRIVER_CLASS = newOption("driver") 138 | val JDBC_PARTITION_COLUMN = newOption("partitionColumn") 139 | val JDBC_LOWER_BOUND = newOption("lowerBound") 140 | val JDBC_UPPER_BOUND = newOption("upperBound") 141 | val JDBC_NUM_PARTITIONS = newOption("numPartitions") 142 | val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") 143 | val JDBC_TRUNCATE = newOption("truncate") 144 | val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") 145 | val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") 146 | val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema") 147 | val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") 148 | val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") 149 | val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") 150 | val JDBC_DUPLICATE_INCS = newOption("duplicateIncs") 151 | } 152 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/JDBCPartition.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | import java.sql.{Connection, PreparedStatement, ResultSet, SQLException} 4 | 5 | import org.apache.spark.internal.Logging 6 | import org.apache.spark.rdd.RDD 7 | import org.apache.spark.sql.catalyst.InternalRow 8 | import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} 9 | import org.apache.spark.sql.sources._ 10 | import org.apache.spark.sql.types._ 11 | import org.apache.spark.util.{CompletionIterator, TaskCompletionListener} 12 | import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} 13 | 14 | import scala.collection.JavaConverters._ 15 | import scala.util.control.NonFatal 16 | 17 | /** 18 | * Data corresponding to one partition of a JDBCRDD. 19 | */ 20 | case class JDBCPartition(whereClause: String, idx: Int) extends Partition { 21 | override def index: Int = idx 22 | } 23 | 24 | object JDBCRDD extends Logging { 25 | 26 | /** 27 | * Takes a (schema, table) specification and returns the table's Catalyst 28 | * schema. 29 | * 30 | * @param options - JDBC options that contains url, table and other information. 31 | * @return A StructType giving the table's Catalyst schema. 32 | * @throws SQLException if the table specification is garbage. 33 | * @throws SQLException if the table contains an unsupported type. 34 | */ 35 | def resolveTable(options: JDBCOptions): StructType = { 36 | val url = options.url 37 | val table = options.table 38 | val dialect = JdbcDialects.get(url) 39 | val conn: Connection = JdbcUtils.createConnectionFactory(options)() 40 | try { 41 | val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) 42 | try { 43 | val rs = statement.executeQuery() 44 | try { 45 | JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) 46 | } finally { 47 | rs.close() 48 | } 49 | } finally { 50 | statement.close() 51 | } 52 | } finally { 53 | conn.close() 54 | } 55 | } 56 | 57 | /** 58 | * Prune all but the specified columns from the specified Catalyst schema. 59 | * 60 | * @param schema - The Catalyst schema of the master table 61 | * @param columns - The list of desired columns 62 | * @return A Catalyst schema corresponding to columns in the given order. 63 | */ 64 | private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { 65 | val fieldMap = Map(schema.fields.map(x => x.name -> x): _*) 66 | new StructType(columns.map(name => fieldMap(name))) 67 | } 68 | 69 | /** 70 | * Turns a single Filter into a String representing a SQL expression. 71 | * Returns None for an unhandled filter. 72 | */ 73 | def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = { 74 | def quote(colName: String): String = dialect.quoteIdentifier(colName) 75 | 76 | Option(f match { 77 | case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" 78 | case EqualNullSafe(attr, value) => 79 | val col = quote(attr) 80 | s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + 81 | s"${dialect.compileValue(value)} IS NULL) OR " + 82 | s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" 83 | case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" 84 | case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" 85 | case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" 86 | case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" 87 | case IsNull(attr) => s"${quote(attr)} IS NULL" 88 | case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" 89 | case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" 90 | case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'" 91 | case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" 92 | case In(attr, value) if value.isEmpty => 93 | s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" 94 | case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" 95 | case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) 96 | case Or(f1, f2) => 97 | // We can't compile Or filter unless both sub-filters are compiled successfully. 98 | // It applies too for the following And filter. 99 | // If we can make sure compileFilter supports all filters, we can remove this check. 100 | val or = Seq(f1, f2).flatMap(compileFilter(_, dialect)) 101 | if (or.size == 2) { 102 | or.map(p => s"($p)").mkString(" OR ") 103 | } else { 104 | null 105 | } 106 | case And(f1, f2) => 107 | val and = Seq(f1, f2).flatMap(compileFilter(_, dialect)) 108 | if (and.size == 2) { 109 | and.map(p => s"($p)").mkString(" AND ") 110 | } else { 111 | null 112 | } 113 | case _ => null 114 | }) 115 | } 116 | 117 | /** 118 | * Build and return JDBCRDD from the given information. 119 | * 120 | * @param sc - Your SparkContext. 121 | * @param schema - The Catalyst schema of the underlying database table. 122 | * @param requiredColumns - The names of the columns to SELECT. 123 | * @param filters - The filters to include in all WHERE clauses. 124 | * @param parts - An array of JDBCPartitions specifying partition ids and 125 | * per-partition WHERE clauses. 126 | * @param options - JDBC options that contains url, table and other information. 127 | * @return An RDD representing "SELECT requiredColumns FROM fqTable". 128 | */ 129 | def scanTable( 130 | sc: SparkContext, 131 | schema: StructType, 132 | requiredColumns: Array[String], 133 | filters: Array[Filter], 134 | parts: Array[Partition], 135 | options: JDBCOptions): RDD[InternalRow] = { 136 | val url = options.url 137 | val dialect = JdbcDialects.get(url) 138 | val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) 139 | new JDBCRDD( 140 | sc, 141 | JdbcUtils.createConnectionFactory(options), 142 | pruneSchema(schema, requiredColumns), 143 | quotedColumns, 144 | filters, 145 | parts, 146 | url, 147 | options) 148 | } 149 | } 150 | 151 | /** 152 | * An RDD representing a table in a database accessed via JDBC. Both the 153 | * driver code and the workers must be able to access the database; the driver 154 | * needs to fetch the schema while the workers need to fetch the data. 155 | */ 156 | private[jdbc2] class JDBCRDD( 157 | sc: SparkContext, 158 | getConnection: () => Connection, 159 | schema: StructType, 160 | columns: Array[String], 161 | filters: Array[Filter], 162 | partitions: Array[Partition], 163 | url: String, 164 | options: JDBCOptions) 165 | extends RDD[InternalRow](sc, Nil) { 166 | 167 | /** 168 | * Retrieve the list of partitions corresponding to this RDD. 169 | */ 170 | override def getPartitions: Array[Partition] = partitions 171 | 172 | /** 173 | * `columns`, but as a String suitable for injection into a SQL query. 174 | */ 175 | private val columnList: String = { 176 | val sb = new StringBuilder() 177 | columns.foreach(x => sb.append(",").append(x)) 178 | if (sb.isEmpty) "1" else sb.substring(1) 179 | } 180 | 181 | /** 182 | * `filters`, but as a WHERE clause suitable for injection into a SQL query. 183 | */ 184 | private val filterWhereClause: String = 185 | filters 186 | .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) 187 | .map(p => s"($p)").mkString(" AND ") 188 | 189 | /** 190 | * A WHERE clause representing both `filters`, if any, and the current partition. 191 | */ 192 | private def getWhereClause(part: JDBCPartition): String = { 193 | if (part.whereClause != null && filterWhereClause.length > 0) { 194 | "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" 195 | } else if (part.whereClause != null) { 196 | "WHERE " + part.whereClause 197 | } else if (filterWhereClause.length > 0) { 198 | "WHERE " + filterWhereClause 199 | } else { 200 | "" 201 | } 202 | } 203 | 204 | /** 205 | * Runs the SQL query against the JDBC driver. 206 | * 207 | */ 208 | override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = { 209 | var closed = false 210 | var rs: ResultSet = null 211 | var stmt: PreparedStatement = null 212 | var conn: Connection = null 213 | 214 | def close() { 215 | if (closed) return 216 | try { 217 | if (null != rs) { 218 | rs.close() 219 | } 220 | } catch { 221 | case e: Exception => logWarning("Exception closing resultset", e) 222 | } 223 | try { 224 | if (null != stmt) { 225 | stmt.close() 226 | } 227 | } catch { 228 | case e: Exception => logWarning("Exception closing statement", e) 229 | } 230 | try { 231 | if (null != conn) { 232 | if (!conn.isClosed && !conn.getAutoCommit) { 233 | try { 234 | conn.commit() 235 | } catch { 236 | case NonFatal(e) => logWarning("Exception committing transaction", e) 237 | } 238 | } 239 | conn.close() 240 | } 241 | logInfo("closed connection") 242 | } catch { 243 | case e: Exception => logWarning("Exception closing connection", e) 244 | } 245 | closed = true 246 | } 247 | 248 | context.addTaskCompletionListener(new TaskCompletionListener { 249 | override def onTaskCompletion(context: TaskContext): Unit = { 250 | close() 251 | } 252 | }) 253 | 254 | val inputMetrics = context.taskMetrics().inputMetrics 255 | val part = thePart.asInstanceOf[JDBCPartition] 256 | conn = getConnection() 257 | val dialect = JdbcDialects.get(url) 258 | dialect.beforeFetch(conn, options.asProperties.asScala.toMap) 259 | 260 | // This executes a generic SQL statement (or PL/SQL block) before reading 261 | // the table/query via JDBC. Use this feature to initialize the database 262 | // session environment, e.g. for optimizations and/or troubleshooting. 263 | options.sessionInitStatement match { 264 | case Some(sql) => 265 | val statement = conn.prepareStatement(sql) 266 | logInfo(s"Executing sessionInitStatement: $sql") 267 | try { 268 | statement.execute() 269 | } finally { 270 | statement.close() 271 | } 272 | case None => 273 | } 274 | 275 | // H2's JDBC driver does not support the setSchema() method. We pass a 276 | // fully-qualified table name in the SELECT statement. I don't know how to 277 | // talk about a table in a completely portable way. 278 | 279 | val myWhereClause = getWhereClause(part) 280 | 281 | val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause" 282 | stmt = conn.prepareStatement(sqlText, 283 | ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) 284 | stmt.setFetchSize(options.fetchSize) 285 | rs = stmt.executeQuery() 286 | val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) 287 | 288 | CompletionIterator[InternalRow, Iterator[InternalRow]]( 289 | new InterruptibleIterator(context, rowsIterator), close()) 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/JDBCPartitioningInfo.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | import org.apache.spark.Partition 4 | import org.apache.spark.internal.Logging 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.sql.jdbc.JdbcDialects 7 | import org.apache.spark.sql.sources._ 8 | import org.apache.spark.sql.types.StructType 9 | import org.apache.spark.sql._ 10 | 11 | import scala.collection.mutable.ArrayBuffer 12 | 13 | /** 14 | * Instructions on how to partition the table among workers. 15 | */ 16 | private[sql] case class JDBCPartitioningInfo( 17 | column: String, 18 | lowerBound: Long, 19 | upperBound: Long, 20 | numPartitions: Int) 21 | 22 | private[sql] object JDBCRelation extends Logging { 23 | /** 24 | * Given a partitioning schematic (a column of integral type, a number of 25 | * partitions, and upper and lower bounds on the column's value), generate 26 | * WHERE clauses for each partition so that each row in the table appears 27 | * exactly once. The parameters minValue and maxValue are advisory in that 28 | * incorrect values may cause the partitioning to be poor, but no data 29 | * will fail to be represented. 30 | * 31 | * Null value predicate is added to the first partition where clause to include 32 | * the rows with null value for the partitions column. 33 | * 34 | * @param partitioning partition information to generate the where clause for each partition 35 | * @return an array of partitions with where clause for each partition 36 | */ 37 | def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { 38 | if (partitioning == null || partitioning.numPartitions <= 1 || 39 | partitioning.lowerBound == partitioning.upperBound) { 40 | return Array[Partition](JDBCPartition(null, 0)) 41 | } 42 | 43 | val lowerBound = partitioning.lowerBound 44 | val upperBound = partitioning.upperBound 45 | require (lowerBound <= upperBound, 46 | "Operation not allowed: the lower bound of partitioning column is larger than the upper " + 47 | s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") 48 | 49 | val numPartitions = 50 | if ((upperBound - lowerBound) >= partitioning.numPartitions || /* check for overflow */ 51 | (upperBound - lowerBound) < 0) { 52 | partitioning.numPartitions 53 | } else { 54 | logWarning("The number of partitions is reduced because the specified number of " + 55 | "partitions is less than the difference between upper bound and lower bound. " + 56 | s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + 57 | s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " + 58 | s"Upper bound: $upperBound.") 59 | upperBound - lowerBound 60 | } 61 | // Overflow and silliness can happen if you subtract then divide. 62 | // Here we get a little roundoff, but that's (hopefully) OK. 63 | val stride: Long = upperBound / numPartitions - lowerBound / numPartitions 64 | val column = partitioning.column 65 | var i: Int = 0 66 | var currentValue: Long = lowerBound 67 | val ans = new ArrayBuffer[Partition]() 68 | while (i < numPartitions) { 69 | val lBound = if (i != 0) s"$column >= $currentValue" else null 70 | currentValue += stride 71 | val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null 72 | val whereClause = 73 | if (uBound == null) { 74 | lBound 75 | } else if (lBound == null) { 76 | s"$uBound or $column is null" 77 | } else { 78 | s"$lBound AND $uBound" 79 | } 80 | ans += JDBCPartition(whereClause, i) 81 | i = i + 1 82 | } 83 | ans.toArray 84 | } 85 | } 86 | 87 | private[sql] case class JDBCRelation( 88 | parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession) 89 | extends BaseRelation 90 | with PrunedFilteredScan 91 | with InsertableRelation { 92 | 93 | override def sqlContext: SQLContext = sparkSession.sqlContext 94 | 95 | override val needConversion: Boolean = false 96 | 97 | override val schema: StructType = { 98 | val tableSchema = JDBCRDD.resolveTable(jdbcOptions) 99 | jdbcOptions.customSchema match { 100 | case Some(customSchema) => JdbcUtils.getCustomSchema( 101 | tableSchema, customSchema, sparkSession.sessionState.conf.resolver) 102 | case None => tableSchema 103 | } 104 | } 105 | 106 | // Check if JDBCRDD.compileFilter can accept input filters 107 | override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { 108 | filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) 109 | } 110 | 111 | override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 112 | // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] 113 | JDBCRDD.scanTable( 114 | sparkSession.sparkContext, 115 | schema, 116 | requiredColumns, 117 | filters, 118 | parts, 119 | jdbcOptions).asInstanceOf[RDD[Row]] 120 | } 121 | 122 | override def insert(data: DataFrame, overwrite: Boolean): Unit = { 123 | data 124 | .write 125 | .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) 126 | .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) 127 | } 128 | 129 | override def toString: String = { 130 | val partitioningInfo = if (parts.nonEmpty) s" [numPartitions=${parts.length}]" else "" 131 | // credentials should not be included in the plan output, table information is sufficient. 132 | s"JDBCRelation(${jdbcOptions.table})" + partitioningInfo 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/JDBCSaveMode.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | object JDBCSaveMode extends Enumeration { 4 | 5 | type JDBCSaveMode = Value 6 | 7 | val Append = Value("Append") 8 | val Overwrite = Value("Overwrite") 9 | val ErrorIfExists = Value("ErrorIfExists") 10 | val Ignore = Value("Ignore") 11 | val Update = Value("Update") 12 | 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/datasources/jdbc2/JdbcUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.datasources.jdbc2 2 | 3 | import java.sql.{Connection, Driver, DriverManager, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} 4 | 5 | import org.apache.commons.lang3.StringUtils 6 | import org.apache.spark.TaskContext 7 | import org.apache.spark.executor.InputMetrics 8 | import org.apache.spark.internal.Logging 9 | import org.apache.spark.sql.catalyst.InternalRow 10 | import org.apache.spark.sql.catalyst.analysis.Resolver 11 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 12 | import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow 13 | import org.apache.spark.sql.catalyst.parser.CatalystSqlParser 14 | import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} 15 | import org.apache.spark.sql.execution.datasources.jdbc2.JDBCSaveMode.JDBCSaveMode 16 | import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} 17 | import org.apache.spark.sql.types._ 18 | import org.apache.spark.sql.util.SchemaUtils 19 | import org.apache.spark.sql.{AnalysisException, DataFrame, Row} 20 | import org.apache.spark.unsafe.types.UTF8String 21 | import org.apache.spark.util.NextIterator 22 | 23 | import scala.collection.JavaConverters._ 24 | import scala.util.Try 25 | import scala.util.control.NonFatal 26 | 27 | /** 28 | * Util functions for JDBC tables. 29 | */ 30 | object JdbcUtils extends Logging { 31 | /** 32 | * Returns a factory for creating connections to the given JDBC URL. 33 | * 34 | * @param options - JDBC options that contains url, table and other information. 35 | */ 36 | def createConnectionFactory(options: JDBCOptions): () => Connection = { 37 | val driverClass: String = options.driverClass 38 | () => { 39 | DriverRegistry.register(driverClass) 40 | val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { 41 | case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d 42 | case d if d.getClass.getCanonicalName == driverClass => d 43 | }.getOrElse { 44 | throw new IllegalStateException( 45 | s"Did not find registered driver with class $driverClass") 46 | } 47 | driver.connect(options.url, options.asConnectionProperties) 48 | } 49 | } 50 | 51 | /** 52 | * Returns true if the table already exists in the JDBC database. 53 | */ 54 | def tableExists(conn: Connection, options: JDBCOptions): Boolean = { 55 | val dialect = JdbcDialects.get(options.url) 56 | 57 | // Somewhat hacky, but there isn't a good way to identify whether a table exists for all 58 | // SQL database systems using JDBC meta data calls, considering "table" could also include 59 | // the database name. Query used to find table exists can be overridden by the dialects. 60 | Try { 61 | val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) 62 | try { 63 | statement.executeQuery() 64 | } finally { 65 | statement.close() 66 | } 67 | }.isSuccess 68 | } 69 | 70 | /** 71 | * Drops a table from the JDBC database. 72 | */ 73 | def dropTable(conn: Connection, table: String): Unit = { 74 | val statement = conn.createStatement 75 | try { 76 | statement.executeUpdate(s"DROP TABLE $table") 77 | } finally { 78 | statement.close() 79 | } 80 | } 81 | 82 | /** 83 | * Truncates a table from the JDBC database without side effects. 84 | */ 85 | def truncateTable(conn: Connection, options: JDBCOptions): Unit = { 86 | val dialect = JdbcDialects.get(options.url) 87 | val statement = conn.createStatement 88 | try { 89 | statement.executeUpdate(dialect.getTruncateQuery(options.table)) 90 | } finally { 91 | statement.close() 92 | } 93 | } 94 | 95 | def isCascadingTruncateTable(url: String): Option[Boolean] = { 96 | JdbcDialects.get(url).isCascadingTruncateTable() 97 | } 98 | 99 | /** 100 | * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. 101 | */ 102 | def getInsertStatement( 103 | table: String, 104 | rddSchema: StructType, 105 | tableSchema: Option[StructType], 106 | isCaseSensitive: Boolean, 107 | dialect: JdbcDialect, 108 | mode: JDBCSaveMode, 109 | options: JDBCOptions 110 | ): String = { 111 | val columns = if (tableSchema.isEmpty) { 112 | rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") 113 | } else { 114 | val columnNameEquality = if (isCaseSensitive) { 115 | org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution 116 | } else { 117 | org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution 118 | } 119 | // The generated insert statement needs to follow rddSchema's column sequence and 120 | // tableSchema's column names. When appending data into some case-sensitive DBMSs like 121 | // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of 122 | // RDD column names for user convenience. 123 | val tableColumnNames = tableSchema.get.fieldNames 124 | rddSchema.fields.map { col => 125 | val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse { 126 | throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""") 127 | } 128 | dialect.quoteIdentifier(normalizedName) 129 | }.mkString(",") 130 | } 131 | val placeholders = rddSchema.fields.map(_ => "?").mkString(",") 132 | 133 | mode match { 134 | case JDBCSaveMode.Update => 135 | val props = options.asProperties 136 | val duplicateIncs = props 137 | .getProperty(JDBCOptions.JDBC_DUPLICATE_INCS, "") 138 | .split(",") 139 | .filter { x => StringUtils.isNotBlank(x) } 140 | .map { x => s"`$x`" } 141 | val duplicateSetting = rddSchema 142 | .fields 143 | .map { x => dialect.quoteIdentifier(x.name) } 144 | .map { name => if (duplicateIncs.contains(name)) s"$name=$name+?" else s"$name=?" } 145 | .mkString(",") 146 | val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting" 147 | if (props.getProperty("showSql", "false").equals("true")) { 148 | println(s"${JDBCSaveMode.Update} => sql => $sql") 149 | } 150 | sql 151 | case _ => s"INSERT INTO $table ($columns) VALUES ($placeholders)" 152 | } 153 | // s"INSERT INTO $table ($columns) VALUES ($placeholders)" 154 | } 155 | 156 | /** 157 | * Retrieve standard jdbc types. 158 | * 159 | * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) 160 | * @return The default JdbcType for this DataType 161 | */ 162 | def getCommonJDBCType(dt: DataType): Option[JdbcType] = { 163 | dt match { 164 | case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) 165 | case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) 166 | case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) 167 | case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) 168 | case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) 169 | case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) 170 | case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) 171 | case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) 172 | case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) 173 | case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) 174 | case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) 175 | case t: DecimalType => Option( 176 | JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) 177 | case _ => None 178 | } 179 | } 180 | 181 | private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { 182 | dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( 183 | throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) 184 | } 185 | 186 | /** 187 | * Maps a JDBC type to a Catalyst type. This function is called only when 188 | * the JdbcDialect class corresponding to your database driver returns null. 189 | * 190 | * @param sqlType - A field of java.sql.Types 191 | * @return The Catalyst type corresponding to sqlType. 192 | */ 193 | private def getCatalystType( 194 | sqlType: Int, 195 | precision: Int, 196 | scale: Int, 197 | signed: Boolean): DataType = { 198 | val answer = sqlType match { 199 | // scalastyle:off 200 | case java.sql.Types.ARRAY => null 201 | case java.sql.Types.BIGINT => if (signed) { 202 | LongType 203 | } else { 204 | DecimalType(20, 0) 205 | } 206 | case java.sql.Types.BINARY => BinaryType 207 | case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks 208 | case java.sql.Types.BLOB => BinaryType 209 | case java.sql.Types.BOOLEAN => BooleanType 210 | case java.sql.Types.CHAR => StringType 211 | case java.sql.Types.CLOB => StringType 212 | case java.sql.Types.DATALINK => null 213 | case java.sql.Types.DATE => DateType 214 | case java.sql.Types.DECIMAL 215 | if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) 216 | case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT 217 | case java.sql.Types.DISTINCT => null 218 | case java.sql.Types.DOUBLE => DoubleType 219 | case java.sql.Types.FLOAT => FloatType 220 | case java.sql.Types.INTEGER => if (signed) { 221 | IntegerType 222 | } else { 223 | LongType 224 | } 225 | case java.sql.Types.JAVA_OBJECT => null 226 | case java.sql.Types.LONGNVARCHAR => StringType 227 | case java.sql.Types.LONGVARBINARY => BinaryType 228 | case java.sql.Types.LONGVARCHAR => StringType 229 | case java.sql.Types.NCHAR => StringType 230 | case java.sql.Types.NCLOB => StringType 231 | case java.sql.Types.NULL => null 232 | case java.sql.Types.NUMERIC 233 | if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) 234 | case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT 235 | case java.sql.Types.NVARCHAR => StringType 236 | case java.sql.Types.OTHER => null 237 | case java.sql.Types.REAL => DoubleType 238 | case java.sql.Types.REF => StringType 239 | case java.sql.Types.REF_CURSOR => null 240 | case java.sql.Types.ROWID => LongType 241 | case java.sql.Types.SMALLINT => IntegerType 242 | case java.sql.Types.SQLXML => StringType 243 | case java.sql.Types.STRUCT => StringType 244 | case java.sql.Types.TIME => TimestampType 245 | case java.sql.Types.TIME_WITH_TIMEZONE 246 | => null 247 | case java.sql.Types.TIMESTAMP => TimestampType 248 | case java.sql.Types.TIMESTAMP_WITH_TIMEZONE 249 | => null 250 | case java.sql.Types.TINYINT => IntegerType 251 | case java.sql.Types.VARBINARY => BinaryType 252 | case java.sql.Types.VARCHAR => StringType 253 | case _ => 254 | throw new SQLException("Unrecognized SQL type " + sqlType) 255 | // scalastyle:on 256 | } 257 | 258 | if (answer == null) { 259 | throw new SQLException("Unsupported type " + JDBCType.valueOf(sqlType).getName) 260 | } 261 | answer 262 | } 263 | 264 | /** 265 | * Returns the schema if the table already exists in the JDBC database. 266 | */ 267 | def getSchemaOption(conn: Connection, options: JDBCOptions): Option[StructType] = { 268 | val dialect = JdbcDialects.get(options.url) 269 | 270 | try { 271 | val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) 272 | try { 273 | Some(getSchema(statement.executeQuery(), dialect)) 274 | } catch { 275 | case _: SQLException => None 276 | } finally { 277 | statement.close() 278 | } 279 | } catch { 280 | case _: SQLException => None 281 | } 282 | } 283 | 284 | /** 285 | * Takes a [[ResultSet]] and returns its Catalyst schema. 286 | * 287 | * @param alwaysNullable If true, all the columns are nullable. 288 | * @return A [[StructType]] giving the Catalyst schema. 289 | * @throws SQLException if the schema contains an unsupported type. 290 | */ 291 | def getSchema( 292 | resultSet: ResultSet, 293 | dialect: JdbcDialect, 294 | alwaysNullable: Boolean = false): StructType = { 295 | val rsmd = resultSet.getMetaData 296 | val ncols = rsmd.getColumnCount 297 | val fields = new Array[StructField](ncols) 298 | var i = 0 299 | while (i < ncols) { 300 | val columnName = rsmd.getColumnLabel(i + 1) 301 | val dataType = rsmd.getColumnType(i + 1) 302 | val typeName = rsmd.getColumnTypeName(i + 1) 303 | val fieldSize = rsmd.getPrecision(i + 1) 304 | val fieldScale = rsmd.getScale(i + 1) 305 | val isSigned = { 306 | try { 307 | rsmd.isSigned(i + 1) 308 | } catch { 309 | // Workaround for HIVE-14684: 310 | case e: SQLException if 311 | e.getMessage == "Method not supported" && 312 | rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true 313 | } 314 | } 315 | val nullable = if (alwaysNullable) { 316 | true 317 | } else { 318 | rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls 319 | } 320 | val metadata = new MetadataBuilder().putLong("scale", fieldScale) 321 | val columnType = 322 | dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( 323 | getCatalystType(dataType, fieldSize, fieldScale, isSigned)) 324 | fields(i) = StructField(columnName, columnType, nullable) 325 | i = i + 1 326 | } 327 | new StructType(fields) 328 | } 329 | 330 | /** 331 | * Convert a [[ResultSet]] into an iterator of Catalyst Rows. 332 | */ 333 | def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = { 334 | val inputMetrics = 335 | Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) 336 | val encoder = RowEncoder(schema).resolveAndBind() 337 | val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics) 338 | internalRows.map(encoder.createDeserializer()) 339 | } 340 | 341 | private[spark] def resultSetToSparkInternalRows( 342 | resultSet: ResultSet, 343 | schema: StructType, 344 | inputMetrics: InputMetrics): Iterator[InternalRow] = { 345 | new NextIterator[InternalRow] { 346 | private[this] val rs = resultSet 347 | private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) 348 | private[this] val mutableRow = new SpecificInternalRow(schema.fields.map(x => x.dataType)) 349 | 350 | override protected def close(): Unit = { 351 | try { 352 | rs.close() 353 | } catch { 354 | case e: Exception => logWarning("Exception closing resultset", e) 355 | } 356 | } 357 | 358 | override protected def getNext(): InternalRow = { 359 | if (rs.next()) { 360 | inputMetrics.incRecordsRead(1) 361 | var i = 0 362 | while (i < getters.length) { 363 | getters(i).apply(rs, mutableRow, i) 364 | if (rs.wasNull) mutableRow.setNullAt(i) 365 | i = i + 1 366 | } 367 | mutableRow 368 | } else { 369 | finished = true 370 | null.asInstanceOf[InternalRow] 371 | } 372 | } 373 | } 374 | } 375 | 376 | // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field 377 | // for `MutableRow`. The last argument `Int` means the index for the value to be set in 378 | // the row and also used for the value in `ResultSet`. 379 | private type JDBCValueGetter = (ResultSet, InternalRow, Int) => Unit 380 | 381 | /** 382 | * Creates `JDBCValueGetter`s according to [[StructType]], which can set 383 | * each value from `ResultSet` to each field of [[InternalRow]] correctly. 384 | */ 385 | private def makeGetters(schema: StructType): Array[JDBCValueGetter] = 386 | schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) 387 | 388 | private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { 389 | case BooleanType => 390 | (rs: ResultSet, row: InternalRow, pos: Int) => 391 | row.setBoolean(pos, rs.getBoolean(pos + 1)) 392 | 393 | case DateType => 394 | (rs: ResultSet, row: InternalRow, pos: Int) => 395 | // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. 396 | val dateVal = rs.getDate(pos + 1) 397 | if (dateVal != null) { 398 | row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) 399 | } else { 400 | row.update(pos, null) 401 | } 402 | 403 | // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal 404 | // object returned by ResultSet.getBigDecimal is not correctly matched to the table 405 | // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. 406 | // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through 407 | // a BigDecimal object with scale as 0. But the dataframe schema has correct type as 408 | // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then 409 | // retrieve it, you will get wrong result 199.99. 410 | // So it is needed to set precision and scale for Decimal based on JDBC metadata. 411 | case DecimalType.Fixed(p, s) => 412 | (rs: ResultSet, row: InternalRow, pos: Int) => 413 | val decimal = 414 | nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) 415 | row.update(pos, decimal) 416 | 417 | case DoubleType => 418 | (rs: ResultSet, row: InternalRow, pos: Int) => 419 | row.setDouble(pos, rs.getDouble(pos + 1)) 420 | 421 | case FloatType => 422 | (rs: ResultSet, row: InternalRow, pos: Int) => 423 | row.setFloat(pos, rs.getFloat(pos + 1)) 424 | 425 | case IntegerType => 426 | (rs: ResultSet, row: InternalRow, pos: Int) => 427 | row.setInt(pos, rs.getInt(pos + 1)) 428 | 429 | case LongType if metadata.contains("binarylong") => 430 | (rs: ResultSet, row: InternalRow, pos: Int) => 431 | val bytes = rs.getBytes(pos + 1) 432 | var ans = 0L 433 | var j = 0 434 | while (j < bytes.length) { 435 | ans = 256 * ans + (255 & bytes(j)) 436 | j = j + 1 437 | } 438 | row.setLong(pos, ans) 439 | 440 | case LongType => 441 | (rs: ResultSet, row: InternalRow, pos: Int) => 442 | row.setLong(pos, rs.getLong(pos + 1)) 443 | 444 | case ShortType => 445 | (rs: ResultSet, row: InternalRow, pos: Int) => 446 | row.setShort(pos, rs.getShort(pos + 1)) 447 | 448 | case StringType => 449 | (rs: ResultSet, row: InternalRow, pos: Int) => 450 | // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 451 | row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) 452 | 453 | case TimestampType => 454 | (rs: ResultSet, row: InternalRow, pos: Int) => 455 | val t = rs.getTimestamp(pos + 1) 456 | if (t != null) { 457 | row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) 458 | } else { 459 | row.update(pos, null) 460 | } 461 | 462 | case BinaryType => 463 | (rs: ResultSet, row: InternalRow, pos: Int) => 464 | row.update(pos, rs.getBytes(pos + 1)) 465 | 466 | case ArrayType(et, _) => 467 | val elementConversion = et match { 468 | case TimestampType => 469 | (array: Object) => 470 | array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => 471 | nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) 472 | } 473 | 474 | case StringType => 475 | (array: Object) => 476 | // some underling types are not String such as uuid, inet, cidr, etc. 477 | array.asInstanceOf[Array[java.lang.Object]] 478 | .map(obj => if (obj == null) null else UTF8String.fromString(obj.toString)) 479 | 480 | case DateType => 481 | (array: Object) => 482 | array.asInstanceOf[Array[java.sql.Date]].map { date => 483 | nullSafeConvert(date, DateTimeUtils.fromJavaDate) 484 | } 485 | 486 | case dt: DecimalType => 487 | (array: Object) => 488 | array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => 489 | nullSafeConvert[java.math.BigDecimal]( 490 | decimal, d => Decimal(d, dt.precision, dt.scale)) 491 | } 492 | 493 | case LongType if metadata.contains("binarylong") => 494 | throw new IllegalArgumentException(s"Unsupported array element " + 495 | s"type ${dt.simpleString} based on binary") 496 | 497 | case ArrayType(_, _) => 498 | throw new IllegalArgumentException("Nested arrays unsupported") 499 | 500 | case _ => (array: Object) => array.asInstanceOf[Array[Any]] 501 | } 502 | 503 | (rs: ResultSet, row: InternalRow, pos: Int) => 504 | val array = nullSafeConvert[java.sql.Array]( 505 | input = rs.getArray(pos + 1), 506 | array => new GenericArrayData(elementConversion.apply(array.getArray))) 507 | row.update(pos, array) 508 | 509 | case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") 510 | } 511 | 512 | private def nullSafeConvert[T](input: T, f: T => Any): Any = { 513 | if (input == null) { 514 | null 515 | } else { 516 | f(input) 517 | } 518 | } 519 | 520 | // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for 521 | // `PreparedStatement`. The last argument `Int` means the index for the value to be set 522 | // in the SQL statement and also used for the value in `Row`. 523 | private type JDBCValueSetter = (PreparedStatement, Row, Int, Int) ⇒ Unit 524 | 525 | private def makeSetter( 526 | conn: Connection, 527 | dialect: JdbcDialect, 528 | dataType: DataType): JDBCValueSetter = dataType match { 529 | 530 | case IntegerType ⇒ 531 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 532 | stmt.setInt(pos + 1, row.getInt(pos - offset)) 533 | 534 | case LongType ⇒ 535 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 536 | stmt.setLong(pos + 1, row.getLong(pos - offset)) 537 | 538 | case DoubleType ⇒ 539 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 540 | stmt.setDouble(pos + 1, row.getDouble(pos - offset)) 541 | 542 | case FloatType ⇒ 543 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 544 | stmt.setFloat(pos + 1, row.getFloat(pos - offset)) 545 | 546 | case ShortType ⇒ 547 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 548 | stmt.setInt(pos + 1, row.getShort(pos - offset)) 549 | 550 | case ByteType ⇒ 551 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 552 | stmt.setInt(pos + 1, row.getByte(pos - offset)) 553 | 554 | case BooleanType ⇒ 555 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 556 | stmt.setBoolean(pos + 1, row.getBoolean(pos - offset)) 557 | 558 | case StringType ⇒ 559 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 560 | stmt.setString(pos + 1, row.getString(pos - offset)) 561 | 562 | case BinaryType ⇒ 563 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 564 | stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - offset)) 565 | 566 | case TimestampType ⇒ 567 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 568 | stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - offset)) 569 | 570 | case DateType ⇒ 571 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 572 | stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - offset)) 573 | 574 | case t: DecimalType ⇒ 575 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 576 | stmt.setBigDecimal(pos + 1, row.getDecimal(pos - offset)) 577 | 578 | case ArrayType(et, _) ⇒ 579 | // remove type length parameters from end of type name 580 | val typeName = getJdbcType(et, dialect).databaseTypeDefinition 581 | .toLowerCase.split("\\(")(0) 582 | (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒ 583 | val array = conn.createArrayOf( 584 | typeName, 585 | row.getSeq[AnyRef](pos - offset).toArray) 586 | stmt.setArray(pos + 1, array) 587 | 588 | case _ ⇒ 589 | (_: PreparedStatement, _: Row, pos: Int, offset: Int) ⇒ 590 | throw new IllegalArgumentException( 591 | s"Can't translate non-null value for field $pos") 592 | } 593 | 594 | private def getSetter(fields: Array[StructField], connection: Connection, dialect: JdbcDialect, isUpdateMode: Boolean): Array[JDBCValueSetter] = { 595 | val setter = fields.map(_.dataType).map(makeSetter(connection, dialect, _)) 596 | if (isUpdateMode) { 597 | Array.fill(2)(setter).flatten 598 | } else { 599 | setter 600 | } 601 | } 602 | 603 | /** 604 | * Saves a partition of a DataFrame to the JDBC database. This is done in 605 | * a single database transaction (unless isolation level is "NONE") 606 | * in order to avoid repeatedly inserting data as much as possible. 607 | * 608 | * It is still theoretically possible for rows in a DataFrame to be 609 | * inserted into the database more than once if a stage somehow fails after 610 | * the commit occurs but before the stage can return successfully. 611 | * 612 | * This is not a closure inside saveTable() because apparently cosmetic 613 | * implementation changes elsewhere might easily render such a closure 614 | * non-Serializable. Instead, we explicitly close over all variables that 615 | * are used. 616 | */ 617 | def savePartition( 618 | getConnection: () => Connection, 619 | table: String, 620 | iterator: Iterator[Row], 621 | rddSchema: StructType, 622 | insertStmt: String, 623 | batchSize: Int, 624 | dialect: JdbcDialect, 625 | isolationLevel: Int, 626 | mode: JDBCSaveMode): Iterator[Byte] = { 627 | val conn = getConnection() 628 | var committed = false 629 | 630 | var finalIsolationLevel = Connection.TRANSACTION_NONE 631 | if (isolationLevel != Connection.TRANSACTION_NONE) { 632 | try { 633 | val metadata = conn.getMetaData 634 | if (metadata.supportsTransactions()) { 635 | // Update to at least use the default isolation, if any transaction level 636 | // has been chosen and transactions are supported 637 | val defaultIsolation = metadata.getDefaultTransactionIsolation 638 | finalIsolationLevel = defaultIsolation 639 | if (metadata.supportsTransactionIsolationLevel(isolationLevel)) { 640 | // Finally update to actually requested level if possible 641 | finalIsolationLevel = isolationLevel 642 | } else { 643 | logWarning(s"Requested isolation level $isolationLevel is not supported; " + 644 | s"falling back to default isolation level $defaultIsolation") 645 | } 646 | } else { 647 | logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported") 648 | } 649 | } catch { 650 | case NonFatal(e) => logWarning("Exception while detecting transaction support", e) 651 | } 652 | } 653 | val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE 654 | 655 | try { 656 | if (supportsTransactions) { 657 | conn.setAutoCommit(false) // Everything in the same db transaction. 658 | conn.setTransactionIsolation(finalIsolationLevel) 659 | } 660 | val isUpdateMode = mode == JDBCSaveMode.Update 661 | val stmt = conn.prepareStatement(insertStmt) 662 | val setters: Array[JDBCValueSetter] = getSetter(rddSchema.fields, conn, dialect, isUpdateMode) 663 | val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType) 664 | val length = rddSchema.fields.length 665 | val numFields = if (isUpdateMode) length * 2 else length 666 | val midField = numFields / 2 667 | try { 668 | var rowCount = 0 669 | while (iterator.hasNext) { 670 | val row = iterator.next() 671 | var i = 0 672 | while (i < numFields) { 673 | if (isUpdateMode) { 674 | i < midField match { 675 | case true ⇒ 676 | if (row.isNullAt(i)) { 677 | stmt.setNull(i + 1, nullTypes(i)) 678 | } else { 679 | setters(i).apply(stmt, row, i, 0) 680 | } 681 | case false ⇒ 682 | if (row.isNullAt(i - midField)) { 683 | stmt.setNull(i + 1, nullTypes(i - midField)) 684 | } else { 685 | setters(i).apply(stmt, row, i, midField) 686 | } 687 | } 688 | } else { 689 | if (row.isNullAt(i)) { 690 | stmt.setNull(i + 1, nullTypes(i)) 691 | } else { 692 | setters(i).apply(stmt, row, i, 0) 693 | } 694 | } 695 | i = i + 1 696 | } 697 | stmt.addBatch() 698 | rowCount += 1 699 | if (rowCount % batchSize == 0) { 700 | stmt.executeBatch() 701 | rowCount = 0 702 | } 703 | } 704 | if (rowCount > 0) { 705 | stmt.executeBatch() 706 | } 707 | } finally { 708 | stmt.close() 709 | } 710 | if (supportsTransactions) { 711 | conn.commit() 712 | } 713 | committed = true 714 | Iterator.empty 715 | } catch { 716 | case e: SQLException => 717 | val cause = e.getNextException 718 | if (cause != null && e.getCause != cause) { 719 | // If there is no cause already, set 'next exception' as cause. If cause is null, 720 | // it *may* be because no cause was set yet 721 | if (e.getCause == null) { 722 | try { 723 | e.initCause(cause) 724 | } catch { 725 | // Or it may be null because the cause *was* explicitly initialized, to *null*, 726 | // in which case this fails. There is no other way to detect it. 727 | // addSuppressed in this case as well. 728 | case _: IllegalStateException => e.addSuppressed(cause) 729 | } 730 | } else { 731 | e.addSuppressed(cause) 732 | } 733 | } 734 | throw e 735 | } finally { 736 | if (!committed) { 737 | // The stage must fail. We got here through an exception path, so 738 | // let the exception through unless rollback() or close() want to 739 | // tell the user about another problem. 740 | if (supportsTransactions) { 741 | conn.rollback() 742 | } 743 | conn.close() 744 | } else { 745 | // The stage must succeed. We cannot propagate any exception close() might throw. 746 | try { 747 | conn.close() 748 | } catch { 749 | case e: Exception => logWarning("Transaction succeeded, but closing failed", e) 750 | } 751 | } 752 | } 753 | } 754 | 755 | /** 756 | * Compute the schema string for this RDD. 757 | */ 758 | def schemaString( 759 | df: DataFrame, 760 | url: String, 761 | createTableColumnTypes: Option[String] = None): String = { 762 | val sb = new StringBuilder() 763 | val dialect = JdbcDialects.get(url) 764 | val userSpecifiedColTypesMap = createTableColumnTypes 765 | .map(parseUserSpecifiedCreateTableColumnTypes(df, _)) 766 | .getOrElse(Map.empty[String, String]) 767 | df.schema.fields.foreach { field => 768 | val name = dialect.quoteIdentifier(field.name) 769 | val typ = userSpecifiedColTypesMap 770 | .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition) 771 | val nullable = if (field.nullable) "" else "NOT NULL" 772 | sb.append(s", $name $typ $nullable") 773 | } 774 | if (sb.length < 2) "" else sb.substring(2) 775 | } 776 | 777 | /** 778 | * Parses the user specified createTableColumnTypes option value string specified in the same 779 | * format as create table ddl column types, and returns Map of field name and the data type to 780 | * use in-place of the default data type. 781 | */ 782 | private def parseUserSpecifiedCreateTableColumnTypes( 783 | df: DataFrame, 784 | createTableColumnTypes: String): Map[String, String] = { 785 | def typeName(f: StructField): String = { 786 | // char/varchar gets translated to string type. Real data type specified by the user 787 | // is available in the field metadata as HIVE_TYPE_STRING 788 | if (f.metadata.contains(HIVE_TYPE_STRING)) { 789 | f.metadata.getString(HIVE_TYPE_STRING) 790 | } else { 791 | f.dataType.catalogString 792 | } 793 | } 794 | 795 | val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) 796 | val nameEquality = df.sparkSession.sessionState.conf.resolver 797 | 798 | // checks duplicate columns in the user specified column types. 799 | SchemaUtils.checkColumnNameDuplication( 800 | userSchema.map(_.name), "in the createTableColumnTypes option value", nameEquality) 801 | 802 | // checks if user specified column names exist in the DataFrame schema 803 | userSchema.fieldNames.foreach { col => 804 | df.schema.find(f => nameEquality(f.name, col)).getOrElse { 805 | throw new AnalysisException( 806 | s"createTableColumnTypes option column $col not found in schema " + 807 | df.schema.catalogString) 808 | } 809 | } 810 | 811 | val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap 812 | val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis 813 | if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) 814 | } 815 | 816 | /** 817 | * Parses the user specified customSchema option value to DataFrame schema, and 818 | * returns a schema that is replaced by the custom schema's dataType if column name is matched. 819 | */ 820 | def getCustomSchema( 821 | tableSchema: StructType, 822 | customSchema: String, 823 | nameEquality: Resolver): StructType = { 824 | if (null != customSchema && customSchema.nonEmpty) { 825 | val userSchema = CatalystSqlParser.parseTableSchema(customSchema) 826 | 827 | SchemaUtils.checkColumnNameDuplication( 828 | userSchema.map(_.name), "in the customSchema option value", nameEquality) 829 | 830 | // This is resolved by names, use the custom filed dataType to replace the default dataType. 831 | val newSchema = tableSchema.map { col => 832 | userSchema.find(f => nameEquality(f.name, col.name)) match { 833 | case Some(c) => col.copy(dataType = c.dataType) 834 | case None => col 835 | } 836 | } 837 | StructType(newSchema) 838 | } else { 839 | tableSchema 840 | } 841 | } 842 | 843 | /** 844 | * Saves the RDD to the database in a single transaction. 845 | */ 846 | def saveTable( 847 | df: DataFrame, 848 | tableSchema: Option[StructType], 849 | isCaseSensitive: Boolean, 850 | options: JDBCOptions, 851 | mode: JDBCSaveMode): Unit = { 852 | val url = options.url 853 | val table = options.table 854 | val dialect = JdbcDialects.get(url) 855 | val rddSchema = df.schema 856 | val getConnection: () => Connection = createConnectionFactory(options) 857 | val batchSize = options.batchSize 858 | val isolationLevel = options.isolationLevel 859 | 860 | val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect, mode, options) 861 | val repartitionedDF = options.numPartitions match { 862 | case Some(n) if n <= 0 => throw new IllegalArgumentException( 863 | s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " + 864 | "via JDBC. The minimum value is 1.") 865 | case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) 866 | case _ => df 867 | } 868 | repartitionedDF.rdd.foreachPartition(iterator => savePartition( 869 | getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, mode) 870 | ) 871 | } 872 | 873 | /** 874 | * Creates a table with a given schema. 875 | */ 876 | def createTable( 877 | conn: Connection, 878 | df: DataFrame, 879 | options: JDBCOptions): Unit = { 880 | val strSchema = schemaString( 881 | df, options.url, options.createTableColumnTypes) 882 | val table = options.table 883 | val createTableOptions = options.createTableOptions 884 | // Create the table if the table does not exist. 885 | // To allow certain options to append when create a new table, which can be 886 | // table_options or partition_options. 887 | // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" 888 | val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" 889 | val statement = conn.createStatement 890 | try { 891 | statement.executeUpdate(sql) 892 | } finally { 893 | statement.close() 894 | } 895 | } 896 | } 897 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/jdbc/JdbcSink.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.jdbc 2 | 3 | import org.apache.spark.internal.Logging 4 | import org.apache.spark.rdd.RDD 5 | import org.apache.spark.sql.catalyst.CatalystTypeConverters 6 | import org.apache.spark.sql.execution.streaming.Sink 7 | import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} 8 | 9 | class JdbcSink( 10 | sqlContext: SQLContext, 11 | parameters: Map[String, String]) extends Sink with Logging { 12 | @volatile private var latestBatchId = -1L 13 | 14 | override def toString(): String = "JdbcSink" 15 | 16 | override def addBatch(batchId: Long, data: DataFrame): Unit = { 17 | if (batchId <= latestBatchId) { 18 | logInfo(s"Skipping already committed batch $batchId") 19 | } else { 20 | val schema = data.schema 21 | val rdd: RDD[Row] = data.queryExecution.toRdd.mapPartitions { rows => 22 | val converter = CatalystTypeConverters.createToScalaConverter(schema) 23 | rows.map(converter(_).asInstanceOf[Row]) 24 | } 25 | sqlContext.createDataFrame(rdd,schema) 26 | .write 27 | .format("org.apache.spark.sql.execution.datasources.jdbc2") 28 | .mode(SaveMode.Append) 29 | .options(parameters) 30 | .save() 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/sql/execution/jdbc/JdbcSourceProvider.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql.execution.jdbc 2 | 3 | import org.apache.spark.internal.Logging 4 | import org.apache.spark.sql.SQLContext 5 | import org.apache.spark.sql.execution.streaming.Sink 6 | import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} 7 | import org.apache.spark.sql.streaming.OutputMode 8 | 9 | class JdbcSourceProvider extends DataSourceRegister 10 | with StreamSinkProvider 11 | with Logging { 12 | 13 | override def shortName(): String = "jdbc" 14 | 15 | override def createSink(sqlContext: SQLContext, 16 | parameters: Map[String, String], 17 | partitionColumns: Seq[String], 18 | outputMode: OutputMode): Sink = { 19 | new JdbcSink(sqlContext,parameters) 20 | } 21 | } 22 | --------------------------------------------------------------------------------