├── .gitattributes ├── .gitignore ├── .repo ├── 243 │ └── source │ │ └── org │ │ └── apache │ │ └── spark │ │ └── WowRowEncoder.scala ├── 311 │ └── source │ │ └── org │ │ └── apache │ │ └── spark │ │ └── WowRowEncoder.scala ├── .DS_Store └── pom.template.xml ├── README.md ├── dev ├── change-scala-version.sh ├── change-version-to-2.11.sh ├── change-version-to-2.12.sh └── release.sh ├── examples └── pyproject1 │ └── train.py ├── pom.xml ├── python ├── README.md ├── install.sh ├── pyjava │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ ├── mlsql.py │ │ └── serve.py │ ├── cache │ │ ├── __init__.py │ │ └── code_cache.py │ ├── cloudpickle.py │ ├── daemon.py │ ├── datatype │ │ ├── __init__.py │ │ └── types.py │ ├── example │ │ ├── OnceServerExample.py │ │ ├── RayServerExample.py │ │ ├── __init__.py │ │ ├── test.py │ │ ├── test2.py │ │ └── test3.py │ ├── rayfix.py │ ├── serializers.py │ ├── storage │ │ ├── __init__.py │ │ └── streaming_tar.py │ ├── tests │ │ └── test_context.py │ ├── udf │ │ └── __init__.py │ ├── utils.py │ ├── version.py │ └── worker.py ├── requirements.txt ├── setup.cfg └── setup.py └── src ├── main └── java │ ├── org │ └── apache │ │ └── spark │ │ ├── WowRowEncoder.scala │ │ └── sql │ │ └── SparkUtils.scala │ └── tech │ └── mlsql │ └── arrow │ ├── ArrowConverters.scala │ ├── ArrowUtils.scala │ ├── ArrowWriter.scala │ ├── ByteBufferOutputStream.scala │ ├── Utils.scala │ ├── api │ └── RedirectStreams.scala │ ├── context │ └── CommonTaskContext.scala │ ├── javadoc.java │ └── python │ ├── PyJavaException.scala │ ├── PythonWorkerFactory.scala │ ├── iapp │ └── AppContextImpl.scala │ ├── ispark │ └── SparkContextImp.scala │ └── runner │ ├── ArrowPythonRunner.scala │ ├── PythonProjectRunner.scala │ ├── PythonRunner.scala │ └── SparkSocketRunner.scala └── test └── java └── tech └── mlsql └── test ├── ApplyPythonScript.scala ├── JavaApp1Spec.scala ├── JavaAppSpec.scala ├── JavaArrowServer.scala ├── Main.scala ├── PythonProjectSpec.scala ├── RayEnv.scala ├── SparkSpec.scala └── function └── SparkFunctions.scala /.gitattributes: -------------------------------------------------------------------------------- 1 | pom.xml merge=ours 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | pyjava.iml 3 | python/build/ 4 | python/dist/ 5 | python/pyjava.egg-info/ 6 | target/ 7 | python/.eggs/ 8 | __pycache__ 9 | spark-warehouse 10 | 11 | *.iml -------------------------------------------------------------------------------- /.repo/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/.repo/.DS_Store -------------------------------------------------------------------------------- /.repo/243/source/org/apache/spark/WowRowEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * DO NOT EDIT THIS FILE DIRECTLY, ANY CHANGE MAY BE OVERWRITE 3 | */ 4 | package org.apache.spark 5 | 6 | import org.apache.spark.sql.Row 7 | import org.apache.spark.sql.catalyst.InternalRow 8 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 9 | import org.apache.spark.sql.types.StructType 10 | 11 | object WowRowEncoder { 12 | def toRow(schema: StructType) = { 13 | val rab = RowEncoder.apply(schema).resolveAndBind() 14 | (irow: InternalRow) => { 15 | rab.fromRow(irow) 16 | } 17 | } 18 | 19 | def fromRow(schema: StructType) = { 20 | val rab = RowEncoder.apply(schema).resolveAndBind() 21 | (row: Row) => { 22 | rab.toRow(row) 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /.repo/311/source/org/apache/spark/WowRowEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * DO NOT EDIT THIS FILE DIRECTLY, ANY CHANGE MAY BE OVERWRITE 3 | */ 4 | package org.apache.spark 5 | import org.apache.spark.sql.Row 6 | import org.apache.spark.sql.catalyst.InternalRow 7 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 8 | import org.apache.spark.sql.types.StructType 9 | 10 | object WowRowEncoder { 11 | def toRow(schema: StructType) = { 12 | RowEncoder.apply(schema).resolveAndBind().createDeserializer() 13 | 14 | } 15 | 16 | def fromRow(schema: StructType) = { 17 | RowEncoder.apply(schema).resolveAndBind().createSerializer() 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PyJava 2 | 3 | This library is an ongoing effort towards bringing the data exchanging ability 4 | between Java/Scala and Python. PyJava introduces Apache Arrow as the exchanging data format, 5 | this means we can avoid ser/der between Java/Scala and Python which can really speed up the 6 | communication efficiency than traditional way. 7 | 8 | When you invoke python code in Java/Scala side, PyJava will start some python workers automatically 9 | and send the data to python worker, and once they are processed, send them back. The python workers are reused 10 | by default. 11 | 12 | 13 | > The initial code in this lib is from Apache Spark. 14 | 15 | 16 | ## Install 17 | 18 | Setup python(>= 3.6) Env(Conda is recommended): 19 | 20 | ```shell 21 | pip uninstall pyjava && pip install pyjava 22 | ``` 23 | 24 | Setup Java env(Maven is recommended): 25 | 26 | For Scala 2.11/Spark 2.4.3 27 | 28 | ```xml 29 | 30 | tech.mlsql 31 | pyjava-2.4_2.11 32 | 0.3.2 33 | 34 | ``` 35 | 36 | For Scala 2.12/Spark 3.1.1 37 | 38 | ```xml 39 | 40 | tech.mlsql 41 | pyjava-3.0_2.12 42 | 0.3.2 43 | 44 | ``` 45 | 46 | # Build Mannually 47 | 48 | Install Build Tool: 49 | 50 | ``` 51 | pip install mlsql_plugin_tool 52 | ``` 53 | 54 | Build for Spark 3.1.1: 55 | 56 | ``` 57 | mlsql_plugin_tool spark311 58 | mvn clean install -DskipTests -Pdisable-java8-doclint -Prelease-sign-artifacts 59 | ``` 60 | 61 | Build For Spark 2.4.3 62 | 63 | ``` 64 | mlsql_plugin_tool spark243 65 | mvn clean install -DskipTests -Pdisable-java8-doclint -Prelease-sign-artifacts 66 | ``` 67 | 68 | 69 | ## Using python code snippet to process data in Java/Scala 70 | 71 | With pyjava, you can run any python code in your Java/Scala application. 72 | 73 | ```scala 74 | 75 | val envs = new util.HashMap[String, String]() 76 | // prepare python environment 77 | envs.put(str(PythonConf.PYTHON_ENV), "source activate dev && export ARROW_PRE_0_15_IPC_FORMAT=1 ") 78 | 79 | // describe the data which will be transfered to python 80 | val sourceSchema = StructType(Seq(StructField("value", StringType))) 81 | 82 | val batch = new ArrowPythonRunner( 83 | Seq(ChainedPythonFunctions(Seq(PythonFunction( 84 | """ 85 | |import pandas as pd 86 | |import numpy as np 87 | | 88 | |def process(): 89 | | for item in context.fetch_once_as_rows(): 90 | | item["value1"] = item["value"] + "_suffix" 91 | | yield item 92 | | 93 | |context.build_result(process()) 94 | """.stripMargin, envs, "python", "3.6")))), sourceSchema, 95 | "GMT", Map() 96 | ) 97 | 98 | // prepare data 99 | val sourceEnconder = RowEncoder.apply(sourceSchema).resolveAndBind() 100 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow => 101 | sourceEnconder.toRow(irow).copy() 102 | }.iterator 103 | 104 | // run the code and get the return result 105 | val javaConext = new JavaContext 106 | val commonTaskContext = new AppContextImpl(javaConext, batch) 107 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext) 108 | 109 | //f.copy(), copy function is required 110 | columnarBatchIter.flatMap { batch => 111 | batch.rowIterator.asScala 112 | }.foreach(f => println(f.copy())) 113 | javaConext.markComplete 114 | javaConext.close 115 | ``` 116 | 117 | ## Using python code snippet to process data in Spark 118 | 119 | ```scala 120 | val session = spark 121 | import session.implicits._ 122 | val timezoneid = session.sessionState.conf.sessionLocalTimeZone 123 | val df = session.createDataset[String](Seq("a1", "b1")).toDF("value") 124 | val struct = df.schema 125 | val abc = df.rdd.mapPartitions { iter => 126 | val enconder = RowEncoder.apply(struct).resolveAndBind() 127 | val envs = new util.HashMap[String, String]() 128 | envs.put(str(PythonConf.PYTHON_ENV), "source activate streamingpro-spark-2.4.x") 129 | val batch = new ArrowPythonRunner( 130 | Seq(ChainedPythonFunctions(Seq(PythonFunction( 131 | """ 132 | |import pandas as pd 133 | |import numpy as np 134 | |for item in data_manager.fetch_once(): 135 | | print(item) 136 | |df = pd.DataFrame({'AAA': [4, 5, 6, 7],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]}) 137 | |data_manager.set_output([[df['AAA'],df['BBB']]]) 138 | """.stripMargin, envs, "python", "3.6")))), struct, 139 | timezoneid, Map() 140 | ) 141 | val newIter = iter.map { irow => 142 | enconder.toRow(irow) 143 | } 144 | val commonTaskContext = new SparkContextImp(TaskContext.get(), batch) 145 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext) 146 | columnarBatchIter.flatMap { batch => 147 | batch.rowIterator.asScala.map(_.copy) 148 | } 149 | } 150 | 151 | val wow = SparkUtils.internalCreateDataFrame(session, abc, StructType(Seq(StructField("AAA", LongType), StructField("BBB", LongType))), false) 152 | wow.show() 153 | ``` 154 | 155 | ## Run Python Project 156 | 157 | With Pyjava, you can tell the system where is the python project and which is then entrypoint, 158 | then you can run this project in Java/Scala. 159 | 160 | ```scala 161 | import tech.mlsql.arrow.python.runner.PythonProjectRunner 162 | 163 | val runner = new PythonProjectRunner("./pyjava/examples/pyproject1", Map()) 164 | val output = runner.run(Seq("bash", "-c", "source activate dev && python train.py"), Map( 165 | "tempDataLocalPath" -> "/tmp/data", 166 | "tempModelLocalPath" -> "/tmp/model" 167 | )) 168 | output.foreach(println) 169 | ``` 170 | 171 | 172 | ## Example In MLSQL 173 | 174 | ### None Interactive Mode: 175 | 176 | ```sql 177 | !python env "PYTHON_ENV=source activate streamingpro-spark-2.4.x"; 178 | !python conf "schema=st(field(a,long),field(b,long))"; 179 | 180 | select 1 as a as table1; 181 | 182 | !python on table1 ''' 183 | 184 | import pandas as pd 185 | import numpy as np 186 | for item in data_manager.fetch_once(): 187 | print(item) 188 | df = pd.DataFrame({'AAA': [4, 5, 6, 8],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]}) 189 | data_manager.set_output([[df['AAA'],df['BBB']]]) 190 | 191 | ''' named mlsql_temp_table2; 192 | 193 | select * from mlsql_temp_table2 as output; 194 | ``` 195 | 196 | ### Interactive Mode: 197 | 198 | ```sql 199 | !python start; 200 | 201 | !python env "PYTHON_ENV=source activate streamingpro-spark-2.4.x"; 202 | !python env "schema=st(field(a,integer),field(b,integer))"; 203 | 204 | 205 | !python ''' 206 | import pandas as pd 207 | import numpy as np 208 | '''; 209 | 210 | !python ''' 211 | for item in data_manager.fetch_once(): 212 | print(item) 213 | df = pd.DataFrame({'AAA': [4, 5, 6, 8],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]}) 214 | data_manager.set_output([[df['AAA'],df['BBB']]]) 215 | '''; 216 | !python close; 217 | ``` 218 | 219 | 220 | 221 | ## Using PyJava as Arrow Server/Client 222 | 223 | Java Server side: 224 | 225 | ```scala 226 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin") 227 | 228 | val dataSchema = StructType(Seq(StructField("value", StringType))) 229 | val enconder = RowEncoder.apply(dataSchema).resolveAndBind() 230 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow => 231 | enconder.toRow(irow) 232 | }.iterator 233 | val javaConext = new JavaContext 234 | val commonTaskContext = new AppContextImpl(javaConext, null) 235 | 236 | val Array(_, host, port) = socketRunner.serveToStreamWithArrow(newIter, dataSchema, 10, commonTaskContext) 237 | println(s"${host}:${port}") 238 | Thread.currentThread().join() 239 | ``` 240 | 241 | Python Client side: 242 | 243 | ```python 244 | import os 245 | import socket 246 | 247 | from pyjava.serializers import \ 248 | ArrowStreamPandasSerializer 249 | 250 | out_ser = ArrowStreamPandasSerializer(None, True, True) 251 | 252 | out_ser = ArrowStreamPandasSerializer("Asia/Harbin", False, None) 253 | HOST = "" 254 | PORT = -1 255 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 256 | sock.connect((HOST, PORT)) 257 | buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) 258 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size) 259 | outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size) 260 | kk = out_ser.load_stream(infile) 261 | for item in kk: 262 | print(item) 263 | ``` 264 | 265 | Python Server side: 266 | 267 | ```python 268 | import os 269 | 270 | import pandas as pd 271 | 272 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1" 273 | from pyjava.api.serve import OnceServer 274 | 275 | ddata = pd.DataFrame(data=[[1, 2, 3, 4], [2, 3, 4, 5]]) 276 | 277 | server = OnceServer("127.0.0.1", 11111, "Asia/Harbin") 278 | server.bind() 279 | server.serve([{'id': 9, 'label': 1}]) 280 | ``` 281 | 282 | Java Client side: 283 | 284 | ```scala 285 | import org.apache.spark.sql.Row 286 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 287 | import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} 288 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 289 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext} 290 | import tech.mlsql.arrow.python.runner.SparkSocketRunner 291 | import tech.mlsql.common.utils.network.NetUtils 292 | 293 | val enconder = RowEncoder.apply(StructType(Seq(StructField("a", LongType),StructField("b", LongType)))).resolveAndBind() 294 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin") 295 | val javaConext = new JavaContext 296 | val commonTaskContext = new AppContextImpl(javaConext, null) 297 | val iter = socketRunner.readFromStreamWithArrow("127.0.0.1", 11111, commonTaskContext) 298 | iter.foreach(i => println(enconder.fromRow(i.copy()))) 299 | javaConext.close 300 | ``` 301 | 302 | ## How to configure python worker runs in Docker (todo) 303 | 304 | 305 | -------------------------------------------------------------------------------- /dev/change-scala-version.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | set -e 21 | 22 | VALID_VERSIONS=( 2.11 2.12 ) 23 | 24 | usage() { 25 | echo "Usage: $(basename $0) [-h|--help] 26 | where : 27 | -h| --help Display this help text 28 | valid version values : ${VALID_VERSIONS[*]} 29 | " 1>&2 30 | exit 1 31 | } 32 | 33 | if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then 34 | usage 35 | fi 36 | 37 | TO_VERSION=$1 38 | 39 | check_scala_version() { 40 | for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done 41 | echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 42 | exit 1 43 | } 44 | 45 | check_scala_version "$TO_VERSION" 46 | 47 | if [ $TO_VERSION = "2.12" ]; then 48 | FROM_VERSION="2.11" 49 | else 50 | FROM_VERSION="2.12" 51 | fi 52 | 53 | sed_i() { 54 | sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" 55 | } 56 | 57 | export -f sed_i 58 | 59 | BASEDIR=$(dirname $0)/.. 60 | find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \ 61 | -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \; 62 | 63 | # Also update in parent POM 64 | # Match any scala binary version to ensure idempotency 65 | sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION' 2 | 5 | 4.0.0 6 | 7 | tech.mlsql 8 | pyjava-3.0_2.12 9 | 0.3.2 10 | pyjava with arrow support 11 | https://github.com/allwefantasy/pyjava 12 | 13 | Communication between Python And Java with Apache Arrow. 14 | 15 | 16 | 17 | Apache 2.0 License 18 | http://www.apache.org/licenses/LICENSE-2.0.html 19 | repo 20 | 21 | 22 | 23 | 24 | allwefantasy 25 | ZhuHaiLin 26 | allwefantasy@gmail.com 27 | 28 | 29 | 30 | 31 | scm:git:git@github.com:allwefantasy/pyjava.git 32 | 33 | 34 | scm:git:git@github.com:allwefantasy/pyjava.git 35 | 36 | https://github.com/allwefantasy/pyjava 37 | 38 | 39 | https://github.com/allwefantasy/pyjava/issues 40 | 41 | 42 | 43 | UTF-8 44 | 45 | 2021.1.12 46 | 47 | 2.12.10 48 | 2.12 49 | 3.1.1 50 | 3.0 51 | 2.0.0 52 | 53 | 16.0 54 | 4.5.3 55 | provided 56 | 2.6.5 57 | 0.3.6 58 | 2.0.6 59 | 60 | 61 | 62 | 63 | 64 | disable-java8-doclint 65 | 66 | [1.8,) 67 | 68 | 69 | -Xdoclint:none 70 | none 71 | 72 | 73 | 74 | release-sign-artifacts 75 | 76 | 77 | performRelease 78 | true 79 | 80 | 81 | 82 | 83 | 84 | org.apache.maven.plugins 85 | maven-gpg-plugin 86 | 1.5 87 | 88 | 89 | sign-artifacts 90 | verify 91 | 92 | sign 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | tech.mlsql 105 | common-utils_${scala.binary.version} 106 | ${common-utils-version} 107 | 108 | 109 | org.scalactic 110 | scalactic_${scala.binary.version} 111 | 3.0.0 112 | test 113 | 114 | 115 | org.scalatest 116 | scalatest_${scala.binary.version} 117 | 3.0.0 118 | test 119 | 120 | 121 | org.apache.arrow 122 | arrow-vector 123 | ${arrow.version} 124 | 125 | 126 | com.fasterxml.jackson.core 127 | jackson-annotations 128 | 129 | 130 | com.fasterxml.jackson.core 131 | jackson-databind 132 | 133 | 134 | io.netty 135 | netty-buffer 136 | 137 | 138 | io.netty 139 | netty-common 140 | 141 | 142 | io.netty 143 | netty-handler 144 | 145 | 146 | 147 | 148 | 149 | 150 | com.fasterxml.jackson.core 151 | jackson-core 152 | 2.10.1 153 | 154 | 155 | 156 | 157 | 158 | org.apache.spark 159 | spark-core_${scala.binary.version} 160 | ${spark.version} 161 | ${scope} 162 | 163 | 164 | org.apache.spark 165 | spark-sql_${scala.binary.version} 166 | ${spark.version} 167 | ${scope} 168 | 169 | 170 | 171 | org.apache.spark 172 | spark-mllib_${scala.binary.version} 173 | ${spark.version} 174 | ${scope} 175 | 176 | 177 | 178 | org.apache.spark 179 | spark-graphx_${scala.binary.version} 180 | ${spark.version} 181 | ${scope} 182 | 183 | 184 | 185 | org.apache.spark 186 | spark-catalyst_${scala.binary.version} 187 | ${spark.version} 188 | tests 189 | test 190 | 191 | 192 | 193 | org.apache.spark 194 | spark-core_${scala.binary.version} 195 | ${spark.version} 196 | tests 197 | test 198 | 199 | 200 | 201 | org.apache.spark 202 | spark-sql_${scala.binary.version} 203 | ${spark.version} 204 | tests 205 | test 206 | 207 | 208 | 209 | org.pegdown 210 | pegdown 211 | 1.6.0 212 | test 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | src/main/resources 221 | 222 | 223 | 224 | 225 | org.apache.maven.plugins 226 | maven-surefire-plugin 227 | 3.0.0-M1 228 | 229 | 1 230 | true 231 | -Xmx4024m 232 | 233 | **/*.java 234 | **/*.scala 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | org.scala-tools 243 | maven-scala-plugin 244 | 2.15.2 245 | 246 | 247 | 248 | -g:vars 249 | 250 | 251 | true 252 | 253 | 254 | 255 | compile 256 | 257 | compile 258 | 259 | compile 260 | 261 | 262 | testCompile 263 | 264 | testCompile 265 | 266 | test 267 | 268 | 269 | process-resources 270 | 271 | compile 272 | 273 | 274 | 275 | 276 | 277 | 278 | org.apache.maven.plugins 279 | maven-compiler-plugin 280 | 2.3.2 281 | 282 | 283 | -g 284 | true 285 | 1.8 286 | 1.8 287 | 288 | 289 | 290 | 291 | 292 | 293 | maven-source-plugin 294 | 2.1 295 | 296 | true 297 | 298 | 299 | 300 | compile 301 | 302 | jar 303 | 304 | 305 | 306 | 307 | 308 | org.apache.maven.plugins 309 | maven-javadoc-plugin 310 | 311 | 312 | attach-javadocs 313 | 314 | jar 315 | 316 | 317 | 318 | 319 | 320 | org.sonatype.plugins 321 | nexus-staging-maven-plugin 322 | 1.6.7 323 | true 324 | 325 | sonatype-nexus-staging 326 | https://oss.sonatype.org/ 327 | true 328 | 329 | 330 | 331 | 332 | org.scalatest 333 | scalatest-maven-plugin 334 | 2.0.0 335 | 336 | streaming.core.NotToRunTag 337 | ${project.build.directory}/surefire-reports 338 | . 339 | WDF TestSuite.txt 340 | ${project.build.directory}/html/scalatest 341 | false 342 | 343 | 344 | 345 | test 346 | 347 | test 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | sonatype-nexus-snapshots 381 | https://oss.sonatype.org/content/repositories/snapshots 382 | 383 | 384 | sonatype-nexus-staging 385 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 386 | 387 | 388 | 389 | 390 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | ## Python side in PyJava 2 | 3 | This library is designed for Java/Scala invoking python code with Apache 4 | Arrow as the exchanging data format. This means if you want to run some python code 5 | in Scala/Java, the Scala/Java will start some python workers with this lib's help. 6 | Then the Scala/Java will send the data/code to python worker in socket and the python will 7 | return the processed data back. 8 | 9 | The initial code in this lib is from Apache Spark. 10 | 11 | -------------------------------------------------------------------------------- /python/install.sh: -------------------------------------------------------------------------------- 1 | project=pyjava 2 | version=0.3.3 3 | rm -rf ./dist/* 4 | pip uninstall -y ${project} 5 | python setup.py sdist bdist_wheel 6 | cd ./dist/ 7 | pip install ${project}-${version}-py3-none-any.whl && cd - 8 | twine upload dist/* 9 | -------------------------------------------------------------------------------- /python/pyjava/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import types 19 | from functools import wraps 20 | 21 | 22 | def since(version): 23 | """ 24 | A decorator that annotates a function to append the version of Spark the function was added. 25 | """ 26 | import re 27 | indent_p = re.compile(r'\n( +)') 28 | 29 | def deco(f): 30 | indents = indent_p.findall(f.__doc__) 31 | indent = ' ' * (min(len(m) for m in indents) if indents else 0) 32 | f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) 33 | return f 34 | 35 | return deco 36 | 37 | 38 | def copy_func(f, name=None, sinceversion=None, doc=None): 39 | """ 40 | Returns a function with same code, globals, defaults, closure, and 41 | name (or provide a new name). 42 | """ 43 | # See 44 | # http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python 45 | fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__, 46 | f.__closure__) 47 | # in case f was given attrs (note this dict is a shallow copy): 48 | fn.__dict__.update(f.__dict__) 49 | if doc is not None: 50 | fn.__doc__ = doc 51 | if sinceversion is not None: 52 | fn = since(sinceversion)(fn) 53 | return fn 54 | 55 | 56 | def keyword_only(func): 57 | """ 58 | A decorator that forces keyword arguments in the wrapped method 59 | and saves actual input keyword arguments in `_input_kwargs`. 60 | 61 | .. note:: Should only be used to wrap a method where first arg is `self` 62 | """ 63 | 64 | @wraps(func) 65 | def wrapper(self, *args, **kwargs): 66 | if len(args) > 0: 67 | raise TypeError("Method %s forces keyword arguments." % func.__name__) 68 | self._input_kwargs = kwargs 69 | return func(self, **kwargs) 70 | 71 | return wrapper 72 | -------------------------------------------------------------------------------- /python/pyjava/api/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, NoReturn, Callable, Dict, List 2 | import tempfile 3 | import os 4 | 5 | from pyjava.api.mlsql import PythonContext 6 | 7 | 8 | class Utils(object): 9 | 10 | @staticmethod 11 | def show_plt(plt: Any, context: PythonContext) -> NoReturn: 12 | content = Utils.gen_img(plt) 13 | context.build_result([{"content": content, "mime": "image"}]) 14 | 15 | @staticmethod 16 | def gen_img(plt: Any) -> str: 17 | import base64 18 | import uuid 19 | img_path = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()) + ".png") 20 | plt.savefig(img_path) 21 | with open(img_path, mode='rb') as file: 22 | file_content = base64.b64encode(file.read()).decode() 23 | return file_content 24 | -------------------------------------------------------------------------------- /python/pyjava/api/mlsql.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socket 4 | from distutils.version import StrictVersion 5 | import uuid 6 | 7 | import pandas as pd 8 | import sys 9 | from typing import Dict 10 | 11 | import pyjava.utils as utils 12 | import requests 13 | from pyjava.serializers import ArrowStreamSerializer 14 | from pyjava.serializers import read_int 15 | from pyjava.utils import utf8_deserializer 16 | from pyjava.storage import streaming_tar 17 | 18 | if sys.version >= '3': 19 | basestring = str 20 | else: 21 | pass 22 | 23 | 24 | class DataServer(object): 25 | def __init__(self, host, port, timezone): 26 | self.host = host 27 | self.port = port 28 | self.timezone = timezone 29 | 30 | 31 | class LogClient(object): 32 | def __init__(self, conf): 33 | self.conf = conf 34 | if 'spark.mlsql.log.driver.url' in self.conf: 35 | self.url = self.conf['spark.mlsql.log.driver.url'] 36 | self.log_user = self.conf['PY_EXECUTE_USER'] 37 | self.log_token = self.conf['spark.mlsql.log.driver.token'] 38 | self.log_group_id = self.conf['groupId'] 39 | 40 | def log_to_driver(self, msg): 41 | if 'spark.mlsql.log.driver.url' not in self.conf: 42 | if self.conf['PY_EXECUTE_USER'] and self.conf['groupId']: 43 | logging.info("[owner] [{}] [groupId] [{}] __MMMMMM__ {}".format(self.conf['PY_EXECUTE_USER'], 44 | self.conf['groupId'], msg)) 45 | else: 46 | logging.info(msg) 47 | return 48 | import json 49 | resp = json.dumps( 50 | {"sendLog": { 51 | "token": self.log_token, 52 | "logLine": "[owner] [{}] [groupId] [{}] __MMMMMM__ {}".format(self.log_user, self.log_group_id, msg) 53 | }}, ensure_ascii=False) 54 | requests.post(self.url, data=resp, headers={'content-type': 'application/x-www-form-urlencoded;charset=UTF-8'}) 55 | 56 | def close(self): 57 | if hasattr(self, "conn"): 58 | self.conn.close() 59 | self.conn = None 60 | 61 | 62 | class PythonContext(object): 63 | cache = {} 64 | 65 | def __init__(self, context_id, iterator, conf): 66 | self.context_id = context_id 67 | self.data_mmap_file_ref = {} 68 | self.input_data = iterator 69 | self.output_data = [[]] 70 | self.conf = conf 71 | self.schema = "" 72 | self.have_fetched = False 73 | self.log_client = LogClient(self.conf) 74 | if "pythonMode" in conf and conf["pythonMode"] == "ray": 75 | self.rayContext = RayContext(self) 76 | 77 | def set_output(self, value, schema=""): 78 | self.output_data = value 79 | self.schema = schema 80 | 81 | @staticmethod 82 | def build_chunk_result(items, block_size=1024): 83 | buffer = [] 84 | for item in items: 85 | buffer.append(item) 86 | if len(buffer) == block_size: 87 | df = pd.DataFrame(buffer, columns=buffer[0].keys()) 88 | buffer.clear() 89 | yield df 90 | 91 | if len(buffer) > 0: 92 | df = pd.DataFrame(buffer, columns=buffer[0].keys()) 93 | buffer.clear() 94 | yield df 95 | 96 | def build_result(self, items, block_size=1024): 97 | self.output_data = ([df[name] for name in df] 98 | for df in PythonContext.build_chunk_result(items, block_size)) 99 | 100 | def build_result_from_dir(self, target_dir, block_size=1024): 101 | items = streaming_tar.build_rows_from_file(target_dir) 102 | self.build_result(items, block_size) 103 | 104 | def output(self): 105 | return self.output_data 106 | 107 | def __del__(self): 108 | logging.info("==clean== context") 109 | if self.log_client is not None: 110 | try: 111 | self.log_client.close() 112 | except Exception as e: 113 | pass 114 | 115 | if 'data_mmap_file_ref' in self.data_mmap_file_ref: 116 | try: 117 | self.data_mmap_file_ref['data_mmap_file_ref'].close() 118 | except Exception as e: 119 | pass 120 | 121 | def noops_fetch(self): 122 | for item in self.fetch_once(): 123 | pass 124 | 125 | def fetch_once_as_dataframe(self): 126 | for df in self.fetch_once(): 127 | yield df 128 | 129 | def fetch_once_as_rows(self): 130 | for df in self.fetch_once_as_dataframe(): 131 | for row in df.to_dict('records'): 132 | yield row 133 | 134 | def fetch_once_as_batch_rows(self): 135 | for df in self.fetch_once_as_dataframe(): 136 | yield (row for row in df.to_dict('records')) 137 | 138 | def fetch_once(self): 139 | import pyarrow as pa 140 | if self.have_fetched: 141 | raise Exception("input data can only be fetched once") 142 | self.have_fetched = True 143 | for items in self.input_data: 144 | yield pa.Table.from_batches([items]).to_pandas() 145 | 146 | def fetch_as_dir(self, target_dir): 147 | if len(self.data_servers()) > 1: 148 | raise Exception("Please make sure you have only one partition on Java/Spark Side") 149 | items = self.fetch_once_as_rows() 150 | streaming_tar.save_rows_as_file(items, target_dir) 151 | 152 | 153 | class PythonProjectContext(object): 154 | def __init__(self): 155 | self.params_read = False 156 | self.conf = {} 157 | self.read_params_once() 158 | self.log_client = LogClient(self.conf) 159 | 160 | def read_params_once(self): 161 | if not self.params_read: 162 | self.params_read = True 163 | infile = sys.stdin.buffer 164 | for i in range(read_int(infile)): 165 | k = utf8_deserializer.loads(infile) 166 | v = utf8_deserializer.loads(infile) 167 | self.conf[k] = v 168 | 169 | def input_data_dir(self): 170 | return self.conf["tempDataLocalPath"] 171 | 172 | def output_model_dir(self): 173 | return self.conf["tempModelLocalPath"] 174 | 175 | def __del__(self): 176 | self.log_client.close() 177 | 178 | 179 | class RayContext(object): 180 | cache = {} 181 | conn_cache = {} 182 | 183 | def __init__(self, python_context): 184 | self.python_context = python_context 185 | self.servers = [] 186 | self.server_ids_in_ray = [] 187 | self.is_setup = False 188 | self.is_dev = utils.is_dev() 189 | self.is_in_mlsql = True 190 | self.mock_data = [] 191 | if "directData" not in python_context.conf: 192 | for item in self.python_context.fetch_once_as_rows(): 193 | self.server_ids_in_ray.append(str(uuid.uuid4())) 194 | self.servers.append(DataServer( 195 | item["host"], int(item["port"]), item["timezone"])) 196 | 197 | def data_servers(self): 198 | return self.servers 199 | 200 | def conf(self): 201 | return self.python_context.conf 202 | 203 | def data_servers_in_ray(self): 204 | import ray 205 | from pyjava.rayfix import RayWrapper 206 | rayw = RayWrapper() 207 | for server_id in self.server_ids_in_ray: 208 | server = rayw.get_actor(server_id) 209 | yield ray.get(server.connect_info.remote()) 210 | 211 | def build_servers_in_ray(self): 212 | from pyjava.rayfix import RayWrapper 213 | from pyjava.api.serve import RayDataServer 214 | import ray 215 | buffer = [] 216 | rayw = RayWrapper() 217 | for (server_id, java_server) in zip(self.server_ids_in_ray, self.servers): 218 | # rds = RayDataServer.options(name=server_id, detached=True, max_concurrency=2).remote(server_id, java_server, 219 | # 0, 220 | # java_server.timezone) 221 | rds = rayw.options(RayDataServer, name=server_id, detached=True, max_concurrency=2).remote(server_id, 222 | java_server, 223 | 0, 224 | java_server.timezone) 225 | res = ray.get(rds.connect_info.remote()) 226 | logging.debug("build ray data server server_id:{} java_server: {} servers:{}".format(server_id, 227 | str(vars( 228 | java_server)), 229 | str(vars(res)))) 230 | buffer.append(res) 231 | return buffer 232 | 233 | @staticmethod 234 | def connect(_context, url, **kwargs): 235 | if isinstance(_context, PythonContext): 236 | context = _context 237 | elif isinstance(_context, dict): 238 | if 'context' in _context: 239 | context = _context['context'] 240 | else: 241 | ''' 242 | we are not in MLSQL 243 | ''' 244 | context = PythonContext("", [], {"pythonMode": "ray"}) 245 | context.rayContext.is_in_mlsql = False 246 | else: 247 | raise Exception("context is not detect. make sure it's in globals().") 248 | 249 | if url == "local": 250 | from pyjava.rayfix import RayWrapper 251 | ray = RayWrapper() 252 | if ray.ray_version < StrictVersion('1.6.0'): 253 | raise Exception("URL:local is only support in ray >= 1.6.0") 254 | # if not ray.ray_instance.is_initialized: 255 | ray.ray_instance.shutdown() 256 | ray.ray_instance.init(namespace="default") 257 | 258 | elif url is not None: 259 | from pyjava.rayfix import RayWrapper 260 | ray = RayWrapper() 261 | is_udf_client = context.conf.get("UDF_CLIENT") 262 | if is_udf_client is None: 263 | ray.shutdown() 264 | ray.init(url, **kwargs) 265 | if is_udf_client and url not in RayContext.conn_cache: 266 | ray.init(url, **kwargs) 267 | RayContext.conn_cache[url] = 1 268 | 269 | return context.rayContext 270 | 271 | def setup(self, func_for_row, func_for_rows=None): 272 | if self.is_setup: 273 | raise ValueError("setup can be only invoke once") 274 | self.is_setup = True 275 | 276 | is_data_mode = "dataMode" in self.conf() and self.conf()["dataMode"] == "data" 277 | 278 | if not is_data_mode: 279 | raise Exception(''' 280 | Please setup dataMode as data instead of model. 281 | Try run: `!python conf "dataMode=data"` or 282 | add comment like: `#%dataMode=data` if you are in notebook. 283 | ''') 284 | 285 | import ray 286 | from pyjava.rayfix import RayWrapper 287 | rayw = RayWrapper() 288 | 289 | if not self.is_in_mlsql: 290 | if func_for_rows is not None: 291 | func = ray.remote(func_for_rows) 292 | return ray.get(func.remote(self.mock_data)) 293 | else: 294 | func = ray.remote(func_for_row) 295 | 296 | def iter_all(rows): 297 | return [ray.get(func.remote(row)) for row in rows] 298 | 299 | iter_all_func = ray.remote(iter_all) 300 | return ray.get(iter_all_func.remote(self.mock_data)) 301 | 302 | buffer = [] 303 | for server_info in self.build_servers_in_ray(): 304 | server = rayw.get_actor(server_info.server_id) 305 | rci = ray.get(server.connect_info.remote()) 306 | buffer.append(rci) 307 | server.serve.remote(func_for_row, func_for_rows) 308 | items = [vars(server) for server in buffer] 309 | self.python_context.build_result(items, 1024) 310 | return buffer 311 | 312 | def foreach(self, func_for_row): 313 | return self.setup(func_for_row) 314 | 315 | def map_iter(self, func_for_rows): 316 | return self.setup(None, func_for_rows) 317 | 318 | def collect(self): 319 | for shard in self.data_servers(): 320 | for row in RayContext.fetch_once_as_rows(shard): 321 | yield row 322 | 323 | def fetch_as_dir(self, target_dir, servers=None): 324 | if not servers: 325 | servers = self.data_servers() 326 | if len(servers) > 1: 327 | raise Exception("Please make sure you have only one partition on Java/Spark Side") 328 | 329 | items = RayContext.collect_from(servers) 330 | streaming_tar.save_rows_as_file(items, target_dir) 331 | 332 | def build_result(self, items, block_size=1024): 333 | self.python_context.build_result(items, block_size) 334 | 335 | def build_result_from_dir(self, target_path): 336 | self.python_context.build_result_from_dir(target_path) 337 | 338 | @staticmethod 339 | def parse_servers(host_ports): 340 | hosts = host_ports.split(",") 341 | hosts = [item.split(":") for item in hosts] 342 | return [DataServer(item[0], int(item[1]), "") for item in hosts] 343 | 344 | @staticmethod 345 | def fetch_as_repeatable_file(context_id, data_servers, file_ref, batch_size): 346 | import pyarrow as pa 347 | 348 | def inner_fetch(): 349 | for data_server in data_servers: 350 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 351 | out_ser = ArrowStreamSerializer() 352 | sock.connect((data_server.host, data_server.port)) 353 | buffer_size = int(os.environ.get("BUFFER_SIZE", 65536)) 354 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size) 355 | result = out_ser.load_stream(infile) 356 | for batch in result: 357 | yield batch 358 | 359 | def gen_by_batch(): 360 | import numpy as np 361 | import math 362 | if 'data_mmap_file_ref' not in file_ref: 363 | file_ref['data_mmap_file_ref'] = pa.memory_map(context_id + "/__input__.dat") 364 | reader = pa.ipc.open_file(file_ref['data_mmap_file_ref']) 365 | num_record_batches = reader.num_record_batches 366 | for i in range(num_record_batches): 367 | df = reader.get_batch(i).to_pandas() 368 | for small_batch in np.array_split(df, math.floor(df.shape[0] / batch_size)): 369 | yield small_batch 370 | 371 | if 'data_mmap_file_ref' in file_ref: 372 | return gen_by_batch() 373 | else: 374 | writer = None 375 | for batch in inner_fetch(): 376 | if writer is None: 377 | writer = pa.RecordBatchFileWriter(context_id + "/__input__.dat", batch.schema) 378 | writer.write_batch(batch) 379 | writer.close() 380 | return gen_by_batch() 381 | 382 | def collect_as_file(self, batch_size): 383 | data_servers = self.data_servers() 384 | python_context = self.python_context 385 | return RayContext.fetch_as_repeatable_file(python_context.context_id, data_servers, 386 | python_context.data_mmap_file_ref, 387 | batch_size) 388 | 389 | @staticmethod 390 | def collect_from(servers): 391 | for shard in servers: 392 | for row in RayContext.fetch_once_as_rows(shard): 393 | yield row 394 | 395 | def to_pandas(self): 396 | items = [row for row in self.collect()] 397 | return pd.DataFrame(data=items) 398 | 399 | @staticmethod 400 | def fetch_once_as_rows(data_server): 401 | for df in RayContext.fetch_data_from_single_data_server(data_server): 402 | for row in df.to_dict('records'): 403 | yield row 404 | 405 | @staticmethod 406 | def fetch_data_from_single_data_server(data_server): 407 | out_ser = ArrowStreamSerializer() 408 | import pyarrow as pa 409 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 410 | sock.connect((data_server.host, data_server.port)) 411 | buffer_size = int(os.environ.get("BUFFER_SIZE", 65536)) 412 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size) 413 | result = out_ser.load_stream(infile) 414 | for items in result: 415 | yield pa.Table.from_batches([items]).to_pandas() 416 | -------------------------------------------------------------------------------- /python/pyjava/api/serve.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socket 4 | import traceback 5 | 6 | import ray 7 | 8 | import pyjava.utils as utils 9 | from pyjava.api.mlsql import RayContext 10 | from pyjava.rayfix import RayWrapper 11 | from pyjava.serializers import \ 12 | write_with_length, \ 13 | write_int, read_int, \ 14 | SpecialLengths, ArrowStreamPandasSerializer 15 | 16 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1" 17 | 18 | 19 | class SocketNotBindException(Exception): 20 | def __init__(self, message): 21 | Exception.__init__(self) 22 | self.message = message 23 | 24 | 25 | class DataServerWithId(object): 26 | def __init__(self, host, port, server_id): 27 | self.host = host 28 | self.port = port 29 | self.server_id = server_id 30 | 31 | 32 | class OnceServer(object): 33 | def __init__(self, host, port, timezone): 34 | self.host = host 35 | self.port = port 36 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 37 | self.socket.settimeout(5 * 60) 38 | self.out_ser = ArrowStreamPandasSerializer(timezone, False, None) 39 | self.is_bind = False 40 | self.is_dev = utils.is_dev() 41 | 42 | def bind(self): 43 | try: 44 | self.socket.bind((self.host, self.port)) 45 | self.is_bind = True 46 | self.socket.listen(1) 47 | except Exception: 48 | print(traceback.format_exc()) 49 | 50 | return self.socket.getsockname() 51 | 52 | def close(self): 53 | self.socket.close() 54 | 55 | def serve(self, data): 56 | from pyjava.api.mlsql import PythonContext 57 | if not self.is_bind: 58 | raise SocketNotBindException( 59 | "Please invoke server.bind() before invoke server.serve") 60 | conn, addr = self.socket.accept() 61 | sockfile = conn.makefile("rwb", int( 62 | os.environ.get("BUFFER_SIZE", 65536))) 63 | infile = sockfile # os.fdopen(os.dup(conn.fileno()), "rb", 65536) 64 | out = sockfile # os.fdopen(os.dup(conn.fileno()), "wb", 65536) 65 | try: 66 | write_int(SpecialLengths.START_ARROW_STREAM, out) 67 | out_data = ([df[name] for name in df] for df in 68 | PythonContext.build_chunk_result(data, 1024)) 69 | self.out_ser.dump_stream(out_data, out) 70 | 71 | write_int(SpecialLengths.END_OF_DATA_SECTION, out) 72 | write_int(SpecialLengths.END_OF_STREAM, out) 73 | out.flush() 74 | if self.is_dev: 75 | print("all data in ray task have been consumed.") 76 | read_int(infile) 77 | except Exception: 78 | try: 79 | write_int(SpecialLengths.ARROW_STREAM_CRASH, out) 80 | ex = traceback.format_exc() 81 | print(ex) 82 | write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, out) 83 | write_with_length(ex.encode("utf-8"), out) 84 | out.flush() 85 | read_int(infile) 86 | except IOError: 87 | # JVM close the socket 88 | pass 89 | except Exception: 90 | # Write the error to stderr if it happened while serializing 91 | print("Py worker failed with exception:") 92 | print(traceback.format_exc()) 93 | pass 94 | 95 | conn.close() 96 | 97 | 98 | @ray.remote 99 | class RayDataServer(object): 100 | 101 | def __init__(self, server_id, java_server, port=0, timezone="Asia/Harbin"): 102 | self.server = OnceServer( 103 | RayWrapper().get_address(), port, java_server.timezone) 104 | try: 105 | (rel_host, rel_port) = self.server.bind() 106 | except Exception as e: 107 | print(traceback.format_exc()) 108 | raise e 109 | 110 | self.host = rel_host 111 | self.port = rel_port 112 | self.timezone = timezone 113 | self.server_id = server_id 114 | self.java_server = java_server 115 | self.is_dev = utils.is_dev() 116 | 117 | def serve(self, func_for_row=None, func_for_rows=None): 118 | try: 119 | if func_for_row is not None: 120 | data = (func_for_row(item) 121 | for item in RayContext.fetch_once_as_rows(self.java_server)) 122 | elif func_for_rows is not None: 123 | data = func_for_rows( 124 | RayContext.fetch_once_as_rows(self.java_server)) 125 | self.server.serve(data) 126 | except Exception as e: 127 | logging.error(f"Fail to processing data in Ray Data Server {self.host}:{self.port}") 128 | raise e 129 | finally: 130 | self.close() 131 | 132 | def close(self): 133 | try: 134 | self.server.close() 135 | ray.actor.exit_actor() 136 | except Exception: 137 | print(traceback.format_exc()) 138 | 139 | def connect_info(self): 140 | return DataServerWithId(self.host, self.port, self.server_id) 141 | -------------------------------------------------------------------------------- /python/pyjava/cache/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/cache/__init__.py -------------------------------------------------------------------------------- /python/pyjava/cache/code_cache.py: -------------------------------------------------------------------------------- 1 | class CodeCache(object): 2 | cache = {} 3 | cache_max_size = 10 * 10000 4 | 5 | @staticmethod 6 | def get(code): 7 | if code in CodeCache.cache: 8 | return CodeCache.cache[code] 9 | else: 10 | return CodeCache.gen_cache(code) 11 | 12 | @staticmethod 13 | def gen_cache(code): 14 | if len(CodeCache.cache) > CodeCache.cache_max_size: 15 | CodeCache.cache.clear() 16 | CodeCache.cache[code] = compile(code, '', 'exec') 17 | return CodeCache.cache[code] 18 | -------------------------------------------------------------------------------- /python/pyjava/daemon.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import numbers 19 | import os 20 | import signal 21 | import select 22 | import socket 23 | import sys 24 | import traceback 25 | import time 26 | import gc 27 | from errno import EINTR, EAGAIN 28 | from socket import AF_INET, SOCK_STREAM, SOMAXCONN 29 | from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT 30 | 31 | from pyjava.worker import main as worker_main 32 | from pyjava.serializers import read_int, write_int, write_with_length, UTF8Deserializer 33 | 34 | 35 | def compute_real_exit_code(exit_code): 36 | # SystemExit's code can be integer or string, but os._exit only accepts integers 37 | if isinstance(exit_code, numbers.Integral): 38 | return exit_code 39 | else: 40 | return 1 41 | 42 | 43 | def worker(sock): 44 | """ 45 | Called by a worker process after the fork(). 46 | """ 47 | signal.signal(SIGHUP, SIG_DFL) 48 | signal.signal(SIGCHLD, SIG_DFL) 49 | signal.signal(SIGTERM, SIG_DFL) 50 | # restore the handler for SIGINT, 51 | # it's useful for debugging (show the stacktrace before exit) 52 | signal.signal(SIGINT, signal.default_int_handler) 53 | 54 | # Read the socket using fdopen instead of socket.makefile() because the latter 55 | # seems to be very slow; note that we need to dup() the file descriptor because 56 | # otherwise writes also cause a seek that makes us miss data on the read side. 57 | buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) 58 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size) 59 | outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size) 60 | 61 | exit_code = 0 62 | try: 63 | worker_main(infile, outfile) 64 | except SystemExit as exc: 65 | exit_code = compute_real_exit_code(exc.code) 66 | finally: 67 | try: 68 | outfile.flush() 69 | except Exception: 70 | pass 71 | return exit_code 72 | 73 | 74 | def manager(): 75 | # Create a new process group to corral our children 76 | os.setpgid(0, 0) 77 | 78 | # Create a listening socket on the AF_INET loopback interface 79 | listen_sock = socket.socket(AF_INET, SOCK_STREAM) 80 | listen_sock.bind(('127.0.0.1', 0)) 81 | listen_sock.listen(max(1024, SOMAXCONN)) 82 | listen_host, listen_port = listen_sock.getsockname() 83 | 84 | # re-open stdin/stdout in 'wb' mode 85 | stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4) 86 | stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4) 87 | write_int(listen_port, stdout_bin) 88 | stdout_bin.flush() 89 | 90 | def shutdown(code): 91 | signal.signal(SIGTERM, SIG_DFL) 92 | # Send SIGHUP to notify workers of shutdown 93 | os.kill(0, SIGHUP) 94 | sys.exit(code) 95 | 96 | def handle_sigterm(*args): 97 | shutdown(1) 98 | signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM 99 | signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP 100 | signal.signal(SIGCHLD, SIG_IGN) 101 | 102 | reuse = os.environ.get("PY_WORKER_REUSE") 103 | 104 | # Initialization complete 105 | try: 106 | while True: 107 | try: 108 | ready_fds = select.select([0, listen_sock], [], [], 1)[0] 109 | except select.error as ex: 110 | if ex[0] == EINTR: 111 | continue 112 | else: 113 | raise 114 | 115 | if 0 in ready_fds: 116 | try: 117 | worker_pid = read_int(stdin_bin) 118 | except EOFError: 119 | # Spark told us to exit by closing stdin 120 | shutdown(0) 121 | try: 122 | os.kill(worker_pid, signal.SIGKILL) 123 | except OSError: 124 | pass # process already died 125 | 126 | if listen_sock in ready_fds: 127 | try: 128 | sock, _ = listen_sock.accept() 129 | except OSError as e: 130 | if e.errno == EINTR: 131 | continue 132 | raise 133 | 134 | # Launch a worker process 135 | try: 136 | pid = os.fork() 137 | except OSError as e: 138 | if e.errno in (EAGAIN, EINTR): 139 | time.sleep(1) 140 | pid = os.fork() # error here will shutdown daemon 141 | else: 142 | outfile = sock.makefile(mode='wb') 143 | write_int(e.errno, outfile) # Signal that the fork failed 144 | outfile.flush() 145 | outfile.close() 146 | sock.close() 147 | continue 148 | 149 | if pid == 0: 150 | # in child process 151 | listen_sock.close() 152 | 153 | # It should close the standard input in the child process so that 154 | # Python native function executions stay intact. 155 | # 156 | # Note that if we just close the standard input (file descriptor 0), 157 | # the lowest file descriptor (file descriptor 0) will be allocated, 158 | # later when other file descriptors should happen to open. 159 | # 160 | # Therefore, here we redirects it to '/dev/null' by duplicating 161 | # another file descriptor for '/dev/null' to the standard input (0). 162 | # See SPARK-26175. 163 | devnull = open(os.devnull, 'r') 164 | os.dup2(devnull.fileno(), 0) 165 | devnull.close() 166 | 167 | try: 168 | # Acknowledge that the fork was successful 169 | outfile = sock.makefile(mode="wb") 170 | write_int(os.getpid(), outfile) 171 | outfile.flush() 172 | outfile.close() 173 | while True: 174 | code = worker(sock) 175 | if not reuse or code: 176 | # wait for closing 177 | try: 178 | while sock.recv(1024): 179 | pass 180 | except Exception: 181 | pass 182 | break 183 | gc.collect() 184 | except: 185 | traceback.print_exc() 186 | os._exit(1) 187 | else: 188 | os._exit(0) 189 | else: 190 | sock.close() 191 | 192 | finally: 193 | shutdown(1) 194 | 195 | 196 | if __name__ == '__main__': 197 | manager() 198 | -------------------------------------------------------------------------------- /python/pyjava/datatype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/datatype/__init__.py -------------------------------------------------------------------------------- /python/pyjava/example/OnceServerExample.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1" 6 | from pyjava.api.serve import OnceServer 7 | 8 | ddata = pd.DataFrame(data=[[1, 2, 3, 4], [2, 3, 4, 5]]) 9 | 10 | server = OnceServer("127.0.0.1", 11111, "Asia/Harbin") 11 | server.bind() 12 | server.serve([{'id': 9, 'label': 1}]) 13 | -------------------------------------------------------------------------------- /python/pyjava/example/RayServerExample.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import sys 5 | sys.path.append("../../") 6 | 7 | from pyjava.api.mlsql import DataServer 8 | from pyjava.api.serve import RayDataServer 9 | from pyjava.rayfix import RayWrapper 10 | 11 | # import ray 12 | 13 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1" 14 | ray = RayWrapper() 15 | ray.init(address='auto', _redis_password='5241590000000000') 16 | 17 | ddata = pd.DataFrame(data=[[1, 2, 3, 4], [2, 3, 4, 5]]) 18 | 19 | server_id = "wow1" 20 | 21 | java_server = DataServer("127.0.0.1", 11111, "Asia/Harbin") 22 | rds = RayDataServer.options(name=server_id, max_concurrency=2).remote(server_id, java_server, 0, 23 | "Asia/Harbin") 24 | print(ray.get(rds.connect_info.remote())) 25 | 26 | 27 | def echo(row): 28 | return row 29 | 30 | 31 | rds.serve.remote(echo) 32 | 33 | rds = ray.get_actor(server_id) 34 | print(vars(ray.get(rds.connect_info.remote()))) 35 | -------------------------------------------------------------------------------- /python/pyjava/example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/example/__init__.py -------------------------------------------------------------------------------- /python/pyjava/example/test.py: -------------------------------------------------------------------------------- 1 | import ray 2 | 3 | ray.init() 4 | 5 | 6 | @ray.remote 7 | def slow_function(): 8 | return 1 9 | 10 | 11 | print(ray.get(slow_function.remote())) 12 | -------------------------------------------------------------------------------- /python/pyjava/example/test2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | 4 | from pyjava.serializers import ArrowStreamPandasSerializer, read_int, write_int 5 | 6 | out_ser = ArrowStreamPandasSerializer("Asia/Harbin", False, None) 7 | HOST = "127.0.0.1" 8 | PORT = 11111 9 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 10 | sock.connect((HOST, PORT)) 11 | buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) 12 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size) 13 | outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size) 14 | # arrow start 15 | print(read_int(infile)) 16 | kk = out_ser.load_stream(infile) 17 | 18 | for item in kk: 19 | print(item) 20 | # end data 21 | print(read_int(infile)) 22 | # end stream 23 | print(read_int(infile)) 24 | write_int(-4,outfile) 25 | -------------------------------------------------------------------------------- /python/pyjava/example/test3.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | request_url = "http://127.0.0.1:9007/run" 4 | 5 | 6 | def registerPyAction(enableAdmin, params={}): 7 | datas = {"code": """from pyjava.api.mlsql import PythonContext 8 | for row in context.fetch_once(): 9 | print(row) 10 | context.build_result([{"content": "{}"}], 1) 11 | """, "action": "registerPyAction", "codeName": "echo"} 12 | if enableAdmin: 13 | datas = {**datas, **{"admin_token": "admin"}} 14 | r = requests.post(request_url, data={**datas, **params}) 15 | print(r.text) 16 | print(r.status_code) 17 | 18 | 19 | # echo 20 | def pyAction(codeName, enableAdmin=True, params={}): 21 | datas = {"codeName": codeName, "action": "pyAction"} 22 | if enableAdmin: 23 | datas = {**datas, **{"admin_token": "admin"}} 24 | r = requests.post(request_url, data={**datas, **params}) 25 | print(r.text) 26 | print(r.status_code) 27 | 28 | 29 | def loadDB(db): 30 | datas = {"name": db, "action": "loadDB"} 31 | r = requests.post(request_url, data=datas) 32 | print(r.text) 33 | print(r.status_code) 34 | 35 | 36 | def http(action, params): 37 | datas = {**{"action": action}, **params} 38 | r = requests.post(request_url, data=datas) 39 | return r 40 | 41 | 42 | def controlReg(): 43 | r = http("controlReg", {"enable": "true"}) 44 | print(r.text) 45 | print(r.status_code) 46 | 47 | 48 | def printRespose(r): 49 | print(r.text) 50 | print(r.status_code) 51 | 52 | 53 | def test_db_config(db, user, password): 54 | return """ 55 | {}: 56 | host: 127.0.0.1 57 | port: 3306 58 | database: {} 59 | username: {} 60 | password: {} 61 | initialSize: 8 62 | disable: true 63 | removeAbandoned: true 64 | testWhileIdle: true 65 | removeAbandonedTimeout: 30 66 | maxWait: 100 67 | filters: stat,log4j 68 | """.format(db, db, user, password) 69 | 70 | 71 | def addDB(instanceName, db, user, password): 72 | # user-system 73 | datas = {"dbName": db, "instanceName": instanceName, 74 | "dbConfig": test_db_config(db, user, password), "admin_token": "admin"} 75 | r = http(addDB.__name__, datas) 76 | printRespose(r) 77 | 78 | 79 | def userReg(name, password): 80 | r = http(userReg.__name__, {"userName": name, "password": password}) 81 | printRespose(r) 82 | 83 | 84 | def users(): 85 | r = http(users.__name__, {}) 86 | printRespose(r) 87 | 88 | 89 | def uploadPlugin(file_path, data): 90 | values = {**data, **{"action": "uploadPlugin"}} 91 | files = {file_path.split("/")[-1]: open(file_path, 'rb')} 92 | r = requests.post(request_url, files=files, data=values) 93 | printRespose(r) 94 | 95 | 96 | def addProxy(name, value): 97 | r = http(addProxy.__name__, {"name": name, "value": value}) 98 | printRespose(r) 99 | 100 | 101 | import json 102 | 103 | 104 | def userLogin(userName, password): 105 | r = http(userLogin.__name__, {"userName": userName, "password": password}) 106 | return json.loads(r.text)[0]['token'] 107 | 108 | 109 | def enablePythonAdmin(userName, token): 110 | r = http("pyAuthAction", 111 | {"userName": userName, "access-token": token, 112 | "resourceType": "admin", 113 | "resourceName": "admin", "admin_token": "admin", "authUser": "jack"}) 114 | return r 115 | 116 | 117 | def enablePythonRegister(userName, token): 118 | r = http("pyAuthAction", 119 | {"userName": userName, "access-token": token, 120 | "resourceType": "action", 121 | "resourceName": "registerPyAction", "authUser": "william"}) 122 | return r 123 | 124 | 125 | def enablePythonExecute(userName, token, codeName): 126 | r = http("pyAuthAction", 127 | {"userName": userName, "access-token": token, 128 | "resourceType": "custom", 129 | "resourceName": codeName, "authUser": "william"}) 130 | return r 131 | 132 | 133 | # userReg() 134 | # users() 135 | # addDB("ar_plugin_repo") 136 | # pyAction("echo") 137 | # addProxy("user-system", "http://127.0.0.1:9007/run") 138 | # uploadPlugin("/Users/allwefantasy/CSDNWorkSpace/user-system/release/user-system-bin_2.11-1.0.0.jar", 139 | # {"name": "jack", "password": "123", "pluginName": "user-system"}) 140 | 141 | # addDB("user-system", "mlsql_python_predictor", "root", "mlsql") 142 | # userReg("william", "mm") 143 | # registerPyAction(False) 144 | # pyAction("echo", False) 145 | token = userLogin("jack", "123") 146 | # r = enablePythonAdmin("jack", token) 147 | # r = enablePythonExecute("jack", token, "echo") 148 | # printRespose(r) 149 | 150 | token = userLogin("william", "mm") 151 | # registerPyAction(False, {"userName": "william", "access-token": token}) 152 | # enablePythonRegister("jack", token) 153 | pyAction("echo", False, {"userName": "william", "access-token": token}) 154 | -------------------------------------------------------------------------------- /python/pyjava/rayfix.py: -------------------------------------------------------------------------------- 1 | from distutils.version import StrictVersion 2 | 3 | import ray 4 | import logging 5 | 6 | 7 | def last(func): 8 | func.__module__ = "pyjava_auto_generate__exec__" 9 | return func 10 | 11 | 12 | class RayWrapper: 13 | 14 | def __init__(self): 15 | self.ray_version = "2.0.0" if "dev" in ray.__version__ else StrictVersion(ray.__version__) 16 | self.ray_instance = ray 17 | 18 | def __getattr__(self, attr): 19 | return getattr(self.ray_instance, attr) 20 | 21 | def get_address(self): 22 | if self.ray_version >= StrictVersion('1.0.0'): 23 | return ray.get_runtime_context().worker.node_ip_address 24 | else: 25 | return ray.services.get_node_ip_address() 26 | 27 | def init(self, address, **kwargs): 28 | logging.debug(f"address {address} {kwargs}") 29 | if self.ray_version >= StrictVersion('1.4.0'): 30 | if "namespace" in kwargs.keys(): 31 | ray.util.connect(conn_str=address, **kwargs) 32 | else: 33 | ray.util.connect(conn_str=address, namespace="default", **kwargs) 34 | elif self.ray_version >= StrictVersion('1.0.0'): 35 | logging.debug(f"try to connect to ray {address}") 36 | ray.util.connect(conn_str=address, **kwargs) 37 | elif self.ray_version == StrictVersion('0.8.7'): 38 | ray.init(address=address, **kwargs) 39 | else: 40 | ray.init(redis_address=address, **kwargs) 41 | 42 | def shutdown(self): 43 | if self.ray_version >= StrictVersion('1.0.0'): 44 | try: 45 | ray.util.disconnect() 46 | except Exception as e: 47 | pass 48 | else: 49 | ray.shutdown(exiting_interpreter=False) 50 | 51 | def options(self, actor_class, **kwargs): 52 | if 'detached' in kwargs and self.ray_version >= StrictVersion('1.0.0'): 53 | del kwargs['detached'] 54 | kwargs['lifetime'] = 'detached' 55 | logging.debug(f"actor build options: {kwargs}") 56 | return actor_class.options(**kwargs) 57 | 58 | def get_actor(self, name): 59 | if self.ray_version >= StrictVersion('1.0.0'): 60 | return ray.get_actor(name) 61 | elif self.ray_version == StrictVersion('0.8.7'): 62 | return ray.get_actor(name) 63 | else: 64 | return ray.experimental.get_actor(name) 65 | -------------------------------------------------------------------------------- /python/pyjava/storage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/storage/__init__.py -------------------------------------------------------------------------------- /python/pyjava/storage/streaming_tar.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import tarfile 4 | import tempfile 5 | 6 | import uuid 7 | 8 | BLOCK_SIZE = 1024 * 64 9 | 10 | 11 | class FileStream(object): 12 | def __init__(self): 13 | self.buffer = io.BytesIO() 14 | self.offset = 0 15 | 16 | def write(self, s): 17 | self.buffer.write(s) 18 | self.offset += len(s) 19 | 20 | def tell(self): 21 | return self.offset 22 | 23 | def close(self): 24 | self.buffer.close() 25 | 26 | def pop(self): 27 | s = self.buffer.getvalue() 28 | self.buffer.close() 29 | self.buffer = io.BytesIO() 30 | return s 31 | 32 | 33 | def build_rows_from_file(target_dir): 34 | file = FileStream() 35 | start = 0 36 | for i in stream_build_tar(target_dir, file): 37 | v = file.pop() 38 | offset = len(v) 39 | if len(v) > 0: 40 | yield {"start": start, "offset": offset, "value": v} 41 | start = start + offset 42 | 43 | 44 | def save_rows_as_file(data, target_dir): 45 | if not os.path.exists(target_dir): 46 | os.makedirs(target_dir) 47 | tf_path = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) 48 | with open(tf_path, "wb") as tf: 49 | for block_row in data: 50 | tf.write(block_row["value"]) 51 | with open(tf_path, "rb") as tf: 52 | tt = tarfile.open(tf.name, mode="r:") 53 | tt.extractall(target_dir) 54 | tt.close() 55 | os.remove(tf_path) 56 | 57 | 58 | def stream_build_tar(file_dir, streaming_fp): 59 | tar = tarfile.TarFile.open(fileobj=streaming_fp, mode="w:") 60 | for root, dirs, files in os.walk(file_dir): 61 | for in_filename in files: 62 | stat = os.stat(os.path.join(root, in_filename)) 63 | tar_info = tarfile.TarInfo(os.path.join(root.lstrip(file_dir), in_filename)) 64 | # tar_info.path = root.lstrip(file_dir) 65 | tar_info.mtime = stat.st_mtime 66 | tar_info.size = stat.st_size 67 | 68 | tar.addfile(tar_info) 69 | 70 | yield 71 | 72 | with open(os.path.join(root, in_filename), 'rb') as in_fp: 73 | # total_size = 0 74 | while True: 75 | s = in_fp.read(BLOCK_SIZE) 76 | 77 | if len(s) > 0: 78 | tar.fileobj.write(s) 79 | yield 80 | 81 | if len(s) < BLOCK_SIZE: 82 | blocks, remainder = divmod(tar_info.size, tarfile.BLOCKSIZE) 83 | 84 | if remainder > 0: 85 | tar.fileobj.write(tarfile.NUL * 86 | (tarfile.BLOCKSIZE - remainder)) 87 | 88 | yield 89 | 90 | blocks += 1 91 | 92 | tar.offset += blocks * tarfile.BLOCKSIZE 93 | break 94 | tar.close() 95 | 96 | 97 | def main(): 98 | rows = build_rows_from_file("/Users/allwefantasy/data/mlsql/homes/demo/tmp/minist_model") 99 | save_rows_as_file(rows, "/Users/allwefantasy/data/mlsql/homes/demo/tmp/minist_model3") 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /python/pyjava/tests/test_context.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from pyjava.api.mlsql import RayContext 4 | 5 | 6 | class RayContextTestCase(unittest.TestCase): 7 | def test_raycontext_collect_as_file(self): 8 | ray_context = RayContext.connect(globals(), None) 9 | dfs = ray_context.collect_as_file(32) 10 | 11 | for i in range(2): 12 | print("======={}======".format(str(i))) 13 | for df in dfs: 14 | print(df) 15 | 16 | ray_context.context.build_result([{"content": "jackma"}]) 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /python/pyjava/udf/__init__.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import Any, NoReturn, Callable, Dict, List 3 | import ray 4 | import time 5 | from ray.util.client.common import ClientActorHandle, ClientObjectRef 6 | 7 | from pyjava.api.mlsql import RayContext 8 | from pyjava.storage import streaming_tar 9 | 10 | 11 | @ray.remote 12 | class UDFMaster(object): 13 | def __init__(self, num: int, conf: Dict[str, str], 14 | init_func: Callable[[List[ClientObjectRef], Dict[str, str]], Any], 15 | apply_func: Callable[[Any, Any], Any]): 16 | model_servers = RayContext.parse_servers(conf["modelServers"]) 17 | items = RayContext.collect_from(model_servers) 18 | model_refs = [ray.put(item) for item in items] 19 | self.actors = dict([(index, UDFWorker.remote(model_refs, conf, init_func, apply_func)) for index in range(num)]) 20 | self._idle_actors = [index for index in range(num)] 21 | 22 | def get(self) -> List[Any]: 23 | while len(self._idle_actors) == 0: 24 | time.sleep(0.001) 25 | index = self._idle_actors.pop() 26 | return [index, self.actors[index]] 27 | 28 | def give_back(self, v) -> NoReturn: 29 | self._idle_actors.append(v) 30 | 31 | def shutdown(self) -> NoReturn: 32 | [ray.kill(self.actors[index]) for index in self._idle_actors] 33 | 34 | 35 | @ray.remote 36 | class UDFWorker(object): 37 | def __init__(self, 38 | model_refs: List[ClientObjectRef], 39 | conf: Dict[str, str], 40 | init_func: Callable[[List[ClientObjectRef], Dict[str, str]], Any], 41 | apply_func: Callable[[Any, Any], Any]): 42 | self.model = init_func(model_refs, conf) 43 | self.apply_func = apply_func 44 | 45 | def apply(self, v: Any) -> Any: 46 | return self.apply_func(self.model, v) 47 | 48 | def shutdown(self): 49 | ray.actor.exit_actor() 50 | 51 | 52 | class UDFBuilder(object): 53 | @staticmethod 54 | def build(ray_context: RayContext, 55 | init_func: Callable[[List[ClientObjectRef], Dict[str, str]], Any], 56 | apply_func: Callable[[Any, Any], Any]) -> NoReturn: 57 | conf = ray_context.conf() 58 | udf_name = conf["UDF_CLIENT"] 59 | max_concurrency = int(conf.get("maxConcurrency", "3")) 60 | 61 | try: 62 | temp_udf_master = ray.get_actor(udf_name) 63 | ray.kill(temp_udf_master) 64 | time.sleep(1) 65 | except Exception as inst: 66 | print(inst) 67 | pass 68 | 69 | UDFMaster.options(name=udf_name, lifetime="detached", max_concurrency=max_concurrency).remote( 70 | max_concurrency, conf, init_func, apply_func) 71 | ray_context.build_result([]) 72 | 73 | @staticmethod 74 | def apply(ray_context: RayContext): 75 | conf = ray_context.conf() 76 | udf_name = conf["UDF_CLIENT"] 77 | udf_master = ray.get_actor(udf_name) 78 | [index, worker] = ray.get(udf_master.get.remote()) 79 | input = [row["value"] for row in ray_context.python_context.fetch_once_as_rows()] 80 | try: 81 | res = ray.get(worker.apply.remote(input)) 82 | except Exception as inst: 83 | res = [] 84 | print(inst) 85 | udf_master.give_back.remote(index) 86 | ray_context.build_result([res]) 87 | 88 | 89 | class UDFBuildInFunc(object): 90 | @staticmethod 91 | def init_tf(model_refs: List[ClientObjectRef], conf: Dict[str, str]) -> Any: 92 | from tensorflow.keras import models 93 | model_path = "./tmp/model/{}".format(str(uuid.uuid4())) 94 | streaming_tar.save_rows_as_file((ray.get(ref) for ref in model_refs), model_path) 95 | return models.load_model(model_path) 96 | -------------------------------------------------------------------------------- /python/pyjava/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from pyjava.serializers import write_with_length, UTF8Deserializer, \ 5 | PickleSerializer 6 | 7 | if sys.version >= '3': 8 | basestring = str 9 | else: 10 | pass 11 | 12 | pickleSer = PickleSerializer() 13 | utf8_deserializer = UTF8Deserializer() 14 | 15 | 16 | def is_dev(): 17 | return 'MLSQL_DEV' in os.environ and int(os.environ["MLSQL_DEV"]) == 1 18 | 19 | 20 | def require_minimum_pandas_version(): 21 | """ Raise ImportError if minimum version of Pandas is not installed 22 | """ 23 | # TODO(HyukjinKwon): Relocate and deduplicate the version specification. 24 | minimum_pandas_version = "0.23.2" 25 | 26 | from distutils.version import LooseVersion 27 | try: 28 | import pandas 29 | have_pandas = True 30 | except ImportError: 31 | have_pandas = False 32 | if not have_pandas: 33 | raise ImportError("Pandas >= %s must be installed; however, " 34 | "it was not found." % minimum_pandas_version) 35 | if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version): 36 | raise ImportError("Pandas >= %s must be installed; however, " 37 | "your version was %s." % (minimum_pandas_version, pandas.__version__)) 38 | 39 | 40 | def require_minimum_pyarrow_version(): 41 | """ Raise ImportError if minimum version of pyarrow is not installed 42 | """ 43 | # TODO(HyukjinKwon): Relocate and deduplicate the version specification. 44 | minimum_pyarrow_version = "0.12.1" 45 | 46 | from distutils.version import LooseVersion 47 | try: 48 | import pyarrow 49 | have_arrow = True 50 | except ImportError: 51 | have_arrow = False 52 | if not have_arrow: 53 | raise ImportError("PyArrow >= %s must be installed; however, " 54 | "it was not found." % minimum_pyarrow_version) 55 | if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version): 56 | raise ImportError("PyArrow >= %s must be installed; however, " 57 | "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__)) 58 | 59 | 60 | def _exception_message(excp): 61 | """Return the message from an exception as either a str or unicode object. Supports both 62 | Python 2 and Python 3. 63 | 64 | >>> msg = "Exception message" 65 | >>> excp = Exception(msg) 66 | >>> msg == _exception_message(excp) 67 | True 68 | 69 | >>> msg = u"unicöde" 70 | >>> excp = Exception(msg) 71 | >>> msg == _exception_message(excp) 72 | True 73 | """ 74 | if hasattr(excp, "message"): 75 | return excp.message 76 | return str(excp) 77 | 78 | 79 | def _do_server_auth(conn, auth_secret): 80 | """ 81 | Performs the authentication protocol defined by the SocketAuthHelper class on the given 82 | file-like object 'conn'. 83 | """ 84 | write_with_length(auth_secret.encode("utf-8"), conn) 85 | conn.flush() 86 | reply = UTF8Deserializer().loads(conn) 87 | if reply != "ok": 88 | conn.close() 89 | raise Exception("Unexpected reply from iterator server.") 90 | 91 | 92 | def local_connect_and_auth(port): 93 | """ 94 | Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. 95 | Handles IPV4 & IPV6, does some error handling. 96 | :param port 97 | :param auth_secret 98 | :return: a tuple with (sockfile, sock) 99 | """ 100 | import socket 101 | sock = None 102 | errors = [] 103 | # Support for both IPv4 and IPv6. 104 | # On most of IPv6-ready systems, IPv6 will take precedence. 105 | for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): 106 | af, socktype, proto, _, sa = res 107 | try: 108 | sock = socket.socket(af, socktype, proto) 109 | sock.settimeout(15) 110 | sock.connect(sa) 111 | sockfile = sock.makefile("rwb", int(os.environ.get("BUFFER_SIZE", 65536))) 112 | return (sockfile, sock) 113 | except socket.error as e: 114 | emsg = _exception_message(e) 115 | errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) 116 | sock.close() 117 | sock = None 118 | raise Exception("could not open socket: %s" % errors) 119 | -------------------------------------------------------------------------------- /python/pyjava/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | __version__ = "0.3.3" 20 | -------------------------------------------------------------------------------- /python/pyjava/worker.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from __future__ import print_function 19 | 20 | from pyjava.api.mlsql import PythonContext 21 | from pyjava.cache.code_cache import CodeCache 22 | from pyjava.utils import * 23 | 24 | # 'resource' is a Unix specific module. 25 | has_resource_module = True 26 | try: 27 | import resource 28 | except ImportError: 29 | has_resource_module = False 30 | 31 | import traceback 32 | import logging 33 | 34 | from pyjava.serializers import \ 35 | write_with_length, \ 36 | write_int, \ 37 | read_int, read_bool, SpecialLengths, UTF8Deserializer, \ 38 | PickleSerializer, ArrowStreamPandasSerializer, ArrowStreamSerializer 39 | 40 | if sys.version >= '3': 41 | basestring = str 42 | else: 43 | pass 44 | 45 | # logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level='DEBUG') 46 | 47 | pickleSer = PickleSerializer() 48 | utf8_deserializer = UTF8Deserializer() 49 | 50 | globals_namespace = globals() 51 | 52 | 53 | def read_command(serializer, file): 54 | command = serializer.load_stream(file) 55 | return command 56 | 57 | 58 | def chain(f, g): 59 | """chain two functions together """ 60 | return lambda *a: g(f(*a)) 61 | 62 | 63 | def main(infile, outfile): 64 | try: 65 | try: 66 | import ray 67 | except ImportError: 68 | pass 69 | # set up memory limits 70 | memory_limit_mb = int(os.environ.get('PY_EXECUTOR_MEMORY', "-1")) 71 | if memory_limit_mb > 0 and has_resource_module: 72 | total_memory = resource.RLIMIT_AS 73 | try: 74 | (soft_limit, hard_limit) = resource.getrlimit(total_memory) 75 | msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) 76 | logging.info(msg) 77 | 78 | # convert to bytes 79 | new_limit = memory_limit_mb * 1024 * 1024 80 | 81 | if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: 82 | msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) 83 | logging.info(msg) 84 | resource.setrlimit(total_memory, (new_limit, new_limit)) 85 | 86 | except (resource.error, OSError, ValueError) as e: 87 | # not all systems support resource limits, so warn instead of failing 88 | logging.warning("WARN: Failed to set memory limit: {0}\n".format(e)) 89 | split_index = read_int(infile) 90 | logging.info("split_index:%s" % split_index) 91 | if split_index == -1: # for unit tests 92 | sys.exit(-1) 93 | 94 | is_barrier = read_bool(infile) 95 | bound_port = read_int(infile) 96 | logging.info(f"is_barrier {is_barrier}, port {bound_port}") 97 | 98 | conf = {} 99 | for i in range(read_int(infile)): 100 | k = utf8_deserializer.loads(infile) 101 | v = utf8_deserializer.loads(infile) 102 | conf[k] = v 103 | logging.debug(f"conf {k}:{v}") 104 | 105 | command = utf8_deserializer.loads(infile) 106 | ser = ArrowStreamSerializer() 107 | 108 | timezone = conf["timezone"] if "timezone" in conf else None 109 | 110 | out_ser = ArrowStreamPandasSerializer(timezone, True, True) 111 | is_interactive = os.environ.get('PY_INTERACTIVE', "no") == "yes" 112 | import uuid 113 | context_id = str(uuid.uuid4()) 114 | 115 | if not os.path.exists(context_id): 116 | os.mkdir(context_id) 117 | 118 | def process(): 119 | try: 120 | input_data = ser.load_stream(infile) 121 | code = CodeCache.get(command) 122 | if is_interactive: 123 | global data_manager 124 | global context 125 | data_manager = PythonContext(context_id, input_data, conf) 126 | context = data_manager 127 | global globals_namespace 128 | exec(code, globals_namespace, globals_namespace) 129 | else: 130 | data_manager = PythonContext(context_id, input_data, conf) 131 | n_local = {"data_manager": data_manager, "context": data_manager} 132 | exec(code, n_local, n_local) 133 | out_iter = data_manager.output() 134 | write_int(SpecialLengths.START_ARROW_STREAM, outfile) 135 | out_ser.dump_stream(out_iter, outfile) 136 | finally: 137 | 138 | try: 139 | import shutil 140 | shutil.rmtree(context_id) 141 | except: 142 | pass 143 | 144 | try: 145 | if hasattr(out_iter, 'close'): 146 | out_iter.close() 147 | except: 148 | pass 149 | 150 | try: 151 | del data_manager 152 | except: 153 | pass 154 | 155 | process() 156 | 157 | 158 | except Exception: 159 | try: 160 | write_int(SpecialLengths.ARROW_STREAM_CRASH, outfile) 161 | write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) 162 | write_with_length(traceback.format_exc().encode("utf-8"), outfile) 163 | except IOError: 164 | # JVM close the socket 165 | pass 166 | except Exception: 167 | # Write the error to stderr if it happened while serializing 168 | print("Py worker failed with exception:", file=sys.stderr) 169 | print(traceback.format_exc(), file=sys.stderr) 170 | sys.exit(-1) 171 | 172 | write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) 173 | flag = read_int(infile) 174 | if flag == SpecialLengths.END_OF_STREAM: 175 | write_int(SpecialLengths.END_OF_STREAM, outfile) 176 | else: 177 | # write a different value to tell JVM to not reuse this worker 178 | write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) 179 | sys.exit(-1) 180 | 181 | 182 | if __name__ == '__main__': 183 | # Read information about how to connect back to the JVM from the environment. 184 | java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) 185 | (sock_file, _) = local_connect_and_auth(java_port) 186 | main(sock_file, sock_file) 187 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | pyarrow==4.0.1 2 | ray>=1.3.0 3 | pandas>=1.0.5; python_version < '3.7' 4 | pandas>=1.2.0; python_version >= '3.7' 5 | requests 6 | matplotlib~=3.3.4 7 | uuid~=1.30 -------------------------------------------------------------------------------- /python/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | from __future__ import print_function 20 | 21 | import sys 22 | 23 | from setuptools import setup 24 | 25 | if sys.version_info < (3, 6): 26 | print("Python versions prior to 2.7 are not supported for pip installed pyjava.", 27 | file=sys.stderr) 28 | sys.exit(-1) 29 | 30 | try: 31 | exec(open('pyjava/version.py').read()) 32 | except IOError: 33 | print("Failed to load PyJava version file for packaging. You must be in PyJava's python dir.", 34 | file=sys.stderr) 35 | sys.exit(-1) 36 | VERSION = __version__ # noqa 37 | # A temporary path so we can access above the Python project root and fetch scripts and jars we need 38 | TEMP_PATH = "deps" 39 | 40 | # Provide guidance about how to use setup.py 41 | incorrect_invocation_message = """ 42 | If you are installing pyjava from source, you must first build 43 | run sdist. 44 | Building the source dist is done in the Python directory: 45 | cd python 46 | python setup.py sdist 47 | pip install dist/*.tar.gz""" 48 | 49 | try: 50 | setup( 51 | name='pyjava', 52 | version=VERSION, 53 | description='PyJava Python API', 54 | long_description="", 55 | author='allwefantasy', 56 | author_email='allwefantasy@gmail.com', 57 | url='https://github.com/allwefantasy/pyjava', 58 | packages=['pyjava', 59 | 'pyjava.api', 60 | 'pyjava.udf', 61 | 'pyjava.datatype', 62 | 'pyjava.storage', 63 | 'pyjava.cache'], 64 | include_package_data=True, 65 | package_dir={ 66 | 'pyjava.sbin': 'deps/sbin' 67 | }, 68 | package_data={ 69 | 'pyjava.sbin': []}, 70 | license='http://www.apache.org/licenses/LICENSE-2.0', 71 | setup_requires=['pypandoc'], 72 | extras_require={ 73 | }, 74 | classifiers=[ 75 | 'Development Status :: 5 - Production/Stable', 76 | 'License :: OSI Approved :: Apache Software License', 77 | 'Programming Language :: Python :: 3.6', 78 | 'Programming Language :: Python :: 3.7', 79 | 'Programming Language :: Python :: Implementation :: CPython', 80 | 'Programming Language :: Python :: Implementation :: PyPy'] 81 | ) 82 | finally: 83 | print("--------") 84 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/WowRowEncoder.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * DO NOT EDIT THIS FILE DIRECTLY, ANY CHANGE MAY BE OVERWRITE 3 | */ 4 | package org.apache.spark 5 | import org.apache.spark.sql.Row 6 | import org.apache.spark.sql.catalyst.InternalRow 7 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 8 | import org.apache.spark.sql.types.StructType 9 | 10 | object WowRowEncoder { 11 | def toRow(schema: StructType) = { 12 | RowEncoder.apply(schema).resolveAndBind().createDeserializer() 13 | 14 | } 15 | 16 | def fromRow(schema: StructType) = { 17 | RowEncoder.apply(schema).resolveAndBind().createSerializer() 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/org/apache/spark/sql/SparkUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.sql 2 | 3 | import org.apache.arrow.vector.types.pojo.ArrowType 4 | import org.apache.spark.TaskContext 5 | import org.apache.spark.rdd.RDD 6 | import org.apache.spark.sql.catalyst.InternalRow 7 | import org.apache.spark.sql.execution.LogicalRDD 8 | import org.apache.spark.sql.types.{DataType, DecimalType, StructType} 9 | 10 | /** 11 | * 2019-08-13 WilliamZhu(allwefantasy@gmail.com) 12 | */ 13 | object SparkUtils { 14 | def internalCreateDataFrame(self: SparkSession, 15 | catalystRows: RDD[InternalRow], 16 | schema: StructType, 17 | isStreaming: Boolean = false): DataFrame = { 18 | val logicalPlan = LogicalRDD( 19 | schema.toAttributes, 20 | catalystRows, 21 | isStreaming = isStreaming)(self) 22 | Dataset.ofRows(self, logicalPlan) 23 | } 24 | 25 | def isFixDecimal(dt: DataType) = { 26 | dt match { 27 | case t@DecimalType.Fixed(precision, scale) => Option(new ArrowType.Decimal(precision, scale)) 28 | case _ => None 29 | } 30 | } 31 | 32 | def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) 33 | 34 | def getKillReason(tc: TaskContext) = tc.getKillReason() 35 | 36 | def killTaskIfInterrupted(tc: TaskContext) = { 37 | tc.killTaskIfInterrupted() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/ArrowConverters.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package tech.mlsql.arrow 19 | 20 | import org.apache.arrow.flatbuf.MessageHeader 21 | import org.apache.arrow.memory.BufferAllocator 22 | import org.apache.arrow.vector._ 23 | import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} 24 | import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} 25 | import org.apache.spark.TaskContext 26 | import org.apache.spark.api.java.JavaRDD 27 | import org.apache.spark.network.util.JavaUtils 28 | import org.apache.spark.sql.catalyst.InternalRow 29 | import org.apache.spark.sql.types.{StructType, _} 30 | import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} 31 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkUtils} 32 | import org.apache.spark.util.TaskCompletionListener 33 | import tech.mlsql.arrow.context.CommonTaskContext 34 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext} 35 | import tech.mlsql.arrow.python.ispark.SparkContextImp 36 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} 37 | import java.nio.channels.{Channels, ReadableByteChannel} 38 | 39 | import tech.mlsql.common.utils.lang.sc.ScalaReflect 40 | 41 | import scala.collection.JavaConverters._ 42 | 43 | 44 | /** 45 | * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. 46 | */ 47 | class ArrowBatchStreamWriter( 48 | schema: StructType, 49 | out: OutputStream, 50 | timeZoneId: String) { 51 | 52 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) 53 | val writeChannel = new WriteChannel(Channels.newChannel(out)) 54 | 55 | // Write the Arrow schema first, before batches 56 | MessageSerializer.serialize(writeChannel, arrowSchema) 57 | 58 | /** 59 | * Consume iterator to write each serialized ArrowRecordBatch to the stream. 60 | */ 61 | def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = { 62 | arrowBatchIter.foreach(writeChannel.write) 63 | } 64 | 65 | /** 66 | * End the Arrow stream, does not close output stream. 67 | */ 68 | def end(): Unit = { 69 | // 等后续无需兼容老版本arrow时,需要调整这行代码 70 | // writeChannel.writeIntLittleEndian(MessageSerializer.IPC_CONTINUATION_TOKEN); 71 | writeChannel.writeIntLittleEndian(0); 72 | } 73 | } 74 | 75 | object ArrowConverters { 76 | 77 | /** 78 | * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size 79 | * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter. 80 | */ 81 | def toBatchIterator( 82 | rowIter: Iterator[InternalRow], 83 | schema: StructType, 84 | maxRecordsPerBatch: Int, 85 | timeZoneId: String, 86 | context: CommonTaskContext): Iterator[Array[Byte]] = { 87 | 88 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) 89 | val allocator = 90 | ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue) 91 | 92 | val root = VectorSchemaRoot.create(arrowSchema, allocator) 93 | val unloader = new VectorUnloader(root) 94 | val arrowWriter = ArrowWriter.create(root) 95 | 96 | context match { 97 | case c: AppContextImpl => c.innerContext.asInstanceOf[JavaContext].addTaskCompletionListener { _ => 98 | root.close() 99 | allocator.close() 100 | } 101 | case c: SparkContextImp => c.innerContext.asInstanceOf[TaskContext].addTaskCompletionListener(new TaskCompletionListener { 102 | override def onTaskCompletion(context: TaskContext): Unit = { 103 | root.close() 104 | allocator.close() 105 | } 106 | }) 107 | } 108 | 109 | new Iterator[Array[Byte]] { 110 | 111 | override def hasNext: Boolean = rowIter.hasNext || { 112 | root.close() 113 | allocator.close() 114 | false 115 | } 116 | 117 | override def next(): Array[Byte] = { 118 | val out = new ByteArrayOutputStream() 119 | val writeChannel = new WriteChannel(Channels.newChannel(out)) 120 | 121 | Utils.tryWithSafeFinally { 122 | var rowCount = 0 123 | while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { 124 | val row = rowIter.next() 125 | arrowWriter.write(row) 126 | rowCount += 1 127 | } 128 | arrowWriter.finish() 129 | val batch = unloader.getRecordBatch() 130 | MessageSerializer.serialize(writeChannel, batch) 131 | batch.close() 132 | } { 133 | arrowWriter.reset() 134 | } 135 | 136 | out.toByteArray 137 | } 138 | } 139 | } 140 | 141 | /** 142 | * Maps iterator from serialized ArrowRecordBatches to InternalRows. 143 | */ 144 | def fromBatchIterator( 145 | arrowBatchIter: Iterator[Array[Byte]], 146 | schema: StructType, 147 | timeZoneId: String, 148 | context: CommonTaskContext): Iterator[InternalRow] = { 149 | val allocator = 150 | ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) 151 | 152 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) 153 | val root = VectorSchemaRoot.create(arrowSchema, allocator) 154 | 155 | new Iterator[InternalRow] { 156 | private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty 157 | 158 | context.innerContext match { 159 | case c: AppContextImpl => c.innerContext.asInstanceOf[JavaContext].addTaskCompletionListener { _ => 160 | root.close() 161 | allocator.close() 162 | } 163 | case c: SparkContextImp => c.innerContext.asInstanceOf[TaskContext].addTaskCompletionListener(new TaskCompletionListener { 164 | override def onTaskCompletion(context: TaskContext): Unit = { 165 | root.close() 166 | allocator.close() 167 | } 168 | }) 169 | } 170 | 171 | override def hasNext: Boolean = rowIter.hasNext || { 172 | if (arrowBatchIter.hasNext) { 173 | rowIter = nextBatch() 174 | true 175 | } else { 176 | root.close() 177 | allocator.close() 178 | false 179 | } 180 | } 181 | 182 | override def next(): InternalRow = rowIter.next() 183 | 184 | private def nextBatch(): Iterator[InternalRow] = { 185 | val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) 186 | val vectorLoader = new VectorLoader(root) 187 | vectorLoader.load(arrowRecordBatch) 188 | arrowRecordBatch.close() 189 | 190 | val columns = root.getFieldVectors.asScala.map { vector => 191 | new ArrowColumnVector(vector).asInstanceOf[ColumnVector] 192 | }.toArray 193 | 194 | val batch = new ColumnarBatch(columns) 195 | batch.setNumRows(root.getRowCount) 196 | batch.rowIterator().asScala 197 | } 198 | } 199 | } 200 | 201 | /** 202 | * Load a serialized ArrowRecordBatch. 203 | */ 204 | def loadBatch( 205 | batchBytes: Array[Byte], 206 | allocator: BufferAllocator): ArrowRecordBatch = { 207 | val in = new ByteArrayInputStream(batchBytes) 208 | MessageSerializer.deserializeRecordBatch( 209 | new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException 210 | } 211 | 212 | /** 213 | * Create a DataFrame from an RDD of serialized ArrowRecordBatches. 214 | */ 215 | def toDataFrame( 216 | arrowBatchRDD: JavaRDD[Array[Byte]], 217 | schemaString: String, 218 | sqlContext: SQLContext): DataFrame = { 219 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] 220 | val timeZoneId = sqlContext.sparkSession.sessionState.conf.sessionLocalTimeZone 221 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter => 222 | val context = new SparkContextImp(TaskContext.get(), null) 223 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context) 224 | } 225 | SparkUtils.internalCreateDataFrame(sqlContext.sparkSession, rdd.setName("arrow"), schema) 226 | } 227 | 228 | /** 229 | * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches. 230 | */ 231 | def readArrowStreamFromFile( 232 | sqlContext: SQLContext, 233 | filename: String): JavaRDD[Array[Byte]] = { 234 | Utils.tryWithResource(new FileInputStream(filename)) { fileStream => 235 | // Create array to consume iterator so that we can safely close the file 236 | val batches = getBatchesFromStream(fileStream.getChannel).toArray 237 | // Parallelize the record batches to create an RDD 238 | JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length)) 239 | } 240 | } 241 | 242 | /** 243 | * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches. 244 | */ 245 | def getBatchesFromStream(in: ReadableByteChannel): Iterator[Array[Byte]] = { 246 | 247 | // Iterate over the serialized Arrow RecordBatch messages from a stream 248 | new Iterator[Array[Byte]] { 249 | var batch: Array[Byte] = readNextBatch() 250 | 251 | override def hasNext: Boolean = batch != null 252 | 253 | override def next(): Array[Byte] = { 254 | val prevBatch = batch 255 | batch = readNextBatch() 256 | prevBatch 257 | } 258 | 259 | // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it 260 | // is a RecordBatch message and then returning the complete serialized message which consists 261 | // of a int32 length, serialized message metadata and a serialized RecordBatch message body 262 | def readNextBatch(): Array[Byte] = { 263 | val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in)) 264 | if (msgMetadata == null) { 265 | return null 266 | } 267 | 268 | // Get the length of the body, which has not been read at this point 269 | val bodyLength = msgMetadata.getMessageBodyLength.toInt 270 | 271 | // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages 272 | if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) { 273 | 274 | // Buffer backed output large enough to hold the complete serialized message 275 | val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength) 276 | 277 | // Write message metadata to ByteBuffer output stream 278 | MessageSerializer.writeMessageBuffer( 279 | new WriteChannel(Channels.newChannel(bbout)), 280 | msgMetadata.getMessageLength, 281 | msgMetadata.getMessageBuffer) 282 | 283 | // Get a zero-copy ByteBuffer with already contains message metadata, must close first 284 | bbout.close() 285 | val bb = bbout.toByteBuffer 286 | bb.position(bbout.getCount()) 287 | 288 | // Read message body directly into the ByteBuffer to avoid copy, return backed byte array 289 | bb.limit(bb.capacity()) 290 | JavaUtils.readFully(in, bb) 291 | bb.array() 292 | } else { 293 | if (bodyLength > 0) { 294 | // Skip message body if not a RecordBatch 295 | Channels.newInputStream(in).skip(bodyLength) 296 | } 297 | 298 | // Proceed to next message 299 | readNextBatch() 300 | } 301 | } 302 | } 303 | } 304 | } 305 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/ArrowUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package tech.mlsql.arrow 19 | 20 | import org.apache.arrow.memory.RootAllocator 21 | import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} 22 | import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} 23 | import org.apache.spark.sql.SparkUtils 24 | import org.apache.spark.sql.types._ 25 | 26 | import scala.collection.JavaConverters._ 27 | 28 | object ArrowUtils { 29 | 30 | val rootAllocator = new RootAllocator(Long.MaxValue) 31 | 32 | // todo: support more types. 33 | 34 | /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ 35 | def toArrowType(dt: DataType, timeZoneId: String): ArrowType = { 36 | SparkUtils.isFixDecimal(dt).getOrElse { 37 | dt match { 38 | case BooleanType => ArrowType.Bool.INSTANCE 39 | case ByteType => new ArrowType.Int(8, true) 40 | case ShortType => new ArrowType.Int(8 * 2, true) 41 | case IntegerType => new ArrowType.Int(8 * 4, true) 42 | case LongType => new ArrowType.Int(8 * 8, true) 43 | case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) 44 | case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) 45 | case StringType => ArrowType.Utf8.INSTANCE 46 | case BinaryType => ArrowType.Binary.INSTANCE 47 | case DateType => new ArrowType.Date(DateUnit.DAY) 48 | case TimestampType => 49 | if (timeZoneId == null) { 50 | throw new UnsupportedOperationException( 51 | s"${TimestampType.catalogString} must supply timeZoneId parameter") 52 | } else { 53 | new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) 54 | } 55 | case _ => 56 | 57 | throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") 58 | } 59 | } 60 | 61 | } 62 | 63 | def fromArrowType(dt: ArrowType): DataType = dt match { 64 | case ArrowType.Bool.INSTANCE => BooleanType 65 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType 66 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType 67 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType 68 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType 69 | case float: ArrowType.FloatingPoint 70 | if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType 71 | case float: ArrowType.FloatingPoint 72 | if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType 73 | case ArrowType.Utf8.INSTANCE => StringType 74 | case ArrowType.Binary.INSTANCE => BinaryType 75 | case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) 76 | case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType 77 | case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType 78 | case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") 79 | } 80 | 81 | /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ 82 | def toArrowField( 83 | name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { 84 | dt match { 85 | case ArrayType(elementType, containsNull) => 86 | val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) 87 | new Field(name, fieldType, 88 | Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) 89 | case StructType(fields) => 90 | val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) 91 | new Field(name, fieldType, 92 | fields.map { field => 93 | toArrowField(field.name, field.dataType, field.nullable, timeZoneId) 94 | }.toSeq.asJava) 95 | case dataType => 96 | val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) 97 | new Field(name, fieldType, Seq.empty[Field].asJava) 98 | } 99 | } 100 | 101 | def fromArrowField(field: Field): DataType = { 102 | field.getType match { 103 | case ArrowType.List.INSTANCE => 104 | val elementField = field.getChildren().get(0) 105 | val elementType = fromArrowField(elementField) 106 | ArrayType(elementType, containsNull = elementField.isNullable) 107 | case ArrowType.Struct.INSTANCE => 108 | val fields = field.getChildren().asScala.map { child => 109 | val dt = fromArrowField(child) 110 | StructField(child.getName, dt, child.isNullable) 111 | } 112 | StructType(fields) 113 | case arrowType => fromArrowType(arrowType) 114 | } 115 | } 116 | 117 | /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ 118 | def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { 119 | new Schema(schema.map { field => 120 | toArrowField(field.name, field.dataType, field.nullable, timeZoneId) 121 | }.asJava) 122 | } 123 | 124 | def fromArrowSchema(schema: Schema): StructType = { 125 | StructType(schema.getFields.asScala.map { field => 126 | val dt = fromArrowField(field) 127 | StructField(field.getName, dt, field.isNullable) 128 | }) 129 | } 130 | 131 | 132 | } 133 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/ArrowWriter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package tech.mlsql.arrow 19 | 20 | import org.apache.arrow.vector._ 21 | import org.apache.arrow.vector.complex._ 22 | import org.apache.spark.sql.SparkUtils 23 | import org.apache.spark.sql.catalyst.InternalRow 24 | import org.apache.spark.sql.catalyst.expressions.SpecializedGetters 25 | import org.apache.spark.sql.types._ 26 | 27 | import scala.collection.JavaConverters._ 28 | 29 | object ArrowWriter { 30 | 31 | def create(schema: StructType, timeZoneId: String): ArrowWriter = { 32 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) 33 | val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) 34 | create(root) 35 | } 36 | 37 | def create(root: VectorSchemaRoot): ArrowWriter = { 38 | val children = root.getFieldVectors().asScala.map { vector => 39 | vector.allocateNew() 40 | createFieldWriter(vector) 41 | } 42 | new ArrowWriter(root, children.toArray) 43 | } 44 | 45 | private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { 46 | val field = vector.getField() 47 | val fixDecimalOpt = SparkUtils.isFixDecimal(ArrowUtils.fromArrowField(field)) 48 | (fixDecimalOpt, vector) match { 49 | case (Some(i), vector: DecimalVector) => return new DecimalWriter(vector, i.getPrecision, i.getScale) 50 | case (None, _) => 51 | } 52 | 53 | (ArrowUtils.fromArrowField(field), vector) match { 54 | case (BooleanType, vector: BitVector) => new BooleanWriter(vector) 55 | case (ByteType, vector: TinyIntVector) => new ByteWriter(vector) 56 | case (ShortType, vector: SmallIntVector) => new ShortWriter(vector) 57 | case (IntegerType, vector: IntVector) => new IntegerWriter(vector) 58 | case (LongType, vector: BigIntVector) => new LongWriter(vector) 59 | case (FloatType, vector: Float4Vector) => new FloatWriter(vector) 60 | case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector) 61 | case (StringType, vector: VarCharVector) => new StringWriter(vector) 62 | case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector) 63 | case (DateType, vector: DateDayVector) => new DateWriter(vector) 64 | case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector) 65 | case (ArrayType(_, _), vector: ListVector) => 66 | val elementVector = createFieldWriter(vector.getDataVector()) 67 | new ArrayWriter(vector, elementVector) 68 | case (StructType(_), vector: StructVector) => 69 | val children = (0 until vector.size()).map { ordinal => 70 | createFieldWriter(vector.getChildByOrdinal(ordinal)) 71 | } 72 | new StructWriter(vector, children.toArray) 73 | case (dt, _) => 74 | throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}") 75 | } 76 | } 77 | } 78 | 79 | class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { 80 | 81 | def schema: StructType = StructType(fields.map { f => 82 | StructField(f.name, f.dataType, f.nullable) 83 | }) 84 | 85 | private var count: Int = 0 86 | 87 | def write(row: InternalRow): Unit = { 88 | var i = 0 89 | while (i < fields.size) { 90 | fields(i).write(row, i) 91 | i += 1 92 | } 93 | count += 1 94 | } 95 | 96 | def finish(): Unit = { 97 | root.setRowCount(count) 98 | fields.foreach(_.finish()) 99 | } 100 | 101 | def reset(): Unit = { 102 | root.setRowCount(0) 103 | count = 0 104 | fields.foreach(_.reset()) 105 | } 106 | } 107 | 108 | private[arrow] abstract class ArrowFieldWriter { 109 | 110 | def valueVector: ValueVector 111 | 112 | def name: String = valueVector.getField().getName() 113 | 114 | def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) 115 | 116 | def nullable: Boolean = valueVector.getField().isNullable() 117 | 118 | def setNull(): Unit 119 | 120 | def setValue(input: SpecializedGetters, ordinal: Int): Unit 121 | 122 | private[arrow] var count: Int = 0 123 | 124 | def write(input: SpecializedGetters, ordinal: Int): Unit = { 125 | if (input.isNullAt(ordinal)) { 126 | setNull() 127 | } else { 128 | setValue(input, ordinal) 129 | } 130 | count += 1 131 | } 132 | 133 | def finish(): Unit = { 134 | valueVector.setValueCount(count) 135 | } 136 | 137 | def reset(): Unit = { 138 | valueVector.reset() 139 | count = 0 140 | } 141 | } 142 | 143 | private[arrow] class BooleanWriter(val valueVector: BitVector) extends ArrowFieldWriter { 144 | 145 | override def setNull(): Unit = { 146 | valueVector.setNull(count) 147 | } 148 | 149 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 150 | valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) 151 | } 152 | } 153 | 154 | private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends ArrowFieldWriter { 155 | 156 | override def setNull(): Unit = { 157 | valueVector.setNull(count) 158 | } 159 | 160 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 161 | valueVector.setSafe(count, input.getByte(ordinal)) 162 | } 163 | } 164 | 165 | private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends ArrowFieldWriter { 166 | 167 | override def setNull(): Unit = { 168 | valueVector.setNull(count) 169 | } 170 | 171 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 172 | valueVector.setSafe(count, input.getShort(ordinal)) 173 | } 174 | } 175 | 176 | private[arrow] class IntegerWriter(val valueVector: IntVector) extends ArrowFieldWriter { 177 | 178 | override def setNull(): Unit = { 179 | valueVector.setNull(count) 180 | } 181 | 182 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 183 | valueVector.setSafe(count, input.getInt(ordinal)) 184 | } 185 | } 186 | 187 | private[arrow] class LongWriter(val valueVector: BigIntVector) extends ArrowFieldWriter { 188 | 189 | override def setNull(): Unit = { 190 | valueVector.setNull(count) 191 | } 192 | 193 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 194 | valueVector.setSafe(count, input.getLong(ordinal)) 195 | } 196 | } 197 | 198 | private[arrow] class FloatWriter(val valueVector: Float4Vector) extends ArrowFieldWriter { 199 | 200 | override def setNull(): Unit = { 201 | valueVector.setNull(count) 202 | } 203 | 204 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 205 | valueVector.setSafe(count, input.getFloat(ordinal)) 206 | } 207 | } 208 | 209 | private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter { 210 | 211 | override def setNull(): Unit = { 212 | valueVector.setNull(count) 213 | } 214 | 215 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 216 | valueVector.setSafe(count, input.getDouble(ordinal)) 217 | } 218 | } 219 | 220 | private[arrow] class DecimalWriter( 221 | val valueVector: DecimalVector, 222 | precision: Int, 223 | scale: Int) extends ArrowFieldWriter { 224 | 225 | override def setNull(): Unit = { 226 | valueVector.setNull(count) 227 | } 228 | 229 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 230 | val decimal = input.getDecimal(ordinal, precision, scale) 231 | if (decimal.changePrecision(precision, scale)) { 232 | valueVector.setSafe(count, decimal.toJavaBigDecimal) 233 | } else { 234 | setNull() 235 | } 236 | } 237 | } 238 | 239 | private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter { 240 | 241 | override def setNull(): Unit = { 242 | valueVector.setNull(count) 243 | } 244 | 245 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 246 | val utf8 = input.getUTF8String(ordinal) 247 | val utf8ByteBuffer = utf8.getByteBuffer 248 | // todo: for off-heap UTF8String, how to pass in to arrow without copy? 249 | valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) 250 | } 251 | } 252 | 253 | private[arrow] class BinaryWriter( 254 | val valueVector: VarBinaryVector) extends ArrowFieldWriter { 255 | 256 | override def setNull(): Unit = { 257 | valueVector.setNull(count) 258 | } 259 | 260 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 261 | val bytes = input.getBinary(ordinal) 262 | valueVector.setSafe(count, bytes, 0, bytes.length) 263 | } 264 | } 265 | 266 | private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter { 267 | 268 | override def setNull(): Unit = { 269 | valueVector.setNull(count) 270 | } 271 | 272 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 273 | valueVector.setSafe(count, input.getInt(ordinal)) 274 | } 275 | } 276 | 277 | private[arrow] class TimestampWriter( 278 | val valueVector: TimeStampMicroTZVector) extends ArrowFieldWriter { 279 | 280 | override def setNull(): Unit = { 281 | valueVector.setNull(count) 282 | } 283 | 284 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 285 | valueVector.setSafe(count, input.getLong(ordinal)) 286 | } 287 | } 288 | 289 | private[arrow] class ArrayWriter( 290 | val valueVector: ListVector, 291 | val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { 292 | 293 | override def setNull(): Unit = { 294 | } 295 | 296 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 297 | val array = input.getArray(ordinal) 298 | var i = 0 299 | valueVector.startNewValue(count) 300 | while (i < array.numElements()) { 301 | elementWriter.write(array, i) 302 | i += 1 303 | } 304 | valueVector.endValue(count, array.numElements()) 305 | } 306 | 307 | override def finish(): Unit = { 308 | super.finish() 309 | elementWriter.finish() 310 | } 311 | 312 | override def reset(): Unit = { 313 | super.reset() 314 | elementWriter.reset() 315 | } 316 | } 317 | 318 | private[arrow] class StructWriter( 319 | val valueVector: StructVector, 320 | children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { 321 | 322 | override def setNull(): Unit = { 323 | var i = 0 324 | while (i < children.length) { 325 | children(i).setNull() 326 | children(i).count += 1 327 | i += 1 328 | } 329 | valueVector.setNull(count) 330 | } 331 | 332 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { 333 | val struct = input.getStruct(ordinal, children.length) 334 | var i = 0 335 | while (i < struct.numFields) { 336 | children(i).write(struct, i) 337 | i += 1 338 | } 339 | valueVector.setIndexDefined(count) 340 | } 341 | 342 | override def finish(): Unit = { 343 | super.finish() 344 | children.foreach(_.finish()) 345 | } 346 | 347 | override def reset(): Unit = { 348 | super.reset() 349 | children.foreach(_.reset()) 350 | } 351 | } 352 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/ByteBufferOutputStream.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow 2 | 3 | import java.io.ByteArrayOutputStream 4 | import java.nio.ByteBuffer 5 | 6 | /** 7 | * 2019-08-13 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) { 10 | 11 | def this() = this(32) 12 | 13 | def getCount(): Int = count 14 | 15 | private[this] var closed: Boolean = false 16 | 17 | override def write(b: Int): Unit = { 18 | require(!closed, "cannot write to a closed ByteBufferOutputStream") 19 | super.write(b) 20 | } 21 | 22 | override def write(b: Array[Byte], off: Int, len: Int): Unit = { 23 | require(!closed, "cannot write to a closed ByteBufferOutputStream") 24 | super.write(b, off, len) 25 | } 26 | 27 | override def reset(): Unit = { 28 | require(!closed, "cannot reset a closed ByteBufferOutputStream") 29 | super.reset() 30 | } 31 | 32 | override def close(): Unit = { 33 | if (!closed) { 34 | super.close() 35 | closed = true 36 | } 37 | } 38 | 39 | def toByteBuffer: ByteBuffer = { 40 | require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed") 41 | ByteBuffer.wrap(buf, 0, count) 42 | } 43 | } 44 | 45 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/Utils.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow 2 | 3 | import java.io._ 4 | import java.nio.charset.StandardCharsets 5 | import java.util.concurrent.TimeUnit 6 | 7 | import org.apache.spark.network.util.JavaUtils 8 | import tech.mlsql.arrow.api.RedirectStreams 9 | import tech.mlsql.arrow.python.PythonWorkerFactory.Tool.REDIRECT_IMPL 10 | import tech.mlsql.common.utils.log.Logging 11 | 12 | import scala.io.Source 13 | import scala.util.Try 14 | import scala.util.control.{ControlThrowable, NonFatal} 15 | 16 | 17 | /** 18 | * 2019-08-13 WilliamZhu(allwefantasy@gmail.com) 19 | */ 20 | object Utils extends Logging { 21 | def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { 22 | var originalThrowable: Throwable = null 23 | try { 24 | block 25 | } catch { 26 | case t: Throwable => 27 | // Purposefully not using NonFatal, because even fatal exceptions 28 | // we don't want to have our finallyBlock suppress 29 | originalThrowable = t 30 | throw originalThrowable 31 | } finally { 32 | try { 33 | finallyBlock 34 | } catch { 35 | case t: Throwable if (originalThrowable != null && originalThrowable != t) => 36 | originalThrowable.addSuppressed(t) 37 | logWarning(s"Suppressing exception in finally: ${t.getMessage}", t) 38 | throw originalThrowable 39 | } 40 | } 41 | } 42 | 43 | 44 | def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { 45 | val resource = createResource 46 | try f.apply(resource) finally resource.close() 47 | } 48 | 49 | 50 | /** 51 | * Return the stderr of a process after waiting for the process to terminate. 52 | * If the process does not terminate within the specified timeout, return None. 53 | */ 54 | def getStderr(process: Process, timeoutMs: Long): Option[String] = { 55 | val terminated = process.waitFor(timeoutMs, TimeUnit.MILLISECONDS) 56 | if (terminated) { 57 | Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n")) 58 | } else { 59 | None 60 | } 61 | } 62 | 63 | /** 64 | * Execute the given block, logging and re-throwing any uncaught exception. 65 | * This is particularly useful for wrapping code that runs in a thread, to ensure 66 | * that exceptions are printed, and to avoid having to catch Throwable. 67 | */ 68 | def logUncaughtExceptions[T](f: => T): T = { 69 | try { 70 | f 71 | } catch { 72 | case ct: ControlThrowable => 73 | throw ct 74 | case t: Throwable => 75 | logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) 76 | throw t 77 | } 78 | } 79 | 80 | 81 | class RedirectThread( 82 | in: InputStream, 83 | out: OutputStream, 84 | name: String, 85 | propagateEof: Boolean = false) 86 | extends Thread(name) { 87 | 88 | setDaemon(true) 89 | 90 | override def run() { 91 | scala.util.control.Exception.ignoring(classOf[IOException]) { 92 | // FIXME: We copy the stream on the level of bytes to avoid encoding problems. 93 | Utils.tryWithSafeFinally { 94 | val buf = new Array[Byte](1024) 95 | var len = in.read(buf) 96 | while (len != -1) { 97 | out.write(buf, 0, len) 98 | out.flush() 99 | len = in.read(buf) 100 | } 101 | } { 102 | if (propagateEof) { 103 | out.close() 104 | } 105 | } 106 | } 107 | } 108 | } 109 | 110 | /** Executes the given block in a Try, logging any uncaught exceptions. */ 111 | def tryLog[T](f: => T): Try[T] = { 112 | try { 113 | val res = f 114 | scala.util.Success(res) 115 | } catch { 116 | case ct: ControlThrowable => 117 | throw ct 118 | case t: Throwable => 119 | logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) 120 | scala.util.Failure(t) 121 | } 122 | } 123 | 124 | /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ 125 | def isFatalError(e: Throwable): Boolean = { 126 | e match { 127 | case NonFatal(_) | 128 | _: InterruptedException | 129 | _: NotImplementedError | 130 | _: ControlThrowable | 131 | _: LinkageError => 132 | false 133 | case _ => 134 | true 135 | } 136 | } 137 | 138 | def deleteRecursively(file: File): Unit = { 139 | if (file != null) { 140 | JavaUtils.deleteRecursively(file) 141 | } 142 | } 143 | 144 | def redirectStream(conf: Map[String, String], stdout: InputStream) { 145 | try { 146 | conf.get(REDIRECT_IMPL) match { 147 | case None => 148 | new RedirectThread(stdout, System.err, "stdout reader ").start() 149 | case Some(clzz) => 150 | val instance = Class.forName(clzz).newInstance().asInstanceOf[RedirectStreams] 151 | instance.setConf(conf) 152 | instance.stdOut(stdout) 153 | } 154 | } catch { 155 | case e: Exception => 156 | logError("Exception in redirecting streams", e) 157 | } 158 | } 159 | 160 | def writeUTF(str: String, dataOut: DataOutputStream) { 161 | val bytes = str.getBytes(StandardCharsets.UTF_8) 162 | dataOut.writeInt(bytes.length) 163 | dataOut.write(bytes) 164 | } 165 | 166 | } 167 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/api/RedirectStreams.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.api 2 | 3 | import java.io.InputStream 4 | 5 | 6 | /** 7 | * We will redirect all python daemon(or worker) stdout/stderr to the the java's stderr 8 | * by default. If you wanna change this behavior try to implements this Trait and 9 | * config the conf in ArrowPythonRunner. 10 | * 11 | * Example: 12 | * 13 | * new ArrowPythonRunner( 14 | * .... 15 | * conf=Map("python.redirect.impl"->"your impl class name") 16 | * ) 17 | * 18 | */ 19 | trait RedirectStreams { 20 | 21 | def setConf(conf: Map[String, String]): Unit 22 | 23 | def conf: Map[String, String] 24 | 25 | def stdOut(stdOut: InputStream): Unit 26 | 27 | def stdErr(stdErr: InputStream): Unit 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/context/CommonTaskContext.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.context 2 | 3 | import java.net.{ServerSocket, Socket} 4 | import java.util.concurrent.atomic.AtomicBoolean 5 | 6 | import org.apache.arrow.memory.BufferAllocator 7 | import org.apache.arrow.vector.ipc.ArrowStreamReader 8 | import tech.mlsql.arrow.python.runner.ArrowPythonRunner 9 | 10 | /** 11 | * This class is used to manager your tasks. When you create a thread to execute something which uses ArrowPythonRunner 12 | * , we should take care the situation when the task is normally completed or unexpectedly finished, there are some callback or action 13 | * should be performed 14 | */ 15 | trait CommonTaskContext { 16 | val arrowPythonRunner: ArrowPythonRunner 17 | 18 | /** 19 | * When reader begins to read, it will invoke the return function. Please use something to remember them 20 | * and When the task is done, use the returned function to clean the reader resource 21 | */ 22 | def readerRegister(callback: () => Unit): (ArrowStreamReader, BufferAllocator) => Unit 23 | 24 | /** 25 | * releasedOrClosed:AtomicBoolean should release or close the python worker we have used. 26 | * reuseWorker: should reuse the python worker 27 | * socket: python worker 28 | * When task started, it will invoke the return function. Please use something to remember them and 29 | * 2hen the task is done, use the returned function to clean the python resource 30 | */ 31 | def pythonWorkerRegister(callback: () => Unit): ( 32 | AtomicBoolean, 33 | Boolean, 34 | Socket) => Unit 35 | 36 | 37 | /** 38 | * 39 | * isBarrier is true, you should also set shutdownServerSocket 40 | */ 41 | def isBarrier: Boolean 42 | 43 | /** 44 | * When task started, it will invoke the return function. Please use something to remember them and 45 | * when the task is done, use the returned function to clean the python server side SocketServer 46 | */ 47 | def javaSideSocketServerRegister(): (ServerSocket) => Unit 48 | 49 | def monitor(callback: () => Unit): (Long, String, Map[String, String], Socket) => Unit 50 | 51 | def assertTaskIsCompleted(callback: () => Unit): () => Unit 52 | 53 | /** 54 | * 55 | * set innerContext, and you can use innerContext to get it again 56 | */ 57 | def setTaskContext(): () => Unit 58 | 59 | def innerContext: Any 60 | 61 | /** 62 | * check task status 63 | */ 64 | def isTaskCompleteOrInterrupt(): () => Boolean 65 | 66 | def isTaskInterrupt(): () => Boolean 67 | 68 | def getTaskKillReason(): () => Option[String] 69 | 70 | def killTaskIfInterrupted(): () => Unit 71 | } 72 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/javadoc.java: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow; 2 | 3 | /** 4 | * 2019-08-17 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | public class javadoc { 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/PyJavaException.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python 2 | 3 | /** 4 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | class PyJavaException(message: String, cause: Throwable) 7 | extends Exception(message, cause) { 8 | 9 | def this(message: String) = this(message, null) 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/PythonWorkerFactory.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python 2 | 3 | /** 4 | * 2019-08-14 WilliamZhu(allwefantasy@gmail.com) 5 | */ 6 | 7 | import java.io._ 8 | import java.net.{InetAddress, ServerSocket, Socket, SocketException} 9 | import java.util.Arrays 10 | import java.util.concurrent.TimeUnit 11 | import java.util.concurrent.atomic.AtomicLong 12 | 13 | import javax.annotation.concurrent.GuardedBy 14 | import tech.mlsql.arrow.Utils 15 | import tech.mlsql.arrow.python.runner.PythonConf 16 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros 17 | import tech.mlsql.common.utils.log.Logging 18 | 19 | import scala.collection.JavaConverters._ 20 | import scala.collection.mutable 21 | 22 | class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String], conf: Map[String, String]) 23 | extends Logging { 24 | self => 25 | 26 | import PythonWorkerFactory.Tool._ 27 | 28 | // Because forking processes from Java is expensive, we prefer to launch a single Python daemon, 29 | // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon 30 | // currently only works on UNIX-based systems now because it uses signals for child management, 31 | // so we can also fall back to launching workers, pyspark/worker.py (by default) directly. 32 | private val useDaemon = { 33 | val useDaemonEnabled = true 34 | 35 | // This flag is ignored on Windows as it's unable to fork. 36 | !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled 37 | } 38 | 39 | // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are 40 | // for very advanced users and they are experimental. This should be considered 41 | // as expert-only option, and shouldn't be used before knowing what it means exactly. 42 | 43 | // This configuration indicates the module to run the daemon to execute its Python workers. 44 | private val daemonModule = conf.getOrElse(PYTHON_DAEMON_MODULE, "pyjava.daemon") 45 | 46 | 47 | // This configuration indicates the module to run each Python worker. 48 | private val workerModule = conf.getOrElse(PYTHON_WORKER_MODULE, "pyjava.worker") 49 | 50 | private val workerIdleTime = conf.getOrElse(PYTHON_WORKER_IDLE_TIME, "1").toInt 51 | 52 | @GuardedBy("self") 53 | private var daemon: Process = null 54 | val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) 55 | @GuardedBy("self") 56 | private var daemonPort: Int = 0 57 | @GuardedBy("self") 58 | private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() 59 | @GuardedBy("self") 60 | private val idleWorkers = new mutable.Queue[Socket]() 61 | @GuardedBy("self") 62 | private var lastActivityNs = 0L 63 | 64 | 65 | private val monitorThread = new MonitorThread() 66 | monitorThread.setWorkerIdleTime(workerIdleTime) 67 | monitorThread.start() 68 | 69 | @GuardedBy("self") 70 | private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() 71 | 72 | private val pythonPath = mergePythonPaths( 73 | envVars.getOrElse("PYTHONPATH", ""), 74 | sys.env.getOrElse("PYTHONPATH", "")) 75 | 76 | def create(): Socket = { 77 | val socket = if (useDaemon) { 78 | self.synchronized { 79 | if (idleWorkers.nonEmpty) { 80 | return idleWorkers.dequeue() 81 | } 82 | } 83 | createThroughDaemon() 84 | } else { 85 | createSimpleWorker() 86 | } 87 | socket 88 | } 89 | 90 | /** 91 | * Connect to a worker launched through pyspark/daemon.py (by default), which forks python 92 | * processes itself to avoid the high cost of forking from Java. This currently only works 93 | * on UNIX-based systems. 94 | */ 95 | private def createThroughDaemon(): Socket = { 96 | 97 | def createSocket(): Socket = { 98 | val socket = new Socket(daemonHost, daemonPort) 99 | val pid = new DataInputStream(socket.getInputStream).readInt() 100 | if (pid < 0) { 101 | throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) 102 | } 103 | daemonWorkers.put(socket, pid) 104 | socket 105 | } 106 | 107 | self.synchronized { 108 | // Start the daemon if it hasn't been started 109 | startDaemon() 110 | 111 | // Attempt to connect, restart and retry once if it fails 112 | try { 113 | createSocket() 114 | } catch { 115 | case exc: SocketException => 116 | logWarning("Failed to open socket to Python daemon:", exc) 117 | logWarning("Assuming that daemon unexpectedly quit, attempting to restart") 118 | stopDaemon() 119 | startDaemon() 120 | createSocket() 121 | } 122 | } 123 | } 124 | 125 | /** 126 | * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. 127 | */ 128 | private def createSimpleWorker(): Socket = { 129 | var serverSocket: ServerSocket = null 130 | try { 131 | serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) 132 | 133 | // Create and start the worker 134 | val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) 135 | val workerEnv = pb.environment() 136 | workerEnv.putAll(envVars.asJava) 137 | workerEnv.put("PYTHONPATH", pythonPath) 138 | // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: 139 | workerEnv.put("PYTHONUNBUFFERED", "YES") 140 | workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) 141 | val worker = pb.start() 142 | 143 | // Redirect worker stdout and stderr 144 | Utils.redirectStream(conf, worker.getInputStream) 145 | Utils.redirectStream(conf, worker.getErrorStream) 146 | 147 | // Wait for it to connect to our socket, and validate the auth secret. 148 | serverSocket.setSoTimeout(10000) 149 | 150 | try { 151 | val socket = serverSocket.accept() 152 | self.synchronized { 153 | simpleWorkers.put(socket, worker) 154 | } 155 | return socket 156 | } catch { 157 | case e: Exception => 158 | throw new RuntimeException("Python worker failed to connect back.", e) 159 | } 160 | } finally { 161 | if (serverSocket != null) { 162 | serverSocket.close() 163 | } 164 | } 165 | null 166 | } 167 | 168 | private def startDaemon() { 169 | self.synchronized { 170 | // Is it already running? 171 | if (daemon != null) { 172 | return 173 | } 174 | 175 | try { 176 | // Create and start the daemon 177 | val envCommand = envVars.getOrElse(ScalaMethodMacros.str(PythonConf.PYTHON_ENV), "") 178 | val command = Seq("bash", "-c", envCommand + s" && python -m ${daemonModule}") 179 | val pb = new ProcessBuilder(command.asJava) 180 | val workerEnv = pb.environment() 181 | workerEnv.putAll(envVars.asJava) 182 | workerEnv.put("PYTHONPATH", pythonPath) 183 | // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: 184 | workerEnv.put("PYTHONUNBUFFERED", "YES") 185 | daemon = pb.start() 186 | 187 | val in = new DataInputStream(daemon.getInputStream) 188 | try { 189 | daemonPort = in.readInt() 190 | } catch { 191 | case _: EOFException => 192 | throw new RuntimeException(s"No port number in $daemonModule's stdout") 193 | } 194 | 195 | // test that the returned port number is within a valid range. 196 | // note: this does not cover the case where the port number 197 | // is arbitrary data but is also coincidentally within range 198 | if (daemonPort < 1 || daemonPort > 0xffff) { 199 | val exceptionMessage = 200 | f""" 201 | |Bad data in $daemonModule's standard output. Invalid port number: 202 | | $daemonPort (0x$daemonPort%08x) 203 | |Python command to execute the daemon was: 204 | | ${command.mkString(" ")} 205 | |Check that you don't have any unexpected modules or libraries in 206 | |your PYTHONPATH: 207 | | $pythonPath 208 | |Also, check if you have a sitecustomize.py module in your python path, 209 | |or in your python installation, that is printing to standard output""" 210 | throw new RuntimeException(exceptionMessage.stripMargin) 211 | } 212 | 213 | // Redirect daemon stdout and stderr 214 | Utils.redirectStream(conf, in) 215 | Utils.redirectStream(conf, daemon.getErrorStream) 216 | } catch { 217 | case e: Exception => 218 | 219 | // If the daemon exists, wait for it to finish and get its stderr 220 | val stderr = Option(daemon) 221 | .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) } 222 | .getOrElse("") 223 | 224 | stopDaemon() 225 | 226 | if (stderr != "") { 227 | val formattedStderr = stderr.replace("\n", "\n ") 228 | val errorMessage = 229 | s""" 230 | |Error from python worker: 231 | | $formattedStderr 232 | |PYTHONPATH was: 233 | | $pythonPath 234 | |$e""" 235 | 236 | // Append error message from python daemon, but keep original stack trace 237 | val wrappedException = new RuntimeException(errorMessage.stripMargin) 238 | wrappedException.setStackTrace(e.getStackTrace) 239 | throw wrappedException 240 | } else { 241 | throw e 242 | } 243 | } 244 | 245 | // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly 246 | // detect our disappearance. 247 | } 248 | } 249 | 250 | 251 | /** 252 | * Monitor all the idle workers, kill them after timeout. 253 | */ 254 | private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { 255 | //minutes 256 | val IDLE_WORKER_TIMEOUT_NS_REF = new AtomicLong(TimeUnit.MINUTES.toNanos(1)) 257 | 258 | def setWorkerIdleTime(minutes: Int) = { 259 | IDLE_WORKER_TIMEOUT_NS_REF.set(TimeUnit.MINUTES.toNanos(minutes)) 260 | } 261 | 262 | setDaemon(true) 263 | 264 | override def run() { 265 | while (true) { 266 | self.synchronized { 267 | if (IDLE_WORKER_TIMEOUT_NS_REF.get() < System.nanoTime() - lastActivityNs) { 268 | cleanupIdleWorkers() 269 | lastActivityNs = System.nanoTime() 270 | } 271 | } 272 | Thread.sleep(10000) 273 | } 274 | } 275 | } 276 | 277 | private def cleanupIdleWorkers() { 278 | while (idleWorkers.nonEmpty) { 279 | val worker = idleWorkers.dequeue() 280 | try { 281 | // the worker will exit after closing the socket 282 | worker.close() 283 | } catch { 284 | case e: Exception => 285 | logWarning("Failed to close worker socket", e) 286 | } 287 | } 288 | } 289 | 290 | private def stopDaemon() { 291 | self.synchronized { 292 | if (useDaemon) { 293 | cleanupIdleWorkers() 294 | 295 | // Request shutdown of existing daemon by sending SIGTERM 296 | if (daemon != null) { 297 | daemon.destroy() 298 | } 299 | 300 | daemon = null 301 | daemonPort = 0 302 | } else { 303 | simpleWorkers.mapValues(_.destroy()) 304 | } 305 | } 306 | } 307 | 308 | def stop() { 309 | stopDaemon() 310 | } 311 | 312 | def stopWorker(worker: Socket) { 313 | self.synchronized { 314 | if (useDaemon) { 315 | if (daemon != null) { 316 | daemonWorkers.get(worker).foreach { pid => 317 | // tell daemon to kill worker by pid 318 | val output = new DataOutputStream(daemon.getOutputStream) 319 | output.writeInt(pid) 320 | output.flush() 321 | daemon.getOutputStream.flush() 322 | } 323 | } 324 | } else { 325 | simpleWorkers.get(worker).foreach(_.destroy()) 326 | } 327 | } 328 | worker.close() 329 | } 330 | 331 | def releaseWorker(worker: Socket) { 332 | if (useDaemon) { 333 | self.synchronized { 334 | lastActivityNs = System.nanoTime() 335 | idleWorkers.enqueue(worker) 336 | } 337 | } else { 338 | // Cleanup the worker socket. This will also cause the Python worker to exit. 339 | try { 340 | worker.close() 341 | } catch { 342 | case e: Exception => 343 | logWarning("Failed to close worker socket", e) 344 | } 345 | } 346 | } 347 | } 348 | 349 | object PythonWorkerFactory { 350 | 351 | private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() 352 | 353 | def createPythonWorker(pythonExec: String, envVars: Map[String, String], conf: Map[String, String]): java.net.Socket = { 354 | synchronized { 355 | val key = (pythonExec, envVars) 356 | pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars, conf)).create() 357 | } 358 | } 359 | 360 | 361 | def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { 362 | synchronized { 363 | val key = (pythonExec, envVars) 364 | pythonWorkers.get(key).foreach(_.stopWorker(worker)) 365 | } 366 | } 367 | 368 | 369 | def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { 370 | synchronized { 371 | val key = (pythonExec, envVars) 372 | pythonWorkers.get(key).foreach(_.releaseWorker(worker)) 373 | } 374 | } 375 | 376 | 377 | object Tool { 378 | val PROCESS_WAIT_TIMEOUT_MS = 10000 379 | val PYTHON_DAEMON_MODULE = "python.daemon.module" 380 | val PYTHON_WORKER_MODULE = "python.worker.module" 381 | val PYTHON_WORKER_IDLE_TIME = "python.worker.idle.time" 382 | val PYTHON_TASK_KILL_TIMEOUT = "python.task.killTimeout" 383 | val REDIRECT_IMPL = "python.redirect.impl" 384 | 385 | def mergePythonPaths(paths: String*): String = { 386 | paths.filter(_ != "").mkString(File.pathSeparator) 387 | } 388 | } 389 | 390 | } 391 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/iapp/AppContextImpl.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python.iapp 2 | 3 | import java.net.{ServerSocket, Socket} 4 | import java.util.concurrent.atomic.AtomicBoolean 5 | import java.util.{EventListener, UUID} 6 | 7 | import org.apache.arrow.memory.BufferAllocator 8 | import org.apache.arrow.vector.ipc.ArrowStreamReader 9 | import tech.mlsql.arrow.context.CommonTaskContext 10 | import tech.mlsql.arrow.python.PythonWorkerFactory 11 | import tech.mlsql.arrow.python.runner.ArrowPythonRunner 12 | import tech.mlsql.common.utils.log.Logging 13 | 14 | import scala.collection.mutable.ArrayBuffer 15 | 16 | /** 17 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com) 18 | */ 19 | 20 | class AppContextImpl(context: JavaContext, _arrowPythonRunner: ArrowPythonRunner) extends CommonTaskContext with Logging { 21 | override def pythonWorkerRegister(callback: () => Unit) = { 22 | (releasedOrClosed: AtomicBoolean, 23 | reuseWorker: Boolean, 24 | worker: Socket 25 | ) => { 26 | context.addTaskCompletionListener[Unit] { _ => 27 | //writerThread.shutdownOnTaskCompletion() 28 | callback() 29 | if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { 30 | try { 31 | worker.close() 32 | } catch { 33 | case e: Exception => 34 | logWarning("Failed to close worker socket", e) 35 | } 36 | } 37 | } 38 | } 39 | } 40 | 41 | override def assertTaskIsCompleted(callback: () => Unit) = { 42 | () => { 43 | assert(context.isCompleted) 44 | } 45 | } 46 | 47 | override def setTaskContext(): () => Unit = { 48 | () => { 49 | 50 | } 51 | } 52 | 53 | override def innerContext: Any = context 54 | 55 | override def isBarrier: Boolean = false 56 | 57 | override def monitor(callback: () => Unit) = { 58 | (taskKillTimeout: Long, pythonExec: String, envVars: Map[String, String], worker: Socket) => { 59 | // Kill the worker if it is interrupted, checking until task completion. 60 | // TODO: This has a race condition if interruption occurs, as completed may still become true. 61 | while (!context.isInterrupted && !context.isCompleted) { 62 | Thread.sleep(2000) 63 | } 64 | if (!context.isCompleted) { 65 | Thread.sleep(taskKillTimeout) 66 | if (!context.isCompleted) { 67 | try { 68 | // Mimic the task name used in `Executor` to help the user find out the task to blame. 69 | val taskName = s"${context.partitionId}" 70 | logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") 71 | PythonWorkerFactory.destroyPythonWorker(pythonExec, envVars, worker) 72 | } catch { 73 | case e: Exception => 74 | logError("Exception when trying to kill worker", e) 75 | } 76 | } 77 | } 78 | } 79 | } 80 | 81 | override val arrowPythonRunner: ArrowPythonRunner = _arrowPythonRunner 82 | 83 | override def javaSideSocketServerRegister(): ServerSocket => Unit = { 84 | (server: ServerSocket) => { 85 | context.addTaskCompletionListener[Unit](_ => server.close()) 86 | } 87 | } 88 | 89 | override def isTaskCompleteOrInterrupt(): () => Boolean = { 90 | () => { 91 | context.isCompleted || context.isInterrupted 92 | } 93 | } 94 | 95 | override def isTaskInterrupt(): () => Boolean = { 96 | () => { 97 | context.isInterrupted 98 | } 99 | } 100 | 101 | override def getTaskKillReason(): () => Option[String] = { 102 | () => { 103 | context.getKillReason 104 | } 105 | } 106 | 107 | override def killTaskIfInterrupted(): () => Unit = { 108 | () => { 109 | context.killTaskIfInterrupted 110 | } 111 | } 112 | 113 | override def readerRegister(callback: () => Unit): (ArrowStreamReader, BufferAllocator) => Unit = { 114 | (reader, allocator) => { 115 | context.addTaskCompletionListener[Unit] { _ => 116 | if (reader != null) { 117 | reader.close(false) 118 | } 119 | try { 120 | allocator.close() 121 | } catch { 122 | case e: Exception => 123 | logError("allocator.close()", e) 124 | } 125 | } 126 | } 127 | } 128 | } 129 | 130 | 131 | class JavaContext { 132 | val buffer = new ArrayBuffer[AppTaskCompletionListener]() 133 | 134 | var _isCompleted = false 135 | var _partitionId = UUID.randomUUID().toString 136 | var reasonIfKilled: Option[String] = None 137 | var getKillReason = reasonIfKilled 138 | 139 | def isInterrupted = reasonIfKilled.isDefined 140 | 141 | def isCompleted = _isCompleted 142 | 143 | def partitionId = _partitionId 144 | 145 | def markComplete = { 146 | _isCompleted = true 147 | } 148 | 149 | def markInterrupted(reason: String): Unit = { 150 | reasonIfKilled = Some(reason) 151 | } 152 | 153 | 154 | def killTaskIfInterrupted = { 155 | val reason = reasonIfKilled 156 | if (reason.isDefined) { 157 | throw new TaskKilledException(reason.get) 158 | } 159 | } 160 | 161 | 162 | def addTaskCompletionListener[U](f: (JavaContext) => U): JavaContext = { 163 | // Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make 164 | // this function polymorphic for every scala version >= 2.12, otherwise an overloaded method 165 | // resolution error occurs at compile time. 166 | _addTaskCompletionListener(new AppTaskCompletionListener { 167 | override def onTaskCompletion(context: JavaContext): Unit = f(context) 168 | }) 169 | } 170 | 171 | def _addTaskCompletionListener(listener: AppTaskCompletionListener): JavaContext = { 172 | buffer += listener 173 | this 174 | } 175 | 176 | def close = { 177 | buffer.foreach(_.onTaskCompletion(this)) 178 | } 179 | } 180 | 181 | trait AppTaskCompletionListener extends EventListener { 182 | def onTaskCompletion(context: JavaContext): Unit 183 | } 184 | 185 | class TaskKilledException(val reason: String) extends RuntimeException { 186 | def this() = this("unknown reason") 187 | } 188 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/ispark/SparkContextImp.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python.ispark 2 | 3 | import java.net.{ServerSocket, Socket} 4 | import java.util.concurrent.atomic.AtomicBoolean 5 | 6 | import org.apache.arrow.memory.BufferAllocator 7 | import org.apache.arrow.vector.ipc.ArrowStreamReader 8 | import org.apache.spark.TaskContext 9 | import org.apache.spark.sql.SparkUtils 10 | import org.apache.spark.util.TaskCompletionListener 11 | import tech.mlsql.arrow.context.CommonTaskContext 12 | import tech.mlsql.arrow.python.PythonWorkerFactory 13 | import tech.mlsql.arrow.python.runner.ArrowPythonRunner 14 | import tech.mlsql.common.utils.log.Logging 15 | 16 | /** 17 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com) 18 | */ 19 | class SparkContextImp(context: TaskContext, _arrowPythonRunner: ArrowPythonRunner) extends CommonTaskContext with Logging { 20 | override def pythonWorkerRegister(callback: () => Unit) = { 21 | (releasedOrClosed: AtomicBoolean, 22 | reuseWorker: Boolean, 23 | worker: Socket 24 | ) => { 25 | context.addTaskCompletionListener(new TaskCompletionListener { 26 | override def onTaskCompletion(context: TaskContext): Unit = { 27 | //writerThread.shutdownOnTaskCompletion() 28 | callback() 29 | if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { 30 | try { 31 | worker.close() 32 | } catch { 33 | case e: Exception => 34 | logWarning("Failed to close worker socket", e) 35 | } 36 | } 37 | } 38 | }) 39 | } 40 | } 41 | 42 | override def assertTaskIsCompleted(callback: () => Unit) = { 43 | () => { 44 | assert(context.isCompleted) 45 | } 46 | } 47 | 48 | override def setTaskContext(): () => Unit = { 49 | () => { 50 | SparkUtils.setTaskContext(context) 51 | } 52 | } 53 | 54 | override def innerContext: Any = context 55 | 56 | override def isBarrier: Boolean = context.getClass.getName == "org.apache.spark.BarrierTaskContext" 57 | 58 | override def monitor(callback: () => Unit) = { 59 | (taskKillTimeout: Long, pythonExec: String, envVars: Map[String, String], worker: Socket) => { 60 | // Kill the worker if it is interrupted, checking until task completion. 61 | // TODO: This has a race condition if interruption occurs, as completed may still become true. 62 | while (!context.isInterrupted && !context.isCompleted) { 63 | Thread.sleep(2000) 64 | } 65 | if (!context.isCompleted) { 66 | Thread.sleep(taskKillTimeout) 67 | if (!context.isCompleted) { 68 | try { 69 | // Mimic the task name used in `Executor` to help the user find out the task to blame. 70 | val taskName = s"${context.partitionId}.${context.attemptNumber} " + 71 | s"in stage ${context.stageId} (TID ${context.taskAttemptId})" 72 | logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") 73 | PythonWorkerFactory.destroyPythonWorker(pythonExec, envVars, worker) 74 | } catch { 75 | case e: Exception => 76 | logError("Exception when trying to kill worker", e) 77 | } 78 | } 79 | } 80 | } 81 | } 82 | 83 | override val arrowPythonRunner: ArrowPythonRunner = _arrowPythonRunner 84 | 85 | override def javaSideSocketServerRegister(): ServerSocket => Unit = { 86 | (server: ServerSocket) => { 87 | context.addTaskCompletionListener(new TaskCompletionListener { 88 | override def onTaskCompletion(context: TaskContext): Unit = { 89 | server.close() 90 | } 91 | }) 92 | } 93 | } 94 | 95 | override def isTaskCompleteOrInterrupt(): () => Boolean = { 96 | () => { 97 | context.isCompleted || context.isInterrupted 98 | } 99 | } 100 | 101 | override def isTaskInterrupt(): () => Boolean = { 102 | () => { 103 | context.isInterrupted 104 | } 105 | } 106 | 107 | override def getTaskKillReason(): () => Option[String] = { 108 | () => { 109 | SparkUtils.getKillReason(context) 110 | } 111 | } 112 | 113 | override def killTaskIfInterrupted(): () => Unit = { 114 | () => { 115 | SparkUtils.killTaskIfInterrupted(context) 116 | } 117 | } 118 | 119 | override def readerRegister(callback: () => Unit): (ArrowStreamReader, BufferAllocator) => Unit = { 120 | (reader, allocator) => { 121 | context.addTaskCompletionListener(new TaskCompletionListener { 122 | override def onTaskCompletion(context: TaskContext): Unit = { 123 | if (reader != null) { 124 | reader.close(false) 125 | } 126 | // 这里有个特殊情况,用户可能只会读取部分数据,这个时候进行close,会 127 | // 显示内存泄露,此时进行close会抛错,我们需要catch住这个错误。 128 | // 目前来看,资源应该能够得到释放。大部分情况,我们都能正常消费掉所有数据。 129 | try { 130 | allocator.close() 131 | } catch { 132 | case e: Exception => 133 | logError("allocator.close()", e) 134 | } 135 | 136 | } 137 | }) 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/runner/ArrowPythonRunner.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python.runner 2 | 3 | import org.apache.arrow.vector.VectorSchemaRoot 4 | import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} 5 | import org.apache.spark.sql.catalyst.InternalRow 6 | import org.apache.spark.sql.types._ 7 | import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} 8 | import tech.mlsql.arrow.context.CommonTaskContext 9 | import tech.mlsql.arrow.{ArrowUtils, ArrowWriter, Utils} 10 | 11 | import java.io._ 12 | import java.net._ 13 | import java.util.concurrent.atomic.AtomicBoolean 14 | import scala.collection.JavaConverters._ 15 | 16 | 17 | /** 18 | * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. 19 | */ 20 | class ArrowPythonRunner( 21 | funcs: Seq[ChainedPythonFunctions], 22 | schema: StructType, 23 | timeZoneId: String, 24 | conf: Map[String, String]) 25 | extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( 26 | funcs, conf) { 27 | 28 | protected override def newWriterThread( 29 | worker: Socket, 30 | inputIterator: Iterator[Iterator[InternalRow]], 31 | partitionIndex: Int, 32 | context: CommonTaskContext): WriterThread = { 33 | new WriterThread(worker, inputIterator, partitionIndex, context) { 34 | 35 | protected override def writeCommand(dataOut: DataOutputStream): Unit = { 36 | 37 | // Write config for the worker as a number of key -> value pairs of strings 38 | dataOut.writeInt(conf.size + 1) 39 | for ((k, v) <- conf) { 40 | Utils.writeUTF(k, dataOut) 41 | Utils.writeUTF(v, dataOut) 42 | } 43 | Utils.writeUTF("timezone", dataOut) 44 | Utils.writeUTF(timeZoneId, dataOut) 45 | 46 | val command = funcs.head.funcs.head.command 47 | Utils.writeUTF(command, dataOut) 48 | 49 | } 50 | 51 | protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { 52 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) 53 | val allocator = ArrowUtils.rootAllocator.newChildAllocator( 54 | s"stdout writer for $pythonExec", 0, Long.MaxValue) 55 | val root = VectorSchemaRoot.create(arrowSchema, allocator) 56 | 57 | Utils.tryWithSafeFinally { 58 | val arrowWriter = ArrowWriter.create(root) 59 | val writer = new ArrowStreamWriter(root, null, dataOut) 60 | writer.start() 61 | 62 | while (inputIterator.hasNext) { 63 | val nextBatch = inputIterator.next() 64 | 65 | while (nextBatch.hasNext) { 66 | arrowWriter.write(nextBatch.next()) 67 | } 68 | 69 | arrowWriter.finish() 70 | writer.writeBatch() 71 | arrowWriter.reset() 72 | } 73 | // end writes footer to the output stream and doesn't clean any resources. 74 | // It could throw exception if the output stream is closed, so it should be 75 | // in the try block. 76 | writer.end() 77 | } { 78 | // If we close root and allocator in TaskCompletionListener, there could be a race 79 | // condition where the writer thread keeps writing to the VectorSchemaRoot while 80 | // it's being closed by the TaskCompletion listener. 81 | // Closing root and allocator here is cleaner because root and allocator is owned 82 | // by the writer thread and is only visible to the writer thread. 83 | // 84 | // If the writer thread is interrupted by TaskCompletionListener, it should either 85 | // (1) in the try block, in which case it will get an InterruptedException when 86 | // performing io, and goes into the finally block or (2) in the finally block, 87 | // in which case it will ignore the interruption and close the resources. 88 | root.close() 89 | allocator.close() 90 | } 91 | } 92 | } 93 | } 94 | 95 | protected override def newReaderIterator( 96 | stream: DataInputStream, 97 | writerThread: WriterThread, 98 | startTime: Long, 99 | worker: Socket, 100 | releasedOrClosed: AtomicBoolean, 101 | context: CommonTaskContext): Iterator[ColumnarBatch] = { 102 | new ReaderIterator(stream, writerThread, startTime, worker, releasedOrClosed, context) { 103 | 104 | private val allocator = ArrowUtils.rootAllocator.newChildAllocator( 105 | s"stdin reader for $pythonExec", 0, Long.MaxValue) 106 | 107 | private var reader: ArrowStreamReader = _ 108 | private var root: VectorSchemaRoot = _ 109 | private var schema: StructType = _ 110 | private var vectors: Array[ColumnVector] = _ 111 | context.readerRegister(() => {})(reader, allocator) 112 | 113 | private var batchLoaded = true 114 | 115 | protected override def read(): ColumnarBatch = { 116 | if (writerThread.exception.isDefined) { 117 | throw writerThread.exception.get 118 | } 119 | try { 120 | if (reader != null && batchLoaded) { 121 | batchLoaded = reader.loadNextBatch() 122 | if (batchLoaded) { 123 | val batch = new ColumnarBatch(vectors) 124 | batch.setNumRows(root.getRowCount) 125 | batch 126 | } else { 127 | reader.close(false) 128 | allocator.close() 129 | // Reach end of stream. Call `read()` again to read control data. 130 | read() 131 | } 132 | } else { 133 | stream.readInt() match { 134 | case SpecialLengths.START_ARROW_STREAM => 135 | try { 136 | reader = new ArrowStreamReader(stream, allocator) 137 | root = reader.getVectorSchemaRoot() 138 | schema = ArrowUtils.fromArrowSchema(root.getSchema()) 139 | vectors = root.getFieldVectors().asScala.map { vector => 140 | new ArrowColumnVector(vector) 141 | }.toArray[ColumnVector] 142 | read() 143 | } catch { 144 | case e: IOException if (e.getMessage.contains("Missing schema") || e.getMessage.contains("Expected schema but header was")) => 145 | logInfo("Arrow read schema fail", e) 146 | reader = null 147 | read() 148 | } 149 | 150 | case SpecialLengths.ARROW_STREAM_CRASH => 151 | read() 152 | 153 | case SpecialLengths.PYTHON_EXCEPTION_THROWN => 154 | throw handlePythonException() 155 | 156 | case SpecialLengths.PYTHON_EXCEPTION_THROWN => 157 | throw handlePythonException() 158 | 159 | case SpecialLengths.END_OF_DATA_SECTION => 160 | handleEndOfDataSection() 161 | null 162 | } 163 | } 164 | } catch handleException 165 | } 166 | } 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/runner/PythonProjectRunner.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python.runner 2 | 3 | import java.io._ 4 | import java.util.concurrent.atomic.AtomicReference 5 | 6 | import os.SubProcess 7 | import tech.mlsql.arrow.Utils 8 | import tech.mlsql.common.utils.log.Logging 9 | import tech.mlsql.common.utils.shell.ShellCommand 10 | 11 | import scala.io.Source 12 | 13 | /** 14 | * 2019-08-22 WilliamZhu(allwefantasy@gmail.com) 15 | */ 16 | class PythonProjectRunner(projectDirectory: String, 17 | env: Map[String, String]) extends Logging { 18 | 19 | import PythonProjectRunner._ 20 | 21 | private var innerProcess: Option[SubProcess] = None 22 | 23 | def getPythonProcess = innerProcess 24 | 25 | def run(command: Seq[String], 26 | conf: Map[String, String] 27 | ) = { 28 | val proc = os.proc(command).spawn( 29 | cwd = os.Path(projectDirectory), 30 | env = env) 31 | innerProcess = Option(proc) 32 | val (_, pythonPid) = try { 33 | val f = proc.wrapped.getClass.getDeclaredField("pid") 34 | f.setAccessible(true) 35 | val parentPid = f.getLong(proc.wrapped) 36 | val subPid = ShellCommand.execCmdV2("pgrep", "-P", parentPid).out.lines.mkString("") 37 | (parentPid, subPid) 38 | } catch { 39 | case e: Exception => 40 | logWarning( 41 | s""" 42 | |${command.mkString(" ")} may not been killed since we can not get it's pid. 43 | |Make sure you are runing on mac/linux and pgrep is installed. 44 | |""".stripMargin) 45 | (-1, -1) 46 | } 47 | 48 | 49 | val lines = Source.fromInputStream(proc.stdout.wrapped)("utf-8").getLines 50 | val childThreadException = new AtomicReference[Throwable](null) 51 | // Start a thread to print the process's stderr to ours 52 | new Thread(s"stdin writer for $command") { 53 | def writeConf = { 54 | val dataOut = new DataOutputStream(proc.stdin) 55 | dataOut.writeInt(conf.size) 56 | for ((k, v) <- conf) { 57 | Utils.writeUTF(k, dataOut) 58 | Utils.writeUTF(v, dataOut) 59 | } 60 | } 61 | 62 | override def run(): Unit = { 63 | try { 64 | writeConf 65 | } catch { 66 | case t: Throwable => childThreadException.set(t) 67 | } finally { 68 | proc.stdin.close() 69 | } 70 | } 71 | }.start() 72 | 73 | // redirect err to other place(e.g. send them to driver) 74 | new Thread(s"stderr reader for $command") { 75 | override def run(): Unit = { 76 | if (conf.getOrElse("throwErr", "true").toBoolean) { 77 | val err = proc.stderr.lines.mkString("\n") 78 | if (!err.isEmpty) { 79 | childThreadException.set(new PythonErrException(err)) 80 | } 81 | } else { 82 | Utils.redirectStream(conf, proc.stderr) 83 | } 84 | } 85 | }.start() 86 | 87 | 88 | new Iterator[String] { 89 | def next(): String = { 90 | if (!hasNext()) { 91 | throw new NoSuchElementException() 92 | } 93 | val line = lines.next() 94 | line 95 | } 96 | 97 | def hasNext(): Boolean = { 98 | val result = if (lines.hasNext) { 99 | true 100 | } else { 101 | try { 102 | proc.waitFor() 103 | } 104 | catch { 105 | case e: InterruptedException => 106 | 0 107 | } 108 | cleanup() 109 | if (proc.exitCode() != 0) { 110 | val msg = s"Subprocess exited with status ${proc.exitCode()}. " + 111 | s"Command ran: " + command.mkString(" ") 112 | if(childThreadException.get()!=null){ 113 | throw childThreadException.get() 114 | }else { 115 | throw new IllegalStateException(msg) 116 | } 117 | } 118 | false 119 | } 120 | propagateChildException 121 | result 122 | } 123 | 124 | private def cleanup(): Unit = { 125 | ShellCommand.execCmdV2("kill", "-9", pythonPid + "") 126 | // cleanup task working directory if used 127 | scala.util.control.Exception.ignoring(classOf[IOException]) { 128 | if (conf.get(KEEP_LOCAL_DIR).map(_.toBoolean).getOrElse(false)) { 129 | Utils.deleteRecursively(new File(projectDirectory)) 130 | } 131 | } 132 | log.debug(s"Removed task working directory $projectDirectory") 133 | } 134 | 135 | private def propagateChildException(): Unit = { 136 | val t = childThreadException.get() 137 | if (t != null) { 138 | proc.destroy() 139 | cleanup() 140 | throw t 141 | } 142 | } 143 | 144 | } 145 | } 146 | } 147 | 148 | object PythonProjectRunner { 149 | val KEEP_LOCAL_DIR = "keepLocalDir" 150 | } 151 | 152 | class PythonErrException(message: String, cause: Throwable) 153 | extends Exception(message, cause) { 154 | 155 | def this(message: String) = this(message, null) 156 | } 157 | -------------------------------------------------------------------------------- /src/main/java/tech/mlsql/arrow/python/runner/SparkSocketRunner.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.arrow.python.runner 2 | 3 | import org.apache.arrow.vector.VectorSchemaRoot 4 | import org.apache.arrow.vector.ipc.ArrowStreamReader 5 | import org.apache.spark.sql.catalyst.InternalRow 6 | import org.apache.spark.sql.types.StructType 7 | import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} 8 | import org.apache.spark.{SparkException, TaskKilledException} 9 | import tech.mlsql.arrow._ 10 | import tech.mlsql.arrow.context.CommonTaskContext 11 | import tech.mlsql.common.utils.distribute.socket.server.SocketServerInExecutor 12 | import tech.mlsql.common.utils.log.Logging 13 | 14 | import java.io._ 15 | import java.net.Socket 16 | import java.nio.charset.StandardCharsets 17 | import scala.collection.JavaConverters._ 18 | 19 | 20 | class SparkSocketRunner(runnerName: String, host: String, timeZoneId: String) { 21 | 22 | def serveToStream(threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { 23 | val (_server, _host, _port) = SocketServerInExecutor.setupOneConnectionServer(host, runnerName)(s => { 24 | val out = new BufferedOutputStream(s.getOutputStream()) 25 | Utils.tryWithSafeFinally { 26 | writeFunc(out) 27 | } { 28 | out.close() 29 | } 30 | }) 31 | 32 | Array(_server, _host, _port) 33 | } 34 | 35 | def serveToStreamWithArrow(iter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, context: CommonTaskContext) = { 36 | serveToStream(runnerName) { out => 37 | val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId) 38 | val arrowBatch = ArrowConverters.toBatchIterator( 39 | iter, schema, maxRecordsPerBatch, timeZoneId, context) 40 | batchWriter.writeBatches(arrowBatch) 41 | batchWriter.end() 42 | } 43 | } 44 | 45 | def readFromStreamWithArrow(host: String, port: Int, context: CommonTaskContext) = { 46 | val socket = new Socket(host, port) 47 | val stream = new DataInputStream(socket.getInputStream) 48 | val outfile = new DataOutputStream(socket.getOutputStream) 49 | new ReaderIterator[ColumnarBatch](stream, System.currentTimeMillis(), context) { 50 | private val allocator = ArrowUtils.rootAllocator.newChildAllocator( 51 | s"stdin reader ", 0, Long.MaxValue) 52 | 53 | private var reader: ArrowStreamReader = _ 54 | private var root: VectorSchemaRoot = _ 55 | private var schema: StructType = _ 56 | private var vectors: Array[ColumnVector] = _ 57 | context.readerRegister(() => {})(reader, allocator) 58 | 59 | private var batchLoaded = true 60 | 61 | protected override def read(): ColumnarBatch = { 62 | try { 63 | if (reader != null && batchLoaded) { 64 | batchLoaded = reader.loadNextBatch() 65 | if (batchLoaded) { 66 | val batch = new ColumnarBatch(vectors) 67 | batch.setNumRows(root.getRowCount) 68 | batch 69 | } else { 70 | reader.close(false) 71 | allocator.close() 72 | // Reach end of stream. Call `read()` again to read control data. 73 | read() 74 | } 75 | } else { 76 | stream.readInt() match { 77 | case SpecialLengths.START_ARROW_STREAM => 78 | 79 | try { 80 | reader = new ArrowStreamReader(stream, allocator) 81 | root = reader.getVectorSchemaRoot() 82 | schema = ArrowUtils.fromArrowSchema(root.getSchema()) 83 | vectors = root.getFieldVectors().asScala.map { vector => 84 | new ArrowColumnVector(vector) 85 | }.toArray[ColumnVector] 86 | read() 87 | } catch { 88 | case e: IOException if (e.getMessage.contains("Missing schema") || e.getMessage.contains("Expected schema but header was")) => 89 | logInfo("Arrow read schema fail", e) 90 | reader = null 91 | read() 92 | } 93 | 94 | case SpecialLengths.ARROW_STREAM_CRASH => 95 | read() 96 | 97 | case SpecialLengths.PYTHON_EXCEPTION_THROWN => 98 | throw handlePythonException(outfile) 99 | 100 | case SpecialLengths.END_OF_DATA_SECTION => 101 | handleEndOfDataSection(outfile) 102 | null 103 | } 104 | } 105 | } catch handleException 106 | } 107 | }.flatMap { batch => 108 | batch.rowIterator.asScala 109 | } 110 | } 111 | 112 | } 113 | 114 | object SparkSocketRunner { 115 | 116 | } 117 | 118 | abstract class ReaderIterator[OUT]( 119 | stream: DataInputStream, 120 | startTime: Long, 121 | context: CommonTaskContext) 122 | extends Iterator[OUT] with Logging { 123 | 124 | private var nextObj: OUT = _ 125 | private var eos = false 126 | 127 | override def hasNext: Boolean = nextObj != null || { 128 | if (!eos) { 129 | nextObj = read() 130 | hasNext 131 | } else { 132 | false 133 | } 134 | } 135 | 136 | override def next(): OUT = { 137 | if (hasNext) { 138 | val obj = nextObj 139 | nextObj = null.asInstanceOf[OUT] 140 | obj 141 | } else { 142 | Iterator.empty.next() 143 | } 144 | } 145 | 146 | /** 147 | * Reads next object from the stream. 148 | * When the stream reaches end of data, needs to process the following sections, 149 | * and then returns null. 150 | */ 151 | protected def read(): OUT 152 | 153 | 154 | protected def handlePythonException(out: DataOutputStream): SparkException = { 155 | // Signals that an exception has been thrown in python 156 | val exLength = stream.readInt() 157 | val obj = new Array[Byte](exLength) 158 | stream.readFully(obj) 159 | try { 160 | out.writeInt(SpecialLengths.END_OF_STREAM) 161 | out.flush() 162 | } catch { 163 | case e: Exception => logError("", e) 164 | } 165 | new SparkException(new String(obj, StandardCharsets.UTF_8), null) 166 | } 167 | 168 | protected def handleEndOfStream(out: DataOutputStream): Unit = { 169 | 170 | eos = true 171 | } 172 | 173 | protected def handleEndOfDataSection(out: DataOutputStream): Unit = { 174 | //read end of stream 175 | val flag = stream.readInt() 176 | if (flag != SpecialLengths.END_OF_STREAM) { 177 | logWarning( 178 | s""" 179 | |-----------------------WARNING-------------------------------------------------------------------- 180 | |Here we should received message is SpecialLengths.END_OF_STREAM:${SpecialLengths.END_OF_STREAM} 181 | |But It's now ${flag}. 182 | |This may cause the **** python worker leak **** and make the ***interactive mode fails***. 183 | |-------------------------------------------------------------------------------------------------- 184 | """.stripMargin) 185 | } 186 | try { 187 | out.writeInt(SpecialLengths.END_OF_STREAM) 188 | out.flush() 189 | } catch { 190 | case e: Exception => logError("", e) 191 | } 192 | 193 | eos = true 194 | } 195 | 196 | protected val handleException: PartialFunction[Throwable, OUT] = { 197 | case e: Exception if context.isTaskInterrupt()() => 198 | logDebug("Exception thrown after task interruption", e) 199 | throw new TaskKilledException(context.getTaskKillReason()().getOrElse("unknown reason")) 200 | case e: SparkException => 201 | throw e 202 | case eof: EOFException => 203 | throw new SparkException("Python worker exited unexpectedly (crashed)", eof) 204 | 205 | case e: Exception => 206 | throw new SparkException("Error to read", e) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/ApplyPythonScript.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import org.apache.spark.{TaskContext, WowRowEncoder} 4 | import org.apache.spark.sql.Row 5 | import org.apache.spark.sql.catalyst.InternalRow 6 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 7 | import org.apache.spark.sql.types.StructType 8 | import tech.mlsql.arrow.python.ispark.SparkContextImp 9 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonFunction} 10 | 11 | import scala.collection.JavaConverters._ 12 | 13 | /** 14 | * 19/4/2021 WilliamZhu(allwefantasy@gmail.com) 15 | */ 16 | class ApplyPythonScript(_rayAddress: String, _envs: java.util.HashMap[String, String], _timezoneId: String, _pythonMode: String = "ray") { 17 | def execute(script: String, struct: StructType): Iterator[Row] => Iterator[InternalRow] = { 18 | // val rayAddress = _rayAddress 19 | val envs = _envs 20 | val timezoneId = _timezoneId 21 | val pythonMode = _pythonMode 22 | 23 | iter => 24 | val encoder = WowRowEncoder.fromRow(struct) 25 | //RowEncoder.apply(struct).resolveAndBind() 26 | val batch = new ArrowPythonRunner( 27 | Seq(ChainedPythonFunctions(Seq(PythonFunction( 28 | script, envs, "python", "3.6")))), struct, 29 | timezoneId, Map("pythonMode" -> pythonMode) 30 | ) 31 | val newIter = iter.map(encoder) 32 | val commonTaskContext = new SparkContextImp(TaskContext.get(), batch) 33 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext) 34 | columnarBatchIter.flatMap(_.rowIterator.asScala).map(f => f.copy()) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/JavaApp1Spec.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import java.util 4 | 5 | import org.apache.spark.sql.Row 6 | import org.apache.spark.sql.types.{ArrayType, DoubleType, StructField, StructType} 7 | import org.apache.spark.{TaskContext, WowRowEncoder} 8 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 9 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext} 10 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonConf, PythonFunction} 11 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros.str 12 | 13 | import scala.collection.JavaConverters._ 14 | 15 | /** 16 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com) 17 | */ 18 | class JavaApp1Spec extends FunSuite 19 | with BeforeAndAfterAll { 20 | 21 | def condaEnv = "source /Users/allwefantasy/opt/anaconda3/bin/activate ray-dev" 22 | 23 | test("normal java application") { 24 | val envs = new util.HashMap[String, String]() 25 | envs.put(str(PythonConf.PYTHON_ENV), s"${condaEnv} && export ARROW_PRE_0_15_IPC_FORMAT=1 ") 26 | val sourceSchema = StructType(Seq(StructField("value", ArrayType(DoubleType)))) 27 | 28 | val runnerConf = Map( 29 | "HOME" -> "", 30 | "OWNER" -> "", 31 | "GROUP_ID" -> "", 32 | "directData" -> "true", 33 | "runIn" -> "driver", 34 | "mode" -> "model", 35 | "rayAddress" -> "127.0.0.1:10001", 36 | "pythonMode" -> "ray" 37 | ) 38 | 39 | val batch = new ArrowPythonRunner( 40 | Seq(ChainedPythonFunctions(Seq(PythonFunction( 41 | """ 42 | |import ray 43 | |from pyjava.api.mlsql import RayContext 44 | |ray_context = RayContext.connect(globals(), context.conf["rayAddress"],namespace="default") 45 | |udfMaster = ray.get_actor("model_predict") 46 | |[index,worker] = ray.get(udfMaster.get.remote()) 47 | | 48 | |input = [row["value"] for row in context.fetch_once_as_rows()] 49 | |try: 50 | | res = ray.get(worker.apply.remote(input)) 51 | |except Exception as inst: 52 | | res=[] 53 | | print(inst) 54 | | 55 | |udfMaster.give_back.remote(index) 56 | |ray_context.build_result([res]) 57 | """.stripMargin, envs, "python", "3.6")))), sourceSchema, 58 | "GMT", runnerConf 59 | ) 60 | 61 | 62 | val sourceEnconder = WowRowEncoder.fromRow(sourceSchema) //RowEncoder.apply(sourceSchema).resolveAndBind() 63 | val newIter = Seq(Row.fromSeq(Seq(Seq(1.1))), Row.fromSeq(Seq(Seq(1.2)))).map { irow => 64 | sourceEnconder(irow).copy() 65 | }.iterator 66 | 67 | val javaConext = new JavaContext 68 | val commonTaskContext = new AppContextImpl(javaConext, batch) 69 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext) 70 | val outputSchema = StructType(Seq(StructField("value", ArrayType(ArrayType(DoubleType))))) 71 | val outputEnconder = WowRowEncoder.toRow(outputSchema) 72 | columnarBatchIter.flatMap { batch => 73 | batch.rowIterator.asScala 74 | }.map { r => 75 | outputEnconder(r) 76 | }.toList.foreach(item => println(item.getAs[Seq[Seq[DoubleType]]](0))) 77 | 78 | javaConext.markComplete 79 | javaConext.close 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/JavaAppSpec.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import java.util 4 | 5 | import org.apache.spark.{TaskContext, WowRowEncoder} 6 | import org.apache.spark.sql.Row 7 | import org.apache.spark.sql.catalyst.encoders.RowEncoder 8 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 9 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 10 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext} 11 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonConf, PythonFunction} 12 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros.str 13 | 14 | import scala.collection.JavaConverters._ 15 | 16 | /** 17 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com) 18 | */ 19 | class JavaAppSpec extends FunSuite 20 | with BeforeAndAfterAll { 21 | 22 | def condaEnv = "source /Users/allwefantasy/opt/anaconda3/bin/activate dev" 23 | 24 | test("normal java application") { 25 | val envs = new util.HashMap[String, String]() 26 | envs.put(str(PythonConf.PYTHON_ENV), s"${condaEnv} && export ARROW_PRE_0_15_IPC_FORMAT=1 ") 27 | val sourceSchema = StructType(Seq(StructField("value", StringType))) 28 | val batch = new ArrowPythonRunner( 29 | Seq(ChainedPythonFunctions(Seq(PythonFunction( 30 | """ 31 | |import pandas as pd 32 | |import numpy as np 33 | | 34 | |def process(): 35 | | for item in context.fetch_once_as_rows(): 36 | | item["value1"] = item["value"] + "_suffix" 37 | | yield item 38 | | 39 | |context.build_result(process()) 40 | """.stripMargin, envs, "python", "3.6")))), sourceSchema, 41 | "GMT", Map() 42 | ) 43 | 44 | 45 | val sourceEnconder = WowRowEncoder.fromRow(sourceSchema) //RowEncoder.apply(sourceSchema).resolveAndBind() 46 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow => 47 | sourceEnconder(irow).copy() 48 | }.iterator 49 | 50 | val javaConext = new JavaContext 51 | val commonTaskContext = new AppContextImpl(javaConext, batch) 52 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext) 53 | //copy is required 54 | columnarBatchIter.flatMap { batch => 55 | batch.rowIterator.asScala 56 | }.foreach(f => println(f.copy())) 57 | javaConext.markComplete 58 | javaConext.close 59 | } 60 | 61 | } 62 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/JavaArrowServer.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import org.apache.spark.WowRowEncoder 4 | import org.apache.spark.sql.Row 5 | import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} 6 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 7 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext} 8 | import tech.mlsql.arrow.python.runner.SparkSocketRunner 9 | import tech.mlsql.common.utils.network.NetUtils 10 | 11 | /** 12 | * 24/12/2019 WilliamZhu(allwefantasy@gmail.com) 13 | */ 14 | class JavaArrowServer extends FunSuite with BeforeAndAfterAll { 15 | 16 | test("test java arrow server") { 17 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin") 18 | 19 | val dataSchema = StructType(Seq(StructField("value", StringType))) 20 | val encoder = WowRowEncoder.fromRow(dataSchema) //RowEncoder.apply(dataSchema).resolveAndBind() 21 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow => 22 | encoder(irow) 23 | }.iterator 24 | val javaConext = new JavaContext 25 | val commonTaskContext = new AppContextImpl(javaConext, null) 26 | 27 | val Array(_, host, port) = socketRunner.serveToStreamWithArrow(newIter, dataSchema, 10, commonTaskContext) 28 | println(s"${host}:${port}") 29 | Thread.currentThread().join() 30 | } 31 | 32 | test("test read python arrow server") { 33 | val enconder = WowRowEncoder.toRow(StructType(Seq(StructField("a", LongType),StructField("b", LongType)))) 34 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin") 35 | val javaConext = new JavaContext 36 | val commonTaskContext = new AppContextImpl(javaConext, null) 37 | val iter = socketRunner.readFromStreamWithArrow("127.0.0.1", 11111, commonTaskContext) 38 | iter.foreach(i => println(enconder(i.copy()))) 39 | javaConext.close 40 | } 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/Main.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import tech.mlsql.arrow.python.runner.PythonProjectRunner 4 | import tech.mlsql.common.utils.path.PathFun 5 | 6 | /** 7 | * 4/3/2020 WilliamZhu(allwefantasy@gmail.com) 8 | */ 9 | object Main { 10 | def main(args: Array[String]): Unit = { 11 | val project = getExampleProject("pyproject1") 12 | val runner = new PythonProjectRunner(project, Map()) 13 | val output = runner.run(Seq("bash", "-c", "source activate dev && python -u train.py"), Map( 14 | "tempDataLocalPath" -> "/tmp/data", 15 | "tempModelLocalPath" -> "/tmp/model" 16 | )) 17 | output.foreach(println) 18 | } 19 | def getExampleProject(name: String) = { 20 | PathFun(getHome).add("examples").add(name).toPath 21 | } 22 | 23 | def getHome = { 24 | getClass.getResource("").getPath.split("target/test-classes").head 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/PythonProjectSpec.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 4 | import tech.mlsql.arrow.python.runner.PythonProjectRunner 5 | import tech.mlsql.common.utils.path.PathFun 6 | 7 | /** 8 | * 2019-08-22 WilliamZhu(allwefantasy@gmail.com) 9 | */ 10 | class PythonProjectSpec extends FunSuite with BeforeAndAfterAll { 11 | test("test python project") { 12 | val project = getExampleProject("pyproject1") 13 | val runner = new PythonProjectRunner(project, Map()) 14 | val output = runner.run(Seq("bash", "-c", "source activate dev && python -u train.py"), Map( 15 | "tempDataLocalPath" -> "/tmp/data", 16 | "tempModelLocalPath" -> "/tmp/model" 17 | )) 18 | output.foreach(println) 19 | } 20 | 21 | def getExampleProject(name: String) = { 22 | PathFun(getHome).add("examples").add(name).toPath 23 | } 24 | 25 | def getHome = { 26 | getClass.getResource("").getPath.split("target/test-classes").head 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/RayEnv.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2016 Kyligence Inc. All rights reserved. 3 | * 4 | * http://kyligence.io 5 | * 6 | * This software is the confidential and proprietary information of 7 | * Kyligence Inc. ("Confidential Information"). You shall not disclose 8 | * such Confidential Information and shall use it only in accordance 9 | * with the terms of the license agreement you entered into with 10 | * Kyligence Inc. 11 | * 12 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 13 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 14 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 15 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 16 | * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 17 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 18 | * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 19 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 20 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | */ 24 | package tech.mlsql.test 25 | 26 | import java.io.File 27 | import java.nio.charset.StandardCharsets 28 | import java.util.TimeZone 29 | import java.util.regex.Pattern 30 | 31 | import com.google.common.io.Files 32 | import org.apache.spark.{TaskContext, WowRowEncoder} 33 | import org.apache.spark.rdd.RDD 34 | import org.apache.spark.sql.catalyst.InternalRow 35 | import org.apache.spark.sql.{DataFrame, Row} 36 | import os.CommandResult 37 | import tech.mlsql.arrow.python.ispark.SparkContextImp 38 | import tech.mlsql.arrow.python.runner.SparkSocketRunner 39 | import tech.mlsql.common.utils.log.Logging 40 | import tech.mlsql.common.utils.net.NetTool 41 | import tech.mlsql.common.utils.network.NetUtils 42 | import tech.mlsql.common.utils.shell.ShellCommand 43 | import tech.mlsql.test.RayEnv.ServerInfo 44 | 45 | class RayEnv extends Logging with Serializable { 46 | 47 | @transient var rayAddress: String = _ 48 | @transient var rayOptions: Map[String, String] = Map.empty 49 | 50 | @transient var dataServers: Seq[ServerInfo] = _ 51 | 52 | def startRay(envName: String): Unit = { 53 | 54 | val rayVersionRes = runWithProfile(envName, "ray --version") 55 | val rayVersion = rayVersionRes.out.lines.toList.head.split("version").last.trim 56 | 57 | val result = if (rayVersion >= "1.0.0" || rayVersion == "0.8.7") { 58 | runWithProfile(envName, "ray start --head") 59 | } else { 60 | runWithProfile(envName, "ray start --head --include-webui") 61 | } 62 | val initRegex = Pattern.compile(".*ray.init\\((.*)\\).*") 63 | val errResultLines = result.err.lines.toList 64 | val outResultLines = result.out.lines.toList 65 | 66 | errResultLines.foreach(logInfo(_)) 67 | outResultLines.foreach(logInfo(_)) 68 | 69 | (errResultLines ++ outResultLines).foreach(line => { 70 | val matcher = initRegex.matcher(line) 71 | if (matcher.matches()) { 72 | val params = matcher.group(1) 73 | 74 | 75 | val options = params.split(",") 76 | .filter(_.contains("=")).map(param => { 77 | val words = param.trim.split("=") 78 | (words.head, words(1).substring(1, words(1).length - 1)) 79 | }).toMap 80 | 81 | 82 | if (rayAddress == null) { 83 | rayAddress = options.getOrElse("address", options.getOrElse("redis_address", null)) 84 | rayOptions = options - "address" - "redis_address" 85 | logInfo(s"Start Ray:${rayAddress}") 86 | } 87 | } 88 | }) 89 | 90 | if (rayVersion >= "1.0.0") { 91 | val initRegex2 = Pattern.compile(".*--address='(.*?)'.*") 92 | (errResultLines ++ outResultLines).foreach(line => { 93 | val matcher = initRegex2.matcher(line) 94 | if (matcher.matches()) { 95 | val params = matcher.group(1) 96 | 97 | val host = params.split(":").head 98 | rayAddress = host + ":10001" 99 | rayOptions = rayOptions 100 | new Thread(new Runnable { 101 | override def run(): Unit = { 102 | runWithProfile(envName, 103 | s""" 104 | |python -m ray.util.client.server --host ${host} --port 10001 105 | |""".stripMargin) 106 | } 107 | }).start() 108 | Thread.sleep(3000) 109 | logInfo(s"Start Ray:${rayAddress}") 110 | } 111 | }) 112 | } 113 | 114 | 115 | if (rayAddress == null) { 116 | throw new RuntimeException("Fail to start ray") 117 | } 118 | 119 | 120 | } 121 | 122 | def stopRay(envName: String): Unit = { 123 | val result = runWithProfile(envName, "ray stop") 124 | result 125 | } 126 | 127 | def startDataServer(df: DataFrame): Unit = { 128 | val dataSchema = df.schema 129 | dataServers = df.repartition(1).rdd.mapPartitions(iter => { 130 | val socketServer = new SparkSocketRunner("serve-runner-for-ut", NetTool.localHostName(), TimeZone.getDefault.getID) 131 | val commonTaskContext = new SparkContextImp(TaskContext.get(), null) 132 | val rab = WowRowEncoder.fromRow(dataSchema) //RowEncoder.apply(dataSchema).resolveAndBind() 133 | val newIter = iter.map(row => { 134 | rab(row) 135 | }) 136 | val Array(_server, _host: String, _port: Int) = socketServer.serveToStreamWithArrow(newIter, dataSchema, 10, commonTaskContext) 137 | Seq(ServerInfo(_host, _port, TimeZone.getDefault.getID)).iterator 138 | }).collect().toSeq 139 | } 140 | 141 | def collectResult(rdd: RDD[Row]): RDD[InternalRow] = { 142 | rdd.flatMap { row => 143 | val socketRunner = new SparkSocketRunner("read-runner-for-ut", NetUtils.getHost, TimeZone.getDefault.getID) 144 | val commonTaskContext = new SparkContextImp(TaskContext.get(), null) 145 | val pythonWorkerHost = row.getAs[String]("host") 146 | val pythonWorkerPort = row.getAs[Long]("port").toInt 147 | logInfo(s" Ray On Data Mode: connect python worker[${pythonWorkerHost}:${pythonWorkerPort}] ") 148 | val iter = socketRunner.readFromStreamWithArrow(pythonWorkerHost, pythonWorkerPort, commonTaskContext) 149 | iter.map(f => f.copy()) 150 | } 151 | } 152 | 153 | 154 | private def runWithProfile(envName: String, command: String): CommandResult = { 155 | val tmpShellFile = File.createTempFile("shell", ".sh") 156 | val setupEnv = if (envName.trim.startsWith("conda") || envName.trim.startsWith("source")) envName else s"""conda activate ${envName}""" 157 | try { 158 | Files.write( 159 | s""" 160 | |#!/bin/bash 161 | |export LC_ALL=en_US.utf-8 162 | |export LANG=en_US.utf-8 163 | | 164 | |# source ~/.bash_profile 165 | |${setupEnv} 166 | |${command} 167 | |""".stripMargin, tmpShellFile, StandardCharsets.UTF_8) 168 | val cmdResult = ShellCommand.execCmdV2("/bin/bash", tmpShellFile.getAbsolutePath) 169 | // if (cmdResult.exitCode != 0) { 170 | // throw new RuntimeException(s"run command failed ${cmdResult.toString()}") 171 | // } 172 | cmdResult 173 | } finally { 174 | tmpShellFile.delete() 175 | } 176 | } 177 | 178 | } 179 | 180 | object RayEnv { 181 | 182 | case class ServerInfo(host: String, port: Long, timezone: String) 183 | 184 | } -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/SparkSpec.scala: -------------------------------------------------------------------------------- 1 | package tech.mlsql.test 2 | 3 | import java.util 4 | 5 | import org.apache.spark.sql.{Row, SparkSession, SparkUtils} 6 | import org.scalatest.{BeforeAndAfterAll, FunSuite} 7 | import tech.mlsql.common.utils.log.Logging 8 | import tech.mlsql.test.function.SparkFunctions.MockData 9 | 10 | /** 11 | * 2019-08-14 WilliamZhu(allwefantasy@gmail.com) 12 | */ 13 | class SparkSpec extends FunSuite with BeforeAndAfterAll with Logging{ 14 | 15 | val rayEnv = new RayEnv 16 | var spark: SparkSession = null 17 | 18 | def condaEnv = "source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.2" 19 | 20 | //spark.executor.heartbeatInterval 21 | test("test python ray connect") { 22 | val session = spark 23 | import session.implicits._ 24 | val timezoneId = session.sessionState.conf.sessionLocalTimeZone 25 | 26 | val dataDF = session.createDataset(Range(0, 100).map(i => MockData(s"Title${i}", s"body-${i}"))).toDF() 27 | rayEnv.startDataServer(dataDF) 28 | 29 | val df = session.createDataset(rayEnv.dataServers).toDF() 30 | 31 | val envs = new util.HashMap[String, String]() 32 | envs.put("PYTHON_ENV", s"${condaEnv};export ARROW_PRE_0_15_IPC_FORMAT=1") 33 | //envs.put("PYTHONPATH", (os.pwd / "python").toString()) 34 | 35 | val aps = new ApplyPythonScript(rayEnv.rayAddress, envs, timezoneId) 36 | val rayAddress = rayEnv.rayAddress 37 | logInfo(rayAddress) 38 | val func = aps.execute( 39 | s""" 40 | |import ray 41 | |import time 42 | |from pyjava.api.mlsql import RayContext 43 | |import numpy as np; 44 | |ray_context = RayContext.connect(globals(),"${rayAddress}") 45 | |def echo(row): 46 | | row1 = {} 47 | | row1["title"]=row['title'][1:] 48 | | row1["body"]= row["body"] + ',' + row["body"] 49 | | return row1 50 | |ray_context.foreach(echo) 51 | """.stripMargin, df.schema) 52 | 53 | val outputDF = df.rdd.mapPartitions(func) 54 | 55 | val pythonServers = SparkUtils.internalCreateDataFrame(session, outputDF, df.schema).collect() 56 | 57 | val rdd = session.sparkContext.makeRDD[Row](pythonServers, numSlices = pythonServers.length) 58 | val pythonOutputDF = rayEnv.collectResult(rdd) 59 | val output = SparkUtils.internalCreateDataFrame(session, pythonOutputDF, dataDF.schema).collect() 60 | assert(output.length == 100) 61 | output.zipWithIndex.foreach({ 62 | case (row, index) => 63 | assert(row.getString(0) == s"itle${index}") 64 | assert(row.getString(1) == s"body-${index},body-${index}") 65 | }) 66 | } 67 | 68 | override def beforeAll(): Unit = { 69 | spark = SparkSession.builder().master("local[*]").appName("test").getOrCreate() 70 | super.beforeAll() 71 | rayEnv.startRay(condaEnv) 72 | } 73 | 74 | override def afterAll(): Unit = { 75 | if (spark != null) { 76 | spark.sparkContext.stop() 77 | } 78 | rayEnv.stopRay(condaEnv) 79 | super.afterAll() 80 | } 81 | 82 | } 83 | -------------------------------------------------------------------------------- /src/test/java/tech/mlsql/test/function/SparkFunctions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2016 Kyligence Inc. All rights reserved. 3 | * 4 | * http://kyligence.io 5 | * 6 | * This software is the confidential and proprietary information of 7 | * Kyligence Inc. ("Confidential Information"). You shall not disclose 8 | * such Confidential Information and shall use it only in accordance 9 | * with the terms of the license agreement you entered into with 10 | * Kyligence Inc. 11 | * 12 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 13 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 14 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 15 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 16 | * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 17 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 18 | * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 19 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 20 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | */ 24 | package tech.mlsql.test.function 25 | 26 | import org.apache.spark.{TaskContext, WowRowEncoder} 27 | import org.apache.spark.sql.Row 28 | import org.apache.spark.sql.catalyst.InternalRow 29 | import org.apache.spark.sql.types.StructType 30 | import tech.mlsql.arrow.python.ispark.SparkContextImp 31 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonConf, PythonFunction} 32 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros.str 33 | import java.util 34 | 35 | import scala.collection.JavaConverters.asScalaIteratorConverter 36 | 37 | object SparkFunctions { 38 | 39 | case class MockData(title: String, body: String) 40 | 41 | def testScript1(struct: StructType, rayAddress: String, timezoneId: String): Iterator[Row] => Iterator[InternalRow] = { 42 | iter => 43 | val encoder = WowRowEncoder.fromRow(struct) 44 | val envs = new util.HashMap[String, String]() 45 | envs.put(str(PythonConf.PYTHON_ENV), "source ~/.bash_profile && conda activate dev && export ARROW_PRE_0_15_IPC_FORMAT=1") 46 | envs.put("PYTHONPATH", (os.pwd / "python").toString()) 47 | val batch = new ArrowPythonRunner( 48 | Seq(ChainedPythonFunctions(Seq(PythonFunction( 49 | s""" 50 | |import ray 51 | |import time 52 | |from pyjava.api.mlsql import RayContext 53 | |import numpy as np; 54 | |ray_context = RayContext.connect(globals(),"${rayAddress}") 55 | |def echo(row): 56 | | row1 = {} 57 | | row1["title"]=row['title'][1:] 58 | | row1["body"]= row["body"] + ',' + row["body"] 59 | | return row1 60 | |ray_context.foreach(echo) 61 | """.stripMargin, envs, "python", "3.6")))), struct, 62 | timezoneId, Map("pythonMode" -> "ray") 63 | ) 64 | val newIter = iter.map(encoder) 65 | val commonTaskContext = new SparkContextImp(TaskContext.get(), batch) 66 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext) 67 | columnarBatchIter.flatMap(_.rowIterator.asScala).map(f => f.copy()) 68 | } 69 | } 70 | --------------------------------------------------------------------------------