├── src ├── main │ ├── resources │ │ └── META-INF │ │ │ └── services │ │ │ └── org.apache.spark.sql.sources.DataSourceRegister │ ├── java │ │ └── com │ │ │ └── qwshen │ │ │ └── flight │ │ │ ├── spark │ │ │ ├── DefaultSource.java │ │ │ ├── read │ │ │ │ ├── FlightPartitionReaderFactory.java │ │ │ │ ├── FlightScan.java │ │ │ │ ├── FlightBatch.java │ │ │ │ ├── FlightInputPartition.java │ │ │ │ ├── FlightPartitionReader.java │ │ │ │ └── FlightScanBuilder.java │ │ │ ├── write │ │ │ │ ├── FlightWriteAbortException.java │ │ │ │ ├── FlightWriteBuilder.java │ │ │ │ ├── FlightWriterCommitMessage.java │ │ │ │ ├── FlightDataWriterFactory.java │ │ │ │ ├── FlightWrite.java │ │ │ │ └── FlightDataWriter.java │ │ │ ├── FlightSource.java │ │ │ └── FlightTable.java │ │ │ ├── WriteProtocol.java │ │ │ ├── Endpoint.java │ │ │ ├── QueryEndpoints.java │ │ │ ├── PushAggregation.java │ │ │ ├── FieldVector.java │ │ │ ├── QueryStatement.java │ │ │ ├── RowSet.java │ │ │ ├── Field.java │ │ │ ├── WriteBehavior.java │ │ │ ├── PartitionBehavior.java │ │ │ ├── Configuration.java │ │ │ ├── Client.java │ │ │ ├── Table.java │ │ │ └── FieldType.java │ └── scala │ │ └── com │ │ └── qwshen │ │ └── flight │ │ └── spark │ │ └── implicits.scala └── test │ ├── resources │ └── data │ │ └── events │ │ └── events.csv │ └── scala │ └── com │ └── qwshen │ └── flight │ └── spark │ └── test │ └── DremioTest.scala ├── SECURITY.md ├── docs └── tutorial.md ├── LICENSE ├── pom.xml └── README.md /src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister: -------------------------------------------------------------------------------- 1 | com.qwshen.flight.spark.FlightSource 2 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/DefaultSource.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark; 2 | 3 | /** 4 | * The default data-source 5 | */ 6 | public class DefaultSource extends FlightSource { 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/WriteProtocol.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | /** 4 | * The protocol tells the connector how to conduct the write operation 5 | */ 6 | public enum WriteProtocol { 7 | //literal sql statements are submitted 8 | LITERAL_SQL, 9 | //prepared sql statements are submitted 10 | PREPARED_SQL 11 | } 12 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | Use this section to tell people about which versions of your project are 6 | currently being supported with security updates. 7 | 8 | | Version | Supported | 9 | | ------- | ------------------ | 10 | | 5.1.x | :white_check_mark: | 11 | | 5.0.x | :x: | 12 | | 4.0.x | :white_check_mark: | 13 | | < 4.0 | :x: | 14 | 15 | ## Reporting a Vulnerability 16 | 17 | Use this section to tell people how to report a vulnerability. 18 | 19 | Tell them where to go, how often they can expect to get an update on a 20 | reported vulnerability, what to expect if the vulnerability is accepted or 21 | declined, etc. 22 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/Endpoint.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import java.io.Serializable; 4 | import java.net.URI; 5 | 6 | /** 7 | * Describes the data-structure of a flight end-point for connections 8 | */ 9 | public class Endpoint implements Serializable { 10 | //the URIs of the end-point 11 | private final URI[] _uris; 12 | //the ticket for connecting to the end-point 13 | private final byte[] _ticket; 14 | 15 | /** 16 | * Construct an end-point 17 | * @param uris - the URIs of the end-point 18 | * @param ticket - the ticket for connecting to the end-point 19 | */ 20 | public Endpoint(URI[] uris, byte[] ticket) { 21 | this._uris = uris; 22 | this._ticket = ticket; 23 | } 24 | 25 | /** 26 | * Get the URIs of the end-point 27 | * @return - the URIs of the end-point 28 | */ 29 | public URI[] getURIs() { 30 | return this._uris; 31 | } 32 | 33 | /** 34 | * Get the ticket of the end-point 35 | * @return - the ticket of the end-point 36 | */ 37 | public byte[] getTicket() { 38 | return this._ticket; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/QueryEndpoints.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.arrow.vector.types.pojo.Schema; 4 | import java.io.Serializable; 5 | 6 | /** 7 | * Describes the data-structure from executing a query on remote flight service 8 | */ 9 | public class QueryEndpoints implements Serializable { 10 | //the schema 11 | private final Schema _schema; 12 | //the collection of end-points exposed for the query 13 | private final Endpoint[] _endpoints; 14 | 15 | /** 16 | * Construct a QueryEndpoints 17 | * @param schema - the schema of the query result 18 | * @param endpoints - end end-points exposed on the remote flight-service for fetching data 19 | */ 20 | public QueryEndpoints(Schema schema, Endpoint[] endpoints) { 21 | this._schema = schema; 22 | this._endpoints = endpoints; 23 | } 24 | 25 | /** 26 | * Get the Schema 27 | * @return - the schema of the QueryEndpoints 28 | */ 29 | public Schema getSchema() { 30 | return this._schema; 31 | } 32 | 33 | /** 34 | * Get the end-points 35 | * @return - the end-points of the QueryEndpoints 36 | */ 37 | public Endpoint[] getEndpoints() { 38 | return this._endpoints; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/read/FlightPartitionReaderFactory.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.read; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import org.apache.spark.sql.catalyst.InternalRow; 5 | import org.apache.spark.sql.connector.read.InputPartition; 6 | import org.apache.spark.sql.connector.read.PartitionReader; 7 | import org.apache.spark.sql.connector.read.PartitionReaderFactory; 8 | 9 | /** 10 | * The flight partition-reader factory for creating flight partition-readers 11 | */ 12 | public class FlightPartitionReaderFactory implements PartitionReaderFactory { 13 | private final Configuration _configuration; 14 | 15 | /** 16 | * Construct a flight partition-reader factory 17 | * @param configuration - the configuration of remote flight service 18 | */ 19 | public FlightPartitionReaderFactory(Configuration configuration) { 20 | this._configuration = configuration; 21 | } 22 | 23 | /** 24 | * Create a reader 25 | * @param inputPartition - the input-partition for the reader 26 | * @return - a partition-reader 27 | */ 28 | public PartitionReader createReader(InputPartition inputPartition) { 29 | return new FlightPartitionReader(this._configuration, inputPartition); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/scala/com/qwshen/flight/spark/implicits.scala: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark 2 | 3 | import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter} 4 | 5 | object implicits { 6 | /** 7 | * Flight DataFrameReader is to load data into a DataFrame from remote flight service 8 | * @param r - the data-frame reader 9 | */ 10 | implicit class FlightDataFrameReader(r: DataFrameReader) { 11 | /** 12 | * Given the format as flight 13 | * @return - data-frame loaded from remote flight service 14 | */ 15 | def flight(): DataFrame = r.format("flight").load() 16 | 17 | /** 18 | * Given the format as flight with an input table 19 | * @param table - the name of a table being queried against 20 | * @return - data-frame loaded from remote flight service 21 | */ 22 | def flight(table: String): DataFrame = r.format("flight").option("table", table).load() 23 | } 24 | 25 | /** 26 | * Flight DataFrameWriter is to write a DataFrame into remote flight service 27 | * @param w - the data-frame writer 28 | */ 29 | implicit class FlightDataFrameWriter(r: DataFrameWriter[org.apache.spark.sql.DataFrame]) { 30 | /** 31 | * Given the format as flight 32 | */ 33 | def flight(): Unit = r.format("flight").save() 34 | 35 | /** 36 | * Given the format as flight with an input table 37 | * @param table - the name of a table being written into 38 | */ 39 | def flight(table: String): Unit = r.format("flight").option("table", table).save() 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/read/FlightScan.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.read; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import com.qwshen.flight.Table; 5 | import org.apache.spark.sql.connector.read.Batch; 6 | import org.apache.spark.sql.connector.read.Scan; 7 | import org.apache.spark.sql.types.StructType; 8 | import java.io.Serializable; 9 | 10 | /** 11 | * Describes the data-structure of FlightScan 12 | */ 13 | public class FlightScan implements Scan, Serializable { 14 | private final Configuration _configuration; 15 | private final Table _table; 16 | 17 | /** 18 | * Construct a FligthScan 19 | * @param configuration - the configuration of remote flight service 20 | * @param table - the table object 21 | */ 22 | public FlightScan(Configuration configuration, Table table) { 23 | this._configuration = configuration; 24 | this._table = table; 25 | } 26 | 27 | /** 28 | * Get the schema of the scan 29 | * @return - the scan for the scan 30 | */ 31 | @Override 32 | public StructType readSchema() { 33 | return this._table.getSparkSchema(); 34 | } 35 | 36 | /** 37 | * The description of the scan 38 | * @return - description 39 | */ 40 | @Override 41 | public String description() { 42 | return this._table.getQueryStatement(); 43 | } 44 | 45 | /** 46 | * Translate the scan to batch 47 | * @return - the batch desribes the scan 48 | */ 49 | @Override 50 | public Batch toBatch() { 51 | return new FlightBatch(this._configuration, this._table); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/test/resources/data/events/events.csv: -------------------------------------------------------------------------------- 1 | event_id,user_id,start_time,city,province,country 2 | 1918771225,4106419938,2012-10-03T08:00:00.002Z,,,, 3 | 1502284248,2016654644,2012-10-03T11:00:00.003Z,Ottawa,,Canada 4 | 2529072432,3639934255,2012-10-26T13:30:00.003Z,Toronto,,Canada 5 | 3072478280,97461525,2012-10-06T05:00:00.003Z,,, 6 | 1390707377,3639934255,2012-10-06T03:00:00.003Z,Toronto,,Canada 7 | 152418051,1618377432,2012-11-03T07:00:00.003Z,Ottawa,,Canada 8 | 4203627753,415464198,2012-10-31T00:00:00.001Z,,, 9 | 110357109,937597069,2012-10-30T00:00:00.001Z,,,, 10 | 799782433,1251929142,2012-11-01T16:00:00.003Z,,, 11 | 823015621,3086474574,2012-11-01T13:00:00.003Z,,,, 12 | 2790605371,3716495692,2012-11-04T00:00:00.001Z,,, 13 | 753115138,1350826506,2012-11-07T03:00:00.003Z,,, 14 | 825060275,2522070769,2012-11-06T15:00:00.003Z,,,, 15 | 1065213296,2522070769,2012-11-07T16:00:00.003Z,,,, 16 | 1073827062,1224027666,2012-11-07T03:00:00.003Z,Montreal,,Canada 17 | 1807884467,781107922,2012-11-25T01:45:00.003Z,,,, 18 | 2590444754,4268471513,2012-11-22T00:00:00.001Z,,,, 19 | 1868735086,2109081461,2012-10-30T00:00:00.001Z,,,, 20 | 1455527953,2109081461,2012-10-30T00:00:00.001Z,,,, 21 | 3487526512,1425684001,2012-10-31T03:00:00.003Z,,,, 22 | 150148297,1795700034,2012-10-31T08:30:00.003Z,Waterloo,ON,Canada 23 | 1461387354,3131071686,2012-10-31T13:00:00.003Z,,,, 24 | 2475079669,3291055841,2012-10-06T06:45:00.003Z,,,, 25 | 955398943,3286716293,2012-11-23T13:30:00.003Z,Toronto,,Canada 26 | 1936167908,2670616496,2012-10-05T21:00:00.003Z,,, 27 | 3858223520,1895222667,2012-10-07T01:00:00.003Z,Toronto,,Canada 28 | 1532377761,3286716293,2012-10-06T05:00:00.003Z,Toronto,,Canada 29 | 2529072432,3639934255,2012-10-26T13:30:00.003Z,Toronto,,Canada -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/PushAggregation.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Describes the data structure for pushed-down aggregation 7 | */ 8 | public class PushAggregation implements Serializable { 9 | //pushed-down Aggregate-Columns (expressions) 10 | private String[] _columnExpressions = null; 11 | //pushed-down GroupBy-Columns 12 | private String[] _groupByColumns = null; 13 | 14 | /** 15 | * Push down aggregation of columns 16 | * select max(age), sum(distinct amount) from table where ... 17 | * @param columnExpressions - the collection of aggregation expressions 18 | */ 19 | public PushAggregation(String[] columnExpressions) { 20 | this._columnExpressions = columnExpressions; 21 | } 22 | 23 | /** 24 | * Push down aggregation with group by columns 25 | * select max(age), sum(amount) from table where ... group by gender 26 | * @param columnExpressions - the collection of aggregation expressions 27 | * @param groupByColumns - the columns in group by 28 | */ 29 | public PushAggregation(String[] columnExpressions, String[] groupByColumns) { 30 | this(columnExpressions); 31 | this._groupByColumns = groupByColumns; 32 | } 33 | 34 | /** 35 | * Return the collection of aggregation expressions 36 | * @return - the expressions 37 | */ 38 | public String[] getColumnExpressions() { 39 | return this._columnExpressions; 40 | } 41 | 42 | /** 43 | * The columns for group-by 44 | * @return - columns 45 | */ 46 | public String[] getGroupByColumns() { 47 | return this._groupByColumns; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/FieldVector.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Describes the format of a field-vector 7 | */ 8 | public class FieldVector implements Serializable { 9 | //the field information 10 | private final com.qwshen.flight.Field _field; 11 | //the values in the vector 12 | private final Object[] _values; 13 | 14 | /** 15 | * Construct a FieldVector 16 | * @param name - the nanem of the field 17 | * @param type - the data type of the field 18 | * @param values - the objects in the vector 19 | */ 20 | public FieldVector(String name, FieldType type, Object[] values) { 21 | this._field = new com.qwshen.flight.Field(name, type); 22 | this._values = values; 23 | } 24 | 25 | /** 26 | * Get the field 27 | * @return - the field for this vector 28 | */ 29 | public com.qwshen.flight.Field getField() { 30 | return this._field; 31 | } 32 | 33 | /** 34 | * Get the data in the vector 35 | * @return - data in the vector 36 | */ 37 | public Object[] getValues() { 38 | return this._values; 39 | } 40 | 41 | /** 42 | * Convert an arrow-FieldVector into a custom FieldVector 43 | * @param vector - the arrow field-vector 44 | * @param type - the data type of the field for the vector 45 | * @param rowCount - number of rows in the vector 46 | * @return - an instance of the custom FieldVector 47 | */ 48 | public static FieldVector fromArrow(org.apache.arrow.vector.FieldVector vector, FieldType type, int rowCount) { 49 | return ArrowConversion.getOrCreate().convert(vector, type, rowCount); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/write/FlightWriteAbortException.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.write; 2 | 3 | import java.io.IOException; 4 | import java.io.Serializable; 5 | 6 | public class FlightWriteAbortException extends IOException implements Serializable { 7 | /** 8 | * Construct a success Flight-Writer-Commit-Message 9 | * @param partitionId - the partition-id of the data-frame been written 10 | * @param taskId - the task id of the writing operation 11 | * @param messageCount - number of rows been written 12 | */ 13 | public FlightWriteAbortException(int partitionId, long taskId, long messageCount) { 14 | super(getMessage(partitionId, taskId, messageCount)); 15 | } 16 | 17 | /** 18 | * Construct a failure Flight-Writer-Commit-Message 19 | * @param partitionId - the partition-id of the data-frame been written 20 | * @param taskId - the task id of the writing operation 21 | * @param epochId - the epoch-id for streaming write. 22 | * @param messageCount - number of rows been written 23 | */ 24 | public FlightWriteAbortException(int partitionId, long taskId, String epochId, long messageCount) { 25 | super(getMessage(partitionId, taskId, epochId, messageCount)); 26 | } 27 | 28 | //form the error message 29 | private static String getMessage(int partitionId, long taskId, long messageCount) { 30 | return String.format("Streaming write for %d messages with partition (%d), task (%d) aborted.", messageCount, partitionId, taskId); 31 | } 32 | 33 | //form the error message 34 | private static String getMessage(int partitionId, long taskId, String epochId, long messageCount) { 35 | return String.format("Streaming write for %d messages with partition (%d), task (%d) and epoch (%s) aborted.", messageCount, partitionId, taskId, epochId); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/write/FlightWriteBuilder.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.write; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import com.qwshen.flight.Table; 5 | import com.qwshen.flight.WriteBehavior; 6 | import org.apache.spark.sql.connector.write.SupportsTruncate; 7 | import org.apache.spark.sql.connector.write.Write; 8 | import org.apache.spark.sql.connector.write.WriteBuilder; 9 | import org.apache.spark.sql.types.StructType; 10 | 11 | /** 12 | * The flight write builder to build flight writers 13 | */ 14 | public class FlightWriteBuilder implements WriteBuilder, SupportsTruncate { 15 | private final Configuration _configuration; 16 | private final Table _table; 17 | private final StructType _dataSchema; 18 | private final WriteBehavior _writeBehavior; 19 | 20 | /** 21 | * Construct a builder for creating flight writers 22 | * @param configuration - the configuraton of remote flight service 23 | * @param table - the table object for describing the target flight table 24 | * @param dataSchema - the schema of the data being written 25 | * @param writeBehavior - the write-behavior 26 | */ 27 | public FlightWriteBuilder(Configuration configuration, Table table, StructType dataSchema, WriteBehavior writeBehavior) { 28 | this._configuration = configuration; 29 | this._table = table; 30 | this._dataSchema = dataSchema; 31 | this._writeBehavior = writeBehavior; 32 | } 33 | 34 | @Override 35 | public Write build() { 36 | return new FlightWrite(this._configuration, this._table, this._dataSchema, this._writeBehavior); 37 | } 38 | 39 | /** 40 | * flag to truncate the target table 41 | * @return - the write-build which truncates the target table 42 | */ 43 | @Override 44 | public WriteBuilder truncate() { 45 | this._writeBehavior.truncate(); 46 | return this; 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/read/FlightBatch.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.read; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import com.qwshen.flight.Table; 5 | import org.apache.spark.sql.connector.read.Batch; 6 | import org.apache.spark.sql.connector.read.InputPartition; 7 | import org.apache.spark.sql.connector.read.PartitionReaderFactory; 8 | import java.io.Serializable; 9 | import java.util.Arrays; 10 | 11 | /** 12 | * Describes a flight-batch 13 | */ 14 | public class FlightBatch implements Batch, Serializable { 15 | private final Configuration _configuration; 16 | private final Table _table; 17 | 18 | /** 19 | * Construct a FligthBatch for a scan 20 | * @param configuration - the configuration of remote flight service 21 | * @param table - the table object 22 | */ 23 | public FlightBatch(Configuration configuration, Table table) { 24 | this._configuration = configuration; 25 | this._table = table; 26 | } 27 | 28 | /** 29 | * Plan the input-partitions 30 | * @return - the logical partitions 31 | */ 32 | @Override 33 | public InputPartition[] planInputPartitions() { 34 | String[] partitionQueries = this._table.getPartitionStatements(); 35 | return (partitionQueries.length > 0) 36 | ? Arrays.stream(partitionQueries).map(q -> new FlightInputPartition.FlightQueryInputPartition(this._table.getSchema(), q)).toArray(InputPartition[]::new) 37 | : Arrays.stream(this._table.getEndpoints()).map(e -> new FlightInputPartition.FlightEndpointInputPartition(this._table.getSchema(), e)).toArray(InputPartition[]::new); 38 | } 39 | 40 | /** 41 | * Create a partition reader factory 42 | * @return - the partition reader factory 43 | */ 44 | @Override 45 | public PartitionReaderFactory createReaderFactory() { 46 | return new FlightPartitionReaderFactory(this._configuration); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/QueryStatement.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * The read statement for querying data from remote flight service 7 | */ 8 | public class QueryStatement implements Serializable { 9 | private final String _stmt; 10 | private final String _where; 11 | private final String _groupBy; 12 | 13 | /** 14 | * Construct a ReadStatement 15 | * @param stmt - the select portion of a select-statement 16 | * @param where - the where portion of a select-statement 17 | * @param groupBy - the groupBy portion of a select-statement 18 | */ 19 | public QueryStatement(String stmt, String where, String groupBy) { 20 | this._stmt = stmt; 21 | this._where = where; 22 | this._groupBy = groupBy; 23 | } 24 | 25 | /** 26 | * Check if the current ReadStatement is different from the input ReadStatement 27 | * @param rs - one ReadStatement to be compared 28 | * @return - true if they are different 29 | */ 30 | public boolean different(QueryStatement rs) { 31 | boolean changed = (rs == null || !rs._stmt.equalsIgnoreCase(this._stmt)); 32 | if (!changed) { 33 | changed = (rs._where != null) ? !rs._where.equalsIgnoreCase(this._where) : this._where != null; 34 | } 35 | if (!changed) { 36 | changed = (rs._groupBy != null) ? !rs._groupBy.equalsIgnoreCase(this._groupBy) : this._groupBy != null; 37 | } 38 | return changed; 39 | } 40 | 41 | /** 42 | * Get the whole select-statement 43 | * @return - the select-statement 44 | */ 45 | public String getStatement() { 46 | return String.format("%s %s %s", this._stmt, 47 | (this._where != null && this._where.length() > 0) ? String.format("where %s", this._where) : "", 48 | (this._groupBy != null && this._groupBy.length() > 0) ? String.format("group by %s", this._groupBy) : "" 49 | ); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/RowSet.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.arrow.vector.types.pojo.Schema; 4 | 5 | import java.io.Serializable; 6 | import java.util.ArrayList; 7 | 8 | /** 9 | * Describes row batch from remote flight service 10 | */ 11 | public class RowSet implements Serializable { 12 | /** 13 | * The definition of Row 14 | */ 15 | public static class Row implements Serializable { 16 | private final ArrayList _data; 17 | 18 | public Row() { 19 | this._data = new ArrayList<>(); 20 | } 21 | 22 | public void add(Object o) { 23 | this._data.add(o); 24 | } 25 | 26 | public Object[] getData() { 27 | return this._data.toArray(new Object[0]); 28 | } 29 | } 30 | 31 | //the schema of each ROw 32 | private final Schema _schema; 33 | //the row collection 34 | private final ArrayList _data; 35 | 36 | /** 37 | * Construct a RowSet 38 | * @param schema - the schema of each row in the collection 39 | */ 40 | public RowSet(Schema schema) { 41 | this._schema = schema; 42 | this._data = new ArrayList<>(); 43 | } 44 | 45 | /** 46 | * Get the schema of the RowSet 47 | * @return - the schema 48 | */ 49 | public Schema getSchema() { 50 | return this._schema; 51 | } 52 | 53 | /** 54 | * Add one Row 55 | * @param row - the row to be added 56 | */ 57 | public void add(Row row) { 58 | this._data.add(row); 59 | } 60 | 61 | /** 62 | * Add all rows from another RowSet 63 | * @param rs - the input RowSet 64 | */ 65 | public void add(RowSet rs) { 66 | if (rs._schema != this._schema) { 67 | throw new RuntimeException("The schema doesn't match. Cannot add the RowSet."); 68 | } 69 | this._data.addAll(rs._data); 70 | } 71 | 72 | /** 73 | * Get all Rows 74 | * @return - all rows in the RowSet 75 | */ 76 | public Row[] getData() { 77 | return this._data.toArray(new Row[] {}); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/Field.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.arrow.vector.types.pojo.Schema; 4 | import java.io.Serializable; 5 | import java.util.Arrays; 6 | import java.util.Optional; 7 | 8 | /** 9 | * Describes a field data-structure, such as the field name and date-type 10 | */ 11 | public class Field implements Serializable { 12 | //the name of the field 13 | private final String _name; 14 | //the data type 15 | private final FieldType _type; 16 | 17 | /** 18 | * Construct a Field 19 | * @param name - the name of the field 20 | * @param type - the data type of the field 21 | */ 22 | public Field(String name, FieldType type) { 23 | this._name = name; 24 | this._type = type; 25 | } 26 | 27 | /** 28 | * Get the name of the field 29 | * @return - the name 30 | */ 31 | public String getName() { 32 | return this._name; 33 | } 34 | 35 | /** 36 | * Get the datatype of the field 37 | * @return - the type 38 | */ 39 | public FieldType getType() { 40 | return this._type; 41 | } 42 | 43 | /** 44 | * Get the hash-code of the field 45 | * @return - the value of the hash-code 46 | */ 47 | public int hashCode() { 48 | return this._name.hashCode(); 49 | } 50 | 51 | /** 52 | * Find the field-type of a field by name 53 | * @param fields - the field collection from which to search 54 | * @param name - the name of a field to be searched 55 | * @return - the field type matching the input name 56 | */ 57 | public static FieldType find(Field[] fields, String name) { 58 | Optional fs = Arrays.stream(fields).filter(s -> s.getName().equalsIgnoreCase(name)).findFirst(); 59 | if (!fs.isPresent()) { 60 | throw new RuntimeException("The field with " + name + " doesn't exist."); 61 | } 62 | return fs.get().getType(); 63 | } 64 | 65 | /** 66 | * Extract all fields from the arrow-flight schema 67 | * @param schema - the arrow schema 68 | * @return - fields from the schema 69 | */ 70 | public static Field[] from(Schema schema) { 71 | return schema.getFields().stream().map(f -> new Field(f.getName(), FieldType.fromArrow(f.getType(), f.getChildren()))).toArray(Field[]::new); 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/WriteBehavior.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | 6 | /** 7 | * Defines the write-behavior 8 | */ 9 | public class WriteBehavior implements Serializable { 10 | private final WriteProtocol _protocol; 11 | private final Map _typeMapping; 12 | 13 | private final int _batchSize; 14 | private final String[] _mergeByColumns; 15 | 16 | //the truncate flag 17 | private Boolean _truncate = false; 18 | 19 | /** 20 | * Construct a WriteBehavior 21 | * @param protocol - the protocol for submitting DML requests. It must be either literal-sql or prepared-sql 22 | * @param batchSize - the size of each batch to be written 23 | * @param mergeByColumn - the columns on which to merge data into the target table 24 | * @param typeMapping - the arrow-type to target data-type mapping 25 | */ 26 | public WriteBehavior(WriteProtocol protocol, int batchSize, String[] mergeByColumn, Map typeMapping) { 27 | this._protocol = protocol; 28 | this._batchSize = batchSize; 29 | this._mergeByColumns = mergeByColumn; 30 | this._typeMapping = typeMapping; 31 | } 32 | 33 | /** 34 | * Get the write-procotol 35 | * @return - the protocol for writing 36 | */ 37 | public WriteProtocol getProtocol() { 38 | return this._protocol; 39 | } 40 | 41 | /** 42 | * Get the size of each batch 43 | * @return - the size of batch for writing 44 | */ 45 | public int getBatchSize() { 46 | return this._batchSize; 47 | } 48 | 49 | /** 50 | * Get the merge-by columns 51 | * @return - the columns on which to merge data into the target table 52 | */ 53 | public String[] getMergeByColumns() { 54 | return isTruncate() ? new String[0] : this._mergeByColumns; 55 | } 56 | 57 | /** 58 | * Get the type-mapping 59 | * @return - the mapping between arrow-type & target data-types 60 | */ 61 | public Map getTypeMapping() { return this._typeMapping; } 62 | 63 | /** 64 | * set the flag to truncate the target table 65 | */ 66 | public void truncate() { 67 | if (this._mergeByColumns != null && this._mergeByColumns.length > 0) { 68 | throw new RuntimeException("The merge-by can only work with append mode."); 69 | } 70 | this._truncate = true; 71 | } 72 | 73 | /** 74 | * Flag to truncate the target table 75 | * @return - true if it is to truncate the target table 76 | */ 77 | public Boolean isTruncate() { 78 | return this._truncate; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/read/FlightInputPartition.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.read; 2 | 3 | import com.qwshen.flight.Endpoint; 4 | import org.apache.arrow.vector.types.pojo.Schema; 5 | import org.apache.spark.sql.connector.read.InputPartition; 6 | import java.io.IOException; 7 | import java.io.Serializable; 8 | 9 | /** 10 | * Describes the data-structure of flight input-partitions 11 | */ 12 | public class FlightInputPartition implements InputPartition, Serializable { 13 | /** 14 | * The input-partition with end-points 15 | */ 16 | public static class FlightEndpointInputPartition extends FlightInputPartition { 17 | private final Endpoint _ep; 18 | 19 | /** 20 | * Construct a flight end-point input-partition 21 | * @param schema - the schema of the partition 22 | * @param ep - the end-point of the partition 23 | */ 24 | public FlightEndpointInputPartition(Schema schema, Endpoint ep) { 25 | super(schema); 26 | this._ep = ep; 27 | } 28 | 29 | /** 30 | * Get the end-point of the input-partition 31 | * @return - the end-point of the input partition 32 | */ 33 | public Endpoint getEndpoint() { 34 | return this._ep; 35 | } 36 | } 37 | 38 | /** 39 | * The input-partition with query 40 | */ 41 | public static class FlightQueryInputPartition extends FlightInputPartition { 42 | private final String _query; 43 | 44 | /** 45 | * Construct a flight query input-partition 46 | * @param schema - the schema of the partition 47 | * @param query - the query for the partition 48 | */ 49 | public FlightQueryInputPartition(Schema schema, String query) { 50 | super(schema); 51 | this._query = query; 52 | } 53 | 54 | /** 55 | * Get the query of the input-partition 56 | * @return - the query of the input-partition 57 | */ 58 | public String getQuery() { 59 | return this._query; 60 | } 61 | } 62 | 63 | //the schema of the input partition 64 | private final String _schema; 65 | 66 | /** 67 | * Construct a flight-input-partition 68 | * @param schema - the schema of the partition 69 | */ 70 | protected FlightInputPartition(Schema schema) { 71 | this._schema = schema.toJson(); 72 | } 73 | 74 | /** 75 | * Get the schema of the partition 76 | * @return - the schema of the partition 77 | * @throws IOException - thrown when the schema is in invalid json format. 78 | */ 79 | public Schema getSchema() throws IOException { 80 | return Schema.fromJSON(this._schema); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/write/FlightWriterCommitMessage.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.write; 2 | 3 | import org.apache.spark.sql.connector.write.WriterCommitMessage; 4 | import java.io.Serializable; 5 | 6 | /** 7 | * Describes the flight writer commit message 8 | */ 9 | public class FlightWriterCommitMessage implements WriterCommitMessage, Serializable { 10 | private final int _partitionId; 11 | private final long _taskId; 12 | private final String _epochId; 13 | private final long _messageCount; 14 | 15 | /** 16 | * Construct a success Flight-Writer-Commit-Message 17 | * @param partitionId - the partition-id of the data-frame been written 18 | * @param taskId - the task id of the writing operation 19 | * @param messageCount - number of rows been written 20 | */ 21 | public FlightWriterCommitMessage(int partitionId, long taskId, long messageCount) { 22 | this._partitionId = partitionId; 23 | this._taskId = taskId; 24 | this._epochId = ""; 25 | this._messageCount = messageCount; 26 | } 27 | 28 | /** 29 | * Construct a failure Flight-Writer-Commit-Message 30 | * @param partitionId - the partition-id of the data-frame been written 31 | * @param taskId - the task id of the writing operation 32 | * @param epochId - the epoch-id for streaming write. 33 | * @param messageCount - number of rows been written 34 | */ 35 | public FlightWriterCommitMessage(int partitionId, long taskId, String epochId, long messageCount) { 36 | this._partitionId = partitionId; 37 | this._taskId = taskId; 38 | this._epochId = epochId; 39 | this._messageCount = messageCount; 40 | } 41 | 42 | /** 43 | * Construct a failure Flight-Writer-Commit-Message 44 | * @param message - the base write commit message 45 | * @param epochId - the epoch-id for streaming write. 46 | */ 47 | public FlightWriterCommitMessage(FlightWriterCommitMessage message, long epochId) { 48 | this._partitionId = message._partitionId; 49 | this._taskId = message._taskId; 50 | this._epochId = Long.toString(epochId); 51 | this._messageCount = message._messageCount; 52 | } 53 | 54 | /** 55 | * form the committed message 56 | * @return - the commit message 57 | */ 58 | public String getMessage() { 59 | return (this._epochId == null || this._epochId.length() == 0) ? String.format("Streaming write for %d messages with partition (%d), task (%d) committed.", this._messageCount, this._partitionId, this._taskId) 60 | : String.format("Streaming write for %d messages with partition (%d), task (%d) and epoch (%s) committed.", this._messageCount, this._partitionId, this._taskId, this._epochId); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/write/FlightDataWriterFactory.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.write; 2 | 3 | import com.qwshen.flight.*; 4 | import org.apache.spark.sql.catalyst.InternalRow; 5 | import org.apache.spark.sql.connector.write.DataWriter; 6 | import org.apache.spark.sql.connector.write.DataWriterFactory; 7 | import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory; 8 | import org.apache.spark.sql.types.StructType; 9 | import org.slf4j.LoggerFactory; 10 | import java.util.Arrays; 11 | 12 | /** 13 | * Defines the FLightDataWriterFactory to create DataWriters 14 | */ 15 | public class FlightDataWriterFactory implements DataWriterFactory, StreamingDataWriterFactory { 16 | private final Configuration _configuration; 17 | 18 | private final WriteStatement _stmt; 19 | private final WriteProtocol _protocol; 20 | private final int _batchSize; 21 | 22 | /** 23 | * Construct a flight-write 24 | * @param configuration - the configuration of remote flight service 25 | * @param table - the table object for describing the target flight table 26 | * @param dataSchema - the schema of data being written 27 | * @param writeBehavior - the write-behavior 28 | */ 29 | public FlightDataWriterFactory(Configuration configuration, Table table, StructType dataSchema, WriteBehavior writeBehavior) { 30 | this._configuration = configuration; 31 | this._stmt = (writeBehavior.getMergeByColumns() == null || writeBehavior.getMergeByColumns().length == 0) 32 | ? new WriteStatement(table.getName(), dataSchema, table.getSchema(), table.getColumnQuote(), writeBehavior.getTypeMapping()) 33 | : new WriteStatement(table.getName(), writeBehavior.getMergeByColumns(), dataSchema, table.getSchema(), table.getColumnQuote(), writeBehavior.getTypeMapping()); 34 | this._protocol = writeBehavior.getProtocol(); 35 | this._batchSize = writeBehavior.getBatchSize(); 36 | 37 | //truncate the table if requested 38 | if (writeBehavior.isTruncate()) { 39 | this.truncate(table.getName()); 40 | } 41 | } 42 | 43 | /** 44 | * truncate the target table 45 | * @param table - the name of the table 46 | */ 47 | private void truncate(String table) { 48 | try { 49 | Client.getOrCreate(this._configuration).truncate(table); 50 | } catch (Exception e) { 51 | LoggerFactory.getLogger(this.getClass()).error(e.getMessage() + " --> " + Arrays.toString(e.getStackTrace())); 52 | throw new RuntimeException(e); 53 | } 54 | } 55 | 56 | /** 57 | * Create a DataWriter for batch-write 58 | * @param partitionId - the partition id 59 | * @param taskId - the task id 60 | * @return - a DataWriter 61 | */ 62 | @Override 63 | public DataWriter createWriter(int partitionId, long taskId) { 64 | return new FlightDataWriter(partitionId, taskId, this._configuration, this._stmt, this._protocol, this._batchSize); 65 | } 66 | 67 | /** 68 | * Create a DataWriter for streaming-write 69 | * @param partitionId - the partition id 70 | * @param taskId - the task id 71 | * @param epochId - a monotonically increasing id for streaming queries that are split into discrete periods of execution. 72 | * @return - a DataWriter 73 | */ 74 | @Override 75 | public DataWriter createWriter(int partitionId, long taskId, long epochId) { 76 | return new FlightDataWriter(partitionId, taskId, epochId, this._configuration, this._stmt, this._protocol, this._batchSize); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/read/FlightPartitionReader.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.read; 2 | 3 | import com.qwshen.flight.Client; 4 | import com.qwshen.flight.Configuration; 5 | import com.qwshen.flight.QueryEndpoints; 6 | import com.qwshen.flight.RowSet; 7 | import org.apache.spark.sql.catalyst.InternalRow; 8 | import org.apache.spark.sql.connector.read.InputPartition; 9 | import org.apache.spark.sql.connector.read.PartitionReader; 10 | import org.slf4j.LoggerFactory; 11 | import scala.collection.JavaConverters; 12 | import java.io.IOException; 13 | import java.util.Arrays; 14 | 15 | /** 16 | * Describes a partition-reader for reading partitions from remote flight service 17 | */ 18 | public class FlightPartitionReader implements PartitionReader { 19 | private final Configuration _configuration; 20 | private final InputPartition _inputPartition; 21 | 22 | //hold all rows 23 | private RowSet.Row[] _rows = null; 24 | //the row pointer 25 | private int _curIdx = 0; 26 | 27 | /** 28 | * Constructor a partion reader 29 | * @param configuration - the configuration of remote flight service 30 | * @param inputPartition - the input-partition 31 | */ 32 | public FlightPartitionReader(Configuration configuration, InputPartition inputPartition) { 33 | this._configuration = configuration; 34 | this._inputPartition = inputPartition; 35 | } 36 | 37 | //fetch data from remote flight service 38 | private void execute() throws IOException { 39 | if (this._inputPartition instanceof FlightInputPartition.FlightEndpointInputPartition) { 40 | this.executeEndpoint((FlightInputPartition.FlightEndpointInputPartition)this._inputPartition); 41 | } else if (this._inputPartition instanceof FlightInputPartition.FlightQueryInputPartition) { 42 | this.executeQuery((FlightInputPartition.FlightQueryInputPartition)this._inputPartition); 43 | } 44 | } 45 | //fetch data by end-point 46 | private void executeEndpoint(FlightInputPartition.FlightEndpointInputPartition dePartition) throws IOException { 47 | try { 48 | Client client = Client.getOrCreate(this._configuration); 49 | //fetch the data 50 | this._rows = client.fetch(dePartition.getEndpoint(), dePartition.getSchema()).getData(); 51 | //reset the pointer 52 | this._curIdx = 0; 53 | } catch (Exception e) { 54 | LoggerFactory.getLogger(this.getClass()).error(e.getMessage() + Arrays.toString(e.getStackTrace())); 55 | throw new IOException(e); 56 | } 57 | } 58 | //fetch data by query 59 | private void executeQuery(FlightInputPartition.FlightQueryInputPartition dqPartition) throws IOException { 60 | try { 61 | Client client = Client.getOrCreate(this._configuration); 62 | //get all end-points of the query 63 | QueryEndpoints qeps = client.getQueryEndpoints(dqPartition.getQuery()); 64 | //fetch the data 65 | this._rows = client.fetch(qeps).getData(); 66 | //reset the pointer 67 | this._curIdx = 0; 68 | } catch (Exception e) { 69 | LoggerFactory.getLogger(this.getClass()).error(e.getMessage() + Arrays.toString(e.getStackTrace())); 70 | throw new IOException(e); 71 | } 72 | } 73 | 74 | /** 75 | * Move to the next row 76 | * @return - true if next row is available, otherwise false 77 | * @throws IOException - throws if the read fails 78 | */ 79 | @Override 80 | public boolean next() throws IOException { 81 | if (this._rows == null) { 82 | this.execute(); 83 | } 84 | return (this._rows != null && this._curIdx < this._rows.length); 85 | } 86 | 87 | /** 88 | * Get the current row 89 | * @return - the current row 90 | */ 91 | @Override 92 | public InternalRow get() { 93 | return InternalRow.fromSeq(JavaConverters.asScalaBuffer(Arrays.asList(this._rows[this._curIdx++].getData())).toSeq()); 94 | } 95 | 96 | /** 97 | * Close the reader 98 | */ 99 | @Override 100 | public void close() { 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/write/FlightWrite.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.write; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import com.qwshen.flight.Table; 5 | import com.qwshen.flight.WriteBehavior; 6 | import org.apache.spark.sql.connector.write.*; 7 | import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory; 8 | import org.apache.spark.sql.connector.write.streaming.StreamingWrite; 9 | import org.apache.spark.sql.types.StructType; 10 | import org.slf4j.Logger; 11 | import org.slf4j.LoggerFactory; 12 | import java.util.Arrays; 13 | import java.util.function.BiFunction; 14 | import java.util.function.Function; 15 | 16 | /** 17 | * Defines how to write data to target flight-service 18 | */ 19 | public class FlightWrite implements Write, BatchWrite, StreamingWrite { 20 | private final Logger _logger = LoggerFactory.getLogger(this.getClass()); 21 | 22 | private final Configuration _configuration; 23 | private final Table _table; 24 | private final StructType _dataSchema; 25 | private final WriteBehavior _writeBehavior; 26 | 27 | /** 28 | * Construct a flight-write 29 | * @param configuration - the configuration of remote flight service 30 | * @param table - the table object for describing the target flight table 31 | * @param dataSchema - the schema of data being written 32 | * @param writeBehavior - the write-behavior 33 | */ 34 | public FlightWrite(Configuration configuration, Table table, StructType dataSchema, WriteBehavior writeBehavior) { 35 | this._configuration = configuration; 36 | this._table = table; 37 | this._dataSchema = dataSchema; 38 | this._writeBehavior = writeBehavior; 39 | } 40 | 41 | @Override 42 | public boolean useCommitCoordinator() { return true;} 43 | 44 | @Override 45 | public BatchWrite toBatch() { 46 | return this; 47 | } 48 | 49 | @Override 50 | public StreamingWrite toStreaming() { 51 | return this; 52 | } 53 | 54 | @Override 55 | public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo physicalWriteInfo) { 56 | return new FlightDataWriterFactory(this._configuration, this._table, this._dataSchema, this._writeBehavior); 57 | } 58 | 59 | @Override 60 | public void onDataWriterCommit(WriterCommitMessage message) { 61 | BatchWrite.super.onDataWriterCommit(message); 62 | } 63 | 64 | @Override 65 | public void commit(WriterCommitMessage[] messages) { 66 | if (this._logger.isDebugEnabled()) { 67 | this._logger.info(this.concat.apply(messages)); 68 | } 69 | } 70 | 71 | @Override 72 | public void abort(WriterCommitMessage[] messages) { 73 | this._logger.error(this.concat.apply(messages)); 74 | } 75 | 76 | @Override 77 | public StreamingDataWriterFactory createStreamingWriterFactory(PhysicalWriteInfo var1) { 78 | return new FlightDataWriterFactory(this._configuration, this._table, this._dataSchema, this._writeBehavior); 79 | } 80 | 81 | @Override 82 | public void commit(long epochId, WriterCommitMessage[] messages) { 83 | if (this._logger.isDebugEnabled()) { 84 | this._logger.info(this.concat.apply(this.adapt.apply(epochId, messages))); 85 | } 86 | } 87 | 88 | @Override 89 | public void abort(long epochId, WriterCommitMessage[] messages) { 90 | this._logger.error(this.concat.apply(this.adapt.apply(epochId, messages))); 91 | } 92 | 93 | //concatenate all commit messages 94 | private final Function concat = (messages) -> { 95 | String[] details = Arrays.stream(messages).map(message -> (message instanceof FlightWriterCommitMessage) ? ((FlightWriterCommitMessage)message).getMessage() : "").toArray(String[]::new); 96 | return String.join(System.lineSeparator(), details); 97 | }; 98 | //adapt all commit messages 99 | private final BiFunction adapt = (epochId, messages) -> Arrays.stream(messages).map(message -> (message instanceof FlightWriterCommitMessage) ? new FlightWriterCommitMessage((FlightWriterCommitMessage)message, epochId) : message).toArray(WriterCommitMessage[]::new); 100 | } 101 | 102 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/FlightSource.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import com.qwshen.flight.Table; 5 | import org.apache.spark.sql.connector.catalog.TableProvider; 6 | import org.apache.spark.sql.connector.expressions.Transform; 7 | import org.apache.spark.sql.sources.DataSourceRegister; 8 | import org.apache.spark.sql.types.StructType; 9 | import org.apache.spark.sql.util.CaseInsensitiveStringMap; 10 | import java.util.Map; 11 | 12 | /** 13 | * Define the flight-source which supports both read from and write to remote flight-service 14 | */ 15 | public class FlightSource implements TableProvider, DataSourceRegister { 16 | //the option keys for connection 17 | private static final String HOST = "host"; 18 | private static final String PORT = "port"; 19 | private static final String TLS_ENABLED = "tls.enabled"; 20 | private static final String TLS_VERIFY_SERVER = "tls.verifyServer"; 21 | private static final String TLS_TRUSTSTORE_JKS = "tls.truststore.jksFile"; 22 | private static final String TLS_TRUSTSTORE_PASS = "tls.truststore.pass"; 23 | //the option keys for account 24 | private static final String USER = "user"; 25 | private static final String PASSWORD = "password"; 26 | private static final String BEARER_TOKEN = "bearerToken"; 27 | 28 | //the option keys for table 29 | private static final String TABLE = "table"; 30 | //the quote for field in case a field containing irregular characters, such as - 31 | private static final String COLUMN_QUOTE = "column.quote"; 32 | 33 | //Managing Workloads 34 | public static final String KEY_DEFAULT_SCHEMA = "default.schema"; 35 | public static final String KEY_ROUTING_TAG = "routing.tag"; 36 | public static final String KEY_ROUTING_QUEUE = "routing.queue"; 37 | 38 | //the service configuration 39 | private Configuration _configuration = null; 40 | //the name of the table 41 | private Table _table = null; 42 | 43 | /** 44 | * Infer schema from the options 45 | * @param options - the options container 46 | * @return - the schema inferred 47 | */ 48 | @Override 49 | public StructType inferSchema(CaseInsensitiveStringMap options) { 50 | this.probeOptions(options); 51 | return this._table.getSparkSchema(); 52 | } 53 | 54 | //extract all related options 55 | private void probeOptions(CaseInsensitiveStringMap options) { 56 | //host & port 57 | String host = options.getOrDefault(FlightSource.HOST, ""); 58 | int port = Integer.parseInt(options.getOrDefault(FlightSource.PORT, "32010")); 59 | //account 60 | String user = options.getOrDefault(FlightSource.USER, ""); 61 | String password = options.getOrDefault(FlightSource.PASSWORD, ""); 62 | String bearerToken = options.getOrDefault(FlightSource.BEARER_TOKEN, ""); 63 | //validation - host, user & password cannot be empty 64 | if (host.isEmpty() || user.isEmpty() || (password.isEmpty() && bearerToken.isEmpty())) { 65 | throw new RuntimeException("The host, user and (password or access-token) are all mandatory."); 66 | } 67 | //tls configuration 68 | boolean tlsEnabled = Boolean.parseBoolean(options.getOrDefault(FlightSource.TLS_ENABLED, "false")); 69 | boolean tlsVerify = Boolean.parseBoolean(options.getOrDefault(FlightSource.TLS_VERIFY_SERVER, "true")); 70 | String truststoreJks = options.getOrDefault(FlightSource.TLS_TRUSTSTORE_JKS, ""); 71 | String truststorePass = options.getOrDefault(FlightSource.TLS_TRUSTSTORE_PASS, ""); 72 | //set up the configuration object 73 | this._configuration = (truststoreJks != null && !truststoreJks.isEmpty()) 74 | ? new Configuration(host, port, truststoreJks, truststorePass, user, password, bearerToken) : new Configuration(host, port, tlsEnabled, tlsEnabled && tlsVerify, user, password, bearerToken); 75 | 76 | //set the schema path, routing tag & queue if any to manage work-loads 77 | this._configuration.setDefaultSchema(options.getOrDefault(FlightSource.KEY_DEFAULT_SCHEMA, "")); 78 | this._configuration.setRoutingTag(options.getOrDefault(FlightSource.KEY_ROUTING_TAG, "")); 79 | this._configuration.setRoutingQueue(options.getOrDefault(FlightSource.KEY_ROUTING_QUEUE, "")); 80 | 81 | //the table name 82 | String tableName = options.getOrDefault(FlightSource.TABLE, ""); 83 | //table name cannot empty 84 | if (tableName == null || tableName.isEmpty()) { 85 | throw new RuntimeException("The table is mandatory."); 86 | } 87 | //set up the flight-table with the column quote-character. By default, columns are not quoted 88 | this._table = Table.forTable(tableName, options.getOrDefault(FlightSource.COLUMN_QUOTE, "")); 89 | this._table.initialize(this._configuration); 90 | } 91 | 92 | /** 93 | * Get the table of the DataSource 94 | * @param schema - the schema of the table 95 | * @param partitioning - the partitioning for the table 96 | * @param properties - the properties of the table 97 | * @return - a Table object 98 | */ 99 | @Override 100 | public org.apache.spark.sql.connector.catalog.Table getTable(StructType schema, Transform[] partitioning, Map properties) { 101 | return new FlightTable(this._configuration, this._table); 102 | } 103 | 104 | /** 105 | * Get the short-name of the DataSource 106 | * @return - the short name of the DataSource 107 | */ 108 | @Override 109 | public String shortName() { 110 | return "flight"; 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/read/FlightScanBuilder.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.read; 2 | 3 | import com.qwshen.flight.Configuration; 4 | import com.qwshen.flight.PartitionBehavior; 5 | import com.qwshen.flight.PushAggregation; 6 | import com.qwshen.flight.Table; 7 | import org.apache.spark.sql.connector.expressions.Expression; 8 | import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; 9 | import org.apache.spark.sql.connector.expressions.aggregate.*; 10 | import org.apache.spark.sql.connector.read.*; 11 | import org.apache.spark.sql.sources.*; 12 | import org.apache.spark.sql.types.StructField; 13 | import org.apache.spark.sql.types.StructType; 14 | import java.util.ArrayList; 15 | import java.util.Arrays; 16 | import java.util.List; 17 | import java.util.Objects; 18 | import java.util.function.Function; 19 | 20 | /** 21 | * Build flight scans which supports pushed-down filter, fields & aggregates 22 | */ 23 | public final class FlightScanBuilder implements ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsPushDownAggregates { 24 | private final Configuration _configuration; 25 | private final Table _table; 26 | private final PartitionBehavior _partitionBehavior; 27 | 28 | //the pushed-down filters 29 | private Filter[] _pdFilters = new Filter[0]; 30 | //the pushed-down columns 31 | private StructField[] _pdColumns = new StructField[0]; 32 | 33 | //pushed-down aggregation 34 | private PushAggregation _pdAggregation = null; 35 | 36 | /** 37 | * Construct a flight-scan builder 38 | * @param configuration - the configuration of remote flight-service 39 | * @param table - the table instance representing a table in remote flight-service 40 | * @param partitionBehavior - the partitioning behavior when loading data from remote flight service 41 | */ 42 | public FlightScanBuilder(Configuration configuration, Table table, PartitionBehavior partitionBehavior) { 43 | this._configuration = configuration; 44 | this._table = table; 45 | this._partitionBehavior = partitionBehavior; 46 | } 47 | 48 | /** 49 | * Collection aggregations that will be pushed down 50 | * @param aggregation - the pushed aggregation 51 | * @return - pushed aggregation 52 | */ 53 | @Override 54 | public boolean pushAggregation(Aggregation aggregation) { 55 | Function quote = (s) -> String.format("%s%s%s", this._table.getColumnQuote(), s, this._table.getColumnQuote()); 56 | Function mks = (ss) -> String.join(",", Arrays.stream(ss).map(quote).toArray(String[]::new)); 57 | Function e2s = (e) -> String.join(",", Arrays.stream(e.references()).map(r -> mks.apply(r.fieldNames())).toArray(String[]::new)); 58 | 59 | List pdAggregateColumns = new ArrayList<>(); 60 | boolean push = true; 61 | for (AggregateFunc agg : aggregation.aggregateExpressions()) { 62 | if (agg instanceof CountStar) { 63 | pdAggregateColumns.add(agg.toString().toLowerCase()); 64 | } else if (agg instanceof Count) { 65 | Count c = (Count)agg; 66 | pdAggregateColumns.add(c.isDistinct() ? String.format("count(distinct(%s))", e2s.apply(c.column())) : String.format("count(%s)", e2s.apply(c.column()))); 67 | } else if (agg instanceof Min) { 68 | Min m = (Min)agg; 69 | pdAggregateColumns.add(String.format("min(%s)", e2s.apply(m.column()))); 70 | } else if (agg instanceof Max) { 71 | Max m = (Max)agg; 72 | pdAggregateColumns.add(String.format("max(%s)", e2s.apply(m.column()))); 73 | } else if (agg instanceof Sum) { 74 | Sum s = (Sum)agg; 75 | pdAggregateColumns.add(s.isDistinct() ? String.format("sum(distinct(%s))", e2s.apply(s.column())) : String.format("sum(%s)", e2s.apply(s.column()))); 76 | } else { 77 | push = false; 78 | break; 79 | } 80 | } 81 | if (push) { 82 | String[] pdGroupByColumns = Arrays.stream(aggregation.groupByExpressions()).flatMap(e -> Arrays.stream(e.references()).flatMap(r -> Arrays.stream(r.fieldNames()).map(quote))).toArray(String[]::new); 83 | pdAggregateColumns.addAll(0, Arrays.asList(pdGroupByColumns)); 84 | this._pdAggregation = pdGroupByColumns.length > 0 ? new PushAggregation(pdAggregateColumns.toArray(new String[0]), pdGroupByColumns) : new PushAggregation(pdAggregateColumns.toArray(new String[0])); 85 | } else { 86 | this._pdAggregation = null; 87 | } 88 | return this._pdAggregation != null; 89 | } 90 | 91 | /** 92 | * For SupportsPushDownFilters interface 93 | * @param filters - the pushed-down filters 94 | * @return - not-accepted filters 95 | */ 96 | @Override 97 | public Filter[] pushFilters(Filter[] filters) { 98 | Function isValid = (filter) -> (filter instanceof IsNotNull || filter instanceof IsNull 99 | || filter instanceof EqualTo || filter instanceof EqualNullSafe 100 | || filter instanceof LessThan || filter instanceof LessThanOrEqual || filter instanceof GreaterThan || filter instanceof GreaterThanOrEqual 101 | || filter instanceof StringStartsWith || filter instanceof StringContains || filter instanceof StringEndsWith 102 | || filter instanceof And || filter instanceof Or || filter instanceof Not || filter instanceof In 103 | ); 104 | 105 | java.util.List pushed = new java.util.ArrayList<>(); 106 | try { 107 | return Arrays.stream(filters).map(filter -> { 108 | if (isValid.apply(filter)) { 109 | pushed.add(filter); 110 | return null; 111 | } else { 112 | return filter; 113 | } 114 | }).filter(Objects::nonNull).toArray(Filter[]::new); 115 | } finally { 116 | this._pdFilters = pushed.toArray(new Filter[0]); 117 | } 118 | } 119 | 120 | /** 121 | * For SupportsPushDownFilters interface 122 | * @return - the pushed-down filters 123 | */ 124 | @Override 125 | public Filter[] pushedFilters() { 126 | return this._pdFilters; 127 | } 128 | 129 | /** 130 | * For SupportsPushDownRequiredColumns interface 131 | * @param columns - the schema containing the required columns 132 | */ 133 | @Override 134 | public void pruneColumns(StructType columns) { 135 | this._pdColumns = columns.fields(); 136 | } 137 | 138 | /** 139 | * To build a flight-scan 140 | * @return - A flight scan 141 | */ 142 | @Override 143 | public Scan build() { 144 | //adjust flight-table upon pushed filters & columns 145 | String where = String.join(" and ", Arrays.stream(this._pdFilters).map(this._table::toWhereClause).toArray(String[]::new)); 146 | if (this._table.probe(where, this._pdColumns, this._pdAggregation, this._partitionBehavior)) { 147 | this._table.initialize(this._configuration); 148 | } 149 | return new FlightScan(this._configuration, this._table); 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /docs/tutorial.md: -------------------------------------------------------------------------------- 1 | This tutorial uses Dremio Community Edition v22.0.0 or above as the back-end Arrow-Flight server. Please follow the following steps: 2 | 3 | 1. Set up a linux environment, such as Ubuntu 14.04+ with Java SE 8 (JDK 1.8) or Java SE 11 (JDK 1.11); 4 | 5 | 2. Download the latest Dremio Community Edition [here](https://download.dremio.com/community-server/dremio-community-LATEST.tar.gz). Make sure it is v21.2.0 or above. Unzip the tar file to /opt/dremio; 6 | 7 | 3. Download Apache Spark v3.2.1 or above from [here](https://spark.apache.org/downloads.html). Unzip the tar file to /opt/spark. 8 | 9 | 4. Edit ~/.bashrc by adding the following line at the end of the file: 10 | ```shell 11 | export PATH=/opt/dremio/bin:/opt/spark/bin:$PATH 12 | ``` 13 | 14 | 5. Run the following command: 15 | ```shell 16 | source ~/.bashrc 17 | ``` 18 | 19 | 6. Start the Dremio server by running the following command: 20 | ```shell 21 | dremio start 22 | ``` 23 | 24 | 7. In a browser, browse to http://127.0.0.1:9047 to open the Dremio Web Console. If this is your first time, you are required to create an admin user: 25 | - User-Name: test 26 | - Password: Password@123 27 | 28 | 8. Open the SQL Runner on the Dremio Web UI, and run the following query to make sure all iceberg entries are set correctly: 29 | ```roomsql 30 | SELECT name, bool_val, num_val FROM sys.options WHERE name like '%iceberg%' 31 | ``` 32 | The following properties should be set to true: 33 | ```roomsql 34 | dremio.iceberg.enabled = true 35 | dremio.iceberg.dml.enabled = true 36 | dremio.iceberg.ctas.enabled = true 37 | dremio.iceberg.time_travel.enabled = true 38 | ``` 39 | 40 | 9. If any above property is not set to true, please execute the following commands in the SQL Runner to enable iceberg: 41 | ```roomsql 42 | alter system set dremio.iceberg.enabled = true; 43 | alter system set dremio.iceberg.dml.enabled = true; 44 | alter system set dremio.iceberg.ctas.enabled = true; 45 | alter system set dremio.iceberg.time_travel.enabled = true; 46 | ``` 47 | 48 | 10. Launch spark-shell 49 | ```shell 50 | # create the data folder first 51 | mkdir -p /tmp/data 52 | # make sure the spark-flight-connector-1.0.jar is copied to the current directory 53 | # launch spark-sql. 54 | spark-shell --packages org.apache.iceberg:iceberg-spark-runtime-3.2_2.12:0.13.2 --jars ./spark-flight-connector_3.2.1-1.0.0.jar \ 55 | --conf spark.sql.catalog.iceberg_catalog=org.apache.iceberg.spark.SparkCatalog \ 56 | --conf spark.sql.catalog.iceberg_catalog.type=hadoop --conf spark.sql.catalog.iceberg_catalog.warehouse=file:///tmp/data 57 | ``` 58 | 59 | 11. In spark-shell, run the following code to create the iceberg database and make it the current database: 60 | ```scala 61 | sql("create database iceberg_catalog.iceberg_db") 62 | sql("use iceberg_catalog.iceberg_db") 63 | ``` 64 | 65 | 12. Create the customer table: 66 | ```scala 67 | sql("create table iceberg_customers(customer_id bigint not null, created_date string not null, company_name string, contact_person string, contact_phone string, active boolean) using iceberg;") 68 | sql("show tables").show(false) // make sure the customer table has been created 69 | ``` 70 | 71 | 13. Insert a few records into the customers table 72 | ```scala 73 | sql("insert into iceberg_customers values(3001, '2019-12-15', 'ABC Manufacturing', 'Jay Douglous', '123-xxx-1212', true)") 74 | sql("insert into iceberg_customers values(3002, '2018-12-14', 'My Pharma Corp.', 'Jessica Smith', '123-xxx-3652', true)") 75 | sql("insert into iceberg_customers values(3003, '2014-12-17', 'My Investors, Inc.', 'Chris Pandha', '123-xxx-6845', true)") 76 | sql("insert into iceberg_customers values(3004, '2013-12-15', 'XYZ Life Insurance, Inc.', 'Foster Ling', '123-xxx-9487', true)") 77 | 78 | //make sure the records have been inserted 79 | sql("select * from iceberg_customers").show(false) 80 | ``` 81 | 82 | 14. Go back to the Dremio Web UI, and create a source pointing to /tmp/data. 83 | - Open http://127.0.0.1:9047 84 | - Sign in with test/Password@123 85 | - Click on the Datasets icon in the top-left corner on the page, then click on the + button at right of "Data Lakes" link in the left-bottom corner of the page 86 | - On the "Add Data Lake" window, pick NAS, then type: 87 | - Name: local-iceberg 88 | - Mount-Path: /tmp/data 89 | - Click the Save button 90 | 91 | 15. Click the local_iceberg source to show the iceberg_customers table; hover your mouse over the iceberg_customer item, and then click on the "Format Folder" button. Dremio will automatically detects the iceberg format. Click Save to save the format. 92 | 93 | 16. Open SQL Runner, and run the following SQL statements: 94 | ```roomsql 95 | select * from "local-iceberg"."iceberg_db"."iceberg_customers" -- make sure all pre-inserted records showing up 96 | -- insert a new record 97 | insert into "local-iceberg"."iceberg_db"."iceberg_customers"(customer_id, created_date, company_name, contact_person, contact_phone, active) values(3005, '2022-05-11', 'My Foods Inc.', 'Judy Smith', '416-xxx-2212', 'true') 98 | -- make sure the new record has been added 99 | select * from "local-iceberg"."iceberg_db"."iceberg_customers" 100 | ``` 101 | 102 | 17. Go back to spark-shell, and run the following code: 103 | ```scala 104 | val df = spark.read.format("flight") 105 | .option("host", "127.0.0.1").option("port", "32010").option("user", "test").option("password", "Password@12345") 106 | .option("table", """"local-iceberg"."iceberg_db"."iceberg_customers"""") 107 | .load 108 | df.show(false) //to show the records from the table 109 | ``` 110 | 111 | 18. Truncate the table, then run an insert: 112 | ```scala 113 | val df = spark.read.format("flight") 114 | .option("host", "127.0.0.1").option("port", "32010").option("user", "test").option("password", "Password@12345") 115 | .option("table", """"local-iceberg"."iceberg_db"."iceberg_customers"""") 116 | .load 117 | df.show(false) //to show the records from the table 118 | 119 | //overwrite 120 | df.withColumn("customer_id", col("customer_id") + 90000) 121 | .write.format("flight") 122 | .option("host", "127.0.0.1").option("port", "32010").option("user", "test").option("password", "Password@12345") 123 | .option("table", """"local-iceberg"."iceberg_db"."iceberg_customers"""") 124 | .mode("overwrite").save() 125 | ``` 126 | Then go to the Dremio web-ui to check if new data has been inserted with the new customer IDs. 127 | 128 | 19. Run the following merge by: 129 | ```scala 130 | val df = spark.read.format("flight") 131 | .option("host", "127.0.0.1").option("port", "32010").option("user", "test").option("password", "Password@12345") 132 | .option("table", """test."iceberg_db"."iceberg_customers"""") 133 | .load 134 | df.show(false) // to show the records from the table 135 | 136 | // merge-by 137 | df.withColumn("customer_id", when(col("customer_id") % 3 === lit(0), col("customer_id") + 90000).otherwise(col("customer_id"))) 138 | .withColumn("created_date", current_date()) 139 | .withColumn("company", when(col("company") === lit("ABC Manufacturing"), lit("Central Bank")).otherwise(col("company"))) 140 | .write.format("flight") 141 | .option("host", "127.0.0.1").option("port", "32010").option("user", "test").option("password", "Password@12345") 142 | .option("table", """test."iceberg_db"."iceberg_customers"""") 143 | .option("merge.byColumn", "customer_id") 144 | .mode("append").save() 145 | ``` 146 | Then go to the Dremio web-ui to check data changes. 147 | 148 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/write/FlightDataWriter.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.write; 2 | 3 | import com.qwshen.flight.*; 4 | import org.apache.arrow.flight.sql.FlightSqlClient; 5 | import org.apache.arrow.memory.RootAllocator; 6 | import org.apache.arrow.vector.VectorSchemaRoot; 7 | import org.apache.arrow.vector.types.pojo.Field; 8 | import org.apache.arrow.vector.types.pojo.Schema; 9 | import org.apache.spark.sql.catalyst.InternalRow; 10 | import org.apache.spark.sql.connector.write.DataWriter; 11 | import org.apache.spark.sql.connector.write.WriterCommitMessage; 12 | import org.apache.spark.sql.types.DataType; 13 | import org.apache.spark.sql.types.StructField; 14 | import org.apache.spark.sql.types.StructType; 15 | import java.io.IOException; 16 | import java.util.function.Function; 17 | import java.util.stream.IntStream; 18 | 19 | /** 20 | * Flight DataWriter writes rows to the target flight table 21 | */ 22 | public class FlightDataWriter implements DataWriter { 23 | private int _partitionId; 24 | private long _taskId; 25 | private String _epochId; 26 | 27 | private final StructType _dataSchema; 28 | private final Schema _arrowSchema; 29 | 30 | private final WriteStatement _stmt; 31 | private final int _batchSize; 32 | 33 | private final Client _client; 34 | private Field[] _fields = null; 35 | private FlightSqlClient.PreparedStatement _preparedStmt = null; 36 | private VectorSchemaRoot _root = null; 37 | private ArrowConversion _conversion = null; 38 | 39 | private final java.util.List _rows; 40 | private long _count = 0; 41 | 42 | /** 43 | * Construct a DataWriter for batch write 44 | * @param partitionId - the partition id of the data block to be written 45 | * @param taskId - the task id of the write operation 46 | * @param configuration - the configuration of remote flight service 47 | * @param protocol - the protocol for writing - sql or arrow 48 | * @param stmt - the write-statement 49 | * @param batchSize - the batch size for write 50 | */ 51 | public FlightDataWriter(int partitionId, long taskId, Configuration configuration, WriteStatement stmt, WriteProtocol protocol, int batchSize) { 52 | this(configuration, stmt, protocol, batchSize); 53 | this._partitionId = partitionId; 54 | this._taskId = taskId; 55 | this._epochId = ""; 56 | } 57 | 58 | /** 59 | * Construct a DataWriter for streaming write 60 | * @param partitionId - the partition id of the data block to be written 61 | * @param taskId - the task id of the write operation 62 | * @param configuration - the configuration of remote flight service 63 | * @param protocol - the protocol for writing - sql or arrow 64 | * @param stmt - the write-statement 65 | * @param batchSize - the batch size for write 66 | * @param epochId - a monotonically increasing id for streaming queries that are split into discrete periods of execution. 67 | */ 68 | public FlightDataWriter(int partitionId, long taskId, long epochId, Configuration configuration, WriteStatement stmt, WriteProtocol protocol, int batchSize) { 69 | this(configuration, stmt, protocol, batchSize); 70 | this._partitionId = partitionId; 71 | this._taskId = taskId; 72 | this._epochId = Long.toString(epochId); 73 | } 74 | 75 | /** 76 | * Internal Constructor 77 | * @param configuration - the configuration of remote flight service 78 | * @param protocol - the protocol for writing - sql or arrow 79 | * @param stmt - the write-statement 80 | * @param batchSize - the batch size for write 81 | */ 82 | private FlightDataWriter(Configuration configuration, WriteStatement stmt, WriteProtocol protocol, int batchSize) { 83 | this._stmt = stmt; 84 | this._batchSize = batchSize; 85 | 86 | this._dataSchema = this._stmt.getDataSchema(); 87 | this._client = Client.getOrCreate(configuration); 88 | if (protocol == WriteProtocol.PREPARED_SQL) { 89 | this._preparedStmt = this._client.getPreparedStatement(this._stmt.getStatement()); 90 | this._arrowSchema = this._preparedStmt.getParameterSchema(); 91 | this._fields = this._arrowSchema.getFields().toArray(new Field[0]); 92 | this._root = VectorSchemaRoot.create(this._arrowSchema, new RootAllocator(Integer.MAX_VALUE)); 93 | this._conversion = ArrowConversion.getOrCreate(); 94 | } else { 95 | try { 96 | this._arrowSchema = this._stmt.getArrowSchema(); 97 | } catch (Exception e) { 98 | throw new RuntimeException("The arrow schema is invalid.", e); 99 | } 100 | } 101 | this._rows = new java.util.ArrayList<>(); 102 | } 103 | 104 | /** 105 | * Write one row 106 | * @param row - the row of data 107 | * @throws IOException - thrown when writing failed 108 | */ 109 | @Override 110 | public void write(InternalRow row) throws IOException { 111 | this._rows.add(row.copy()); 112 | if (this._rows.size() > this._batchSize) { 113 | this.write(this._rows.toArray(new InternalRow[0])); 114 | this._rows.clear(); 115 | } 116 | } 117 | /** 118 | * Write out all rows 119 | * @param rows - the data rows 120 | */ 121 | private void write(InternalRow[] rows) { 122 | if (this._conversion != null) { 123 | Function dtFind = (name) -> this._dataSchema.find(field -> field.name().equalsIgnoreCase(name)).map(StructField::dataType).get(); 124 | IntStream.range(0, this._fields.length).forEach(idx -> this._conversion.populate(this._root.getVector(idx), rows, idx, dtFind.apply(this._fields[idx].getName()))); 125 | this._root.setRowCount(rows.length); 126 | this._preparedStmt.setParameters(this._root); 127 | try { 128 | this._client.executeUpdate(this._preparedStmt); 129 | } finally { 130 | this._preparedStmt.clearParameters(); 131 | this._root.clear(); 132 | } 133 | } else { 134 | this._client.execute(this._stmt.fillStatement(rows, this._arrowSchema.getFields().toArray(new Field[0]))); 135 | } 136 | this._count += rows.length; 137 | } 138 | 139 | /** 140 | * Commit write 141 | * @return - a commit-message 142 | */ 143 | @Override 144 | public WriterCommitMessage commit() { 145 | //write any left-over 146 | if (this._rows.size() > 0) { 147 | this.write(this._rows.toArray(new InternalRow[0])); 148 | this._rows.clear(); 149 | } 150 | 151 | long cnt = this._count; 152 | this._count = 0; 153 | return (this._epochId.length() == 0) ? new FlightWriterCommitMessage(this._partitionId, this._taskId, cnt) 154 | : new FlightWriterCommitMessage(this._partitionId, this._taskId, this._epochId, cnt); 155 | } 156 | 157 | /** 158 | * Abort write 159 | * @throws IOException - the exception with the error message 160 | */ 161 | @Override 162 | public void abort() throws IOException { 163 | throw (this._epochId.length() == 0) ? new FlightWriteAbortException(this._partitionId, this._taskId, this._count) 164 | : new FlightWriteAbortException(this._partitionId, this._taskId, this._epochId, this._count); 165 | } 166 | 167 | /** 168 | * Close any connections 169 | */ 170 | @Override 171 | public void close() { 172 | if (this._preparedStmt != null) { 173 | this._preparedStmt.close(); 174 | } 175 | if (this._root != null) { 176 | this._root.close(); 177 | } 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/spark/FlightTable.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark; 2 | 3 | import com.qwshen.flight.*; 4 | import com.qwshen.flight.spark.read.FlightScanBuilder; 5 | import com.qwshen.flight.spark.write.FlightWriteBuilder; 6 | import org.apache.commons.lang3.ArrayUtils; 7 | import org.apache.spark.sql.connector.catalog.SupportsRead; 8 | import org.apache.spark.sql.connector.catalog.SupportsWrite; 9 | import org.apache.spark.sql.connector.catalog.TableCapability; 10 | import org.apache.spark.sql.connector.read.ScanBuilder; 11 | import org.apache.spark.sql.connector.write.LogicalWriteInfo; 12 | import org.apache.spark.sql.connector.write.WriteBuilder; 13 | import org.apache.spark.sql.types.StructType; 14 | import org.apache.spark.sql.util.CaseInsensitiveStringMap; 15 | import java.util.Arrays; 16 | import java.util.HashSet; 17 | import java.util.Set; 18 | import java.util.function.Function; 19 | import java.util.stream.Collectors; 20 | 21 | /** 22 | * Describes the flight-table 23 | */ 24 | public class FlightTable implements org.apache.spark.sql.connector.catalog.Table, SupportsRead, SupportsWrite { 25 | //the default type-mapping 26 | private static final String _TYPE_MAPPING_DEFAULT = "BIT:BOOLEAN;" 27 | + "LARGEVARCHAR:VARCHAR; VARCHAR:VARCHAR;" 28 | + "TINYINT:INT; SMALLINT:INT; UINT1:INT; UINT2:INT; UINT4:INT; UINT8:INT; INT:INT; BIGINT:BIGINT;" 29 | + "FLOAT4:FLOAT; FLOAT8:DOUBLE; DECIMAL:DECIMAL; DECIMAL256:DECIMAL;" 30 | + "DATE:DATE; TIME:TIME; TIMESTAMP:TIMESTAMP"; 31 | //the number of partitions when reading and writing 32 | private static final String PARTITION_SIZE = "partition.size"; 33 | //the option keys for partitioning 34 | private static final String PARTITION_HASH_FUN = "partition.hashFunc"; 35 | private static final String PARTITION_BY_COLUMN = "partition.byColumn"; 36 | private static final String PARTITION_LOWER_BOUND = "partition.lowerBound"; 37 | private static final String PARTITION_UPPER_BOUND = "partition.upperBound"; 38 | private static final String PARTITION_PREDICATE = "partition.predicate"; 39 | private static final String PARTITION_PREDICATES = "partition.predicates"; 40 | 41 | //write protocol 42 | private static final String WRITE_PROTOCOL = "write.protocol"; 43 | //type mapping 44 | private static final String WRITE_TYPE_MAPPING = "write.typeMapping"; 45 | //the batch-size for writing 46 | private static final String BATCH_SIZE = "batch.size"; 47 | //merge by keys 48 | private static final String MERGE_BY_COLUMN = "merge.byColumn"; 49 | private static final String MERGE_BY_COLUMNS = "merge.byColumns"; 50 | 51 | //the configuration of remote flight service 52 | private final Configuration _configuration; 53 | //the table description 54 | private final Table _table; 55 | 56 | //the table capabilities 57 | private final Set _capabilities = new HashSet<>(); 58 | 59 | /** 60 | * Construct a flight-table 61 | * @param configuration - the configuration of remote flight-service 62 | * @param table - the table instance pointing to a remote flight-table 63 | */ 64 | public FlightTable(Configuration configuration, Table table) { 65 | this._configuration = configuration; 66 | this._table = table; 67 | 68 | //the data-source supports batch read/write, truncate the table 69 | this._capabilities.add(TableCapability.ACCEPT_ANY_SCHEMA); 70 | this._capabilities.add(TableCapability.BATCH_READ); 71 | this._capabilities.add(TableCapability.BATCH_WRITE); 72 | this._capabilities.add(TableCapability.TRUNCATE); 73 | this._capabilities.add(TableCapability.STREAMING_WRITE); 74 | } 75 | 76 | /** 77 | * For Table interface 78 | * @return - the name of the table 79 | */ 80 | @Override 81 | public String name() { 82 | return this._table.getName(); 83 | } 84 | 85 | /** 86 | * For Table interface 87 | * @return - the schema for Spark 88 | */ 89 | @Override 90 | public StructType schema() { 91 | return this._table.getSparkSchema(); 92 | } 93 | 94 | /** 95 | * For Table interface 96 | * @return - the capabilities of this table 97 | */ 98 | @Override 99 | public Set capabilities() { 100 | return this._capabilities; 101 | } 102 | 103 | /** 104 | * For SupportsRead interface 105 | * @param options - the options from the api call 106 | * @return - A scan builder 107 | */ 108 | @Override 109 | public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { 110 | //the partitioning behavior 111 | PartitionBehavior partitionBehavior = new PartitionBehavior( 112 | options.getOrDefault(FlightTable.PARTITION_HASH_FUN, "hash"), options.getOrDefault(FlightTable.PARTITION_BY_COLUMN, ""), 113 | Integer.parseInt(options.getOrDefault(FlightTable.PARTITION_SIZE, "6")), 114 | options.getOrDefault(FlightTable.PARTITION_LOWER_BOUND, ""), options.getOrDefault(FlightTable.PARTITION_UPPER_BOUND, ""), 115 | (String[]) ArrayUtils.addAll( 116 | //filter out any partition.predicates for only partition.predicate 117 | options.keySet().stream() 118 | .filter(k -> !k.equalsIgnoreCase(FlightTable.PARTITION_PREDICATES) && k.toLowerCase().startsWith(FlightTable.PARTITION_PREDICATE.toLowerCase())) 119 | .map(k -> options.getOrDefault(k, "")).filter(p -> !p.isEmpty()).toArray(String[]::new), 120 | //combine with partition.predicates 121 | options.containsKey(FlightTable.PARTITION_PREDICATES) ? options.get(FlightTable.PARTITION_PREDICATES).split("[;|,]") : new String[0] 122 | ) 123 | ); 124 | return new FlightScanBuilder(this._configuration, this._table, partitionBehavior); 125 | } 126 | 127 | /** 128 | * For SupportsWrite interface 129 | * @param logicalWriteInfo - the logical information for the writing 130 | * @return - A write builder 131 | */ 132 | @Override 133 | public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { 134 | //validate the schema - fields being written must be in the table schema 135 | Function exists = (field) -> this._table.getSparkSchema().exists(sf -> sf.name().equalsIgnoreCase(field)); 136 | if (!Arrays.stream(logicalWriteInfo.schema().fields()).map(sf -> exists.apply(sf.name())).reduce((x, y) -> x & y).orElse(false)) { 137 | throw new RuntimeException("The schema of the dataframe being written is not compatible with the schema of the underlying table."); 138 | } 139 | 140 | //construct write-behavior 141 | CaseInsensitiveStringMap options = logicalWriteInfo.options(); 142 | WriteBehavior writeBehavior = new WriteBehavior( 143 | //by default, the write-protocol is submitting literal sql statements 144 | options.getOrDefault(FlightTable.WRITE_PROTOCOL, "literal-sql").equalsIgnoreCase("prepared-sql") ? WriteProtocol.PREPARED_SQL : WriteProtocol.LITERAL_SQL, 145 | //by default, the batch-size is 10,240. 146 | Integer.parseInt(options.getOrDefault(FlightTable.BATCH_SIZE, "1024")), 147 | ArrayUtils.addAll( 148 | //filter out any merge.ByColumns for only merge.ByColumn 149 | options.keySet().stream() 150 | .filter(k -> !k.equalsIgnoreCase(FlightTable.MERGE_BY_COLUMNS) && k.toLowerCase().startsWith(FlightTable.MERGE_BY_COLUMN.toLowerCase())) 151 | .map(k -> options.getOrDefault(k, "")).filter(p -> !p.isEmpty()).toArray(String[]::new), 152 | //combine with merge.ByColumns 153 | options.containsKey(FlightTable.MERGE_BY_COLUMNS) ? options.get(FlightTable.MERGE_BY_COLUMNS).split("[;|,]") : new String[0] 154 | ), 155 | Arrays.stream(options.getOrDefault(FlightTable.WRITE_TYPE_MAPPING, FlightTable._TYPE_MAPPING_DEFAULT).split(";")) 156 | .map(tm -> Arrays.stream(tm.split(":")).map(s -> s.trim().toLowerCase()).toArray(String[]::new)).collect(Collectors.toMap(s -> s[0], s -> s[1])) 157 | ); 158 | return new FlightWriteBuilder(this._configuration, this._table, logicalWriteInfo.schema(), writeBehavior); 159 | } 160 | } 161 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/PartitionBehavior.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.spark.sql.types.DataType; 4 | import org.apache.spark.sql.types.DataTypes; 5 | import org.apache.spark.sql.types.DecimalType; 6 | import org.apache.spark.sql.types.StructField; 7 | import org.joda.time.DateTime; 8 | import org.joda.time.format.DateTimeFormat; 9 | import org.joda.time.format.DateTimeFormatter; 10 | import java.io.Serializable; 11 | import java.util.Arrays; 12 | import java.util.Optional; 13 | import java.util.function.Function; 14 | import java.util.stream.IntStream; 15 | 16 | /** 17 | * Describes the partition behavior. If the data type of by-column is numberic or date-time, and lower-bound and upper-bound are specified, the 18 | * partitioning step is calculated upon partition size. Otherwise, hash-partitioning will be used. 19 | */ 20 | public class PartitionBehavior implements Serializable { 21 | /** 22 | * The internal Bound for organizing predicates 23 | */ 24 | private static class Bound implements Serializable { 25 | private final double _lower; 26 | private final double _upper; 27 | 28 | public Bound(double lower, double upper) { 29 | this._lower = lower; 30 | this._upper = upper; 31 | } 32 | 33 | public String toLongPredicate(String name) { 34 | return toPredicate(name, Long.toString((long)this._lower), Long.toString((long)this._upper)); 35 | } 36 | 37 | public String toDoublePredicate(String name) { 38 | return toPredicate(name, Double.toString(this._lower), Double.toString(this._upper)); 39 | } 40 | 41 | public String toDateTimePredicate(String name, DateTimeFormatter dtFormat) { 42 | return toPredicate(name, String.format("'%s'", dtFormat.print(new DateTime((long)this._lower))), String.format("'%s'", dtFormat.print(new DateTime((long)this._upper)))); 43 | } 44 | 45 | private String toPredicate(String name, String lower, String upper) { 46 | return !lower.equalsIgnoreCase(upper) ? String.format("%s <= %s and %s < %s", lower, name, name, upper) : String.format("%s = %s", name, lower); 47 | } 48 | } 49 | 50 | //the name of hash-func in remote flight service 51 | private final String _hashFunc; 52 | //the name of partition-by column 53 | private final String _byColumn; 54 | //the number of partitions 55 | private final int _size; 56 | //the lower bound 57 | private final String _lowerBound; 58 | //the upper bound 59 | private final String _upperBound; 60 | 61 | //explicit predicates 62 | private final String[] _predicates; 63 | 64 | /** 65 | * Construct a partition behavior 66 | * @param hashFunc - the name of the hash-func 67 | * @param byColumn - the column used for partitioning 68 | * @param size - the partition size 69 | * @param lowerBound - the lower bound used for partitioning 70 | * @param upperBound - the upper bound used for partitioning 71 | * @param predicates - the explicit predicates 72 | */ 73 | public PartitionBehavior(String hashFunc, String byColumn, int size, String lowerBound, String upperBound, String[] predicates) { 74 | this._hashFunc = hashFunc; 75 | this._byColumn = byColumn; 76 | this._size = size; 77 | this._lowerBound = lowerBound; 78 | this._upperBound = upperBound; 79 | 80 | this._predicates = predicates; 81 | } 82 | 83 | /** 84 | * Get the by-column 85 | * @return - the name of the column used for partitioning 86 | */ 87 | public String getByColumn() { 88 | return this._byColumn; 89 | } 90 | 91 | /** 92 | * Get the predicates 93 | * @return - the collection of predicates for partitioning 94 | */ 95 | public String[] getPredicates() { 96 | return this._predicates; 97 | } 98 | 99 | /** 100 | * Calculate the predicates upon by-column, size, lower-bound & upper-bound 101 | * @param dataFields - The fields from the select-list. The column for partitioning may or may not on the select-list. 102 | * @return - the predicates which partitions the rows 103 | */ 104 | public String[] calculatePredicates(StructField[] dataFields) { 105 | String[] predicates = null; 106 | if (this._lowerBound != null && this._lowerBound.length() > 0 && this._upperBound != null && this._upperBound.length() > 0 && dataFields != null) { 107 | StructField partitionColumn = Arrays.stream(dataFields).filter(field -> field.name().equalsIgnoreCase(this._byColumn)).findFirst().orElse(null); 108 | if (partitionColumn != null) { 109 | DataType dataType = partitionColumn.dataType(); 110 | if (dataType.equals(DataTypes.ByteType) || dataType.equals(DataTypes.ShortType) || dataType.equals(DataTypes.IntegerType) || dataType.equals(DataTypes.LongType)) { 111 | predicates = probeLongPredicates().orElse(probeDoublePredicates().orElse(null)); 112 | } else if (dataType.equals(DataTypes.FloatType) || dataType.equals(DataTypes.DoubleType) || dataType instanceof DecimalType) { 113 | predicates = probeDoublePredicates().orElse(null); 114 | } else if (dataType.equals(DataTypes.DateType) || dataType.equals(DataTypes.TimestampType)) { 115 | predicates = probeDateTimePredicates().orElse(null); 116 | } 117 | } 118 | } 119 | if (predicates == null) { 120 | //by default, hash-partitioning is applied 121 | Function hashPredicate = (idx) -> String.format("(%s(%s) %% %d + %d) %% %d = %d", this._hashFunc, this._byColumn, this._size, this._size, this._size, idx); 122 | predicates = IntStream.range(0, this._size).mapToObj(hashPredicate::apply).toArray(String[]::new); 123 | } 124 | return predicates; 125 | } 126 | 127 | //probe Long predicates 128 | private Optional probeLongPredicates() { 129 | try { 130 | long lower = Long.parseLong(this._lowerBound.replace(",", "")); 131 | long upper = Long.parseLong(this._upperBound.replace(",", "")); 132 | double step = (double)(upper - lower) / (double)this._size; 133 | return Optional.of(IntStream.range(0, this._size).mapToObj(i -> new Bound(lower + i * step, lower + (i + 1) * step)).map(b -> b.toLongPredicate(this._byColumn)).toArray(String[]::new)); 134 | } catch (Exception e) { 135 | return Optional.empty(); 136 | } 137 | } 138 | //probe Double predicates 139 | private Optional probeDoublePredicates() { 140 | try { 141 | double lower = Double.parseDouble(this._lowerBound.replace(",", "")); 142 | double upper = Double.parseDouble(this._upperBound.replace(",", "")); 143 | double step = (upper - lower) / (double)this._size; 144 | return Optional.of(IntStream.range(0, this._size).mapToObj(i -> new Bound(lower + i * step, lower + (i + 1) * step)).map(b -> b.toDoublePredicate(this._byColumn)).toArray(String[]::new)); 145 | } catch (Exception e) { 146 | return Optional.empty(); 147 | } 148 | } 149 | //probe DateTime predicates 150 | private Optional probeDateTimePredicates() { 151 | String[] dtFormats = new String[] { 152 | "yyyy-MM-dd HH:mm:ss", "yyyy/MM/dd HH:mm:ss", "MM/dd/yyyy HH:mm:ss", "dd/MM/yyyy HH:mm:ss", 153 | "yyyy-MM-dd'T'HH:mm:ss.SSSZ", "yyyy.MM.dd HH:mm:ss", "yyyyMMdd HH:mm:ss", 154 | "yyyy-MM-dd", "yyyy/MM/d", "MM/dd/yyyy", "dd/MM/yyyy", "yyyyMMdd" 155 | }; 156 | Optional predicates = Optional.empty(); 157 | for (int i = 0; i < dtFormats.length && !predicates.isPresent(); i++) { 158 | predicates = tryDateTimePredicates(DateTimeFormat.forPattern(dtFormats[i])); 159 | } 160 | return predicates; 161 | } 162 | private Optional tryDateTimePredicates(DateTimeFormatter dtFormat) { 163 | try { 164 | long lower = DateTime.parse(this._lowerBound, dtFormat).getMillis(); 165 | long upper = DateTime.parse(this._upperBound, dtFormat).getMillis(); 166 | double step = (double)(upper - lower) / (double)this._size; 167 | return Optional.of(IntStream.range(0, this._size).mapToObj(i -> new Bound(lower + i * step, lower + (i + 1) * step)).map(b -> b.toDateTimePredicate(this._byColumn, dtFormat)).toArray(String[]::new)); 168 | } catch (Exception e) { 169 | return Optional.empty(); 170 | } 171 | } 172 | 173 | /** 174 | * Check if the behavior is defined for partitioning 175 | * @return - true when partitioning is defined 176 | */ 177 | public Boolean enabled() { 178 | return ((this._byColumn != null && this._byColumn.length() > 0) || (this._predicates != null && this._predicates.length > 0)); 179 | } 180 | 181 | /** 182 | * Flg to indicate whether pre-defined predicates have been given 183 | * @return - true if partition predicates provided 184 | */ 185 | public Boolean predicateDefined() { 186 | return (this._predicates != null && this._predicates.length > 0); 187 | } 188 | } 189 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/Configuration.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.bouncycastle.openssl.jcajce.JcaPEMWriter; 4 | import org.slf4j.LoggerFactory; 5 | import java.io.IOException; 6 | import java.io.InputStream; 7 | import java.io.Serializable; 8 | import java.io.StringWriter; 9 | import java.nio.charset.StandardCharsets; 10 | import java.nio.file.Files; 11 | import java.nio.file.Paths; 12 | import java.security.KeyStore; 13 | import java.security.cert.Certificate; 14 | import java.util.Enumeration; 15 | 16 | /** 17 | * Describes the data-structure for connecting to remote flight-service. 18 | */ 19 | public class Configuration implements Serializable { 20 | //the host name of the flight end-point 21 | private final String _fsHost; 22 | //the port # of the flight end-point 23 | private final int _fsPort; 24 | //the flight end-point is whether tls enabled 25 | private final Boolean _tlsEnabled; 26 | private final Boolean _crtVerify; 27 | 28 | //the trust-store file & pass-code 29 | private String _trustStoreJks; 30 | private String _trustStorePass; 31 | 32 | //the user and password/access-token with which to connect to flight service 33 | private final String _user; 34 | private final String _password; 35 | private final String _bearerToken; 36 | 37 | //information to manage work-loads 38 | private String _defaultSchema = ""; 39 | private String _routingTag = ""; 40 | private String _routingQueue = ""; 41 | 42 | //the binary content of the certificate 43 | private byte[] _certBytes; 44 | 45 | /** 46 | * Construct a Configuration object 47 | * @param host - the host name of the remote flight service 48 | * @param port - the port number of the remote flight service 49 | * @param user - the user account for connecting to remote flight service 50 | * @param password - the password of the user account 51 | * @param bearerToken - the pat or auth2 token 52 | */ 53 | public Configuration(String host, int port, String user, String password, String bearerToken) { 54 | this(host, port, false, false, user, password, bearerToken); 55 | } 56 | 57 | /** 58 | * Construct a Configuration object 59 | * @param host - the host name of the remote flight service 60 | * @param port - the port number of the remote flight service 61 | * @param tlsEnabled - whether the flight service has tls enabled for secure connection 62 | * @param crtVerify -whether to verify the certificate if remote flight service is tls-enabled. 63 | * @param user - the user account for connecting to remote flight service 64 | * @param password - the password of the user account 65 | * @param bearerToken - the pat or auth2 token 66 | */ 67 | public Configuration(String host, int port, Boolean tlsEnabled, Boolean crtVerify, String user, String password, String bearerToken) { 68 | this._fsHost = host; 69 | this._fsPort = port; 70 | this._tlsEnabled = tlsEnabled; 71 | this._crtVerify = crtVerify; 72 | 73 | this._trustStoreJks = null; 74 | this._trustStorePass = null; 75 | this._certBytes = null; 76 | 77 | this._user = user; 78 | this._password = password; 79 | this._bearerToken = bearerToken; 80 | } 81 | 82 | /** 83 | * Construct a Configuration object 84 | * @param host - the host name of the remote flight service 85 | * @param port - the port number of the remote flight service 86 | * @param trustStoreJks - the filename of the trust store in jks 87 | * @param truststorePass - the pass code of the trust store 88 | * @param user - the user account for connecting to remote flight service 89 | * @param password - the password of the user account 90 | * @param bearerToken - the pat or auth2 token 91 | */ 92 | public Configuration(String host, int port, String trustStoreJks, String truststorePass, String user, String password, String bearerToken) { 93 | this(host, port, true, true, user, password, bearerToken); 94 | 95 | this._trustStoreJks = trustStoreJks; 96 | this._trustStorePass = truststorePass; 97 | 98 | this._certBytes = Configuration.getCertificateBytes(this._trustStoreJks, this._trustStorePass); 99 | } 100 | 101 | /** 102 | * Get the host name of the remote flight service 103 | * @return - the host name 104 | */ 105 | public String getFlightHost() { 106 | return this._fsHost; 107 | } 108 | 109 | /** 110 | * Get the port number of the remote flight service 111 | * @return - the port number 112 | */ 113 | public int getFlightPort() { 114 | return this._fsPort; 115 | } 116 | 117 | /** 118 | * Get the filename of the truststore 119 | * @return - the filename 120 | */ 121 | public String getTruststoreJks() { 122 | return this._trustStoreJks; 123 | } 124 | 125 | /** 126 | * Get the password of the truststore. 127 | * @return - the password 128 | */ 129 | public String getTruststorePass() { 130 | return this._trustStorePass; 131 | } 132 | 133 | /** 134 | * Get the flag of whether the remote flight service has tls enabled for secure connections 135 | * @return - true if the remtoe flight service supports secure connections 136 | */ 137 | public Boolean getTlsEnabled() { 138 | return this._tlsEnabled; 139 | } 140 | 141 | /** 142 | * Get the flag of whether to skip verifying the remote flight service 143 | * @return - true to skip the verification 144 | */ 145 | public Boolean verifyServer() { 146 | return this._crtVerify; 147 | } 148 | 149 | /** 150 | * Get the byte conent of the service certificate 151 | * @return - the byte content 152 | */ 153 | public byte[] getCertificateBytes() { 154 | return this._certBytes; 155 | } 156 | 157 | /** 158 | * Get the user account 159 | * @return - the user account 160 | */ 161 | public String getUser() { 162 | return this._user; 163 | } 164 | 165 | /** 166 | * Get the password of the user account 167 | * @return - the password of the user account 168 | */ 169 | public String getPassword() { 170 | return this._password; 171 | } 172 | /** 173 | * Get the access-token of the user account 174 | * @return - the access-token of the user account 175 | */ 176 | public String getBearerToken() { 177 | return this._bearerToken; 178 | } 179 | 180 | /** 181 | * Retrieve the connection string for connecting to the remote flight service 182 | * @return - the connection string 183 | */ 184 | public String getConnectionString() { 185 | String secret = (this._password != null && this._password.length() > 0) ? this._password : this._bearerToken; 186 | return String.format("%s://%s:%s@%s:%d", this._tlsEnabled ? "https:" : "http", this._user, secret, this._fsHost, this._fsPort); 187 | } 188 | 189 | private static byte[] getCertificateBytes(String keyStorePath, String keyStorePassword) { 190 | try { 191 | KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); 192 | try (InputStream keyStoreStream = Files.newInputStream(Paths.get(keyStorePath))) { 193 | keyStore.load(keyStoreStream, keyStorePassword.toCharArray()); 194 | } 195 | 196 | Enumeration aliases = keyStore.aliases(); 197 | while (aliases.hasMoreElements()) { 198 | String alias = aliases.nextElement(); 199 | if (keyStore.isCertificateEntry(alias)) { 200 | Certificate certificates = keyStore.getCertificate(alias); 201 | return toBytes(certificates); 202 | } 203 | } 204 | } catch (Exception e) { 205 | LoggerFactory.getLogger(Configuration.class).warn("Cannot load the cert - " + keyStorePath); 206 | } 207 | return new byte[0]; 208 | } 209 | 210 | private static byte[] toBytes(Certificate certificate) throws IOException { 211 | try ( 212 | StringWriter writer = new StringWriter(); 213 | JcaPEMWriter pemWriter = new JcaPEMWriter(writer) 214 | ) { 215 | pemWriter.writeObject(certificate); 216 | pemWriter.flush(); 217 | return writer.toString().getBytes(StandardCharsets.UTF_8); 218 | } 219 | } 220 | 221 | /** 222 | * Get the path of the default schema 223 | * @return - Default schema path to the dataset that the user wants to query. 224 | */ 225 | public String getDefaultSchema() { 226 | return _defaultSchema; 227 | } 228 | 229 | /** 230 | * Set the path of default schema 231 | * @param defaultSchema - Default schema path to the dataset that the user wants to query. 232 | */ 233 | public void setDefaultSchema(String defaultSchema) { 234 | this._defaultSchema = defaultSchema; 235 | } 236 | 237 | /** 238 | * Get the routing-tag 239 | * @return - Tag name associated with all queries executed within a Flight session. Used only during authentication. 240 | */ 241 | public String getRoutingTag() { 242 | return this._routingTag; 243 | } 244 | 245 | /** 246 | * Set the rouging-tag 247 | * @param routingTag - Tag name associated with all queries executed within a Flight session. Used only during authentication. 248 | */ 249 | public void setRoutingTag(String routingTag) { 250 | this._routingTag = routingTag; 251 | } 252 | 253 | /** 254 | * Get the routing-queue 255 | * @return - Name of the workload management queue. Used only during authentication. 256 | */ 257 | public String getRoutingQueue() { 258 | return this._routingQueue; 259 | } 260 | 261 | /** 262 | * Set the routing-queue 263 | * @param routingQueue - Name of the workload management queue. Used only during authentication. 264 | */ 265 | public void setRoutingQueue(String routingQueue) { 266 | this._routingQueue = routingQueue; 267 | } 268 | } 269 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/Client.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.arrow.flight.*; 4 | import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; 5 | import org.apache.arrow.flight.auth2.BearerCredentialWriter; 6 | import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; 7 | import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; 8 | import org.apache.arrow.flight.grpc.CredentialCallOption; 9 | import org.apache.arrow.flight.sql.FlightSqlClient; 10 | import org.apache.arrow.memory.BufferAllocator; 11 | import org.apache.arrow.memory.RootAllocator; 12 | import org.apache.arrow.util.AutoCloseables; 13 | import org.apache.arrow.vector.*; 14 | import org.apache.arrow.vector.types.pojo.Schema; 15 | import org.slf4j.LoggerFactory; 16 | import java.io.ByteArrayInputStream; 17 | import java.net.URI; 18 | import java.nio.charset.StandardCharsets; 19 | import java.util.Arrays; 20 | 21 | /** 22 | * Describes the data-structure of Client for communicating with remote flight service 23 | */ 24 | public final class Client implements AutoCloseable { 25 | //the factory 26 | private static final ClientIncomingAuthHeaderMiddleware.Factory _factory = new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); 27 | //the existing objects of client 28 | private static final java.util.Map _clients = new java.util.HashMap<>(); 29 | 30 | //the flight client 31 | private final FlightClient _client; 32 | //the flight-sql client 33 | private final FlightSqlClient _sqlClient; 34 | //the token for calls 35 | private final CredentialCallOption _bearerToken; 36 | 37 | //the buffer 38 | private final BufferAllocator _allocator; 39 | //the connection string to identify the client 40 | private final String _connectionString; 41 | 42 | /** 43 | * Construct a Client object 44 | * @param client - the client object of the flight service 45 | * @param bearerToken - the credential token 46 | * @param connectionString - the connection string to identify the client 47 | */ 48 | private Client(FlightClient client, CredentialCallOption bearerToken, String connectionString, BufferAllocator allocator) { 49 | this._client = client; 50 | this._sqlClient = new FlightSqlClient(this._client); 51 | this._bearerToken = bearerToken; 52 | 53 | this._connectionString = connectionString; 54 | this._allocator = allocator; 55 | } 56 | 57 | /** 58 | * Fetch meta-data of all related end-points by a query 59 | * @param query - the query submitted to remote flight-service 60 | * @return - the schema and end-points for the query 61 | */ 62 | public QueryEndpoints getQueryEndpoints(String query) { 63 | FlightInfo fi = this._client.getInfo(FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)), this._bearerToken); 64 | Endpoint[] endpoints = fi.getEndpoints().stream().map(ep -> new Endpoint(ep.getLocations().stream().map(Location::getUri).toArray(URI[]::new), ep.getTicket().getBytes())).toArray(Endpoint[]::new); 65 | return new QueryEndpoints(fi.getSchema(), endpoints); 66 | } 67 | 68 | /** 69 | * Fetch rows from the end-point 70 | * @param ep - the end-point 71 | * @param schema - the schema of rows from the end-point 72 | * @return - row set from the end-point 73 | */ 74 | public RowSet fetch(Endpoint ep, Schema schema) { 75 | RowSet rs = new RowSet(schema); 76 | Field[] fields = Field.from(schema); 77 | FlightEndpoint fep = new FlightEndpoint(new Ticket(ep.getTicket()), Arrays.stream(ep.getURIs()).map(Location::new).toArray(Location[]::new)); 78 | try { 79 | try(FlightStream stream = this._client.getStream(fep.getTicket(), this._bearerToken)) { 80 | VectorSchemaRoot root = stream.getRoot(); 81 | while (stream.next()) { 82 | FieldVector[] fs = root.getFieldVectors().stream().map(fv -> FieldVector.fromArrow(fv, Field.find(fields, fv.getName()), root.getRowCount())).toArray(FieldVector[]::new); 83 | for (int i = 0; i < root.getRowCount(); i++) { 84 | RowSet.Row row = new RowSet.Row(); 85 | for (FieldVector f : fs) { 86 | row.add((f.getValues())[i]); 87 | } 88 | rs.add(row); 89 | } 90 | } 91 | } 92 | } catch (Exception e) { 93 | throw new RuntimeException(e); 94 | } 95 | return rs; 96 | } 97 | 98 | /** 99 | * Fetch rows from all end-points 100 | * @param qEndpoints - the query end-points 101 | * @return - all rows from the end-points 102 | */ 103 | public RowSet fetch(QueryEndpoints qEndpoints) { 104 | RowSet rs = new RowSet(qEndpoints.getSchema()); 105 | Arrays.stream(qEndpoints.getEndpoints()).forEach(ep -> rs.add(fetch(ep, qEndpoints.getSchema()))); 106 | return rs; 107 | } 108 | 109 | /** 110 | * Execute a literal SQL statement 111 | * @param stmt - the literal sql-statement 112 | */ 113 | public long execute(String stmt) { 114 | FlightInfo fi = this._sqlClient.execute(stmt, this._bearerToken); 115 | long count = 0; 116 | for (FlightEndpoint endpoint: fi.getEndpoints()) { 117 | try { 118 | try(FlightStream stream = this._client.getStream(endpoint.getTicket(), this._bearerToken)) { 119 | while (stream.next()) { 120 | VectorSchemaRoot root = stream.getRoot(); 121 | count += root.getRowCount(); 122 | } 123 | } 124 | } catch (Exception e) { 125 | throw new RuntimeException(e); 126 | } 127 | } 128 | return count; 129 | } 130 | 131 | /** 132 | * Truncate the target table 133 | * @param table - the name of the table 134 | */ 135 | public void truncate(String table) { 136 | this.execute(String.format("truncate table %s", table)); 137 | } 138 | 139 | /** 140 | * Get a prepared-statement for a query 141 | * @param query - the query being used for update 142 | * @return - the prepared sql-statement for down-stream operations 143 | */ 144 | public FlightSqlClient.PreparedStatement getPreparedStatement(String query) { 145 | return this._sqlClient.prepare(query, this._bearerToken); 146 | } 147 | 148 | /** 149 | * Execute a prepared-statement 150 | * @param preparedStmt - the prepared-statement being executed. 151 | * @return - the number of rows affected. 152 | */ 153 | public long executeUpdate(FlightSqlClient.PreparedStatement preparedStmt) { 154 | return preparedStmt.executeUpdate(this._bearerToken); 155 | } 156 | 157 | /** 158 | * Close the connection 159 | */ 160 | @Override 161 | public void close() { 162 | try { 163 | synchronized (Client._clients) { 164 | Client._clients.remove(this._connectionString); 165 | } 166 | this._client.close(); 167 | 168 | this._allocator.getChildAllocators().forEach(BufferAllocator::close); 169 | AutoCloseables.close(this._allocator); 170 | } catch (Exception ex) { 171 | LoggerFactory.getLogger(this.getClass()).warn(ex.getMessage() + Arrays.toString(ex.getStackTrace())); 172 | } 173 | } 174 | 175 | /** 176 | * Get a client object 177 | * @param config - the connection configuration for establishing connections to remote flight service 178 | * @return - the client object 179 | */ 180 | public static synchronized Client getOrCreate(Configuration config) { 181 | String cs = config.getConnectionString(); 182 | if (!Client._clients.containsKey(cs)) { 183 | final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); 184 | final FlightClient client = Client.create(config, allocator); 185 | 186 | final CallHeaders callHeaders = new FlightCallHeaders(); 187 | if (config.getDefaultSchema() != null && config.getDefaultSchema().length() > 0) { 188 | callHeaders.insert("SCHEMA", config.getDefaultSchema()); 189 | } 190 | if (config.getRoutingTag() != null && config.getRoutingTag().length() > 0) { 191 | callHeaders.insert("ROUTING_TAG", config.getRoutingTag()); 192 | } 193 | if (config.getRoutingQueue() != null && config.getRoutingQueue().length() > 0) { 194 | callHeaders.insert("ROUTING_QUEUE", config.getRoutingQueue()); 195 | } 196 | final HeaderCallOption clientProperties = (callHeaders.keys().size() > 0) ? new HeaderCallOption(callHeaders) : null; 197 | 198 | Client._clients.put(cs, new Client(client, authenticate(client, config.getUser(), config.getPassword(), config.getBearerToken(), clientProperties), cs, allocator)); 199 | } 200 | return Client._clients.get(cs); 201 | } 202 | 203 | //Create a client object with the service configuration 204 | private static FlightClient create(Configuration config, BufferAllocator allocator) { 205 | FlightClient.Builder builder = FlightClient.builder().allocator(allocator); 206 | if (config.getTlsEnabled()) { 207 | if (config.getTruststoreJks() == null || config.getTruststoreJks().isEmpty()) { 208 | builder.location(Location.forGrpcTls(config.getFlightHost(), config.getFlightPort())).useTls().verifyServer(config.verifyServer()); 209 | } else { 210 | builder.location(Location.forGrpcTls(config.getFlightHost(), config.getFlightPort())).useTls().trustedCertificates(new ByteArrayInputStream(config.getCertificateBytes())); 211 | } 212 | } else { 213 | builder.location(Location.forGrpcInsecure(config.getFlightHost(), config.getFlightPort())); 214 | } 215 | return (config.getPassword() != null && config.getPassword().length() > 0) ? builder.intercept(Client._factory).build() : builder.build(); 216 | } 217 | //Authenticate with user & password to obtain the credential token-ticket 218 | private static CredentialCallOption authenticate(FlightClient client, String user, String password, String bearerToken, HeaderCallOption clientProperties) { 219 | final java.util.List callOptions = new java.util.ArrayList<>(); 220 | callOptions.add(clientProperties); 221 | 222 | boolean usePassword = (password != null && password.length() > 0); 223 | callOptions.add(new CredentialCallOption(usePassword ? new BasicAuthCredentialWriter(user, password) : new BearerCredentialWriter(bearerToken))); 224 | client.handshake(callOptions.toArray(new CallOption[0])); 225 | return usePassword ? Client._factory.getCredentialCallOption() : (CredentialCallOption)callOptions.get(0); 226 | } 227 | } 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | 8 | 17 9 | 2.12 10 | 2.12.18 11 | 3.5.7 12 | 12.0.1 13 | 32.0.1-jre 14 | 1.41.0 15 | 4.1.68.Final 16 | 2.13.4 17 | 18 | 17 19 | 17 20 | 21 | UTF-8 22 | UTF-8 23 | 24 | 25 | com.qwshen 26 | spark-flight-connector_${spark.version} 27 | 1.0.5 28 | jar 29 | 30 | 31 | 32 | 33 | org.apache.maven.plugins 34 | maven-compiler-plugin 35 | 3.14.1 36 | 37 | ${java.version} 38 | ${java.version} 39 | 40 | 41 | 42 | net.alchim31.maven 43 | scala-maven-plugin 44 | 4.8.1 45 | 46 | 47 | 48 | compile 49 | testCompile 50 | 51 | 52 | 53 | 54 | 55 | org.apache.maven.plugins 56 | maven-shade-plugin 57 | 3.5.1 58 | 59 | 60 | package 61 | 62 | shade 63 | 64 | 65 | 66 | 67 | *:* 68 | 69 | module-info.class 70 | META-INF/*.SF 71 | META-INF/*.DSA 72 | META-INF/*.RSA 73 | 74 | 75 | 76 | 77 | 78 | com.google.common.base 79 | com.shaded.google.common.base 80 | 81 | 82 | com.google.common.util.concurrent 83 | com.shaded.google.common.util.concurrent 84 | 85 | 86 | com.google.protobuf 87 | com.shaded.google.protobuf 88 | 89 | 90 | org.apache.arrow.memory 91 | org.shaded.apache.arrow.memory 92 | 93 | 94 | org.apache.arrow.vector 95 | org.shaded.apache.arrow.vector 96 | 97 | 98 | org.apache.arrow.util 99 | org.shaded.apache.arrow.util 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | org.apache.maven.plugins 108 | maven-surefire-plugin 109 | 2.12.4 110 | 111 | true 112 | 113 | 114 | 115 | org.scalatest 116 | scalatest-maven-plugin 117 | 1.0 118 | 119 | ${project.build.directory}/surefire-reports 120 | . 121 | WDF TestSuite.txt 122 | -Xmx1g 123 | 124 | 125 | 126 | test 127 | 128 | test 129 | 130 | 131 | 132 | 133 | 134 | org.scala-tools 135 | maven-scala-plugin 136 | 2.15.2 137 | 138 | 139 | 140 | compile 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | io.grpc 152 | grpc-api 153 | ${grpc.version} 154 | 155 | 156 | io.grpc 157 | grpc-core 158 | ${grpc.version} 159 | 160 | 161 | io.grpc 162 | grpc-netty 163 | ${grpc.version} 164 | 165 | 166 | org.apache.arrow 167 | arrow-memory-core 168 | ${arrow.version} 169 | 170 | 171 | org.apache.arrow 172 | arrow-memory-netty 173 | ${arrow.version} 174 | 175 | 176 | 177 | 178 | 179 | org.scala-lang 180 | scala-library 181 | ${scala.version} 182 | 183 | 184 | org.scala-lang.modules 185 | scala-collection-compat_${scala.compact.version} 186 | 2.7.0 187 | 188 | 189 | org.apache.spark 190 | spark-core_${scala.compact.version} 191 | ${spark.version} 192 | provided 193 | 194 | 195 | org.apache.spark 196 | spark-sql_${scala.compact.version} 197 | ${spark.version} 198 | provided 199 | 200 | 201 | org.apache.arrow 202 | flight-core 203 | ${arrow.version} 204 | 205 | 206 | io.netty 207 | netty-transport-native-unix-common 208 | 209 | 210 | io.netty 211 | netty-transport-native-kqueue 212 | 213 | 214 | io.netty 215 | netty-transport-native-epoll 216 | 217 | 218 | com.fasterxml.jackson.core 219 | * 220 | 221 | 222 | 223 | 224 | org.apache.arrow 225 | flight-sql 226 | ${arrow.version} 227 | 228 | 229 | org.apache.arrow 230 | flight-grpc 231 | ${arrow.version} 232 | 233 | 234 | com.typesafe 235 | config 236 | 1.4.2 237 | 238 | 239 | com.google.guava 240 | guava 241 | ${guava.version} 242 | 243 | 244 | com.google.protobuf 245 | protobuf-java 246 | 3.25.5 247 | 248 | 249 | org.bouncycastle 250 | bcpkix-jdk18on 251 | 1.79 252 | 253 | 254 | com.fasterxml.jackson.module 255 | jackson-module-scala_${scala.compact.version} 256 | ${com.fasterxml.jackson} 257 | 258 | 259 | com.fasterxml.jackson.dataformat 260 | jackson-dataformat-yaml 261 | ${com.fasterxml.jackson} 262 | 263 | 264 | joda-time 265 | joda-time 266 | 2.10.10 267 | 268 | 269 | org.scalatest 270 | scalatest_${scala.compact.version} 271 | 3.0.4 272 | test 273 | 274 | 275 | org.junit.jupiter 276 | junit-jupiter-engine 277 | 5.7.2 278 | test 279 | 280 | 281 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/Table.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.arrow.vector.types.pojo.Schema; 4 | import org.apache.spark.sql.sources.*; 5 | import org.apache.spark.sql.types.Metadata; 6 | import org.apache.spark.sql.types.StructField; 7 | import org.apache.spark.sql.types.StructType; 8 | import org.slf4j.LoggerFactory; 9 | import java.io.Serializable; 10 | import java.util.Arrays; 11 | import java.util.Hashtable; 12 | import java.util.function.BiFunction; 13 | import java.util.function.Function; 14 | 15 | /** 16 | * Describes a flight table 17 | */ 18 | public final class Table implements Serializable { 19 | //the name of a flight table whose data will be queried/updated 20 | private final String _name; 21 | //the character for quoting columns in sql statements 22 | private final String _columnQuote; 23 | 24 | //the read-statement 25 | private QueryStatement _stmt; 26 | 27 | //the spark schema 28 | private StructType _sparkSchema = null; 29 | //the flight schema 30 | private Schema _schema = null; 31 | //the end-points exposed by the remote flight-service for fetching data of this table 32 | private Endpoint[] _endpoints = new Endpoint[0]; 33 | 34 | //the container for holding the partitioning queries 35 | private final java.util.List _partitionStmts = new java.util.ArrayList<>(); 36 | 37 | /** 38 | * Construct a Table object 39 | * @param name - the name of the table 40 | * @param columnQuote - the character for quoting columns in sql statements 41 | */ 42 | private Table(String name, String columnQuote) { 43 | this._name = name; 44 | this._columnQuote = columnQuote; 45 | 46 | this.prepareQueryStatement(null, null, null, null); 47 | } 48 | 49 | /** 50 | * Get the name of the table 51 | * @return - the name of the table 52 | */ 53 | public String getName() { 54 | return this._name; 55 | } 56 | 57 | /** 58 | * Get the sql-statement for querying the table 59 | * @return - the physical query which will be submitted to remote flight service 60 | */ 61 | public String getQueryStatement() { 62 | if (this._stmt == null) { 63 | throw new RuntimeException("The read statement is not valid."); 64 | } 65 | return this._stmt.getStatement(); 66 | } 67 | 68 | /** 69 | * Get the partition queries 70 | * @return - the partition queries with each of which is submitted from a spark executor 71 | */ 72 | public String[] getPartitionStatements() { 73 | return this._partitionStmts.toArray(new String[0]); 74 | } 75 | 76 | /** 77 | * Get the spark schema 78 | * @return - the spark schema 79 | */ 80 | public StructType getSparkSchema() { 81 | return this._sparkSchema; 82 | } 83 | 84 | /** 85 | * Get the end-points 86 | * @return - end-points exposed by the remote flight service upon submitted query 87 | */ 88 | public Endpoint[] getEndpoints() { 89 | return this._endpoints; 90 | } 91 | 92 | /** 93 | * Get the flight schema 94 | * @return - the flight schema 95 | */ 96 | public Schema getSchema() { 97 | return this._schema; 98 | } 99 | 100 | /** 101 | * Get the character for quoting columns 102 | * @return - the character for quoting columns 103 | */ 104 | public String getColumnQuote() { 105 | return this._columnQuote; 106 | } 107 | 108 | /** 109 | * Initialize the schema and end-points by submitting the physical query 110 | * @param config - the connection configuration 111 | */ 112 | public void initialize(Configuration config) { 113 | try { 114 | Client client = Client.getOrCreate(config); 115 | QueryEndpoints eps = client.getQueryEndpoints(this.getQueryStatement()); 116 | 117 | this._sparkSchema = new StructType(Arrays.stream(Field.from(eps.getSchema())).map(fs -> new StructField(fs.getName(), FieldType.toSpark(fs.getType()), true, Metadata.empty())).toArray(StructField[]::new)); 118 | this._schema = eps.getSchema(); 119 | this._endpoints = eps.getEndpoints(); 120 | } catch (Exception e) { 121 | LoggerFactory.getLogger(this.getClass()).error(e.getMessage() + " --> " + Arrays.toString(e.getStackTrace())); 122 | throw new RuntimeException(e); 123 | } 124 | } 125 | 126 | //Prepare the query for submitting to remote flight service 127 | private boolean prepareQueryStatement(PushAggregation aggregation, StructField[] fields, String filter, PartitionBehavior partitionBehavior) { 128 | //aggregation mode: 0 -> no aggregation; 1 -> aggregation without group-by; 2 -> aggregation with group-by 129 | int aggMode = 0; 130 | String select = "", groupBy = ""; 131 | if (aggregation != null) { 132 | String[] groupByFields = aggregation.getGroupByColumns(); 133 | if (groupByFields != null && groupByFields.length > 0) { 134 | aggMode = 2; 135 | groupBy = String.join(",", groupByFields); 136 | } else { 137 | aggMode = 1; 138 | } 139 | select = String.format("select %s from %s", String.join(",", aggregation.getColumnExpressions()), this._name); 140 | } else if (fields != null && fields.length > 0) { 141 | select = String.format("select %s from %s", String.join(",", Arrays.stream(fields).map(column -> String.format("%s%s%s", this._columnQuote, column.name(), this._columnQuote)).toArray(String[]::new)), this._name); 142 | } else { 143 | select = String.format("select * from %s", this._name); 144 | } 145 | QueryStatement stmt = new QueryStatement(select, filter, groupBy); 146 | boolean changed = stmt.different(this._stmt); 147 | if (changed) { 148 | this._stmt = stmt; 149 | } 150 | 151 | if (aggMode == 1) { 152 | this._partitionStmts.clear(); 153 | } else if (partitionBehavior != null && partitionBehavior.enabled()) { 154 | String where = (filter != null && !filter.isEmpty()) ? String.format("(%s) and ", filter) : ""; 155 | BiFunction merge = (s1, s2) -> { 156 | Hashtable s = new Hashtable(); 157 | for (StructField sf : s1) { 158 | s.put(sf.name(), sf); 159 | } 160 | for (StructField sf : s2) { 161 | s.put(sf.name(), sf); 162 | } 163 | return s.values().toArray(new StructField[0]); 164 | }; 165 | String[] predicates = partitionBehavior.predicateDefined() ? partitionBehavior.getPredicates() 166 | : partitionBehavior.calculatePredicates(this._sparkSchema == null ? fields : merge.apply(fields, this._sparkSchema.fields())); 167 | for (String predicate : predicates) { 168 | QueryStatement s = new QueryStatement(select, String.format("%s(%s)", where, predicate), groupBy); 169 | this._partitionStmts.add(s.getStatement()); 170 | } 171 | } 172 | return changed; 173 | } 174 | 175 | //translate a filter to where clause 176 | public String toWhereClause(Filter filter) { 177 | StringBuilder sb = new StringBuilder(); 178 | if (filter instanceof EqualTo) { 179 | EqualTo et = (EqualTo)filter; 180 | sb.append((et.value() instanceof Number) 181 | ? String.format("%s%s%s = %s", this._columnQuote, et.attribute(), this._columnQuote, et.value().toString()) 182 | : String.format("%s%s%s = '%s'", this._columnQuote, et.attribute(), this._columnQuote, et.value().toString()) 183 | ); 184 | } else if (filter instanceof EqualNullSafe) { 185 | EqualNullSafe ens = (EqualNullSafe)filter; 186 | sb.append(String.format("((%s%s%s is null and %s is null) or (%s%s%s is not null and %s is not null))", this._columnQuote, ens.attribute(), this._columnQuote, ens.value(), this._columnQuote, ens.attribute(), this._columnQuote, ens.value())); 187 | } else if (filter instanceof LessThan) { 188 | LessThan lt = (LessThan)filter; 189 | sb.append((lt.value() instanceof Number) ? String.format("%s%s%s < %s", this._columnQuote, lt.attribute(), this._columnQuote, lt.value()) : String.format("%s%s%s < '%s'", this._columnQuote, lt.attribute(), this._columnQuote, lt.value())); 190 | } else if (filter instanceof LessThanOrEqual) { 191 | LessThanOrEqual lt = (LessThanOrEqual)filter; 192 | sb.append((lt.value() instanceof Number) ? String.format("%s%s%s <= %s", this._columnQuote, lt.attribute(), this._columnQuote, lt.value()) : String.format("%s%s%s <= '%s'", this._columnQuote, lt.attribute(), this._columnQuote, lt.value())); 193 | } else if (filter instanceof GreaterThan) { 194 | GreaterThan gt = (GreaterThan)filter; 195 | sb.append((gt.value() instanceof Number) ? String.format("%s%s%s > %s", this._columnQuote, gt.attribute(), this._columnQuote, gt.value()) : String.format("%s%s%s > '%s'", this._columnQuote, gt.attribute(), this._columnQuote, gt.value())); 196 | } else if (filter instanceof GreaterThanOrEqual) { 197 | GreaterThanOrEqual gt = (GreaterThanOrEqual)filter; 198 | sb.append((gt.value() instanceof Number) ? String.format("%s%s%s >= %s", this._columnQuote, gt.attribute(), this._columnQuote, gt.value()) : String.format("%s%s%s >= '%s'", this._columnQuote, gt.attribute(), this._columnQuote, gt.value())); 199 | } else if (filter instanceof And) { 200 | And and = (And)filter; 201 | sb.append(String.format("(%s and %s)", toWhereClause(and.left()), toWhereClause(and.right()))); 202 | } else if (filter instanceof Or) { 203 | Or or = (Or)filter; 204 | sb.append(String.format("(%s or %s)", toWhereClause(or.left()), toWhereClause(or.right()))); 205 | } else if (filter instanceof IsNull) { 206 | IsNull in = (IsNull)filter; 207 | sb.append(String.format("%s%s%s is null", this._columnQuote, in.attribute(), this._columnQuote)); 208 | } else if (filter instanceof IsNotNull) { 209 | IsNotNull in = (IsNotNull)filter; 210 | sb.append(String.format("%s%s%s is not null", this._columnQuote, in.attribute(), this._columnQuote)); 211 | } else if (filter instanceof StringStartsWith) { 212 | StringStartsWith ss = (StringStartsWith)filter; 213 | sb.append(String.format("%s%s%s like '%s%s'", this._columnQuote, ss.attribute(), this._columnQuote, ss.value(), "%")); 214 | } else if (filter instanceof StringContains) { 215 | StringContains sc = (StringContains)filter; 216 | sb.append(String.format("%s%s%s like '%s%s%s'", this._columnQuote, sc.attribute(), this._columnQuote, "%", sc.value(), "%")); 217 | } else if (filter instanceof StringEndsWith) { 218 | StringEndsWith se = (StringEndsWith)filter; 219 | sb.append(String.format("%s%s%s like '%s%s'", this._columnQuote, se.attribute(), this._columnQuote, "%", se.value())); 220 | } else if (filter instanceof Not) { 221 | Not not = (Not)filter; 222 | sb.append(String.format("not (%s)", toWhereClause(not.child()))); 223 | } else if (filter instanceof In) { 224 | In in = (In)filter; 225 | sb.append(String.format("%s%s%s in (%s)", this._columnQuote, in.attribute(), this._columnQuote, String.join(",", Arrays.stream(in.values()).map(v -> (v instanceof Number) ? v.toString() : String.format("'%s'", v.toString())).toArray(String[]::new)))); 226 | } 227 | return sb.toString(); 228 | } 229 | 230 | /** 231 | * Probe if the pushed filter, fields and aggregation would affect the existing schema & end-points 232 | * @param pushedFilter - the pushed filter 233 | * @param pushedFields - the pushed fields 234 | * @param pushedAggregation - the pushed aggregation 235 | * @param partitionBehavior - the partitioning behavior 236 | * @return - true if initialization is required 237 | */ 238 | public Boolean probe(String pushedFilter, StructField[] pushedFields, PushAggregation pushedAggregation, PartitionBehavior partitionBehavior) { 239 | if ((pushedFilter == null || pushedFilter.isEmpty()) && (pushedFields == null || pushedFields.length == 0) && pushedAggregation == null) { 240 | return false; 241 | } 242 | return this.prepareQueryStatement(pushedAggregation, pushedFields, pushedFilter, partitionBehavior); 243 | } 244 | 245 | /** 246 | * Table with name 247 | * @param tableName - the name of a table 248 | * @param columnQuote - the character for quoting columns in sql statements 249 | * @return - a Table object 250 | */ 251 | public static Table forTable(String tableName, String columnQuote) { 252 | Function isQuery = (t) -> t.replaceAll("[\r|\n]", " ").trim().toLowerCase().matches("^select .+ [from]?.+"); 253 | return new Table(isQuery.apply(tableName) ? String.format("(%s) t", tableName) : tableName, columnQuote); 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /src/main/java/com/qwshen/flight/FieldType.java: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight; 2 | 3 | import org.apache.arrow.vector.types.pojo.ArrowType; 4 | import org.apache.spark.sql.types.DataTypes; 5 | import org.apache.spark.sql.types.Metadata; 6 | import java.io.Serializable; 7 | import java.util.Arrays; 8 | import java.util.Map; 9 | import java.util.Set; 10 | import java.util.function.Function; 11 | import org.apache.spark.sql.types.YearMonthIntervalType; 12 | import org.apache.spark.sql.types.DayTimeIntervalType; 13 | 14 | /** 15 | * Describes the data-type of a Field 16 | */ 17 | public class FieldType implements Serializable { 18 | /** 19 | * The ID of each type 20 | */ 21 | public enum IDs { 22 | NULL, //Null object 23 | BOOLEAN, //Boolean 24 | BYTE, //Byte 25 | CHAR, //Character 26 | SHORT, //Short 27 | INT, //Integer 28 | LONG, //Long 29 | BIGINT, //Big Integer 30 | FLOAT, //Float 31 | DOUBLE, //Double 32 | DECIMAL, //BidDecimal 33 | VARCHAR, //String 34 | BYTES, //byte[] 35 | DATE, //LocalDate 36 | TIME, //LocalTime 37 | TIMESTAMP, //TimeStamp 38 | PERIOD_YEAR_MONTH, //Period - year with month 39 | DURATION_DAY_TIME, //Period - day with time 40 | PERIOD_DURATION_MONTH_DAY_TIME, //Period - month with day 41 | LIST, //List 42 | MAP, //Map 43 | STRUCT //Struct 44 | } 45 | 46 | /** 47 | * Decimal Type 48 | */ 49 | public static class DecimalType extends FieldType { 50 | private final int _precision; 51 | private final int _scale; 52 | 53 | public DecimalType(int precision, int scale) { 54 | super(IDs.DECIMAL); 55 | this._precision = precision; 56 | this._scale = scale; 57 | } 58 | 59 | public int getPrecision() { 60 | return this._precision; 61 | } 62 | public int getScale() { 63 | return this._scale; 64 | } 65 | } 66 | 67 | /** 68 | * Binary Type 69 | */ 70 | public static class BinaryType extends FieldType { 71 | private final int _byteWidth; 72 | 73 | public BinaryType(int byteWidth) { 74 | super(IDs.BYTES); 75 | this._byteWidth = byteWidth; 76 | } 77 | 78 | public int getByteWidth() { 79 | return this._byteWidth; 80 | } 81 | } 82 | 83 | /** 84 | * List Type 85 | */ 86 | public static class ListType extends FieldType { 87 | private final int _length; 88 | private final FieldType _childType; 89 | 90 | public ListType(int length, FieldType childType) { 91 | super(IDs.LIST); 92 | this._length = length; 93 | this._childType = childType; 94 | } 95 | public ListType(FieldType childType) { 96 | //dynamic size of list 97 | this(-1, childType); 98 | } 99 | 100 | public int getLength() { 101 | return this._length; 102 | } 103 | public FieldType getChildType() { 104 | return this._childType; 105 | } 106 | } 107 | 108 | /** 109 | * May Type 110 | */ 111 | public static class MapType extends FieldType { 112 | private final FieldType _keyType; 113 | private final FieldType _valueType; 114 | 115 | public MapType(FieldType keyType, FieldType valueType) { 116 | super(IDs.MAP); 117 | this._keyType = keyType; 118 | this._valueType = valueType; 119 | } 120 | 121 | public FieldType getKeyType() { 122 | return this._keyType; 123 | } 124 | public FieldType getValueType() { 125 | return this._valueType; 126 | } 127 | } 128 | 129 | /** 130 | * Struct Type 131 | */ 132 | public static class StructType extends FieldType { 133 | private final java.util.Map _childrenType; 134 | 135 | public StructType(java.util.Map childrenType) { 136 | super(IDs.STRUCT); 137 | this._childrenType = childrenType; 138 | } 139 | 140 | public java.util.Map getChildrenType() { 141 | return this._childrenType; 142 | } 143 | } 144 | 145 | /** 146 | * Union Type 147 | */ 148 | public static class UnionType extends StructType { 149 | public UnionType(java.util.Map childrenType) { 150 | super(childrenType); 151 | } 152 | } 153 | 154 | //the type value 155 | private final IDs _typeId; 156 | 157 | /** 158 | * Construct a FieldType object 159 | * @param typeId - the id of the type 160 | */ 161 | public FieldType(IDs typeId) { 162 | this._typeId = typeId; 163 | } 164 | 165 | /** 166 | * Get the Type ID 167 | * @return - the type ID 168 | */ 169 | public IDs getTypeID() { 170 | return this._typeId; 171 | } 172 | 173 | /** 174 | * Convert an arrow-type to field-type 175 | * @param at - the arrow type 176 | * @param children - any children of the arrow-type 177 | * @return - the converted field-type 178 | */ 179 | public static FieldType fromArrow(ArrowType at, java.util.List children) { 180 | switch(at.getTypeID()) { 181 | case Int: 182 | ArrowType.Int it = (ArrowType.Int)at; 183 | switch (it.getBitWidth()) { 184 | case 8: 185 | return new FieldType(it.getIsSigned() ? IDs.BYTE : IDs.SHORT); 186 | case 16: 187 | return new FieldType(it.getIsSigned() ? IDs.SHORT : IDs.INT); 188 | case 64: 189 | return new FieldType(it.getIsSigned() ? IDs.LONG : IDs.BIGINT); 190 | case 32: 191 | default: 192 | return new FieldType(it.getIsSigned() ? IDs.INT : IDs.LONG); 193 | } 194 | case Utf8: 195 | case LargeUtf8: 196 | return new FieldType(IDs.VARCHAR); 197 | case Decimal: 198 | ArrowType.Decimal d = (ArrowType.Decimal)at; 199 | return new FieldType.DecimalType(d.getPrecision(), d.getScale()); 200 | case Date: 201 | return new FieldType(IDs.DATE); 202 | case Time: 203 | return new FieldType(IDs.TIME); 204 | case Timestamp: 205 | return new FieldType(IDs.TIMESTAMP); 206 | case FloatingPoint: 207 | switch(((ArrowType.FloatingPoint)at).getPrecision()) { 208 | case HALF: 209 | case SINGLE: 210 | return new FieldType(IDs.FLOAT); 211 | case DOUBLE: 212 | default: 213 | return new FieldType(IDs.DOUBLE); 214 | } 215 | case Interval: 216 | switch (((ArrowType.Interval)at).getUnit()) { 217 | case YEAR_MONTH: 218 | return new FieldType(IDs.PERIOD_YEAR_MONTH); 219 | case DAY_TIME: 220 | return new FieldType(IDs.DURATION_DAY_TIME); 221 | case MONTH_DAY_NANO: 222 | default: 223 | return new FieldType(IDs.PERIOD_DURATION_MONTH_DAY_TIME); 224 | } 225 | case Duration: 226 | return new FieldType(IDs.PERIOD_DURATION_MONTH_DAY_TIME); 227 | case Bool: 228 | return new FieldType(IDs.BOOLEAN); 229 | case Struct: 230 | case Union: 231 | java.util.Map scType = new java.util.LinkedHashMap<>(); 232 | if (children != null) { 233 | children.forEach(c -> scType.put(c.getName(), FieldType.fromArrow(c.getType(), c.getChildren()))); 234 | } 235 | return (at.getTypeID() == ArrowType.ArrowTypeID.Struct) ? new StructType(scType) : new UnionType(scType); 236 | case Map: 237 | FieldType keyType = null, valueType = null; 238 | if (children != null) { 239 | if (children.size() == 1) { 240 | FieldType mcType = FieldType.fromArrow(children.get(0).getType(), children.get(0).getChildren()); 241 | if (mcType.getTypeID() == IDs.STRUCT) { 242 | Map cldType = ((StructType)mcType).getChildrenType(); 243 | String[] keys = cldType.keySet().toArray(new String[0]); 244 | if (keys.length == 2 && keys[0].equalsIgnoreCase("Key") && keys[1].equalsIgnoreCase("value")) { 245 | keyType = cldType.get("key"); 246 | valueType = cldType.get("value"); 247 | } 248 | } 249 | } 250 | else if (children.size() == 2) { 251 | keyType = FieldType.fromArrow(children.get(0).getType(), children.get(0).getChildren()); 252 | valueType = FieldType.fromArrow(children.get(1).getType(), children.get(1).getChildren()); 253 | } 254 | } 255 | if (keyType == null || valueType == null) { 256 | throw new RuntimeException("Invalid map-type."); 257 | } 258 | return new MapType(keyType, valueType); 259 | case List: 260 | case LargeList: 261 | case FixedSizeList: 262 | FieldType lcType = (children != null && children.size() > 0) ? FieldType.fromArrow(children.get(0).getType(), children.get(0).getChildren()) : null; 263 | return (at.getTypeID() == ArrowType.ArrowTypeID.FixedSizeList) ? new ListType(((ArrowType.FixedSizeList)at).getListSize(), lcType) : new ListType(lcType); 264 | case Binary: 265 | case LargeBinary: 266 | return new BinaryType(-1); 267 | case FixedSizeBinary: 268 | return new BinaryType(((ArrowType.FixedSizeBinary)at).getByteWidth()); 269 | case Null: 270 | case NONE: 271 | default: 272 | return new FieldType(IDs.NULL); 273 | } 274 | } 275 | 276 | //Convert to Spark DecimalType 277 | private static final Function toDecimalType = t -> { 278 | FieldType.DecimalType dt = (FieldType.DecimalType)t; 279 | return new org.apache.spark.sql.types.DecimalType(dt.getPrecision(), dt.getScale()); 280 | }; 281 | private static final Function toStructType = t -> { 282 | org.apache.spark.sql.types.MapType mt = null; 283 | java.util.List> entries = new java.util.ArrayList<>(t.getChildrenType().entrySet()); 284 | if (entries.size() == 1 && entries.get(0).getKey().equals("map") && entries.get(0).getValue().getTypeID() == FieldType.IDs.LIST) { 285 | FieldType.ListType lt = (FieldType.ListType)entries.get(0).getValue(); 286 | if (lt.getChildType().getTypeID() == FieldType.IDs.STRUCT) { 287 | FieldType.StructType st = (FieldType.StructType)lt.getChildType(); 288 | if (st.getTypeID() == FieldType.IDs.STRUCT) { 289 | java.util.List> children = new java.util.ArrayList<>(st.getChildrenType().entrySet()); 290 | if (children.size() == 2 && children.get(0).getKey().equals("key") && children.get(1).getKey().equals("value")) { 291 | mt = new org.apache.spark.sql.types.MapType(FieldType.toSpark(children.get(0).getValue()), FieldType.toSpark(children.get(1).getValue()), true); 292 | } 293 | } 294 | } 295 | } 296 | return (mt != null) ? mt : new org.apache.spark.sql.types.StructType(t.getChildrenType().entrySet().stream().map(e -> new org.apache.spark.sql.types.StructField(e.getKey(), FieldType.toSpark(e.getValue()), true, Metadata.empty())).toArray(org.apache.spark.sql.types.StructField[]::new)); 297 | }; 298 | //convert to Spark ListType 299 | private static final Function toListType = t -> org.apache.spark.sql.types.ArrayType.apply(FieldType.toSpark(t.getChildType())); 300 | //convert to Spark MapType 301 | private static final Function toMapType = t -> org.apache.spark.sql.types.MapType.apply(FieldType.toSpark(t.getKeyType()), FieldType.toSpark(t.getValueType())); 302 | 303 | public static org.apache.spark.sql.types.DataType toSpark(FieldType ft) { 304 | switch (ft.getTypeID()) { 305 | case INT: 306 | return DataTypes.IntegerType; 307 | case CHAR: 308 | case VARCHAR: 309 | case TIME: 310 | return DataTypes.StringType; 311 | case LONG: 312 | case BIGINT: 313 | return DataTypes.LongType; 314 | case FLOAT: 315 | return DataTypes.FloatType; 316 | case DOUBLE: 317 | return DataTypes.DoubleType; 318 | case DECIMAL: 319 | return toDecimalType.apply(ft); 320 | case DATE: 321 | return DataTypes.DateType; 322 | case TIMESTAMP: 323 | return DataTypes.TimestampType; 324 | case BOOLEAN: 325 | return DataTypes.BooleanType; 326 | case BYTE: 327 | return DataTypes.ByteType; 328 | case SHORT: 329 | return DataTypes.ShortType; 330 | case BYTES: 331 | return DataTypes.BinaryType; 332 | case PERIOD_YEAR_MONTH: 333 | return new YearMonthIntervalType(YearMonthIntervalType.YEAR(), YearMonthIntervalType.MONTH()); 334 | case DURATION_DAY_TIME: 335 | return new DayTimeIntervalType(DayTimeIntervalType.DAY(), DayTimeIntervalType.SECOND()); 336 | case PERIOD_DURATION_MONTH_DAY_TIME: 337 | return new org.apache.spark.sql.types.StructType(Arrays.stream(new org.apache.spark.sql.types.StructField[] { 338 | new org.apache.spark.sql.types.StructField("period", new YearMonthIntervalType(YearMonthIntervalType.YEAR(), YearMonthIntervalType.MONTH()), true, org.apache.spark.sql.types.Metadata.empty()), 339 | new org.apache.spark.sql.types.StructField("duration", new DayTimeIntervalType(DayTimeIntervalType.DAY(), DayTimeIntervalType.SECOND()), true, org.apache.spark.sql.types.Metadata.empty()) 340 | }).toArray(org.apache.spark.sql.types.StructField[]::new)); 341 | case LIST: 342 | return toListType.apply((ListType)ft); 343 | case MAP: 344 | return toMapType.apply((MapType)ft); 345 | case STRUCT: 346 | return toStructType.apply((StructType)ft); 347 | case NULL: 348 | default: 349 | return org.apache.spark.sql.types.DataTypes.NullType; 350 | } 351 | } 352 | } 353 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **spark-flight-connector** is an Apache Spark DataSource API that reads/writes data from/to arrow-flight endpoints, such as Dremio Flight Server. With proper partitioning, it supports fast loading of large datasets by parallelizing reads. It also supports all insert/merge/update/delete DML operations for writes. With arrow-flight, it enables high speed data transfers compared to ODBC/JDBC connections by utilizing the Apache Arrow format to avoid serializing and deserializing data. 2 | 3 | To build the project, run: 4 | ```shell 5 | mvn clean install -DskipTests 6 | ``` 7 | 8 | The connector requires Spark 3.2.0+ due to the support of Interval types. For a quick start, please jump to this [tutorial](docs/tutorial.md). 9 | 10 | In order to connect to an arrow-flight endpoint, the following properties are mandatory: 11 | 12 | - `host`: the full host-name or ip of an arrow-flight server; 13 | - `port`: the port number; 14 | - `user`: the user account for connecting to and reading/writing data from/to the arrow-flight endpoint; 15 | - `password` or `bearerToken`: the password/pat or auth2 token of the user account. One of them must be provided; the password takes precedence if both provided. 16 | - `table`: the name of a table from/to which the connector reads/writes data. The table can be a physical data table, or any view. It can also be a select sql-statement or tables-joining statement. For example: 17 | - Select statement: 18 | ```roomsql 19 | select id, name, address from customers where city = 'TORONTO' 20 | ``` 21 | - Join tables statement: 22 | ```roomsql 23 | orders o inner join customers c on o.customer_id = c.id 24 | ``` 25 | *Note: the connector doesn't support legacy flight authentication mode (flight.auth.mode = legacy.arrow.flight.auth).* 26 | 27 | The following properties are optional: 28 | 29 | - `tls.enabled`: whether the arrow-flight end-point is tls-enabled for secure communication; 30 | - `tls.verifyServer` - whether to verify the certificate from the arrow-flight end-point; Default: true if tls.enabled = true. 31 | - `tls.truststore.jksFile`: the trust-store file in jks format; 32 | - `tls.truststore.pass`: the pass code of the trust-store; 33 | - `column.quote`: the character to quote the name of fields if any special character is used, such as the following sql statement: 34 | ```roomsql 35 | select id, "departure-time", "arrival-time" from flights where "flight-no" = 'ABC-21'; 36 | ``` 37 | - `default.schema`: default schema path to the dataset that the user wants to query. 38 | - `routing.tag`: tag name associated with all queries executed within a Flight session. Used only during authentication. 39 | - `routing.queue`: name of the workload management queue. Used only during authentication. 40 | 41 | The connector supports optimized reads with filters, required columns and aggregation pushing-down, and parallel reads when partitioning is enabled. 42 | 43 | ### 1. Load data 44 | ```scala 45 | val df = spark.read.format("flight") 46 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 47 | .option("table", """"e-commerce".orders""") 48 | .options(options) //other options 49 | .load 50 | df.printSchema 51 | df.count 52 | df.show 53 | ``` 54 | or 55 | ```scala 56 | val query = "select e.event_id, e.start_time, u.name as host_user, e.city from events e left join users u on e.user_id = u.id" 57 | val df = spark.read 58 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 59 | .option("table", query) 60 | .options(options) //other options 61 | .flight 62 | df.show 63 | ``` 64 | or 65 | ```scala 66 | import com.qwshen.flight.spark.implicits._ 67 | val df = spark.read 68 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 69 | .optoin("column.quote", "\"") 70 | .options(options) //other options 71 | .flight(""""e-commerce".orders""") 72 | df.show 73 | ``` 74 | 75 | #### - Partitioning: 76 | 77 | By default, the connector respects the partitioning from the source arrow-flight endpoints. Data from each endpoint is assigned to one partition. However, the partitioning behavior can be further customized with the following properties: 78 | - `partition.size`: the number of partitions in the final dataframe. The default is 6. 79 | - `partition.byColumn`: the name of a column used for partitioning. Only one column is supported. This is mandatory when custom partitioning is applied. 80 | - `partition.lowerBound`: the lower-bound of the by-column. This only applies when the data type of the by-column is numeric or date-time. 81 | - `partition.upperBound`: the upper-bound of the by-column. This only applies when the data type of the by-column is numeric or date-time. 82 | - `partition.hashFunc`: the name of the hash function supported in the arrow-flight end-points. This is required when the data-type of the by-column is not numeric or date-time, and the lower-bound, upper-bound are not provided. The default name is the hash as defined in Dremio. 83 | - `partition.predicate`: each individual partitioning predicate is prefixed with this key. 84 | - `partition.predicates`: all partitioning predicates, concatenated by semi-colons (;) or commas (,). 85 | 86 | Notes: 87 | - The by-column may or may not be on the select-list. If not, hash-partitioning is used by default due to the fact that the data-type of the by-column is not available. 88 | - The lowerBound and upperBound must be both specified or none of them specified. If only of them specified, it is ignored. 89 | - When lowerBound and upperBound are specified, and if their values don't match or are not compatible with the data-type of the by-column, then hash-partitioning is applied. 90 | 91 | Examples: 92 | ```scala 93 | //with lower-bound & upper-bound 94 | import com.qwshen.flight.spark.implicits._ 95 | spark.read 96 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 97 | .optoin("column.quote", "\"") 98 | .option("partition.size", 128).option("partition.byColumn", "order_date").option("partition.lowerBound", "2000-01-01").option("partition.upperBound", "2010-12-31") 99 | .flight(""""e-commerce".orders""") 100 | ``` 101 | ```scala 102 | //with hash-function 103 | import com.qwshen.flight.spark.implicits._ 104 | spark.read 105 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 106 | .optoin("column.quote", "\"") 107 | .option("partition.size", 128).option("partition.byColumn", "order_id").option("partition.hashFunc", "hash") 108 | .flight(""""e-commerce".orders""") 109 | ``` 110 | ```scala 111 | //with predicates 112 | import com.qwshen.flight.spark.implicits._ 113 | spark.read 114 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 115 | .optoin("column.quote", "\"") 116 | .option("partition.size", 128) 117 | .option("partition.predicate.1", "92 <= event_id and event_id < 715827379").option("partition.predicate.2", "715827379 <= event_id and event_id < 1431654667") 118 | .option("partition.predicates", "1431654667 <= event_id and event_id < 2147481954;2147481954 <= event_id and event_id < 2863309242;2863309242 <= event_id and event_id < 3579136529") //concatenated with ; 119 | .flight(""""e-commerce".events""") 120 | ``` 121 | Note: when lowerBound & upperBound with byColumn or predicates are used, they are eventually filters applied on the queries to fetch data which may impact the final result-set. Please make sure these partitioning options do not affect the final output, but rather only apply for partitioning the output. 122 | 123 | #### - Pushing filter & columns down 124 | Filters and required-columns are pushed down when they are provided. This limits the data at the source which greatly decreases the amount of data being transferred and processed. 125 | ```scala 126 | import com.qwshen.flight.spark.implicits._ 127 | spark.read 128 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 129 | .options(options) //other options 130 | .flight(""""e-commerce".orders""") 131 | .filter("order_date > '2020-01-01' and order_amount > 100") //filter is pushed down 132 | .select("order_id", "customer_id", "payment_method", "order_amount", "order_date") //required-columns are pushed down 133 | ``` 134 | 135 | #### - Pushing aggregation down 136 | Aggregations are pushed down when they are used. Only the following aggregations are supported: 137 | - max 138 | - min 139 | - count 140 | - count distinct 141 | - sum 142 | - sum distinct 143 | 144 | For avg, it can be achieved by combining count & sum. For any other aggregations, they are calculated at Spark level. 145 | 146 | ```scala 147 | val df = spark.read 148 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 149 | .options(options) //other options 150 | .flight(""""e-commerce".orders""") 151 | .filter("order_date > '2020-01-01' and order_amount > 100") //filter is pushed down 152 | 153 | df.agg(count(col("order_id")).as("num_orders"), sum(col("amount")).as("total_amount")).show() //aggregation pushed down 154 | 155 | df.groupBy(col("gender")) 156 | .agg( 157 | countDistinct(col("order_id")).as("num_orders"), 158 | max(col("amount")).as("max_amount"), 159 | min(col("amount")).as("min_amount"), 160 | sum(col("amount")).as("total_amount") 161 | ) //aggregation pushed down 162 | .show() 163 | ``` 164 | 165 | ### 2. Write & Streaming-Write data (tables being written must be iceberg tables in case of Dremio Flight) 166 | ```scala 167 | df.write.format("flight") 168 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 169 | .option("table", """"e-commerce".orders""") 170 | .option("write.protocol", "literal-sql").option("batch.size", "512") 171 | .options(options) //other options 172 | .mode("overwrite") 173 | .save 174 | ``` 175 | or 176 | ```scala 177 | import com.qwshen.flight.spark.implicits._ 178 | df.write 179 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 180 | .option("table", """"e-commerce".orders""") 181 | .option("merge.byColumn", "order_id") 182 | .options(options) //other options 183 | .mode("overwrite") 184 | .flight 185 | ``` 186 | or 187 | ```scala 188 | import com.qwshen.flight.spark.implicits._ 189 | df.write 190 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 191 | .option("merge.byColumns", "order_id,customer_id").optoin("column.quote", "\"") 192 | .options(options) //other options 193 | .mode("append") 194 | .flight(""""e-commerce".orders""") 195 | ``` 196 | streaming-write 197 | ```scala 198 | df.writeStream.format("flight") 199 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 200 | .option("table", """"e-commerce".orders""") 201 | .option("checkpointLocation", "/tmp/checkpointing") 202 | .options(options) //other options 203 | .trigger(Trigger.Continuous(1000)) 204 | .outputMode(OutputMode.Complete()) 205 | .start() 206 | .awaitTermination(6000) 207 | ``` 208 | The following options are supported for writing: 209 | - `write.protocol`: the protocol of how to submit DML requests to flight end-points. It must be one of the following: 210 | - `prepared-sql`: the connector uses PreparedStatement of Flight-SQL to conduct all DML operations. 211 | - `literal-sql`: the connector creates literal sql-statements for all DML operations. Type mappings between arrow and target flight end-point may be required, please check the Type-Mapping section below. This is the default protocol. 212 | - `batch.size`: the number of rows in each batch for writing. The default value is 1024. Note: depending on the size of each record, StackOverflowError might be thrown if the batch size is too big. In such case, adjust it to a smaller value. 213 | - `merge.byColumn`: the name of a column used for merging the data into the target table. This only applies when the save-mode is `append`; 214 | - `merge.ByColumns`: the name of multiple columns used for merging the data into the target table. This only applies when the save-mode is `append`. 215 | 216 | Examples: 217 | - Batch Write 218 | ```scala 219 | df.write 220 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 221 | .optoin("batch.size", 16000) 222 | .option("merge.byColumn.1", "user_id").options("merge.byColumn.2", "product_id") 223 | .option("merge.byColumns", "order_date;order_amount") //concatenated with ; 224 | .mode("append") 225 | .flight(""""e-commerce".orders""") 226 | ``` 227 | - Streaming Write 228 | ```scala 229 | df.writeStream.format("flight") 230 | .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") 231 | .option("table", """"local-iceberg".iceberg_db.iceberg_events""") 232 | .option("checkpointLocation", s"/tmp/staging/checkpoint/events") 233 | .optoin("batch.size", 640) 234 | .trigger(Trigger.Once()) 235 | .outputMode(OutputMode.Append()) 236 | .start() 237 | .awaitTermination(300000) 238 | ``` 239 | 240 | ### 3. Data-type Mapping 241 | #### - Arrow >> Spark 242 | 243 | Arrow | Spark 244 | --- | --- 245 | bit | boolean 246 | signed int8 | byte 247 | un-signed int8 | short 248 | signed int16 | short 249 | un-signed int16 | int 250 | signed int32 | int 251 | un-signed int32 | long 252 | signed int64 | long 253 | un-signed int64 | long 254 | var-char | string 255 | decimal | decimal 256 | decimal256 | decimal 257 | floating-single | float 258 | floating-half | float 259 | floating-double | double 260 | date | date 261 | time | string (hh:mm:ss.*) 262 | timestamp | timestamp 263 | interval - year.month | year-month interval 264 | interval - day.time | day-time interval 265 | duration | struct(year-month, day-time) 266 | interval - month-day-nano | struct(year-month, day-time) 267 | list | array 268 | struct | struct 269 | map | map 270 | 271 | Note: for Dremio Flight before v23.0.0, the Map type is converted to Struct. The connector detects the pattern and converts back to Map when reading data, and adapts to Struct when writing data (with v22.0.0 or above for write only). 272 | 273 | #### - Spark >> Arrow 274 | 275 | When the connector is writing data, the schema of the target table is retrieved first, then the connector tries to adapt the source field to the type of the target, so the types of source and target must be compatible. Otherwise, runtime exception will be thrown. Such as 276 | - Spark Int adapts to Arrow Decimal; 277 | - Spark Timestamp adapts to Arrow Time; 278 | - When using literal sql-statements (write.protocol = literal-sql), all complex types (struct, map & list) are converted to json string. 279 | - etc. 280 | 281 | Note: for Dremio Flight (up to v22.0.0), it doesn't support writing complex types with DML statements yet, neither batch-writing with prepared-statements against iceberg tables. In such case, the connector could use literal sql-statements for DML operations. 282 | 283 | #### - Arrow >> Flight End-Point (For "write.protocol = literal-sql" only) 284 | When the connector uses literal sql-statements for DML operations, it needs to know the type system of the target flight end-point which may not support all types defined in Apache Arrow. 285 | 286 | The following is the type-mapping between Apache Arrow and Dremio end-point: 287 | ```scala worksheet 288 | BIT --> BOOLEAN 289 | LARGEVARCHAR --> VARCHAR 290 | VARCHAR --> VARCHAR 291 | TINYINT --> INT 292 | SMALLINT --> INT 293 | UINT1 --> INT 294 | UINT2 --> INT 295 | UINT4 --> INT 296 | UINT8 --> INT 297 | INT -> INT 298 | BIGINT -> BIGINT 299 | FLOAT4 --> FLOAT 300 | FLOAT8 --> DOUBLE 301 | DECIMAL --> DECIMAL 302 | DECIMAL256 --> DECIMAL 303 | DATE --> DATE 304 | TIME --> TIME 305 | TIMESTAMP --> TIMESTAMP 306 | ``` 307 | This is also the default type-mapping used by the connector. To override it, please use the following option: 308 | ```scala worksheet 309 | df.write.format("flight") 310 | .option("write.protocol", "literal-sql") 311 | .option("write.typeMapping", "LARGEVARCHAR:VARCHAR;VARCHAR:VARCHAR;TINYINT:INT;SMALLINT:INT;UINT1:INT;UINT2:INT;UINT4:INT;UINT8:INT;INT:INT;BIGINT:BIGINT;FLOAT4:FLOAT;FLOAT8:DOUBLE;DECIMAL:DECIMAL;DECIMAL256:DECIMAL;DATE:DATE;TIME:TIME;TIMESTAMP:TIMESTAMP") 312 | .option(options) //other options 313 | .mode("overwrite").save 314 | ``` 315 | Currently, the binary, interval, and complex types are not supported when using literal sql-statement for DML operations. 316 | -------------------------------------------------------------------------------- /src/test/scala/com/qwshen/flight/spark/test/DremioTest.scala: -------------------------------------------------------------------------------- 1 | package com.qwshen.flight.spark.test 2 | 3 | import org.apache.spark.sql.{DataFrame, SparkSession} 4 | import org.apache.spark.sql.functions.{array, avg, col, count, countDistinct, lit, map, max, min, struct, sum, sum_distinct, when} 5 | import org.apache.spark.sql.streaming.{OutputMode, Trigger} 6 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 7 | import org.scalatest.{BeforeAndAfterEach, FunSuite} 8 | 9 | class DremioTest extends FunSuite with BeforeAndAfterEach { 10 | private val dremioHost = "192.168.0.19" 11 | private val dremioPort = "32010" 12 | private val dremioTlsEnabled = false 13 | private val user = "test" 14 | private val password = "Password@12345" 15 | 16 | test("Run a simple query") { 17 | val query = """select * from "azure-wstorage".input.users""" 18 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "column.quote" -> "\"")) 19 | val df = this.execute(run) 20 | df.printSchema() 21 | df.count() 22 | df.show() 23 | } 24 | 25 | test("Run a simple query with filter & fields") { 26 | val query = """select * from "azure-wstorage".input.users""" 27 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "column.quote" -> "\"")) 28 | val df = this.execute(run) 29 | .filter("""(birthyear >= 1988 and birthyear < 1997) or (gender = 'female' and "joined-at" like '2012-11%')""") 30 | .select("user_id", "joined-at") 31 | df.printSchema() 32 | df.count() 33 | df.show() 34 | } 35 | 36 | test("Query a table with filter & fields") { 37 | val table = """"azure-wstorage".input.users""" 38 | val run: SparkSession => DataFrame = this.load(Map("table" -> table, "column.quote" -> "\"")) 39 | val df = this.execute(run) 40 | .filter("""(birthyear >= 1988 and birthyear < 1997) or (gender = 'female' and "joined-at" like '2012-11%')""") 41 | .select("user_id", "joined-at") 42 | df.printSchema() 43 | df.count() 44 | df.show() 45 | } 46 | 47 | test("Query a join-table with filter & fields") { 48 | val table = """"azure-wstorage".input.events e left join "azure-wstorage".input.users u on e.user_id = u.user_id""" 49 | val run: SparkSession => DataFrame = this.load(Map("table" -> table, "column.quote" -> "\"")) 50 | val df = this.execute(run) 51 | .filter("""birthyear is not null""") 52 | .select("user_name", "gender", "event_id", "start_time") 53 | df.printSchema() 54 | df.count() 55 | df.show() 56 | } 57 | 58 | test("Run a query with decimal, date, time, timestamp, year-month, day-time etc.") { 59 | val query = """ 60 | |select 61 | | cast(event_id as bigint) as event_id, 62 | | gender, 63 | | birthyear, 64 | | cast(2022 - birthyear as interval year) as age_year, 65 | | cast(2022 - birthyear as interval month) as age_month, 66 | | cast(2022 - birthyear as interval day) as age_day, 67 | | start_time as raw_start_time, 68 | | cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 8)) as timestamp) as real_start_time, 69 | | cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 12)) as timestamp) as milli_start_time, 70 | | cast(substring(start_time, 12, 8) as time) as real_time, 71 | | cast(now() - cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 8)) as timestamp) as interval hour) as ts_hour, 72 | | cast(now() - cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 8)) as timestamp) as interval minute) as ts_minute, 73 | | cast(now() - cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 8)) as timestamp) as interval second) as ts_second, 74 | | cast(substring(start_time, 1, 10) as date) as start_date, 75 | | cast(case when extract(year from now()) - extract(year from cast(substring(start_time, 1, 10) as date)) >= 65 then 1 else 0 end as boolean) as senior, 76 | | cast(3423.23 as float) as float_amount, 77 | | cast(2342345.13 as double) as double_amount, 78 | | cast(32423423.31 as decimal) as decimal_amount 79 | |from "@test".events e inner join "@test".users u on e.user_id = u.user_id 80 | |""".stripMargin 81 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "column.quote" -> "\"")) 82 | val df = this.execute(run) 83 | df.printSchema() 84 | df.count() 85 | df.show(false) 86 | } 87 | 88 | test("Run a query with simple list & struct types") { 89 | val query = """ 90 | |select 91 | | 'Gnarly' as name, 92 | | 7 as age, 93 | | CONVERT_FROM('{"name" : "Gnarly", "age": 7, "car": "BMW"}', 'json') as data, 94 | | CONVERT_FROM('["apple", "strawberry", "banana"]', 'json') as favorites 95 | |""".stripMargin 96 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "column.quote" -> "\"")) 97 | val df = this.execute(run) 98 | df.printSchema() 99 | df.count() 100 | df.show() 101 | } 102 | 103 | test("Run a query with simple map types") { 104 | val query = """ 105 | |select 106 | | 'James' as name, 107 | | convert_from('["Newark", "NY"]', 'json') as cities, 108 | | convert_from('{"city": "Newark", "state": "NY"}', 'json') as address, 109 | | convert_from('{"map": [{"Key": "hair", "value": "block"}, {"key": "eye", "value": "brown"}]}', 'json') as prop_1, 110 | | convert_from('{"map": [{"key": "height", "value": "5.9"}]}', 'json') as prop_2 111 | |""".stripMargin 112 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "column.quote" -> "\"")) 113 | val df = this.execute(run) 114 | df.printSchema() 115 | df.count() 116 | df.show() 117 | } 118 | 119 | test("Query a table with list, struct & map types") { 120 | /* 121 | Create a dataframe in spark-shell with the following code, then save the dataframe in parquet files. 122 | Create a data source pointing the parquet files in Dremio 123 | 124 | val arrayStructureData = Seq( 125 | Row("James",List(Row("Newark","NY"), 126 | Row("Brooklyn","NY")),Map("hair"->"black","eye"->"brown"), Map("height"->"5.9")), 127 | Row("Michael",List(Row("SanJose","CA"),Row("Sandiago","CA")), Map("hair"->"brown","eye"->"black"),Map("height"->"6")), 128 | Row("Robert",List(Row("LasVegas","NV")), Map("hair"->"red","eye"->"gray"),Map("height"->"6.3")), 129 | Row("Maria",null,Map("hair"->"blond","eye"->"red"), Map("height"->"5.6")), 130 | Row("Jen",List(Row("LAX","CA"),Row("Orange","CA")), Map("white"->"black","eye"->"black"),Map("height"->"5.2")) 131 | ) 132 | 133 | val mapType = DataTypes.createMapType(StringType,StringType) 134 | val arrayStructureSchema = new StructType() 135 | .add("name",StringType) 136 | .add("addresses", ArrayType(new StructType() 137 | .add("city",StringType) 138 | .add("state",StringType))) 139 | .add("properties", mapType) 140 | .add("secondProp", MapType(StringType,StringType)) 141 | 142 | val mapTypeDF = spark.createDataFrame(spark.sparkContext.parallelize(arrayStructureData),arrayStructureSchema) 143 | mapTypeDF.printSchema() 144 | mapTypeDF.show() 145 | */ 146 | val run: SparkSession => DataFrame = this.load(Map("table" -> "ESG.\"map\"", "column.quote" -> "\"")) 147 | val df = this.execute(run) 148 | df.printSchema() 149 | df.count() 150 | df.show() 151 | } 152 | 153 | test("Query a table with partitioning by column") { 154 | val table = """"azure-wstorage".input.events""" 155 | val run: SparkSession => DataFrame = this.load(Map("table" -> table, "partition.byColumn" -> "event_id", "partition.size" -> "3")) 156 | val df = this.execute(run) 157 | df.printSchema() 158 | df.count() 159 | df.show(10) 160 | } 161 | 162 | test("Query a table with partitioning by column having date-time lower-bound & upper-bound") { 163 | val table = """"azure-wstorage".input.events""" 164 | val run: SparkSession => DataFrame = this.load(Map("table" -> table, "partition.byColumn" -> "start_time", "partition.size" -> "6", "partition.lowerBound" -> "1912-01-01T05:59:24.003Z", "partition.upperBound" -> "3001-04-12T05:00:00.002Z")) 165 | val df = this.execute(run) 166 | df.printSchema() 167 | df.count() 168 | df.show(10) 169 | } 170 | 171 | test("Query a table with partitioning by column having long lower-bound & upper-bound") { 172 | val query = """ 173 | |select 174 | | cast(event_id as bigint) as event_id, 175 | | cast(user_id as bigint) as user_id, 176 | | start_time, 177 | | city, 178 | | state, 179 | | zip, 180 | | country 181 | |from "azure-wstorage".input.events 182 | |""".stripMargin 183 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "partition.byColumn" -> "event_id", "partition.size" -> "6", "partition.lowerBound" -> "92", "partition.upperBound" -> "4294963817")) 184 | val df = this.execute(run) 185 | df.printSchema() 186 | df.count() 187 | df.show() 188 | } 189 | 190 | test("Query a table with partitioning predicates") { 191 | val query = """ 192 | |select 193 | | cast(event_id as bigint) as event_id, 194 | | cast(user_id as bigint) as user_id, 195 | | start_time, 196 | | city 197 | |from "azure-wstorage".input.events 198 | |""".stripMargin 199 | val predicates = Seq( 200 | "92 <= event_id and event_id < 715827379", "715827379 <= event_id and event_id < 1431654667", "1431654667 <= event_id and event_id < 2147481954", 201 | "2147481954 <= event_id and event_id < 2863309242", "2863309242 <= event_id and event_id < 3579136529", "3579136529 <= event_id and event_id < 4294963817" 202 | ) 203 | val run: SparkSession => DataFrame = this.load(Map("table" -> query, "partition.predicates" -> predicates.mkString(";"))) 204 | val df = this.execute(run) 205 | df.printSchema() 206 | df.count() 207 | df.show() 208 | } 209 | 210 | test("Query a table with aggregation") { 211 | val table = """"local-iceberg".iceberg_db.log_events_iceberg_table_events""" 212 | val run: SparkSession => DataFrame = this.load(Map("table" -> table)) 213 | val df = this.execute(run) 214 | 215 | df.agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() 216 | df.filter(col("float_amount") >= lit(2.34f)).agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() 217 | 218 | //df.limit(10).show() //not supported 219 | df.distinct().show() 220 | 221 | df.filter(col("float_amount") >= lit(2.34f)) 222 | .groupBy(col("gender"), col("birthyear")) 223 | .agg( 224 | countDistinct(col("event_id")).as("distinct_count"), 225 | count(col("event_id")).as("count"), 226 | max(col("float_amount")).as("max_float_amount"), 227 | min(col("float_amount")).as("min_float_amount"), 228 | //avg(col("decimal_amount")).as("avg_decimal_amount"), //not supported 229 | sum_distinct(col("double_amount")).as("distinct_sum_double_amount"), 230 | sum(col("double_amount")).as("sum_double_amount") 231 | ) 232 | .show() 233 | } 234 | 235 | test("Query a table with aggregation with partitioning by hashing") { 236 | val table = """"local-iceberg".iceberg_db.log_events_iceberg_table_events""" 237 | val run: SparkSession => DataFrame = this.load(Map("table" -> table, "partition.size" -> "3", "partition.byColumn" -> "event_id")) 238 | val df = this.execute(run) 239 | 240 | df.agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() 241 | df.filter(col("float_amount") >= lit(2.34f)).agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() 242 | 243 | //df.limit(10).show() //not supported 244 | //df.distinct().show() //not supported 245 | 246 | df.filter(col("float_amount") >= lit(2.34f)) 247 | .groupBy(col("gender"), col("birthyear")) 248 | .agg( 249 | countDistinct(col("event_id")).as("distinct_count"), 250 | count(col("event_id")).as("count"), 251 | max(col("float_amount")).as("max_float_amount"), 252 | min(col("float_amount")).as("min_float_amount"), 253 | //avg(col("decimal_amount")).as("avg_decimal_amount"), //not supported 254 | sum_distinct(col("double_amount")).as("distinct_sum_double_amount"), 255 | sum(col("double_amount")).as("sum_double_amount") 256 | ) 257 | .show() 258 | } 259 | 260 | test("Query a table with aggregation with partitioning by lower/upper bounds") { 261 | val table = """"local-iceberg".iceberg_db.log_events_iceberg_table_events""" 262 | val run: SparkSession => DataFrame = this.load(Map("table" -> table, "partition.size" -> "3", "partition.byColumn" -> "start_date", "partition.lowerBound" -> "2012-10-01", "partition.upperBound" -> "2012-12-31")) 263 | val df = this.execute(run) 264 | 265 | df.agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() 266 | df.filter(col("float_amount") >= lit(2.34f)).agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() 267 | 268 | //df.limit(10).show() //not supported 269 | //df.distinct().show() //not supported 270 | 271 | df.filter(col("float_amount") >= lit(2.34f)) 272 | .groupBy(col("gender"), col("start_date")) 273 | .agg( 274 | countDistinct(col("event_id")).as("distinct_count"), 275 | count(col("event_id")).as("count"), 276 | max(col("float_amount")).as("max_float_amount"), 277 | min(col("float_amount")).as("min_float_amount"), 278 | //avg(col("decimal_amount")).as("avg_decimal_amount"), //not supported 279 | sum_distinct(col("double_amount")).as("distinct_sum_double_amount"), 280 | sum(col("double_amount")).as("sum_double_amount") 281 | ) 282 | .show() 283 | } 284 | 285 | test("Write a simple table") { 286 | val query = """ 287 | |select 288 | | cast(event_id as bigint) as event_id, 289 | | gender, 290 | | birthyear, 291 | | start_time as raw_start_time, 292 | | cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 8)) as timestamp) as real_start_time, 293 | | cast(concat(substring(start_time, 1, 10), ' ', substring(start_time, 12, 12)) as timestamp) as milli_start_time, 294 | | cast(substring(start_time, 12, 8) as time) as real_time, 295 | | cast(substring(start_time, 1, 10) as date) as start_date, 296 | | cast(case when extract(year from now()) - extract(year from cast(substring(start_time, 1, 10) as date)) >= 65 then 1 else 0 end as boolean) as senior, 297 | | cast(3423.23 as float) as float_amount, 298 | | cast(2342345.13 as double) as double_amount, 299 | | cast(32423423.31 as decimal) as decimal_amount 300 | |from "azure-wstorage".input.events e inner join "azure-wstorage".input.users u on e.user_id = u.user_id 301 | |""".stripMargin 302 | val runLoad: SparkSession => DataFrame = this.load(Map("table" -> query)) 303 | val df = this.execute(runLoad).cache 304 | df.printSchema() 305 | df.show(false) 306 | 307 | val table = """"local-iceberg"."iceberg_db"."log_events_iceberg_table_events"""" 308 | //overwrite 309 | val overwrite: () => Unit = this.overWrite(Map("table" -> table), df) 310 | this.execute(overwrite) 311 | //appendWrite 312 | val dfAppend = df.withColumn("event_id", col("event_id") + 10000000000L) 313 | val append: () => Unit = this.appendWrite(Map("table" -> table), dfAppend) 314 | this.execute(append) 315 | //merge 316 | val dfMerge = df 317 | .withColumn("event_id", when(col("event_id") === lit(42204521L), col("event_id") + 1).otherwise(col("event_id"))) 318 | .withColumn("float_amount", col("float_amount") + 1111) 319 | .withColumn("double_amount", col("double_amount") + 2222) 320 | .withColumn("decimal_amount", col("decimal_amount") + 3333) 321 | val merge: () => Unit = this.mergeWrite(Map("table" -> table, "merge.byColumn" -> "event_id"), dfMerge) 322 | this.execute(merge) 323 | } 324 | 325 | test("Append a heavy table") { 326 | val srcTable = """"azure-wstorage".input.events""" 327 | val runLoad: SparkSession => DataFrame = this.load(Map("table" -> srcTable)) 328 | val df = this.execute(runLoad).cache 329 | 330 | //appendWrite 331 | val dstTable = """"local-iceberg"."iceberg_db"."iceberg_events"""" 332 | val dfAppend = df.limit(10000) 333 | .withColumn("event_id", col("event_id") + 1000000000L) 334 | .withColumn("remark", lit("new")) 335 | val overwrite: () => Unit = this.overWrite(Map("table" -> dstTable, "batch.size" -> "768"), dfAppend) 336 | this.execute(overwrite) 337 | } 338 | 339 | test("Append a heavy table with complex-type --> string") { 340 | val srcTable = """"azure-wstorage".input.events""" 341 | val runLoad: SparkSession => DataFrame = this.load(Map("table" -> srcTable)) 342 | val df = this.execute(runLoad).filter(col("user_id").isin("781622845", "1519813515", "1733137333", "3709565024")).cache 343 | 344 | //the target table 345 | val dstTable = """"local-iceberg"."iceberg_db"."iceberg_events"""" 346 | 347 | //appendWrite for struct 348 | val dfStruct = df.filter(col("user_id") === lit("781622845") || col("user_id") === lit("3709565024")) 349 | .withColumn("event_id", col("event_id") + 1000000000L) 350 | .withColumn("remark", struct(col("city").as("city"), col("state").as("state"), col("country").as("country"))) 351 | val appendStruct: () => Unit = this.appendWrite(Map("table" -> dstTable, "merge.byColumns" -> "event_id,user_id"), dfStruct) 352 | this.execute(appendStruct) 353 | 354 | //appendWrite for map 355 | val dfMap = df.filter(col("user_id") === lit("1519813515")) 356 | .withColumn("event_id", col("event_id") + 1000000000L) 357 | .withColumn("remark", map(lit("city"), col("city"), lit("state"), col("state"), lit("country"), col("country"))) 358 | val appendMap: () => Unit = this.appendWrite(Map("table" -> dstTable, "merge.byColumns_1" -> "event_id", "merge.byColumns_2" -> "user_id"), dfMap) 359 | this.execute(appendMap) 360 | 361 | //appendWrite for array 362 | val dfArray = df.filter(col("user_id") === lit("1733137333")) 363 | .withColumn("event_id", col("event_id") + 1000000000L) 364 | .withColumn("remark", array(col("city"), col("state"), col("country"))) 365 | val appendArray: () => Unit = this.appendWrite(Map("table" -> dstTable, "merge.byColumns" -> "event_id;user_id"), dfArray) 366 | this.execute(appendArray) 367 | } 368 | 369 | //inserting complex type not supported yet due to un-support on the flight service 370 | ignore("Write a table with list and struct") { 371 | val table = """"local-iceberg"."iceberg_db"."log_events_iceberg_struct_list"""" 372 | val runLoad: SparkSession => DataFrame = this.load(Map("table" -> table)) 373 | val df = this.execute(runLoad).cache 374 | df.printSchema() 375 | df.show(false) 376 | 377 | val run: () => Unit = this.overWrite(Map("table" -> table), df) 378 | this.execute(run) 379 | } 380 | 381 | ignore("Write a table with map") { 382 | val table = """"local-iceberg"."iceberg_db"."log_events_iceberg_map"""" 383 | val runLoad: SparkSession => DataFrame = this.load(Map("table" -> table)) 384 | val df = this.execute(runLoad).cache 385 | df.printSchema() 386 | df.show(false) 387 | 388 | val write: () => Unit = this.appendWrite(Map("table" -> table), df) 389 | this.execute(write) 390 | } 391 | 392 | test("Streaming-write a table") { 393 | val resRoot: String = getClass.getClassLoader.getResource("").getPath 394 | 395 | val fields = Seq(StructField("event_id", StringType), StructField("user_id", StringType), 396 | StructField("start_time", StringType), StructField("city", StringType), StructField("province", StringType), StructField("country", StringType)) 397 | val streamLoad: SparkSession => DataFrame = this.streamLoad(Map("header" -> "true", "delimiter" -> ","), StructType(fields), s"${resRoot}data/events") 398 | val df = this.execute(streamLoad) 399 | .select(col("event_id"), col("user_id"), col("start_time"), col("city"), col("province").as("state"), lit("n/a").as("zip"), col("country"), lit("n/a").as("remark")) 400 | val streamWrite: () => Unit = this.streamWrite(Map("table" -> """"local-iceberg".iceberg_db.iceberg_events""", "checkpointLocation" -> s"${resRoot}checkpoint/events"), df) 401 | this.execute(streamWrite) 402 | } 403 | 404 | //load the data-frame 405 | private def load(options: Map[String, String])(spark: SparkSession): DataFrame = spark.read.format("flight") 406 | .option("host", this.dremioHost).option("port", this.dremioPort).option("tls.enabled", dremioTlsEnabled).option("user", this.user).option("password", this.password) 407 | .options(options) 408 | .load 409 | 410 | //stream-load 411 | private def streamLoad(options: Map[String, String], schema: StructType, dataLocation: String)(spark: SparkSession): DataFrame = spark.readStream.format("csv").options(options).schema(schema).load(dataLocation) 412 | 413 | //overwrite with the data-frame 414 | private def overWrite(options: Map[String, String], df: DataFrame)(): Unit = df.write.format("flight") 415 | .option("host", this.dremioHost).option("port", this.dremioPort).option("tls.enabled", dremioTlsEnabled).option("user", this.user).option("password", this.password) 416 | .options(options) 417 | .mode("overwrite").save 418 | 419 | //appendWrite the data-frame 420 | private def appendWrite(options: Map[String, String], df: DataFrame)(): Unit = df.write.format("flight") 421 | .option("host", this.dremioHost).option("port", this.dremioPort).option("tls.enabled", dremioTlsEnabled).option("user", this.user).option("password", this.password) 422 | .options(options) 423 | .mode("append").save 424 | 425 | //merge the data-frame 426 | private def mergeWrite(options: Map[String, String], df: DataFrame)(): Unit = df.write.format("flight") 427 | .option("host", this.dremioHost).option("port", this.dremioPort).option("tls.enabled", dremioTlsEnabled).option("user", this.user).option("password", this.password) 428 | .options(options) 429 | .mode("append").save 430 | 431 | //streaming-write with the data-frame 432 | private def streamWrite(options: Map[String, String], df: DataFrame)(): Unit = df.writeStream.format("flight") 433 | .option("host", this.dremioHost).option("port", this.dremioPort).option("tls.enabled", dremioTlsEnabled).option("user", this.user).option("password", this.password) 434 | .options(options) 435 | .trigger(Trigger.Once()) 436 | .outputMode(OutputMode.Append()) 437 | .start() 438 | .awaitTermination(300000) 439 | 440 | //create spark-session 441 | private var spark: SparkSession = _ 442 | //execute a job 443 | private def execute[T](read: SparkSession => T): T = read(spark) 444 | private def execute[T](write: () => T): T = write() 445 | 446 | override def beforeEach(): Unit = spark = SparkSession.builder.master("local[*]").config("spark.executor.memory", "24g").config("spark.driver.memory", "24g").appName("test").getOrCreate 447 | override def afterEach(): Unit = spark.stop() 448 | } 449 | --------------------------------------------------------------------------------