options) {
91 | return makeTable(new CaseInsensitiveStringMap(options));
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightArrowColumnVector.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | /*
17 | * Licensed to the Apache Software Foundation (ASF) under one or more
18 | * contributor license agreements. See the NOTICE file distributed with
19 | * this work for additional information regarding copyright ownership.
20 | * The ASF licenses this file to You under the Apache License, Version 2.0
21 | * (the "License"); you may not use this file except in compliance with
22 | * the License. You may obtain a copy of the License at
23 | *
24 | * http://www.apache.org/licenses/LICENSE-2.0
25 | *
26 | * Unless required by applicable law or agreed to in writing, software
27 | * distributed under the License is distributed on an "AS IS" BASIS,
28 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29 | * See the License for the specific language governing permissions and
30 | * limitations under the License.
31 | */
32 |
33 | package org.apache.arrow.flight.spark;
34 |
35 | import org.apache.arrow.vector.BigIntVector;
36 | import org.apache.arrow.vector.BitVector;
37 | import org.apache.arrow.vector.DateDayVector;
38 | import org.apache.arrow.vector.DateMilliVector;
39 | import org.apache.arrow.vector.DecimalVector;
40 | import org.apache.arrow.vector.Float4Vector;
41 | import org.apache.arrow.vector.Float8Vector;
42 | import org.apache.arrow.vector.IntVector;
43 | import org.apache.arrow.vector.SmallIntVector;
44 | import org.apache.arrow.vector.TimeStampMicroVector;
45 | import org.apache.arrow.vector.TimeStampMicroTZVector;
46 | import org.apache.arrow.vector.TimeStampMilliVector;
47 | import org.apache.arrow.vector.TimeStampVector;
48 | import org.apache.arrow.vector.TinyIntVector;
49 | import org.apache.arrow.vector.ValueVector;
50 | import org.apache.arrow.vector.VarBinaryVector;
51 | import org.apache.arrow.vector.VarCharVector;
52 | import org.apache.arrow.vector.complex.ListVector;
53 | import org.apache.arrow.vector.complex.StructVector;
54 | import org.apache.arrow.vector.holders.NullableVarCharHolder;
55 | import org.apache.arrow.memory.ArrowBuf;
56 | import org.apache.spark.sql.execution.arrow.FlightArrowUtils;
57 | import org.apache.spark.sql.types.Decimal;
58 | import org.apache.spark.sql.vectorized.ColumnVector;
59 | import org.apache.spark.sql.vectorized.ColumnarArray;
60 | import org.apache.spark.sql.vectorized.ColumnarMap;
61 | import org.apache.spark.unsafe.types.UTF8String;
62 |
63 |
64 | /**
65 | * A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
66 | * supported. This is a copy of ArrowColumnVector with added support for DateMilli and TimestampMilli
67 | */
68 | public final class FlightArrowColumnVector extends ColumnVector {
69 |
70 | private final ArrowVectorAccessor accessor;
71 | private FlightArrowColumnVector[] childColumns;
72 |
73 | @Override
74 | public boolean hasNull() {
75 | return accessor.getNullCount() > 0;
76 | }
77 |
78 | @Override
79 | public int numNulls() {
80 | return accessor.getNullCount();
81 | }
82 |
83 | @Override
84 | public void close() {
85 | if (childColumns != null) {
86 | for (int i = 0; i < childColumns.length; i++) {
87 | childColumns[i].close();
88 | childColumns[i] = null;
89 | }
90 | childColumns = null;
91 | }
92 | accessor.close();
93 | }
94 |
95 | @Override
96 | public boolean isNullAt(int rowId) {
97 | return accessor.isNullAt(rowId);
98 | }
99 |
100 | @Override
101 | public boolean getBoolean(int rowId) {
102 | return accessor.getBoolean(rowId);
103 | }
104 |
105 | @Override
106 | public byte getByte(int rowId) {
107 | return accessor.getByte(rowId);
108 | }
109 |
110 | @Override
111 | public short getShort(int rowId) {
112 | return accessor.getShort(rowId);
113 | }
114 |
115 | @Override
116 | public int getInt(int rowId) {
117 | return accessor.getInt(rowId);
118 | }
119 |
120 | @Override
121 | public long getLong(int rowId) {
122 | return accessor.getLong(rowId);
123 | }
124 |
125 | @Override
126 | public float getFloat(int rowId) {
127 | return accessor.getFloat(rowId);
128 | }
129 |
130 | @Override
131 | public double getDouble(int rowId) {
132 | return accessor.getDouble(rowId);
133 | }
134 |
135 | @Override
136 | public Decimal getDecimal(int rowId, int precision, int scale) {
137 | if (isNullAt(rowId)) {
138 | return null;
139 | }
140 | return accessor.getDecimal(rowId, precision, scale);
141 | }
142 |
143 | @Override
144 | public UTF8String getUTF8String(int rowId) {
145 | if (isNullAt(rowId)) {
146 | return null;
147 | }
148 | return accessor.getUTF8String(rowId);
149 | }
150 |
151 | @Override
152 | public byte[] getBinary(int rowId) {
153 | if (isNullAt(rowId)) {
154 | return null;
155 | }
156 | return accessor.getBinary(rowId);
157 | }
158 |
159 | @Override
160 | public ColumnarArray getArray(int rowId) {
161 | if (isNullAt(rowId)) {
162 | return null;
163 | }
164 | return accessor.getArray(rowId);
165 | }
166 |
167 | @Override
168 | public ColumnarMap getMap(int rowId) {
169 | throw new UnsupportedOperationException();
170 | }
171 |
172 | @Override
173 | public FlightArrowColumnVector getChild(int ordinal) {
174 | return childColumns[ordinal];
175 | }
176 |
177 | public FlightArrowColumnVector(ValueVector vector) {
178 | super(FlightArrowUtils.fromArrowField(vector.getField()));
179 |
180 | if (vector instanceof BitVector) {
181 | accessor = new BooleanAccessor((BitVector) vector);
182 | } else if (vector instanceof TinyIntVector) {
183 | accessor = new ByteAccessor((TinyIntVector) vector);
184 | } else if (vector instanceof SmallIntVector) {
185 | accessor = new ShortAccessor((SmallIntVector) vector);
186 | } else if (vector instanceof IntVector) {
187 | accessor = new IntAccessor((IntVector) vector);
188 | } else if (vector instanceof BigIntVector) {
189 | accessor = new LongAccessor((BigIntVector) vector);
190 | } else if (vector instanceof Float4Vector) {
191 | accessor = new FloatAccessor((Float4Vector) vector);
192 | } else if (vector instanceof Float8Vector) {
193 | accessor = new DoubleAccessor((Float8Vector) vector);
194 | } else if (vector instanceof DecimalVector) {
195 | accessor = new DecimalAccessor((DecimalVector) vector);
196 | } else if (vector instanceof VarCharVector) {
197 | accessor = new StringAccessor((VarCharVector) vector);
198 | } else if (vector instanceof VarBinaryVector) {
199 | accessor = new BinaryAccessor((VarBinaryVector) vector);
200 | } else if (vector instanceof DateDayVector) {
201 | accessor = new DateAccessor((DateDayVector) vector);
202 | } else if (vector instanceof DateMilliVector) {
203 | accessor = new DateMilliAccessor((DateMilliVector) vector);
204 | } else if (vector instanceof TimeStampMicroVector) {
205 | accessor = new TimestampMicroAccessor((TimeStampMicroVector) vector);
206 | } else if (vector instanceof TimeStampMicroTZVector) {
207 | accessor = new TimestampMicroTZAccessor((TimeStampMicroTZVector) vector);
208 | } else if (vector instanceof TimeStampMilliVector) {
209 | accessor = new TimestampMilliAccessor((TimeStampMilliVector) vector);
210 | } else if (vector instanceof ListVector) {
211 | ListVector listVector = (ListVector) vector;
212 | accessor = new ArrayAccessor(listVector);
213 | } else if (vector instanceof StructVector) {
214 | StructVector structVector = (StructVector) vector;
215 | accessor = new StructAccessor(structVector);
216 |
217 | childColumns = new FlightArrowColumnVector[structVector.size()];
218 | for (int i = 0; i < childColumns.length; ++i) {
219 | childColumns[i] = new FlightArrowColumnVector(structVector.getVectorById(i));
220 | }
221 | } else {
222 | System.out.println(vector);
223 | throw new UnsupportedOperationException();
224 | }
225 | }
226 |
227 | private abstract static class ArrowVectorAccessor {
228 |
229 | private final ValueVector vector;
230 |
231 | ArrowVectorAccessor(ValueVector vector) {
232 | this.vector = vector;
233 | }
234 |
235 | // TODO: should be final after removing ArrayAccessor workaround
236 | boolean isNullAt(int rowId) {
237 | return vector.isNull(rowId);
238 | }
239 |
240 | final int getNullCount() {
241 | return vector.getNullCount();
242 | }
243 |
244 | final void close() {
245 | vector.close();
246 | }
247 |
248 | boolean getBoolean(int rowId) {
249 | throw new UnsupportedOperationException();
250 | }
251 |
252 | byte getByte(int rowId) {
253 | throw new UnsupportedOperationException();
254 | }
255 |
256 | short getShort(int rowId) {
257 | throw new UnsupportedOperationException();
258 | }
259 |
260 | int getInt(int rowId) {
261 | throw new UnsupportedOperationException();
262 | }
263 |
264 | long getLong(int rowId) {
265 | throw new UnsupportedOperationException();
266 | }
267 |
268 | float getFloat(int rowId) {
269 | throw new UnsupportedOperationException();
270 | }
271 |
272 | double getDouble(int rowId) {
273 | throw new UnsupportedOperationException();
274 | }
275 |
276 | Decimal getDecimal(int rowId, int precision, int scale) {
277 | throw new UnsupportedOperationException();
278 | }
279 |
280 | UTF8String getUTF8String(int rowId) {
281 | throw new UnsupportedOperationException();
282 | }
283 |
284 | byte[] getBinary(int rowId) {
285 | throw new UnsupportedOperationException();
286 | }
287 |
288 | ColumnarArray getArray(int rowId) {
289 | throw new UnsupportedOperationException();
290 | }
291 | }
292 |
293 | private static class BooleanAccessor extends ArrowVectorAccessor {
294 |
295 | private final BitVector accessor;
296 |
297 | BooleanAccessor(BitVector vector) {
298 | super(vector);
299 | this.accessor = vector;
300 | }
301 |
302 | @Override
303 | final boolean getBoolean(int rowId) {
304 | return accessor.get(rowId) == 1;
305 | }
306 | }
307 |
308 | private static class ByteAccessor extends ArrowVectorAccessor {
309 |
310 | private final TinyIntVector accessor;
311 |
312 | ByteAccessor(TinyIntVector vector) {
313 | super(vector);
314 | this.accessor = vector;
315 | }
316 |
317 | @Override
318 | final byte getByte(int rowId) {
319 | return accessor.get(rowId);
320 | }
321 | }
322 |
323 | private static class ShortAccessor extends ArrowVectorAccessor {
324 |
325 | private final SmallIntVector accessor;
326 |
327 | ShortAccessor(SmallIntVector vector) {
328 | super(vector);
329 | this.accessor = vector;
330 | }
331 |
332 | @Override
333 | final short getShort(int rowId) {
334 | return accessor.get(rowId);
335 | }
336 | }
337 |
338 | private static class IntAccessor extends ArrowVectorAccessor {
339 |
340 | private final IntVector accessor;
341 |
342 | IntAccessor(IntVector vector) {
343 | super(vector);
344 | this.accessor = vector;
345 | }
346 |
347 | @Override
348 | final int getInt(int rowId) {
349 | return accessor.get(rowId);
350 | }
351 | }
352 |
353 | private static class LongAccessor extends ArrowVectorAccessor {
354 |
355 | private final BigIntVector accessor;
356 |
357 | LongAccessor(BigIntVector vector) {
358 | super(vector);
359 | this.accessor = vector;
360 | }
361 |
362 | @Override
363 | final long getLong(int rowId) {
364 | return accessor.get(rowId);
365 | }
366 | }
367 |
368 | private static class FloatAccessor extends ArrowVectorAccessor {
369 |
370 | private final Float4Vector accessor;
371 |
372 | FloatAccessor(Float4Vector vector) {
373 | super(vector);
374 | this.accessor = vector;
375 | }
376 |
377 | @Override
378 | final float getFloat(int rowId) {
379 | return accessor.get(rowId);
380 | }
381 | }
382 |
383 | private static class DoubleAccessor extends ArrowVectorAccessor {
384 |
385 | private final Float8Vector accessor;
386 |
387 | DoubleAccessor(Float8Vector vector) {
388 | super(vector);
389 | this.accessor = vector;
390 | }
391 |
392 | @Override
393 | final double getDouble(int rowId) {
394 | return accessor.get(rowId);
395 | }
396 | }
397 |
398 | private static class DecimalAccessor extends ArrowVectorAccessor {
399 |
400 | private final DecimalVector accessor;
401 |
402 | DecimalAccessor(DecimalVector vector) {
403 | super(vector);
404 | this.accessor = vector;
405 | }
406 |
407 | @Override
408 | final Decimal getDecimal(int rowId, int precision, int scale) {
409 | if (isNullAt(rowId)) {
410 | return null;
411 | }
412 | return Decimal.apply(accessor.getObject(rowId), precision, scale);
413 | }
414 | }
415 |
416 | private static class StringAccessor extends ArrowVectorAccessor {
417 |
418 | private final VarCharVector accessor;
419 | private final NullableVarCharHolder stringResult = new NullableVarCharHolder();
420 |
421 | StringAccessor(VarCharVector vector) {
422 | super(vector);
423 | this.accessor = vector;
424 | }
425 |
426 | @Override
427 | final UTF8String getUTF8String(int rowId) {
428 | accessor.get(rowId, stringResult);
429 | if (stringResult.isSet == 0) {
430 | return null;
431 | } else {
432 | return UTF8String.fromAddress(null,
433 | stringResult.buffer.memoryAddress() + stringResult.start,
434 | stringResult.end - stringResult.start);
435 | }
436 | }
437 | }
438 |
439 | private static class BinaryAccessor extends ArrowVectorAccessor {
440 |
441 | private final VarBinaryVector accessor;
442 |
443 | BinaryAccessor(VarBinaryVector vector) {
444 | super(vector);
445 | this.accessor = vector;
446 | }
447 |
448 | @Override
449 | final byte[] getBinary(int rowId) {
450 | return accessor.getObject(rowId);
451 | }
452 | }
453 |
454 | private static class DateAccessor extends ArrowVectorAccessor {
455 |
456 | private final DateDayVector accessor;
457 |
458 | DateAccessor(DateDayVector vector) {
459 | super(vector);
460 | this.accessor = vector;
461 | }
462 |
463 | @Override
464 | final int getInt(int rowId) {
465 | return accessor.get(rowId);
466 | }
467 | }
468 |
469 | private static class DateMilliAccessor extends ArrowVectorAccessor {
470 |
471 | private final DateMilliVector accessor;
472 | private final double val = 1.0 / (24. * 60. * 60. * 1000.);
473 |
474 | DateMilliAccessor(DateMilliVector vector) {
475 | super(vector);
476 | this.accessor = vector;
477 | }
478 |
479 | @Override
480 | final int getInt(int rowId) {
481 | System.out.println(accessor.get(rowId) + " " + (accessor.get(rowId) * val) + " " + val);
482 | return (int) (accessor.get(rowId) * val);
483 | }
484 | }
485 |
486 | private static class TimestampMicroAccessor extends ArrowVectorAccessor {
487 |
488 | private final TimeStampVector accessor;
489 |
490 | TimestampMicroAccessor(TimeStampMicroVector vector) {
491 | super(vector);
492 | this.accessor = vector;
493 | }
494 |
495 | @Override
496 | final long getLong(int rowId) {
497 | return accessor.get(rowId);
498 | }
499 | }
500 |
501 | private static class TimestampMicroTZAccessor extends ArrowVectorAccessor {
502 |
503 | private final TimeStampVector accessor;
504 |
505 | TimestampMicroTZAccessor(TimeStampMicroTZVector vector) {
506 | super(vector);
507 | this.accessor = vector;
508 | }
509 |
510 | @Override
511 | final long getLong(int rowId) {
512 | return accessor.get(rowId);
513 | }
514 | }
515 |
516 | private static class TimestampMilliAccessor extends ArrowVectorAccessor {
517 |
518 | private final TimeStampVector accessor;
519 |
520 | TimestampMilliAccessor(TimeStampMilliVector vector) {
521 | super(vector);
522 | this.accessor = vector;
523 | }
524 |
525 | @Override
526 | final long getLong(int rowId) {
527 | return accessor.get(rowId) * 1000;
528 | }
529 | }
530 |
531 | private static class ArrayAccessor extends ArrowVectorAccessor {
532 |
533 | private final ListVector accessor;
534 | private final FlightArrowColumnVector arrayData;
535 |
536 | ArrayAccessor(ListVector vector) {
537 | super(vector);
538 | this.accessor = vector;
539 | this.arrayData = new FlightArrowColumnVector(vector.getDataVector());
540 | }
541 |
542 | @Override
543 | final boolean isNullAt(int rowId) {
544 | // TODO: Workaround if vector has all non-null values, see ARROW-1948
545 | if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) {
546 | return false;
547 | } else {
548 | return super.isNullAt(rowId);
549 | }
550 | }
551 |
552 | @Override
553 | final ColumnarArray getArray(int rowId) {
554 | ArrowBuf offsets = accessor.getOffsetBuffer();
555 | int index = rowId * ListVector.OFFSET_WIDTH;
556 | int start = offsets.getInt(index);
557 | int end = offsets.getInt(index + ListVector.OFFSET_WIDTH);
558 | return new ColumnarArray(arrayData, start, end - start);
559 | }
560 | }
561 |
562 | /**
563 | * Any call to "get" method will throw UnsupportedOperationException.
564 | *
565 | * Access struct values in a ArrowColumnVector doesn't use this accessor. Instead, it uses
566 | * getStruct() method defined in the parent class. Any call to "get" method in this class is a
567 | * bug in the code.
568 | */
569 | private static class StructAccessor extends ArrowVectorAccessor {
570 |
571 | StructAccessor(StructVector vector) {
572 | super(vector);
573 | }
574 | }
575 | }
576 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightClientFactory.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.apache.arrow.flight.spark;
17 |
18 | import java.io.ByteArrayInputStream;
19 | import java.io.InputStream;
20 |
21 | import org.apache.arrow.flight.FlightClient;
22 | import org.apache.arrow.flight.Location;
23 | import org.apache.arrow.flight.grpc.CredentialCallOption;
24 | import org.apache.arrow.memory.BufferAllocator;
25 | import org.apache.arrow.memory.RootAllocator;
26 |
27 | public class FlightClientFactory implements AutoCloseable {
28 | private final BufferAllocator allocator = new RootAllocator();
29 | private final Location defaultLocation;
30 | private final FlightClientOptions clientOptions;
31 |
32 | private CredentialCallOption callOption;
33 |
34 | public FlightClientFactory(Location defaultLocation, FlightClientOptions clientOptions) {
35 | this.defaultLocation = defaultLocation;
36 | this.clientOptions = clientOptions;
37 | }
38 |
39 | public FlightClient apply() {
40 | FlightClient.Builder builder = FlightClient.builder(allocator, defaultLocation);
41 |
42 | if (!clientOptions.getTrustedCertificates().isEmpty()) {
43 | builder.trustedCertificates(new ByteArrayInputStream(clientOptions.getTrustedCertificates().getBytes()));
44 | }
45 |
46 | String clientCertificate = clientOptions.getClientCertificate();
47 | if (clientCertificate != null && !clientCertificate.isEmpty()) {
48 | InputStream clientCert = new ByteArrayInputStream(clientCertificate.getBytes());
49 | InputStream clientKey = new ByteArrayInputStream(clientOptions.getClientKey().getBytes());
50 | builder.clientCertificate(clientCert, clientKey);
51 | }
52 |
53 | // Add client middleware
54 | clientOptions.getMiddleware().stream().forEach(middleware -> builder.intercept(middleware));
55 |
56 | FlightClient client = builder.build();
57 | String username = clientOptions.getUsername();
58 | if (username != null && !username.isEmpty()) {
59 | this.callOption = client.authenticateBasicToken(clientOptions.getUsername(), clientOptions.getPassword()).get();
60 | }
61 |
62 | return client;
63 | }
64 |
65 | public CredentialCallOption getCallOption() {
66 | return this.callOption;
67 | }
68 |
69 | @Override
70 | public void close() {
71 | allocator.close();
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightClientMiddlewareFactory.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.io.Serializable;
20 |
21 | import org.apache.arrow.flight.FlightClientMiddleware;
22 |
23 | public interface FlightClientMiddlewareFactory extends FlightClientMiddleware.Factory, Serializable {
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightClientOptions.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.io.Serializable;
20 | import java.util.List;
21 |
22 | public class FlightClientOptions implements Serializable {
23 | private final String username;
24 | private final String password;
25 | private final String trustedCertificates;
26 | private final String clientCertificate;
27 | private final String clientKey;
28 | private final List middleware;
29 |
30 | public FlightClientOptions(String username, String password, String trustedCertificates, String clientCertificate, String clientKey, List middleware) {
31 | this.username = username;
32 | this.password = password;
33 | this.trustedCertificates = trustedCertificates;
34 | this.clientCertificate = clientCertificate;
35 | this.clientKey = clientKey;
36 | this.middleware = middleware;
37 | }
38 |
39 | public String getUsername() {
40 | return username;
41 | }
42 |
43 | public String getPassword() {
44 | return password;
45 | }
46 |
47 | public String getTrustedCertificates() {
48 | return trustedCertificates;
49 | }
50 |
51 | public String getClientCertificate() {
52 | return clientCertificate;
53 | }
54 |
55 | public String getClientKey() {
56 | return clientKey;
57 | }
58 |
59 | public List getMiddleware() {
60 | return middleware;
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightColumnarPartitionReader.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.io.IOException;
20 |
21 | import org.apache.arrow.flight.grpc.CredentialCallOption;
22 | import org.apache.spark.sql.connector.read.PartitionReader;
23 | import org.apache.spark.sql.vectorized.ColumnarBatch;
24 | import org.apache.arrow.flight.FlightClient;
25 | import org.apache.arrow.flight.FlightStream;
26 | import org.apache.arrow.util.AutoCloseables;
27 | import org.apache.spark.sql.vectorized.ColumnVector;
28 |
29 | public class FlightColumnarPartitionReader implements PartitionReader {
30 | private final FlightClientFactory clientFactory;
31 | private final FlightClient client;
32 | private final FlightStream stream;
33 |
34 | public FlightColumnarPartitionReader(FlightClientOptions clientOptions, FlightPartition partition) {
35 | // TODO - Should we handle multiple locations?
36 | clientFactory = new FlightClientFactory(partition.getEndpoint().get().getLocations().get(0), clientOptions);
37 | client = clientFactory.apply();
38 | CredentialCallOption callOption = clientFactory.getCallOption();
39 | stream = client.getStream(partition.getEndpoint().get().getTicket(), callOption);
40 | }
41 |
42 | // This is written this way because the Spark interface iterates in a different way.
43 | // E.g., .next() -> .get() vs. .hasNext() -> .next()
44 | @Override
45 | public boolean next() throws IOException {
46 | try {
47 | return stream.next();
48 | } catch (RuntimeException e) {
49 | throw new IOException(e);
50 | }
51 | }
52 |
53 | @Override
54 | public ColumnarBatch get() {
55 | ColumnarBatch batch = new ColumnarBatch(
56 | stream.getRoot().getFieldVectors()
57 | .stream()
58 | .map(FlightArrowColumnVector::new)
59 | .toArray(ColumnVector[]::new)
60 | );
61 | batch.setNumRows(stream.getRoot().getRowCount());
62 | return batch;
63 | }
64 |
65 | @Override
66 | public void close() throws IOException {
67 | try {
68 | AutoCloseables.close(stream, client, clientFactory);
69 | } catch (Exception e) {
70 | throw new IOException(e);
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightEndpointWrapper.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.io.IOException;
20 | import java.io.ObjectInputStream;
21 | import java.io.ObjectOutputStream;
22 | import java.io.Serializable;
23 | import java.net.URI;
24 | import java.util.ArrayList;
25 | import java.util.stream.Collectors;
26 |
27 | import org.apache.arrow.flight.FlightEndpoint;
28 | import org.apache.arrow.flight.Location;
29 | import org.apache.arrow.flight.Ticket;
30 |
31 | // This is needed for FlightEndpoint to be Serializable in spark.
32 | // org.apache.arrow.flight.FlightEndpoint is a POJO of Serializable types.
33 | // However if spark is using build-in serialization instead of Kyro then we must implement Serializable
34 | public class FlightEndpointWrapper implements Serializable {
35 | private FlightEndpoint inner;
36 |
37 | public FlightEndpointWrapper(FlightEndpoint inner) {
38 | this.inner = inner;
39 | }
40 |
41 | public FlightEndpoint get() {
42 | return inner;
43 | }
44 |
45 | private void writeObject(ObjectOutputStream out) throws IOException {
46 | ArrayList locations = inner.getLocations().stream().map(location -> location.getUri()).collect(Collectors.toCollection(ArrayList::new));
47 | out.writeObject(locations);
48 | out.write(inner.getTicket().getBytes());
49 | }
50 |
51 | private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
52 | @SuppressWarnings("unchecked")
53 | Location[] locations = ((ArrayList) in.readObject()).stream().map(l -> new Location(l)).toArray(Location[]::new);
54 | byte[] ticket = in.readAllBytes();
55 | this.inner = new FlightEndpoint(new Ticket(ticket), locations);
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightPartition.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import org.apache.spark.sql.connector.read.InputPartition;
20 |
21 | public class FlightPartition implements InputPartition {
22 | private final FlightEndpointWrapper endpoint;
23 |
24 | public FlightPartition(FlightEndpointWrapper endpoint) {
25 | this.endpoint = endpoint;
26 | }
27 |
28 | @Override
29 | public String[] preferredLocations() {
30 | return endpoint.get().getLocations().stream().map(location -> location.getUri().getHost()).toArray(String[]::new);
31 | }
32 |
33 | public FlightEndpointWrapper getEndpoint() {
34 | return endpoint;
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReader.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.io.IOException;
20 | import java.util.Iterator;
21 | import java.util.Optional;
22 |
23 | import org.apache.arrow.flight.FlightClient;
24 | import org.apache.arrow.flight.FlightStream;
25 | import org.apache.arrow.flight.grpc.CredentialCallOption;
26 | import org.apache.arrow.util.AutoCloseables;
27 | import org.apache.spark.sql.catalyst.InternalRow;
28 | import org.apache.spark.sql.connector.read.PartitionReader;
29 | import org.apache.spark.sql.vectorized.ColumnVector;
30 | import org.apache.spark.sql.vectorized.ColumnarBatch;
31 |
32 | public class FlightPartitionReader implements PartitionReader {
33 | private final FlightClientFactory clientFactory;;
34 | private final FlightClient client;
35 | private final CredentialCallOption callOption;
36 | private final FlightStream stream;
37 | private Optional> batch;
38 | private InternalRow row;
39 |
40 | public FlightPartitionReader(FlightClientOptions clientOptions, FlightPartition partition) {
41 | // TODO - Should we handle multiple locations?
42 | clientFactory = new FlightClientFactory(partition.getEndpoint().get().getLocations().get(0), clientOptions);
43 | client = clientFactory.apply();
44 | callOption = clientFactory.getCallOption();
45 | stream = client.getStream(partition.getEndpoint().get().getTicket(), callOption);
46 | }
47 |
48 | private Iterator getNextBatch() {
49 | ColumnarBatch batch = new ColumnarBatch(
50 | stream.getRoot().getFieldVectors()
51 | .stream()
52 | .map(FlightArrowColumnVector::new)
53 | .toArray(ColumnVector[]::new)
54 | );
55 | batch.setNumRows(stream.getRoot().getRowCount());
56 | return batch.rowIterator();
57 | }
58 |
59 | // This is written this way because the Spark interface iterates in a different way.
60 | // E.g., .next() -> .get() vs. .hasNext() -> .next()
61 | @Override
62 | public boolean next() throws IOException {
63 | try {
64 | // Try the iterator first then get next batch
65 | // Not quite rust match expressions...
66 | return batch.map(currentBatch -> {
67 | // Are there still rows in this batch?
68 | if (currentBatch.hasNext()) {
69 | row = currentBatch.next();
70 | return true;
71 | // No more rows, get the next batch
72 | } else {
73 | // Is there another batch?
74 | if (stream.next()) {
75 | // Yes, then fetch it.
76 | Iterator nextBatch = getNextBatch();
77 | batch = Optional.of(nextBatch);
78 | if (currentBatch.hasNext()) {
79 | row = currentBatch.next();
80 | return true;
81 | // Odd, we got an empty batch
82 | } else {
83 | return false;
84 | }
85 | // This partition / stream is complete
86 | } else {
87 | return false;
88 | }
89 | }
90 | // Fetch the first batch
91 | }).orElseGet(() -> {
92 | // Is the stream empty?
93 | if (stream.next()) {
94 | // No, then fetch the first batch
95 | Iterator firstBatch = getNextBatch();
96 | batch = Optional.of(firstBatch);
97 | if (firstBatch.hasNext()) {
98 | row = firstBatch.next();
99 | return true;
100 | // Odd, we got an empty batch
101 | } else {
102 | return false;
103 | }
104 | // The stream was empty...
105 | } else {
106 | return false;
107 | }
108 | });
109 | } catch (RuntimeException e) {
110 | throw new IOException(e);
111 | }
112 | }
113 |
114 | @Override
115 | public InternalRow get() {
116 | return row;
117 | }
118 |
119 | @Override
120 | public void close() throws IOException {
121 | try {
122 | AutoCloseables.close(stream, client, clientFactory);
123 | } catch (Exception e) {
124 | throw new IOException(e);
125 | }
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightPartitionReaderFactory.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import org.apache.spark.broadcast.Broadcast;
20 | import org.apache.spark.sql.catalyst.InternalRow;
21 | import org.apache.spark.sql.connector.read.InputPartition;
22 | import org.apache.spark.sql.connector.read.PartitionReader;
23 | import org.apache.spark.sql.connector.read.PartitionReaderFactory;
24 | import org.apache.spark.sql.vectorized.ColumnarBatch;
25 |
26 | public class FlightPartitionReaderFactory implements PartitionReaderFactory {
27 | private final Broadcast clientOptions;
28 |
29 | public FlightPartitionReaderFactory(Broadcast clientOptions) {
30 | this.clientOptions = clientOptions;
31 | }
32 |
33 | @Override
34 | public PartitionReader createReader(InputPartition iPartition) {
35 | // This feels wrong but this is what upstream spark sources do to.
36 | FlightPartition partition = (FlightPartition) iPartition;
37 | return new FlightPartitionReader(clientOptions.getValue(), partition);
38 | }
39 |
40 | @Override
41 | public PartitionReader createColumnarReader(InputPartition iPartition) {
42 | // This feels wrong but this is what upstream spark sources do to.
43 | FlightPartition partition = (FlightPartition) iPartition;
44 | return new FlightColumnarPartitionReader(clientOptions.getValue(), partition);
45 | }
46 |
47 | @Override
48 | public boolean supportColumnarReads(InputPartition partition) {
49 | return true;
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightScan.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import org.apache.spark.sql.connector.read.Scan;
20 |
21 | import org.apache.arrow.flight.FlightInfo;
22 | import org.apache.spark.broadcast.Broadcast;
23 | import org.apache.spark.sql.connector.read.Batch;
24 | import org.apache.spark.sql.connector.read.InputPartition;
25 | import org.apache.spark.sql.connector.read.PartitionReaderFactory;
26 | import org.apache.spark.sql.types.StructType;
27 |
28 | public class FlightScan implements Scan, Batch {
29 | private final StructType schema;
30 | private final FlightInfo info;
31 | private final Broadcast clientOptions;
32 |
33 | public FlightScan(StructType schema, FlightInfo info, Broadcast clientOptions) {
34 | this.schema = schema;
35 | this.info = info;
36 | this.clientOptions = clientOptions;
37 | }
38 |
39 | @Override
40 | public StructType readSchema() {
41 | return schema;
42 | }
43 |
44 | @Override
45 | public Batch toBatch() {
46 | return this;
47 | }
48 |
49 | @Override
50 | public InputPartition[] planInputPartitions() {
51 | InputPartition[] batches = info.getEndpoints().stream().map(endpoint -> {
52 | FlightEndpointWrapper endpointWrapper = new FlightEndpointWrapper(endpoint);
53 | return new FlightPartition(endpointWrapper);
54 | }).toArray(InputPartition[]::new);
55 | return batches;
56 | }
57 |
58 | @Override
59 | public PartitionReaderFactory createReaderFactory() {
60 | return new FlightPartitionReaderFactory(clientOptions);
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightScanBuilder.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.util.List;
20 | import java.util.Map;
21 | import java.util.stream.Collectors;
22 |
23 | import org.apache.arrow.flight.FlightClient;
24 | import org.apache.arrow.flight.FlightDescriptor;
25 | import org.apache.arrow.flight.FlightInfo;
26 | import org.apache.arrow.flight.Location;
27 | import org.apache.arrow.flight.SchemaResult;
28 | import org.apache.arrow.flight.grpc.CredentialCallOption;
29 | import org.apache.arrow.util.AutoCloseables;
30 | import org.apache.arrow.vector.types.FloatingPointPrecision;
31 | import org.apache.arrow.vector.types.pojo.ArrowType;
32 | import org.apache.arrow.vector.types.pojo.FieldType;
33 | import org.apache.spark.broadcast.Broadcast;
34 | import org.apache.spark.sql.connector.read.Scan;
35 | import org.apache.spark.sql.connector.read.ScanBuilder;
36 | import org.apache.spark.sql.connector.read.SupportsPushDownFilters;
37 | import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns;
38 | import org.apache.spark.sql.sources.*;
39 | import org.apache.spark.sql.types.*;
40 | import org.slf4j.Logger;
41 | import org.slf4j.LoggerFactory;
42 |
43 | import scala.collection.JavaConversions;
44 |
45 | import com.google.common.collect.Lists;
46 | import com.google.common.base.Joiner;
47 |
48 | public class FlightScanBuilder implements ScanBuilder, SupportsPushDownRequiredColumns, SupportsPushDownFilters {
49 | private static final Logger LOGGER = LoggerFactory.getLogger(FlightScanBuilder.class);
50 | private static final Joiner WHERE_JOINER = Joiner.on(" and ");
51 | private static final Joiner PROJ_JOINER = Joiner.on(", ");
52 | private SchemaResult flightSchema;
53 | private StructType schema;
54 | private final Location location;
55 | private final Broadcast clientOptions;
56 | private FlightDescriptor descriptor;
57 | private String sql;
58 | private Filter[] pushed;
59 |
60 | public FlightScanBuilder(Location location, Broadcast clientOptions, String sql) {
61 | this.location = location;
62 | this.clientOptions = clientOptions;
63 | this.sql = sql;
64 | descriptor = getDescriptor(sql);
65 | }
66 |
67 | private class Client implements AutoCloseable {
68 | private final FlightClientFactory clientFactory;
69 | private final FlightClient client;
70 | private final CredentialCallOption callOption;
71 |
72 | public Client(Location location, FlightClientOptions clientOptions) {
73 | this.clientFactory = new FlightClientFactory(location, clientOptions);
74 | this.client = clientFactory.apply();
75 | this.callOption = clientFactory.getCallOption();
76 | }
77 |
78 | public FlightClient get() {
79 | return client;
80 | }
81 |
82 | public CredentialCallOption getCallOption() {
83 | return this.callOption;
84 | }
85 |
86 | @Override
87 | public void close() throws Exception {
88 | AutoCloseables.close(client, clientFactory);
89 | }
90 | }
91 |
92 | private void getFlightSchema(FlightDescriptor descriptor) {
93 | try (Client client = new Client(location, clientOptions.getValue())) {
94 | LOGGER.info("getSchema() descriptor: %s", descriptor);
95 | flightSchema = client.get().getSchema(descriptor, client.getCallOption());
96 | } catch (Exception e) {
97 | throw new RuntimeException(e);
98 | }
99 | }
100 |
101 | @Override
102 | public Scan build() {
103 | try (Client client = new Client(location, clientOptions.getValue())) {
104 | FlightDescriptor descriptor = FlightDescriptor.command(sql.getBytes());
105 | LOGGER.info("getInfo() descriptor: %s", descriptor);
106 | FlightInfo info = client.get().getInfo(descriptor, client.getCallOption());
107 | return new FlightScan(readSchema(), info, clientOptions);
108 | } catch (Exception e) {
109 | throw new RuntimeException(e);
110 | }
111 | }
112 |
113 | private boolean canBePushed(Filter filter) {
114 | if (filter instanceof IsNotNull) {
115 | return true;
116 | } else if (filter instanceof EqualTo) {
117 | return true;
118 | }
119 | if (filter instanceof GreaterThan) {
120 | return true;
121 | }
122 | if (filter instanceof GreaterThanOrEqual) {
123 | return true;
124 | }
125 | if (filter instanceof LessThan) {
126 | return true;
127 | }
128 | if (filter instanceof LessThanOrEqual) {
129 | return true;
130 | }
131 | LOGGER.error("Cant push filter of type " + filter.toString());
132 | return false;
133 | }
134 |
135 | private String valueToString(Object value) {
136 | if (value instanceof String) {
137 | return String.format("'%s'", value);
138 | }
139 | return value.toString();
140 | }
141 |
142 | private String generateWhereClause(List pushed) {
143 | List filterStr = Lists.newArrayList();
144 | for (Filter filter : pushed) {
145 | if (filter instanceof IsNotNull) {
146 | filterStr.add(String.format("isnotnull(\"%s\")", ((IsNotNull) filter).attribute()));
147 | } else if (filter instanceof EqualTo) {
148 | filterStr.add(String.format("\"%s\" = %s", ((EqualTo) filter).attribute(), valueToString(((EqualTo) filter).value())));
149 | } else if (filter instanceof GreaterThan) {
150 | filterStr.add(String.format("\"%s\" > %s", ((GreaterThan) filter).attribute(), valueToString(((GreaterThan) filter).value())));
151 | } else if (filter instanceof GreaterThanOrEqual) {
152 | filterStr.add(String.format("\"%s\" <= %s", ((GreaterThanOrEqual) filter).attribute(), valueToString(((GreaterThanOrEqual) filter).value())));
153 | } else if (filter instanceof LessThan) {
154 | filterStr.add(String.format("\"%s\" < %s", ((LessThan) filter).attribute(), valueToString(((LessThan) filter).value())));
155 | } else if (filter instanceof LessThanOrEqual) {
156 | filterStr.add(String.format("\"%s\" <= %s", ((LessThanOrEqual) filter).attribute(), valueToString(((LessThanOrEqual) filter).value())));
157 | }
158 | //todo fill out rest of Filter types
159 | }
160 | return WHERE_JOINER.join(filterStr);
161 | }
162 |
163 | private FlightDescriptor getDescriptor(String sql) {
164 | return FlightDescriptor.command(sql.getBytes());
165 | }
166 |
167 | private void mergeWhereDescriptors(String whereClause) {
168 | sql = String.format("select * from (%s) as where_merge where %s", sql, whereClause);
169 | descriptor = getDescriptor(sql);
170 | }
171 |
172 | @Override
173 | public Filter[] pushFilters(Filter[] filters) {
174 | List notPushed = Lists.newArrayList();
175 | List pushed = Lists.newArrayList();
176 | for (Filter filter : filters) {
177 | boolean isPushed = canBePushed(filter);
178 | if (isPushed) {
179 | pushed.add(filter);
180 | } else {
181 | notPushed.add(filter);
182 | }
183 | }
184 | this.pushed = pushed.toArray(new Filter[0]);
185 | if (!pushed.isEmpty()) {
186 | String whereClause = generateWhereClause(pushed);
187 | mergeWhereDescriptors(whereClause);
188 | getFlightSchema(descriptor);
189 | }
190 | return notPushed.toArray(new Filter[0]);
191 | }
192 |
193 | @Override
194 | public Filter[] pushedFilters() {
195 | return pushed;
196 | }
197 |
198 | private DataType sparkFromArrow(FieldType fieldType) {
199 | switch (fieldType.getType().getTypeID()) {
200 | case Null:
201 | return DataTypes.NullType;
202 | case Struct:
203 | throw new UnsupportedOperationException("have not implemented Struct type yet");
204 | case List:
205 | throw new UnsupportedOperationException("have not implemented List type yet");
206 | case FixedSizeList:
207 | throw new UnsupportedOperationException("have not implemented FixedSizeList type yet");
208 | case Union:
209 | throw new UnsupportedOperationException("have not implemented Union type yet");
210 | case Int:
211 | ArrowType.Int intType = (ArrowType.Int) fieldType.getType();
212 | int bitWidth = intType.getBitWidth();
213 | if (bitWidth == 8) {
214 | return DataTypes.ByteType;
215 | } else if (bitWidth == 16) {
216 | return DataTypes.ShortType;
217 | } else if (bitWidth == 32) {
218 | return DataTypes.IntegerType;
219 | } else if (bitWidth == 64) {
220 | return DataTypes.LongType;
221 | }
222 | throw new UnsupportedOperationException("unknown int type with bitwidth " + bitWidth);
223 | case FloatingPoint:
224 | ArrowType.FloatingPoint floatType = (ArrowType.FloatingPoint) fieldType.getType();
225 | FloatingPointPrecision precision = floatType.getPrecision();
226 | switch (precision) {
227 | case HALF:
228 | case SINGLE:
229 | return DataTypes.FloatType;
230 | case DOUBLE:
231 | return DataTypes.DoubleType;
232 | }
233 | case Utf8:
234 | return DataTypes.StringType;
235 | case Binary:
236 | case FixedSizeBinary:
237 | return DataTypes.BinaryType;
238 | case Bool:
239 | return DataTypes.BooleanType;
240 | case Decimal:
241 | throw new UnsupportedOperationException("have not implemented Decimal type yet");
242 | case Date:
243 | return DataTypes.DateType;
244 | case Time:
245 | return DataTypes.TimestampType; // note i don't know what this will do!
246 | case Timestamp:
247 | return DataTypes.TimestampType;
248 | case Interval:
249 | return DataTypes.CalendarIntervalType;
250 | case NONE:
251 | return DataTypes.NullType;
252 | default:
253 | throw new IllegalStateException("Unexpected value: " + fieldType);
254 | }
255 | }
256 |
257 | private StructType readSchemaImpl() {
258 | if (flightSchema == null) {
259 | getFlightSchema(descriptor);
260 | }
261 | StructField[] fields = flightSchema.getSchema().getFields().stream()
262 | .map(field -> new StructField(field.getName(),
263 | sparkFromArrow(field.getFieldType()),
264 | field.isNullable(),
265 | Metadata.empty()))
266 | .toArray(StructField[]::new);
267 | return new StructType(fields);
268 | }
269 |
270 | public StructType readSchema() {
271 | if (schema == null) {
272 | schema = readSchemaImpl();
273 | }
274 | return schema;
275 | }
276 |
277 | private void mergeProjDescriptors(String projClause) {
278 | sql = String.format("select %s from (%s) as proj_merge", projClause, sql);
279 | descriptor = getDescriptor(sql);
280 | }
281 |
282 | @Override
283 | public void pruneColumns(StructType requiredSchema) {
284 | if (requiredSchema.toSeq().isEmpty()) {
285 | return;
286 | }
287 | StructType schema = readSchema();
288 | List fields = Lists.newArrayList();
289 | List fieldsLeft = Lists.newArrayList();
290 | Map fieldNames = JavaConversions.seqAsJavaList(schema.toSeq()).stream()
291 | .collect(Collectors.toMap(StructField::name, f -> f));
292 | for (StructField field : JavaConversions.seqAsJavaList(requiredSchema.toSeq())) {
293 | String name = field.name();
294 | StructField f = fieldNames.remove(name);
295 | if (f != null) {
296 | fields.add(String.format("\"%s\"", name));
297 | fieldsLeft.add(f);
298 | }
299 | }
300 | if (!fieldNames.isEmpty()) {
301 | this.schema = new StructType(fieldsLeft.toArray(new StructField[0]));
302 | mergeProjDescriptors(PROJ_JOINER.join(fields));
303 | getFlightSchema(descriptor);
304 | }
305 | }
306 | }
307 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightSparkContext.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.apache.arrow.flight.spark;
17 |
18 | import org.apache.spark.SparkConf;
19 | import org.apache.spark.sql.DataFrameReader;
20 | import org.apache.spark.sql.Dataset;
21 | import org.apache.spark.sql.Row;
22 | import org.apache.spark.sql.SparkSession;
23 |
24 | public class FlightSparkContext {
25 |
26 | private SparkConf conf;
27 |
28 | private final DataFrameReader reader;
29 |
30 | public FlightSparkContext(SparkSession spark) {
31 | this.conf = spark.sparkContext().getConf();
32 | reader = spark.read().format("org.apache.arrow.flight.spark");
33 | }
34 |
35 | public Dataset read(String s) {
36 | return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port")))
37 | .option("uri", String.format(
38 | "grpc://%s:%s",
39 | conf.get("spark.flight.endpoint.host"),
40 | conf.get("spark.flight.endpoint.port")))
41 | .option("username", conf.get("spark.flight.auth.username"))
42 | .option("password", conf.get("spark.flight.auth.password"))
43 | .load(s);
44 | }
45 |
46 | public Dataset readSql(String s) {
47 | return reader.option("port", Integer.parseInt(conf.get("spark.flight.endpoint.port")))
48 | .option("uri", String.format(
49 | "grpc://%s:%s",
50 | conf.get("spark.flight.endpoint.host"),
51 | conf.get("spark.flight.endpoint.port")))
52 | .option("username", conf.get("spark.flight.auth.username"))
53 | .option("password", conf.get("spark.flight.auth.password"))
54 | .load(s);
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/FlightTable.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import java.util.Set;
20 |
21 | import org.apache.arrow.flight.Location;
22 | import org.apache.spark.broadcast.Broadcast;
23 | import org.apache.spark.sql.connector.catalog.SupportsRead;
24 | import org.apache.spark.sql.connector.catalog.Table;
25 | import org.apache.spark.sql.connector.catalog.TableCapability;
26 | import org.apache.spark.sql.connector.read.ScanBuilder;
27 | import org.apache.spark.sql.types.StructType;
28 | import org.apache.spark.sql.util.CaseInsensitiveStringMap;
29 |
30 | public class FlightTable implements Table, SupportsRead {
31 | private static final Set CAPABILITIES = Set.of(TableCapability.BATCH_READ);
32 | private final String name;
33 | private final Location location;
34 | private final String sql;
35 | private final Broadcast clientOptions;
36 | private StructType schema;
37 |
38 | public FlightTable(String name, Location location, String sql, Broadcast clientOptions) {
39 | this.name = name;
40 | this.location = location;
41 | this.sql = sql;
42 | this.clientOptions = clientOptions;
43 | }
44 |
45 | @Override
46 | public String name() {
47 | return name;
48 | }
49 |
50 | @Override
51 | public StructType schema() {
52 | if (schema == null) {
53 | FlightScanBuilder scanBuilder = new FlightScanBuilder(location, clientOptions, sql);
54 | schema = scanBuilder.readSchema();
55 | }
56 | return schema;
57 | }
58 |
59 | // TODO - We could probably implement partitioning() but it would require server side support
60 |
61 | @Override
62 | public Set capabilities() {
63 | // We only support reading for now
64 | return CAPABILITIES;
65 | }
66 |
67 | @Override
68 | public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
69 | return new FlightScanBuilder(location, clientOptions, sql);
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddleware.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import org.apache.arrow.flight.CallHeaders;
20 | import org.apache.arrow.flight.CallStatus;
21 | import org.apache.arrow.flight.FlightClientMiddleware;
22 |
23 | public class TokenClientMiddleware implements FlightClientMiddleware {
24 | private final String token;
25 |
26 | public TokenClientMiddleware(String token) {
27 | this.token = token;
28 | }
29 |
30 | @Override
31 | public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
32 | outgoingHeaders.insert("authorization", String.format("Bearer %s", token));
33 | }
34 |
35 | @Override
36 | public void onHeadersReceived(CallHeaders incomingHeaders) {
37 | // Nothing needed here
38 | }
39 |
40 | @Override
41 | public void onCallCompleted(CallStatus status) {
42 | // Nothing needed here
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/arrow/flight/spark/TokenClientMiddlewareFactory.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.arrow.flight.spark;
18 |
19 | import org.apache.arrow.flight.CallInfo;
20 | import org.apache.arrow.flight.FlightClientMiddleware;
21 |
22 | public class TokenClientMiddlewareFactory implements FlightClientMiddlewareFactory {
23 | private final String token;
24 |
25 | public TokenClientMiddlewareFactory(String token) {
26 | this.token = token;
27 | }
28 |
29 | @Override
30 | public FlightClientMiddleware onCallStarted(CallInfo info) {
31 | return new TokenClientMiddleware(token);
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/sql/execution/arrow/FlightArrowUtils.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.apache.spark.sql.execution.arrow
18 |
19 | import org.apache.arrow.memory.RootAllocator
20 | import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
21 | import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
22 | import org.apache.spark.sql.internal.SQLConf
23 | import org.apache.spark.sql.types._
24 | import scala.collection.JavaConverters._
25 |
26 | /**
27 | * FlightArrowUtils is a copy of ArrowUtils with extra support for DateMilli and TimestampMilli
28 | */
29 | object FlightArrowUtils {
30 |
31 | val rootAllocator = new RootAllocator(Long.MaxValue)
32 |
33 | // todo: support more types.
34 |
35 | /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
36 | def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match {
37 | case BooleanType => ArrowType.Bool.INSTANCE
38 | case ByteType => new ArrowType.Int(8, true)
39 | case ShortType => new ArrowType.Int(8 * 2, true)
40 | case IntegerType => new ArrowType.Int(8 * 4, true)
41 | case LongType => new ArrowType.Int(8 * 8, true)
42 | case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
43 | case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
44 | case StringType => ArrowType.Utf8.INSTANCE
45 | case BinaryType => ArrowType.Binary.INSTANCE
46 | case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale)
47 | case DateType => new ArrowType.Date(DateUnit.DAY)
48 | case TimestampType =>
49 | if (timeZoneId == null) {
50 | throw new UnsupportedOperationException(
51 | s"${TimestampType.catalogString} must supply timeZoneId parameter")
52 | } else {
53 | new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
54 | }
55 | case _ =>
56 | throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
57 | }
58 |
59 | def fromArrowType(dt: ArrowType): DataType = dt match {
60 | case ArrowType.Bool.INSTANCE => BooleanType
61 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
62 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
63 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
64 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
65 | case float: ArrowType.FloatingPoint
66 | if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType
67 | case float: ArrowType.FloatingPoint
68 | if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
69 | case ArrowType.Utf8.INSTANCE => StringType
70 | case ArrowType.Binary.INSTANCE => BinaryType
71 | case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
72 | case date: ArrowType.Date if date.getUnit == DateUnit.DAY || date.getUnit == DateUnit.MILLISECOND => DateType
73 | case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND || ts.getUnit == TimeUnit.MILLISECOND => TimestampType
74 | case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
75 | }
76 |
77 | /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
78 | def toArrowField(
79 | name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
80 | dt match {
81 | case ArrayType(elementType, containsNull) =>
82 | val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
83 | new Field(name, fieldType,
84 | Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava)
85 | case StructType(fields) =>
86 | val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
87 | new Field(name, fieldType,
88 | fields.map { field =>
89 | toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
90 | }.toSeq.asJava)
91 | case dataType =>
92 | val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null)
93 | new Field(name, fieldType, Seq.empty[Field].asJava)
94 | }
95 | }
96 |
97 | def fromArrowField(field: Field): DataType = {
98 | field.getType match {
99 | case ArrowType.List.INSTANCE =>
100 | val elementField = field.getChildren().get(0)
101 | val elementType = fromArrowField(elementField)
102 | ArrayType(elementType, containsNull = elementField.isNullable)
103 | case ArrowType.Struct.INSTANCE =>
104 | val fields = field.getChildren().asScala.map { child =>
105 | val dt = fromArrowField(child)
106 | StructField(child.getName, dt, child.isNullable)
107 | }
108 | StructType(fields)
109 | case arrowType => fromArrowType(arrowType)
110 | }
111 | }
112 |
113 | /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */
114 | def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
115 | new Schema(schema.map { field =>
116 | toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
117 | }.asJava)
118 | }
119 |
120 | def fromArrowSchema(schema: Schema): StructType = {
121 | StructType(schema.getFields.asScala.map { field =>
122 | val dt = fromArrowField(field)
123 | StructField(field.getName, dt, field.isNullable)
124 | })
125 | }
126 |
127 | /** Return Map with conf settings to be used in ArrowPythonRunner */
128 | def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
129 | val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key ->
130 | conf.sessionLocalTimeZone)
131 | val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
132 | conf.pandasGroupedMapAssignColumnsByName.toString)
133 | Map(timeZoneConf ++ pandasColsByName: _*)
134 | }
135 | }
136 |
--------------------------------------------------------------------------------
/src/test/java/org/apache/arrow/flight/spark/TestConnector.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 The flight-spark-source Authors
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 | package org.apache.arrow.flight.spark;
17 |
18 | import java.io.ByteArrayOutputStream;
19 | import java.io.IOException;
20 | import java.io.ObjectOutputStream;
21 | import java.util.ArrayList;
22 | import java.util.Iterator;
23 | import java.util.List;
24 | import java.util.Optional;
25 | import java.util.function.Consumer;
26 |
27 | import org.apache.arrow.flight.Action;
28 | import org.apache.arrow.flight.FlightDescriptor;
29 | import org.apache.arrow.flight.FlightEndpoint;
30 | import org.apache.arrow.flight.FlightInfo;
31 | import org.apache.arrow.flight.FlightServer;
32 | import org.apache.arrow.flight.FlightTestUtil;
33 | import org.apache.arrow.flight.Location;
34 | import org.apache.arrow.flight.NoOpFlightProducer;
35 | import org.apache.arrow.flight.Result;
36 | import org.apache.arrow.flight.Ticket;
37 | import org.apache.arrow.flight.auth.ServerAuthHandler;
38 | import org.apache.arrow.flight.auth2.CallHeaderAuthenticator;
39 | import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator;
40 | import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator;
41 | import org.apache.arrow.memory.BufferAllocator;
42 | import org.apache.arrow.memory.RootAllocator;
43 | import org.apache.arrow.util.AutoCloseables;
44 | import org.apache.arrow.vector.BigIntVector;
45 | import org.apache.arrow.vector.Float8Vector;
46 | import org.apache.arrow.vector.VarCharVector;
47 | import org.apache.arrow.vector.VectorSchemaRoot;
48 | import org.apache.arrow.vector.types.Types;
49 | import org.apache.arrow.vector.types.pojo.Field;
50 | import org.apache.arrow.vector.types.pojo.Schema;
51 | import org.apache.arrow.vector.util.Text;
52 | import org.apache.spark.api.java.JavaSparkContext;
53 | import org.apache.spark.sql.Dataset;
54 | import org.apache.spark.sql.Row;
55 | import org.apache.spark.sql.SparkSession;
56 | import org.junit.AfterClass;
57 | import org.junit.Assert;
58 | import org.junit.BeforeClass;
59 | import org.junit.Test;
60 | import org.junit.Test.None;
61 | import org.apache.arrow.flight.CallStatus;
62 | import com.google.common.collect.ImmutableList;
63 | import com.google.common.base.Strings;
64 |
65 | public class TestConnector {
66 | private static final String USERNAME_1 = "flight1";
67 | private static final String USERNAME_2 = "flight2";
68 | private static final String NO_USERNAME = "";
69 | private static final String PASSWORD_1 = "woohoo1";
70 | private static final String PASSWORD_2 = "woohoo2";
71 |
72 | private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
73 | private static Location location;
74 | private static FlightServer server;
75 | private static SparkSession spark;
76 | private static FlightSparkContext csc;
77 |
78 | public static CallHeaderAuthenticator.AuthResult validate(String username, String password) {
79 | if (Strings.isNullOrEmpty(username)) {
80 | throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException();
81 | }
82 | final String identity;
83 | if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) {
84 | identity = USERNAME_1;
85 | } else if (USERNAME_2.equals(username) && PASSWORD_2.equals(password)) {
86 | identity = USERNAME_2;
87 | } else {
88 | throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException();
89 | }
90 | return () -> identity;
91 | }
92 |
93 | @BeforeClass
94 | public static void setUp() throws Exception {
95 | FlightServer.Builder builder = FlightServer.builder(allocator,
96 | Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, /*port*/ 0),
97 | new TestProducer());
98 | builder.headerAuthenticator(
99 | new GeneratedBearerTokenAuthenticator(
100 | new BasicCallHeaderAuthenticator(TestConnector::validate)
101 | )
102 | );
103 | server = builder.build();
104 | server.start();
105 | location = server.getLocation();
106 | spark = SparkSession.builder()
107 | .appName("flightTest")
108 | .master("local[*]")
109 | .config("spark.driver.host", "127.0.0.1")
110 | .config("spark.driver.allowMultipleContexts", "true")
111 | .config("spark.flight.endpoint.host", location.getUri().getHost())
112 | .config("spark.flight.endpoint.port", Integer.toString(location.getUri().getPort()))
113 | .config("spark.flight.auth.username", USERNAME_1)
114 | .config("spark.flight.auth.password", PASSWORD_1)
115 | .getOrCreate();
116 | csc = new FlightSparkContext(spark);
117 | }
118 |
119 | @AfterClass
120 | public static void tearDown() throws Exception {
121 | AutoCloseables.close(server, allocator, spark);
122 | }
123 |
124 | private class DummyObjectOutputStream extends ObjectOutputStream {
125 | public DummyObjectOutputStream() throws IOException {
126 | super(new ByteArrayOutputStream());
127 | }
128 | }
129 |
130 | @Test(expected = None.class)
131 | public void testFlightPartitionReaderFactorySerialization() throws IOException {
132 | List middleware = new ArrayList<>();
133 | FlightClientOptions clientOptions = new FlightClientOptions("xxx", "yyy", "FooBar", "FooBar", "FooBar", middleware);
134 | FlightPartitionReaderFactory readerFactory = new FlightPartitionReaderFactory(JavaSparkContext.fromSparkContext(spark.sparkContext()).broadcast(clientOptions));
135 |
136 | try (ObjectOutputStream oos = new DummyObjectOutputStream()) {
137 | oos.writeObject(readerFactory);
138 | }
139 | }
140 |
141 | @Test(expected = None.class)
142 | public void testFlightPartitionSerialization() throws IOException {
143 | Ticket ticket = new Ticket("FooBar".getBytes());
144 | FlightEndpoint endpoint = new FlightEndpoint(ticket, location);
145 | FlightPartition partition = new FlightPartition(new FlightEndpointWrapper(endpoint));
146 | try (ObjectOutputStream oos = new DummyObjectOutputStream()) {
147 | oos.writeObject(partition);
148 | }
149 | }
150 |
151 | @Test
152 | public void testConnect() {
153 | csc.read("test.table");
154 | }
155 |
156 | @Test
157 | public void testRead() {
158 | long count = csc.read("test.table").count();
159 | Assert.assertEquals(20, count);
160 | }
161 |
162 | @Test
163 | public void testSql() {
164 | long count = csc.readSql("select * from test.table").count();
165 | Assert.assertEquals(20, count);
166 | }
167 |
168 | @Test
169 | public void testFilter() {
170 | Dataset df = csc.readSql("select * from test.table");
171 | long count = df.filter(df.col("symbol").equalTo("USDCAD")).count();
172 | long countOriginal = csc.readSql("select * from test.table").count();
173 | Assert.assertTrue(count < countOriginal);
174 | }
175 |
176 | private static class SizeConsumer implements Consumer {
177 | private int length = 0;
178 | private int width = 0;
179 |
180 | @Override
181 | public void accept(Row row) {
182 | length += 1;
183 | width = row.length();
184 | }
185 | }
186 |
187 | @Test
188 | public void testProject() {
189 | Dataset df = csc.readSql("select * from test.table");
190 | SizeConsumer c = new SizeConsumer();
191 | df.select("bid", "ask", "symbol").toLocalIterator().forEachRemaining(c);
192 | long count = c.width;
193 | long countOriginal = csc.readSql("select * from test.table").columns().length;
194 | Assert.assertTrue(count < countOriginal);
195 | }
196 |
197 | private static class TestProducer extends NoOpFlightProducer {
198 | private boolean parallel = false;
199 |
200 | @Override
201 | public void doAction(CallContext context, Action action, StreamListener listener) {
202 | parallel = true;
203 | listener.onNext(new Result("ok".getBytes()));
204 | listener.onCompleted();
205 | }
206 |
207 | @Override
208 | public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
209 | Schema schema;
210 | List endpoints;
211 | if (parallel) {
212 | endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location),
213 | new FlightEndpoint(new Ticket(descriptor.getCommand()), location));
214 | } else {
215 | endpoints = ImmutableList.of(new FlightEndpoint(new Ticket(descriptor.getCommand()), location));
216 | }
217 | if (new String(descriptor.getCommand()).equals("select \"bid\", \"ask\", \"symbol\" from (select * from test.table))")) {
218 | schema = new Schema(ImmutableList.of(
219 | Field.nullable("bid", Types.MinorType.FLOAT8.getType()),
220 | Field.nullable("ask", Types.MinorType.FLOAT8.getType()),
221 | Field.nullable("symbol", Types.MinorType.VARCHAR.getType()))
222 | );
223 |
224 | } else {
225 | schema = new Schema(ImmutableList.of(
226 | Field.nullable("bid", Types.MinorType.FLOAT8.getType()),
227 | Field.nullable("ask", Types.MinorType.FLOAT8.getType()),
228 | Field.nullable("symbol", Types.MinorType.VARCHAR.getType()),
229 | Field.nullable("bidsize", Types.MinorType.BIGINT.getType()),
230 | Field.nullable("asksize", Types.MinorType.BIGINT.getType()))
231 | );
232 | }
233 | return new FlightInfo(schema, descriptor, endpoints, 1000000, 10);
234 | }
235 |
236 | @Override
237 | public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
238 | final int size = (new String(ticket.getBytes()).contains("USDCAD")) ? 5 : 10;
239 |
240 | if (new String(ticket.getBytes()).equals("select \"bid\", \"ask\", \"symbol\" from (select * from test.table))")) {
241 | Float8Vector b = new Float8Vector("bid", allocator);
242 | Float8Vector a = new Float8Vector("ask", allocator);
243 | VarCharVector s = new VarCharVector("symbol", allocator);
244 |
245 | VectorSchemaRoot root = VectorSchemaRoot.of(b, a, s);
246 | listener.start(root);
247 |
248 | //batch 1
249 | root.allocateNew();
250 | for (int i = 0; i < size; i++) {
251 | b.set(i, (double) i);
252 | a.set(i, (double) i);
253 | s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
254 | }
255 | b.setValueCount(size);
256 | a.setValueCount(size);
257 | s.setValueCount(size);
258 | root.setRowCount(size);
259 | listener.putNext();
260 |
261 | // batch 2
262 |
263 | root.allocateNew();
264 | for (int i = 0; i < size; i++) {
265 | b.set(i, (double) i);
266 | a.set(i, (double) i);
267 | s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
268 | }
269 | b.setValueCount(size);
270 | a.setValueCount(size);
271 | s.setValueCount(size);
272 | root.setRowCount(size);
273 | listener.putNext();
274 | root.clear();
275 | listener.completed();
276 | } else {
277 | BigIntVector bs = new BigIntVector("bidsize", allocator);
278 | BigIntVector as = new BigIntVector("asksize", allocator);
279 | Float8Vector b = new Float8Vector("bid", allocator);
280 | Float8Vector a = new Float8Vector("ask", allocator);
281 | VarCharVector s = new VarCharVector("symbol", allocator);
282 |
283 | VectorSchemaRoot root = VectorSchemaRoot.of(b, a, s, bs, as);
284 | listener.start(root);
285 |
286 | //batch 1
287 | root.allocateNew();
288 | for (int i = 0; i < size; i++) {
289 | bs.set(i, (long) i);
290 | as.set(i, (long) i);
291 | b.set(i, (double) i);
292 | a.set(i, (double) i);
293 | s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
294 | }
295 | bs.setValueCount(size);
296 | as.setValueCount(size);
297 | b.setValueCount(size);
298 | a.setValueCount(size);
299 | s.setValueCount(size);
300 | root.setRowCount(size);
301 | listener.putNext();
302 |
303 | // batch 2
304 |
305 | root.allocateNew();
306 | for (int i = 0; i < size; i++) {
307 | bs.set(i, (long) i);
308 | as.set(i, (long) i);
309 | b.set(i, (double) i);
310 | a.set(i, (double) i);
311 | s.set(i, (i % 2 == 0) ? new Text("USDCAD") : new Text("EURUSD"));
312 | }
313 | bs.setValueCount(size);
314 | as.setValueCount(size);
315 | b.setValueCount(size);
316 | a.setValueCount(size);
317 | s.setValueCount(size);
318 | root.setRowCount(size);
319 | listener.putNext();
320 | root.clear();
321 | listener.completed();
322 | }
323 | }
324 |
325 |
326 | }
327 | }
328 |
--------------------------------------------------------------------------------
/src/test/resources/logback-test.xml:
--------------------------------------------------------------------------------
1 |
2 |
19 |
20 |
22 | true
23 | 10000
24 | true
25 | ${LILITH_HOSTNAME:-localhost}
26 |
27 |
28 |
29 |
30 | %highlight %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------