├── .gitignore ├── .mypy.ini ├── DEVELOPING.md ├── KHOP.md ├── LICENSE ├── MANIFESTO.md ├── README.md ├── build.gradle ├── client ├── README.md └── src │ └── main │ └── java │ └── org │ └── neo4j │ └── arrow │ └── demo │ └── Client.java ├── common ├── README.md └── src │ ├── main │ └── java │ │ └── org │ │ └── neo4j │ │ └── arrow │ │ └── action │ │ ├── CypherActionHandler.java │ │ └── CypherMessage.java │ └── test │ └── java │ └── org │ └── neo4j │ └── arrow │ └── action │ └── CypherMessageTest.java ├── examples ├── BulkImport.ipynb ├── PyArrow Demo - 17 Aug 2021.pdf ├── PyArrow Demo.ipynb ├── arrow_to_bq.py ├── example.py ├── khop.py ├── live_migration_demo.ipynb └── yolo.py ├── fast.gif ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── khop.svg ├── plugin ├── README.md └── src │ ├── main │ ├── java │ │ └── org │ │ │ └── neo4j │ │ │ └── arrow │ │ │ ├── ArrowExtensionFactory.java │ │ │ ├── ArrowService.java │ │ │ ├── CypherRecord.java │ │ │ ├── GdsNodeRecord.java │ │ │ ├── GdsRecord.java │ │ │ ├── GdsRelationshipRecord.java │ │ │ ├── Neo4jDefaults.java │ │ │ ├── SubGraphRecord.java │ │ │ ├── action │ │ │ ├── BulkImportActionHandler.java │ │ │ ├── BulkImportMessage.java │ │ │ ├── GdsActionHandler.java │ │ │ ├── GdsMessage.java │ │ │ ├── GdsWriteNodeMessage.java │ │ │ ├── GdsWriteRelsMessage.java │ │ │ └── KHopMessage.java │ │ │ ├── auth │ │ │ ├── ArrowConnectionInfo.java │ │ │ └── NativeAuthValidator.java │ │ │ ├── batchimport │ │ │ ├── NodeInputIterator.java │ │ │ ├── QueueInputIterator.java │ │ │ └── RelationshipInputIterator.java │ │ │ ├── gds │ │ │ ├── ArrowAdjacencyCursor.java │ │ │ ├── ArrowAdjacencyList.java │ │ │ ├── ArrowGraphStore.java │ │ │ ├── ArrowNodeProperties.java │ │ │ ├── Edge.java │ │ │ ├── KHop.java │ │ │ ├── NodeHistory.java │ │ │ └── SuperNodeCache.java │ │ │ └── job │ │ │ ├── BulkImportJob.java │ │ │ ├── GdsReadJob.java │ │ │ ├── GdsWriteJob.java │ │ │ └── TransactionApiJob.java │ └── resources │ │ └── META-INF │ │ └── services │ │ └── org.neo4j.kernel.extension.ExtensionFactory │ └── test │ └── java │ └── org │ └── neo4j │ └── arrow │ ├── GdsRecordBenchmarkTest.java │ ├── action │ └── GdsMessageTest.java │ ├── auth │ └── NativeAuthValidatorTest.java │ ├── gds │ └── NodeHistoryTest.java │ └── job │ ├── CypherRecordTest.java │ └── EdgePackingTest.java ├── python ├── neo4j_arrow.py ├── pyarrow │ ├── __init__.pyi │ ├── csv.pyi │ ├── flight.pyi │ └── lib.pyi ├── pyimport.py └── requirements.txt ├── server ├── README.md └── src │ └── main │ └── java │ └── org │ └── neo4j │ └── arrow │ ├── DriverRecord.java │ ├── demo │ └── Neo4jProxyServer.java │ └── job │ ├── AsyncDriverJob.java │ └── DriverJobSummary.java ├── settings.gradle ├── speed ├── 26-aug-2021 │ ├── README.md │ ├── figure1.png │ └── img.png ├── client.py └── direct.py ├── src ├── main │ └── java │ │ └── org │ │ └── neo4j │ │ └── arrow │ │ ├── App.java │ │ ├── Config.java │ │ ├── Producer.java │ │ ├── RowBasedRecord.java │ │ ├── WorkBuffer.java │ │ ├── action │ │ ├── ActionHandler.java │ │ ├── Message.java │ │ ├── Outcome.java │ │ ├── ServerInfoHandler.java │ │ ├── StatusHandler.java │ │ └── auth │ │ │ └── HorribleBasicAuthValidator.java │ │ ├── batch │ │ ├── ArrowBatch.java │ │ ├── ArrowBatches.java │ │ └── BatchedVector.java │ │ └── job │ │ ├── Job.java │ │ ├── JobCreator.java │ │ ├── JobSummary.java │ │ ├── ReadJob.java │ │ └── WriteJob.java └── test │ ├── java │ └── org │ │ └── neo4j │ │ └── arrow │ │ ├── ArrowBatchesTest.java │ │ ├── ArrowListTests.java │ │ ├── Neo4jFlightServerTest.java │ │ ├── NoOpBenchmark.java │ │ ├── SillyAsyncTest.java │ │ ├── SillyBenchmark.java │ │ ├── SillyStreamsTest.java │ │ └── action │ │ └── ServerInfoHandlerTest.java │ └── resources │ └── META-INF │ └── MANIFEST.MF ├── strawman ├── README.md └── src │ └── main │ └── java │ └── org │ └── neo4j │ └── arrow │ └── Neo4jDirectClient.java └── todo.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .gradle/ 3 | build/ 4 | *.swp 5 | .#* 6 | venv/ 7 | __pycache__/ 8 | /python/neo4j_arrow.iml 9 | -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | files = python/*.py 3 | python_version = 3.9 4 | strict = True 5 | 6 | -------------------------------------------------------------------------------- /client/README.md: -------------------------------------------------------------------------------- 1 | # neo4j-arrow java client 2 | A quick and dirty Java client for talking to a `neo4j-arrow` service. 3 | 4 | ## Building 5 | In the parent project: 6 | 7 | ``` 8 | $ ./gradlew :client:shadowJar 9 | ``` 10 | 11 | ## Running 12 | You can just run the uberjar created by Gradle. See the [server](../server) 13 | docs for details on environment variables for configuration. 14 | 15 | ``` 16 | $ java -jar /client-1.0-SNAPSHOT.jar 17 | ``` 18 | 19 | ## Extending or Embedding 20 | The `org.neo4j.arrow.demo.Client` class has both a `static main` method as 21 | well as a Class structure with a `run` method. Either use from the command 22 | line or feel free to instantiate, configure, and run from other JVM based code! -------------------------------------------------------------------------------- /common/README.md: -------------------------------------------------------------------------------- 1 | # common files 2 | These classes are just useful in multiple sub-projects...not much to see here! -------------------------------------------------------------------------------- /common/src/main/java/org/neo4j/arrow/action/CypherMessage.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import com.fasterxml.jackson.databind.ObjectMapper; 4 | import org.neo4j.cypher.internal.ast.Statement; 5 | import org.neo4j.cypher.internal.parser.CypherParser; 6 | import org.neo4j.cypher.internal.util.CypherExceptionFactory; 7 | import org.neo4j.cypher.internal.util.InputPosition; 8 | import org.neo4j.cypher.internal.util.OpenCypherExceptionFactory; 9 | import scala.Option; 10 | 11 | import java.io.IOException; 12 | import java.nio.ByteBuffer; 13 | import java.nio.ByteOrder; 14 | import java.nio.charset.StandardCharsets; 15 | import java.util.Map; 16 | 17 | public class CypherMessage { 18 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(CypherMessage.class); 19 | 20 | static private final ObjectMapper mapper = new ObjectMapper(); 21 | static private final CypherParser parser = new CypherParser(); 22 | 23 | private final String cypher; 24 | private final String database; 25 | private final Map params; 26 | 27 | public CypherMessage(String database, String cypher) { 28 | this(database, cypher, Map.of()); 29 | } 30 | 31 | public CypherMessage(String database, String cypher, Map params) { 32 | this.cypher = cypher; 33 | this.params = params; 34 | this.database = database; 35 | 36 | Statement stmt = parser.parse(cypher, new CypherExceptionFactory() { 37 | @Override 38 | public Exception arithmeticException(String message, Exception cause) { 39 | logger.error(message, cause); 40 | return cause; 41 | } 42 | 43 | @Override 44 | public Exception syntaxException(String message, InputPosition pos) { 45 | Exception e = new OpenCypherExceptionFactory.SyntaxException(message, pos); 46 | logger.error(message, e); 47 | return e; 48 | } 49 | }, Option.apply(InputPosition.NONE())); 50 | 51 | logger.info("parsed: {}", stmt.asCanonicalStringVal()); 52 | } 53 | 54 | public static CypherMessage deserialize(byte[] bytes) throws IOException { 55 | ByteBuffer buffer = ByteBuffer.wrap(bytes) 56 | .asReadOnlyBuffer() 57 | .order(ByteOrder.BIG_ENDIAN); 58 | 59 | short len = buffer.getShort(); 60 | byte[] slice = new byte[len]; 61 | buffer.get(slice); 62 | String cypher = new String(slice, StandardCharsets.UTF_8); 63 | 64 | len = buffer.getShort(); 65 | slice = new byte[len]; 66 | buffer.get(slice); 67 | String database = new String(slice, StandardCharsets.UTF_8); 68 | 69 | len = buffer.getShort(); 70 | slice = new byte[len]; 71 | buffer.get(slice); 72 | Map params = mapper.createParser(slice).readValueAs(Map.class); 73 | 74 | return new CypherMessage(database, cypher, params); 75 | } 76 | 77 | /** 78 | * Serialize the CypherMessage to a platform independent format, kept very simple for now, in 79 | * network byte order. 80 | * 81 | * @return byte[] 82 | */ 83 | public byte[] serialize() { 84 | // not most memory sensitive approach for now, but simple 85 | try { 86 | final byte[] cypherBytes = cypher.getBytes(StandardCharsets.UTF_8); 87 | final byte[] databaseBytes = database.getBytes(StandardCharsets.UTF_8); 88 | final byte[] paramsBytes = mapper.writeValueAsString(params).getBytes(StandardCharsets.UTF_8); 89 | 90 | ByteBuffer buffer = ByteBuffer.allocate(cypherBytes.length + paramsBytes.length + 12); 91 | // Size prefixes, as scalars, are transmitted in network byte order 92 | buffer.order(ByteOrder.BIG_ENDIAN); 93 | 94 | // Cypher length 95 | buffer.putShort((short) cypherBytes.length); 96 | // Cypher utf8 string 97 | buffer.put(cypherBytes); 98 | // Database name length 99 | buffer.putShort((short) databaseBytes.length); 100 | // Database utf8 string 101 | buffer.put(databaseBytes); 102 | // JSON params length 103 | buffer.putShort((short) paramsBytes.length); 104 | // JSON params utf8 string 105 | buffer.put(paramsBytes); 106 | 107 | return buffer.array(); 108 | } catch (Exception e) { 109 | logger.error("serialization error", e); 110 | } 111 | return new byte[0]; 112 | } 113 | 114 | public String getCypher() { 115 | return cypher; 116 | } 117 | 118 | public String getDatabase() { 119 | return database; 120 | } 121 | 122 | public Map getParams() { 123 | return params; 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /common/src/test/java/org/neo4j/arrow/action/CypherMessageTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import java.io.IOException; 7 | import java.nio.ByteBuffer; 8 | import java.nio.charset.StandardCharsets; 9 | import java.util.Map; 10 | 11 | public class CypherMessageTest { 12 | private static final org.slf4j.Logger logger; 13 | 14 | static { 15 | // Set up nicer logging output. 16 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 17 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 18 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 19 | logger = org.slf4j.LoggerFactory.getLogger(CypherMessageTest.class); 20 | } 21 | 22 | @Test 23 | public void canSerializeCypherMessages() throws IOException { 24 | final String cypher = "MATCH (n:MyLabel) RETURN [n.prop1] AS prop1, n AS node;"; 25 | final Map map = Map.of("n", 123, "name", "Dave"); 26 | 27 | CypherMessage msg = new CypherMessage("neo4j", cypher, map); 28 | byte[] bytes = msg.serialize(); 29 | Assertions.assertTrue(bytes.length > 0); 30 | 31 | CypherMessage msg2 = CypherMessage.deserialize(bytes); 32 | System.out.println(StandardCharsets.UTF_8.decode(ByteBuffer.wrap(bytes))); 33 | 34 | Assertions.assertEquals(cypher, msg2.getCypher()); 35 | Assertions.assertEquals(map, msg2.getParams()); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /examples/PyArrow Demo - 17 Aug 2021.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neo4j-field/neo4j-arrow/4eed6311eb5b0f7beae77a668ed282ce1cdced83/examples/PyArrow Demo - 17 Aug 2021.pdf -------------------------------------------------------------------------------- /examples/arrow_to_bq.py: -------------------------------------------------------------------------------- 1 | import pyarrow as pa 2 | from src.main.neo4j_arrow import neo4j_arrow as na 3 | from google.cloud import bigquery 4 | 5 | import os, sys, threading, queue 6 | from time import time 7 | 8 | ### Config 9 | HOST = os.environ.get('NEO4J_ARROW_HOST', 'localhost') 10 | PORT = int(os.environ.get('NEO4J_ARROW_PORT', '9999')) 11 | USERNAME = os.environ.get('NEO4J_USERNAME', 'neo4j') 12 | PASSWORD = os.environ.get('NEO4J_PASSWORD', 'password') 13 | GRAPH = os.environ.get('NEO4J_GRAPH', 'random') 14 | PROPERTIES = [ 15 | p for p in os.environ.get('NEO4J_PROPERTIES', '').split(',') 16 | if len(p) > 0] 17 | TLS = len(os.environ.get('NEO4J_ARROW_TLS', '')) > 0 18 | TLS_VERIFY = len(os.environ.get('NEO4J_ARROW_TLS_NO_VERIFY', '')) < 1 19 | DATASET = os.environ.get('BQ_DATASET', 'neo4j_arrow') 20 | TABLE = os.environ.get('BQ_TABLE', 'nodes') 21 | DELETE = len(os.environ.get('BQ_DELETE_FIRST', '')) > 0 22 | 23 | ### Globals 24 | bq_client = bigquery.Client() 25 | data_q = queue.Queue() 26 | job_q = queue.Queue(maxsize=8) 27 | done_feeding = threading.Event() 28 | 29 | ### Various Load Options 30 | parquet_options = bigquery.format_options.ParquetOptions() 31 | parquet_options.enable_list_inference = True 32 | bigquery.dataset = DATASET 33 | job_config = bigquery.LoadJobConfig() 34 | job_config.source_format = bigquery.SourceFormat.PARQUET 35 | job_config.parquet_options = parquet_options 36 | 37 | 38 | def upload_complete(load_job): 39 | # print(f'bq upload complete ({load_job.job_id})') 40 | q.task_done() 41 | 42 | def write_to_bigquery(writer_id, table, job_config): 43 | """Convert a PyArrow Table to a Parquet file and load into BigQuery""" 44 | writer = pa.BufferOutputStream() 45 | pa.parquet.write_table(table, writer, use_compliant_nested_type=True) 46 | pq_reader = pa.BufferReader(writer.getvalue()) 47 | 48 | load_job = bq_client.load_table_from_file( 49 | pq_reader, f'{DATASET}.{TABLE}', job_config=job_config) 50 | load_job.add_done_callback(upload_complete) 51 | return load_job 52 | 53 | def bq_writer(writer_id): 54 | """Primary logice for a BigQuery writer thread""" 55 | global done_feeding, job_config 56 | jobs = [] 57 | print(f"w({writer_id}): writer starting") 58 | 59 | while True: 60 | try: 61 | batch = q.get(timeout=5) 62 | if len(batch) < 1: 63 | break 64 | table = pa.Table.from_batches(batch, batch[0].schema) 65 | load_job = write_to_bigquery(writer_id, table, job_config) 66 | jobs.append(load_job) 67 | except queue.Empty: 68 | # use this as a chance to cleanup our jobs 69 | _jobs = [] 70 | for j in jobs: 71 | if j.running(): 72 | _jobs.append(j) 73 | elif j.error_result: 74 | # consider this fatal for now...might have hit a rate-limit! 75 | print(f"w({writer_id}): job {j} had an error {j.error_result}!!!") 76 | sys.exit(1) 77 | if len(_jobs) > 0: 78 | print(f"w({writer_id}): waiting on {len(jobs)} bq load jobs") 79 | elif done_feeding.is_set(): 80 | break 81 | print(f"w({writer_id}): finished") 82 | 83 | 84 | def stream_records(reader): 85 | """Consume a neo4j-arrow GDS stream and populate a work queue""" 86 | print('Start arrow table processing') 87 | 88 | cnt, rows, nbytes = 0, 0, 0 89 | batch = [] 90 | start = time() 91 | for chunk, metadata, in reader: 92 | cnt = cnt + chunk.num_rows 93 | rows = rows + chunk.num_rows 94 | nbytes = nbytes + chunk.nbytes 95 | batch.append(chunk) 96 | if rows >= 100_000: 97 | q.put(batch) 98 | nbytes = (nbytes >> 20) 99 | print(f"stream row @ {cnt:,}, batch size: {rows:,} rows, {nbytes:,} MiB") 100 | batch = [] 101 | rows, nbytes = 0, 0 102 | if len(batch) > 0: 103 | # add any remaining data 104 | q.put(batch) 105 | 106 | # signal we're done consuming the source feed and wait for work to complete 107 | done_feeding.set() 108 | q.join() 109 | 110 | finish = time() 111 | print(f"Done! Time Delta: {round(finish - start, 1):,}s") 112 | print(f"Count: {cnt:,} rows, Rate: {round(cnt / (finish - start)):,} rows/s") 113 | 114 | 115 | if __name__ == "__main__": 116 | print("Creating neo4j-arrow client") 117 | client = na.Neo4jArrow(USERNAME, PASSWORD, (HOST, PORT), 118 | tls=TLS, verifyTls=TLS_VERIFY) 119 | 120 | print("Submitting read job for graph '{GRAPH}'") 121 | ticket = client.gds_nodes(GRAPH, properties=PROPERTIES) 122 | 123 | print("Starting worker threads") 124 | threads = [] 125 | for i in range(0, 12): 126 | t = threading.Thread(target=bq_writer, daemon=True, args=[i]) 127 | threads.append(t) 128 | t.start() 129 | 130 | print(f"Streaming nodes from {GRAPH} with properties {PROPERTIES}") 131 | client.wait_for_job(ticket, timeout=180) 132 | reader = client.stream(ticket) 133 | stream_records(reader) 134 | 135 | # try to nicely wait for threads to finish just in case 136 | for t in threads: 137 | t.join(timeout=60) 138 | -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | from src.main.neo4j_arrow import neo4j_arrow as na 2 | 3 | client = na.Neo4jArrow('neo4j', 'password', ('voutila-arrow-test', 9999)) 4 | ticket = client.gds_nodes('mygraph', properties=['n']) 5 | 6 | table = client.stream(ticket).read_all() 7 | print(f"table num_rows={table.num_rows:,}") 8 | 9 | data = table[5:8].to_pydict() 10 | print(f"nodeIds: {data['nodeId']}") 11 | print(f"embeddings (sample): {[x[0:4] for x in data['n']]}") 12 | 13 | -------------------------------------------------------------------------------- /examples/khop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from src.main.neo4j_arrow import neo4j_arrow as na 3 | from time import time 4 | import os, sys 5 | 6 | graph = sys.argv[-1] 7 | db = sys.argv[-2] 8 | HOST = os.environ.get('HOST', 'localhost') 9 | print(f"using graph {graph} in db {db} on host {HOST}") 10 | 11 | client = na.Neo4jArrow('neo4j', 'password', (HOST, 9999)) 12 | 13 | t = client.khop(graph, database=db) 14 | rows, nbytes, log_cnt, log_nbytes = 0, 0, 0, 0 15 | 16 | start = time() 17 | original_start = start 18 | for (chunk, _) in client.stream(t): 19 | rows = rows + chunk.num_rows 20 | log_cnt = log_cnt + chunk.num_rows 21 | nbytes = nbytes + chunk.nbytes 22 | log_nbytes = log_nbytes + chunk.nbytes 23 | 24 | if log_cnt >= 1_000_000: 25 | delta = time() - start 26 | rate = int(log_nbytes / delta) 27 | print(f"@ {rows:,} rows, {(nbytes >> 20):,} MiB, rate: {(rate >> 20):,} MiB/s, {int(log_cnt / delta):,} row/s") 28 | log_cnt, log_nbytes = 0, 0 29 | start = time() 30 | 31 | delta = time() - original_start 32 | rate = int(nbytes / delta) 33 | print(f"consumed {rows:,}, {(nbytes >> 20):,} MiB") 34 | print(f"rate: {(rate >> 20):,} MiB/s, {int(rows / delta):,}") 35 | 36 | 37 | -------------------------------------------------------------------------------- /examples/yolo.py: -------------------------------------------------------------------------------- 1 | from src.main.neo4j_arrow import neo4j_arrow as na 2 | from time import sleep 3 | 4 | client = na.Neo4jArrow('neo4j', 'password', ('', 9999)) 5 | 6 | db, graph, props = 'neo4j', 'mygraph', ['fastRp'] 7 | print(f"fetching graph '{graph}' from db '{db}' with node props {props}") 8 | 9 | ticket = client.gds_nodes(graph, database=db, properties=props) 10 | nodes = client.stream(ticket).read_all().to_pandas() 11 | print(f"Got nodes:\n{nodes}") 12 | 13 | ticket = client.gds_relationships(graph, database=db) 14 | rels = client.stream(ticket).read_all().to_pandas() 15 | # for now we need to slice off rel props 16 | subrels = rels[['_source_id_', '_target_id_', '_type_']] 17 | print(f"Got relationships:\n{subrels}") 18 | 19 | #### Now for writing! 20 | db, graph = 'hacks', 'test' 21 | print(f"Writing our graph to db '{db}', graph '{graph}'") 22 | 23 | ticket = client.gds_write_nodes(graph, database=db) 24 | client.put_stream(ticket, nodes) 25 | print("Wrote nodes...") 26 | 27 | sleep(0.5) 28 | 29 | ticket = client.gds_write_relationships(graph, database=db) 30 | client.put_stream(ticket, subrels) 31 | print("Wrote rels...") 32 | 33 | 34 | -------------------------------------------------------------------------------- /fast.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neo4j-field/neo4j-arrow/4eed6311eb5b0f7beae77a668ed282ce1cdced83/fast.gif -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neo4j-field/neo4j-arrow/4eed6311eb5b0f7beae77a668ed282ce1cdced83/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # 4 | # Copyright 2015 the original author or authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | ## 21 | ## Gradle start up script for UN*X 22 | ## 23 | ############################################################################## 24 | 25 | # Attempt to set APP_HOME 26 | # Resolve links: $0 may be a link 27 | PRG="$0" 28 | # Need this for relative symlinks. 29 | while [ -h "$PRG" ] ; do 30 | ls=`ls -ld "$PRG"` 31 | link=`expr "$ls" : '.*-> \(.*\)$'` 32 | if expr "$link" : '/.*' > /dev/null; then 33 | PRG="$link" 34 | else 35 | PRG=`dirname "$PRG"`"/$link" 36 | fi 37 | done 38 | SAVED="`pwd`" 39 | cd "`dirname \"$PRG\"`/" >/dev/null 40 | APP_HOME="`pwd -P`" 41 | cd "$SAVED" >/dev/null 42 | 43 | APP_NAME="Gradle" 44 | APP_BASE_NAME=`basename "$0"` 45 | 46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 48 | 49 | # Use the maximum available, or set MAX_FD != -1 to use that value. 50 | MAX_FD="maximum" 51 | 52 | warn () { 53 | echo "$*" 54 | } 55 | 56 | die () { 57 | echo 58 | echo "$*" 59 | echo 60 | exit 1 61 | } 62 | 63 | # OS specific support (must be 'true' or 'false'). 64 | cygwin=false 65 | msys=false 66 | darwin=false 67 | nonstop=false 68 | case "`uname`" in 69 | CYGWIN* ) 70 | cygwin=true 71 | ;; 72 | Darwin* ) 73 | darwin=true 74 | ;; 75 | MSYS* | MINGW* ) 76 | msys=true 77 | ;; 78 | NONSTOP* ) 79 | nonstop=true 80 | ;; 81 | esac 82 | 83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 84 | 85 | 86 | # Determine the Java command to use to start the JVM. 87 | if [ -n "$JAVA_HOME" ] ; then 88 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 89 | # IBM's JDK on AIX uses strange locations for the executables 90 | JAVACMD="$JAVA_HOME/jre/sh/java" 91 | else 92 | JAVACMD="$JAVA_HOME/bin/java" 93 | fi 94 | if [ ! -x "$JAVACMD" ] ; then 95 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 96 | 97 | Please set the JAVA_HOME variable in your environment to match the 98 | location of your Java installation." 99 | fi 100 | else 101 | JAVACMD="java" 102 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 103 | 104 | Please set the JAVA_HOME variable in your environment to match the 105 | location of your Java installation." 106 | fi 107 | 108 | # Increase the maximum file descriptors if we can. 109 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 110 | MAX_FD_LIMIT=`ulimit -H -n` 111 | if [ $? -eq 0 ] ; then 112 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 113 | MAX_FD="$MAX_FD_LIMIT" 114 | fi 115 | ulimit -n $MAX_FD 116 | if [ $? -ne 0 ] ; then 117 | warn "Could not set maximum file descriptor limit: $MAX_FD" 118 | fi 119 | else 120 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 121 | fi 122 | fi 123 | 124 | # For Darwin, add options to specify how the application appears in the dock 125 | if $darwin; then 126 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 127 | fi 128 | 129 | # For Cygwin or MSYS, switch paths to Windows format before running java 130 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then 131 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 132 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 133 | 134 | JAVACMD=`cygpath --unix "$JAVACMD"` 135 | 136 | # We build the pattern for arguments to be converted via cygpath 137 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 138 | SEP="" 139 | for dir in $ROOTDIRSRAW ; do 140 | ROOTDIRS="$ROOTDIRS$SEP$dir" 141 | SEP="|" 142 | done 143 | OURCYGPATTERN="(^($ROOTDIRS))" 144 | # Add a user-defined pattern to the cygpath arguments 145 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 146 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 147 | fi 148 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 149 | i=0 150 | for arg in "$@" ; do 151 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 152 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 153 | 154 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 155 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 156 | else 157 | eval `echo args$i`="\"$arg\"" 158 | fi 159 | i=`expr $i + 1` 160 | done 161 | case $i in 162 | 0) set -- ;; 163 | 1) set -- "$args0" ;; 164 | 2) set -- "$args0" "$args1" ;; 165 | 3) set -- "$args0" "$args1" "$args2" ;; 166 | 4) set -- "$args0" "$args1" "$args2" "$args3" ;; 167 | 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 168 | 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 169 | 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 170 | 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 171 | 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 172 | esac 173 | fi 174 | 175 | # Escape application args 176 | save () { 177 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 178 | echo " " 179 | } 180 | APP_ARGS=`save "$@"` 181 | 182 | # Collect all arguments for the java command, following the shell quoting and substitution rules 183 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 184 | 185 | exec "$JAVACMD" "$@" 186 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if "%ERRORLEVEL%" == "0" goto execute 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto execute 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :execute 68 | @rem Setup the command line 69 | 70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 71 | 72 | 73 | @rem Execute Gradle 74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* 75 | 76 | :end 77 | @rem End local scope for the variables with windows NT shell 78 | if "%ERRORLEVEL%"=="0" goto mainEnd 79 | 80 | :fail 81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 82 | rem the _cmd.exe /c_ return code! 83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 84 | exit /b 1 85 | 86 | :mainEnd 87 | if "%OS%"=="Windows_NT" endlocal 88 | 89 | :omega 90 | -------------------------------------------------------------------------------- /plugin/README.md: -------------------------------------------------------------------------------- 1 | # neo4j-arrow server plugin 2 | 3 | A Neo4j server-side extension exposing Arrow Flight services. This is the 4 | fun project you should build and play with :-) 5 | 6 | ## Building 7 | In the parent project, run: 8 | 9 | ``` 10 | $ ./gradlew :plugin:shadowJar 11 | ``` 12 | 13 | That's it! You'll get a `-all` plugin uberjar in: `plugins/build/libs` 14 | 15 | ## Installing 16 | Drop the uberjar into a 4.3 version of Neo4j Enterprise (hasn't been tested 17 | with Community Edition). You also need a recent 1.6.x version of GDS. 18 | 19 | > I recommend using `v1.6.4` of GDS as that's what the project builds with! 20 | 21 | ## Running 22 | The plugin exposes a new service on the Neo4j server. By default, this is on 23 | TCP port `9999`, though this is configurable via environment variables. (See 24 | the [server](../server) docs for now on what environment variables exist.) 25 | 26 | ## Ok, now what? 27 | Grab an Arrow client, either: 28 | 29 | * [neo4j-arrow.py](../neo4j_arrow.py) for Python 30 | * [org.neo4j.arrow.demo.Client](../client) for Java 31 | 32 | Docs are still coming together, but see the recent 33 | [demo walk-through](../PyArrow%20Demo.ipynb) using the PyArrow-based client 34 | for an example. 35 | 36 | ## An Example using Docker 37 | Assuming you've got a local directory called `./plugins` that has GDS 1.7 38 | and the `neo4j-arrow` jar file (and is readable by uid:gid 7474:7474): 39 | 40 | > Tune heap/pagecache as needed 41 | 42 | ```shell 43 | #!/bin/sh 44 | HOST=0.0.0.0 45 | PORT=9999 46 | docker run --rm -it --name neo4j-test \ 47 | -e HOST="${HOST}" -e PORT=${PORT} \ 48 | -v "$(pwd)/plugins":/plugins \ 49 | -e NEO4J_AUTH=neo4j/password \ 50 | -e NEO4J_ACCEPT_LICENSE_AGREEMENT=yes \ 51 | -p 7687:7687 -p 7474:7474 \ 52 | -v "$(pwd)/data:/data" \ 53 | -p "${PORT}:${PORT}" \ 54 | -e NEO4J_dbms_memory_heap_initial__size=80g \ 55 | -e NEO4J_dbms_memory_heap_max__size=80g \ 56 | -e NEO4J_dbms_memory_pagecache_size=16g \ 57 | -e NEO4J_dbms_security_procedures_unrestricted="gds.*" \ 58 | -e NEO4J_dbms_jvm_additional=-Dio.netty.tryReflectionSetAccessible=true \ 59 | -e NEO4J_dbms_allow__upgrade=true \ 60 | -e NEO4J_gds_enterprise_license__file=/plugins/gds.txt \ 61 | neo4j:4.3.6-enterprise 62 | ``` -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/ArrowExtensionFactory.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.neo4j.dbms.api.DatabaseManagementService; 4 | import org.neo4j.kernel.extension.ExtensionFactory; 5 | import org.neo4j.kernel.extension.ExtensionType; 6 | import org.neo4j.kernel.extension.context.ExtensionContext; 7 | import org.neo4j.kernel.lifecycle.Lifecycle; 8 | import org.neo4j.logging.internal.LogService; 9 | 10 | public class ArrowExtensionFactory extends ExtensionFactory { 11 | 12 | static { 13 | // XXX Neo4j's "Logger" is annoying me...need to figure out how to properly hook in Slf4j 14 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 15 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 16 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 17 | } 18 | 19 | public ArrowExtensionFactory() { 20 | super(ExtensionType.GLOBAL, "arrowExtension"); 21 | } 22 | 23 | @Override 24 | public Lifecycle newInstance(ExtensionContext context, Dependencies dependencies) { 25 | return new ArrowService(dependencies.dbms(), dependencies.logService()); 26 | } 27 | 28 | public interface Dependencies { 29 | DatabaseManagementService dbms(); 30 | LogService logService(); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/ArrowService.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.apache.arrow.flight.Location; 4 | import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; 5 | import org.apache.arrow.memory.BufferAllocator; 6 | import org.apache.arrow.memory.RootAllocator; 7 | import org.apache.arrow.util.AutoCloseables; 8 | import org.neo4j.arrow.action.BulkImportActionHandler; 9 | import org.neo4j.arrow.action.CypherActionHandler; 10 | import org.neo4j.arrow.action.GdsActionHandler; 11 | import org.neo4j.arrow.action.GdsMessage; 12 | import org.neo4j.arrow.auth.NativeAuthValidator; 13 | import org.neo4j.arrow.job.*; 14 | import org.neo4j.configuration.GraphDatabaseSettings; 15 | import org.neo4j.dbms.api.DatabaseManagementService; 16 | import org.neo4j.gds.compat.GraphDatabaseApiProxy; 17 | import org.neo4j.kernel.api.security.AuthManager; 18 | import org.neo4j.kernel.lifecycle.LifecycleAdapter; 19 | import org.neo4j.logging.Log; 20 | import org.neo4j.logging.internal.LogService; 21 | 22 | import java.util.concurrent.CompletableFuture; 23 | import java.util.concurrent.ExecutionException; 24 | import java.util.concurrent.Executor; 25 | import java.util.concurrent.TimeUnit; 26 | import java.util.function.Supplier; 27 | import java.util.stream.Collectors; 28 | import java.util.stream.Stream; 29 | 30 | /** 31 | * An Apache Arrow service for Neo4j, offering Arrow RPC-based access to Cypher and GDS services. 32 | *

33 | * Since this runs as a database plugin, all Cypher access is via the Transaction API. GDS access 34 | * is available directly to the Graph Catalog. 35 | */ 36 | public class ArrowService extends LifecycleAdapter { 37 | 38 | private final DatabaseManagementService dbms; 39 | private final Log log; 40 | 41 | private App app; 42 | private Location location; 43 | private BufferAllocator allocator; 44 | 45 | public ArrowService(DatabaseManagementService dbms, LogService logService) { 46 | this.dbms = dbms; 47 | this.log = logService.getUserLog(ArrowService.class); 48 | } 49 | 50 | @Override 51 | public void init() throws Exception { 52 | super.init(); 53 | log.info(">>>--[Arrow]--> init()"); 54 | allocator = new RootAllocator(Config.maxArrowMemory); 55 | 56 | // Our TLS support is not yet integrated into neo4j.conf 57 | if (!Config.tlsCertficate.isBlank() && !Config.tlsPrivateKey.isBlank()) { 58 | location = Location.forGrpcTls(Config.host, Config.port); 59 | } else { 60 | location = Location.forGrpcInsecure(Config.host, Config.port); 61 | } 62 | 63 | // Allocator debug logging... 64 | CompletableFuture.runAsync(() -> { 65 | // Allocator debug logging... 66 | if (!System.getenv() 67 | .getOrDefault("ARROW_ALLOCATOR_HEARTBEAT", "") 68 | .isBlank()) { 69 | final Executor delayedExecutor = CompletableFuture.delayedExecutor(30, TimeUnit.SECONDS); 70 | while (true) { 71 | try { 72 | CompletableFuture.runAsync(() -> { 73 | var s = Stream.concat(Stream.of(allocator), allocator.getChildAllocators() 74 | .stream() 75 | .flatMap(child -> Stream.concat(Stream.of(child), child.getChildAllocators() 76 | .stream() 77 | .flatMap(grandkid -> Stream.concat(Stream.of(grandkid), grandkid.getChildAllocators() 78 | .stream() 79 | .flatMap(greatgrandkid -> Stream.concat(Stream.of(greatgrandkid), 80 | greatgrandkid.getChildAllocators().stream()))))))) 81 | .map(a -> String.format("%s - %,d MiB allocated, %,d MiB limit", 82 | a.getName(), (a.getAllocatedMemory() >> 20), (a.getLimit() >> 20))); 83 | log.info("allocator report:\n" + String.join("\n", s.collect(Collectors.toList()))); 84 | }, delayedExecutor) 85 | .get(); 86 | } catch (InterruptedException | ExecutionException e) { 87 | log.error(e.getMessage(), e); 88 | break; 89 | } 90 | } 91 | } 92 | }); 93 | 94 | // Use GDS's handy hooks to get our Auth Manager. Needs to be deferred as it will fail 95 | // if we try to get a reference here since it doesn't exist yet. 96 | final Supplier authManager = () -> 97 | GraphDatabaseApiProxy.resolveDependency(dbms.database( 98 | GraphDatabaseSettings.SYSTEM_DATABASE_NAME), AuthManager.class); 99 | 100 | app = new App(allocator, location, "neo4j-arrow-plugin", 101 | new BasicCallHeaderAuthenticator(new NativeAuthValidator(authManager, log))); 102 | 103 | app.registerHandler(new CypherActionHandler( 104 | (msg, mode, username) -> new TransactionApiJob(msg, username, dbms, log))); 105 | app.registerHandler(new GdsActionHandler( 106 | (msg, mode, username) -> // XXX casts and stuff 107 | (mode == Job.Mode.READ) ? new GdsReadJob((GdsMessage) msg, username) 108 | : new GdsWriteJob(msg, username, allocator, dbms), log)); 109 | app.registerHandler(new BulkImportActionHandler( 110 | (msg, mode, username) -> new BulkImportJob(msg, username, allocator, dbms))); 111 | } 112 | 113 | @Override 114 | public void start() throws Exception { 115 | super.start(); 116 | log.info(">>>--[Arrow]--> start()"); 117 | app.start(); 118 | log.info("started arrow app at location " + location); 119 | } 120 | 121 | @Override 122 | public void stop() throws Exception { 123 | super.stop(); 124 | 125 | log.info(">>>--[Arrow]--> stop()"); 126 | 127 | // Use an async approach to stopping so we don't completely block Neo4j's shutdown waiting 128 | // for streams to terminate. 129 | long timeout = 5; 130 | TimeUnit unit = TimeUnit.SECONDS; 131 | 132 | log.info(String.format(">>>--[Arrow]--> waiting %d %s for jobs to complete", timeout, unit)); 133 | CompletableFuture.runAsync(() -> { 134 | try { 135 | app.awaitTermination(timeout, unit); 136 | log.info("stopped app " + app); 137 | } catch (InterruptedException e) { 138 | log.error("failed to stop app " + app, e); 139 | } 140 | }); 141 | } 142 | 143 | @Override 144 | public void shutdown() throws Exception { 145 | super.shutdown(); 146 | log.info(">>>--[Arrow]--> shutdown()"); 147 | 148 | AutoCloseables.close(app, allocator); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/GdsNodeRecord.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.neo4j.gds.ElementIdentifier; 4 | import org.neo4j.gds.NodeLabel; 5 | import org.neo4j.gds.api.NodeProperties; 6 | import org.neo4j.gds.api.nodeproperties.ValueType; 7 | 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | import java.util.Objects; 11 | import java.util.Set; 12 | import java.util.function.Function; 13 | import java.util.stream.Collectors; 14 | 15 | /** 16 | * Wrapper around a record of Nodes and properties from the in-memory graph in GDS. 17 | *

18 | * GDS gives us the ability to use fixed-width data structures, specifically certain Java 19 | * primitives. This dramatically increases the efficiency and performance of Apache Arrow. As 20 | * such, we use as many Java arrays as we can in lieu of boxed types in {@link List}s. 21 | *

22 | */ 23 | public class GdsNodeRecord extends GdsRecord { 24 | /** Represents the underlying node id */ 25 | private final Value nodeId; 26 | private final Value labels; 27 | 28 | // THESE SHOULD MATCH THE DEFAULTS IN neo4j_arrow.py!!!! 29 | public static final String NODE_ID_FIELD = Neo4jDefaults.ID_FIELD; 30 | public static final String LABELS_FIELD = Neo4jDefaults.LABELS_FIELD; 31 | 32 | protected GdsNodeRecord(long nodeId, Set labels, String[] keys, Value[] values, Function nodeIdResolver) { 33 | super(keys, values); 34 | this.nodeId = wrapScalar(nodeIdResolver.apply(nodeId), ValueType.LONG); 35 | this.labels = wrapNodeLabels(labels); 36 | } 37 | 38 | /** 39 | * Wrap the given GDS information into a single {@link GdsNodeRecord}. 40 | * 41 | * @param nodeId the native node id of the record 42 | * @param labels Set of node labels for the given node 43 | * @param fieldNames the names of the properties or fields 44 | * @param propertiesArray an array of references to the {@link NodeProperties} interface for 45 | * resolving the property values 46 | * @return a new {@link GdsNodeRecord} 47 | */ 48 | public static GdsNodeRecord wrap(long nodeId, Set labels, String[] fieldNames, 49 | NodeProperties[] propertiesArray, Function nodeIdResolver) { 50 | final Value[] values = new Value[propertiesArray.length]; 51 | 52 | assert fieldNames.length == values.length; 53 | 54 | for (int i=0; i nodeLabels) { 85 | return wrapLabels(nodeLabels.stream().map(ElementIdentifier::name).collect(Collectors.toUnmodifiableSet())); 86 | } 87 | 88 | public static Value wrapLabels(Set labels) { 89 | return new Value() { 90 | final List list = new ArrayList<>(labels); 91 | @Override 92 | public int size() { 93 | return list.size(); 94 | } 95 | 96 | @Override 97 | public String asString() { 98 | return String.join(",", list); 99 | } 100 | 101 | @Override 102 | public List asStringList() { 103 | return list; 104 | } 105 | 106 | @Override 107 | public List asList() { 108 | return new ArrayList<>(labels); 109 | } 110 | 111 | @Override 112 | public Type type() { 113 | return Type.STRING_LIST; 114 | } 115 | }; 116 | } 117 | 118 | @Override 119 | public Value get(int index) { 120 | if (index < 0 || index >= valueArray.length + 2) 121 | throw new RuntimeException("invalid index"); 122 | if (index == 0) 123 | return nodeId; 124 | if (index == 1) 125 | return labels; 126 | return valueArray[index - 2]; 127 | } 128 | 129 | @Override 130 | public Value get(String field) { 131 | if (NODE_ID_FIELD.equals(field)) { 132 | return nodeId; 133 | } else if (LABELS_FIELD.equals(field)) { 134 | return labels; 135 | } else { 136 | for (int i = 0; i < keyArray.length; i++) 137 | if (keyArray[i].equals(field)) 138 | return valueArray[i]; 139 | } 140 | throw new RuntimeException("invalid field"); 141 | } 142 | 143 | @Override 144 | public List keys() { 145 | ArrayList list = new ArrayList<>(); 146 | list.add(NODE_ID_FIELD); 147 | list.add(LABELS_FIELD); 148 | list.addAll(List.of(keyArray)); 149 | return list; 150 | } 151 | 152 | } 153 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/GdsRelationshipRecord.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.neo4j.gds.api.nodeproperties.ValueType; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * A simplified record that mimics the structure output by gds.graph.streamRelationships() 9 | */ 10 | public class GdsRelationshipRecord extends GdsRecord { 11 | 12 | // THESE SHOULD MATCH THE DEFAULTS IN neo4j_arrow.py!!!! 13 | public static final String SOURCE_FIELD = Neo4jDefaults.SOURCE_FIELD; 14 | public static final String TARGET_FIELD = Neo4jDefaults.TARGET_FIELD; 15 | public static final String TYPE_FIELD = Neo4jDefaults.TYPE_FIELD; 16 | public static final String PROPERTY_FIELD = "property"; 17 | public static final String VALUE_FIELD = "value"; 18 | 19 | private final Value sourceId; 20 | private final Value targetId; 21 | private final Value relType; 22 | private final Value property; 23 | private final Value value; 24 | 25 | public GdsRelationshipRecord(long sourceId, long targetId, String type, String property, Value value) { 26 | super(new String[0], new Value[0]); 27 | this.sourceId = wrapScalar(sourceId, ValueType.LONG); 28 | this.targetId = wrapScalar(targetId, ValueType.LONG); 29 | this.relType = wrapString(type); 30 | this.property = wrapString(property); 31 | this.value = value; 32 | } 33 | 34 | @Override 35 | public Value get(int index) { 36 | switch (index) { 37 | case 0: 38 | return sourceId; 39 | case 1: 40 | return targetId; 41 | case 2: 42 | return relType; 43 | case 3: 44 | return property; 45 | case 4: 46 | return value; 47 | default: 48 | throw new RuntimeException("invalid index"); 49 | } 50 | } 51 | 52 | @Override 53 | public Value get(String field) { 54 | switch (field) { 55 | case SOURCE_FIELD: 56 | return sourceId; 57 | case TARGET_FIELD: 58 | return targetId; 59 | case TYPE_FIELD: 60 | return relType; 61 | case PROPERTY_FIELD: 62 | return property; 63 | case VALUE_FIELD: 64 | return value; 65 | default: 66 | throw new RuntimeException("invalid field"); 67 | } 68 | } 69 | 70 | @Override 71 | public List keys() { 72 | return List.of(SOURCE_FIELD, TARGET_FIELD, TYPE_FIELD, PROPERTY_FIELD, VALUE_FIELD); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/Neo4jDefaults.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | /** 4 | * Global defaults to keep consistency across Neo4j-specific Messages. 5 | */ 6 | public class Neo4jDefaults { 7 | /** Used for Node ids. */ 8 | public static final String ID_FIELD = "ID"; 9 | 10 | /** Used for Node labels. */ 11 | public static final String LABELS_FIELD = "LABELS"; 12 | 13 | /** Beginning Node id in a Relationship. */ 14 | public static final String SOURCE_FIELD = "START_ID"; 15 | 16 | /** Ending Node id in a Relationship. */ 17 | public static final String TARGET_FIELD = "END_ID"; 18 | 19 | /** Relationship Type. */ 20 | public static final String TYPE_FIELD = "TYPE"; 21 | } 22 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/SubGraphRecord.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.neo4j.arrow.gds.Edge; 4 | 5 | import java.util.Arrays; 6 | import java.util.List; 7 | import java.util.function.Function; 8 | import java.util.stream.Collectors; 9 | 10 | public class SubGraphRecord implements RowBasedRecord { 11 | 12 | public final static String KEY_ORIGIN_ID = Neo4jDefaults.ID_FIELD; 13 | public final static String KEY_SOURCE_IDS = Neo4jDefaults.SOURCE_FIELD; 14 | public final static String KEY_TARGET_IDS = Neo4jDefaults.TARGET_FIELD; 15 | 16 | private final String[] keys = { 17 | KEY_ORIGIN_ID, 18 | KEY_SOURCE_IDS, 19 | KEY_TARGET_IDS 20 | }; 21 | 22 | private final int origin; 23 | private final int[] sourceIds; 24 | private final int[] targetIds; 25 | 26 | 27 | protected SubGraphRecord(int origin, int[] sourceIds, int[] targetIds) { 28 | this.origin = origin; 29 | this.sourceIds = sourceIds; 30 | this.targetIds = targetIds; 31 | } 32 | 33 | public static SubGraphRecord of(int origin, int[] sourceIds, int[] targetIds) { 34 | if (sourceIds.length != targetIds.length) { 35 | throw new IllegalArgumentException("length of both source and target id arrays must be the same"); 36 | } 37 | return new SubGraphRecord(origin, sourceIds, targetIds); 38 | } 39 | 40 | public static SubGraphRecord of(long origin, Iterable edges, int size, Function idMapper) { 41 | final int[] sources = new int[size]; 42 | final int[] targets = new int[size]; 43 | 44 | int pos = 0; 45 | for (long edge : edges) { 46 | sources[pos] = idMapper.apply(Edge.source(edge)).intValue(); 47 | targets[pos] = idMapper.apply(Edge.target(edge)).intValue(); 48 | pos++; 49 | } 50 | return new SubGraphRecord(idMapper.apply(origin).intValue(), sources, targets); // XXX cast 51 | } 52 | 53 | @Override 54 | public Value get(int index) { 55 | switch (index) { 56 | case 0: return GdsRecord.wrapInt(origin); 57 | case 1: return GdsRecord.wrapInts(sourceIds); 58 | case 2: return GdsRecord.wrapInts(targetIds); 59 | default: 60 | throw new RuntimeException("invalid index: " + index); 61 | } 62 | } 63 | 64 | @Override 65 | public Value get(String field) { 66 | switch (field) { 67 | case KEY_ORIGIN_ID: return get(0); 68 | case KEY_SOURCE_IDS: return get(1); 69 | case KEY_TARGET_IDS: return get(2); 70 | default: 71 | throw new RuntimeException("invalid field: " + field); 72 | } 73 | } 74 | 75 | @Override 76 | public String toString() { 77 | return "SubGraphRecord{" + 78 | "origin=" + origin + 79 | ", sourceIds=" + sourceIds + 80 | ", targetIds=" + targetIds + 81 | '}'; 82 | } 83 | 84 | @Override 85 | public List keys() { 86 | return Arrays.stream(keys).collect(Collectors.toList()); 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/action/BulkImportActionHandler.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.apache.arrow.flight.*; 4 | import org.neo4j.arrow.Producer; 5 | import org.neo4j.arrow.job.BulkImportJob; 6 | import org.neo4j.arrow.job.Job; 7 | import org.neo4j.arrow.job.JobCreator; 8 | 9 | import java.io.IOException; 10 | import java.util.List; 11 | 12 | public class BulkImportActionHandler implements ActionHandler { 13 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(BulkImportActionHandler.class); 14 | 15 | private final JobCreator jobCreator; 16 | public static String importBulkActionType = "import.bulk"; 17 | public static String importBulkActionDescription = "Use neo4j bulk import to bootstrap a new database."; 18 | 19 | public BulkImportActionHandler(JobCreator jobCreator) { 20 | this.jobCreator = jobCreator; 21 | } 22 | 23 | @Override 24 | public List actionTypes() { 25 | return List.of(importBulkActionType); 26 | } 27 | 28 | @Override 29 | public List actionDescriptions() { 30 | return List.of(new ActionType(importBulkActionType, importBulkActionDescription)); 31 | } 32 | 33 | @Override 34 | public Outcome handle(FlightProducer.CallContext context, Action action, Producer producer) { 35 | final String username = context.peerIdentity(); 36 | logger.info("user {} attempting a bulk import action", username); 37 | 38 | // XXX assert user is admin/neo4j? 39 | 40 | try { 41 | final BulkImportMessage msg = BulkImportMessage.deserialize(action.getBody()); 42 | final BulkImportJob job = jobCreator.newJob(msg, Job.Mode.WRITE, username); 43 | final Ticket ticket = producer.ticketJob(job); 44 | return Outcome.success(new Result(ticket.serialize().array())); 45 | 46 | } catch (IOException e) { 47 | e.printStackTrace(); 48 | return Outcome.failure(CallStatus.INTERNAL.withDescription(e.getMessage())); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/action/BulkImportMessage.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import com.fasterxml.jackson.core.JsonParser; 4 | import com.fasterxml.jackson.core.JsonProcessingException; 5 | import com.fasterxml.jackson.core.type.TypeReference; 6 | import com.fasterxml.jackson.databind.ObjectMapper; 7 | import org.neo4j.arrow.Neo4jDefaults; 8 | 9 | import java.io.IOException; 10 | import java.nio.charset.StandardCharsets; 11 | import java.util.Map; 12 | 13 | public class BulkImportMessage implements Message { 14 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(BulkImportMessage.class); 15 | static private final ObjectMapper mapper = new ObjectMapper(); 16 | 17 | public static final String JSON_KEY_DATABASE_NAME = "db"; 18 | 19 | public static final String JSON_KEY_NODE_ID_FIELD = "id_field"; 20 | public static final String DEFAULT_NODE_ID_FIELD = Neo4jDefaults.ID_FIELD; 21 | 22 | public static final String JSON_KEY_LABELS_FIELD = "labels_field"; 23 | public static final String DEFAULT_LABELS_FIELD = Neo4jDefaults.LABELS_FIELD; 24 | 25 | public static final String JSON_KEY_SOURCE_FIELD = "source_field"; 26 | public static final String DEFAULT_SOURCE_FIELD = Neo4jDefaults.SOURCE_FIELD; 27 | 28 | public static final String JSON_KEY_TARGET_FIELD = "target_field"; 29 | public static final String DEFAULT_TARGET_FIELD = Neo4jDefaults.TARGET_FIELD; 30 | 31 | public static final String JSON_KEY_TYPE_FIELD = "type_field"; 32 | public static final String DEFAULT_TYPE_FIELD = Neo4jDefaults.TYPE_FIELD; 33 | 34 | private final String dbName; 35 | private final String idField; 36 | private final String labelsField; 37 | private final String sourceField; 38 | private final String targetField; 39 | private final String typeField; 40 | 41 | private String eitherOrDefault(String input, String defaultValue) { 42 | if (input == null || input.isEmpty()) 43 | return defaultValue; 44 | return input; 45 | } 46 | 47 | public BulkImportMessage(String dbName, String idField, String labelsField, 48 | String sourceField, String targetField, String typeField) { 49 | this.dbName = dbName; 50 | this.idField = eitherOrDefault(idField, DEFAULT_NODE_ID_FIELD); 51 | this.labelsField = eitherOrDefault(labelsField, DEFAULT_LABELS_FIELD); 52 | this.sourceField = eitherOrDefault(sourceField, DEFAULT_SOURCE_FIELD); 53 | this.targetField = eitherOrDefault(targetField, DEFAULT_TARGET_FIELD); 54 | this.typeField = eitherOrDefault(typeField, DEFAULT_TYPE_FIELD); 55 | } 56 | 57 | @Override 58 | public byte[] serialize() { 59 | try { 60 | return mapper.writeValueAsString( 61 | Map.of(JSON_KEY_DATABASE_NAME, dbName, 62 | JSON_KEY_NODE_ID_FIELD, idField, 63 | JSON_KEY_LABELS_FIELD, labelsField, 64 | JSON_KEY_SOURCE_FIELD, sourceField, 65 | JSON_KEY_TARGET_FIELD, targetField, 66 | JSON_KEY_TYPE_FIELD, typeField)) 67 | .getBytes(StandardCharsets.UTF_8); 68 | } catch (JsonProcessingException e) { 69 | logger.error("serialization error", e); 70 | } 71 | return new byte[0]; 72 | } 73 | 74 | private static class MapTypeReference extends TypeReference> { } 75 | 76 | public static BulkImportMessage deserialize(byte[] bytes) throws IOException { 77 | final JsonParser parser = mapper.createParser(bytes); 78 | final Map params = parser.readValueAs(new BulkImportMessage.MapTypeReference()); 79 | 80 | final String dbName = params.getOrDefault(JSON_KEY_DATABASE_NAME, "").toString(); 81 | if (dbName.isEmpty()) { 82 | throw new IOException("empty or invalid db parameter"); 83 | } 84 | 85 | final String idField = params.get(JSON_KEY_NODE_ID_FIELD) != null ? 86 | params.get(JSON_KEY_NODE_ID_FIELD).toString() : null; 87 | 88 | final String labelsField = params.get(JSON_KEY_LABELS_FIELD) != null ? 89 | params.get(JSON_KEY_LABELS_FIELD).toString() : null; 90 | 91 | final String sourceField = params.get(JSON_KEY_SOURCE_FIELD) != null ? 92 | params.get(JSON_KEY_SOURCE_FIELD).toString() : null; 93 | 94 | final String targetField = params.get(JSON_KEY_TARGET_FIELD) != null ? 95 | params.get(JSON_KEY_TARGET_FIELD).toString() : null; 96 | 97 | final String typeField = params.get(JSON_KEY_TYPE_FIELD) != null ? 98 | params.get(JSON_KEY_TYPE_FIELD).toString() : null; 99 | 100 | return new BulkImportMessage(dbName, idField, labelsField, sourceField, 101 | targetField, typeField); 102 | 103 | } 104 | 105 | public String getDbName() { 106 | return dbName; 107 | } 108 | 109 | public String getIdField() { 110 | return idField; 111 | } 112 | 113 | public String getLabelsField() { 114 | return labelsField; 115 | } 116 | 117 | public String getSourceField() { 118 | return sourceField; 119 | } 120 | 121 | public String getTargetField() { 122 | return targetField; 123 | } 124 | 125 | public String getTypeField() { 126 | return typeField; 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/action/GdsWriteNodeMessage.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import com.fasterxml.jackson.core.JsonParser; 4 | import com.fasterxml.jackson.core.JsonProcessingException; 5 | import com.fasterxml.jackson.core.type.TypeReference; 6 | import com.fasterxml.jackson.databind.ObjectMapper; 7 | import org.neo4j.arrow.Neo4jDefaults; 8 | 9 | import javax.annotation.Nullable; 10 | import java.io.IOException; 11 | import java.nio.charset.StandardCharsets; 12 | import java.util.Map; 13 | 14 | public class GdsWriteNodeMessage implements Message { 15 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(GdsWriteNodeMessage.class); 16 | 17 | static private final ObjectMapper mapper = new ObjectMapper(); 18 | 19 | public static final String JSON_KEY_DATABASE_NAME = "db"; 20 | public static final String JSON_KEY_GRAPH_NAME = "graph"; 21 | // TODO: "type" is pretty vague...needs a better name 22 | public static final String JSON_KEY_NODE_ID_FIELD = "id_field"; 23 | public static final String DEFAULT_NODE_ID_FIELD = Neo4jDefaults.ID_FIELD; 24 | 25 | public static final String JSON_KEY_LABELS_FIELD = "labels_field"; 26 | public static final String DEFAULT_LABELS_FIELD = Neo4jDefaults.LABELS_FIELD; 27 | 28 | /** Name of the Neo4j Database where our Graph lives. Optional. Defaults to "neo4j" */ 29 | private final String dbName; 30 | /** Name of the Graph (projection) from the Graph Catalog. Required. */ 31 | private final String graphName; 32 | 33 | private final String idField; 34 | private final String labelsField; 35 | 36 | public GdsWriteNodeMessage(String dbName, String graphName, @Nullable String idField, @Nullable String labelsField) { 37 | this.dbName = dbName; 38 | this.graphName = graphName; 39 | 40 | this.idField = (idField != null) ? idField : DEFAULT_NODE_ID_FIELD; 41 | this.labelsField = (labelsField != null) ? labelsField : DEFAULT_LABELS_FIELD; 42 | } 43 | 44 | /** 45 | * Serialize the GdsMessage to a JSON blob of bytes. 46 | * @return byte[] containing UTF-8 JSON blob or byte[0] on error 47 | */ 48 | public byte[] serialize() { 49 | try { 50 | return mapper.writeValueAsString( 51 | Map.of(JSON_KEY_DATABASE_NAME, dbName, 52 | JSON_KEY_GRAPH_NAME, graphName, 53 | JSON_KEY_NODE_ID_FIELD, idField, 54 | JSON_KEY_LABELS_FIELD, labelsField)) 55 | .getBytes(StandardCharsets.UTF_8); 56 | } catch (JsonProcessingException e) { 57 | logger.error("serialization error", e); 58 | } 59 | return new byte[0]; 60 | } 61 | 62 | private static class MapTypeReference extends TypeReference> { } 63 | 64 | /** 65 | * Deserialize the given bytes, containing JSON, into a GdsMessage instance. 66 | * @param bytes UTF-8 bytes containing JSON payload 67 | * @return new GdsMessage 68 | * @throws IOException if error encountered during serialization 69 | */ 70 | public static GdsWriteNodeMessage deserialize(byte[] bytes) throws IOException { 71 | 72 | final JsonParser parser = mapper.createParser(bytes); 73 | final Map params = parser.readValueAs(new GdsWriteNodeMessage.MapTypeReference()); 74 | 75 | final String dbName = params.getOrDefault(JSON_KEY_DATABASE_NAME, "neo4j").toString(); 76 | // TODO: assert our minimum schema? 77 | final String graphName = params.get(JSON_KEY_GRAPH_NAME).toString(); 78 | 79 | final String idField = params.get(JSON_KEY_NODE_ID_FIELD) != null ? 80 | params.get(JSON_KEY_NODE_ID_FIELD).toString() : null; 81 | 82 | final String labelsField = params.get(JSON_KEY_LABELS_FIELD) != null ? 83 | params.get(JSON_KEY_LABELS_FIELD).toString() : null; 84 | 85 | return new GdsWriteNodeMessage(dbName, graphName, idField, labelsField); 86 | } 87 | 88 | public String getDbName() { 89 | return dbName; 90 | } 91 | 92 | public String getGraphName() { 93 | return graphName; 94 | } 95 | 96 | public String getIdField() { 97 | return idField; 98 | } 99 | 100 | public String getLabelsField() { 101 | return labelsField; 102 | } 103 | 104 | @Override 105 | public String toString() { 106 | return "GdsWriteNodeMessage{" + 107 | "db='" + dbName + "'" + 108 | ", graph='" + graphName + "'" + 109 | ", idField='" + idField + "'" + 110 | ", labelsField='" + labelsField + "'" + 111 | "}"; 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/action/GdsWriteRelsMessage.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import com.fasterxml.jackson.core.JsonParser; 4 | import com.fasterxml.jackson.core.JsonProcessingException; 5 | import com.fasterxml.jackson.core.type.TypeReference; 6 | import com.fasterxml.jackson.databind.ObjectMapper; 7 | import org.neo4j.arrow.Neo4jDefaults; 8 | 9 | import java.io.IOException; 10 | import java.nio.charset.StandardCharsets; 11 | import java.util.Map; 12 | 13 | public class GdsWriteRelsMessage implements Message { 14 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(GdsWriteRelsMessage.class); 15 | 16 | static private final ObjectMapper mapper = new ObjectMapper(); 17 | 18 | public static final String JSON_KEY_DATABASE_NAME = "db"; 19 | public static final String JSON_KEY_GRAPH_NAME = "graph"; 20 | 21 | public static final String JSON_KEY_SOURCE_FIELD = "source_field"; 22 | public static final String DEFAULT_SOURCE_FIELD = Neo4jDefaults.SOURCE_FIELD; 23 | 24 | public static final String JSON_KEY_TARGET_FIELD = "target_field"; 25 | public static final String DEFAULT_TARGET_FIELD = Neo4jDefaults.TARGET_FIELD; 26 | 27 | public static final String JSON_KEY_TYPE_FIELD = "type_field"; 28 | public static final String DEFAULT_TYPE_FIELD = Neo4jDefaults.TYPE_FIELD; 29 | 30 | /** Name of the Neo4j Database where our Graph lives. Optional. Defaults to "neo4j" */ 31 | private final String dbName; 32 | /** Name of the Graph (projection) from the Graph Catalog. Required. */ 33 | private final String graphName; 34 | 35 | private final String sourceField; 36 | private final String targetField; 37 | private final String typeField; 38 | 39 | protected GdsWriteRelsMessage(String dbName, String graphName, String sourceField, String targetField, String typeField) { 40 | this.dbName = dbName; 41 | this.graphName = graphName; 42 | 43 | this.sourceField = (sourceField != null) ? sourceField : DEFAULT_SOURCE_FIELD; 44 | this.targetField = (targetField != null) ? targetField : DEFAULT_TARGET_FIELD; 45 | this.typeField = (typeField != null) ? typeField : DEFAULT_TYPE_FIELD; 46 | } 47 | 48 | private static class MapTypeReference extends TypeReference> { } 49 | 50 | public static GdsWriteRelsMessage deserialize(byte[] bytes) throws IOException { 51 | final JsonParser parser = mapper.createParser(bytes); 52 | final Map params = parser.readValueAs(new GdsWriteRelsMessage.MapTypeReference()); 53 | 54 | final String dbName = params.getOrDefault(JSON_KEY_DATABASE_NAME, "neo4j").toString(); 55 | // TODO: assert our minimum schema? 56 | final String graphName = params.get(JSON_KEY_GRAPH_NAME).toString(); 57 | 58 | final String sourceIdField = params.get(JSON_KEY_SOURCE_FIELD) != null ? 59 | params.get(JSON_KEY_SOURCE_FIELD).toString() : null; 60 | 61 | final String targetIdField = params.get(JSON_KEY_TARGET_FIELD) != null ? 62 | params.get(JSON_KEY_TARGET_FIELD).toString() : null; 63 | 64 | final String typeField = params.get(JSON_KEY_TYPE_FIELD) != null ? 65 | params.get(JSON_KEY_TYPE_FIELD).toString() : null; 66 | 67 | return new GdsWriteRelsMessage(dbName, graphName, sourceIdField, targetIdField, typeField); 68 | } 69 | 70 | @Override 71 | public byte[] serialize() { 72 | try { 73 | return mapper.writeValueAsString( 74 | Map.of(JSON_KEY_DATABASE_NAME, dbName, 75 | JSON_KEY_GRAPH_NAME, graphName, 76 | JSON_KEY_SOURCE_FIELD, sourceField, 77 | JSON_KEY_TARGET_FIELD, targetField, 78 | JSON_KEY_TYPE_FIELD, typeField)) 79 | .getBytes(StandardCharsets.UTF_8); 80 | } catch (JsonProcessingException e) { 81 | logger.error("serialization error", e); 82 | } 83 | return new byte[0]; 84 | } 85 | 86 | public String getDbName() { 87 | return dbName; 88 | } 89 | 90 | public String getGraphName() { 91 | return graphName; 92 | } 93 | 94 | public String getSourceField() { 95 | return sourceField; 96 | } 97 | 98 | public String getTargetField() { 99 | return targetField; 100 | } 101 | 102 | public String getTypeField() { 103 | return typeField; 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/action/KHopMessage.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import com.fasterxml.jackson.core.JsonParser; 4 | import com.fasterxml.jackson.core.JsonProcessingException; 5 | import com.fasterxml.jackson.core.type.TypeReference; 6 | import com.fasterxml.jackson.databind.ObjectMapper; 7 | import org.neo4j.arrow.Neo4jDefaults; 8 | 9 | import java.io.IOException; 10 | import java.nio.charset.StandardCharsets; 11 | import java.util.Map; 12 | 13 | public class KHopMessage implements Message { 14 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(KHopMessage.class); 15 | static private final ObjectMapper mapper = new ObjectMapper(); 16 | 17 | public static final String JSON_KEY_DATABASE_NAME = "db"; 18 | public static final String JSON_KEY_GRAPH_NAME = "graph"; 19 | public static final String JSON_KEY_K = "k"; 20 | public static final String JSON_KEY_REL_PROPERTY = "rel_property"; // XXX is this correct/used? 21 | public static final String JSON_KEY_NODE_ID_PROPERTY = Neo4jDefaults.ID_FIELD; 22 | 23 | private final String dbName; 24 | private final String graph; 25 | private final String relProperty; 26 | private final String nodeIdProperty; 27 | private final int k; 28 | 29 | public KHopMessage(String dbName, String graph, String nodeIdProperty, int k, String relProperty) { 30 | this.dbName = dbName; 31 | this.graph = graph; 32 | this.nodeIdProperty = nodeIdProperty; 33 | this.k = k; 34 | this.relProperty = relProperty; 35 | } 36 | 37 | public String getDbName() { 38 | return dbName; 39 | } 40 | 41 | public String getGraph() { 42 | return graph; 43 | } 44 | 45 | public String getNodeIdProperty() { 46 | return nodeIdProperty; 47 | } 48 | 49 | public int getK() { 50 | return k; 51 | } 52 | 53 | public String getRelProperty() { 54 | return relProperty; 55 | } 56 | 57 | private static class MapTypeReference extends TypeReference> { } 58 | 59 | @Override 60 | public byte[] serialize() { 61 | try { 62 | return mapper.writeValueAsString( 63 | Map.of(JSON_KEY_DATABASE_NAME, dbName, 64 | JSON_KEY_GRAPH_NAME, graph, 65 | JSON_KEY_NODE_ID_PROPERTY, nodeIdProperty, 66 | JSON_KEY_REL_PROPERTY, relProperty, 67 | JSON_KEY_K, k)) 68 | .getBytes(StandardCharsets.UTF_8); 69 | } catch (JsonProcessingException e) { 70 | logger.error("serialization error", e); 71 | } 72 | return new byte[0]; 73 | } 74 | 75 | public static KHopMessage deserialize(byte[] bytes) throws IOException { 76 | final JsonParser parser = mapper.createParser(bytes); 77 | final Map params = parser.readValueAs(new KHopMessage.MapTypeReference()); 78 | 79 | final String dbName = params.getOrDefault(JSON_KEY_DATABASE_NAME, "neo4j").toString(); 80 | final String graph = params.getOrDefault(JSON_KEY_GRAPH_NAME, "random").toString(); 81 | final String nodeIdProperty = params.getOrDefault(JSON_KEY_NODE_ID_PROPERTY, "").toString(); 82 | 83 | /* 84 | XXX note that this property name ('_type_') is the trickery used to describe the direction of the 85 | relationship and NOT the Relationship Type!!! 86 | */ 87 | final String relProperty = params.getOrDefault(JSON_KEY_REL_PROPERTY, "_type_").toString(); 88 | 89 | final Integer k = (Integer) params.getOrDefault(JSON_KEY_K, 2); 90 | 91 | return new KHopMessage(dbName, graph, nodeIdProperty, k, relProperty); 92 | } 93 | 94 | @Override 95 | public String toString() { 96 | return "KHopMessage{" + 97 | "dbName='" + dbName + '\'' + 98 | ", graph='" + graph + '\'' + 99 | ", relProperty='" + relProperty + '\'' + 100 | ", nodeIdProperty='" + nodeIdProperty + '\'' + 101 | ", k=" + k + 102 | '}'; 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/auth/ArrowConnectionInfo.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.auth; 2 | 3 | import org.neo4j.internal.kernel.api.connectioninfo.ClientConnectionInfo; 4 | 5 | /** 6 | * Placeholder for now... 7 | */ 8 | public class ArrowConnectionInfo extends ClientConnectionInfo { 9 | @Override 10 | public String asConnectionDetails() { 11 | return "arrow"; 12 | } 13 | 14 | @Override 15 | public String protocol() { 16 | return "arrow"; 17 | } 18 | 19 | @Override 20 | public String connectionId() { 21 | return "arrow-tbd"; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/auth/NativeAuthValidator.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.auth; 2 | 3 | import org.apache.arrow.flight.CallStatus; 4 | import org.apache.arrow.flight.auth.BasicServerAuthHandler; 5 | import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; 6 | import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; 7 | import org.neo4j.internal.kernel.api.security.AuthenticationResult; 8 | import org.neo4j.internal.kernel.api.security.LoginContext; 9 | import org.neo4j.kernel.api.security.AuthManager; 10 | import org.neo4j.kernel.api.security.AuthToken; 11 | import org.neo4j.kernel.api.security.exception.InvalidAuthTokenException; 12 | import org.neo4j.logging.Log; 13 | 14 | import java.nio.charset.StandardCharsets; 15 | import java.util.Base64; 16 | import java.util.Map; 17 | import java.util.Optional; 18 | import java.util.concurrent.ConcurrentHashMap; 19 | import java.util.function.Supplier; 20 | 21 | /** 22 | * A simple wrapper around using a Neo4j AuthManager for authenticating a client. 23 | *

24 | * Note: this does not check Authorization...it only Authenticates identities> 25 | */ 26 | public class NativeAuthValidator 27 | implements BasicServerAuthHandler.BasicAuthValidator, BasicCallHeaderAuthenticator.CredentialValidator { 28 | 29 | private final Supplier authManager; 30 | private final Log log; 31 | 32 | private static final Base64.Encoder encoder = Base64.getEncoder(); 33 | private static final Base64.Decoder decoder = Base64.getDecoder(); 34 | 35 | public static ConcurrentHashMap contextMap = new ConcurrentHashMap<>(); 36 | 37 | public NativeAuthValidator(Supplier authManager, Log log) { 38 | this.authManager = authManager; 39 | this.log = log; 40 | } 41 | 42 | /** 43 | * Generate a token representation of an identity and secret (password) as a byte array. 44 | *

45 | * The token is in the form of the username, base64 encoded, a

.
character, and the 46 | * password, base64 encoded. (All in UTF-8 encoding.) 47 | * @param username identity of the client 48 | * @param password secret token or password provided by the client 49 | * @return new byte[] containing the token 50 | */ 51 | @Override 52 | public byte[] getToken(String username, String password) { 53 | final String token = 54 | encoder.encodeToString(username.getBytes(StandardCharsets.UTF_8)) + 55 | "." + 56 | encoder.encodeToString(password.getBytes(StandardCharsets.UTF_8)); 57 | 58 | return token.getBytes(StandardCharsets.UTF_8); 59 | } 60 | 61 | /** 62 | * Validate the given token. 63 | *

64 | * Parses the given token byte array, assuming the format as defined by 65 | * {@link #getToken(String, String)}, and validates the password is correct for the username 66 | * contained within the token. 67 | * 68 | * @param token encoded identity and secret 69 | * @return Optional containing the username if valid, empty if not. 70 | */ 71 | @Override 72 | public Optional isValid(byte[] token) { 73 | final String in = new String(token, StandardCharsets.UTF_8); 74 | 75 | final String[] parts = in.split("\\."); 76 | assert parts.length == 2; 77 | 78 | final String username = new String(decoder.decode(parts[0]), StandardCharsets.UTF_8); 79 | final String password = new String(decoder.decode(parts[1]), StandardCharsets.UTF_8); 80 | 81 | try { 82 | // XXX Assumes no exception means success...highly suspect!!! 83 | validate(username, password); 84 | } catch (Exception e) { 85 | return Optional.empty(); 86 | } 87 | 88 | return Optional.of(username); 89 | } 90 | 91 | /** 92 | * Validate the given username and password for the calling client by using the Neo4j 93 | * {@link #authManager} instance. (Does not support realms!) 94 | * 95 | * @param username identity of the client 96 | * @param password secret token or password provided by the client 97 | * @return an AuthResult that resolves to the username of the caller 98 | * @throws InvalidAuthTokenException if the basic auth token is malformed 99 | * @throws RuntimeException if the AuthManager service is unavailable or the identity cannot be 100 | * confirmed 101 | */ 102 | @Override 103 | public CallHeaderAuthenticator.AuthResult validate(String username, String password) 104 | throws RuntimeException, InvalidAuthTokenException { 105 | final Map token = AuthToken.newBasicAuthToken(username, password); 106 | final AuthManager am = authManager.get(); 107 | 108 | if (am == null) { 109 | log.error("Cannot get AuthManager from supplier!"); 110 | throw CallStatus.INTERNAL.withDescription("authentication system unavailable").toRuntimeException(); 111 | } 112 | final LoginContext context = am.login(token, new ArrowConnectionInfo()); 113 | 114 | if (context.subject().getAuthenticationResult() == AuthenticationResult.SUCCESS) { 115 | // XXX: this is HORRIBLE /facepalm 116 | contextMap.put(context.subject().authenticatedUser(), context); 117 | return () -> context.subject().authenticatedUser(); 118 | } 119 | throw CallStatus.UNAUTHENTICATED.withDescription("invalid username or password").toRuntimeException(); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/batchimport/QueueInputIterator.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.batchimport; 2 | 3 | import org.neo4j.internal.batchimport.InputIterator; 4 | 5 | public interface QueueInputIterator extends InputIterator { 6 | void closeQueue(); 7 | boolean isOpen(); 8 | } 9 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/gds/ArrowAdjacencyCursor.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | import org.neo4j.gds.api.AdjacencyCursor; 4 | 5 | import java.util.ArrayList; 6 | import java.util.LinkedList; 7 | import java.util.List; 8 | import java.util.function.Function; 9 | 10 | public class ArrowAdjacencyCursor implements AdjacencyCursor { 11 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ArrowAdjacencyCursor.class); 12 | 13 | private long pos = 0; 14 | 15 | private double fallbackValue = Double.NaN; 16 | private List targets = List.of(); 17 | private Function> queueResolver = (unused) -> new LinkedList<>(); 18 | 19 | protected ArrowAdjacencyCursor(Function> queueResolver) { 20 | this.queueResolver = queueResolver; 21 | } 22 | 23 | protected ArrowAdjacencyCursor(List targets, double fallbackValue) { 24 | this.targets = targets; 25 | this.fallbackValue = fallbackValue; 26 | } 27 | 28 | @Override 29 | public void init(long node, int unused) { 30 | pos = 0; 31 | targets = new ArrayList<>(queueResolver.apply(node)); 32 | } 33 | 34 | @Override 35 | public int size() { 36 | return targets.size(); 37 | } 38 | 39 | @Override 40 | public boolean hasNextVLong() { 41 | return (pos < targets.size()); 42 | } 43 | 44 | @Override 45 | public long nextVLong() { 46 | try { 47 | final int targetIdx = targets.get((int) pos); 48 | pos++; 49 | return targetIdx; 50 | } catch (IndexOutOfBoundsException e) { 51 | logger.warn(String.format("nextVLong() index out of bounds (index=%d)", pos)); 52 | return NOT_FOUND; 53 | } 54 | } 55 | 56 | @Override 57 | public long peekVLong() { 58 | if (hasNextVLong()) 59 | return targets.get((int) pos); // XXX 60 | else 61 | return NOT_FOUND; 62 | } 63 | 64 | @Override 65 | public int remaining() { 66 | return (int) (targets.size() - pos); 67 | } 68 | 69 | @Override 70 | public long skipUntil(long nodeId) { 71 | long next = NOT_FOUND; 72 | while (hasNextVLong()) { 73 | next = nextVLong(); 74 | if (next > nodeId) 75 | break; 76 | } 77 | return next; 78 | } 79 | 80 | @Override 81 | public long advance(long nodeId) { 82 | long next = NOT_FOUND; 83 | while (hasNextVLong()) { 84 | next = nextVLong(); 85 | if (next >= nodeId) 86 | break; 87 | } 88 | return next; 89 | } 90 | 91 | @Override 92 | public AdjacencyCursor shallowCopy(AdjacencyCursor destination) { 93 | final AdjacencyCursor copy = (destination == null) ? 94 | new ArrowAdjacencyCursor(targets, fallbackValue) 95 | : destination; 96 | copy.init(pos, 0); 97 | return copy; 98 | } 99 | 100 | @Override 101 | public void close() { 102 | // nop 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/gds/ArrowAdjacencyList.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | import org.neo4j.gds.api.AdjacencyCursor; 4 | import org.neo4j.gds.api.AdjacencyList; 5 | 6 | import java.util.LinkedList; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.function.Consumer; 10 | 11 | public class ArrowAdjacencyList implements AdjacencyList { 12 | 13 | /** Map of an inner (gds) node id to a queue of neighboring nodes (in outward direction) */ 14 | final private Map> sourceIdMap; 15 | 16 | /** Map of count of incoming relationships for a given inner (gds) node id */ 17 | final private Map inDegreeMap; 18 | 19 | /** Hook into a block of code to call when closing, used for cleanup */ 20 | final private Consumer closeCallback; 21 | 22 | public ArrowAdjacencyList(Map> sourceIdMap, Map inDegreeMap, Consumer closeCallback) { 23 | this.sourceIdMap = sourceIdMap; 24 | this.inDegreeMap = inDegreeMap; 25 | this.closeCallback = closeCallback; 26 | } 27 | 28 | @Override 29 | public int degree(long longNode) { 30 | final int node = (int)longNode; // XXX 31 | return (sourceIdMap.containsKey(node)) 32 | ? sourceIdMap.get(node).size() + inDegreeMap.getOrDefault(node, 0) 33 | : 0; 34 | } 35 | 36 | @Override 37 | public AdjacencyCursor adjacencyCursor(long longNode) { 38 | final int node = (int)longNode; // XXX 39 | if (sourceIdMap.containsKey(node)) { 40 | return new ArrowAdjacencyCursor(sourceIdMap.get(node), Double.NaN); 41 | } 42 | return AdjacencyCursor.EmptyAdjacencyCursor.INSTANCE; 43 | } 44 | 45 | @Override 46 | public AdjacencyCursor adjacencyCursor(long longNode, double fallbackValue) { 47 | final int node = (int)longNode; // XXX 48 | if (sourceIdMap.containsKey(node)) { 49 | return new ArrowAdjacencyCursor(sourceIdMap.get(node), fallbackValue); 50 | } 51 | return AdjacencyCursor.EmptyAdjacencyCursor.INSTANCE; 52 | } 53 | 54 | @Override 55 | public AdjacencyCursor adjacencyCursor(AdjacencyCursor reuse, long longNode) { 56 | final int node = (int)longNode; // XXX 57 | if (sourceIdMap.containsKey(node)) { 58 | return new ArrowAdjacencyCursor(sourceIdMap.get(node), Double.NaN); 59 | } 60 | return AdjacencyCursor.EmptyAdjacencyCursor.INSTANCE; 61 | } 62 | 63 | @Override 64 | public AdjacencyCursor adjacencyCursor(AdjacencyCursor reuse, long longNode, double fallbackValue) { 65 | final int node = (int)longNode; // XXX 66 | if (sourceIdMap.containsKey(node)) { 67 | return new ArrowAdjacencyCursor(sourceIdMap.get(node), fallbackValue); 68 | } 69 | return AdjacencyCursor.EmptyAdjacencyCursor.INSTANCE; 70 | } 71 | 72 | @Override 73 | public AdjacencyCursor rawAdjacencyCursor() { 74 | return new ArrowAdjacencyCursor((node) -> { 75 | int id = node.intValue(); 76 | if (sourceIdMap.containsKey(id)) { 77 | return sourceIdMap.get(id); 78 | } 79 | return new LinkedList<>(); 80 | }); 81 | } 82 | 83 | @Override 84 | public void close() { 85 | closeCallback.accept(null); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/gds/Edge.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | /** 4 | * Optimized relationship format in 64-bits 5 | */ 6 | public class Edge { 7 | public static final long ZERO = 0x0L; 8 | public static final long ORIENTATION_BIT = 0x8000000000000000L; 9 | public static final long FLAG_BITS = 0xF000000000000000L; 10 | public static final long FLAG_MASK = 0x0FFFFFFFFFFFFFFFL; 11 | public static final long TARGET_BITS = 0x000000003FFFFFFFL; 12 | 13 | public static long edge(long source, long target) { 14 | return (source << 30) | target; 15 | } 16 | 17 | public static long edge(long source, long target, boolean flag) { 18 | // assumption: source and target are both < 30 bits in length 19 | return ((source << 30) | target) | (flag ? FLAG_BITS : ZERO); 20 | } 21 | 22 | public static long source(long edge) { 23 | return (edge >> 30) & TARGET_BITS; 24 | } 25 | 26 | public static int sourceAsInt(long edge) { 27 | return (int) source(edge); 28 | } 29 | 30 | public static long target(long edge) { 31 | return (edge & TARGET_BITS); 32 | } 33 | 34 | public static int targetAsInt(long edge) { 35 | return (int) target(edge); 36 | } 37 | 38 | public static long uniquify(long edge) { 39 | // convert an edge to a "unique" form...which for now just means flipping direction if it's reversed 40 | return flag(edge) ? (edge & FLAG_MASK) : edge(target(edge), source(edge)); 41 | } 42 | 43 | public static boolean flag(long edge) { 44 | // TODO: use all 4 bits 45 | return (edge >> 60) != 0; 46 | } 47 | 48 | public static boolean isNatural(long edge) { 49 | // 1 for natural, 0 for reversed 50 | return (edge & ORIENTATION_BIT) != 0; 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/gds/KHop.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | import com.google.common.collect.Streams; 4 | import com.google.common.util.concurrent.ThreadFactoryBuilder; 5 | import org.apache.commons.lang3.tuple.Pair; 6 | import org.neo4j.gds.api.Graph; 7 | 8 | import java.util.Arrays; 9 | import java.util.Collection; 10 | import java.util.HashSet; 11 | import java.util.Set; 12 | import java.util.concurrent.ConcurrentLinkedQueue; 13 | import java.util.concurrent.Executor; 14 | import java.util.concurrent.Executors; 15 | import java.util.concurrent.atomic.AtomicInteger; 16 | import java.util.stream.LongStream; 17 | 18 | /** 19 | * Utilities for our k-hop implementation. 20 | */ 21 | public class KHop { 22 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(KHop.class); 23 | 24 | public static Collection findSupernodes(Graph graph) { 25 | final long nodeCount = graph.nodeCount(); 26 | final long relCount = graph.relationshipCount(); 27 | 28 | final Collection superNodes = new ConcurrentLinkedQueue<>(); 29 | 30 | logger.info(String.format("checking for super nodes in graph with %,d nodes, %,d edges...", nodeCount, relCount)); 31 | final AtomicInteger[] histogram = new AtomicInteger[10]; 32 | for (int i = 0; i < histogram.length; i++) 33 | histogram[i] = new AtomicInteger(0); 34 | 35 | LongStream.range(0, nodeCount) 36 | .parallel() 37 | .mapToObj(id -> Pair.of(id, 0)) 38 | .map(pair -> Pair.of(pair.getLeft(), graph.degree(pair.getLeft()))) 39 | .map(pair -> (pair.getRight() == 0) ? Pair.of(pair.getLeft(), Double.NaN) 40 | : Pair.of(pair.getLeft(), Math.floor(Math.log10(pair.getRight())))) 41 | .map(pair -> (pair.getRight().isNaN()) ? Pair.of(pair.getLeft(), 0) 42 | : Pair.of(pair.getLeft(), pair.getRight().intValue() + 1)) 43 | .forEach(pair -> { 44 | int magnitude = pair.getRight(); 45 | if (magnitude > 4) { // 100k's of edges 46 | superNodes.add(pair.getLeft()); 47 | } 48 | histogram[pair.getRight()].incrementAndGet(); 49 | }); 50 | 51 | logger.info(String.format("\t[ zero ]\t- %,d nodes", histogram[0].get())); 52 | for (int i = 1; i < histogram.length; i++) 53 | logger.info(String.format(" [ 1e%d - 1e%d )\t- %,d nodes", i - 1, i, histogram[i].get())); 54 | logger.info(String.format("%,d potential supernodes!", superNodes.size())); 55 | 56 | return superNodes; 57 | } 58 | 59 | public static SuperNodeCache populateCache(Graph graph, Collection nodes) { 60 | // Pre-cache supernodes adjacency lists 61 | if (nodes.size() > 0) { 62 | logger.info("caching supernodes..."); 63 | final SuperNodeCache supernodeCache = SuperNodeCache.ofSize(nodes.size()); 64 | nodes.parallelStream() 65 | .forEach(superNodeId -> 66 | supernodeCache.set(superNodeId.intValue(), 67 | graph.concurrentCopy() 68 | .streamRelationships(superNodeId, Double.NaN) 69 | .mapToLong(cursor -> { 70 | final boolean isNatural = Double.isNaN(cursor.property()); 71 | return Edge.edge(cursor.sourceId(), cursor.targetId(), isNatural); 72 | }).toArray())); 73 | 74 | logger.info(String.format("cached %,d supernodes (%,d edges)", 75 | nodes.size(), supernodeCache.stream() 76 | .mapToLong(array -> array == null ? 0 : array.length).sum())); 77 | return supernodeCache; 78 | } 79 | 80 | return SuperNodeCache.empty(); 81 | } 82 | 83 | /** 84 | * Perform a k'th hop from the given origin. 85 | * 86 | * @param k depth of current hop/traversal 87 | * @param origin starting node id for the hop 88 | * @param graph reference to a GDS {@link Graph} 89 | * @param history a stateful {@link NodeHistory} to update 90 | * @param cache an optional {@link SuperNodeCache} to leverage 91 | * @return new {@link LongStream} of {@link Edge}s as longs 92 | */ 93 | protected static LongStream hop(int k, int origin, Graph graph, NodeHistory history, SuperNodeCache cache) { 94 | final long[] cachedList = cache.get(origin); 95 | final Graph copy = graph.concurrentCopy(); 96 | 97 | final Set dropSet = new HashSet<>(); 98 | 99 | final LongStream stream; 100 | if (cachedList == null) { 101 | if (k == 0) { 102 | /* 103 | * If we're at the origin, we need to collect a set of the initial neighbor ids 104 | * AND find a list of targets to drop if there's a bi-directional relationship. 105 | * This prevents duplicates without having to keep a huge list of edges. 106 | */ 107 | copy.streamRelationships(origin, Double.NaN) 108 | .mapToInt(cursor -> (int) cursor.targetId()) 109 | .forEach(id -> { 110 | if (history.getAndSet(id)) 111 | dropSet.add(id); 112 | }); 113 | } 114 | stream = copy 115 | .streamRelationships(origin, Double.NaN) 116 | .mapToLong(cursor -> Edge.edge(cursor.sourceId(), cursor.targetId(), (Double.isNaN(cursor.property())))); 117 | } else { 118 | if (k == 0) { 119 | Arrays.stream(cachedList) 120 | .sequential() 121 | .mapToInt(Edge::targetAsInt) 122 | .forEach(id -> { 123 | if (history.getAndSet(id)) 124 | dropSet.add(id); 125 | }); 126 | } 127 | stream = Arrays.stream(cachedList); 128 | } 129 | if (k == 0) { 130 | return stream 131 | .filter(edge -> Edge.isNatural(edge) || !dropSet.contains(Edge.targetAsInt(edge))); 132 | } 133 | return stream; 134 | } 135 | 136 | public static LongStream stream(int origin, int k, Graph graph, SuperNodeCache cache) { 137 | final NodeHistory history = NodeHistory.given((int) graph.nodeCount()); // XXX cast 138 | history.getAndSet(origin); 139 | 140 | LongStream stream = hop(0, origin, graph, history, cache); 141 | 142 | // k-hop...where k == 2 for now ;) There are specific optimizations for the k=2 condition. 143 | for (int i = 1; i < k; i++) { 144 | final int n = i; 145 | stream = stream 146 | .flatMap(edge -> Streams.concat( 147 | LongStream.of(edge), 148 | hop(n, Edge.targetAsInt(edge), graph, history, cache))); 149 | } 150 | 151 | return stream 152 | .sequential() // Be safe since we use a stateful filter 153 | .filter(edge -> Edge.isNatural(edge) || !history.getAndSet(Edge.targetAsInt(edge))) 154 | .map(edge -> (Edge.isNatural(edge)) ? edge : Edge.uniquify(edge)); 155 | } 156 | 157 | public static Executor buildKhopExecutor() { 158 | return Executors.newCachedThreadPool((new ThreadFactoryBuilder()) 159 | .setDaemon(true) 160 | .setNameFormat("khop-%d") 161 | .build()); 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/gds/NodeHistory.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | import org.neo4j.gds.mem.BitUtil; 4 | import org.roaringbitmap.RoaringBitmap; 5 | import org.roaringbitmap.RoaringBitmapWriter; 6 | 7 | import java.nio.ByteBuffer; 8 | import java.util.HashSet; 9 | import java.util.Set; 10 | 11 | public abstract class NodeHistory { 12 | public static final int CUTOFF = 100_000; 13 | 14 | public static NodeHistory given(long numNodes) { 15 | assert(numNodes < (1 << 30)); 16 | if (numNodes < CUTOFF) { 17 | return new SmolNodeHistory((int) numNodes); 18 | } 19 | return new LorgeNodeHistory(numNodes); 20 | } 21 | 22 | public static NodeHistory offHeap(int numNodes) { 23 | return new OffHeapNodeHistory(numNodes); 24 | } 25 | 26 | /** Retrieve the current value for the given bit and set to 1 */ 27 | public abstract boolean getAndSet(int node); 28 | 29 | /** Clear the bitmap, resetting all values to 0 */ 30 | public abstract void clear(); 31 | 32 | protected static class SmolNodeHistory extends NodeHistory { 33 | // XXX not threadsafe 34 | 35 | private final Set set; 36 | 37 | protected SmolNodeHistory(int numNodes) { 38 | set = new HashSet<>(numNodes); 39 | } 40 | 41 | @Override 42 | public boolean getAndSet(int node) { 43 | return !set.add(node); 44 | } 45 | 46 | @Override 47 | public void clear() { 48 | set.clear(); 49 | } 50 | } 51 | 52 | protected static class OffHeapNodeHistory extends NodeHistory { 53 | private final ByteBuffer buffer; 54 | private final int size; 55 | 56 | protected OffHeapNodeHistory(int size) { 57 | this.size = size; 58 | buffer = ByteBuffer.allocateDirect(BitUtil.ceilDiv(size, Byte.SIZE)); 59 | } 60 | 61 | @Override 62 | public boolean getAndSet(int node) { 63 | if (0 > node || node >= size) 64 | throw new IndexOutOfBoundsException(String.format("%d out of bounds for history of size %d", node, size)); 65 | 66 | boolean result; 67 | final int index = Math.floorDiv(node, Byte.SIZE); 68 | final int bitMask = 1 << Math.floorMod(node, Byte.SIZE); 69 | synchronized (buffer) { 70 | final byte b = buffer.get(index); 71 | result = (b & bitMask) != 0; 72 | buffer.put(index, (byte) (b | bitMask)); 73 | } 74 | return result; 75 | } 76 | 77 | @Override 78 | public void clear() { 79 | buffer.clear(); // XXX not sure if this works 80 | } 81 | } 82 | 83 | protected static class LorgeNodeHistory extends NodeHistory { 84 | 85 | private final RoaringBitmap bitmap; 86 | 87 | protected LorgeNodeHistory(long numNodes) { 88 | final RoaringBitmapWriter writer = RoaringBitmapWriter.writer() 89 | .expectedRange(0, numNodes) 90 | .optimiseForArrays() 91 | .get(); 92 | bitmap = writer.get(); 93 | } 94 | 95 | @Override 96 | public boolean getAndSet(int node) { 97 | // XXX not threadsafe 98 | final boolean result = bitmap.contains(node); 99 | bitmap.add(node); 100 | return result; 101 | } 102 | 103 | @Override 104 | public void clear() { 105 | bitmap.clear(); 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /plugin/src/main/java/org/neo4j/arrow/gds/SuperNodeCache.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | import java.util.Map; 4 | import java.util.concurrent.ConcurrentHashMap; 5 | import java.util.stream.Stream; 6 | 7 | /** 8 | * Simple wrapper class to let us experiment with differing approaches to implementing 9 | * a supernode cache. 10 | */ 11 | public class SuperNodeCache { 12 | private final Map cache; 13 | 14 | private SuperNodeCache(int size) { 15 | if (size > 0) { 16 | cache = new ConcurrentHashMap<>(size); 17 | } else { 18 | cache = Map.of(); 19 | } 20 | } 21 | 22 | public static SuperNodeCache empty() { 23 | return new SuperNodeCache(0); 24 | } 25 | 26 | public static SuperNodeCache ofSize(int size) { 27 | return new SuperNodeCache(size); 28 | } 29 | 30 | public long[] get(int index) { 31 | if (cache != null) { 32 | return cache.get(index); 33 | } 34 | return null; 35 | } 36 | 37 | public void set(int index, long[] value) { 38 | cache.put(index, value); 39 | } 40 | 41 | public Stream stream() { 42 | return cache.values().stream(); 43 | } 44 | } -------------------------------------------------------------------------------- /plugin/src/main/resources/META-INF/services/org.neo4j.kernel.extension.ExtensionFactory: -------------------------------------------------------------------------------- 1 | org.neo4j.arrow.ArrowExtensionFactory -------------------------------------------------------------------------------- /plugin/src/test/java/org/neo4j/arrow/GdsRecordBenchmarkTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.apache.arrow.flight.Action; 4 | import org.apache.arrow.flight.Location; 5 | import org.apache.arrow.memory.RootAllocator; 6 | import org.junit.jupiter.api.Assertions; 7 | import org.junit.jupiter.api.Disabled; 8 | import org.junit.jupiter.api.Timeout; 9 | import org.junit.jupiter.params.ParameterizedTest; 10 | import org.junit.jupiter.params.provider.Arguments; 11 | import org.junit.jupiter.params.provider.MethodSource; 12 | import org.neo4j.arrow.action.GdsActionHandler; 13 | import org.neo4j.arrow.action.GdsMessage; 14 | import org.neo4j.arrow.demo.Client; 15 | import org.neo4j.arrow.job.ReadJob; 16 | import org.neo4j.gds.NodeLabel; 17 | import org.neo4j.logging.Log; 18 | import org.neo4j.logging.log4j.Log4jLogProvider; 19 | 20 | import java.util.List; 21 | import java.util.Set; 22 | import java.util.concurrent.CompletableFuture; 23 | import java.util.concurrent.TimeUnit; 24 | import java.util.function.BiConsumer; 25 | import java.util.function.Function; 26 | import java.util.stream.IntStream; 27 | import java.util.stream.LongStream; 28 | import java.util.stream.Stream; 29 | 30 | /** 31 | * A simple local perf test, nothing fancy. Just make sure data can flow and flow effectively. Used 32 | * mostly for doing flame graphs of the call stack to find hot spots. 33 | */ 34 | public class GdsRecordBenchmarkTest { 35 | private final static Log log; 36 | 37 | static { 38 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 39 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 40 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 41 | log = new Log4jLogProvider(System.out).getLog(GdsRecordBenchmarkTest.class); 42 | } 43 | 44 | private static RowBasedRecord getRecord(String type) { 45 | final int size = 256; 46 | switch (type) { 47 | case "float": 48 | float[] data = new float[size]; 49 | for (int i=0; i future; 73 | final int numResults; 74 | 75 | public NoOpJob(int numResults, CompletableFuture signal, String type) { 76 | super(); 77 | this.numResults = numResults; 78 | 79 | future = CompletableFuture.supplyAsync(() -> { 80 | try { 81 | log.info("Job starting"); 82 | final RowBasedRecord record = getRecord(type); 83 | onFirstRecord(record); 84 | log.info("Job feeding"); 85 | BiConsumer consumer = super.futureConsumer.join(); 86 | 87 | IntStream.range(1, numResults).parallel().forEach(i -> consumer.accept(record, i)); 88 | 89 | signal.complete(System.currentTimeMillis()); 90 | log.info("Job finished"); 91 | onCompletion(() -> "done"); 92 | return numResults; 93 | } catch (Exception e) { 94 | Assertions.fail(e); 95 | } 96 | return 0; 97 | }); 98 | } 99 | 100 | @Override 101 | public boolean cancel(boolean mayInterruptIfRunning) { 102 | return false; 103 | } 104 | 105 | @Override 106 | public void close() {} 107 | } 108 | 109 | private static Stream provideDifferentNativeArrays() { 110 | return Stream.of( 111 | Arguments.of("long"), 112 | Arguments.of("float"), 113 | Arguments.of("double") 114 | ); 115 | } 116 | 117 | @Disabled 118 | @ParameterizedTest 119 | @MethodSource("provideDifferentNativeArrays") 120 | @Timeout(value = 5, unit = TimeUnit.MINUTES) 121 | public void testSpeed(String type) throws Exception { 122 | final Location location = Location.forGrpcInsecure("localhost", 12345); 123 | final CompletableFuture signal = new CompletableFuture<>(); 124 | 125 | try (App app = new App(new RootAllocator(Long.MAX_VALUE), location); 126 | Client client = new Client(new RootAllocator(Long.MAX_VALUE), location)) { 127 | 128 | app.registerHandler(new GdsActionHandler( 129 | (msg, mode, username) -> new NoOpJob(1_000_000, signal, type), log)); 130 | app.start(); 131 | 132 | final long start = System.currentTimeMillis(); 133 | final GdsMessage msg = new GdsMessage("neo4j", "mygraph", 134 | GdsMessage.RequestType.node, List.of("fastRp"), List.of(), ""); 135 | final Action action = new Action(GdsActionHandler.GDS_READ_ACTION, msg.serialize()); 136 | client.run(action); 137 | final long stop = signal.join(); 138 | log.info(String.format("Client Lifecycle Time: %,d ms", stop - start)); 139 | 140 | app.awaitTermination(5, TimeUnit.SECONDS); 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /plugin/src/test/java/org/neo4j/arrow/action/GdsMessageTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import java.io.IOException; 7 | import java.util.List; 8 | 9 | public class GdsMessageTest { 10 | @Test 11 | public void testGdsMessageSerialization() throws IOException { 12 | final GdsMessage msg = new GdsMessage("db1", "graph1", 13 | GdsMessage.RequestType.node, List.of("prop1", "prop2"), List.of("filter1"), ""); 14 | final byte[] bytes = msg.serialize(); 15 | 16 | final GdsMessage msg2 = GdsMessage.deserialize(bytes); 17 | Assertions.assertEquals("db1", msg2.getDbName()); 18 | Assertions.assertEquals("graph1", msg2.getGraphName()); 19 | Assertions.assertEquals(GdsMessage.RequestType.node, msg2.getRequestType()); 20 | Assertions.assertArrayEquals(new String[] { "prop1", "prop2" }, msg2.getProperties().toArray(new String[2])); 21 | Assertions.assertArrayEquals(new String[] { "filter1" }, msg2.getFilters().toArray(new String[1])); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /plugin/src/test/java/org/neo4j/arrow/auth/NativeAuthValidatorTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.auth; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | import org.neo4j.kernel.api.security.AuthManager; 6 | 7 | public class NativeAuthValidatorTest { 8 | @Test 9 | public void testTokenLifecycle() throws Exception { 10 | final String username = "dave"; 11 | final String password = "so it goes!"; 12 | 13 | final NativeAuthValidator validator = new NativeAuthValidator(() -> AuthManager.NO_AUTH, null); 14 | final byte[] out = validator.getToken(username, password); 15 | final String newUsername = validator.isValid(out).get(); 16 | 17 | Assertions.assertEquals(username, newUsername, "usernames should match"); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /plugin/src/test/java/org/neo4j/arrow/gds/NodeHistoryTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.gds; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | 6 | public class NodeHistoryTest { 7 | @Test 8 | public void testSmallEdgeCache() { 9 | NodeHistory small = NodeHistory.given(4); 10 | int node1 = 1; 11 | int node2 = 2; 12 | int node3 = 2; 13 | int node4 = 3; 14 | 15 | Assertions.assertFalse(small.getAndSet(node1)); 16 | Assertions.assertFalse(small.getAndSet(node2)); 17 | Assertions.assertTrue(small.getAndSet(node3)); 18 | Assertions.assertFalse(small.getAndSet(node4)); 19 | } 20 | 21 | @Test 22 | public void testHugeEdgeCache() { 23 | NodeHistory small = NodeHistory.given(3_000_000); 24 | 25 | int node1 = 1; 26 | int node2 = 2; 27 | int node3 = 2; 28 | int node4 = 3; 29 | 30 | Assertions.assertFalse(small.getAndSet(node1)); 31 | Assertions.assertFalse(small.getAndSet(node2)); 32 | Assertions.assertTrue(small.getAndSet(node3)); 33 | Assertions.assertFalse(small.getAndSet(node4)); 34 | } 35 | 36 | @Test 37 | public void testOffHeapEdgeCache() { 38 | NodeHistory offheap = NodeHistory.offHeap(200); 39 | int node1 = 42; 40 | Assertions.assertFalse(offheap.getAndSet(node1)); 41 | Assertions.assertTrue(offheap.getAndSet(node1)); 42 | Assertions.assertFalse(offheap.getAndSet(node1 + 1)); 43 | Assertions.assertFalse(offheap.getAndSet(node1 - 1)); 44 | 45 | Assertions.assertFalse(offheap.getAndSet(0)); 46 | Assertions.assertFalse(offheap.getAndSet(199)); 47 | Assertions.assertThrows(IndexOutOfBoundsException.class, 48 | () -> offheap.getAndSet(200)); 49 | Assertions.assertThrows(IndexOutOfBoundsException.class, 50 | () -> offheap.getAndSet(-1)); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /plugin/src/test/java/org/neo4j/arrow/job/CypherRecordTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | import org.neo4j.arrow.CypherRecord; 6 | import org.neo4j.arrow.RowBasedRecord; 7 | import org.neo4j.graphdb.Node; 8 | import org.neo4j.graphdb.Path; 9 | import org.neo4j.graphdb.Relationship; 10 | import org.neo4j.graphdb.Result; 11 | 12 | import java.util.List; 13 | 14 | public class CypherRecordTest { 15 | private static final org.slf4j.Logger logger; 16 | 17 | static { 18 | // Set up nicer logging output. 19 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 20 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 21 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 22 | logger = org.slf4j.LoggerFactory.getLogger(CypherRecordTest.class); 23 | } 24 | 25 | @Test 26 | public void canInferTypes() { 27 | Result.ResultRow row = new Result.ResultRow() { 28 | @Override 29 | public Node getNode(String key) { 30 | return null; 31 | } 32 | 33 | @Override 34 | public Relationship getRelationship(String key) { 35 | return null; 36 | } 37 | 38 | @Override 39 | public Object get(String key) { 40 | switch (key) { 41 | case "int": 42 | return 123; 43 | case "bool": 44 | return true; 45 | case "string": 46 | return "string"; 47 | case "float": 48 | return 1.23f; 49 | default: 50 | return null; 51 | } 52 | } 53 | 54 | @Override 55 | public String getString(String key) { 56 | return "hey"; 57 | } 58 | 59 | @Override 60 | public Number getNumber(String key) { 61 | return 123; 62 | } 63 | 64 | @Override 65 | public Boolean getBoolean(String key) { 66 | return true; 67 | } 68 | 69 | @Override 70 | public Path getPath(String key) { 71 | return null; 72 | } 73 | }; 74 | 75 | final CypherRecord record = CypherRecord.wrap(row, List.of("int", "bool", "string", "float")); 76 | Assertions.assertEquals(RowBasedRecord.Type.INT, record.get("int").type()); 77 | Assertions.assertEquals(RowBasedRecord.Type.FLOAT, record.get("float").type()); 78 | Assertions.assertEquals(RowBasedRecord.Type.STRING, record.get("string").type()); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /plugin/src/test/java/org/neo4j/arrow/job/EdgePackingTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | import org.neo4j.arrow.gds.Edge; 6 | 7 | public class EdgePackingTest { 8 | @Test 9 | public void testEdgePackingLogic() { 10 | long startId = 123; 11 | long endId = 456; 12 | boolean isNatural = true; 13 | 14 | long edge = Edge.edge(startId, endId, isNatural); 15 | System.out.printf("0x%X\n", edge); 16 | System.out.printf("start: %d ?? %d\n", startId, Edge.source(edge)); 17 | Assertions.assertEquals(startId, Edge.source(edge)); 18 | 19 | System.out.printf("end: %d ?? %d\n", endId, Edge.target(edge)); 20 | Assertions.assertEquals(endId, Edge.target(edge)); 21 | 22 | System.out.printf("isNatural?: %s : %s\n", isNatural, Edge.flag(edge)); 23 | Assertions.assertEquals(isNatural, Edge.flag(edge)); 24 | 25 | startId = 300_000_000; 26 | endId = 0; 27 | isNatural = false; 28 | edge = Edge.edge(startId, endId, isNatural); 29 | System.out.printf("0x%X\n", edge); 30 | System.out.printf("start: %d ?? %d\n", startId, Edge.source(edge)); 31 | Assertions.assertEquals(startId, Edge.source(edge)); 32 | 33 | System.out.printf("end: %d ?? %d\n", endId, Edge.target(edge)); 34 | Assertions.assertEquals(endId, Edge.target(edge)); 35 | 36 | System.out.printf("isNatural?: %s : %s\n", isNatural, Edge.flag(edge)); 37 | Assertions.assertEquals(isNatural, Edge.flag(edge)); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /python/pyarrow/__init__.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Union 2 | from pyarrow.lib import Array, ChunkedArray, DataType, Mapping, \ 3 | MemoryPool, Schema, Table 4 | 5 | def array(obj: Union[Sequence[Any], Iterable[Any]], # also ndarray or Series 6 | type: Optional[DataType] = None, mask: Any = None, size: Optional[int] = None, 7 | from_pandas: Optional[bool] = None, safe: bool = True, 8 | memory_pool: Optional[MemoryPool] = None) -> Union[Array, ChunkedArray]: ... 9 | 10 | def concat_tables(tables: Iterable[Table], promote: Optional[bool] = False, 11 | MemoryPool: Any = None) -> Table: ... 12 | 13 | def enable_signal_handlers(enable: bool) -> None: ... 14 | 15 | def table(data: Any, names: Optional[List[str]] = None, 16 | schema: Optional[Schema] = None, 17 | metadata: Optional[Union[Dict[Any, Any], Mapping]] = None, 18 | nthreads: Optional[int] = None) -> Table: ... 19 | -------------------------------------------------------------------------------- /python/pyarrow/csv.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, AnyStr, Dict, IO, List, Optional, Union 2 | from pyarrow.lib import MemoryPool, RecordBatch, Schema, Table 3 | 4 | class CSVStreamingReader(): 5 | def read_all(self) -> Table: ... 6 | def read_next_batch(self) -> RecordBatch: ... 7 | schema: Schema 8 | 9 | class ConvertOptions(): 10 | def __init__(self, check_utf8: bool = True, 11 | column_types: Optional[Union[Dict[Any, Any], Schema]] = None, 12 | null_values: Optional[List[str]] = None, 13 | true_values: Optional[List[str]] = None, 14 | false_values:Optional[List[str]] = None, 15 | decimal_point: str = '.', 16 | timestamp_parsers: Optional[List[str]] = None, 17 | strings_can_be_null: bool = False, 18 | quoted_strings_can_be_null: bool = True, 19 | auto_dict_encode: bool = False, 20 | auto_dict_max_cardinality: Optional[int] = None, 21 | include_columns: Optional[List[str]] = None, 22 | include_missing_columns: bool = False) -> None: ... 23 | 24 | class ParseOptions(): ... 25 | 26 | class ReadOptions(): 27 | def __init__(self, use_threads: bool = True, 28 | block_size: Optional[int] = None, 29 | skip_rows: int = 0, skip_rows_after_names: int = 0, 30 | column_names: Optional[List[str]] = None, 31 | autogenerate_column_names: bool = False, 32 | encoding: str = 'utf8') -> None: ... 33 | 34 | def open_csv(input_file: Union[str, IO[AnyStr]], 35 | read_options: Optional[ReadOptions] = None, 36 | parse_options: Optional[ParseOptions] = None, 37 | convert_options: Optional[ConvertOptions] = None) \ 38 | -> CSVStreamingReader: ... 39 | 40 | def read_csv(input_file: Union[str, IO[AnyStr]], 41 | read_options: Optional[ReadOptions] = None, 42 | parse_options: Optional[ParseOptions] = None, 43 | convert_options: Optional[ConvertOptions] = None, 44 | memory_pool: Optional[MemoryPool] = None) -> Table: ... 45 | -------------------------------------------------------------------------------- /python/pyarrow/flight.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, Iterator, Generator, List, Optional, \ 2 | Tuple, Union 3 | from pyarrow.lib import Buffer, RecordBatch, Schema, Table 4 | 5 | class Action: ... 6 | 7 | class FlightCallOptions: 8 | def __init__(self, timeout: Optional[float] = None, write_options: Any = None, 9 | headers: Optional[Union[List[Tuple[str, str]], 10 | List[Tuple[bytes, bytes]]]] = None) \ 11 | -> None: ... 12 | 13 | class FlightDescriptor: 14 | @classmethod 15 | def deserialize(cls, serialized: bytes) -> FlightDescriptor: ... 16 | @classmethod 17 | def for_command(cls, command: bytes) -> FlightDescriptor: ... 18 | @classmethod 19 | def for_path(cls, path: str) -> FlightDescriptor: ... 20 | command: Any 21 | descriptor_type: Any 22 | path: Any 23 | 24 | class FlightInfo: 25 | schema: Schema 26 | descriptor: FlightDescriptor 27 | total_bytes: int 28 | total_records: int 29 | 30 | class FlightMetadataReader: ... 31 | 32 | class FlightStreamChunk: 33 | data: RecordBatch 34 | app_metadata: Union[None, Any] 35 | 36 | class FlightStreamReader(Iterable[FlightStreamChunk]): 37 | def __iter__(self) -> Iterator[FlightStreamChunk]: ... 38 | def read_chunk(self) -> FlightStreamChunk: ... 39 | schema: Schema 40 | 41 | class _CRecordBatchWriter: 42 | def close(self) -> None: ... 43 | def write(self, table_or_batch: Union[RecordBatch, Table]) -> None: ... 44 | def write_table(self, table: Table, 45 | max_chunksize: Optional[int] = None, 46 | kwargs: Optional[Dict[str, Any]] = None) -> None: ... 47 | 48 | class MetadataRecordBatchWriter(_CRecordBatchWriter): 49 | def write_batch(self, batch: RecordBatch) -> None: ... 50 | def write_metadata(self, buf: Any) -> None: ... 51 | def write_with_metdata(self, batch: RecordBatch, buf: Any) -> None: ... 52 | 53 | class FlightStreamWriter(MetadataRecordBatchWriter): 54 | def done_writing(self) -> None: ... 55 | 56 | 57 | class Location: 58 | @classmethod 59 | def for_grpc_tcp(cls, host: str, port: int) -> Location: ... 60 | @classmethod 61 | def for_grpc_tls(cls, host: str, port: int) -> Location: ... 62 | @classmethod 63 | def for_grpc_unix(cls, path: str) -> Location: ... 64 | 65 | class Result: 66 | def __init__(self, buf: Union[Buffer, bytes]) -> None: ... 67 | body: Buffer 68 | 69 | class Ticket: 70 | def serialize(self) -> bytes: ... 71 | @classmethod 72 | def deserialize(cls, serialized: bytes) -> Ticket: ... 73 | 74 | class FlightClient: 75 | def __init__(self, location: Location, 76 | tls_root_certs: Optional[bytes] = None, 77 | cert_chain: Optional[bytes] = None, 78 | private_key: Optional[bytes] = None, 79 | override_hostname: Optional[str] = None, 80 | middleware: Optional[List[Any]] = None, 81 | write_size_limit_bytes: Optional[int] = None, 82 | disable_server_verification: bool = False, 83 | generic_options: Optional[List[Any]] = None) -> None: ... 84 | 85 | # TODO: do_action() supports other types for 'action' arg 86 | def do_action(self, action: Union[Tuple[str, bytes], Action], 87 | options: Optional[FlightCallOptions] = None) \ 88 | -> Iterator[Result]: ... 89 | 90 | def do_get(self, ticket: Ticket, 91 | options: Optional[FlightCallOptions] = None) \ 92 | -> FlightStreamReader: ... 93 | 94 | def do_put(self, descriptor: FlightDescriptor, schema: Schema, 95 | options: Optional[FlightCallOptions] = None) \ 96 | -> Tuple[FlightStreamWriter, FlightStreamReader]: ... 97 | 98 | def get_flight_info(self, descriptor: FlightDescriptor, 99 | options: Optional[FlightCallOptions] = None) -> FlightInfo: ... 100 | 101 | def list_actions(self, options: Optional[FlightCallOptions] = None) \ 102 | -> List[Action]: ... 103 | 104 | def list_flights(self, criteria: Optional[bytes] = None, 105 | options: Optional[FlightCallOptions] = None) \ 106 | -> Generator[FlightInfo, None, None]: ... 107 | 108 | def wait_for_available(self, timeout: int = 5) -> None: ... 109 | -------------------------------------------------------------------------------- /python/pyarrow/lib.pyi: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping as ABC_Mapping 2 | from typing import Any, Dict, Iterable, Iterator, List, Optional, Union, Sequence 3 | 4 | class Array(): ... 5 | 6 | class ArrowException(Exception): ... 7 | class ArrowKeyError(ArrowException): ... 8 | 9 | class Buffer: 10 | def to_pybytes(self) -> bytes: ... 11 | address: int 12 | is_cpu: bool 13 | is_mutable: bool 14 | parent: Buffer 15 | size: int 16 | 17 | class ChunkedArray(Array): ... 18 | 19 | class DataType: ... 20 | class Field: ... 21 | class Mapping: ... 22 | class MemoryPool: ... 23 | class Metadata(): ... 24 | 25 | class Schema: 26 | def with_metadata(self, metadata: Dict[bytes, bytes]) -> Schema: ... 27 | metadata: Dict[bytes, bytes] 28 | names: List[str] 29 | 30 | class RecordBatch: 31 | @classmethod 32 | def from_pydict(cls, mapping: Union[Dict[Any, Any], Mapping], 33 | schema: Optional[Schema] = None, 34 | metadata: Optional[Union[Dict[Any, Any], Mapping]] = None) \ 35 | -> RecordBatch: ... 36 | def __len__(self) -> int: ... 37 | def replace_schema_metadata(self, 38 | metadata: Optional[Dict[Any, Any]] = None) \ 39 | -> RecordBatch: ... 40 | def to_pydict(self) -> Dict[Any, Any]: ... 41 | 42 | nbytes: int 43 | num_columns: int 44 | num_rows: int 45 | schema: Schema 46 | 47 | class Table: 48 | @classmethod 49 | def from_pydict(cls, mapping: Dict[Any, Any], 50 | schema: Optional[Schema] = None, 51 | metadata: Optional[Union[Dict[Any, Any], 52 | Mapping]] = None) -> Table: ... 53 | @classmethod 54 | def from_batches(cls, batches: Union[Sequence[RecordBatch], 55 | Iterator[RecordBatch]], 56 | schema: Optional[Schema] = None) -> Table: ... 57 | def __len__(self) -> int: ... 58 | def append_column(self, field: Union[str, Field], 59 | column: Union[Array, List[Array], List[Any]]) -> Table: ... 60 | def replace_schema_metadata(self, 61 | metadata: Optional[Dict[Any, Any]] = None) \ 62 | -> Table: ... 63 | def slice(self, offset: Optional[int] = 0, 64 | length: Optional[int] = None) -> Table: ... 65 | def to_batches(self, max_chunksize: Optional[int] = None, 66 | kwargs: Optional[Dict[Any, Any]] = None) \ 67 | -> List[RecordBatch]: ... 68 | def to_pydict(self) -> Dict[Any, Any]: ... 69 | 70 | nbytes: int 71 | num_columns: int 72 | num_rows: int 73 | schema: Schema 74 | 75 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | pyarrow==6.0.1 2 | -------------------------------------------------------------------------------- /server/README.md: -------------------------------------------------------------------------------- 1 | > **NOTE**: I'm no longer hacking on this server sub-project! 2 | 3 | # neo4j-arrow "proxy" server 4 | Can't install the `neo4j-arrow` [plugin](../plugin)? Try the proxy. 5 | 6 | > This approach support Cypher jobs, but does *not* support GDS jobs! 7 | 8 | ## Building 9 | Simplest way is to run the `shadowJar` task (from the project root): 10 | 11 | ``` 12 | $ ./gradlew :server:shadowJar 13 | ``` 14 | 15 | ## Running 16 | I'm lazy, so the config is all via environment variables: 17 | 18 | * `NEO4J_URL`: bolt url (default: `neo4j://localhost:7687`) 19 | * `NEO4J_USERNAME`: (default: `neo4j`) 20 | * `NEO4J_PASSWORD`: (default: `password`) 21 | * `NEO4J_DATABASE`: (default: `neo4j`) 22 | * `HOST`: ip/hostname to listen on (default: `localhost`) 23 | * `PORT`: tpc port to listen on (default: `9999`) 24 | 25 | Some tuning config available: 26 | 27 | * `MAX_MEM_GLOBAL`: maximum memory to allow allocating by the Arrow global 28 | allocator (default: java's `Long.MAX_VALUE`) 29 | * `MAX_MEM_STREAM`: maximum memory allowed to be allocated by an individual 30 | stream (default: java's `Integer.MAX_VALUE`) 31 | * `ARROW_BATCH_SIZE`: size of the transmitted vector batches (i.e. number of 32 | "rows") (default: `25,000`) 33 | * `BOLT_FETCH_SIZE`: (default: `1,000`) 34 | 35 | Set any of the above and run: 36 | 37 | ``` 38 | $ java -jar /server-1.0-SNAPSHOT.jar 39 | ``` -------------------------------------------------------------------------------- /server/src/main/java/org/neo4j/arrow/DriverRecord.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.neo4j.driver.Record; 4 | import org.neo4j.driver.internal.types.InternalTypeSystem; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Wraps the Records returned by the Neo4jDriver. 10 | */ 11 | public class DriverRecord implements RowBasedRecord { 12 | private final Record record; 13 | 14 | private DriverRecord(Record record) { 15 | this.record = record; 16 | } 17 | 18 | public static RowBasedRecord wrap(Record record) { 19 | return new DriverRecord(record); 20 | } 21 | 22 | private static RowBasedRecord.Value wrapValue(org.neo4j.driver.Value value) { 23 | return new Value() { 24 | @Override 25 | public int size() { 26 | if (value.hasType(InternalTypeSystem.TYPE_SYSTEM.LIST())) 27 | return value.asList().size(); 28 | return 1; 29 | } 30 | 31 | @Override 32 | public int asInt() { 33 | return value.asInt(); 34 | } 35 | 36 | @Override 37 | public long asLong() { 38 | return value.asLong(); 39 | } 40 | 41 | @Override 42 | public float asFloat() { 43 | return value.asFloat(); 44 | } 45 | 46 | @Override 47 | public double asDouble() { 48 | return value.asDouble(); 49 | } 50 | 51 | @Override 52 | public String asString() { 53 | return value.asString(); 54 | } 55 | 56 | @Override 57 | public List asList() { 58 | return value.asList(org.neo4j.driver.Value::asObject); 59 | } 60 | 61 | @Override 62 | public List asIntList() { 63 | return value.asList(org.neo4j.driver.Value::asInt); 64 | } 65 | 66 | @Override 67 | public List asLongList() { 68 | return value.asList(org.neo4j.driver.Value::asLong); 69 | } 70 | 71 | @Override 72 | public List asFloatList() { 73 | return value.asList(org.neo4j.driver.Value::asFloat); 74 | } 75 | 76 | @Override 77 | public List asDoubleList() { 78 | return value.asList(org.neo4j.driver.Value::asDouble); 79 | } 80 | 81 | @Override 82 | public double[] asDoubleArray() { 83 | return null; 84 | } 85 | 86 | public Type type() { 87 | switch (value.type().name()) { 88 | case "INTEGER": 89 | return Type.INT; 90 | case "LONG": 91 | return Type.LONG; 92 | case "FLOAT": 93 | case "DOUBLE": 94 | return Type.DOUBLE; 95 | case "STRING": 96 | return Type.STRING; 97 | case "LIST OF ANY?": 98 | return Type.LIST; 99 | } 100 | // Catch-all 101 | return Type.OBJECT; 102 | } 103 | }; 104 | } 105 | 106 | @Override 107 | public Value get(int index) { 108 | return wrapValue(record.get(index)); 109 | } 110 | 111 | @Override 112 | public Value get(String field) { 113 | return wrapValue(record.get(field)); 114 | } 115 | 116 | @Override 117 | public List keys() { 118 | return record.keys(); 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /server/src/main/java/org/neo4j/arrow/demo/Neo4jProxyServer.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.demo; 2 | 3 | import org.apache.arrow.flight.Location; 4 | import org.apache.arrow.memory.BufferAllocator; 5 | import org.apache.arrow.memory.RootAllocator; 6 | import org.apache.arrow.util.AutoCloseables; 7 | import org.neo4j.arrow.App; 8 | import org.neo4j.arrow.Config; 9 | import org.neo4j.arrow.action.CypherActionHandler; 10 | import org.neo4j.arrow.job.AsyncDriverJob; 11 | import org.neo4j.driver.AuthTokens; 12 | 13 | import java.util.concurrent.TimeUnit; 14 | 15 | /** 16 | * A simple implementation of a Neo4j Arrow Service. Acts as a stand-alone bridge or proxy to a 17 | * remote Neo4j instance, offering Cypher-only services. (No native GDS integration.) 18 | */ 19 | public class Neo4jProxyServer { 20 | private static final org.slf4j.Logger logger; 21 | 22 | static { 23 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 24 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 25 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 26 | System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "info"); 27 | logger = org.slf4j.LoggerFactory.getLogger(Neo4jProxyServer.class); 28 | } 29 | 30 | public static void main(String[] args) throws Exception { 31 | long timeout = 15; 32 | TimeUnit unit = TimeUnit.MINUTES; 33 | 34 | final BufferAllocator bufferAllocator = new RootAllocator(Config.maxArrowMemory); 35 | final App app = new App( 36 | bufferAllocator, 37 | Location.forGrpcInsecure(Config.host, Config.port)); 38 | 39 | final CypherActionHandler cypherHandler = new CypherActionHandler( 40 | (cypherMsg, mode, username) -> 41 | new AsyncDriverJob(cypherMsg, mode, AuthTokens.basic(Config.username, Config.password))); 42 | app.registerHandler(cypherHandler); 43 | 44 | app.start(); 45 | 46 | Runtime.getRuntime().addShutdownHook(new Thread(() -> { 47 | try { 48 | logger.info("Shutting down..."); 49 | AutoCloseables.close(app, bufferAllocator); 50 | logger.info("Stopped."); 51 | } catch (Exception e) { 52 | logger.error("Failure during shutdown!", e); 53 | } 54 | })); 55 | 56 | logger.info("Will terminate after timeout of {} {}", timeout, unit); 57 | app.awaitTermination(timeout, unit); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /server/src/main/java/org/neo4j/arrow/job/AsyncDriverJob.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import org.neo4j.arrow.DriverRecord; 4 | import org.neo4j.arrow.RowBasedRecord; 5 | import org.neo4j.arrow.action.CypherMessage; 6 | import org.neo4j.driver.*; 7 | import org.neo4j.driver.async.AsyncSession; 8 | import org.neo4j.driver.summary.ResultSummary; 9 | import org.slf4j.LoggerFactory; 10 | 11 | import java.util.concurrent.CompletableFuture; 12 | import java.util.concurrent.ConcurrentHashMap; 13 | import java.util.concurrent.ConcurrentMap; 14 | import java.util.function.BiConsumer; 15 | 16 | /** 17 | * Implementation of a Neo4jJob that uses an AsyncSession via the Neo4j Java Driver. 18 | */ 19 | public class AsyncDriverJob extends ReadJob { 20 | private static final org.slf4j.Logger logger = LoggerFactory.getLogger(AsyncDriverJob.class); 21 | 22 | /* 1 Driver per identity for now...gross simplification. */ 23 | private static final ConcurrentMap driverMap = new ConcurrentHashMap<>(); 24 | 25 | private final AsyncSession session; 26 | private final CompletableFuture future; 27 | 28 | public AsyncDriverJob(CypherMessage msg, Mode mode, AuthToken authToken) { 29 | super(); 30 | 31 | Driver driver; 32 | if (!driverMap.containsKey(authToken)) { 33 | org.neo4j.driver.Config.ConfigBuilder builder = org.neo4j.driver.Config.builder(); 34 | driver = GraphDatabase.driver(org.neo4j.arrow.Config.neo4jUrl, authToken, 35 | builder.withUserAgent("Neo4j-Arrow-Proxy/alpha") 36 | .withMaxConnectionPoolSize(8) 37 | .withFetchSize(org.neo4j.arrow.Config.boltFetchSize) 38 | .build()); 39 | driverMap.put(authToken, driver); 40 | } else { 41 | driver = driverMap.get(authToken); 42 | } 43 | 44 | this.session = driver.asyncSession(SessionConfig.builder() 45 | .withDatabase(org.neo4j.arrow.Config.database) 46 | .withDefaultAccessMode(AccessMode.valueOf(mode.name())) 47 | .build()); 48 | 49 | future = session.runAsync(msg.getCypher(), msg.getParams()) 50 | .thenComposeAsync(resultCursor -> { 51 | logger.info("Job {} producing", session); 52 | setStatus(Status.PRODUCING); 53 | 54 | /* We need to inspect the first record and guess at a schema :-( */ 55 | final RowBasedRecord firstRecord = DriverRecord.wrap( 56 | resultCursor.peekAsync().toCompletableFuture().join()); 57 | onFirstRecord(firstRecord); 58 | 59 | final BiConsumer consumer = futureConsumer.join(); 60 | 61 | // Hacky for now 62 | return resultCursor.forEachAsync(record -> 63 | CompletableFuture.runAsync(() -> { 64 | // Trying to find a "cheap" way to partition the data lock-free 65 | int partition = (int) (System.currentTimeMillis() & ((1 << 30) - 1)); 66 | consumer.accept(DriverRecord.wrap(record), partition); 67 | })); 68 | }).whenCompleteAsync((resultSummary, throwable) -> { 69 | if (throwable != null) { 70 | setStatus(Status.ERROR); 71 | logger.error("Job failure", throwable); 72 | } else { 73 | logger.info("Job {} complete", session); 74 | setStatus(Status.COMPLETE); 75 | } 76 | onCompletion(DriverJobSummary.wrap(resultSummary)); 77 | session.closeAsync().toCompletableFuture().join(); 78 | }).toCompletableFuture(); 79 | } 80 | 81 | @Override 82 | public boolean cancel(boolean mayInterruptIfRunning) { 83 | return future.cancel(mayInterruptIfRunning); 84 | } 85 | 86 | @Override 87 | public void close() { 88 | future.cancel(true); 89 | session.closeAsync(); 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /server/src/main/java/org/neo4j/arrow/job/DriverJobSummary.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import org.neo4j.driver.summary.ResultSummary; 4 | 5 | public class DriverJobSummary implements JobSummary { 6 | private final String resultSummary; 7 | private DriverJobSummary(ResultSummary summary) { 8 | this.resultSummary = summary.toString(); 9 | } 10 | 11 | public static JobSummary wrap(ResultSummary summary) { 12 | return summary::toString; 13 | } 14 | 15 | @Override 16 | public String asString() { 17 | return resultSummary; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'neo4j-arrow' 2 | include 'common' 3 | include 'client' 4 | include 'plugin' 5 | include 'server' 6 | include 'strawman' 7 | -------------------------------------------------------------------------------- /speed/26-aug-2021/README.md: -------------------------------------------------------------------------------- 1 | # Performance Testing 🏎 2 | **Date**: `26 Aug 2021` 3 | 4 | ## Test Environment 🔬 5 | 6 | ### Systems 💻 7 | * **Host** 8 | * Neo4j `v4.3.3 Enterprise` 9 | * GDS `v1.6.4` 10 | * GCP `c2-standard-30` (30 vCPUs, 120 GB memory) 11 | * Ubuntu `21.04` (Linux 5.11.0-1017-gcp) 12 | * Zone `northamerica-northeast1-a` 13 | * **Client** 14 | * Python `v3.9.5` 15 | * Neo4j Java Driver `v4.3.3` 16 | * Neo4j Python Driver `v4.3.4` 17 | * Ubuntu `21.04` (Linux 5.11.0-1017-gcp) 18 | * GCP `e2-standard-8` (8 vCPUs, 32 GB memory) 19 | * Zone `northamerica-northeast1-a` 20 | * **Data** 21 | * [PaySim](https://storage.googleapis.com/neo4j-paysim/paysim-07may2021.dump) 22 | * `1,892,751` Nodes 23 | * `5,576,578` Relationships 24 | 25 | ### The Graph Projection 📽 26 | The graph projection isn't complicated: it's just the entire graph :-) 27 | ```cypher 28 | CALL gds.graph.create('mygraph', '*', '*'); 29 | ``` 30 | From a cold start, this should take about `1,480 ms`. 31 | 32 | ### The Embeddings 🕸 33 | 34 | ```cypher 35 | CALL gds.fastRP.mutate('mygraph', { 36 | concurrency: 28, 37 | embeddingDimension: 256, 38 | mutateProperty: 'n' 39 | }); 40 | ``` 41 | 42 | From a cold start, this should take about `4,670 ms`. 43 | 44 | ## Methodology 🧪 45 | 46 | 1. Warm-up with 5 runs per client type 47 | 2. Run 5 times after warm-up, recording throughput 48 | 3. Average the 3 best runs for the final result 49 | 50 | ### The Python Driver 🐍 51 | The code is in [direct.py](../direct.py). It runs the following Cypher: 52 | 53 | ```python 54 | CALL gds.graph.streamNodeProperty('mygraph', 'n'); 55 | ``` 56 | 57 | Time is measured using `time.time()` after submitting the transaction and 58 | just prior to iterating through results. Time is stopped after the last result. 59 | 60 | ### The Java Driver ☕ 61 | The code is in the [strawman](../../strawman) subproject. 62 | 63 | It also runs the same Cypher as the Python driver: 64 | 65 | ```python 66 | CALL gds.graph.streamNodeProperty('mygraph', 'n'); 67 | ``` 68 | 69 | Time is measured using `System.currentTimeMillis()` after submitting the 70 | transaction and just prior to iterating through results. Time is stopped 71 | after the last result. 72 | 73 | ### The Neo4j-Arrow Client 🏹 74 | The code is in [client.py](./client.py). 75 | 76 | It submits the following `GdsMessage` and then requests the stream: 77 | 78 | ```json 79 | { 80 | "db": "neo4j", 81 | "graph": "mygraph", 82 | "filters": [], 83 | "properties": ["n"] 84 | } 85 | ``` 86 | 87 | Time is measured by `time.time()` after submitting the request for the 88 | stream and before iterating through the Arrow batches. 89 | 90 | --- 91 | 92 | ## Observations 🔍 93 | 94 | Each test is measured in "rows/second" where each "row" is like: 95 | 96 | ``` 97 | [(long) nodeId, (float[256]) embedding] 98 | ``` 99 | 100 | Each row is approximately anywhere from `8 + 256*[4..8] ==> 1032..2056` 101 | bytes with some potential for additional metadata overhead. As a result, the 102 | approximate raw payload of 1.9M vectors is approximately 1.8 GiB or greater. 103 | 104 | ### Table 1: Performance Results 📋 105 | 106 | The best observations used to generate the average: 107 | 108 | | Client | Best 1 | Best 2 | Best 3 | Average | Std.Dev | 109 | | -------------- | --------- | --------- | --------- | --------- | ------- | 110 | | Python Driver | 2,249 | 2,247 | 2,225 | 2,240 | 11 | 111 | | Java Driver | 57,356 | 57,356 | 55,669 | 56,794 | 795 | 112 | | Neo4j-Arrow | 1,055,474 | 1,030,174 | 1,008,737 | 1,031,461 | 19,102 | 113 | 114 | ### Table 2: Discarded Results 🗑 115 | 116 | The discarded observations (not the Top 3): 117 | 118 | | Client | Discarded 1 | Discarded 2 | 119 | | -------------- | ----------- | ----------- | 120 | | Python Driver | 2,177 | 2,137 | 121 | | Java Driver | 55,669 | 55,669 | 122 | | Neo4j-Arrow | 879,622 | 1,004,904 | 123 | 124 | ### Figure 1: Log plot of Streaming Speed 📈 125 | 126 | ![Average of Best 3 Results](./figure1.png?raw=true) 127 | 128 | > Note: the exponential trendline fits quite well 😁 129 | 130 | # Conclusions 🤔 131 | This test focused on a specific use-case: _streaming homogenous feature 132 | vectors._ In this use case, Neo4j-Arrow is **orders of magnitude faster than 133 | traditional approaches with Python and over 1 order of magnitude faster than 134 | using the Java Driver.** 135 | 136 | ## Discussion 🦜 137 | There are some fundamental differences between Neo4j-Arrow and the existing 138 | drivers: 139 | 140 | 1. **Parallelism**: Arrow Vectors can be assembled in parallel from the GDS 141 | in-memory graph. It's not clear if the GDS procs can be changed to 142 | benefit from concurrency as it would require changes to the Cypher 143 | runtime and/or to Bolt. 144 | 145 | 2. **Efficiency**: Arrow can leverage schema with homogenous data to more 146 | efficiently pack data on the wire, especially sparse data. (In this test, 147 | the data isn't sparse.) On the client-side, client's don't need to 148 | "deserialize" the data off the wire and only interpret the data on-demand, 149 | as-needed...and with the help of a schema. -------------------------------------------------------------------------------- /speed/26-aug-2021/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neo4j-field/neo4j-arrow/4eed6311eb5b0f7beae77a668ed282ce1cdced83/speed/26-aug-2021/figure1.png -------------------------------------------------------------------------------- /speed/26-aug-2021/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neo4j-field/neo4j-arrow/4eed6311eb5b0f7beae77a668ed282ce1cdced83/speed/26-aug-2021/img.png -------------------------------------------------------------------------------- /speed/direct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import neo4j 3 | from time import time 4 | 5 | query = """ 6 | UNWIND range(1, $rows) AS row 7 | RETURN row, [_ IN range(1, $dimension) | rand()] as fauxEmbedding 8 | """ 9 | params = {"rows": 1_000_000, "dimension": 128} 10 | 11 | auth = ('neo4j', 'password') 12 | with neo4j.GraphDatabase.driver('neo4j://localhost:7687', auth=auth) as d: 13 | with d.session(fetch_size=10_000) as s: 14 | print(f"Starting query {query}") 15 | result = s.run(query, params) 16 | cnt = 0 17 | start = time() 18 | for row in result: 19 | cnt = cnt + 1 20 | if cnt % 50_000 == 0: 21 | print(f"Current Row @ {cnt:,}:\t[fields: {row.keys()}]") 22 | finish = time() 23 | print(f"Done! Time Delta: {round(finish - start, 2):,}s") 24 | print(f"Count: {cnt:,}, Rate: {round(cnt / (finish - start)):,} rows/s") 25 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/App.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.apache.arrow.flight.FlightServer; 4 | import org.apache.arrow.flight.Location; 5 | import org.apache.arrow.flight.LocationSchemes; 6 | import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; 7 | import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; 8 | import org.apache.arrow.memory.BufferAllocator; 9 | import org.apache.arrow.util.AutoCloseables; 10 | import org.neo4j.arrow.action.ActionHandler; 11 | import org.neo4j.arrow.action.auth.HorribleBasicAuthValidator; 12 | 13 | import java.io.File; 14 | import java.io.IOException; 15 | import java.util.concurrent.TimeUnit; 16 | 17 | /** 18 | * An Arrow Flight Application for integrating Neo4j and Apache Arrow 19 | */ 20 | public class App implements AutoCloseable { 21 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(App.class); 22 | 23 | private final FlightServer server; 24 | private final Location location; 25 | private final Producer producer; 26 | private final BufferAllocator allocator; 27 | private final String name; 28 | 29 | /** 30 | * Create a new Arrow Flight application using the provided memory allocator. It will listen on 31 | * the provided {@link Location}. 32 | * 33 | * @param rootAllocator main {@link BufferAllocator} for use by the application 34 | * @param location hostname/port to listen on for incoming API calls 35 | */ 36 | public App(BufferAllocator rootAllocator, Location location) { 37 | this(rootAllocator, location, "unnamed-app", 38 | new BasicCallHeaderAuthenticator(new HorribleBasicAuthValidator())); 39 | } 40 | 41 | /** 42 | * Create a new Arrow Flight application using the provided memory allocator. It will listen on 43 | * the provided {@link Location}. 44 | * 45 | * @param rootAllocator main {@link BufferAllocator} for use by the application 46 | * @param location hostname/port to listen on for incoming API calls 47 | * @param name identifiable name for the service 48 | */ 49 | @SuppressWarnings("unused") 50 | public App(BufferAllocator rootAllocator, Location location, String name) { 51 | this(rootAllocator, location, name, 52 | new BasicCallHeaderAuthenticator(new HorribleBasicAuthValidator())); 53 | } 54 | 55 | /** 56 | * Create a new Arrow Flight application using the provided memory allocator. It will listen on 57 | * the provided {@link Location}. 58 | *

59 | * Utilizes the provided {@link CallHeaderAuthenticator} for authenticating client calls and 60 | * requests. 61 | * @param rootAllocator main {@link BufferAllocator} for use by the application 62 | * @param location hostname/port to listen on for incoming API calls 63 | * @param name identifiable name for the service 64 | * @param authenticator a {@link CallHeaderAuthenticator} to use for authenticating API calls 65 | */ 66 | public App(BufferAllocator rootAllocator, Location location, String name, CallHeaderAuthenticator authenticator) { 67 | allocator = rootAllocator.newChildAllocator("neo4j-flight-server", 0, Config.maxArrowMemory); 68 | this.location = location; 69 | this.producer = new Producer(allocator, location); 70 | final FlightServer.Builder builder = FlightServer.builder(rootAllocator, location, this.producer) 71 | // XXX header auth expects basic HTTP headers in the GRPC calls 72 | .headerAuthenticator(authenticator); 73 | 74 | if (location.getUri().getScheme().equalsIgnoreCase(LocationSchemes.GRPC_TLS)) { 75 | try { 76 | final File cert = new File(Config.tlsCertficate); 77 | final File privateKey = new File(Config.tlsPrivateKey); 78 | builder.useTls(cert, privateKey); 79 | } catch (IOException e) { 80 | logger.error("could not initialize TLS FlightServer", e); 81 | throw new RuntimeException("failed to initialize a TLS FlightServer", e); 82 | } 83 | } 84 | 85 | this.server = builder.build(); 86 | this.name = name; 87 | } 88 | 89 | public void registerHandler(ActionHandler handler) { 90 | producer.registerHandler(handler); 91 | } 92 | 93 | public void start() throws IOException { 94 | server.start(); 95 | logger.info("server listening @ {}", location.getUri().toString()); 96 | } 97 | 98 | public void awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { 99 | server.awaitTermination(timeout, unit); 100 | } 101 | 102 | @SuppressWarnings("unused") 103 | public Location getLocation() { 104 | return location; 105 | } 106 | 107 | @Override 108 | public String toString() { 109 | return "NeojFlightApp { name: " + name + ", location: " + location.toString() + " }"; 110 | } 111 | 112 | @Override 113 | public void close() throws Exception { 114 | logger.info("closing"); 115 | AutoCloseables.close(producer, server, allocator); 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/Config.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import java.util.Map; 4 | import java.util.stream.Stream; 5 | 6 | /** 7 | * Super simple environment-based config. 8 | *

9 | * Warning: the password is stored in the environment in plaintext!!! 10 | *

11 | */ 12 | public class Config { 13 | public final static Long TiB = (1L << 40); 14 | public final static Long GiB = (1L << 30); 15 | public final static Long MiB = (1L << 20); 16 | public final static Map scaleMap = Map.of( 17 | "T", TiB, "t", TiB, 18 | "G", GiB, "g", GiB, 19 | "M", MiB, "m", MiB, 20 | "", 1L 21 | ); 22 | 23 | public static Long parseMemory(String memory) { 24 | return Stream.of(memory) 25 | .map(s -> { 26 | if (!Character.isDigit(s.charAt(s.length() - 1))) { 27 | return new String[] { s.substring(0, s.length() - 1), s.substring(s.length() - 1) }; 28 | } 29 | return new String[] { s, "" }; 30 | }) 31 | .map(pair -> Long.parseLong(pair[0]) * scaleMap.get(pair[1])) 32 | .findFirst() 33 | .orElse(Long.MAX_VALUE); 34 | } 35 | 36 | /** Bolt URL for accessing a remote Neo4j database */ 37 | public final static String neo4jUrl = System.getenv().getOrDefault("NEO4J_URL", "neo4j://localhost:7687"); 38 | /** Username for any Neo4j driver connection */ 39 | public final static String username = System.getenv().getOrDefault("NEO4J_USERNAME", "neo4j"); 40 | /** Password for any Neo4j driver connection */ 41 | public final static String password = System.getenv().getOrDefault("NEO4J_PASSWORD", "password"); 42 | /** Name of the default Neo4j database to use */ 43 | public final static String database = System.getenv().getOrDefault("NEO4J_DATABASE", "neo4j"); 44 | 45 | /** Hostname or IP address to listen on when running as a server or connect to when a client */ 46 | public final static String host = System.getenv().getOrDefault("HOST", "localhost"); 47 | /** Port number to listen on or connect to. */ 48 | public final static int port = Integer.parseInt(System.getenv().getOrDefault("PORT", "9999")); 49 | 50 | /** Maximum native memory allowed to be allocated by the global allocator and its children */ 51 | public final static long maxArrowMemory = parseMemory( 52 | System.getenv().getOrDefault("MAX_MEM_GLOBAL", String.valueOf(Runtime.getRuntime().maxMemory()))); 53 | 54 | /** Maximum native memory allowed to be allocated by a single stream */ 55 | public final static long maxStreamMemory = parseMemory( 56 | System.getenv().getOrDefault("MAX_MEM_STREAM", String.valueOf(Runtime.getRuntime().maxMemory()))); 57 | 58 | /** Arrow Batch Size controls the size of the transmitted vectors.*/ 59 | public final static int arrowBatchSize = Math.abs(Integer.parseInt( 60 | System.getenv().getOrDefault("ARROW_BATCH_SIZE", Integer.toString(1024)) 61 | )); 62 | 63 | /** Arrow parallelism */ 64 | public final static int arrowMaxPartitions = Math.abs(Integer.parseInt( 65 | System.getenv().getOrDefault("ARROW_MAX_PARTITIONS", 66 | String.valueOf(Runtime.getRuntime().availableProcessors() * 2)))); 67 | 68 | /** Arrow flush timeout in seconds */ 69 | public final static int arrowFlushTimeout = Math.abs(Integer.parseInt( 70 | System.getenv().getOrDefault("ARROW_FLUSH_TIMEOUT", Integer.toString(60 * 30)) 71 | )); 72 | 73 | /** Arrow max List size, used by k-hop */ 74 | public final static int arrowMaxListSize = Math.abs(Integer.parseInt( 75 | System.getenv().getOrDefault("ARROW_MAX_LIST_SIZE", String.valueOf(2048)))); 76 | 77 | /** Bolt fetch size controls how many Records we PULL at a given time. Should be set lower 78 | * than the Arrow Batch size. 79 | */ 80 | public final static long boltFetchSize = Math.abs(Long.parseLong( 81 | System.getenv().getOrDefault("BOLT_FETCH_SIZE", Long.toString(1_000)) 82 | )); 83 | 84 | public final static String tlsCertficate = System.getenv().getOrDefault("ARROW_TLS_CERTIFICATE", ""); 85 | public final static String tlsPrivateKey = System.getenv().getOrDefault("ARROW_TLS_PRIVATE_KEY", ""); 86 | } 87 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/RowBasedRecord.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * Adapter for different ways to represent raw values from Neo4j. For instance, from a Driver or 7 | * the underlying Server tx api. 8 | */ 9 | public interface RowBasedRecord { 10 | 11 | /** 12 | * Supported record types. Each {@link Type} should have a logical mapping to an Arrow type. 13 | */ 14 | enum Type { 15 | /** Scalar integer (32 bit, signed) */ 16 | INT, 17 | /** Array of integers (32 bit, signed) */ 18 | INT_ARRAY, 19 | /** Scalar long (64 bit, signed) */ 20 | LONG, 21 | /** Array of longs (64 bit, signed) */ 22 | LONG_ARRAY, 23 | /** Scalar floating point (single precision, 32-bit) */ 24 | FLOAT, 25 | /** Array of floating points (single precision, 32-bit) */ 26 | FLOAT_ARRAY, 27 | /** Scalar of floating point (double precision, 64-bit) */ 28 | DOUBLE, 29 | /** Array of floating points (double precision, 64-bit) */ 30 | DOUBLE_ARRAY, 31 | /** Variable length UTF-8 string (varchar) */ 32 | STRING, 33 | /** List of Strings */ 34 | STRING_LIST, 35 | /** Heterogeneous array of supported {@link Type}s */ 36 | LIST, 37 | LONG_LIST, 38 | INT_LIST, 39 | /** Catch-all...TBD */ 40 | OBJECT; 41 | } 42 | 43 | /** Retrieve a {@link Value} from the record by positional index */ 44 | Value get(int index); 45 | /** Retrieve a {@link Value} from the record by named field */ 46 | Value get(String field); 47 | 48 | /** List of fields in this {@link RowBasedRecord} */ 49 | List keys(); 50 | 51 | /** 52 | * 53 | */ 54 | interface Value { 55 | /* Number of primitives or inner values */ 56 | default int size() { 57 | return 1; 58 | } 59 | 60 | default int asInt() { 61 | return 0; 62 | } 63 | 64 | default long asLong() { 65 | return 0L; 66 | } 67 | 68 | default float asFloat() { 69 | return 0f; 70 | } 71 | 72 | default double asDouble() { 73 | return 0d; 74 | } 75 | 76 | default String asString() { 77 | return ""; 78 | } 79 | 80 | default List asList() { 81 | return List.of(); 82 | } 83 | 84 | @SuppressWarnings("unused") 85 | default List asIntList() { 86 | return List.of(); 87 | } 88 | 89 | default int[] asIntArray() { 90 | return new int[0]; 91 | } 92 | 93 | @SuppressWarnings("unused") 94 | default List asLongList() { 95 | return List.of(); 96 | } 97 | 98 | default long[] asLongArray() { 99 | return new long[0]; 100 | } 101 | 102 | default List asFloatList() { 103 | return List.of(); 104 | } 105 | 106 | default float[] asFloatArray() { 107 | return new float[0]; 108 | } 109 | 110 | default List asDoubleList() { 111 | return List.of(); 112 | } 113 | 114 | default double[] asDoubleArray() { 115 | return new double[0]; 116 | } 117 | 118 | default List asStringList() { return List.of(); } 119 | 120 | default Type type() { 121 | return Type.OBJECT; 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/action/ActionHandler.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.apache.arrow.flight.Action; 4 | import org.apache.arrow.flight.ActionType; 5 | import org.apache.arrow.flight.FlightProducer; 6 | import org.neo4j.arrow.Producer; 7 | 8 | import java.util.List; 9 | 10 | /** 11 | * An {@link ActionHandler} provides handling for different Arrow Flight {@link Action} types. 12 | *

13 | * When registered with a {@link Producer}, any {@link Action} that matches a type in an 14 | * instances {@link #actionTypes()} will be passed to the given {@link ActionHandler} via 15 | * its {@link #handle(FlightProducer.CallContext, Action, Producer)} method. 16 | *

17 | */ 18 | public interface ActionHandler { 19 | /** 20 | * Get the list of action types supported by this handler. 21 | * 22 | * @return a {@link List} of supported {@link Action} types as {@link String}s 23 | */ 24 | List actionTypes(); 25 | 26 | /** 27 | * Get a list of descriptions (as {@link ActionType} instances) supported by this handler. 28 | * 29 | * @return a {@link List} of {@link ActionType}s supported 30 | */ 31 | List actionDescriptions(); 32 | 33 | /** 34 | * Handle an Arrow Flight RPC {@link Action}. 35 | * 36 | * @param context a reference to the caller's {@link org.apache.arrow.flight.FlightProducer.CallContext} 37 | * providing access to the peer identity of the caller 38 | * @param action the {@link Action} to process 39 | * @param producer reference to the controlling {@link Producer} 40 | * @return new Outcome 41 | */ 42 | Outcome handle(FlightProducer.CallContext context, Action action, Producer producer); 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/action/Message.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | public interface Message { 4 | byte[] serialize(); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/action/Outcome.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.apache.arrow.flight.CallStatus; 4 | import org.apache.arrow.flight.Result; 5 | 6 | import java.util.Optional; 7 | 8 | /** 9 | * The outcome of processing an Arrow Flight RPC call. 10 | *

11 | * Effectively a pair of {@link Optional} results: a {@link Result} on success and a 12 | * {@link CallStatus} on failure. 13 | *

14 | */ 15 | public class Outcome { 16 | private final Result result; 17 | private final CallStatus callStatus; 18 | 19 | protected Outcome(Result result, CallStatus callStatus) { 20 | this.result = result; 21 | this.callStatus = callStatus; 22 | 23 | assert ((result == null) ^ (callStatus == null)); 24 | } 25 | 26 | /** 27 | * Creates a new failure {@link Outcome} from the provided {@link CallStatus}. 28 | * 29 | * @param callStatus {@link CallStatus} representing the Arrow Flight RPC failure event 30 | * @return a new {@link Outcome} 31 | */ 32 | public static Outcome failure(CallStatus callStatus) { 33 | return new Outcome(null, callStatus); 34 | } 35 | 36 | /** 37 | * Creates a new successful {@link Outcome} from the provided Arrow Flight RPC {@link Result}. 38 | * 39 | * @param result the successful {@link Result} to return 40 | * @return a new {@link Outcome} 41 | */ 42 | public static Outcome success(Result result) { 43 | return new Outcome(result, null); 44 | } 45 | 46 | /** 47 | * Returns whether the {@link Outcome} is considered successful or not. 48 | * 49 | * @return true if successful, otherwise false. 50 | */ 51 | public boolean isSuccessful() { 52 | return result != null; 53 | } 54 | 55 | public Optional getResult() { 56 | return Optional.ofNullable(result); 57 | } 58 | 59 | public Optional getCallStatus() { 60 | return Optional.ofNullable(callStatus); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/action/ServerInfoHandler.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import com.fasterxml.jackson.core.JsonProcessingException; 4 | import com.fasterxml.jackson.databind.ObjectMapper; 5 | import org.apache.arrow.flight.*; 6 | import org.neo4j.arrow.Producer; 7 | import org.neo4j.arrow.job.Job; 8 | import org.slf4j.Logger; 9 | import org.slf4j.LoggerFactory; 10 | 11 | import java.util.HashMap; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.jar.Manifest; 15 | import java.util.stream.Collectors; 16 | 17 | public class ServerInfoHandler implements ActionHandler { 18 | private static final Logger logger = LoggerFactory.getLogger(ServerInfoHandler.class); 19 | 20 | public static final String SERVER_JOBS = "info.jobs"; 21 | public static final String SERVER_VERSION = "info.version"; 22 | 23 | protected static final String VERSION_KEY = "neo4j-arrow-version"; 24 | private static final String MANIFEST_KEY_BASE = "neo4j-arrow"; 25 | private static final ObjectMapper mapper = new ObjectMapper(); 26 | 27 | /** Reference to the Producer's job map. Expected to be thread safe. */ 28 | private final Map jobMap; 29 | 30 | public ServerInfoHandler(Map jobMap) { 31 | this.jobMap = jobMap; 32 | } 33 | 34 | @Override 35 | public List actionTypes() { 36 | return List.of(SERVER_JOBS, SERVER_VERSION); 37 | } 38 | 39 | @Override 40 | public List actionDescriptions() { 41 | return List.of( 42 | new ActionType(SERVER_JOBS, "List currently active Jobs"), 43 | new ActionType(SERVER_VERSION, "Get metadata on server info")); 44 | } 45 | 46 | protected static Map cachedArrowVersion = null; 47 | 48 | protected Map getJobs() { 49 | // TODO: accept a ticket or job id as filter? 50 | return jobMap.values() 51 | .stream() 52 | .collect(Collectors.toUnmodifiableMap( 53 | Job::getJobId, 54 | (job) -> job.getStatus().toString())); 55 | } 56 | 57 | protected synchronized static Map getArrowVersion() { 58 | 59 | if (cachedArrowVersion != null) { 60 | return cachedArrowVersion; 61 | } 62 | 63 | final Map map = new HashMap<>(); 64 | 65 | try { 66 | var iterator = ServerInfoHandler.class.getProtectionDomain().getClassLoader() 67 | .getResources("META-INF/MANIFEST.MF").asIterator(); 68 | 69 | while (iterator.hasNext()) { 70 | var url = iterator.next(); 71 | try { 72 | final Manifest m = new Manifest(url.openStream()); 73 | m.getMainAttributes().forEach((k, v) -> { 74 | if (k.toString().toLowerCase().startsWith(MANIFEST_KEY_BASE)) { 75 | logger.debug("caching server info ({}: {})", k, v); 76 | map.put(k.toString(), v.toString()); 77 | } 78 | }); 79 | } catch (Exception e) { 80 | // XXX unhandled 81 | } 82 | } 83 | } catch (Exception e) { 84 | logger.error("failure getting manifest", e); 85 | } 86 | cachedArrowVersion = map; 87 | return cachedArrowVersion; 88 | } 89 | 90 | @Override 91 | public Outcome handle(FlightProducer.CallContext context, Action action, Producer producer) { 92 | try { 93 | Map map; 94 | switch (action.getType().toLowerCase()) { 95 | case SERVER_JOBS: 96 | // TODO: support filtering by a ticket 97 | map = getJobs(); 98 | break; 99 | case SERVER_VERSION: 100 | map = getArrowVersion(); 101 | break; 102 | default: 103 | return Outcome.failure(CallStatus.UNKNOWN.withDescription("Unsupported action for job status handler")); 104 | } 105 | final byte[] payload = mapper.writeValueAsBytes(map); 106 | return Outcome.success(new Result(payload)); 107 | } catch (JsonProcessingException e) { 108 | logger.error("json error", e); 109 | return Outcome.failure(CallStatus.INTERNAL); 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/action/StatusHandler.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.apache.arrow.flight.*; 4 | import org.neo4j.arrow.Producer; 5 | import org.neo4j.arrow.job.Job; 6 | import org.slf4j.Logger; 7 | import org.slf4j.LoggerFactory; 8 | 9 | import java.io.IOException; 10 | import java.nio.ByteBuffer; 11 | import java.nio.charset.StandardCharsets; 12 | import java.util.List; 13 | 14 | /** 15 | * Report on the status of active, completed, or pending Jobs 16 | */ 17 | public class StatusHandler implements ActionHandler { 18 | private static final Logger logger = LoggerFactory.getLogger(StatusHandler.class); 19 | 20 | public static final String STATUS_ACTION = "job.status"; 21 | 22 | @Override 23 | public List actionTypes() { 24 | return List.of(STATUS_ACTION); 25 | } 26 | 27 | @Override 28 | public List actionDescriptions() { 29 | return List.of(new ActionType(STATUS_ACTION, "Check the status of a Job")); 30 | } 31 | 32 | @Override 33 | public Outcome handle(FlightProducer.CallContext context, Action action, Producer producer) { 34 | // TODO: standardize on matching logic? case sensitive/insensitive? 35 | if (!action.getType().equalsIgnoreCase(STATUS_ACTION)) { 36 | return Outcome.failure(CallStatus.UNKNOWN.withDescription("Unsupported action for job status handler")); 37 | } 38 | 39 | try { 40 | final Ticket ticket = Ticket.deserialize(ByteBuffer.wrap(action.getBody())); 41 | final Job job = producer.getJob(ticket); 42 | if (job != null) { 43 | return Outcome.success(new Result(job.getStatus().toString().getBytes(StandardCharsets.UTF_8))); 44 | } 45 | return Outcome.failure(CallStatus.NOT_FOUND.withDescription("no job for ticket")); 46 | } catch (IOException e) { 47 | logger.error("Problem servicing cypher status action", e); 48 | return Outcome.failure(CallStatus.INTERNAL.withDescription(e.getMessage())); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/action/auth/HorribleBasicAuthValidator.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action.auth; 2 | 3 | import org.apache.arrow.flight.auth.BasicServerAuthHandler; 4 | import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; 5 | import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; 6 | 7 | import java.nio.charset.StandardCharsets; 8 | import java.util.Arrays; 9 | import java.util.Base64; 10 | import java.util.Optional; 11 | 12 | /** 13 | * Seriously, try not to use this. It's an auth validator that uses an in-memory (on jvm heap) token 14 | * containing a username and password for authenticating with the Arrow Flight service. 15 | *

16 | * THIS SHOULD BE USED FOR TESTING OR LOCAL HACKING ONLY! 17 | *

18 | */ 19 | public class HorribleBasicAuthValidator 20 | implements BasicServerAuthHandler.BasicAuthValidator, BasicCallHeaderAuthenticator.CredentialValidator { 21 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(HorribleBasicAuthValidator.class); 22 | 23 | // XXX you really shouldn't be using this auth validator! It simply keeps an in-memory token. 24 | /** The hardcoded auth token in the form of username:password. Set via environment variable. */ 25 | private static final String HARDCODED_TOKEN = System.getenv() 26 | .getOrDefault("NEO4J_ARROW_TOKEN", "neo4j:password"); 27 | 28 | @Override 29 | public byte[] getToken(String username, String password) { 30 | logger.debug("getToken called: username={}, password={}", username, password); 31 | final String token = Base64.getEncoder() 32 | .encodeToString((username + ":" + password).getBytes(StandardCharsets.UTF_8)); 33 | logger.debug("token = {}", token); 34 | return token.getBytes(StandardCharsets.UTF_8); 35 | } 36 | 37 | @Override 38 | public Optional isValid(byte[] token) { 39 | logger.debug("isValid called: token={}", token); 40 | 41 | // TODO: make an auth handler that isn't this silly 42 | if (Arrays.equals(HARDCODED_TOKEN.getBytes(StandardCharsets.UTF_8), token)) 43 | return Optional.of("neo4j"); 44 | else 45 | return Optional.empty(); 46 | } 47 | 48 | @Override 49 | public CallHeaderAuthenticator.AuthResult validate(String username, String password) throws Exception { 50 | logger.debug("validate called for username={}", username); 51 | if (username.equals("neo4j") && password.equals("password")) 52 | return () -> username; 53 | else 54 | throw new Exception("Oh, fiddlesticks"); 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/batch/ArrowBatch.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.batch; 2 | 3 | import org.apache.arrow.memory.BufferAllocator; 4 | import org.apache.arrow.vector.FieldVector; 5 | import org.apache.arrow.vector.ValueVector; 6 | import org.apache.arrow.vector.VectorSchemaRoot; 7 | import org.apache.arrow.vector.complex.ListVector; 8 | import org.apache.arrow.vector.complex.reader.BaseReader; 9 | import org.apache.arrow.vector.complex.reader.FieldReader; 10 | import org.apache.arrow.vector.types.pojo.Field; 11 | import org.apache.arrow.vector.util.TransferPair; 12 | 13 | import java.util.Arrays; 14 | import java.util.List; 15 | import java.util.function.Consumer; 16 | import java.util.stream.IntStream; 17 | 18 | /** 19 | * TBD 20 | */ 21 | public class ArrowBatch implements AutoCloseable { 22 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ArrowBatch.class); 23 | 24 | private final String[] fieldNames; 25 | private final ValueVector[] vectors; 26 | private final int rowCount; 27 | 28 | protected ArrowBatch(String[] fieldNames, ValueVector[] vectors, int rowCount) { 29 | this.fieldNames = fieldNames; 30 | this.vectors = vectors; 31 | this.rowCount = rowCount; 32 | } 33 | 34 | /** 35 | * Build an {@link ArrowBatch} from the given {@link VectorSchemaRoot}, transferring buffers 36 | * to the given {@link BufferAllocator}; 37 | * 38 | * @param root VectorSchemaRoot source for the ValueVectors 39 | * @param intoAllocator BufferAllocator to take ownership of the current Arrow buffers 40 | * @return a new ArrowBatch 41 | */ 42 | public static ArrowBatch fromRoot(VectorSchemaRoot root, BufferAllocator intoAllocator) { 43 | final List fields = root.getSchema().getFields(); 44 | final String[] fieldNames = fields.stream() 45 | .map(Field::getName).toArray(String[]::new); 46 | final ValueVector[] vectors = new ValueVector[fieldNames.length]; 47 | 48 | IntStream.range(0, fieldNames.length) 49 | .forEach(idx -> { 50 | try { 51 | final FieldVector fv = root.getVector(idx); 52 | final TransferPair pair = fv.getTransferPair(intoAllocator); 53 | pair.transfer(); 54 | vectors[idx] = pair.getTo(); 55 | } catch (Exception e) { 56 | logger.error("unable to build ArrowBatch", e); 57 | throw new RuntimeException(e); 58 | } 59 | }); 60 | return new ArrowBatch(fieldNames, vectors, root.getRowCount()); 61 | } 62 | 63 | @Override 64 | public void close() throws Exception { 65 | // The ArrowBatch doesn't own its allocator, so just free its vectors 66 | Arrays.stream(vectors).forEach(ValueVector::close); 67 | } 68 | 69 | public String[] getFieldNames() { 70 | return fieldNames; 71 | } 72 | 73 | public ValueVector[] getVectors() { 74 | return vectors; 75 | } 76 | 77 | public int getRowCount() { 78 | return rowCount; 79 | } 80 | 81 | @Override 82 | public String toString() { 83 | return "ArrowBatch{" + 84 | "fieldNames=" + Arrays.toString(fieldNames) + 85 | ", rowCount=" + rowCount + 86 | '}'; 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/batch/ArrowBatches.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.batch; 2 | 3 | import org.apache.arrow.memory.ArrowBuf; 4 | import org.apache.arrow.memory.BufferAllocator; 5 | import org.apache.arrow.util.AutoCloseables; 6 | import org.apache.arrow.vector.ValueVector; 7 | import org.apache.arrow.vector.types.pojo.Field; 8 | import org.apache.arrow.vector.types.pojo.Schema; 9 | import org.apache.arrow.vector.util.TransferPair; 10 | import org.neo4j.arrow.Config; 11 | 12 | import java.util.ArrayList; 13 | import java.util.Arrays; 14 | import java.util.Collection; 15 | import java.util.List; 16 | import java.util.stream.Collectors; 17 | import java.util.stream.IntStream; 18 | 19 | /** 20 | * Container of Arrow vectors and metadata, schema, etc. !!! NOT THREAD SAFE.!!! 21 | *

22 | * This is a mess right now :-( 23 | *

24 | */ 25 | public class ArrowBatches implements AutoCloseable { 26 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ArrowBatches.class); 27 | 28 | final Schema schema; 29 | private BufferAllocator allocator; 30 | final List> vectorSpace = new ArrayList<>(); 31 | final String[] fieldNames; 32 | 33 | long rowCount = 0; 34 | int maxBatchSize = -1; 35 | 36 | public ArrowBatches(Schema schema, BufferAllocator parentAllocator, String name) { 37 | this.schema = schema; 38 | this.allocator = parentAllocator.newChildAllocator(name, 0L, Config.maxStreamMemory); 39 | final List fields = schema.getFields(); 40 | fieldNames = new String[fields.size()]; 41 | 42 | IntStream.range(0, fields.size()) 43 | .forEach(idx -> { 44 | final String fieldName = fields.get(idx).getName(); 45 | fieldNames[idx] = fieldName; 46 | vectorSpace.add(new ArrayList<>()); 47 | logger.info("added {} to vectorspace", fieldName); 48 | }); 49 | } 50 | 51 | public void appendBatch(ArrowBatch batch) { 52 | assert batch != null; 53 | final ValueVector[] vectors = batch.getVectors(); 54 | final int rows = batch.getRowCount(); 55 | 56 | if (rows < 1) { 57 | throw new RuntimeException("empty batch?"); 58 | } 59 | 60 | maxBatchSize = Math.max(maxBatchSize, rows); 61 | 62 | assert vectors.length == fieldNames.length; 63 | IntStream.range(0, vectors.length) 64 | .forEach(idx -> { 65 | final List vectorList = vectorSpace.get(idx); 66 | final TransferPair pair = vectors[idx].getTransferPair(allocator); 67 | pair.transfer(); 68 | vectorList.add(pair.getTo()); 69 | }); 70 | 71 | rowCount += rows; 72 | 73 | AutoCloseables.closeNoChecked(batch); 74 | } 75 | 76 | public long estimateSize() { 77 | return vectorSpace.stream() 78 | .flatMap(Collection::stream) 79 | .mapToLong(ValueVector::getBufferSize) 80 | .sum(); 81 | } 82 | 83 | public long actualSize() { 84 | return vectorSpace.stream() 85 | .flatMap(Collection::stream) 86 | .map(vec -> vec.getBuffers(false)) 87 | .flatMap(Arrays::stream) 88 | .mapToLong(ArrowBuf::capacity) 89 | .sum(); 90 | } 91 | 92 | public BatchedVector getVector(int index) { 93 | if (index < 0 || index >= fieldNames.length) 94 | throw new RuntimeException("index out of range"); 95 | 96 | final List vectors = vectorSpace.get(index); 97 | if (vectors == null) { 98 | System.out.println("index " + index + " returned nul list?!"); 99 | throw new ArrayIndexOutOfBoundsException("invalid vectorspace index"); 100 | } 101 | return new BatchedVector(fieldNames[index], vectorSpace.get(index), maxBatchSize, rowCount); 102 | } 103 | 104 | public BatchedVector getVector(String name) { 105 | logger.info("...finding vector for name {}", name); 106 | int index = 0; 107 | for ( ; index getFieldVectors() { 118 | return IntStream.range(0, fieldNames.length) 119 | .mapToObj(this::getVector) 120 | .collect(Collectors.toList()); 121 | } 122 | 123 | public int getRowCount() { 124 | return (int) rowCount; // XXX cast 125 | } 126 | 127 | public Schema getSchema() { 128 | return schema; 129 | } 130 | 131 | @Override 132 | public void close() { 133 | allocator.assertOpen(); 134 | vectorSpace.forEach(list -> list.forEach(ValueVector::close)); 135 | allocator.close(); 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/batch/BatchedVector.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.batch; 2 | 3 | import org.apache.arrow.vector.ValueVector; 4 | import org.apache.arrow.vector.complex.FixedSizeListVector; 5 | import org.apache.arrow.vector.complex.ListVector; 6 | import org.apache.arrow.vector.util.Text; 7 | 8 | import java.io.Closeable; 9 | import java.util.List; 10 | import java.util.Optional; 11 | import java.util.stream.Collectors; 12 | 13 | public class BatchedVector implements Closeable { 14 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(BatchedVector.class); 15 | 16 | private final List vectors; 17 | private final int batchSize; 18 | private final String name; 19 | private final long rowCount; 20 | private int watermark = 0; 21 | 22 | BatchedVector(String name, List vectors, int batchSize, long rowCount) { 23 | this.name = name; 24 | this.vectors = vectors; 25 | this.batchSize = batchSize; 26 | this.rowCount = rowCount; 27 | 28 | if (vectors == null || vectors.size() < 1) 29 | throw new RuntimeException("invalid vectors list"); 30 | 31 | // XXX this is ugly...but we need to know where our search gets tough 32 | for (int i = 0; i < vectors.size(); i++) { 33 | final ValueVector vector = vectors.get(i); 34 | assert vector != null; 35 | if (vector.getValueCount() < batchSize) { 36 | watermark = i; 37 | break; 38 | } 39 | } 40 | } 41 | 42 | public List getVectors() { 43 | return this.vectors; 44 | } 45 | 46 | public String getName() { 47 | return name; 48 | } 49 | 50 | public Class getBaseType() { 51 | return vectors.get(0).getClass(); 52 | } 53 | 54 | /** 55 | * Find an item from the vector space, accounting for the fact the tail end might have batch sizes less than 56 | * the original batch size. 57 | * 58 | * @param index index of item to retrieve from the space 59 | * @return Object value if found, otherwise null 60 | */ 61 | private Object translateIndex(long index) { 62 | assert (index < rowCount); 63 | 64 | // assumption is our batches only become "short" at the end 65 | int column = (int) Math.floorDiv(index, batchSize); 66 | int offset = (int) (index % batchSize); 67 | 68 | try { 69 | if (column < watermark) { 70 | // trivial case 71 | ValueVector vector = vectors.get(column); 72 | return vector.getObject(offset); 73 | } 74 | 75 | // harder, we need to search varying size columns. start at our watermark. 76 | int pos = watermark * batchSize; 77 | column = watermark; 78 | ValueVector vector = vectors.get(column); 79 | logger.trace("starting search from pos {} to find index {} (watermark: {})", pos, index, watermark); 80 | while ((index - pos) >= vector.getValueCount()) { 81 | column++; 82 | pos += vector.getValueCount(); 83 | vector = vectors.get(column); // XXX eventually will barf...need better handling here 84 | } 85 | return vector.getObject((int) (index - pos)); 86 | } catch (Exception e) { 87 | logger.error(String.format("failed to get index %d for %s (offset %d, column %d, batchSize %d)", 88 | index, vectors.get(0).getName(), offset, column, batchSize), e); 89 | logger.trace(String.format("debug: %s", 90 | vectors.stream() 91 | .map(ValueVector::getValueCount) 92 | .map(String::valueOf) 93 | .collect(Collectors.joining(", ")))); 94 | return null; 95 | } 96 | } 97 | 98 | public long getNodeId(long index) { 99 | final Long nodeId = (Long) translateIndex(index); 100 | if (nodeId == null) { 101 | throw new RuntimeException(String.format("cant get nodeId for index %d", index)); 102 | } 103 | if (nodeId < 0) { 104 | throw new RuntimeException("nodeId < 0?!?!"); 105 | } 106 | return nodeId; 107 | } 108 | 109 | public List getLabels(long index) { 110 | final List list = (List) translateIndex(index); 111 | if (list == null) { 112 | logger.warn("failed to find list at index {}, index", index); 113 | return List.of(); 114 | } 115 | return list.stream().map(Object::toString).collect(Collectors.toList()); 116 | } 117 | 118 | public String getType(long index) { 119 | // XXX Assumption for now is we're dealing with a VarCharVector 120 | final Text type = (Text) translateIndex(index); 121 | if (type == null) { 122 | logger.warn("failed to find type string at index {}, index", index); 123 | return ""; 124 | } 125 | return type.toString(); 126 | } 127 | 128 | public Object getObject(long index) { 129 | final Object o = translateIndex(index); 130 | if (o == null) { 131 | logger.warn("failed to find list at index {}, index", index); 132 | return null; 133 | } 134 | return o; 135 | } 136 | 137 | public List getList(long index) { 138 | final List list = (List) translateIndex(index); 139 | if (list == null) { 140 | logger.warn("failed to find list at index {}, index", index); 141 | return List.of(); 142 | } 143 | return list; 144 | } 145 | 146 | public Optional> getDataClass() { 147 | final ValueVector v = vectors.get(0); 148 | if (v instanceof ListVector) { 149 | return Optional.of(((ListVector) v).getDataVector().getClass()); 150 | } else if (v instanceof FixedSizeListVector) { 151 | final FixedSizeListVector fv = (FixedSizeListVector) v; 152 | return Optional.of(fv.getDataVector().getClass()); 153 | } 154 | return Optional.empty(); 155 | } 156 | 157 | @Override 158 | public void close() { 159 | vectors.forEach(ValueVector::close); 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/job/Job.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import javax.annotation.Nonnull; 4 | import java.util.concurrent.*; 5 | import java.util.concurrent.atomic.AtomicLong; 6 | 7 | /** 8 | * The abstract base class for neo4j-arrow Jobs. 9 | */ 10 | public abstract class Job implements AutoCloseable, Future { 11 | protected static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Job.class); 12 | private static final AtomicLong jobCounter = new AtomicLong(0); 13 | 14 | protected final String jobId; 15 | 16 | public enum Mode { 17 | READ, 18 | WRITE 19 | } 20 | 21 | public enum Status { 22 | INITIALIZING, 23 | PENDING, 24 | PRODUCING, 25 | COMPLETE, 26 | ERROR; 27 | 28 | @Override 29 | public String toString() { 30 | switch (this) { 31 | case INITIALIZING: 32 | return "INITIALIZING"; 33 | case PENDING: 34 | return "PENDING"; 35 | case COMPLETE: 36 | return "COMPLETE"; 37 | case ERROR: 38 | return "ERROR"; 39 | case PRODUCING: 40 | return "PRODUCING"; 41 | } 42 | return "UNKNOWN"; 43 | } 44 | } 45 | 46 | private Status jobStatus = Status.INITIALIZING; 47 | 48 | protected final CompletableFuture jobSummary = new CompletableFuture<>(); 49 | 50 | protected Job() { 51 | jobId = String.format("job-%d", jobCounter.getAndIncrement()); 52 | } 53 | 54 | public synchronized Status getStatus() { 55 | return jobStatus; 56 | } 57 | 58 | // XXX making public for now...API needs some future rework 59 | public synchronized void setStatus(Status status) { 60 | logger.info("status {} -> {}", jobStatus, status); 61 | jobStatus = status; 62 | } 63 | 64 | protected void onCompletion(JobSummary summary) { 65 | logger.info("Job {} completed", jobId); 66 | jobSummary.complete(summary); 67 | setStatus(Status.COMPLETE); 68 | } 69 | 70 | @Override 71 | public abstract boolean cancel(boolean mayInterruptIfRunning); 72 | 73 | @Override 74 | public boolean isCancelled() { 75 | return jobSummary.isCancelled(); 76 | } 77 | 78 | @Override 79 | public boolean isDone() { 80 | return jobSummary.isDone(); 81 | } 82 | 83 | public String getJobId() { 84 | return jobId; 85 | } 86 | 87 | @Override 88 | public JobSummary get() throws InterruptedException, ExecutionException { 89 | return jobSummary.get(); 90 | } 91 | 92 | @Override 93 | public JobSummary get(long timeout, @Nonnull TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { 94 | return jobSummary.get(timeout, unit); 95 | } 96 | 97 | @Override 98 | public abstract void close() throws Exception; 99 | 100 | @Override 101 | public String toString() { 102 | return "Neo4jJob{" + 103 | "status=" + jobStatus + 104 | ", id='" + jobId + '\'' + 105 | '}'; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/job/JobCreator.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import javax.annotation.Nullable; 4 | 5 | /** 6 | * The {@link JobCreator} provides a functional interface for creating an instance of a {@link Job}. 7 | *

8 | * Since each {@link Job} implementation potentially uses a distinct message format (and there's 9 | * currently no message interface or base class), the {@link JobCreator} is generic and 10 | * parameterized by the type T of the supported message. 11 | *

12 | *

13 | * It's assumed that things like the {@link Job.Mode}, a username, and password are common 14 | * enough to warrant being in the core signature. (Albeit username and password are optional.) 15 | *

16 | */ 17 | @FunctionalInterface 18 | public interface JobCreator { 19 | /** 20 | * Create a new {@link Job} given the job message, {@link Job.Mode}, and optional username and 21 | * password. 22 | * 23 | * @param msg a {@link Job}-specific message 24 | * @param mode the mode, e.g. READ vs WRITE 25 | * @param username optional username for the caller 26 | * @return new {@link Job} 27 | */ 28 | J newJob(T msg, Job.Mode mode, @Nullable String username); 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/job/JobSummary.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | /** 4 | * This is pretty much useless at the moment...probably will be killed. 5 | */ 6 | @FunctionalInterface 7 | public interface JobSummary { 8 | @SuppressWarnings("unused") 9 | String asString(); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/job/ReadJob.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | import org.neo4j.arrow.Config; 4 | import org.neo4j.arrow.RowBasedRecord; 5 | 6 | import java.util.concurrent.CompletableFuture; 7 | import java.util.concurrent.Future; 8 | import java.util.function.BiConsumer; 9 | 10 | public abstract class ReadJob extends Job { 11 | 12 | public static final Mode mode = Mode.READ; 13 | 14 | /** Completes once the first record is received and ready from the Neo4j system. */ 15 | private final CompletableFuture firstRecord = new CompletableFuture<>(); 16 | 17 | /** Provides a {@link BiConsumer} taking a {@link RowBasedRecord} with data and a partition id {@link Integer} */ 18 | protected final CompletableFuture> futureConsumer = new CompletableFuture<>(); 19 | 20 | /** Maximum number of simultaneous buffers to populate */ 21 | public final int maxPartitionCnt; 22 | 23 | /** Maximum number of "rows" in a batch, i.e. the max vector dimension. */ 24 | public final int maxRowCount; 25 | 26 | /** Maximum entries in a variable size list. Selectively used by certain jobs. */ 27 | public final int maxVariableListSize; 28 | 29 | protected ReadJob() { 30 | this(Config.arrowMaxPartitions, Config.arrowBatchSize, Config.arrowMaxListSize); 31 | } 32 | 33 | protected ReadJob(int maxPartitionCnt, int maxRowCount, int maxVariableListSize) { 34 | this.maxPartitionCnt = maxPartitionCnt; 35 | this.maxRowCount = maxRowCount; 36 | this.maxVariableListSize = maxVariableListSize; 37 | } 38 | 39 | protected void onFirstRecord(RowBasedRecord record) { 40 | logger.debug("First record received {}", record); 41 | firstRecord.complete(record); 42 | setStatus(Status.PENDING); 43 | } 44 | 45 | public Future getFirstRecord() { 46 | return firstRecord; 47 | } 48 | 49 | public void consume(BiConsumer consumer) { 50 | if (!futureConsumer.isDone()) 51 | futureConsumer.complete(consumer); 52 | else 53 | logger.error("Consumer already supplied for job {}", this); 54 | } 55 | 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/java/org/neo4j/arrow/job/WriteJob.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.job; 2 | 3 | 4 | import org.apache.arrow.memory.BufferAllocator; 5 | import org.apache.arrow.vector.types.pojo.Schema; 6 | import org.neo4j.arrow.Config; 7 | import org.neo4j.arrow.batch.ArrowBatch; 8 | 9 | import java.util.concurrent.CompletableFuture; 10 | import java.util.concurrent.Future; 11 | import java.util.function.Consumer; 12 | 13 | public abstract class WriteJob extends Job { 14 | private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(WriteJob.class); 15 | 16 | public static final Mode mode = Mode.WRITE; 17 | 18 | private final BufferAllocator allocator; 19 | private CompletableFuture schema = new CompletableFuture<>(); 20 | 21 | private final CompletableFuture streamComplete = new CompletableFuture<>(); 22 | 23 | public WriteJob(BufferAllocator parentAllocator) { 24 | super(); 25 | allocator = parentAllocator.newChildAllocator(this.getJobId(), 0L, Config.maxStreamMemory); 26 | } 27 | 28 | @Override 29 | public boolean cancel(boolean mayInterruptIfRunning) { 30 | return false; 31 | } 32 | 33 | @Override 34 | public void close() throws Exception { 35 | allocator.close(); 36 | } 37 | 38 | public Future getStreamCompletion() { 39 | return streamComplete; 40 | } 41 | 42 | public void onStreamComplete(Schema schema) { 43 | logger.debug("stream completed"); 44 | streamComplete.complete(null); 45 | } 46 | 47 | public void onSchema(Schema schema) { 48 | logger.debug("received schema"); 49 | this.schema.complete(schema); 50 | } 51 | 52 | protected CompletableFuture getSchema() { 53 | return schema; 54 | } 55 | 56 | public abstract Consumer getConsumer(Schema schema); 57 | 58 | public BufferAllocator getAllocator() { 59 | return allocator; 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/ArrowBatchesTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.apache.arrow.memory.BufferAllocator; 4 | import org.apache.arrow.memory.RootAllocator; 5 | import org.apache.arrow.vector.BigIntVector; 6 | import org.apache.arrow.vector.FieldVector; 7 | import org.apache.arrow.vector.ValueVector; 8 | import org.apache.arrow.vector.VectorSchemaRoot; 9 | import org.apache.arrow.vector.types.pojo.ArrowType; 10 | import org.apache.arrow.vector.types.pojo.Field; 11 | import org.apache.arrow.vector.types.pojo.FieldType; 12 | import org.apache.arrow.vector.types.pojo.Schema; 13 | import org.junit.jupiter.api.Assertions; 14 | import org.junit.jupiter.api.Test; 15 | import org.neo4j.arrow.batch.ArrowBatch; 16 | import org.neo4j.arrow.batch.ArrowBatches; 17 | import org.neo4j.arrow.batch.BatchedVector; 18 | 19 | import java.util.HashSet; 20 | import java.util.List; 21 | import java.util.Set; 22 | 23 | public class ArrowBatchesTest { 24 | 25 | @Test 26 | public void testSearchingTailEnd() { 27 | Field field = new Field("junk", FieldType.nullable(new ArrowType.Int(64, true)), null); 28 | Schema schema = new Schema(List.of(field)); 29 | 30 | try (BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { 31 | final ArrowBatches batches = new ArrowBatches(schema, allocator, "batch"); 32 | 33 | long nodeId = 0; 34 | for (int i=0; i<10; i++) { 35 | try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { 36 | FieldVector vector = root.getVector(0); 37 | vector.allocateNewSafe(); 38 | BigIntVector biv = (BigIntVector) vector; 39 | for (int j=0; j<6; j++) { 40 | biv.set(j, nodeId++); 41 | } 42 | root.setRowCount(6); 43 | System.out.println("created vector: " + biv); 44 | 45 | final ArrowBatch batch = ArrowBatch.fromRoot(root, allocator); 46 | batches.appendBatch(batch); 47 | vector.close(); 48 | } 49 | } 50 | 51 | for (int i=0; i<48; i++) { 52 | try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { 53 | FieldVector vector = root.getVector(0); 54 | vector.allocateNewSafe(); 55 | BigIntVector biv = (BigIntVector) vector; 56 | for (int j=0; j<5; j++) { 57 | biv.set(j, nodeId++); 58 | } 59 | root.setRowCount(5); 60 | System.out.println("created vector: " + biv); 61 | 62 | final ArrowBatch batch = ArrowBatch.fromRoot(root, allocator); 63 | batches.appendBatch(batch); 64 | vector.close(); 65 | } 66 | } 67 | 68 | Assertions.assertEquals(6 * 10 + 5 * 48, batches.getRowCount()); 69 | 70 | final List batchedVectors = batches.getFieldVectors(); 71 | Assertions.assertNotNull(batchedVectors); 72 | Assertions.assertEquals(1, batchedVectors.size()); 73 | 74 | final BatchedVector bv = batches.getVector(0); 75 | final List list = bv.getVectors(); 76 | Assertions.assertEquals(58, list.size()); 77 | 78 | Set seen = new HashSet<>(); 79 | for (long i=0; i<300; i++) { 80 | final long id = bv.getNodeId(i); 81 | System.out.println("i=" + i + ", nodeId= " + id); 82 | if (seen.contains(id)) { 83 | Assertions.fail("already seen nodeId: " + id); 84 | } 85 | seen.add(id); 86 | } 87 | 88 | batches.close(); 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/Neo4jFlightServerTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.apache.arrow.flight.Location; 4 | import org.apache.arrow.memory.BufferAllocator; 5 | import org.apache.arrow.memory.RootAllocator; 6 | import org.apache.arrow.util.AutoCloseables; 7 | 8 | import java.util.concurrent.TimeUnit; 9 | 10 | public class Neo4jFlightServerTest { 11 | private static final org.slf4j.Logger logger; 12 | 13 | static { 14 | // Set up nicer logging output. 15 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 16 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 17 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 18 | logger = org.slf4j.LoggerFactory.getLogger(Neo4jFlightServerTest.class); 19 | } 20 | 21 | public static void main(String[] args) throws Exception { 22 | long timeout = 30; 23 | TimeUnit unit = TimeUnit.SECONDS; 24 | 25 | final BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); 26 | final App neo4jFlightServer = new App( 27 | bufferAllocator, 28 | Location.forGrpcInsecure("0.0.0.0", 9999)); 29 | neo4jFlightServer.start(); 30 | 31 | Runtime.getRuntime().addShutdownHook(new Thread(() -> { 32 | try { 33 | logger.info("Shutting down..."); 34 | AutoCloseables.close(neo4jFlightServer, bufferAllocator); 35 | logger.info("Stopped."); 36 | } catch (Exception e) { 37 | logger.error("Failure during shutdown", e); 38 | } 39 | })); 40 | 41 | neo4jFlightServer.awaitTermination(timeout, unit); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/NoOpBenchmark.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.apache.arrow.flight.*; 4 | import org.apache.arrow.memory.RootAllocator; 5 | import org.apache.arrow.vector.types.FloatingPointPrecision; 6 | import org.apache.arrow.vector.types.pojo.ArrowType; 7 | import org.apache.arrow.vector.types.pojo.Field; 8 | import org.apache.arrow.vector.types.pojo.FieldType; 9 | import org.apache.arrow.vector.types.pojo.Schema; 10 | import org.junit.jupiter.api.Assertions; 11 | import org.junit.jupiter.api.Test; 12 | import org.neo4j.arrow.action.ActionHandler; 13 | import org.neo4j.arrow.action.Outcome; 14 | import org.neo4j.arrow.demo.Client; 15 | import org.neo4j.arrow.job.ReadJob; 16 | 17 | import java.util.Arrays; 18 | import java.util.List; 19 | import java.util.concurrent.CompletableFuture; 20 | import java.util.concurrent.TimeUnit; 21 | import java.util.function.BiConsumer; 22 | import java.util.stream.Collectors; 23 | import java.util.stream.IntStream; 24 | 25 | /** 26 | * A relatively simple NoOp test mostly for producing flamegraphs and doing some light integration testing. 27 | */ 28 | public class NoOpBenchmark { 29 | private static final org.slf4j.Logger logger; 30 | private final static String ACTION_NAME = "NoOp"; 31 | 32 | static { 33 | // Set up nicer logging output. 34 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 35 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 36 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 37 | logger = org.slf4j.LoggerFactory.getLogger(NoOpBenchmark.class); 38 | } 39 | 40 | private static final double[] PAYLOAD = { 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 41 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 42 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 43 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 44 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 45 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 46 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 47 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 48 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 49 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 50 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 51 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d, 9d, 10d, 52 | 1d, 2d, 3d, 4d, 5d, 6d, 7d, 8d }; 53 | 54 | private static class NoOpRecord implements RowBasedRecord { 55 | 56 | @Override 57 | public Value get(int index) { 58 | return new Value() { 59 | private final List doubleList = Arrays.stream(PAYLOAD).boxed().collect(Collectors.toList()); 60 | 61 | @Override 62 | public int size() { 63 | return doubleList.size(); 64 | } 65 | 66 | @Override 67 | public List asDoubleList() { 68 | return doubleList; 69 | } 70 | 71 | @Override 72 | public double[] asDoubleArray() { 73 | return PAYLOAD; 74 | } 75 | 76 | @Override 77 | public Type type() { 78 | return Type.DOUBLE_ARRAY; 79 | } 80 | }; 81 | } 82 | 83 | @Override 84 | public Value get(String field) { 85 | return get(1); 86 | } 87 | 88 | @Override 89 | public List keys() { 90 | return List.of("n"); 91 | } 92 | } 93 | 94 | private static class NoOpJob extends ReadJob { 95 | 96 | final CompletableFuture future; 97 | final int numResults; 98 | 99 | public NoOpJob(int numResults, CompletableFuture signal) { 100 | super(); 101 | this.numResults = numResults; 102 | 103 | future = CompletableFuture.supplyAsync(() -> { 104 | logger.info("Job starting"); 105 | final RowBasedRecord record = new NoOpRecord(); 106 | onFirstRecord(record); 107 | logger.info("Job feeding"); 108 | 109 | BiConsumer consumer = super.futureConsumer.join(); 110 | IntStream.range(1, numResults + 1) 111 | .parallel() 112 | .forEach(i -> consumer.accept(record, i)); 113 | 114 | signal.complete(System.currentTimeMillis()); 115 | logger.info("Job finished"); 116 | onCompletion(() -> "done"); 117 | return numResults; 118 | }); 119 | } 120 | 121 | @Override 122 | public boolean cancel(boolean mayInterruptIfRunning) { 123 | return false; 124 | } 125 | 126 | @Override 127 | public void close() { 128 | } 129 | } 130 | 131 | private static class NoOpHandler implements ActionHandler { 132 | 133 | final CompletableFuture signal; 134 | NoOpHandler(CompletableFuture signal) { 135 | this.signal = signal; 136 | } 137 | 138 | @Override 139 | public List actionTypes() { 140 | return List.of(ACTION_NAME); 141 | } 142 | 143 | @Override 144 | public List actionDescriptions() { 145 | return List.of(new ActionType(ACTION_NAME, "Nothing")); 146 | } 147 | 148 | @Override 149 | public Outcome handle(FlightProducer.CallContext context, Action action, Producer producer) { 150 | Assertions.assertEquals(ACTION_NAME, action.getType()); 151 | final Ticket ticket = producer.ticketJob(new NoOpJob(1_000_000, signal)); 152 | producer.setFlightInfo(ticket, new Schema( 153 | List.of(new Field("embedding", 154 | FieldType.nullable(new ArrowType.FixedSizeList(PAYLOAD.length)), 155 | List.of(new Field("embedding", 156 | FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null)))))); 157 | return Outcome.success(new Result(ticket.serialize().array())); 158 | } 159 | } 160 | 161 | @Test 162 | public void testSpeed() throws Exception { 163 | final Location location = Location.forGrpcInsecure("localhost", 12345); 164 | final CompletableFuture signal = new CompletableFuture<>(); 165 | 166 | App app = new App(new RootAllocator(Long.MAX_VALUE), location); 167 | Client client = new Client(new RootAllocator(Long.MAX_VALUE), location); 168 | 169 | try { 170 | app.registerHandler(new NoOpHandler(signal)); 171 | app.start(); 172 | 173 | long start = System.currentTimeMillis(); 174 | Action action = new Action(ACTION_NAME); 175 | client.run(action); 176 | long stop = signal.join(); 177 | logger.info(String.format("Client Lifecycle Time: %,d ms", stop - start)); 178 | 179 | app.awaitTermination(1, TimeUnit.SECONDS); 180 | } catch (Exception e) { 181 | e.printStackTrace(); 182 | } 183 | app.close(); 184 | client.close(); 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/SillyAsyncTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import java.io.IOException; 4 | import java.util.concurrent.CompletableFuture; 5 | import java.util.concurrent.atomic.AtomicBoolean; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | import java.util.concurrent.atomic.AtomicReference; 8 | 9 | public class SillyAsyncTest { 10 | 11 | final static AtomicInteger cnt = new AtomicInteger(0); 12 | final static Runnable r = () -> { 13 | try { 14 | Thread.sleep(2000); 15 | System.out.printf("cnt is now %d\n", cnt.incrementAndGet()); 16 | } catch (InterruptedException e) { 17 | e.printStackTrace(); 18 | } 19 | }; 20 | 21 | public static void main(String[] args) throws IOException { 22 | final AtomicBoolean bool = new AtomicBoolean(false); 23 | AtomicReference> futureRef = new AtomicReference(CompletableFuture.runAsync(r)); 24 | 25 | while (true) { 26 | System.out.println("Command?"); 27 | int i = System.in.read(); 28 | System.out.printf("i=%d\n", i); 29 | if (i == 10) { 30 | System.out.println("adding new task to future"); 31 | 32 | futureRef.getAndUpdate(currentFuture -> currentFuture.thenRunAsync(r)); 33 | } else { 34 | break; 35 | } 36 | } 37 | System.out.println("Waiting on future..."); 38 | futureRef.get().join(); 39 | System.out.println("Future complete!"); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/SillyBenchmark.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.junit.jupiter.api.Disabled; 4 | import org.junit.jupiter.api.Test; 5 | import org.neo4j.driver.*; 6 | 7 | import java.util.Map; 8 | 9 | public class SillyBenchmark { 10 | private static final org.slf4j.Logger logger; 11 | static { 12 | // Set up nicer logging output. 13 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 14 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 15 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 16 | logger = org.slf4j.LoggerFactory.getLogger(SillyBenchmark.class); 17 | } 18 | 19 | @Test 20 | @Disabled 21 | public void testSilly1MillionEmbeddings() { 22 | try (Driver driver = GraphDatabase.driver(Config.neo4jUrl, 23 | AuthTokens.basic(Config.username, Config.password))) { 24 | for (int i = 0; i < 1; i++) { 25 | try (Session session = driver.session(SessionConfig.builder() 26 | .withDefaultAccessMode(AccessMode.READ).build())) { 27 | try { 28 | long start = System.currentTimeMillis(); 29 | Result result = session.run( 30 | "UNWIND range(1, $rows) AS row\n" + 31 | "RETURN row, [_ IN range(1, $dimension) | rand()] as fauxEmbedding", 32 | Map.of("rows", 1_000_000, "dimension", 128)); 33 | long cnt = 0; 34 | while (result.hasNext()) { 35 | result.next(); 36 | cnt++; 37 | if (cnt % 25_000 == 0) 38 | logger.info("Current Row @ {} [fields: {}]", cnt, result.keys()); 39 | } 40 | long finish = System.currentTimeMillis(); 41 | logger.info(String.format("finished in %,d ms, rate of %,d rows/sec", 42 | (finish - start), 1000 * cnt / (finish - start))); 43 | } catch (Exception e) { 44 | logger.error("oops", e); 45 | } 46 | } 47 | } 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/SillyStreamsTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.junit.jupiter.api.Disabled; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import java.util.concurrent.*; 7 | import java.util.concurrent.atomic.AtomicInteger; 8 | import java.util.function.Function; 9 | import java.util.function.Supplier; 10 | import java.util.stream.IntStream; 11 | import java.util.stream.LongStream; 12 | import java.util.stream.Stream; 13 | 14 | public class SillyStreamsTest { 15 | 16 | @Test 17 | public void autoclosingTest() throws Exception { 18 | Function make = (name) -> new AutoCloseable() { 19 | @Override 20 | public void close() throws Exception { 21 | System.out.println(name + " closed!"); 22 | } 23 | }; 24 | 25 | try (var one = make.apply("one"); 26 | var two = make.apply("two"); 27 | var three = make.apply("three")) { 28 | System.out.println("Made the following: " + one + ", " + two + ", " + three); 29 | System.out.println("ok...closing now..."); 30 | } 31 | } 32 | 33 | @Test 34 | public void testChokingStreams() throws Exception { 35 | final Semaphore semaphore = new Semaphore(3); 36 | 37 | final AtomicInteger cnt = new AtomicInteger(0); 38 | 39 | int i = IntStream.range(0, 100) 40 | .unordered() 41 | .parallel() 42 | .map(n -> { 43 | try { 44 | semaphore.acquire(); 45 | int c = cnt.incrementAndGet(); 46 | System.out.printf("%d: c = %d\n", n, c); 47 | cnt.decrementAndGet(); 48 | semaphore.release(); 49 | return 1; 50 | } catch (Exception e) { 51 | e.printStackTrace(); 52 | return 0; 53 | } 54 | }).sum(); 55 | System.out.println("i = " + i + ", cnt = " + cnt.get()); 56 | } 57 | 58 | @Disabled 59 | @Test 60 | public void testStreams() throws ExecutionException, InterruptedException { 61 | var n = IntStream.range(1, 1_000_000).boxed().parallel().map(i -> { 62 | try { 63 | Thread.sleep(50); 64 | } catch (Exception e) { 65 | 66 | } 67 | if (i % 100 == 0) 68 | System.out.printf("%s : i=%d\n", Thread.currentThread(), i); 69 | return i; 70 | } 71 | ).reduce(Integer::sum); 72 | System.out.println(n); 73 | } 74 | 75 | @Test 76 | public void testStreamClosing() { 77 | 78 | final AtomicInteger cnt = new AtomicInteger(0); 79 | 80 | Stream longStream = LongStream.range(0, 36).boxed().parallel() 81 | .onClose(() -> System.out.println("parents stream closed!")); 82 | 83 | longStream.flatMap(l -> { 84 | cnt.incrementAndGet(); 85 | return LongStream.range(0, l) 86 | .boxed() 87 | .parallel() 88 | .peek(m -> System.out.println("processing " + l + ", " + m)) 89 | .onClose(() -> { 90 | int open = cnt.decrementAndGet(); 91 | System.out.printf("closing substream %d (%d still open)\n", l, open); 92 | }); 93 | }).forEach(l -> System.out.println("l = " + l)); 94 | } 95 | 96 | @Test 97 | public void testSupplier() throws Exception { 98 | Supplier sup = System::currentTimeMillis; 99 | System.out.println("sup? " + sup.get()); 100 | Thread.sleep(1233); 101 | System.out.println("sup? " + sup.get()); 102 | 103 | } 104 | 105 | @Test 106 | public void testStreamExecutors() throws Exception { 107 | final ThreadGroup group = new ThreadGroup("test-group"); 108 | group.setDaemon(true); 109 | 110 | Executor executor = Executors.newFixedThreadPool(3, runnable -> new Thread(group, runnable)); 111 | 112 | CompletableFuture.runAsync(() -> { 113 | LongStream.range(0, 33).parallel() 114 | .forEach(l -> System.out.printf("%s: %d\n", Thread.currentThread(), l)); 115 | 116 | }, executor).join(); 117 | 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/test/java/org/neo4j/arrow/action/ServerInfoHandlerTest.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow.action; 2 | 3 | import org.junit.jupiter.api.Assertions; 4 | import org.junit.jupiter.api.Test; 5 | 6 | import java.util.Map; 7 | 8 | public class ServerInfoHandlerTest { 9 | 10 | @Test 11 | public void testGettingVersion() { 12 | 13 | Map map = ServerInfoHandler.getArrowVersion(); 14 | Assertions.assertTrue(map.size() > 0); 15 | Assertions.assertEquals("cool", map.get(ServerInfoHandler.VERSION_KEY)); 16 | System.out.println(map); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/test/resources/META-INF/MANIFEST.MF: -------------------------------------------------------------------------------- 1 | neo4j-arrow-version: cool 2 | -------------------------------------------------------------------------------- /strawman/README.md: -------------------------------------------------------------------------------- 1 | # A Strawman 2 | 3 | This subproject simply uses the Neo4j Java Driver to facilitate testing 4 | performance against an "out of the box" approach. That is to say, it uses 5 | the Java driver to run either: 6 | 7 | * A Cypher-only "export" of embeddings 8 | * Using GDS procedures to export from the in-memory Graph 9 | 10 | ## Building 11 | Run the following in the parent project: 12 | 13 | ``` 14 | $ ./gradlew :strawman:shadowJar 15 | ``` 16 | 17 | ## Running 18 | Most of the pertinent config options described in the [server](../server) 19 | project apply. However, there's one special environment variable: 20 | 21 | * `GDS`: set any value (e.g. `'ok'` or `1`) to use GDS procedures 22 | 23 | ``` 24 | $ java -jar strawman/build/libs/strawman-1.0-SNAPSHOT.jar 25 | ``` -------------------------------------------------------------------------------- /strawman/src/main/java/org/neo4j/arrow/Neo4jDirectClient.java: -------------------------------------------------------------------------------- 1 | package org.neo4j.arrow; 2 | 3 | import org.neo4j.driver.*; 4 | 5 | import java.util.Locale; 6 | import java.util.concurrent.atomic.AtomicInteger; 7 | 8 | public class Neo4jDirectClient { 9 | private static final org.slf4j.Logger logger; 10 | static { 11 | // Set up nicer logging output. 12 | System.setProperty("org.slf4j.simpleLogger.showDateTime", "true"); 13 | System.setProperty("org.slf4j.simpleLogger.dateTimeFormat", "[yyyy-MM-dd'T'HH:mm:ss:SSS]"); 14 | System.setProperty("org.slf4j.simpleLogger.logFile", "System.out"); 15 | System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "info"); 16 | logger = org.slf4j.LoggerFactory.getLogger(Neo4jDirectClient.class); 17 | } 18 | 19 | private static final String cypher; 20 | static { 21 | switch (System.getenv() 22 | .getOrDefault("TEST_MODE", "CYPHER") 23 | .toUpperCase(Locale.ROOT)) { 24 | case "GDS": 25 | cypher = "CALL gds.graph.streamNodeProperty('mygraph', 'n')"; 26 | break; 27 | case "RANDOM": 28 | cypher = "UNWIND range(1, 10000000) AS row RETURN row, [_ IN range(1, 256) | rand()] as fauxEmbedding"; 29 | break; 30 | case "SIMPLE": 31 | cypher = "UNWIND range(1, toInteger(1e7)) AS row RETURN row, range(0, 64) AS data"; 32 | break; 33 | case "CYPHER": 34 | default: 35 | cypher = "MATCH (n) RETURN id(n) as nodeId, n.fastRp AS embedding"; 36 | break; 37 | } 38 | } 39 | 40 | public static void main(String[] args) { 41 | logger.info("Connecting to {} using Java Driver", Config.neo4jUrl); 42 | try (Driver driver = 43 | GraphDatabase.driver(Config.neo4jUrl, AuthTokens.basic(Config.username, Config.password))) { 44 | logger.info("Opening session with database {}", Config.database); 45 | final SessionConfig config = SessionConfig.builder() 46 | .withDatabase(Config.database) 47 | .withFetchSize(Config.boltFetchSize) 48 | .withDefaultAccessMode(AccessMode.READ) 49 | .build(); 50 | try (Session session = driver.session(config)) { 51 | AtomicInteger cnt = new AtomicInteger(0); 52 | logger.info("Executing cypher: {}", cypher); 53 | final long start = System.currentTimeMillis(); 54 | session.run(cypher).stream().forEach(record -> { 55 | int i = cnt.incrementAndGet(); 56 | if (i % 25_000 == 0) { 57 | logger.info("Current Row @ {}:\t[fields: {}]", i, record.keys()); 58 | } 59 | }); 60 | final float delta = (System.currentTimeMillis() - start) / 1000.0f; 61 | logger.info(String.format("Done! Time Delta: %,f s", delta)); 62 | logger.info(String.format("Count=%,d rows, Rate=%,d rows/s", 63 | cnt.get(), cnt.get() / Math.round(delta))); 64 | } 65 | } 66 | 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /todo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | find . -name '*.java' \ 3 | | xargs grep TODO \ 4 | | awk -F'//' '{ gsub(/[ :]/, "", $1); printf "- [ ]%s [%s](%s)\n", $2, substr($1, match($1, /\/\w+[\\.]java/) + 1), $1 }' \ 5 | | grep -v '

' --------------------------------------------------------------------------------