├── .gitattributes
├── .gitignore
├── .repo
├── 243
│ └── source
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── WowRowEncoder.scala
├── 311
│ └── source
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── WowRowEncoder.scala
├── .DS_Store
└── pom.template.xml
├── README.md
├── dev
├── change-scala-version.sh
├── change-version-to-2.11.sh
├── change-version-to-2.12.sh
└── release.sh
├── examples
└── pyproject1
│ └── train.py
├── pom.xml
├── python
├── README.md
├── install.sh
├── pyjava
│ ├── __init__.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── mlsql.py
│ │ └── serve.py
│ ├── cache
│ │ ├── __init__.py
│ │ └── code_cache.py
│ ├── cloudpickle.py
│ ├── daemon.py
│ ├── datatype
│ │ ├── __init__.py
│ │ └── types.py
│ ├── example
│ │ ├── OnceServerExample.py
│ │ ├── RayServerExample.py
│ │ ├── __init__.py
│ │ ├── test.py
│ │ ├── test2.py
│ │ └── test3.py
│ ├── rayfix.py
│ ├── serializers.py
│ ├── storage
│ │ ├── __init__.py
│ │ └── streaming_tar.py
│ ├── tests
│ │ └── test_context.py
│ ├── udf
│ │ └── __init__.py
│ ├── utils.py
│ ├── version.py
│ └── worker.py
├── requirements.txt
├── setup.cfg
└── setup.py
└── src
├── main
└── java
│ ├── org
│ └── apache
│ │ └── spark
│ │ ├── WowRowEncoder.scala
│ │ └── sql
│ │ └── SparkUtils.scala
│ └── tech
│ └── mlsql
│ └── arrow
│ ├── ArrowConverters.scala
│ ├── ArrowUtils.scala
│ ├── ArrowWriter.scala
│ ├── ByteBufferOutputStream.scala
│ ├── Utils.scala
│ ├── api
│ └── RedirectStreams.scala
│ ├── context
│ └── CommonTaskContext.scala
│ ├── javadoc.java
│ └── python
│ ├── PyJavaException.scala
│ ├── PythonWorkerFactory.scala
│ ├── iapp
│ └── AppContextImpl.scala
│ ├── ispark
│ └── SparkContextImp.scala
│ └── runner
│ ├── ArrowPythonRunner.scala
│ ├── PythonProjectRunner.scala
│ ├── PythonRunner.scala
│ └── SparkSocketRunner.scala
└── test
└── java
└── tech
└── mlsql
└── test
├── ApplyPythonScript.scala
├── JavaApp1Spec.scala
├── JavaAppSpec.scala
├── JavaArrowServer.scala
├── Main.scala
├── PythonProjectSpec.scala
├── RayEnv.scala
├── SparkSpec.scala
└── function
└── SparkFunctions.scala
/.gitattributes:
--------------------------------------------------------------------------------
1 | pom.xml merge=ours
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | pyjava.iml
3 | python/build/
4 | python/dist/
5 | python/pyjava.egg-info/
6 | target/
7 | python/.eggs/
8 | __pycache__
9 | spark-warehouse
10 |
11 | *.iml
--------------------------------------------------------------------------------
/.repo/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/.repo/.DS_Store
--------------------------------------------------------------------------------
/.repo/243/source/org/apache/spark/WowRowEncoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * DO NOT EDIT THIS FILE DIRECTLY, ANY CHANGE MAY BE OVERWRITE
3 | */
4 | package org.apache.spark
5 |
6 | import org.apache.spark.sql.Row
7 | import org.apache.spark.sql.catalyst.InternalRow
8 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
9 | import org.apache.spark.sql.types.StructType
10 |
11 | object WowRowEncoder {
12 | def toRow(schema: StructType) = {
13 | val rab = RowEncoder.apply(schema).resolveAndBind()
14 | (irow: InternalRow) => {
15 | rab.fromRow(irow)
16 | }
17 | }
18 |
19 | def fromRow(schema: StructType) = {
20 | val rab = RowEncoder.apply(schema).resolveAndBind()
21 | (row: Row) => {
22 | rab.toRow(row)
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/.repo/311/source/org/apache/spark/WowRowEncoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * DO NOT EDIT THIS FILE DIRECTLY, ANY CHANGE MAY BE OVERWRITE
3 | */
4 | package org.apache.spark
5 | import org.apache.spark.sql.Row
6 | import org.apache.spark.sql.catalyst.InternalRow
7 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
8 | import org.apache.spark.sql.types.StructType
9 |
10 | object WowRowEncoder {
11 | def toRow(schema: StructType) = {
12 | RowEncoder.apply(schema).resolveAndBind().createDeserializer()
13 |
14 | }
15 |
16 | def fromRow(schema: StructType) = {
17 | RowEncoder.apply(schema).resolveAndBind().createSerializer()
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## PyJava
2 |
3 | This library is an ongoing effort towards bringing the data exchanging ability
4 | between Java/Scala and Python. PyJava introduces Apache Arrow as the exchanging data format,
5 | this means we can avoid ser/der between Java/Scala and Python which can really speed up the
6 | communication efficiency than traditional way.
7 |
8 | When you invoke python code in Java/Scala side, PyJava will start some python workers automatically
9 | and send the data to python worker, and once they are processed, send them back. The python workers are reused
10 | by default.
11 |
12 |
13 | > The initial code in this lib is from Apache Spark.
14 |
15 |
16 | ## Install
17 |
18 | Setup python(>= 3.6) Env(Conda is recommended):
19 |
20 | ```shell
21 | pip uninstall pyjava && pip install pyjava
22 | ```
23 |
24 | Setup Java env(Maven is recommended):
25 |
26 | For Scala 2.11/Spark 2.4.3
27 |
28 | ```xml
29 |
30 | tech.mlsql
31 | pyjava-2.4_2.11
32 | 0.3.2
33 |
34 | ```
35 |
36 | For Scala 2.12/Spark 3.1.1
37 |
38 | ```xml
39 |
40 | tech.mlsql
41 | pyjava-3.0_2.12
42 | 0.3.2
43 |
44 | ```
45 |
46 | # Build Mannually
47 |
48 | Install Build Tool:
49 |
50 | ```
51 | pip install mlsql_plugin_tool
52 | ```
53 |
54 | Build for Spark 3.1.1:
55 |
56 | ```
57 | mlsql_plugin_tool spark311
58 | mvn clean install -DskipTests -Pdisable-java8-doclint -Prelease-sign-artifacts
59 | ```
60 |
61 | Build For Spark 2.4.3
62 |
63 | ```
64 | mlsql_plugin_tool spark243
65 | mvn clean install -DskipTests -Pdisable-java8-doclint -Prelease-sign-artifacts
66 | ```
67 |
68 |
69 | ## Using python code snippet to process data in Java/Scala
70 |
71 | With pyjava, you can run any python code in your Java/Scala application.
72 |
73 | ```scala
74 |
75 | val envs = new util.HashMap[String, String]()
76 | // prepare python environment
77 | envs.put(str(PythonConf.PYTHON_ENV), "source activate dev && export ARROW_PRE_0_15_IPC_FORMAT=1 ")
78 |
79 | // describe the data which will be transfered to python
80 | val sourceSchema = StructType(Seq(StructField("value", StringType)))
81 |
82 | val batch = new ArrowPythonRunner(
83 | Seq(ChainedPythonFunctions(Seq(PythonFunction(
84 | """
85 | |import pandas as pd
86 | |import numpy as np
87 | |
88 | |def process():
89 | | for item in context.fetch_once_as_rows():
90 | | item["value1"] = item["value"] + "_suffix"
91 | | yield item
92 | |
93 | |context.build_result(process())
94 | """.stripMargin, envs, "python", "3.6")))), sourceSchema,
95 | "GMT", Map()
96 | )
97 |
98 | // prepare data
99 | val sourceEnconder = RowEncoder.apply(sourceSchema).resolveAndBind()
100 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow =>
101 | sourceEnconder.toRow(irow).copy()
102 | }.iterator
103 |
104 | // run the code and get the return result
105 | val javaConext = new JavaContext
106 | val commonTaskContext = new AppContextImpl(javaConext, batch)
107 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
108 |
109 | //f.copy(), copy function is required
110 | columnarBatchIter.flatMap { batch =>
111 | batch.rowIterator.asScala
112 | }.foreach(f => println(f.copy()))
113 | javaConext.markComplete
114 | javaConext.close
115 | ```
116 |
117 | ## Using python code snippet to process data in Spark
118 |
119 | ```scala
120 | val session = spark
121 | import session.implicits._
122 | val timezoneid = session.sessionState.conf.sessionLocalTimeZone
123 | val df = session.createDataset[String](Seq("a1", "b1")).toDF("value")
124 | val struct = df.schema
125 | val abc = df.rdd.mapPartitions { iter =>
126 | val enconder = RowEncoder.apply(struct).resolveAndBind()
127 | val envs = new util.HashMap[String, String]()
128 | envs.put(str(PythonConf.PYTHON_ENV), "source activate streamingpro-spark-2.4.x")
129 | val batch = new ArrowPythonRunner(
130 | Seq(ChainedPythonFunctions(Seq(PythonFunction(
131 | """
132 | |import pandas as pd
133 | |import numpy as np
134 | |for item in data_manager.fetch_once():
135 | | print(item)
136 | |df = pd.DataFrame({'AAA': [4, 5, 6, 7],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]})
137 | |data_manager.set_output([[df['AAA'],df['BBB']]])
138 | """.stripMargin, envs, "python", "3.6")))), struct,
139 | timezoneid, Map()
140 | )
141 | val newIter = iter.map { irow =>
142 | enconder.toRow(irow)
143 | }
144 | val commonTaskContext = new SparkContextImp(TaskContext.get(), batch)
145 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
146 | columnarBatchIter.flatMap { batch =>
147 | batch.rowIterator.asScala.map(_.copy)
148 | }
149 | }
150 |
151 | val wow = SparkUtils.internalCreateDataFrame(session, abc, StructType(Seq(StructField("AAA", LongType), StructField("BBB", LongType))), false)
152 | wow.show()
153 | ```
154 |
155 | ## Run Python Project
156 |
157 | With Pyjava, you can tell the system where is the python project and which is then entrypoint,
158 | then you can run this project in Java/Scala.
159 |
160 | ```scala
161 | import tech.mlsql.arrow.python.runner.PythonProjectRunner
162 |
163 | val runner = new PythonProjectRunner("./pyjava/examples/pyproject1", Map())
164 | val output = runner.run(Seq("bash", "-c", "source activate dev && python train.py"), Map(
165 | "tempDataLocalPath" -> "/tmp/data",
166 | "tempModelLocalPath" -> "/tmp/model"
167 | ))
168 | output.foreach(println)
169 | ```
170 |
171 |
172 | ## Example In MLSQL
173 |
174 | ### None Interactive Mode:
175 |
176 | ```sql
177 | !python env "PYTHON_ENV=source activate streamingpro-spark-2.4.x";
178 | !python conf "schema=st(field(a,long),field(b,long))";
179 |
180 | select 1 as a as table1;
181 |
182 | !python on table1 '''
183 |
184 | import pandas as pd
185 | import numpy as np
186 | for item in data_manager.fetch_once():
187 | print(item)
188 | df = pd.DataFrame({'AAA': [4, 5, 6, 8],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]})
189 | data_manager.set_output([[df['AAA'],df['BBB']]])
190 |
191 | ''' named mlsql_temp_table2;
192 |
193 | select * from mlsql_temp_table2 as output;
194 | ```
195 |
196 | ### Interactive Mode:
197 |
198 | ```sql
199 | !python start;
200 |
201 | !python env "PYTHON_ENV=source activate streamingpro-spark-2.4.x";
202 | !python env "schema=st(field(a,integer),field(b,integer))";
203 |
204 |
205 | !python '''
206 | import pandas as pd
207 | import numpy as np
208 | ''';
209 |
210 | !python '''
211 | for item in data_manager.fetch_once():
212 | print(item)
213 | df = pd.DataFrame({'AAA': [4, 5, 6, 8],'BBB': [10, 20, 30, 40],'CCC': [100, 50, -30, -50]})
214 | data_manager.set_output([[df['AAA'],df['BBB']]])
215 | ''';
216 | !python close;
217 | ```
218 |
219 |
220 |
221 | ## Using PyJava as Arrow Server/Client
222 |
223 | Java Server side:
224 |
225 | ```scala
226 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin")
227 |
228 | val dataSchema = StructType(Seq(StructField("value", StringType)))
229 | val enconder = RowEncoder.apply(dataSchema).resolveAndBind()
230 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow =>
231 | enconder.toRow(irow)
232 | }.iterator
233 | val javaConext = new JavaContext
234 | val commonTaskContext = new AppContextImpl(javaConext, null)
235 |
236 | val Array(_, host, port) = socketRunner.serveToStreamWithArrow(newIter, dataSchema, 10, commonTaskContext)
237 | println(s"${host}:${port}")
238 | Thread.currentThread().join()
239 | ```
240 |
241 | Python Client side:
242 |
243 | ```python
244 | import os
245 | import socket
246 |
247 | from pyjava.serializers import \
248 | ArrowStreamPandasSerializer
249 |
250 | out_ser = ArrowStreamPandasSerializer(None, True, True)
251 |
252 | out_ser = ArrowStreamPandasSerializer("Asia/Harbin", False, None)
253 | HOST = ""
254 | PORT = -1
255 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
256 | sock.connect((HOST, PORT))
257 | buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
258 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
259 | outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)
260 | kk = out_ser.load_stream(infile)
261 | for item in kk:
262 | print(item)
263 | ```
264 |
265 | Python Server side:
266 |
267 | ```python
268 | import os
269 |
270 | import pandas as pd
271 |
272 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1"
273 | from pyjava.api.serve import OnceServer
274 |
275 | ddata = pd.DataFrame(data=[[1, 2, 3, 4], [2, 3, 4, 5]])
276 |
277 | server = OnceServer("127.0.0.1", 11111, "Asia/Harbin")
278 | server.bind()
279 | server.serve([{'id': 9, 'label': 1}])
280 | ```
281 |
282 | Java Client side:
283 |
284 | ```scala
285 | import org.apache.spark.sql.Row
286 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
287 | import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
288 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
289 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext}
290 | import tech.mlsql.arrow.python.runner.SparkSocketRunner
291 | import tech.mlsql.common.utils.network.NetUtils
292 |
293 | val enconder = RowEncoder.apply(StructType(Seq(StructField("a", LongType),StructField("b", LongType)))).resolveAndBind()
294 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin")
295 | val javaConext = new JavaContext
296 | val commonTaskContext = new AppContextImpl(javaConext, null)
297 | val iter = socketRunner.readFromStreamWithArrow("127.0.0.1", 11111, commonTaskContext)
298 | iter.foreach(i => println(enconder.fromRow(i.copy())))
299 | javaConext.close
300 | ```
301 |
302 | ## How to configure python worker runs in Docker (todo)
303 |
304 |
305 |
--------------------------------------------------------------------------------
/dev/change-scala-version.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | set -e
21 |
22 | VALID_VERSIONS=( 2.11 2.12 )
23 |
24 | usage() {
25 | echo "Usage: $(basename $0) [-h|--help]
26 | where :
27 | -h| --help Display this help text
28 | valid version values : ${VALID_VERSIONS[*]}
29 | " 1>&2
30 | exit 1
31 | }
32 |
33 | if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then
34 | usage
35 | fi
36 |
37 | TO_VERSION=$1
38 |
39 | check_scala_version() {
40 | for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done
41 | echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2
42 | exit 1
43 | }
44 |
45 | check_scala_version "$TO_VERSION"
46 |
47 | if [ $TO_VERSION = "2.12" ]; then
48 | FROM_VERSION="2.11"
49 | else
50 | FROM_VERSION="2.12"
51 | fi
52 |
53 | sed_i() {
54 | sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2"
55 | }
56 |
57 | export -f sed_i
58 |
59 | BASEDIR=$(dirname $0)/..
60 | find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \
61 | -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \;
62 |
63 | # Also update in parent POM
64 | # Match any scala binary version to ensure idempotency
65 | sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION'' \
66 | "$BASEDIR/pom.xml"
67 |
68 | # Update source of scaladocs
69 | # echo "$BASEDIR/docs/_plugins/copy_api_dirs.rb"
70 | # sed_i 's/scala\-'$FROM_VERSION'/scala\-'$TO_VERSION'/' "$BASEDIR/docs/_plugins/copy_api_dirs.rb"
71 |
--------------------------------------------------------------------------------
/dev/change-version-to-2.11.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | # This script exists for backwards compability. Use change-scala-version.sh instead.
21 | echo "This script is deprecated. Please instead run: change-scala-version.sh 2.11"
22 |
23 | $(dirname $0)/change-scala-version.sh 2.11
24 |
--------------------------------------------------------------------------------
/dev/change-version-to-2.12.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 | #
19 |
20 | # This script exists for backwards compability. Use change-scala-version.sh instead.
21 | echo "This script is deprecated. Please instead run: change-scala-version.sh 2.10"
22 |
23 | $(dirname $0)/change-scala-version.sh 2.12
24 |
--------------------------------------------------------------------------------
/dev/release.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | version=${1:-0.3.3}
4 |
5 | quoteVersion=$(cat python/pyjava/version.py|grep "__version__" |awk -F'=' '{print $2}'| xargs )
6 |
7 | if [[ "${version}" != "${quoteVersion}" ]];then
8 | echo "version[${quoteVersion}] in python/pyjava/version.py is not match with version[${version}] you specified"
9 | exit 1
10 | fi
11 |
12 | if [[ ! -d '.repo' ]];then
13 | echo "Make sure this script executed in root directory of pyjava"
14 | exit 1
15 | fi
16 |
17 | echo "deploy pyjava jar based on spark243...."
18 | mlsql_plugin_tool spark243
19 | mvn clean deploy -DskipTests -Pdisable-java8-doclint -Prelease-sign-artifacts
20 |
21 | echo "deploy pyjava jar based on spark311...."
22 | mlsql_plugin_tool spark311
23 | mvn clean deploy -DskipTests -Pdisable-java8-doclint -Prelease-sign-artifacts
24 |
25 | echo "deploy pyjava pip...."
26 | cd python
27 | rm -rf dist
28 | pip uninstall -y pyjava && python setup.py sdist bdist_wheel && cd ./dist/ && pip install pyjava-${version}-py3-none-any.whl && cd -
29 | twine upload dist/*
--------------------------------------------------------------------------------
/examples/pyproject1/train.py:
--------------------------------------------------------------------------------
1 | from pyjava.api.mlsql import PythonProjectContext
2 |
3 | context = PythonProjectContext()
4 | context.read_params_once()
5 | print(context.conf)
6 |
7 | import time
8 | print('foo', flush=True)
9 | print("yes")
10 | # time.sleep(10)
11 | print("10 yes")
12 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | tech.mlsql
8 | pyjava-3.0_2.12
9 | 0.3.2
10 | pyjava with arrow support
11 | https://github.com/allwefantasy/pyjava
12 |
13 | Communication between Python And Java with Apache Arrow.
14 |
15 |
16 |
17 | Apache 2.0 License
18 | http://www.apache.org/licenses/LICENSE-2.0.html
19 | repo
20 |
21 |
22 |
23 |
24 | allwefantasy
25 | ZhuHaiLin
26 | allwefantasy@gmail.com
27 |
28 |
29 |
30 |
31 | scm:git:git@github.com:allwefantasy/pyjava.git
32 |
33 |
34 | scm:git:git@github.com:allwefantasy/pyjava.git
35 |
36 | https://github.com/allwefantasy/pyjava
37 |
38 |
39 | https://github.com/allwefantasy/pyjava/issues
40 |
41 |
42 |
43 | UTF-8
44 |
45 | 2021.1.12
46 |
47 | 2.12.10
48 | 2.12
49 | 3.1.1
50 | 3.0
51 | 2.0.0
52 |
53 | 16.0
54 | 4.5.3
55 | provided
56 | 2.6.5
57 | 0.3.6
58 | 2.0.6
59 |
60 |
61 |
62 |
63 |
64 | disable-java8-doclint
65 |
66 | [1.8,)
67 |
68 |
69 | -Xdoclint:none
70 | none
71 |
72 |
73 |
74 | release-sign-artifacts
75 |
76 |
77 | performRelease
78 | true
79 |
80 |
81 |
82 |
83 |
84 | org.apache.maven.plugins
85 | maven-gpg-plugin
86 | 1.5
87 |
88 |
89 | sign-artifacts
90 | verify
91 |
92 | sign
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 | tech.mlsql
105 | common-utils_${scala.binary.version}
106 | ${common-utils-version}
107 |
108 |
109 | org.scalactic
110 | scalactic_${scala.binary.version}
111 | 3.0.0
112 | test
113 |
114 |
115 | org.scalatest
116 | scalatest_${scala.binary.version}
117 | 3.0.0
118 | test
119 |
120 |
121 | org.apache.arrow
122 | arrow-vector
123 | ${arrow.version}
124 |
125 |
126 | com.fasterxml.jackson.core
127 | jackson-annotations
128 |
129 |
130 | com.fasterxml.jackson.core
131 | jackson-databind
132 |
133 |
134 | io.netty
135 | netty-buffer
136 |
137 |
138 | io.netty
139 | netty-common
140 |
141 |
142 | io.netty
143 | netty-handler
144 |
145 |
146 |
147 |
148 |
149 |
150 | com.fasterxml.jackson.core
151 | jackson-core
152 | 2.10.1
153 |
154 |
155 |
156 |
157 |
158 | org.apache.spark
159 | spark-core_${scala.binary.version}
160 | ${spark.version}
161 | ${scope}
162 |
163 |
164 | org.apache.spark
165 | spark-sql_${scala.binary.version}
166 | ${spark.version}
167 | ${scope}
168 |
169 |
170 |
171 | org.apache.spark
172 | spark-mllib_${scala.binary.version}
173 | ${spark.version}
174 | ${scope}
175 |
176 |
177 |
178 | org.apache.spark
179 | spark-graphx_${scala.binary.version}
180 | ${spark.version}
181 | ${scope}
182 |
183 |
184 |
185 | org.apache.spark
186 | spark-catalyst_${scala.binary.version}
187 | ${spark.version}
188 | tests
189 | test
190 |
191 |
192 |
193 | org.apache.spark
194 | spark-core_${scala.binary.version}
195 | ${spark.version}
196 | tests
197 | test
198 |
199 |
200 |
201 | org.apache.spark
202 | spark-sql_${scala.binary.version}
203 | ${spark.version}
204 | tests
205 | test
206 |
207 |
208 |
209 | org.pegdown
210 | pegdown
211 | 1.6.0
212 | test
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 | src/main/resources
221 |
222 |
223 |
224 |
225 | org.apache.maven.plugins
226 | maven-surefire-plugin
227 | 3.0.0-M1
228 |
229 | 1
230 | true
231 | -Xmx4024m
232 |
233 | **/*.java
234 | **/*.scala
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 | org.scala-tools
243 | maven-scala-plugin
244 | 2.15.2
245 |
246 |
247 |
248 | -g:vars
249 |
250 |
251 | true
252 |
253 |
254 |
255 | compile
256 |
257 | compile
258 |
259 | compile
260 |
261 |
262 | testCompile
263 |
264 | testCompile
265 |
266 | test
267 |
268 |
269 | process-resources
270 |
271 | compile
272 |
273 |
274 |
275 |
276 |
277 |
278 | org.apache.maven.plugins
279 | maven-compiler-plugin
280 | 2.3.2
281 |
282 |
283 | -g
284 | true
285 | 1.8
286 | 1.8
287 |
288 |
289 |
290 |
291 |
292 |
293 | maven-source-plugin
294 | 2.1
295 |
296 | true
297 |
298 |
299 |
300 | compile
301 |
302 | jar
303 |
304 |
305 |
306 |
307 |
308 | org.apache.maven.plugins
309 | maven-javadoc-plugin
310 |
311 |
312 | attach-javadocs
313 |
314 | jar
315 |
316 |
317 |
318 |
319 |
320 | org.sonatype.plugins
321 | nexus-staging-maven-plugin
322 | 1.6.7
323 | true
324 |
325 | sonatype-nexus-staging
326 | https://oss.sonatype.org/
327 | true
328 |
329 |
330 |
331 |
332 | org.scalatest
333 | scalatest-maven-plugin
334 | 2.0.0
335 |
336 | streaming.core.NotToRunTag
337 | ${project.build.directory}/surefire-reports
338 | .
339 | WDF TestSuite.txt
340 | ${project.build.directory}/html/scalatest
341 | false
342 |
343 |
344 |
345 | test
346 |
347 | test
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 | sonatype-nexus-snapshots
381 | https://oss.sonatype.org/content/repositories/snapshots
382 |
383 |
384 | sonatype-nexus-staging
385 | https://oss.sonatype.org/service/local/staging/deploy/maven2/
386 |
387 |
388 |
389 |
390 |
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | ## Python side in PyJava
2 |
3 | This library is designed for Java/Scala invoking python code with Apache
4 | Arrow as the exchanging data format. This means if you want to run some python code
5 | in Scala/Java, the Scala/Java will start some python workers with this lib's help.
6 | Then the Scala/Java will send the data/code to python worker in socket and the python will
7 | return the processed data back.
8 |
9 | The initial code in this lib is from Apache Spark.
10 |
11 |
--------------------------------------------------------------------------------
/python/install.sh:
--------------------------------------------------------------------------------
1 | project=pyjava
2 | version=0.3.3
3 | rm -rf ./dist/*
4 | pip uninstall -y ${project}
5 | python setup.py sdist bdist_wheel
6 | cd ./dist/
7 | pip install ${project}-${version}-py3-none-any.whl && cd -
8 | twine upload dist/*
9 |
--------------------------------------------------------------------------------
/python/pyjava/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import types
19 | from functools import wraps
20 |
21 |
22 | def since(version):
23 | """
24 | A decorator that annotates a function to append the version of Spark the function was added.
25 | """
26 | import re
27 | indent_p = re.compile(r'\n( +)')
28 |
29 | def deco(f):
30 | indents = indent_p.findall(f.__doc__)
31 | indent = ' ' * (min(len(m) for m in indents) if indents else 0)
32 | f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version)
33 | return f
34 |
35 | return deco
36 |
37 |
38 | def copy_func(f, name=None, sinceversion=None, doc=None):
39 | """
40 | Returns a function with same code, globals, defaults, closure, and
41 | name (or provide a new name).
42 | """
43 | # See
44 | # http://stackoverflow.com/questions/6527633/how-can-i-make-a-deepcopy-of-a-function-in-python
45 | fn = types.FunctionType(f.__code__, f.__globals__, name or f.__name__, f.__defaults__,
46 | f.__closure__)
47 | # in case f was given attrs (note this dict is a shallow copy):
48 | fn.__dict__.update(f.__dict__)
49 | if doc is not None:
50 | fn.__doc__ = doc
51 | if sinceversion is not None:
52 | fn = since(sinceversion)(fn)
53 | return fn
54 |
55 |
56 | def keyword_only(func):
57 | """
58 | A decorator that forces keyword arguments in the wrapped method
59 | and saves actual input keyword arguments in `_input_kwargs`.
60 |
61 | .. note:: Should only be used to wrap a method where first arg is `self`
62 | """
63 |
64 | @wraps(func)
65 | def wrapper(self, *args, **kwargs):
66 | if len(args) > 0:
67 | raise TypeError("Method %s forces keyword arguments." % func.__name__)
68 | self._input_kwargs = kwargs
69 | return func(self, **kwargs)
70 |
71 | return wrapper
72 |
--------------------------------------------------------------------------------
/python/pyjava/api/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Any, NoReturn, Callable, Dict, List
2 | import tempfile
3 | import os
4 |
5 | from pyjava.api.mlsql import PythonContext
6 |
7 |
8 | class Utils(object):
9 |
10 | @staticmethod
11 | def show_plt(plt: Any, context: PythonContext) -> NoReturn:
12 | content = Utils.gen_img(plt)
13 | context.build_result([{"content": content, "mime": "image"}])
14 |
15 | @staticmethod
16 | def gen_img(plt: Any) -> str:
17 | import base64
18 | import uuid
19 | img_path = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()) + ".png")
20 | plt.savefig(img_path)
21 | with open(img_path, mode='rb') as file:
22 | file_content = base64.b64encode(file.read()).decode()
23 | return file_content
24 |
--------------------------------------------------------------------------------
/python/pyjava/api/mlsql.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import socket
4 | from distutils.version import StrictVersion
5 | import uuid
6 |
7 | import pandas as pd
8 | import sys
9 | from typing import Dict
10 |
11 | import pyjava.utils as utils
12 | import requests
13 | from pyjava.serializers import ArrowStreamSerializer
14 | from pyjava.serializers import read_int
15 | from pyjava.utils import utf8_deserializer
16 | from pyjava.storage import streaming_tar
17 |
18 | if sys.version >= '3':
19 | basestring = str
20 | else:
21 | pass
22 |
23 |
24 | class DataServer(object):
25 | def __init__(self, host, port, timezone):
26 | self.host = host
27 | self.port = port
28 | self.timezone = timezone
29 |
30 |
31 | class LogClient(object):
32 | def __init__(self, conf):
33 | self.conf = conf
34 | if 'spark.mlsql.log.driver.url' in self.conf:
35 | self.url = self.conf['spark.mlsql.log.driver.url']
36 | self.log_user = self.conf['PY_EXECUTE_USER']
37 | self.log_token = self.conf['spark.mlsql.log.driver.token']
38 | self.log_group_id = self.conf['groupId']
39 |
40 | def log_to_driver(self, msg):
41 | if 'spark.mlsql.log.driver.url' not in self.conf:
42 | if self.conf['PY_EXECUTE_USER'] and self.conf['groupId']:
43 | logging.info("[owner] [{}] [groupId] [{}] __MMMMMM__ {}".format(self.conf['PY_EXECUTE_USER'],
44 | self.conf['groupId'], msg))
45 | else:
46 | logging.info(msg)
47 | return
48 | import json
49 | resp = json.dumps(
50 | {"sendLog": {
51 | "token": self.log_token,
52 | "logLine": "[owner] [{}] [groupId] [{}] __MMMMMM__ {}".format(self.log_user, self.log_group_id, msg)
53 | }}, ensure_ascii=False)
54 | requests.post(self.url, data=resp, headers={'content-type': 'application/x-www-form-urlencoded;charset=UTF-8'})
55 |
56 | def close(self):
57 | if hasattr(self, "conn"):
58 | self.conn.close()
59 | self.conn = None
60 |
61 |
62 | class PythonContext(object):
63 | cache = {}
64 |
65 | def __init__(self, context_id, iterator, conf):
66 | self.context_id = context_id
67 | self.data_mmap_file_ref = {}
68 | self.input_data = iterator
69 | self.output_data = [[]]
70 | self.conf = conf
71 | self.schema = ""
72 | self.have_fetched = False
73 | self.log_client = LogClient(self.conf)
74 | if "pythonMode" in conf and conf["pythonMode"] == "ray":
75 | self.rayContext = RayContext(self)
76 |
77 | def set_output(self, value, schema=""):
78 | self.output_data = value
79 | self.schema = schema
80 |
81 | @staticmethod
82 | def build_chunk_result(items, block_size=1024):
83 | buffer = []
84 | for item in items:
85 | buffer.append(item)
86 | if len(buffer) == block_size:
87 | df = pd.DataFrame(buffer, columns=buffer[0].keys())
88 | buffer.clear()
89 | yield df
90 |
91 | if len(buffer) > 0:
92 | df = pd.DataFrame(buffer, columns=buffer[0].keys())
93 | buffer.clear()
94 | yield df
95 |
96 | def build_result(self, items, block_size=1024):
97 | self.output_data = ([df[name] for name in df]
98 | for df in PythonContext.build_chunk_result(items, block_size))
99 |
100 | def build_result_from_dir(self, target_dir, block_size=1024):
101 | items = streaming_tar.build_rows_from_file(target_dir)
102 | self.build_result(items, block_size)
103 |
104 | def output(self):
105 | return self.output_data
106 |
107 | def __del__(self):
108 | logging.info("==clean== context")
109 | if self.log_client is not None:
110 | try:
111 | self.log_client.close()
112 | except Exception as e:
113 | pass
114 |
115 | if 'data_mmap_file_ref' in self.data_mmap_file_ref:
116 | try:
117 | self.data_mmap_file_ref['data_mmap_file_ref'].close()
118 | except Exception as e:
119 | pass
120 |
121 | def noops_fetch(self):
122 | for item in self.fetch_once():
123 | pass
124 |
125 | def fetch_once_as_dataframe(self):
126 | for df in self.fetch_once():
127 | yield df
128 |
129 | def fetch_once_as_rows(self):
130 | for df in self.fetch_once_as_dataframe():
131 | for row in df.to_dict('records'):
132 | yield row
133 |
134 | def fetch_once_as_batch_rows(self):
135 | for df in self.fetch_once_as_dataframe():
136 | yield (row for row in df.to_dict('records'))
137 |
138 | def fetch_once(self):
139 | import pyarrow as pa
140 | if self.have_fetched:
141 | raise Exception("input data can only be fetched once")
142 | self.have_fetched = True
143 | for items in self.input_data:
144 | yield pa.Table.from_batches([items]).to_pandas()
145 |
146 | def fetch_as_dir(self, target_dir):
147 | if len(self.data_servers()) > 1:
148 | raise Exception("Please make sure you have only one partition on Java/Spark Side")
149 | items = self.fetch_once_as_rows()
150 | streaming_tar.save_rows_as_file(items, target_dir)
151 |
152 |
153 | class PythonProjectContext(object):
154 | def __init__(self):
155 | self.params_read = False
156 | self.conf = {}
157 | self.read_params_once()
158 | self.log_client = LogClient(self.conf)
159 |
160 | def read_params_once(self):
161 | if not self.params_read:
162 | self.params_read = True
163 | infile = sys.stdin.buffer
164 | for i in range(read_int(infile)):
165 | k = utf8_deserializer.loads(infile)
166 | v = utf8_deserializer.loads(infile)
167 | self.conf[k] = v
168 |
169 | def input_data_dir(self):
170 | return self.conf["tempDataLocalPath"]
171 |
172 | def output_model_dir(self):
173 | return self.conf["tempModelLocalPath"]
174 |
175 | def __del__(self):
176 | self.log_client.close()
177 |
178 |
179 | class RayContext(object):
180 | cache = {}
181 | conn_cache = {}
182 |
183 | def __init__(self, python_context):
184 | self.python_context = python_context
185 | self.servers = []
186 | self.server_ids_in_ray = []
187 | self.is_setup = False
188 | self.is_dev = utils.is_dev()
189 | self.is_in_mlsql = True
190 | self.mock_data = []
191 | if "directData" not in python_context.conf:
192 | for item in self.python_context.fetch_once_as_rows():
193 | self.server_ids_in_ray.append(str(uuid.uuid4()))
194 | self.servers.append(DataServer(
195 | item["host"], int(item["port"]), item["timezone"]))
196 |
197 | def data_servers(self):
198 | return self.servers
199 |
200 | def conf(self):
201 | return self.python_context.conf
202 |
203 | def data_servers_in_ray(self):
204 | import ray
205 | from pyjava.rayfix import RayWrapper
206 | rayw = RayWrapper()
207 | for server_id in self.server_ids_in_ray:
208 | server = rayw.get_actor(server_id)
209 | yield ray.get(server.connect_info.remote())
210 |
211 | def build_servers_in_ray(self):
212 | from pyjava.rayfix import RayWrapper
213 | from pyjava.api.serve import RayDataServer
214 | import ray
215 | buffer = []
216 | rayw = RayWrapper()
217 | for (server_id, java_server) in zip(self.server_ids_in_ray, self.servers):
218 | # rds = RayDataServer.options(name=server_id, detached=True, max_concurrency=2).remote(server_id, java_server,
219 | # 0,
220 | # java_server.timezone)
221 | rds = rayw.options(RayDataServer, name=server_id, detached=True, max_concurrency=2).remote(server_id,
222 | java_server,
223 | 0,
224 | java_server.timezone)
225 | res = ray.get(rds.connect_info.remote())
226 | logging.debug("build ray data server server_id:{} java_server: {} servers:{}".format(server_id,
227 | str(vars(
228 | java_server)),
229 | str(vars(res))))
230 | buffer.append(res)
231 | return buffer
232 |
233 | @staticmethod
234 | def connect(_context, url, **kwargs):
235 | if isinstance(_context, PythonContext):
236 | context = _context
237 | elif isinstance(_context, dict):
238 | if 'context' in _context:
239 | context = _context['context']
240 | else:
241 | '''
242 | we are not in MLSQL
243 | '''
244 | context = PythonContext("", [], {"pythonMode": "ray"})
245 | context.rayContext.is_in_mlsql = False
246 | else:
247 | raise Exception("context is not detect. make sure it's in globals().")
248 |
249 | if url == "local":
250 | from pyjava.rayfix import RayWrapper
251 | ray = RayWrapper()
252 | if ray.ray_version < StrictVersion('1.6.0'):
253 | raise Exception("URL:local is only support in ray >= 1.6.0")
254 | # if not ray.ray_instance.is_initialized:
255 | ray.ray_instance.shutdown()
256 | ray.ray_instance.init(namespace="default")
257 |
258 | elif url is not None:
259 | from pyjava.rayfix import RayWrapper
260 | ray = RayWrapper()
261 | is_udf_client = context.conf.get("UDF_CLIENT")
262 | if is_udf_client is None:
263 | ray.shutdown()
264 | ray.init(url, **kwargs)
265 | if is_udf_client and url not in RayContext.conn_cache:
266 | ray.init(url, **kwargs)
267 | RayContext.conn_cache[url] = 1
268 |
269 | return context.rayContext
270 |
271 | def setup(self, func_for_row, func_for_rows=None):
272 | if self.is_setup:
273 | raise ValueError("setup can be only invoke once")
274 | self.is_setup = True
275 |
276 | is_data_mode = "dataMode" in self.conf() and self.conf()["dataMode"] == "data"
277 |
278 | if not is_data_mode:
279 | raise Exception('''
280 | Please setup dataMode as data instead of model.
281 | Try run: `!python conf "dataMode=data"` or
282 | add comment like: `#%dataMode=data` if you are in notebook.
283 | ''')
284 |
285 | import ray
286 | from pyjava.rayfix import RayWrapper
287 | rayw = RayWrapper()
288 |
289 | if not self.is_in_mlsql:
290 | if func_for_rows is not None:
291 | func = ray.remote(func_for_rows)
292 | return ray.get(func.remote(self.mock_data))
293 | else:
294 | func = ray.remote(func_for_row)
295 |
296 | def iter_all(rows):
297 | return [ray.get(func.remote(row)) for row in rows]
298 |
299 | iter_all_func = ray.remote(iter_all)
300 | return ray.get(iter_all_func.remote(self.mock_data))
301 |
302 | buffer = []
303 | for server_info in self.build_servers_in_ray():
304 | server = rayw.get_actor(server_info.server_id)
305 | rci = ray.get(server.connect_info.remote())
306 | buffer.append(rci)
307 | server.serve.remote(func_for_row, func_for_rows)
308 | items = [vars(server) for server in buffer]
309 | self.python_context.build_result(items, 1024)
310 | return buffer
311 |
312 | def foreach(self, func_for_row):
313 | return self.setup(func_for_row)
314 |
315 | def map_iter(self, func_for_rows):
316 | return self.setup(None, func_for_rows)
317 |
318 | def collect(self):
319 | for shard in self.data_servers():
320 | for row in RayContext.fetch_once_as_rows(shard):
321 | yield row
322 |
323 | def fetch_as_dir(self, target_dir, servers=None):
324 | if not servers:
325 | servers = self.data_servers()
326 | if len(servers) > 1:
327 | raise Exception("Please make sure you have only one partition on Java/Spark Side")
328 |
329 | items = RayContext.collect_from(servers)
330 | streaming_tar.save_rows_as_file(items, target_dir)
331 |
332 | def build_result(self, items, block_size=1024):
333 | self.python_context.build_result(items, block_size)
334 |
335 | def build_result_from_dir(self, target_path):
336 | self.python_context.build_result_from_dir(target_path)
337 |
338 | @staticmethod
339 | def parse_servers(host_ports):
340 | hosts = host_ports.split(",")
341 | hosts = [item.split(":") for item in hosts]
342 | return [DataServer(item[0], int(item[1]), "") for item in hosts]
343 |
344 | @staticmethod
345 | def fetch_as_repeatable_file(context_id, data_servers, file_ref, batch_size):
346 | import pyarrow as pa
347 |
348 | def inner_fetch():
349 | for data_server in data_servers:
350 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
351 | out_ser = ArrowStreamSerializer()
352 | sock.connect((data_server.host, data_server.port))
353 | buffer_size = int(os.environ.get("BUFFER_SIZE", 65536))
354 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
355 | result = out_ser.load_stream(infile)
356 | for batch in result:
357 | yield batch
358 |
359 | def gen_by_batch():
360 | import numpy as np
361 | import math
362 | if 'data_mmap_file_ref' not in file_ref:
363 | file_ref['data_mmap_file_ref'] = pa.memory_map(context_id + "/__input__.dat")
364 | reader = pa.ipc.open_file(file_ref['data_mmap_file_ref'])
365 | num_record_batches = reader.num_record_batches
366 | for i in range(num_record_batches):
367 | df = reader.get_batch(i).to_pandas()
368 | for small_batch in np.array_split(df, math.floor(df.shape[0] / batch_size)):
369 | yield small_batch
370 |
371 | if 'data_mmap_file_ref' in file_ref:
372 | return gen_by_batch()
373 | else:
374 | writer = None
375 | for batch in inner_fetch():
376 | if writer is None:
377 | writer = pa.RecordBatchFileWriter(context_id + "/__input__.dat", batch.schema)
378 | writer.write_batch(batch)
379 | writer.close()
380 | return gen_by_batch()
381 |
382 | def collect_as_file(self, batch_size):
383 | data_servers = self.data_servers()
384 | python_context = self.python_context
385 | return RayContext.fetch_as_repeatable_file(python_context.context_id, data_servers,
386 | python_context.data_mmap_file_ref,
387 | batch_size)
388 |
389 | @staticmethod
390 | def collect_from(servers):
391 | for shard in servers:
392 | for row in RayContext.fetch_once_as_rows(shard):
393 | yield row
394 |
395 | def to_pandas(self):
396 | items = [row for row in self.collect()]
397 | return pd.DataFrame(data=items)
398 |
399 | @staticmethod
400 | def fetch_once_as_rows(data_server):
401 | for df in RayContext.fetch_data_from_single_data_server(data_server):
402 | for row in df.to_dict('records'):
403 | yield row
404 |
405 | @staticmethod
406 | def fetch_data_from_single_data_server(data_server):
407 | out_ser = ArrowStreamSerializer()
408 | import pyarrow as pa
409 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
410 | sock.connect((data_server.host, data_server.port))
411 | buffer_size = int(os.environ.get("BUFFER_SIZE", 65536))
412 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
413 | result = out_ser.load_stream(infile)
414 | for items in result:
415 | yield pa.Table.from_batches([items]).to_pandas()
416 |
--------------------------------------------------------------------------------
/python/pyjava/api/serve.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import socket
4 | import traceback
5 |
6 | import ray
7 |
8 | import pyjava.utils as utils
9 | from pyjava.api.mlsql import RayContext
10 | from pyjava.rayfix import RayWrapper
11 | from pyjava.serializers import \
12 | write_with_length, \
13 | write_int, read_int, \
14 | SpecialLengths, ArrowStreamPandasSerializer
15 |
16 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1"
17 |
18 |
19 | class SocketNotBindException(Exception):
20 | def __init__(self, message):
21 | Exception.__init__(self)
22 | self.message = message
23 |
24 |
25 | class DataServerWithId(object):
26 | def __init__(self, host, port, server_id):
27 | self.host = host
28 | self.port = port
29 | self.server_id = server_id
30 |
31 |
32 | class OnceServer(object):
33 | def __init__(self, host, port, timezone):
34 | self.host = host
35 | self.port = port
36 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
37 | self.socket.settimeout(5 * 60)
38 | self.out_ser = ArrowStreamPandasSerializer(timezone, False, None)
39 | self.is_bind = False
40 | self.is_dev = utils.is_dev()
41 |
42 | def bind(self):
43 | try:
44 | self.socket.bind((self.host, self.port))
45 | self.is_bind = True
46 | self.socket.listen(1)
47 | except Exception:
48 | print(traceback.format_exc())
49 |
50 | return self.socket.getsockname()
51 |
52 | def close(self):
53 | self.socket.close()
54 |
55 | def serve(self, data):
56 | from pyjava.api.mlsql import PythonContext
57 | if not self.is_bind:
58 | raise SocketNotBindException(
59 | "Please invoke server.bind() before invoke server.serve")
60 | conn, addr = self.socket.accept()
61 | sockfile = conn.makefile("rwb", int(
62 | os.environ.get("BUFFER_SIZE", 65536)))
63 | infile = sockfile # os.fdopen(os.dup(conn.fileno()), "rb", 65536)
64 | out = sockfile # os.fdopen(os.dup(conn.fileno()), "wb", 65536)
65 | try:
66 | write_int(SpecialLengths.START_ARROW_STREAM, out)
67 | out_data = ([df[name] for name in df] for df in
68 | PythonContext.build_chunk_result(data, 1024))
69 | self.out_ser.dump_stream(out_data, out)
70 |
71 | write_int(SpecialLengths.END_OF_DATA_SECTION, out)
72 | write_int(SpecialLengths.END_OF_STREAM, out)
73 | out.flush()
74 | if self.is_dev:
75 | print("all data in ray task have been consumed.")
76 | read_int(infile)
77 | except Exception:
78 | try:
79 | write_int(SpecialLengths.ARROW_STREAM_CRASH, out)
80 | ex = traceback.format_exc()
81 | print(ex)
82 | write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, out)
83 | write_with_length(ex.encode("utf-8"), out)
84 | out.flush()
85 | read_int(infile)
86 | except IOError:
87 | # JVM close the socket
88 | pass
89 | except Exception:
90 | # Write the error to stderr if it happened while serializing
91 | print("Py worker failed with exception:")
92 | print(traceback.format_exc())
93 | pass
94 |
95 | conn.close()
96 |
97 |
98 | @ray.remote
99 | class RayDataServer(object):
100 |
101 | def __init__(self, server_id, java_server, port=0, timezone="Asia/Harbin"):
102 | self.server = OnceServer(
103 | RayWrapper().get_address(), port, java_server.timezone)
104 | try:
105 | (rel_host, rel_port) = self.server.bind()
106 | except Exception as e:
107 | print(traceback.format_exc())
108 | raise e
109 |
110 | self.host = rel_host
111 | self.port = rel_port
112 | self.timezone = timezone
113 | self.server_id = server_id
114 | self.java_server = java_server
115 | self.is_dev = utils.is_dev()
116 |
117 | def serve(self, func_for_row=None, func_for_rows=None):
118 | try:
119 | if func_for_row is not None:
120 | data = (func_for_row(item)
121 | for item in RayContext.fetch_once_as_rows(self.java_server))
122 | elif func_for_rows is not None:
123 | data = func_for_rows(
124 | RayContext.fetch_once_as_rows(self.java_server))
125 | self.server.serve(data)
126 | except Exception as e:
127 | logging.error(f"Fail to processing data in Ray Data Server {self.host}:{self.port}")
128 | raise e
129 | finally:
130 | self.close()
131 |
132 | def close(self):
133 | try:
134 | self.server.close()
135 | ray.actor.exit_actor()
136 | except Exception:
137 | print(traceback.format_exc())
138 |
139 | def connect_info(self):
140 | return DataServerWithId(self.host, self.port, self.server_id)
141 |
--------------------------------------------------------------------------------
/python/pyjava/cache/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/cache/__init__.py
--------------------------------------------------------------------------------
/python/pyjava/cache/code_cache.py:
--------------------------------------------------------------------------------
1 | class CodeCache(object):
2 | cache = {}
3 | cache_max_size = 10 * 10000
4 |
5 | @staticmethod
6 | def get(code):
7 | if code in CodeCache.cache:
8 | return CodeCache.cache[code]
9 | else:
10 | return CodeCache.gen_cache(code)
11 |
12 | @staticmethod
13 | def gen_cache(code):
14 | if len(CodeCache.cache) > CodeCache.cache_max_size:
15 | CodeCache.cache.clear()
16 | CodeCache.cache[code] = compile(code, '', 'exec')
17 | return CodeCache.cache[code]
18 |
--------------------------------------------------------------------------------
/python/pyjava/daemon.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import numbers
19 | import os
20 | import signal
21 | import select
22 | import socket
23 | import sys
24 | import traceback
25 | import time
26 | import gc
27 | from errno import EINTR, EAGAIN
28 | from socket import AF_INET, SOCK_STREAM, SOMAXCONN
29 | from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
30 |
31 | from pyjava.worker import main as worker_main
32 | from pyjava.serializers import read_int, write_int, write_with_length, UTF8Deserializer
33 |
34 |
35 | def compute_real_exit_code(exit_code):
36 | # SystemExit's code can be integer or string, but os._exit only accepts integers
37 | if isinstance(exit_code, numbers.Integral):
38 | return exit_code
39 | else:
40 | return 1
41 |
42 |
43 | def worker(sock):
44 | """
45 | Called by a worker process after the fork().
46 | """
47 | signal.signal(SIGHUP, SIG_DFL)
48 | signal.signal(SIGCHLD, SIG_DFL)
49 | signal.signal(SIGTERM, SIG_DFL)
50 | # restore the handler for SIGINT,
51 | # it's useful for debugging (show the stacktrace before exit)
52 | signal.signal(SIGINT, signal.default_int_handler)
53 |
54 | # Read the socket using fdopen instead of socket.makefile() because the latter
55 | # seems to be very slow; note that we need to dup() the file descriptor because
56 | # otherwise writes also cause a seek that makes us miss data on the read side.
57 | buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
58 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
59 | outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)
60 |
61 | exit_code = 0
62 | try:
63 | worker_main(infile, outfile)
64 | except SystemExit as exc:
65 | exit_code = compute_real_exit_code(exc.code)
66 | finally:
67 | try:
68 | outfile.flush()
69 | except Exception:
70 | pass
71 | return exit_code
72 |
73 |
74 | def manager():
75 | # Create a new process group to corral our children
76 | os.setpgid(0, 0)
77 |
78 | # Create a listening socket on the AF_INET loopback interface
79 | listen_sock = socket.socket(AF_INET, SOCK_STREAM)
80 | listen_sock.bind(('127.0.0.1', 0))
81 | listen_sock.listen(max(1024, SOMAXCONN))
82 | listen_host, listen_port = listen_sock.getsockname()
83 |
84 | # re-open stdin/stdout in 'wb' mode
85 | stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4)
86 | stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4)
87 | write_int(listen_port, stdout_bin)
88 | stdout_bin.flush()
89 |
90 | def shutdown(code):
91 | signal.signal(SIGTERM, SIG_DFL)
92 | # Send SIGHUP to notify workers of shutdown
93 | os.kill(0, SIGHUP)
94 | sys.exit(code)
95 |
96 | def handle_sigterm(*args):
97 | shutdown(1)
98 | signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
99 | signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
100 | signal.signal(SIGCHLD, SIG_IGN)
101 |
102 | reuse = os.environ.get("PY_WORKER_REUSE")
103 |
104 | # Initialization complete
105 | try:
106 | while True:
107 | try:
108 | ready_fds = select.select([0, listen_sock], [], [], 1)[0]
109 | except select.error as ex:
110 | if ex[0] == EINTR:
111 | continue
112 | else:
113 | raise
114 |
115 | if 0 in ready_fds:
116 | try:
117 | worker_pid = read_int(stdin_bin)
118 | except EOFError:
119 | # Spark told us to exit by closing stdin
120 | shutdown(0)
121 | try:
122 | os.kill(worker_pid, signal.SIGKILL)
123 | except OSError:
124 | pass # process already died
125 |
126 | if listen_sock in ready_fds:
127 | try:
128 | sock, _ = listen_sock.accept()
129 | except OSError as e:
130 | if e.errno == EINTR:
131 | continue
132 | raise
133 |
134 | # Launch a worker process
135 | try:
136 | pid = os.fork()
137 | except OSError as e:
138 | if e.errno in (EAGAIN, EINTR):
139 | time.sleep(1)
140 | pid = os.fork() # error here will shutdown daemon
141 | else:
142 | outfile = sock.makefile(mode='wb')
143 | write_int(e.errno, outfile) # Signal that the fork failed
144 | outfile.flush()
145 | outfile.close()
146 | sock.close()
147 | continue
148 |
149 | if pid == 0:
150 | # in child process
151 | listen_sock.close()
152 |
153 | # It should close the standard input in the child process so that
154 | # Python native function executions stay intact.
155 | #
156 | # Note that if we just close the standard input (file descriptor 0),
157 | # the lowest file descriptor (file descriptor 0) will be allocated,
158 | # later when other file descriptors should happen to open.
159 | #
160 | # Therefore, here we redirects it to '/dev/null' by duplicating
161 | # another file descriptor for '/dev/null' to the standard input (0).
162 | # See SPARK-26175.
163 | devnull = open(os.devnull, 'r')
164 | os.dup2(devnull.fileno(), 0)
165 | devnull.close()
166 |
167 | try:
168 | # Acknowledge that the fork was successful
169 | outfile = sock.makefile(mode="wb")
170 | write_int(os.getpid(), outfile)
171 | outfile.flush()
172 | outfile.close()
173 | while True:
174 | code = worker(sock)
175 | if not reuse or code:
176 | # wait for closing
177 | try:
178 | while sock.recv(1024):
179 | pass
180 | except Exception:
181 | pass
182 | break
183 | gc.collect()
184 | except:
185 | traceback.print_exc()
186 | os._exit(1)
187 | else:
188 | os._exit(0)
189 | else:
190 | sock.close()
191 |
192 | finally:
193 | shutdown(1)
194 |
195 |
196 | if __name__ == '__main__':
197 | manager()
198 |
--------------------------------------------------------------------------------
/python/pyjava/datatype/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/datatype/__init__.py
--------------------------------------------------------------------------------
/python/pyjava/example/OnceServerExample.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 |
5 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1"
6 | from pyjava.api.serve import OnceServer
7 |
8 | ddata = pd.DataFrame(data=[[1, 2, 3, 4], [2, 3, 4, 5]])
9 |
10 | server = OnceServer("127.0.0.1", 11111, "Asia/Harbin")
11 | server.bind()
12 | server.serve([{'id': 9, 'label': 1}])
13 |
--------------------------------------------------------------------------------
/python/pyjava/example/RayServerExample.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | import sys
5 | sys.path.append("../../")
6 |
7 | from pyjava.api.mlsql import DataServer
8 | from pyjava.api.serve import RayDataServer
9 | from pyjava.rayfix import RayWrapper
10 |
11 | # import ray
12 |
13 | os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1"
14 | ray = RayWrapper()
15 | ray.init(address='auto', _redis_password='5241590000000000')
16 |
17 | ddata = pd.DataFrame(data=[[1, 2, 3, 4], [2, 3, 4, 5]])
18 |
19 | server_id = "wow1"
20 |
21 | java_server = DataServer("127.0.0.1", 11111, "Asia/Harbin")
22 | rds = RayDataServer.options(name=server_id, max_concurrency=2).remote(server_id, java_server, 0,
23 | "Asia/Harbin")
24 | print(ray.get(rds.connect_info.remote()))
25 |
26 |
27 | def echo(row):
28 | return row
29 |
30 |
31 | rds.serve.remote(echo)
32 |
33 | rds = ray.get_actor(server_id)
34 | print(vars(ray.get(rds.connect_info.remote())))
35 |
--------------------------------------------------------------------------------
/python/pyjava/example/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/example/__init__.py
--------------------------------------------------------------------------------
/python/pyjava/example/test.py:
--------------------------------------------------------------------------------
1 | import ray
2 |
3 | ray.init()
4 |
5 |
6 | @ray.remote
7 | def slow_function():
8 | return 1
9 |
10 |
11 | print(ray.get(slow_function.remote()))
12 |
--------------------------------------------------------------------------------
/python/pyjava/example/test2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 |
4 | from pyjava.serializers import ArrowStreamPandasSerializer, read_int, write_int
5 |
6 | out_ser = ArrowStreamPandasSerializer("Asia/Harbin", False, None)
7 | HOST = "127.0.0.1"
8 | PORT = 11111
9 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
10 | sock.connect((HOST, PORT))
11 | buffer_size = int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
12 | infile = os.fdopen(os.dup(sock.fileno()), "rb", buffer_size)
13 | outfile = os.fdopen(os.dup(sock.fileno()), "wb", buffer_size)
14 | # arrow start
15 | print(read_int(infile))
16 | kk = out_ser.load_stream(infile)
17 |
18 | for item in kk:
19 | print(item)
20 | # end data
21 | print(read_int(infile))
22 | # end stream
23 | print(read_int(infile))
24 | write_int(-4,outfile)
25 |
--------------------------------------------------------------------------------
/python/pyjava/example/test3.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | request_url = "http://127.0.0.1:9007/run"
4 |
5 |
6 | def registerPyAction(enableAdmin, params={}):
7 | datas = {"code": """from pyjava.api.mlsql import PythonContext
8 | for row in context.fetch_once():
9 | print(row)
10 | context.build_result([{"content": "{}"}], 1)
11 | """, "action": "registerPyAction", "codeName": "echo"}
12 | if enableAdmin:
13 | datas = {**datas, **{"admin_token": "admin"}}
14 | r = requests.post(request_url, data={**datas, **params})
15 | print(r.text)
16 | print(r.status_code)
17 |
18 |
19 | # echo
20 | def pyAction(codeName, enableAdmin=True, params={}):
21 | datas = {"codeName": codeName, "action": "pyAction"}
22 | if enableAdmin:
23 | datas = {**datas, **{"admin_token": "admin"}}
24 | r = requests.post(request_url, data={**datas, **params})
25 | print(r.text)
26 | print(r.status_code)
27 |
28 |
29 | def loadDB(db):
30 | datas = {"name": db, "action": "loadDB"}
31 | r = requests.post(request_url, data=datas)
32 | print(r.text)
33 | print(r.status_code)
34 |
35 |
36 | def http(action, params):
37 | datas = {**{"action": action}, **params}
38 | r = requests.post(request_url, data=datas)
39 | return r
40 |
41 |
42 | def controlReg():
43 | r = http("controlReg", {"enable": "true"})
44 | print(r.text)
45 | print(r.status_code)
46 |
47 |
48 | def printRespose(r):
49 | print(r.text)
50 | print(r.status_code)
51 |
52 |
53 | def test_db_config(db, user, password):
54 | return """
55 | {}:
56 | host: 127.0.0.1
57 | port: 3306
58 | database: {}
59 | username: {}
60 | password: {}
61 | initialSize: 8
62 | disable: true
63 | removeAbandoned: true
64 | testWhileIdle: true
65 | removeAbandonedTimeout: 30
66 | maxWait: 100
67 | filters: stat,log4j
68 | """.format(db, db, user, password)
69 |
70 |
71 | def addDB(instanceName, db, user, password):
72 | # user-system
73 | datas = {"dbName": db, "instanceName": instanceName,
74 | "dbConfig": test_db_config(db, user, password), "admin_token": "admin"}
75 | r = http(addDB.__name__, datas)
76 | printRespose(r)
77 |
78 |
79 | def userReg(name, password):
80 | r = http(userReg.__name__, {"userName": name, "password": password})
81 | printRespose(r)
82 |
83 |
84 | def users():
85 | r = http(users.__name__, {})
86 | printRespose(r)
87 |
88 |
89 | def uploadPlugin(file_path, data):
90 | values = {**data, **{"action": "uploadPlugin"}}
91 | files = {file_path.split("/")[-1]: open(file_path, 'rb')}
92 | r = requests.post(request_url, files=files, data=values)
93 | printRespose(r)
94 |
95 |
96 | def addProxy(name, value):
97 | r = http(addProxy.__name__, {"name": name, "value": value})
98 | printRespose(r)
99 |
100 |
101 | import json
102 |
103 |
104 | def userLogin(userName, password):
105 | r = http(userLogin.__name__, {"userName": userName, "password": password})
106 | return json.loads(r.text)[0]['token']
107 |
108 |
109 | def enablePythonAdmin(userName, token):
110 | r = http("pyAuthAction",
111 | {"userName": userName, "access-token": token,
112 | "resourceType": "admin",
113 | "resourceName": "admin", "admin_token": "admin", "authUser": "jack"})
114 | return r
115 |
116 |
117 | def enablePythonRegister(userName, token):
118 | r = http("pyAuthAction",
119 | {"userName": userName, "access-token": token,
120 | "resourceType": "action",
121 | "resourceName": "registerPyAction", "authUser": "william"})
122 | return r
123 |
124 |
125 | def enablePythonExecute(userName, token, codeName):
126 | r = http("pyAuthAction",
127 | {"userName": userName, "access-token": token,
128 | "resourceType": "custom",
129 | "resourceName": codeName, "authUser": "william"})
130 | return r
131 |
132 |
133 | # userReg()
134 | # users()
135 | # addDB("ar_plugin_repo")
136 | # pyAction("echo")
137 | # addProxy("user-system", "http://127.0.0.1:9007/run")
138 | # uploadPlugin("/Users/allwefantasy/CSDNWorkSpace/user-system/release/user-system-bin_2.11-1.0.0.jar",
139 | # {"name": "jack", "password": "123", "pluginName": "user-system"})
140 |
141 | # addDB("user-system", "mlsql_python_predictor", "root", "mlsql")
142 | # userReg("william", "mm")
143 | # registerPyAction(False)
144 | # pyAction("echo", False)
145 | token = userLogin("jack", "123")
146 | # r = enablePythonAdmin("jack", token)
147 | # r = enablePythonExecute("jack", token, "echo")
148 | # printRespose(r)
149 |
150 | token = userLogin("william", "mm")
151 | # registerPyAction(False, {"userName": "william", "access-token": token})
152 | # enablePythonRegister("jack", token)
153 | pyAction("echo", False, {"userName": "william", "access-token": token})
154 |
--------------------------------------------------------------------------------
/python/pyjava/rayfix.py:
--------------------------------------------------------------------------------
1 | from distutils.version import StrictVersion
2 |
3 | import ray
4 | import logging
5 |
6 |
7 | def last(func):
8 | func.__module__ = "pyjava_auto_generate__exec__"
9 | return func
10 |
11 |
12 | class RayWrapper:
13 |
14 | def __init__(self):
15 | self.ray_version = "2.0.0" if "dev" in ray.__version__ else StrictVersion(ray.__version__)
16 | self.ray_instance = ray
17 |
18 | def __getattr__(self, attr):
19 | return getattr(self.ray_instance, attr)
20 |
21 | def get_address(self):
22 | if self.ray_version >= StrictVersion('1.0.0'):
23 | return ray.get_runtime_context().worker.node_ip_address
24 | else:
25 | return ray.services.get_node_ip_address()
26 |
27 | def init(self, address, **kwargs):
28 | logging.debug(f"address {address} {kwargs}")
29 | if self.ray_version >= StrictVersion('1.4.0'):
30 | if "namespace" in kwargs.keys():
31 | ray.util.connect(conn_str=address, **kwargs)
32 | else:
33 | ray.util.connect(conn_str=address, namespace="default", **kwargs)
34 | elif self.ray_version >= StrictVersion('1.0.0'):
35 | logging.debug(f"try to connect to ray {address}")
36 | ray.util.connect(conn_str=address, **kwargs)
37 | elif self.ray_version == StrictVersion('0.8.7'):
38 | ray.init(address=address, **kwargs)
39 | else:
40 | ray.init(redis_address=address, **kwargs)
41 |
42 | def shutdown(self):
43 | if self.ray_version >= StrictVersion('1.0.0'):
44 | try:
45 | ray.util.disconnect()
46 | except Exception as e:
47 | pass
48 | else:
49 | ray.shutdown(exiting_interpreter=False)
50 |
51 | def options(self, actor_class, **kwargs):
52 | if 'detached' in kwargs and self.ray_version >= StrictVersion('1.0.0'):
53 | del kwargs['detached']
54 | kwargs['lifetime'] = 'detached'
55 | logging.debug(f"actor build options: {kwargs}")
56 | return actor_class.options(**kwargs)
57 |
58 | def get_actor(self, name):
59 | if self.ray_version >= StrictVersion('1.0.0'):
60 | return ray.get_actor(name)
61 | elif self.ray_version == StrictVersion('0.8.7'):
62 | return ray.get_actor(name)
63 | else:
64 | return ray.experimental.get_actor(name)
65 |
--------------------------------------------------------------------------------
/python/pyjava/storage/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/allwefantasy/pyjava/83112f57ce8e3a87b6a5a59e6b33d2f9004caf60/python/pyjava/storage/__init__.py
--------------------------------------------------------------------------------
/python/pyjava/storage/streaming_tar.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import tarfile
4 | import tempfile
5 |
6 | import uuid
7 |
8 | BLOCK_SIZE = 1024 * 64
9 |
10 |
11 | class FileStream(object):
12 | def __init__(self):
13 | self.buffer = io.BytesIO()
14 | self.offset = 0
15 |
16 | def write(self, s):
17 | self.buffer.write(s)
18 | self.offset += len(s)
19 |
20 | def tell(self):
21 | return self.offset
22 |
23 | def close(self):
24 | self.buffer.close()
25 |
26 | def pop(self):
27 | s = self.buffer.getvalue()
28 | self.buffer.close()
29 | self.buffer = io.BytesIO()
30 | return s
31 |
32 |
33 | def build_rows_from_file(target_dir):
34 | file = FileStream()
35 | start = 0
36 | for i in stream_build_tar(target_dir, file):
37 | v = file.pop()
38 | offset = len(v)
39 | if len(v) > 0:
40 | yield {"start": start, "offset": offset, "value": v}
41 | start = start + offset
42 |
43 |
44 | def save_rows_as_file(data, target_dir):
45 | if not os.path.exists(target_dir):
46 | os.makedirs(target_dir)
47 | tf_path = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
48 | with open(tf_path, "wb") as tf:
49 | for block_row in data:
50 | tf.write(block_row["value"])
51 | with open(tf_path, "rb") as tf:
52 | tt = tarfile.open(tf.name, mode="r:")
53 | tt.extractall(target_dir)
54 | tt.close()
55 | os.remove(tf_path)
56 |
57 |
58 | def stream_build_tar(file_dir, streaming_fp):
59 | tar = tarfile.TarFile.open(fileobj=streaming_fp, mode="w:")
60 | for root, dirs, files in os.walk(file_dir):
61 | for in_filename in files:
62 | stat = os.stat(os.path.join(root, in_filename))
63 | tar_info = tarfile.TarInfo(os.path.join(root.lstrip(file_dir), in_filename))
64 | # tar_info.path = root.lstrip(file_dir)
65 | tar_info.mtime = stat.st_mtime
66 | tar_info.size = stat.st_size
67 |
68 | tar.addfile(tar_info)
69 |
70 | yield
71 |
72 | with open(os.path.join(root, in_filename), 'rb') as in_fp:
73 | # total_size = 0
74 | while True:
75 | s = in_fp.read(BLOCK_SIZE)
76 |
77 | if len(s) > 0:
78 | tar.fileobj.write(s)
79 | yield
80 |
81 | if len(s) < BLOCK_SIZE:
82 | blocks, remainder = divmod(tar_info.size, tarfile.BLOCKSIZE)
83 |
84 | if remainder > 0:
85 | tar.fileobj.write(tarfile.NUL *
86 | (tarfile.BLOCKSIZE - remainder))
87 |
88 | yield
89 |
90 | blocks += 1
91 |
92 | tar.offset += blocks * tarfile.BLOCKSIZE
93 | break
94 | tar.close()
95 |
96 |
97 | def main():
98 | rows = build_rows_from_file("/Users/allwefantasy/data/mlsql/homes/demo/tmp/minist_model")
99 | save_rows_as_file(rows, "/Users/allwefantasy/data/mlsql/homes/demo/tmp/minist_model3")
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/python/pyjava/tests/test_context.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from pyjava.api.mlsql import RayContext
4 |
5 |
6 | class RayContextTestCase(unittest.TestCase):
7 | def test_raycontext_collect_as_file(self):
8 | ray_context = RayContext.connect(globals(), None)
9 | dfs = ray_context.collect_as_file(32)
10 |
11 | for i in range(2):
12 | print("======={}======".format(str(i)))
13 | for df in dfs:
14 | print(df)
15 |
16 | ray_context.context.build_result([{"content": "jackma"}])
17 |
18 |
19 | if __name__ == '__main__':
20 | unittest.main()
21 |
--------------------------------------------------------------------------------
/python/pyjava/udf/__init__.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from typing import Any, NoReturn, Callable, Dict, List
3 | import ray
4 | import time
5 | from ray.util.client.common import ClientActorHandle, ClientObjectRef
6 |
7 | from pyjava.api.mlsql import RayContext
8 | from pyjava.storage import streaming_tar
9 |
10 |
11 | @ray.remote
12 | class UDFMaster(object):
13 | def __init__(self, num: int, conf: Dict[str, str],
14 | init_func: Callable[[List[ClientObjectRef], Dict[str, str]], Any],
15 | apply_func: Callable[[Any, Any], Any]):
16 | model_servers = RayContext.parse_servers(conf["modelServers"])
17 | items = RayContext.collect_from(model_servers)
18 | model_refs = [ray.put(item) for item in items]
19 | self.actors = dict([(index, UDFWorker.remote(model_refs, conf, init_func, apply_func)) for index in range(num)])
20 | self._idle_actors = [index for index in range(num)]
21 |
22 | def get(self) -> List[Any]:
23 | while len(self._idle_actors) == 0:
24 | time.sleep(0.001)
25 | index = self._idle_actors.pop()
26 | return [index, self.actors[index]]
27 |
28 | def give_back(self, v) -> NoReturn:
29 | self._idle_actors.append(v)
30 |
31 | def shutdown(self) -> NoReturn:
32 | [ray.kill(self.actors[index]) for index in self._idle_actors]
33 |
34 |
35 | @ray.remote
36 | class UDFWorker(object):
37 | def __init__(self,
38 | model_refs: List[ClientObjectRef],
39 | conf: Dict[str, str],
40 | init_func: Callable[[List[ClientObjectRef], Dict[str, str]], Any],
41 | apply_func: Callable[[Any, Any], Any]):
42 | self.model = init_func(model_refs, conf)
43 | self.apply_func = apply_func
44 |
45 | def apply(self, v: Any) -> Any:
46 | return self.apply_func(self.model, v)
47 |
48 | def shutdown(self):
49 | ray.actor.exit_actor()
50 |
51 |
52 | class UDFBuilder(object):
53 | @staticmethod
54 | def build(ray_context: RayContext,
55 | init_func: Callable[[List[ClientObjectRef], Dict[str, str]], Any],
56 | apply_func: Callable[[Any, Any], Any]) -> NoReturn:
57 | conf = ray_context.conf()
58 | udf_name = conf["UDF_CLIENT"]
59 | max_concurrency = int(conf.get("maxConcurrency", "3"))
60 |
61 | try:
62 | temp_udf_master = ray.get_actor(udf_name)
63 | ray.kill(temp_udf_master)
64 | time.sleep(1)
65 | except Exception as inst:
66 | print(inst)
67 | pass
68 |
69 | UDFMaster.options(name=udf_name, lifetime="detached", max_concurrency=max_concurrency).remote(
70 | max_concurrency, conf, init_func, apply_func)
71 | ray_context.build_result([])
72 |
73 | @staticmethod
74 | def apply(ray_context: RayContext):
75 | conf = ray_context.conf()
76 | udf_name = conf["UDF_CLIENT"]
77 | udf_master = ray.get_actor(udf_name)
78 | [index, worker] = ray.get(udf_master.get.remote())
79 | input = [row["value"] for row in ray_context.python_context.fetch_once_as_rows()]
80 | try:
81 | res = ray.get(worker.apply.remote(input))
82 | except Exception as inst:
83 | res = []
84 | print(inst)
85 | udf_master.give_back.remote(index)
86 | ray_context.build_result([res])
87 |
88 |
89 | class UDFBuildInFunc(object):
90 | @staticmethod
91 | def init_tf(model_refs: List[ClientObjectRef], conf: Dict[str, str]) -> Any:
92 | from tensorflow.keras import models
93 | model_path = "./tmp/model/{}".format(str(uuid.uuid4()))
94 | streaming_tar.save_rows_as_file((ray.get(ref) for ref in model_refs), model_path)
95 | return models.load_model(model_path)
96 |
--------------------------------------------------------------------------------
/python/pyjava/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | from pyjava.serializers import write_with_length, UTF8Deserializer, \
5 | PickleSerializer
6 |
7 | if sys.version >= '3':
8 | basestring = str
9 | else:
10 | pass
11 |
12 | pickleSer = PickleSerializer()
13 | utf8_deserializer = UTF8Deserializer()
14 |
15 |
16 | def is_dev():
17 | return 'MLSQL_DEV' in os.environ and int(os.environ["MLSQL_DEV"]) == 1
18 |
19 |
20 | def require_minimum_pandas_version():
21 | """ Raise ImportError if minimum version of Pandas is not installed
22 | """
23 | # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
24 | minimum_pandas_version = "0.23.2"
25 |
26 | from distutils.version import LooseVersion
27 | try:
28 | import pandas
29 | have_pandas = True
30 | except ImportError:
31 | have_pandas = False
32 | if not have_pandas:
33 | raise ImportError("Pandas >= %s must be installed; however, "
34 | "it was not found." % minimum_pandas_version)
35 | if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
36 | raise ImportError("Pandas >= %s must be installed; however, "
37 | "your version was %s." % (minimum_pandas_version, pandas.__version__))
38 |
39 |
40 | def require_minimum_pyarrow_version():
41 | """ Raise ImportError if minimum version of pyarrow is not installed
42 | """
43 | # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
44 | minimum_pyarrow_version = "0.12.1"
45 |
46 | from distutils.version import LooseVersion
47 | try:
48 | import pyarrow
49 | have_arrow = True
50 | except ImportError:
51 | have_arrow = False
52 | if not have_arrow:
53 | raise ImportError("PyArrow >= %s must be installed; however, "
54 | "it was not found." % minimum_pyarrow_version)
55 | if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
56 | raise ImportError("PyArrow >= %s must be installed; however, "
57 | "your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
58 |
59 |
60 | def _exception_message(excp):
61 | """Return the message from an exception as either a str or unicode object. Supports both
62 | Python 2 and Python 3.
63 |
64 | >>> msg = "Exception message"
65 | >>> excp = Exception(msg)
66 | >>> msg == _exception_message(excp)
67 | True
68 |
69 | >>> msg = u"unicöde"
70 | >>> excp = Exception(msg)
71 | >>> msg == _exception_message(excp)
72 | True
73 | """
74 | if hasattr(excp, "message"):
75 | return excp.message
76 | return str(excp)
77 |
78 |
79 | def _do_server_auth(conn, auth_secret):
80 | """
81 | Performs the authentication protocol defined by the SocketAuthHelper class on the given
82 | file-like object 'conn'.
83 | """
84 | write_with_length(auth_secret.encode("utf-8"), conn)
85 | conn.flush()
86 | reply = UTF8Deserializer().loads(conn)
87 | if reply != "ok":
88 | conn.close()
89 | raise Exception("Unexpected reply from iterator server.")
90 |
91 |
92 | def local_connect_and_auth(port):
93 | """
94 | Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
95 | Handles IPV4 & IPV6, does some error handling.
96 | :param port
97 | :param auth_secret
98 | :return: a tuple with (sockfile, sock)
99 | """
100 | import socket
101 | sock = None
102 | errors = []
103 | # Support for both IPv4 and IPv6.
104 | # On most of IPv6-ready systems, IPv6 will take precedence.
105 | for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
106 | af, socktype, proto, _, sa = res
107 | try:
108 | sock = socket.socket(af, socktype, proto)
109 | sock.settimeout(15)
110 | sock.connect(sa)
111 | sockfile = sock.makefile("rwb", int(os.environ.get("BUFFER_SIZE", 65536)))
112 | return (sockfile, sock)
113 | except socket.error as e:
114 | emsg = _exception_message(e)
115 | errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
116 | sock.close()
117 | sock = None
118 | raise Exception("could not open socket: %s" % errors)
119 |
--------------------------------------------------------------------------------
/python/pyjava/version.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | __version__ = "0.3.3"
20 |
--------------------------------------------------------------------------------
/python/pyjava/worker.py:
--------------------------------------------------------------------------------
1 | #
2 | # Licensed to the Apache Software Foundation (ASF) under one or more
3 | # contributor license agreements. See the NOTICE file distributed with
4 | # this work for additional information regarding copyright ownership.
5 | # The ASF licenses this file to You under the Apache License, Version 2.0
6 | # (the "License"); you may not use this file except in compliance with
7 | # the License. You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | from __future__ import print_function
19 |
20 | from pyjava.api.mlsql import PythonContext
21 | from pyjava.cache.code_cache import CodeCache
22 | from pyjava.utils import *
23 |
24 | # 'resource' is a Unix specific module.
25 | has_resource_module = True
26 | try:
27 | import resource
28 | except ImportError:
29 | has_resource_module = False
30 |
31 | import traceback
32 | import logging
33 |
34 | from pyjava.serializers import \
35 | write_with_length, \
36 | write_int, \
37 | read_int, read_bool, SpecialLengths, UTF8Deserializer, \
38 | PickleSerializer, ArrowStreamPandasSerializer, ArrowStreamSerializer
39 |
40 | if sys.version >= '3':
41 | basestring = str
42 | else:
43 | pass
44 |
45 | # logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level='DEBUG')
46 |
47 | pickleSer = PickleSerializer()
48 | utf8_deserializer = UTF8Deserializer()
49 |
50 | globals_namespace = globals()
51 |
52 |
53 | def read_command(serializer, file):
54 | command = serializer.load_stream(file)
55 | return command
56 |
57 |
58 | def chain(f, g):
59 | """chain two functions together """
60 | return lambda *a: g(f(*a))
61 |
62 |
63 | def main(infile, outfile):
64 | try:
65 | try:
66 | import ray
67 | except ImportError:
68 | pass
69 | # set up memory limits
70 | memory_limit_mb = int(os.environ.get('PY_EXECUTOR_MEMORY', "-1"))
71 | if memory_limit_mb > 0 and has_resource_module:
72 | total_memory = resource.RLIMIT_AS
73 | try:
74 | (soft_limit, hard_limit) = resource.getrlimit(total_memory)
75 | msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit)
76 | logging.info(msg)
77 |
78 | # convert to bytes
79 | new_limit = memory_limit_mb * 1024 * 1024
80 |
81 | if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit:
82 | msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit)
83 | logging.info(msg)
84 | resource.setrlimit(total_memory, (new_limit, new_limit))
85 |
86 | except (resource.error, OSError, ValueError) as e:
87 | # not all systems support resource limits, so warn instead of failing
88 | logging.warning("WARN: Failed to set memory limit: {0}\n".format(e))
89 | split_index = read_int(infile)
90 | logging.info("split_index:%s" % split_index)
91 | if split_index == -1: # for unit tests
92 | sys.exit(-1)
93 |
94 | is_barrier = read_bool(infile)
95 | bound_port = read_int(infile)
96 | logging.info(f"is_barrier {is_barrier}, port {bound_port}")
97 |
98 | conf = {}
99 | for i in range(read_int(infile)):
100 | k = utf8_deserializer.loads(infile)
101 | v = utf8_deserializer.loads(infile)
102 | conf[k] = v
103 | logging.debug(f"conf {k}:{v}")
104 |
105 | command = utf8_deserializer.loads(infile)
106 | ser = ArrowStreamSerializer()
107 |
108 | timezone = conf["timezone"] if "timezone" in conf else None
109 |
110 | out_ser = ArrowStreamPandasSerializer(timezone, True, True)
111 | is_interactive = os.environ.get('PY_INTERACTIVE', "no") == "yes"
112 | import uuid
113 | context_id = str(uuid.uuid4())
114 |
115 | if not os.path.exists(context_id):
116 | os.mkdir(context_id)
117 |
118 | def process():
119 | try:
120 | input_data = ser.load_stream(infile)
121 | code = CodeCache.get(command)
122 | if is_interactive:
123 | global data_manager
124 | global context
125 | data_manager = PythonContext(context_id, input_data, conf)
126 | context = data_manager
127 | global globals_namespace
128 | exec(code, globals_namespace, globals_namespace)
129 | else:
130 | data_manager = PythonContext(context_id, input_data, conf)
131 | n_local = {"data_manager": data_manager, "context": data_manager}
132 | exec(code, n_local, n_local)
133 | out_iter = data_manager.output()
134 | write_int(SpecialLengths.START_ARROW_STREAM, outfile)
135 | out_ser.dump_stream(out_iter, outfile)
136 | finally:
137 |
138 | try:
139 | import shutil
140 | shutil.rmtree(context_id)
141 | except:
142 | pass
143 |
144 | try:
145 | if hasattr(out_iter, 'close'):
146 | out_iter.close()
147 | except:
148 | pass
149 |
150 | try:
151 | del data_manager
152 | except:
153 | pass
154 |
155 | process()
156 |
157 |
158 | except Exception:
159 | try:
160 | write_int(SpecialLengths.ARROW_STREAM_CRASH, outfile)
161 | write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
162 | write_with_length(traceback.format_exc().encode("utf-8"), outfile)
163 | except IOError:
164 | # JVM close the socket
165 | pass
166 | except Exception:
167 | # Write the error to stderr if it happened while serializing
168 | print("Py worker failed with exception:", file=sys.stderr)
169 | print(traceback.format_exc(), file=sys.stderr)
170 | sys.exit(-1)
171 |
172 | write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
173 | flag = read_int(infile)
174 | if flag == SpecialLengths.END_OF_STREAM:
175 | write_int(SpecialLengths.END_OF_STREAM, outfile)
176 | else:
177 | # write a different value to tell JVM to not reuse this worker
178 | write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
179 | sys.exit(-1)
180 |
181 |
182 | if __name__ == '__main__':
183 | # Read information about how to connect back to the JVM from the environment.
184 | java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
185 | (sock_file, _) = local_connect_and_auth(java_port)
186 | main(sock_file, sock_file)
187 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | pyarrow==4.0.1
2 | ray>=1.3.0
3 | pandas>=1.0.5; python_version < '3.7'
4 | pandas>=1.2.0; python_version >= '3.7'
5 | requests
6 | matplotlib~=3.3.4
7 | uuid~=1.30
--------------------------------------------------------------------------------
/python/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | #
4 | # Licensed to the Apache Software Foundation (ASF) under one or more
5 | # contributor license agreements. See the NOTICE file distributed with
6 | # this work for additional information regarding copyright ownership.
7 | # The ASF licenses this file to You under the Apache License, Version 2.0
8 | # (the "License"); you may not use this file except in compliance with
9 | # the License. You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | from __future__ import print_function
20 |
21 | import sys
22 |
23 | from setuptools import setup
24 |
25 | if sys.version_info < (3, 6):
26 | print("Python versions prior to 2.7 are not supported for pip installed pyjava.",
27 | file=sys.stderr)
28 | sys.exit(-1)
29 |
30 | try:
31 | exec(open('pyjava/version.py').read())
32 | except IOError:
33 | print("Failed to load PyJava version file for packaging. You must be in PyJava's python dir.",
34 | file=sys.stderr)
35 | sys.exit(-1)
36 | VERSION = __version__ # noqa
37 | # A temporary path so we can access above the Python project root and fetch scripts and jars we need
38 | TEMP_PATH = "deps"
39 |
40 | # Provide guidance about how to use setup.py
41 | incorrect_invocation_message = """
42 | If you are installing pyjava from source, you must first build
43 | run sdist.
44 | Building the source dist is done in the Python directory:
45 | cd python
46 | python setup.py sdist
47 | pip install dist/*.tar.gz"""
48 |
49 | try:
50 | setup(
51 | name='pyjava',
52 | version=VERSION,
53 | description='PyJava Python API',
54 | long_description="",
55 | author='allwefantasy',
56 | author_email='allwefantasy@gmail.com',
57 | url='https://github.com/allwefantasy/pyjava',
58 | packages=['pyjava',
59 | 'pyjava.api',
60 | 'pyjava.udf',
61 | 'pyjava.datatype',
62 | 'pyjava.storage',
63 | 'pyjava.cache'],
64 | include_package_data=True,
65 | package_dir={
66 | 'pyjava.sbin': 'deps/sbin'
67 | },
68 | package_data={
69 | 'pyjava.sbin': []},
70 | license='http://www.apache.org/licenses/LICENSE-2.0',
71 | setup_requires=['pypandoc'],
72 | extras_require={
73 | },
74 | classifiers=[
75 | 'Development Status :: 5 - Production/Stable',
76 | 'License :: OSI Approved :: Apache Software License',
77 | 'Programming Language :: Python :: 3.6',
78 | 'Programming Language :: Python :: 3.7',
79 | 'Programming Language :: Python :: Implementation :: CPython',
80 | 'Programming Language :: Python :: Implementation :: PyPy']
81 | )
82 | finally:
83 | print("--------")
84 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/WowRowEncoder.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * DO NOT EDIT THIS FILE DIRECTLY, ANY CHANGE MAY BE OVERWRITE
3 | */
4 | package org.apache.spark
5 | import org.apache.spark.sql.Row
6 | import org.apache.spark.sql.catalyst.InternalRow
7 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
8 | import org.apache.spark.sql.types.StructType
9 |
10 | object WowRowEncoder {
11 | def toRow(schema: StructType) = {
12 | RowEncoder.apply(schema).resolveAndBind().createDeserializer()
13 |
14 | }
15 |
16 | def fromRow(schema: StructType) = {
17 | RowEncoder.apply(schema).resolveAndBind().createSerializer()
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/src/main/java/org/apache/spark/sql/SparkUtils.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.sql
2 |
3 | import org.apache.arrow.vector.types.pojo.ArrowType
4 | import org.apache.spark.TaskContext
5 | import org.apache.spark.rdd.RDD
6 | import org.apache.spark.sql.catalyst.InternalRow
7 | import org.apache.spark.sql.execution.LogicalRDD
8 | import org.apache.spark.sql.types.{DataType, DecimalType, StructType}
9 |
10 | /**
11 | * 2019-08-13 WilliamZhu(allwefantasy@gmail.com)
12 | */
13 | object SparkUtils {
14 | def internalCreateDataFrame(self: SparkSession,
15 | catalystRows: RDD[InternalRow],
16 | schema: StructType,
17 | isStreaming: Boolean = false): DataFrame = {
18 | val logicalPlan = LogicalRDD(
19 | schema.toAttributes,
20 | catalystRows,
21 | isStreaming = isStreaming)(self)
22 | Dataset.ofRows(self, logicalPlan)
23 | }
24 |
25 | def isFixDecimal(dt: DataType) = {
26 | dt match {
27 | case t@DecimalType.Fixed(precision, scale) => Option(new ArrowType.Decimal(precision, scale))
28 | case _ => None
29 | }
30 | }
31 |
32 | def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc)
33 |
34 | def getKillReason(tc: TaskContext) = tc.getKillReason()
35 |
36 | def killTaskIfInterrupted(tc: TaskContext) = {
37 | tc.killTaskIfInterrupted()
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/ArrowConverters.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package tech.mlsql.arrow
19 |
20 | import org.apache.arrow.flatbuf.MessageHeader
21 | import org.apache.arrow.memory.BufferAllocator
22 | import org.apache.arrow.vector._
23 | import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
24 | import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel}
25 | import org.apache.spark.TaskContext
26 | import org.apache.spark.api.java.JavaRDD
27 | import org.apache.spark.network.util.JavaUtils
28 | import org.apache.spark.sql.catalyst.InternalRow
29 | import org.apache.spark.sql.types.{StructType, _}
30 | import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch}
31 | import org.apache.spark.sql.{DataFrame, SQLContext, SparkUtils}
32 | import org.apache.spark.util.TaskCompletionListener
33 | import tech.mlsql.arrow.context.CommonTaskContext
34 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext}
35 | import tech.mlsql.arrow.python.ispark.SparkContextImp
36 | import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream}
37 | import java.nio.channels.{Channels, ReadableByteChannel}
38 |
39 | import tech.mlsql.common.utils.lang.sc.ScalaReflect
40 |
41 | import scala.collection.JavaConverters._
42 |
43 |
44 | /**
45 | * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format.
46 | */
47 | class ArrowBatchStreamWriter(
48 | schema: StructType,
49 | out: OutputStream,
50 | timeZoneId: String) {
51 |
52 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
53 | val writeChannel = new WriteChannel(Channels.newChannel(out))
54 |
55 | // Write the Arrow schema first, before batches
56 | MessageSerializer.serialize(writeChannel, arrowSchema)
57 |
58 | /**
59 | * Consume iterator to write each serialized ArrowRecordBatch to the stream.
60 | */
61 | def writeBatches(arrowBatchIter: Iterator[Array[Byte]]): Unit = {
62 | arrowBatchIter.foreach(writeChannel.write)
63 | }
64 |
65 | /**
66 | * End the Arrow stream, does not close output stream.
67 | */
68 | def end(): Unit = {
69 | // 等后续无需兼容老版本arrow时,需要调整这行代码
70 | // writeChannel.writeIntLittleEndian(MessageSerializer.IPC_CONTINUATION_TOKEN);
71 | writeChannel.writeIntLittleEndian(0);
72 | }
73 | }
74 |
75 | object ArrowConverters {
76 |
77 | /**
78 | * Maps Iterator from InternalRow to serialized ArrowRecordBatches. Limit ArrowRecordBatch size
79 | * in a batch by setting maxRecordsPerBatch or use 0 to fully consume rowIter.
80 | */
81 | def toBatchIterator(
82 | rowIter: Iterator[InternalRow],
83 | schema: StructType,
84 | maxRecordsPerBatch: Int,
85 | timeZoneId: String,
86 | context: CommonTaskContext): Iterator[Array[Byte]] = {
87 |
88 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
89 | val allocator =
90 | ArrowUtils.rootAllocator.newChildAllocator("toBatchIterator", 0, Long.MaxValue)
91 |
92 | val root = VectorSchemaRoot.create(arrowSchema, allocator)
93 | val unloader = new VectorUnloader(root)
94 | val arrowWriter = ArrowWriter.create(root)
95 |
96 | context match {
97 | case c: AppContextImpl => c.innerContext.asInstanceOf[JavaContext].addTaskCompletionListener { _ =>
98 | root.close()
99 | allocator.close()
100 | }
101 | case c: SparkContextImp => c.innerContext.asInstanceOf[TaskContext].addTaskCompletionListener(new TaskCompletionListener {
102 | override def onTaskCompletion(context: TaskContext): Unit = {
103 | root.close()
104 | allocator.close()
105 | }
106 | })
107 | }
108 |
109 | new Iterator[Array[Byte]] {
110 |
111 | override def hasNext: Boolean = rowIter.hasNext || {
112 | root.close()
113 | allocator.close()
114 | false
115 | }
116 |
117 | override def next(): Array[Byte] = {
118 | val out = new ByteArrayOutputStream()
119 | val writeChannel = new WriteChannel(Channels.newChannel(out))
120 |
121 | Utils.tryWithSafeFinally {
122 | var rowCount = 0
123 | while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) {
124 | val row = rowIter.next()
125 | arrowWriter.write(row)
126 | rowCount += 1
127 | }
128 | arrowWriter.finish()
129 | val batch = unloader.getRecordBatch()
130 | MessageSerializer.serialize(writeChannel, batch)
131 | batch.close()
132 | } {
133 | arrowWriter.reset()
134 | }
135 |
136 | out.toByteArray
137 | }
138 | }
139 | }
140 |
141 | /**
142 | * Maps iterator from serialized ArrowRecordBatches to InternalRows.
143 | */
144 | def fromBatchIterator(
145 | arrowBatchIter: Iterator[Array[Byte]],
146 | schema: StructType,
147 | timeZoneId: String,
148 | context: CommonTaskContext): Iterator[InternalRow] = {
149 | val allocator =
150 | ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue)
151 |
152 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
153 | val root = VectorSchemaRoot.create(arrowSchema, allocator)
154 |
155 | new Iterator[InternalRow] {
156 | private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty
157 |
158 | context.innerContext match {
159 | case c: AppContextImpl => c.innerContext.asInstanceOf[JavaContext].addTaskCompletionListener { _ =>
160 | root.close()
161 | allocator.close()
162 | }
163 | case c: SparkContextImp => c.innerContext.asInstanceOf[TaskContext].addTaskCompletionListener(new TaskCompletionListener {
164 | override def onTaskCompletion(context: TaskContext): Unit = {
165 | root.close()
166 | allocator.close()
167 | }
168 | })
169 | }
170 |
171 | override def hasNext: Boolean = rowIter.hasNext || {
172 | if (arrowBatchIter.hasNext) {
173 | rowIter = nextBatch()
174 | true
175 | } else {
176 | root.close()
177 | allocator.close()
178 | false
179 | }
180 | }
181 |
182 | override def next(): InternalRow = rowIter.next()
183 |
184 | private def nextBatch(): Iterator[InternalRow] = {
185 | val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator)
186 | val vectorLoader = new VectorLoader(root)
187 | vectorLoader.load(arrowRecordBatch)
188 | arrowRecordBatch.close()
189 |
190 | val columns = root.getFieldVectors.asScala.map { vector =>
191 | new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
192 | }.toArray
193 |
194 | val batch = new ColumnarBatch(columns)
195 | batch.setNumRows(root.getRowCount)
196 | batch.rowIterator().asScala
197 | }
198 | }
199 | }
200 |
201 | /**
202 | * Load a serialized ArrowRecordBatch.
203 | */
204 | def loadBatch(
205 | batchBytes: Array[Byte],
206 | allocator: BufferAllocator): ArrowRecordBatch = {
207 | val in = new ByteArrayInputStream(batchBytes)
208 | MessageSerializer.deserializeRecordBatch(
209 | new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException
210 | }
211 |
212 | /**
213 | * Create a DataFrame from an RDD of serialized ArrowRecordBatches.
214 | */
215 | def toDataFrame(
216 | arrowBatchRDD: JavaRDD[Array[Byte]],
217 | schemaString: String,
218 | sqlContext: SQLContext): DataFrame = {
219 | val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
220 | val timeZoneId = sqlContext.sparkSession.sessionState.conf.sessionLocalTimeZone
221 | val rdd = arrowBatchRDD.rdd.mapPartitions { iter =>
222 | val context = new SparkContextImp(TaskContext.get(), null)
223 | ArrowConverters.fromBatchIterator(iter, schema, timeZoneId, context)
224 | }
225 | SparkUtils.internalCreateDataFrame(sqlContext.sparkSession, rdd.setName("arrow"), schema)
226 | }
227 |
228 | /**
229 | * Read a file as an Arrow stream and parallelize as an RDD of serialized ArrowRecordBatches.
230 | */
231 | def readArrowStreamFromFile(
232 | sqlContext: SQLContext,
233 | filename: String): JavaRDD[Array[Byte]] = {
234 | Utils.tryWithResource(new FileInputStream(filename)) { fileStream =>
235 | // Create array to consume iterator so that we can safely close the file
236 | val batches = getBatchesFromStream(fileStream.getChannel).toArray
237 | // Parallelize the record batches to create an RDD
238 | JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length))
239 | }
240 | }
241 |
242 | /**
243 | * Read an Arrow stream input and return an iterator of serialized ArrowRecordBatches.
244 | */
245 | def getBatchesFromStream(in: ReadableByteChannel): Iterator[Array[Byte]] = {
246 |
247 | // Iterate over the serialized Arrow RecordBatch messages from a stream
248 | new Iterator[Array[Byte]] {
249 | var batch: Array[Byte] = readNextBatch()
250 |
251 | override def hasNext: Boolean = batch != null
252 |
253 | override def next(): Array[Byte] = {
254 | val prevBatch = batch
255 | batch = readNextBatch()
256 | prevBatch
257 | }
258 |
259 | // This gets the next serialized ArrowRecordBatch by reading message metadata to check if it
260 | // is a RecordBatch message and then returning the complete serialized message which consists
261 | // of a int32 length, serialized message metadata and a serialized RecordBatch message body
262 | def readNextBatch(): Array[Byte] = {
263 | val msgMetadata = MessageSerializer.readMessage(new ReadChannel(in))
264 | if (msgMetadata == null) {
265 | return null
266 | }
267 |
268 | // Get the length of the body, which has not been read at this point
269 | val bodyLength = msgMetadata.getMessageBodyLength.toInt
270 |
271 | // Only care about RecordBatch messages, skip Schema and unsupported Dictionary messages
272 | if (msgMetadata.getMessage.headerType() == MessageHeader.RecordBatch) {
273 |
274 | // Buffer backed output large enough to hold the complete serialized message
275 | val bbout = new ByteBufferOutputStream(4 + msgMetadata.getMessageLength + bodyLength)
276 |
277 | // Write message metadata to ByteBuffer output stream
278 | MessageSerializer.writeMessageBuffer(
279 | new WriteChannel(Channels.newChannel(bbout)),
280 | msgMetadata.getMessageLength,
281 | msgMetadata.getMessageBuffer)
282 |
283 | // Get a zero-copy ByteBuffer with already contains message metadata, must close first
284 | bbout.close()
285 | val bb = bbout.toByteBuffer
286 | bb.position(bbout.getCount())
287 |
288 | // Read message body directly into the ByteBuffer to avoid copy, return backed byte array
289 | bb.limit(bb.capacity())
290 | JavaUtils.readFully(in, bb)
291 | bb.array()
292 | } else {
293 | if (bodyLength > 0) {
294 | // Skip message body if not a RecordBatch
295 | Channels.newInputStream(in).skip(bodyLength)
296 | }
297 |
298 | // Proceed to next message
299 | readNextBatch()
300 | }
301 | }
302 | }
303 | }
304 | }
305 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/ArrowUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package tech.mlsql.arrow
19 |
20 | import org.apache.arrow.memory.RootAllocator
21 | import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
22 | import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
23 | import org.apache.spark.sql.SparkUtils
24 | import org.apache.spark.sql.types._
25 |
26 | import scala.collection.JavaConverters._
27 |
28 | object ArrowUtils {
29 |
30 | val rootAllocator = new RootAllocator(Long.MaxValue)
31 |
32 | // todo: support more types.
33 |
34 | /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */
35 | def toArrowType(dt: DataType, timeZoneId: String): ArrowType = {
36 | SparkUtils.isFixDecimal(dt).getOrElse {
37 | dt match {
38 | case BooleanType => ArrowType.Bool.INSTANCE
39 | case ByteType => new ArrowType.Int(8, true)
40 | case ShortType => new ArrowType.Int(8 * 2, true)
41 | case IntegerType => new ArrowType.Int(8 * 4, true)
42 | case LongType => new ArrowType.Int(8 * 8, true)
43 | case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
44 | case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
45 | case StringType => ArrowType.Utf8.INSTANCE
46 | case BinaryType => ArrowType.Binary.INSTANCE
47 | case DateType => new ArrowType.Date(DateUnit.DAY)
48 | case TimestampType =>
49 | if (timeZoneId == null) {
50 | throw new UnsupportedOperationException(
51 | s"${TimestampType.catalogString} must supply timeZoneId parameter")
52 | } else {
53 | new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
54 | }
55 | case _ =>
56 |
57 | throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
58 | }
59 | }
60 |
61 | }
62 |
63 | def fromArrowType(dt: ArrowType): DataType = dt match {
64 | case ArrowType.Bool.INSTANCE => BooleanType
65 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType
66 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType
67 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType
68 | case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
69 | case float: ArrowType.FloatingPoint
70 | if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType
71 | case float: ArrowType.FloatingPoint
72 | if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
73 | case ArrowType.Utf8.INSTANCE => StringType
74 | case ArrowType.Binary.INSTANCE => BinaryType
75 | case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
76 | case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
77 | case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
78 | case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
79 | }
80 |
81 | /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
82 | def toArrowField(
83 | name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = {
84 | dt match {
85 | case ArrayType(elementType, containsNull) =>
86 | val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
87 | new Field(name, fieldType,
88 | Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava)
89 | case StructType(fields) =>
90 | val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
91 | new Field(name, fieldType,
92 | fields.map { field =>
93 | toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
94 | }.toSeq.asJava)
95 | case dataType =>
96 | val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null)
97 | new Field(name, fieldType, Seq.empty[Field].asJava)
98 | }
99 | }
100 |
101 | def fromArrowField(field: Field): DataType = {
102 | field.getType match {
103 | case ArrowType.List.INSTANCE =>
104 | val elementField = field.getChildren().get(0)
105 | val elementType = fromArrowField(elementField)
106 | ArrayType(elementType, containsNull = elementField.isNullable)
107 | case ArrowType.Struct.INSTANCE =>
108 | val fields = field.getChildren().asScala.map { child =>
109 | val dt = fromArrowField(child)
110 | StructField(child.getName, dt, child.isNullable)
111 | }
112 | StructType(fields)
113 | case arrowType => fromArrowType(arrowType)
114 | }
115 | }
116 |
117 | /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */
118 | def toArrowSchema(schema: StructType, timeZoneId: String): Schema = {
119 | new Schema(schema.map { field =>
120 | toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
121 | }.asJava)
122 | }
123 |
124 | def fromArrowSchema(schema: Schema): StructType = {
125 | StructType(schema.getFields.asScala.map { field =>
126 | val dt = fromArrowField(field)
127 | StructField(field.getName, dt, field.isNullable)
128 | })
129 | }
130 |
131 |
132 | }
133 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/ArrowWriter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package tech.mlsql.arrow
19 |
20 | import org.apache.arrow.vector._
21 | import org.apache.arrow.vector.complex._
22 | import org.apache.spark.sql.SparkUtils
23 | import org.apache.spark.sql.catalyst.InternalRow
24 | import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
25 | import org.apache.spark.sql.types._
26 |
27 | import scala.collection.JavaConverters._
28 |
29 | object ArrowWriter {
30 |
31 | def create(schema: StructType, timeZoneId: String): ArrowWriter = {
32 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
33 | val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator)
34 | create(root)
35 | }
36 |
37 | def create(root: VectorSchemaRoot): ArrowWriter = {
38 | val children = root.getFieldVectors().asScala.map { vector =>
39 | vector.allocateNew()
40 | createFieldWriter(vector)
41 | }
42 | new ArrowWriter(root, children.toArray)
43 | }
44 |
45 | private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
46 | val field = vector.getField()
47 | val fixDecimalOpt = SparkUtils.isFixDecimal(ArrowUtils.fromArrowField(field))
48 | (fixDecimalOpt, vector) match {
49 | case (Some(i), vector: DecimalVector) => return new DecimalWriter(vector, i.getPrecision, i.getScale)
50 | case (None, _) =>
51 | }
52 |
53 | (ArrowUtils.fromArrowField(field), vector) match {
54 | case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
55 | case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
56 | case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
57 | case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
58 | case (LongType, vector: BigIntVector) => new LongWriter(vector)
59 | case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
60 | case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
61 | case (StringType, vector: VarCharVector) => new StringWriter(vector)
62 | case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
63 | case (DateType, vector: DateDayVector) => new DateWriter(vector)
64 | case (TimestampType, vector: TimeStampMicroTZVector) => new TimestampWriter(vector)
65 | case (ArrayType(_, _), vector: ListVector) =>
66 | val elementVector = createFieldWriter(vector.getDataVector())
67 | new ArrayWriter(vector, elementVector)
68 | case (StructType(_), vector: StructVector) =>
69 | val children = (0 until vector.size()).map { ordinal =>
70 | createFieldWriter(vector.getChildByOrdinal(ordinal))
71 | }
72 | new StructWriter(vector, children.toArray)
73 | case (dt, _) =>
74 | throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
75 | }
76 | }
77 | }
78 |
79 | class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {
80 |
81 | def schema: StructType = StructType(fields.map { f =>
82 | StructField(f.name, f.dataType, f.nullable)
83 | })
84 |
85 | private var count: Int = 0
86 |
87 | def write(row: InternalRow): Unit = {
88 | var i = 0
89 | while (i < fields.size) {
90 | fields(i).write(row, i)
91 | i += 1
92 | }
93 | count += 1
94 | }
95 |
96 | def finish(): Unit = {
97 | root.setRowCount(count)
98 | fields.foreach(_.finish())
99 | }
100 |
101 | def reset(): Unit = {
102 | root.setRowCount(0)
103 | count = 0
104 | fields.foreach(_.reset())
105 | }
106 | }
107 |
108 | private[arrow] abstract class ArrowFieldWriter {
109 |
110 | def valueVector: ValueVector
111 |
112 | def name: String = valueVector.getField().getName()
113 |
114 | def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField())
115 |
116 | def nullable: Boolean = valueVector.getField().isNullable()
117 |
118 | def setNull(): Unit
119 |
120 | def setValue(input: SpecializedGetters, ordinal: Int): Unit
121 |
122 | private[arrow] var count: Int = 0
123 |
124 | def write(input: SpecializedGetters, ordinal: Int): Unit = {
125 | if (input.isNullAt(ordinal)) {
126 | setNull()
127 | } else {
128 | setValue(input, ordinal)
129 | }
130 | count += 1
131 | }
132 |
133 | def finish(): Unit = {
134 | valueVector.setValueCount(count)
135 | }
136 |
137 | def reset(): Unit = {
138 | valueVector.reset()
139 | count = 0
140 | }
141 | }
142 |
143 | private[arrow] class BooleanWriter(val valueVector: BitVector) extends ArrowFieldWriter {
144 |
145 | override def setNull(): Unit = {
146 | valueVector.setNull(count)
147 | }
148 |
149 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
150 | valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
151 | }
152 | }
153 |
154 | private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends ArrowFieldWriter {
155 |
156 | override def setNull(): Unit = {
157 | valueVector.setNull(count)
158 | }
159 |
160 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
161 | valueVector.setSafe(count, input.getByte(ordinal))
162 | }
163 | }
164 |
165 | private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends ArrowFieldWriter {
166 |
167 | override def setNull(): Unit = {
168 | valueVector.setNull(count)
169 | }
170 |
171 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
172 | valueVector.setSafe(count, input.getShort(ordinal))
173 | }
174 | }
175 |
176 | private[arrow] class IntegerWriter(val valueVector: IntVector) extends ArrowFieldWriter {
177 |
178 | override def setNull(): Unit = {
179 | valueVector.setNull(count)
180 | }
181 |
182 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
183 | valueVector.setSafe(count, input.getInt(ordinal))
184 | }
185 | }
186 |
187 | private[arrow] class LongWriter(val valueVector: BigIntVector) extends ArrowFieldWriter {
188 |
189 | override def setNull(): Unit = {
190 | valueVector.setNull(count)
191 | }
192 |
193 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
194 | valueVector.setSafe(count, input.getLong(ordinal))
195 | }
196 | }
197 |
198 | private[arrow] class FloatWriter(val valueVector: Float4Vector) extends ArrowFieldWriter {
199 |
200 | override def setNull(): Unit = {
201 | valueVector.setNull(count)
202 | }
203 |
204 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
205 | valueVector.setSafe(count, input.getFloat(ordinal))
206 | }
207 | }
208 |
209 | private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends ArrowFieldWriter {
210 |
211 | override def setNull(): Unit = {
212 | valueVector.setNull(count)
213 | }
214 |
215 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
216 | valueVector.setSafe(count, input.getDouble(ordinal))
217 | }
218 | }
219 |
220 | private[arrow] class DecimalWriter(
221 | val valueVector: DecimalVector,
222 | precision: Int,
223 | scale: Int) extends ArrowFieldWriter {
224 |
225 | override def setNull(): Unit = {
226 | valueVector.setNull(count)
227 | }
228 |
229 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
230 | val decimal = input.getDecimal(ordinal, precision, scale)
231 | if (decimal.changePrecision(precision, scale)) {
232 | valueVector.setSafe(count, decimal.toJavaBigDecimal)
233 | } else {
234 | setNull()
235 | }
236 | }
237 | }
238 |
239 | private[arrow] class StringWriter(val valueVector: VarCharVector) extends ArrowFieldWriter {
240 |
241 | override def setNull(): Unit = {
242 | valueVector.setNull(count)
243 | }
244 |
245 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
246 | val utf8 = input.getUTF8String(ordinal)
247 | val utf8ByteBuffer = utf8.getByteBuffer
248 | // todo: for off-heap UTF8String, how to pass in to arrow without copy?
249 | valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes())
250 | }
251 | }
252 |
253 | private[arrow] class BinaryWriter(
254 | val valueVector: VarBinaryVector) extends ArrowFieldWriter {
255 |
256 | override def setNull(): Unit = {
257 | valueVector.setNull(count)
258 | }
259 |
260 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
261 | val bytes = input.getBinary(ordinal)
262 | valueVector.setSafe(count, bytes, 0, bytes.length)
263 | }
264 | }
265 |
266 | private[arrow] class DateWriter(val valueVector: DateDayVector) extends ArrowFieldWriter {
267 |
268 | override def setNull(): Unit = {
269 | valueVector.setNull(count)
270 | }
271 |
272 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
273 | valueVector.setSafe(count, input.getInt(ordinal))
274 | }
275 | }
276 |
277 | private[arrow] class TimestampWriter(
278 | val valueVector: TimeStampMicroTZVector) extends ArrowFieldWriter {
279 |
280 | override def setNull(): Unit = {
281 | valueVector.setNull(count)
282 | }
283 |
284 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
285 | valueVector.setSafe(count, input.getLong(ordinal))
286 | }
287 | }
288 |
289 | private[arrow] class ArrayWriter(
290 | val valueVector: ListVector,
291 | val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter {
292 |
293 | override def setNull(): Unit = {
294 | }
295 |
296 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
297 | val array = input.getArray(ordinal)
298 | var i = 0
299 | valueVector.startNewValue(count)
300 | while (i < array.numElements()) {
301 | elementWriter.write(array, i)
302 | i += 1
303 | }
304 | valueVector.endValue(count, array.numElements())
305 | }
306 |
307 | override def finish(): Unit = {
308 | super.finish()
309 | elementWriter.finish()
310 | }
311 |
312 | override def reset(): Unit = {
313 | super.reset()
314 | elementWriter.reset()
315 | }
316 | }
317 |
318 | private[arrow] class StructWriter(
319 | val valueVector: StructVector,
320 | children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {
321 |
322 | override def setNull(): Unit = {
323 | var i = 0
324 | while (i < children.length) {
325 | children(i).setNull()
326 | children(i).count += 1
327 | i += 1
328 | }
329 | valueVector.setNull(count)
330 | }
331 |
332 | override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
333 | val struct = input.getStruct(ordinal, children.length)
334 | var i = 0
335 | while (i < struct.numFields) {
336 | children(i).write(struct, i)
337 | i += 1
338 | }
339 | valueVector.setIndexDefined(count)
340 | }
341 |
342 | override def finish(): Unit = {
343 | super.finish()
344 | children.foreach(_.finish())
345 | }
346 |
347 | override def reset(): Unit = {
348 | super.reset()
349 | children.foreach(_.reset())
350 | }
351 | }
352 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/ByteBufferOutputStream.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow
2 |
3 | import java.io.ByteArrayOutputStream
4 | import java.nio.ByteBuffer
5 |
6 | /**
7 | * 2019-08-13 WilliamZhu(allwefantasy@gmail.com)
8 | */
9 | class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutputStream(capacity) {
10 |
11 | def this() = this(32)
12 |
13 | def getCount(): Int = count
14 |
15 | private[this] var closed: Boolean = false
16 |
17 | override def write(b: Int): Unit = {
18 | require(!closed, "cannot write to a closed ByteBufferOutputStream")
19 | super.write(b)
20 | }
21 |
22 | override def write(b: Array[Byte], off: Int, len: Int): Unit = {
23 | require(!closed, "cannot write to a closed ByteBufferOutputStream")
24 | super.write(b, off, len)
25 | }
26 |
27 | override def reset(): Unit = {
28 | require(!closed, "cannot reset a closed ByteBufferOutputStream")
29 | super.reset()
30 | }
31 |
32 | override def close(): Unit = {
33 | if (!closed) {
34 | super.close()
35 | closed = true
36 | }
37 | }
38 |
39 | def toByteBuffer: ByteBuffer = {
40 | require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
41 | ByteBuffer.wrap(buf, 0, count)
42 | }
43 | }
44 |
45 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/Utils.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow
2 |
3 | import java.io._
4 | import java.nio.charset.StandardCharsets
5 | import java.util.concurrent.TimeUnit
6 |
7 | import org.apache.spark.network.util.JavaUtils
8 | import tech.mlsql.arrow.api.RedirectStreams
9 | import tech.mlsql.arrow.python.PythonWorkerFactory.Tool.REDIRECT_IMPL
10 | import tech.mlsql.common.utils.log.Logging
11 |
12 | import scala.io.Source
13 | import scala.util.Try
14 | import scala.util.control.{ControlThrowable, NonFatal}
15 |
16 |
17 | /**
18 | * 2019-08-13 WilliamZhu(allwefantasy@gmail.com)
19 | */
20 | object Utils extends Logging {
21 | def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
22 | var originalThrowable: Throwable = null
23 | try {
24 | block
25 | } catch {
26 | case t: Throwable =>
27 | // Purposefully not using NonFatal, because even fatal exceptions
28 | // we don't want to have our finallyBlock suppress
29 | originalThrowable = t
30 | throw originalThrowable
31 | } finally {
32 | try {
33 | finallyBlock
34 | } catch {
35 | case t: Throwable if (originalThrowable != null && originalThrowable != t) =>
36 | originalThrowable.addSuppressed(t)
37 | logWarning(s"Suppressing exception in finally: ${t.getMessage}", t)
38 | throw originalThrowable
39 | }
40 | }
41 | }
42 |
43 |
44 | def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = {
45 | val resource = createResource
46 | try f.apply(resource) finally resource.close()
47 | }
48 |
49 |
50 | /**
51 | * Return the stderr of a process after waiting for the process to terminate.
52 | * If the process does not terminate within the specified timeout, return None.
53 | */
54 | def getStderr(process: Process, timeoutMs: Long): Option[String] = {
55 | val terminated = process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)
56 | if (terminated) {
57 | Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n"))
58 | } else {
59 | None
60 | }
61 | }
62 |
63 | /**
64 | * Execute the given block, logging and re-throwing any uncaught exception.
65 | * This is particularly useful for wrapping code that runs in a thread, to ensure
66 | * that exceptions are printed, and to avoid having to catch Throwable.
67 | */
68 | def logUncaughtExceptions[T](f: => T): T = {
69 | try {
70 | f
71 | } catch {
72 | case ct: ControlThrowable =>
73 | throw ct
74 | case t: Throwable =>
75 | logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
76 | throw t
77 | }
78 | }
79 |
80 |
81 | class RedirectThread(
82 | in: InputStream,
83 | out: OutputStream,
84 | name: String,
85 | propagateEof: Boolean = false)
86 | extends Thread(name) {
87 |
88 | setDaemon(true)
89 |
90 | override def run() {
91 | scala.util.control.Exception.ignoring(classOf[IOException]) {
92 | // FIXME: We copy the stream on the level of bytes to avoid encoding problems.
93 | Utils.tryWithSafeFinally {
94 | val buf = new Array[Byte](1024)
95 | var len = in.read(buf)
96 | while (len != -1) {
97 | out.write(buf, 0, len)
98 | out.flush()
99 | len = in.read(buf)
100 | }
101 | } {
102 | if (propagateEof) {
103 | out.close()
104 | }
105 | }
106 | }
107 | }
108 | }
109 |
110 | /** Executes the given block in a Try, logging any uncaught exceptions. */
111 | def tryLog[T](f: => T): Try[T] = {
112 | try {
113 | val res = f
114 | scala.util.Success(res)
115 | } catch {
116 | case ct: ControlThrowable =>
117 | throw ct
118 | case t: Throwable =>
119 | logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t)
120 | scala.util.Failure(t)
121 | }
122 | }
123 |
124 | /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */
125 | def isFatalError(e: Throwable): Boolean = {
126 | e match {
127 | case NonFatal(_) |
128 | _: InterruptedException |
129 | _: NotImplementedError |
130 | _: ControlThrowable |
131 | _: LinkageError =>
132 | false
133 | case _ =>
134 | true
135 | }
136 | }
137 |
138 | def deleteRecursively(file: File): Unit = {
139 | if (file != null) {
140 | JavaUtils.deleteRecursively(file)
141 | }
142 | }
143 |
144 | def redirectStream(conf: Map[String, String], stdout: InputStream) {
145 | try {
146 | conf.get(REDIRECT_IMPL) match {
147 | case None =>
148 | new RedirectThread(stdout, System.err, "stdout reader ").start()
149 | case Some(clzz) =>
150 | val instance = Class.forName(clzz).newInstance().asInstanceOf[RedirectStreams]
151 | instance.setConf(conf)
152 | instance.stdOut(stdout)
153 | }
154 | } catch {
155 | case e: Exception =>
156 | logError("Exception in redirecting streams", e)
157 | }
158 | }
159 |
160 | def writeUTF(str: String, dataOut: DataOutputStream) {
161 | val bytes = str.getBytes(StandardCharsets.UTF_8)
162 | dataOut.writeInt(bytes.length)
163 | dataOut.write(bytes)
164 | }
165 |
166 | }
167 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/api/RedirectStreams.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.api
2 |
3 | import java.io.InputStream
4 |
5 |
6 | /**
7 | * We will redirect all python daemon(or worker) stdout/stderr to the the java's stderr
8 | * by default. If you wanna change this behavior try to implements this Trait and
9 | * config the conf in ArrowPythonRunner.
10 | *
11 | * Example:
12 | *
13 | * new ArrowPythonRunner(
14 | * ....
15 | * conf=Map("python.redirect.impl"->"your impl class name")
16 | * )
17 | *
18 | */
19 | trait RedirectStreams {
20 |
21 | def setConf(conf: Map[String, String]): Unit
22 |
23 | def conf: Map[String, String]
24 |
25 | def stdOut(stdOut: InputStream): Unit
26 |
27 | def stdErr(stdErr: InputStream): Unit
28 | }
29 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/context/CommonTaskContext.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.context
2 |
3 | import java.net.{ServerSocket, Socket}
4 | import java.util.concurrent.atomic.AtomicBoolean
5 |
6 | import org.apache.arrow.memory.BufferAllocator
7 | import org.apache.arrow.vector.ipc.ArrowStreamReader
8 | import tech.mlsql.arrow.python.runner.ArrowPythonRunner
9 |
10 | /**
11 | * This class is used to manager your tasks. When you create a thread to execute something which uses ArrowPythonRunner
12 | * , we should take care the situation when the task is normally completed or unexpectedly finished, there are some callback or action
13 | * should be performed
14 | */
15 | trait CommonTaskContext {
16 | val arrowPythonRunner: ArrowPythonRunner
17 |
18 | /**
19 | * When reader begins to read, it will invoke the return function. Please use something to remember them
20 | * and When the task is done, use the returned function to clean the reader resource
21 | */
22 | def readerRegister(callback: () => Unit): (ArrowStreamReader, BufferAllocator) => Unit
23 |
24 | /**
25 | * releasedOrClosed:AtomicBoolean should release or close the python worker we have used.
26 | * reuseWorker: should reuse the python worker
27 | * socket: python worker
28 | * When task started, it will invoke the return function. Please use something to remember them and
29 | * 2hen the task is done, use the returned function to clean the python resource
30 | */
31 | def pythonWorkerRegister(callback: () => Unit): (
32 | AtomicBoolean,
33 | Boolean,
34 | Socket) => Unit
35 |
36 |
37 | /**
38 | *
39 | * isBarrier is true, you should also set shutdownServerSocket
40 | */
41 | def isBarrier: Boolean
42 |
43 | /**
44 | * When task started, it will invoke the return function. Please use something to remember them and
45 | * when the task is done, use the returned function to clean the python server side SocketServer
46 | */
47 | def javaSideSocketServerRegister(): (ServerSocket) => Unit
48 |
49 | def monitor(callback: () => Unit): (Long, String, Map[String, String], Socket) => Unit
50 |
51 | def assertTaskIsCompleted(callback: () => Unit): () => Unit
52 |
53 | /**
54 | *
55 | * set innerContext, and you can use innerContext to get it again
56 | */
57 | def setTaskContext(): () => Unit
58 |
59 | def innerContext: Any
60 |
61 | /**
62 | * check task status
63 | */
64 | def isTaskCompleteOrInterrupt(): () => Boolean
65 |
66 | def isTaskInterrupt(): () => Boolean
67 |
68 | def getTaskKillReason(): () => Option[String]
69 |
70 | def killTaskIfInterrupted(): () => Unit
71 | }
72 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/javadoc.java:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow;
2 |
3 | /**
4 | * 2019-08-17 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | public class javadoc {
7 | }
8 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/PyJavaException.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python
2 |
3 | /**
4 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 | class PyJavaException(message: String, cause: Throwable)
7 | extends Exception(message, cause) {
8 |
9 | def this(message: String) = this(message, null)
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/PythonWorkerFactory.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python
2 |
3 | /**
4 | * 2019-08-14 WilliamZhu(allwefantasy@gmail.com)
5 | */
6 |
7 | import java.io._
8 | import java.net.{InetAddress, ServerSocket, Socket, SocketException}
9 | import java.util.Arrays
10 | import java.util.concurrent.TimeUnit
11 | import java.util.concurrent.atomic.AtomicLong
12 |
13 | import javax.annotation.concurrent.GuardedBy
14 | import tech.mlsql.arrow.Utils
15 | import tech.mlsql.arrow.python.runner.PythonConf
16 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros
17 | import tech.mlsql.common.utils.log.Logging
18 |
19 | import scala.collection.JavaConverters._
20 | import scala.collection.mutable
21 |
22 | class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String], conf: Map[String, String])
23 | extends Logging {
24 | self =>
25 |
26 | import PythonWorkerFactory.Tool._
27 |
28 | // Because forking processes from Java is expensive, we prefer to launch a single Python daemon,
29 | // pyspark/daemon.py (by default) and tell it to fork new workers for our tasks. This daemon
30 | // currently only works on UNIX-based systems now because it uses signals for child management,
31 | // so we can also fall back to launching workers, pyspark/worker.py (by default) directly.
32 | private val useDaemon = {
33 | val useDaemonEnabled = true
34 |
35 | // This flag is ignored on Windows as it's unable to fork.
36 | !System.getProperty("os.name").startsWith("Windows") && useDaemonEnabled
37 | }
38 |
39 | // WARN: Both configurations, 'spark.python.daemon.module' and 'spark.python.worker.module' are
40 | // for very advanced users and they are experimental. This should be considered
41 | // as expert-only option, and shouldn't be used before knowing what it means exactly.
42 |
43 | // This configuration indicates the module to run the daemon to execute its Python workers.
44 | private val daemonModule = conf.getOrElse(PYTHON_DAEMON_MODULE, "pyjava.daemon")
45 |
46 |
47 | // This configuration indicates the module to run each Python worker.
48 | private val workerModule = conf.getOrElse(PYTHON_WORKER_MODULE, "pyjava.worker")
49 |
50 | private val workerIdleTime = conf.getOrElse(PYTHON_WORKER_IDLE_TIME, "1").toInt
51 |
52 | @GuardedBy("self")
53 | private var daemon: Process = null
54 | val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
55 | @GuardedBy("self")
56 | private var daemonPort: Int = 0
57 | @GuardedBy("self")
58 | private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
59 | @GuardedBy("self")
60 | private val idleWorkers = new mutable.Queue[Socket]()
61 | @GuardedBy("self")
62 | private var lastActivityNs = 0L
63 |
64 |
65 | private val monitorThread = new MonitorThread()
66 | monitorThread.setWorkerIdleTime(workerIdleTime)
67 | monitorThread.start()
68 |
69 | @GuardedBy("self")
70 | private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
71 |
72 | private val pythonPath = mergePythonPaths(
73 | envVars.getOrElse("PYTHONPATH", ""),
74 | sys.env.getOrElse("PYTHONPATH", ""))
75 |
76 | def create(): Socket = {
77 | val socket = if (useDaemon) {
78 | self.synchronized {
79 | if (idleWorkers.nonEmpty) {
80 | return idleWorkers.dequeue()
81 | }
82 | }
83 | createThroughDaemon()
84 | } else {
85 | createSimpleWorker()
86 | }
87 | socket
88 | }
89 |
90 | /**
91 | * Connect to a worker launched through pyspark/daemon.py (by default), which forks python
92 | * processes itself to avoid the high cost of forking from Java. This currently only works
93 | * on UNIX-based systems.
94 | */
95 | private def createThroughDaemon(): Socket = {
96 |
97 | def createSocket(): Socket = {
98 | val socket = new Socket(daemonHost, daemonPort)
99 | val pid = new DataInputStream(socket.getInputStream).readInt()
100 | if (pid < 0) {
101 | throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
102 | }
103 | daemonWorkers.put(socket, pid)
104 | socket
105 | }
106 |
107 | self.synchronized {
108 | // Start the daemon if it hasn't been started
109 | startDaemon()
110 |
111 | // Attempt to connect, restart and retry once if it fails
112 | try {
113 | createSocket()
114 | } catch {
115 | case exc: SocketException =>
116 | logWarning("Failed to open socket to Python daemon:", exc)
117 | logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
118 | stopDaemon()
119 | startDaemon()
120 | createSocket()
121 | }
122 | }
123 | }
124 |
125 | /**
126 | * Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
127 | */
128 | private def createSimpleWorker(): Socket = {
129 | var serverSocket: ServerSocket = null
130 | try {
131 | serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
132 |
133 | // Create and start the worker
134 | val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule))
135 | val workerEnv = pb.environment()
136 | workerEnv.putAll(envVars.asJava)
137 | workerEnv.put("PYTHONPATH", pythonPath)
138 | // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
139 | workerEnv.put("PYTHONUNBUFFERED", "YES")
140 | workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString)
141 | val worker = pb.start()
142 |
143 | // Redirect worker stdout and stderr
144 | Utils.redirectStream(conf, worker.getInputStream)
145 | Utils.redirectStream(conf, worker.getErrorStream)
146 |
147 | // Wait for it to connect to our socket, and validate the auth secret.
148 | serverSocket.setSoTimeout(10000)
149 |
150 | try {
151 | val socket = serverSocket.accept()
152 | self.synchronized {
153 | simpleWorkers.put(socket, worker)
154 | }
155 | return socket
156 | } catch {
157 | case e: Exception =>
158 | throw new RuntimeException("Python worker failed to connect back.", e)
159 | }
160 | } finally {
161 | if (serverSocket != null) {
162 | serverSocket.close()
163 | }
164 | }
165 | null
166 | }
167 |
168 | private def startDaemon() {
169 | self.synchronized {
170 | // Is it already running?
171 | if (daemon != null) {
172 | return
173 | }
174 |
175 | try {
176 | // Create and start the daemon
177 | val envCommand = envVars.getOrElse(ScalaMethodMacros.str(PythonConf.PYTHON_ENV), "")
178 | val command = Seq("bash", "-c", envCommand + s" && python -m ${daemonModule}")
179 | val pb = new ProcessBuilder(command.asJava)
180 | val workerEnv = pb.environment()
181 | workerEnv.putAll(envVars.asJava)
182 | workerEnv.put("PYTHONPATH", pythonPath)
183 | // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
184 | workerEnv.put("PYTHONUNBUFFERED", "YES")
185 | daemon = pb.start()
186 |
187 | val in = new DataInputStream(daemon.getInputStream)
188 | try {
189 | daemonPort = in.readInt()
190 | } catch {
191 | case _: EOFException =>
192 | throw new RuntimeException(s"No port number in $daemonModule's stdout")
193 | }
194 |
195 | // test that the returned port number is within a valid range.
196 | // note: this does not cover the case where the port number
197 | // is arbitrary data but is also coincidentally within range
198 | if (daemonPort < 1 || daemonPort > 0xffff) {
199 | val exceptionMessage =
200 | f"""
201 | |Bad data in $daemonModule's standard output. Invalid port number:
202 | | $daemonPort (0x$daemonPort%08x)
203 | |Python command to execute the daemon was:
204 | | ${command.mkString(" ")}
205 | |Check that you don't have any unexpected modules or libraries in
206 | |your PYTHONPATH:
207 | | $pythonPath
208 | |Also, check if you have a sitecustomize.py module in your python path,
209 | |or in your python installation, that is printing to standard output"""
210 | throw new RuntimeException(exceptionMessage.stripMargin)
211 | }
212 |
213 | // Redirect daemon stdout and stderr
214 | Utils.redirectStream(conf, in)
215 | Utils.redirectStream(conf, daemon.getErrorStream)
216 | } catch {
217 | case e: Exception =>
218 |
219 | // If the daemon exists, wait for it to finish and get its stderr
220 | val stderr = Option(daemon)
221 | .flatMap { d => Utils.getStderr(d, PROCESS_WAIT_TIMEOUT_MS) }
222 | .getOrElse("")
223 |
224 | stopDaemon()
225 |
226 | if (stderr != "") {
227 | val formattedStderr = stderr.replace("\n", "\n ")
228 | val errorMessage =
229 | s"""
230 | |Error from python worker:
231 | | $formattedStderr
232 | |PYTHONPATH was:
233 | | $pythonPath
234 | |$e"""
235 |
236 | // Append error message from python daemon, but keep original stack trace
237 | val wrappedException = new RuntimeException(errorMessage.stripMargin)
238 | wrappedException.setStackTrace(e.getStackTrace)
239 | throw wrappedException
240 | } else {
241 | throw e
242 | }
243 | }
244 |
245 | // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
246 | // detect our disappearance.
247 | }
248 | }
249 |
250 |
251 | /**
252 | * Monitor all the idle workers, kill them after timeout.
253 | */
254 | private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
255 | //minutes
256 | val IDLE_WORKER_TIMEOUT_NS_REF = new AtomicLong(TimeUnit.MINUTES.toNanos(1))
257 |
258 | def setWorkerIdleTime(minutes: Int) = {
259 | IDLE_WORKER_TIMEOUT_NS_REF.set(TimeUnit.MINUTES.toNanos(minutes))
260 | }
261 |
262 | setDaemon(true)
263 |
264 | override def run() {
265 | while (true) {
266 | self.synchronized {
267 | if (IDLE_WORKER_TIMEOUT_NS_REF.get() < System.nanoTime() - lastActivityNs) {
268 | cleanupIdleWorkers()
269 | lastActivityNs = System.nanoTime()
270 | }
271 | }
272 | Thread.sleep(10000)
273 | }
274 | }
275 | }
276 |
277 | private def cleanupIdleWorkers() {
278 | while (idleWorkers.nonEmpty) {
279 | val worker = idleWorkers.dequeue()
280 | try {
281 | // the worker will exit after closing the socket
282 | worker.close()
283 | } catch {
284 | case e: Exception =>
285 | logWarning("Failed to close worker socket", e)
286 | }
287 | }
288 | }
289 |
290 | private def stopDaemon() {
291 | self.synchronized {
292 | if (useDaemon) {
293 | cleanupIdleWorkers()
294 |
295 | // Request shutdown of existing daemon by sending SIGTERM
296 | if (daemon != null) {
297 | daemon.destroy()
298 | }
299 |
300 | daemon = null
301 | daemonPort = 0
302 | } else {
303 | simpleWorkers.mapValues(_.destroy())
304 | }
305 | }
306 | }
307 |
308 | def stop() {
309 | stopDaemon()
310 | }
311 |
312 | def stopWorker(worker: Socket) {
313 | self.synchronized {
314 | if (useDaemon) {
315 | if (daemon != null) {
316 | daemonWorkers.get(worker).foreach { pid =>
317 | // tell daemon to kill worker by pid
318 | val output = new DataOutputStream(daemon.getOutputStream)
319 | output.writeInt(pid)
320 | output.flush()
321 | daemon.getOutputStream.flush()
322 | }
323 | }
324 | } else {
325 | simpleWorkers.get(worker).foreach(_.destroy())
326 | }
327 | }
328 | worker.close()
329 | }
330 |
331 | def releaseWorker(worker: Socket) {
332 | if (useDaemon) {
333 | self.synchronized {
334 | lastActivityNs = System.nanoTime()
335 | idleWorkers.enqueue(worker)
336 | }
337 | } else {
338 | // Cleanup the worker socket. This will also cause the Python worker to exit.
339 | try {
340 | worker.close()
341 | } catch {
342 | case e: Exception =>
343 | logWarning("Failed to close worker socket", e)
344 | }
345 | }
346 | }
347 | }
348 |
349 | object PythonWorkerFactory {
350 |
351 | private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
352 |
353 | def createPythonWorker(pythonExec: String, envVars: Map[String, String], conf: Map[String, String]): java.net.Socket = {
354 | synchronized {
355 | val key = (pythonExec, envVars)
356 | pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars, conf)).create()
357 | }
358 | }
359 |
360 |
361 | def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
362 | synchronized {
363 | val key = (pythonExec, envVars)
364 | pythonWorkers.get(key).foreach(_.stopWorker(worker))
365 | }
366 | }
367 |
368 |
369 | def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
370 | synchronized {
371 | val key = (pythonExec, envVars)
372 | pythonWorkers.get(key).foreach(_.releaseWorker(worker))
373 | }
374 | }
375 |
376 |
377 | object Tool {
378 | val PROCESS_WAIT_TIMEOUT_MS = 10000
379 | val PYTHON_DAEMON_MODULE = "python.daemon.module"
380 | val PYTHON_WORKER_MODULE = "python.worker.module"
381 | val PYTHON_WORKER_IDLE_TIME = "python.worker.idle.time"
382 | val PYTHON_TASK_KILL_TIMEOUT = "python.task.killTimeout"
383 | val REDIRECT_IMPL = "python.redirect.impl"
384 |
385 | def mergePythonPaths(paths: String*): String = {
386 | paths.filter(_ != "").mkString(File.pathSeparator)
387 | }
388 | }
389 |
390 | }
391 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/iapp/AppContextImpl.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python.iapp
2 |
3 | import java.net.{ServerSocket, Socket}
4 | import java.util.concurrent.atomic.AtomicBoolean
5 | import java.util.{EventListener, UUID}
6 |
7 | import org.apache.arrow.memory.BufferAllocator
8 | import org.apache.arrow.vector.ipc.ArrowStreamReader
9 | import tech.mlsql.arrow.context.CommonTaskContext
10 | import tech.mlsql.arrow.python.PythonWorkerFactory
11 | import tech.mlsql.arrow.python.runner.ArrowPythonRunner
12 | import tech.mlsql.common.utils.log.Logging
13 |
14 | import scala.collection.mutable.ArrayBuffer
15 |
16 | /**
17 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com)
18 | */
19 |
20 | class AppContextImpl(context: JavaContext, _arrowPythonRunner: ArrowPythonRunner) extends CommonTaskContext with Logging {
21 | override def pythonWorkerRegister(callback: () => Unit) = {
22 | (releasedOrClosed: AtomicBoolean,
23 | reuseWorker: Boolean,
24 | worker: Socket
25 | ) => {
26 | context.addTaskCompletionListener[Unit] { _ =>
27 | //writerThread.shutdownOnTaskCompletion()
28 | callback()
29 | if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
30 | try {
31 | worker.close()
32 | } catch {
33 | case e: Exception =>
34 | logWarning("Failed to close worker socket", e)
35 | }
36 | }
37 | }
38 | }
39 | }
40 |
41 | override def assertTaskIsCompleted(callback: () => Unit) = {
42 | () => {
43 | assert(context.isCompleted)
44 | }
45 | }
46 |
47 | override def setTaskContext(): () => Unit = {
48 | () => {
49 |
50 | }
51 | }
52 |
53 | override def innerContext: Any = context
54 |
55 | override def isBarrier: Boolean = false
56 |
57 | override def monitor(callback: () => Unit) = {
58 | (taskKillTimeout: Long, pythonExec: String, envVars: Map[String, String], worker: Socket) => {
59 | // Kill the worker if it is interrupted, checking until task completion.
60 | // TODO: This has a race condition if interruption occurs, as completed may still become true.
61 | while (!context.isInterrupted && !context.isCompleted) {
62 | Thread.sleep(2000)
63 | }
64 | if (!context.isCompleted) {
65 | Thread.sleep(taskKillTimeout)
66 | if (!context.isCompleted) {
67 | try {
68 | // Mimic the task name used in `Executor` to help the user find out the task to blame.
69 | val taskName = s"${context.partitionId}"
70 | logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker")
71 | PythonWorkerFactory.destroyPythonWorker(pythonExec, envVars, worker)
72 | } catch {
73 | case e: Exception =>
74 | logError("Exception when trying to kill worker", e)
75 | }
76 | }
77 | }
78 | }
79 | }
80 |
81 | override val arrowPythonRunner: ArrowPythonRunner = _arrowPythonRunner
82 |
83 | override def javaSideSocketServerRegister(): ServerSocket => Unit = {
84 | (server: ServerSocket) => {
85 | context.addTaskCompletionListener[Unit](_ => server.close())
86 | }
87 | }
88 |
89 | override def isTaskCompleteOrInterrupt(): () => Boolean = {
90 | () => {
91 | context.isCompleted || context.isInterrupted
92 | }
93 | }
94 |
95 | override def isTaskInterrupt(): () => Boolean = {
96 | () => {
97 | context.isInterrupted
98 | }
99 | }
100 |
101 | override def getTaskKillReason(): () => Option[String] = {
102 | () => {
103 | context.getKillReason
104 | }
105 | }
106 |
107 | override def killTaskIfInterrupted(): () => Unit = {
108 | () => {
109 | context.killTaskIfInterrupted
110 | }
111 | }
112 |
113 | override def readerRegister(callback: () => Unit): (ArrowStreamReader, BufferAllocator) => Unit = {
114 | (reader, allocator) => {
115 | context.addTaskCompletionListener[Unit] { _ =>
116 | if (reader != null) {
117 | reader.close(false)
118 | }
119 | try {
120 | allocator.close()
121 | } catch {
122 | case e: Exception =>
123 | logError("allocator.close()", e)
124 | }
125 | }
126 | }
127 | }
128 | }
129 |
130 |
131 | class JavaContext {
132 | val buffer = new ArrayBuffer[AppTaskCompletionListener]()
133 |
134 | var _isCompleted = false
135 | var _partitionId = UUID.randomUUID().toString
136 | var reasonIfKilled: Option[String] = None
137 | var getKillReason = reasonIfKilled
138 |
139 | def isInterrupted = reasonIfKilled.isDefined
140 |
141 | def isCompleted = _isCompleted
142 |
143 | def partitionId = _partitionId
144 |
145 | def markComplete = {
146 | _isCompleted = true
147 | }
148 |
149 | def markInterrupted(reason: String): Unit = {
150 | reasonIfKilled = Some(reason)
151 | }
152 |
153 |
154 | def killTaskIfInterrupted = {
155 | val reason = reasonIfKilled
156 | if (reason.isDefined) {
157 | throw new TaskKilledException(reason.get)
158 | }
159 | }
160 |
161 |
162 | def addTaskCompletionListener[U](f: (JavaContext) => U): JavaContext = {
163 | // Note that due to this scala bug: https://github.com/scala/bug/issues/11016, we need to make
164 | // this function polymorphic for every scala version >= 2.12, otherwise an overloaded method
165 | // resolution error occurs at compile time.
166 | _addTaskCompletionListener(new AppTaskCompletionListener {
167 | override def onTaskCompletion(context: JavaContext): Unit = f(context)
168 | })
169 | }
170 |
171 | def _addTaskCompletionListener(listener: AppTaskCompletionListener): JavaContext = {
172 | buffer += listener
173 | this
174 | }
175 |
176 | def close = {
177 | buffer.foreach(_.onTaskCompletion(this))
178 | }
179 | }
180 |
181 | trait AppTaskCompletionListener extends EventListener {
182 | def onTaskCompletion(context: JavaContext): Unit
183 | }
184 |
185 | class TaskKilledException(val reason: String) extends RuntimeException {
186 | def this() = this("unknown reason")
187 | }
188 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/ispark/SparkContextImp.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python.ispark
2 |
3 | import java.net.{ServerSocket, Socket}
4 | import java.util.concurrent.atomic.AtomicBoolean
5 |
6 | import org.apache.arrow.memory.BufferAllocator
7 | import org.apache.arrow.vector.ipc.ArrowStreamReader
8 | import org.apache.spark.TaskContext
9 | import org.apache.spark.sql.SparkUtils
10 | import org.apache.spark.util.TaskCompletionListener
11 | import tech.mlsql.arrow.context.CommonTaskContext
12 | import tech.mlsql.arrow.python.PythonWorkerFactory
13 | import tech.mlsql.arrow.python.runner.ArrowPythonRunner
14 | import tech.mlsql.common.utils.log.Logging
15 |
16 | /**
17 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com)
18 | */
19 | class SparkContextImp(context: TaskContext, _arrowPythonRunner: ArrowPythonRunner) extends CommonTaskContext with Logging {
20 | override def pythonWorkerRegister(callback: () => Unit) = {
21 | (releasedOrClosed: AtomicBoolean,
22 | reuseWorker: Boolean,
23 | worker: Socket
24 | ) => {
25 | context.addTaskCompletionListener(new TaskCompletionListener {
26 | override def onTaskCompletion(context: TaskContext): Unit = {
27 | //writerThread.shutdownOnTaskCompletion()
28 | callback()
29 | if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
30 | try {
31 | worker.close()
32 | } catch {
33 | case e: Exception =>
34 | logWarning("Failed to close worker socket", e)
35 | }
36 | }
37 | }
38 | })
39 | }
40 | }
41 |
42 | override def assertTaskIsCompleted(callback: () => Unit) = {
43 | () => {
44 | assert(context.isCompleted)
45 | }
46 | }
47 |
48 | override def setTaskContext(): () => Unit = {
49 | () => {
50 | SparkUtils.setTaskContext(context)
51 | }
52 | }
53 |
54 | override def innerContext: Any = context
55 |
56 | override def isBarrier: Boolean = context.getClass.getName == "org.apache.spark.BarrierTaskContext"
57 |
58 | override def monitor(callback: () => Unit) = {
59 | (taskKillTimeout: Long, pythonExec: String, envVars: Map[String, String], worker: Socket) => {
60 | // Kill the worker if it is interrupted, checking until task completion.
61 | // TODO: This has a race condition if interruption occurs, as completed may still become true.
62 | while (!context.isInterrupted && !context.isCompleted) {
63 | Thread.sleep(2000)
64 | }
65 | if (!context.isCompleted) {
66 | Thread.sleep(taskKillTimeout)
67 | if (!context.isCompleted) {
68 | try {
69 | // Mimic the task name used in `Executor` to help the user find out the task to blame.
70 | val taskName = s"${context.partitionId}.${context.attemptNumber} " +
71 | s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
72 | logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker")
73 | PythonWorkerFactory.destroyPythonWorker(pythonExec, envVars, worker)
74 | } catch {
75 | case e: Exception =>
76 | logError("Exception when trying to kill worker", e)
77 | }
78 | }
79 | }
80 | }
81 | }
82 |
83 | override val arrowPythonRunner: ArrowPythonRunner = _arrowPythonRunner
84 |
85 | override def javaSideSocketServerRegister(): ServerSocket => Unit = {
86 | (server: ServerSocket) => {
87 | context.addTaskCompletionListener(new TaskCompletionListener {
88 | override def onTaskCompletion(context: TaskContext): Unit = {
89 | server.close()
90 | }
91 | })
92 | }
93 | }
94 |
95 | override def isTaskCompleteOrInterrupt(): () => Boolean = {
96 | () => {
97 | context.isCompleted || context.isInterrupted
98 | }
99 | }
100 |
101 | override def isTaskInterrupt(): () => Boolean = {
102 | () => {
103 | context.isInterrupted
104 | }
105 | }
106 |
107 | override def getTaskKillReason(): () => Option[String] = {
108 | () => {
109 | SparkUtils.getKillReason(context)
110 | }
111 | }
112 |
113 | override def killTaskIfInterrupted(): () => Unit = {
114 | () => {
115 | SparkUtils.killTaskIfInterrupted(context)
116 | }
117 | }
118 |
119 | override def readerRegister(callback: () => Unit): (ArrowStreamReader, BufferAllocator) => Unit = {
120 | (reader, allocator) => {
121 | context.addTaskCompletionListener(new TaskCompletionListener {
122 | override def onTaskCompletion(context: TaskContext): Unit = {
123 | if (reader != null) {
124 | reader.close(false)
125 | }
126 | // 这里有个特殊情况,用户可能只会读取部分数据,这个时候进行close,会
127 | // 显示内存泄露,此时进行close会抛错,我们需要catch住这个错误。
128 | // 目前来看,资源应该能够得到释放。大部分情况,我们都能正常消费掉所有数据。
129 | try {
130 | allocator.close()
131 | } catch {
132 | case e: Exception =>
133 | logError("allocator.close()", e)
134 | }
135 |
136 | }
137 | })
138 | }
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/runner/ArrowPythonRunner.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python.runner
2 |
3 | import org.apache.arrow.vector.VectorSchemaRoot
4 | import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
5 | import org.apache.spark.sql.catalyst.InternalRow
6 | import org.apache.spark.sql.types._
7 | import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch}
8 | import tech.mlsql.arrow.context.CommonTaskContext
9 | import tech.mlsql.arrow.{ArrowUtils, ArrowWriter, Utils}
10 |
11 | import java.io._
12 | import java.net._
13 | import java.util.concurrent.atomic.AtomicBoolean
14 | import scala.collection.JavaConverters._
15 |
16 |
17 | /**
18 | * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
19 | */
20 | class ArrowPythonRunner(
21 | funcs: Seq[ChainedPythonFunctions],
22 | schema: StructType,
23 | timeZoneId: String,
24 | conf: Map[String, String])
25 | extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
26 | funcs, conf) {
27 |
28 | protected override def newWriterThread(
29 | worker: Socket,
30 | inputIterator: Iterator[Iterator[InternalRow]],
31 | partitionIndex: Int,
32 | context: CommonTaskContext): WriterThread = {
33 | new WriterThread(worker, inputIterator, partitionIndex, context) {
34 |
35 | protected override def writeCommand(dataOut: DataOutputStream): Unit = {
36 |
37 | // Write config for the worker as a number of key -> value pairs of strings
38 | dataOut.writeInt(conf.size + 1)
39 | for ((k, v) <- conf) {
40 | Utils.writeUTF(k, dataOut)
41 | Utils.writeUTF(v, dataOut)
42 | }
43 | Utils.writeUTF("timezone", dataOut)
44 | Utils.writeUTF(timeZoneId, dataOut)
45 |
46 | val command = funcs.head.funcs.head.command
47 | Utils.writeUTF(command, dataOut)
48 |
49 | }
50 |
51 | protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
52 | val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
53 | val allocator = ArrowUtils.rootAllocator.newChildAllocator(
54 | s"stdout writer for $pythonExec", 0, Long.MaxValue)
55 | val root = VectorSchemaRoot.create(arrowSchema, allocator)
56 |
57 | Utils.tryWithSafeFinally {
58 | val arrowWriter = ArrowWriter.create(root)
59 | val writer = new ArrowStreamWriter(root, null, dataOut)
60 | writer.start()
61 |
62 | while (inputIterator.hasNext) {
63 | val nextBatch = inputIterator.next()
64 |
65 | while (nextBatch.hasNext) {
66 | arrowWriter.write(nextBatch.next())
67 | }
68 |
69 | arrowWriter.finish()
70 | writer.writeBatch()
71 | arrowWriter.reset()
72 | }
73 | // end writes footer to the output stream and doesn't clean any resources.
74 | // It could throw exception if the output stream is closed, so it should be
75 | // in the try block.
76 | writer.end()
77 | } {
78 | // If we close root and allocator in TaskCompletionListener, there could be a race
79 | // condition where the writer thread keeps writing to the VectorSchemaRoot while
80 | // it's being closed by the TaskCompletion listener.
81 | // Closing root and allocator here is cleaner because root and allocator is owned
82 | // by the writer thread and is only visible to the writer thread.
83 | //
84 | // If the writer thread is interrupted by TaskCompletionListener, it should either
85 | // (1) in the try block, in which case it will get an InterruptedException when
86 | // performing io, and goes into the finally block or (2) in the finally block,
87 | // in which case it will ignore the interruption and close the resources.
88 | root.close()
89 | allocator.close()
90 | }
91 | }
92 | }
93 | }
94 |
95 | protected override def newReaderIterator(
96 | stream: DataInputStream,
97 | writerThread: WriterThread,
98 | startTime: Long,
99 | worker: Socket,
100 | releasedOrClosed: AtomicBoolean,
101 | context: CommonTaskContext): Iterator[ColumnarBatch] = {
102 | new ReaderIterator(stream, writerThread, startTime, worker, releasedOrClosed, context) {
103 |
104 | private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
105 | s"stdin reader for $pythonExec", 0, Long.MaxValue)
106 |
107 | private var reader: ArrowStreamReader = _
108 | private var root: VectorSchemaRoot = _
109 | private var schema: StructType = _
110 | private var vectors: Array[ColumnVector] = _
111 | context.readerRegister(() => {})(reader, allocator)
112 |
113 | private var batchLoaded = true
114 |
115 | protected override def read(): ColumnarBatch = {
116 | if (writerThread.exception.isDefined) {
117 | throw writerThread.exception.get
118 | }
119 | try {
120 | if (reader != null && batchLoaded) {
121 | batchLoaded = reader.loadNextBatch()
122 | if (batchLoaded) {
123 | val batch = new ColumnarBatch(vectors)
124 | batch.setNumRows(root.getRowCount)
125 | batch
126 | } else {
127 | reader.close(false)
128 | allocator.close()
129 | // Reach end of stream. Call `read()` again to read control data.
130 | read()
131 | }
132 | } else {
133 | stream.readInt() match {
134 | case SpecialLengths.START_ARROW_STREAM =>
135 | try {
136 | reader = new ArrowStreamReader(stream, allocator)
137 | root = reader.getVectorSchemaRoot()
138 | schema = ArrowUtils.fromArrowSchema(root.getSchema())
139 | vectors = root.getFieldVectors().asScala.map { vector =>
140 | new ArrowColumnVector(vector)
141 | }.toArray[ColumnVector]
142 | read()
143 | } catch {
144 | case e: IOException if (e.getMessage.contains("Missing schema") || e.getMessage.contains("Expected schema but header was")) =>
145 | logInfo("Arrow read schema fail", e)
146 | reader = null
147 | read()
148 | }
149 |
150 | case SpecialLengths.ARROW_STREAM_CRASH =>
151 | read()
152 |
153 | case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
154 | throw handlePythonException()
155 |
156 | case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
157 | throw handlePythonException()
158 |
159 | case SpecialLengths.END_OF_DATA_SECTION =>
160 | handleEndOfDataSection()
161 | null
162 | }
163 | }
164 | } catch handleException
165 | }
166 | }
167 | }
168 | }
169 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/runner/PythonProjectRunner.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python.runner
2 |
3 | import java.io._
4 | import java.util.concurrent.atomic.AtomicReference
5 |
6 | import os.SubProcess
7 | import tech.mlsql.arrow.Utils
8 | import tech.mlsql.common.utils.log.Logging
9 | import tech.mlsql.common.utils.shell.ShellCommand
10 |
11 | import scala.io.Source
12 |
13 | /**
14 | * 2019-08-22 WilliamZhu(allwefantasy@gmail.com)
15 | */
16 | class PythonProjectRunner(projectDirectory: String,
17 | env: Map[String, String]) extends Logging {
18 |
19 | import PythonProjectRunner._
20 |
21 | private var innerProcess: Option[SubProcess] = None
22 |
23 | def getPythonProcess = innerProcess
24 |
25 | def run(command: Seq[String],
26 | conf: Map[String, String]
27 | ) = {
28 | val proc = os.proc(command).spawn(
29 | cwd = os.Path(projectDirectory),
30 | env = env)
31 | innerProcess = Option(proc)
32 | val (_, pythonPid) = try {
33 | val f = proc.wrapped.getClass.getDeclaredField("pid")
34 | f.setAccessible(true)
35 | val parentPid = f.getLong(proc.wrapped)
36 | val subPid = ShellCommand.execCmdV2("pgrep", "-P", parentPid).out.lines.mkString("")
37 | (parentPid, subPid)
38 | } catch {
39 | case e: Exception =>
40 | logWarning(
41 | s"""
42 | |${command.mkString(" ")} may not been killed since we can not get it's pid.
43 | |Make sure you are runing on mac/linux and pgrep is installed.
44 | |""".stripMargin)
45 | (-1, -1)
46 | }
47 |
48 |
49 | val lines = Source.fromInputStream(proc.stdout.wrapped)("utf-8").getLines
50 | val childThreadException = new AtomicReference[Throwable](null)
51 | // Start a thread to print the process's stderr to ours
52 | new Thread(s"stdin writer for $command") {
53 | def writeConf = {
54 | val dataOut = new DataOutputStream(proc.stdin)
55 | dataOut.writeInt(conf.size)
56 | for ((k, v) <- conf) {
57 | Utils.writeUTF(k, dataOut)
58 | Utils.writeUTF(v, dataOut)
59 | }
60 | }
61 |
62 | override def run(): Unit = {
63 | try {
64 | writeConf
65 | } catch {
66 | case t: Throwable => childThreadException.set(t)
67 | } finally {
68 | proc.stdin.close()
69 | }
70 | }
71 | }.start()
72 |
73 | // redirect err to other place(e.g. send them to driver)
74 | new Thread(s"stderr reader for $command") {
75 | override def run(): Unit = {
76 | if (conf.getOrElse("throwErr", "true").toBoolean) {
77 | val err = proc.stderr.lines.mkString("\n")
78 | if (!err.isEmpty) {
79 | childThreadException.set(new PythonErrException(err))
80 | }
81 | } else {
82 | Utils.redirectStream(conf, proc.stderr)
83 | }
84 | }
85 | }.start()
86 |
87 |
88 | new Iterator[String] {
89 | def next(): String = {
90 | if (!hasNext()) {
91 | throw new NoSuchElementException()
92 | }
93 | val line = lines.next()
94 | line
95 | }
96 |
97 | def hasNext(): Boolean = {
98 | val result = if (lines.hasNext) {
99 | true
100 | } else {
101 | try {
102 | proc.waitFor()
103 | }
104 | catch {
105 | case e: InterruptedException =>
106 | 0
107 | }
108 | cleanup()
109 | if (proc.exitCode() != 0) {
110 | val msg = s"Subprocess exited with status ${proc.exitCode()}. " +
111 | s"Command ran: " + command.mkString(" ")
112 | if(childThreadException.get()!=null){
113 | throw childThreadException.get()
114 | }else {
115 | throw new IllegalStateException(msg)
116 | }
117 | }
118 | false
119 | }
120 | propagateChildException
121 | result
122 | }
123 |
124 | private def cleanup(): Unit = {
125 | ShellCommand.execCmdV2("kill", "-9", pythonPid + "")
126 | // cleanup task working directory if used
127 | scala.util.control.Exception.ignoring(classOf[IOException]) {
128 | if (conf.get(KEEP_LOCAL_DIR).map(_.toBoolean).getOrElse(false)) {
129 | Utils.deleteRecursively(new File(projectDirectory))
130 | }
131 | }
132 | log.debug(s"Removed task working directory $projectDirectory")
133 | }
134 |
135 | private def propagateChildException(): Unit = {
136 | val t = childThreadException.get()
137 | if (t != null) {
138 | proc.destroy()
139 | cleanup()
140 | throw t
141 | }
142 | }
143 |
144 | }
145 | }
146 | }
147 |
148 | object PythonProjectRunner {
149 | val KEEP_LOCAL_DIR = "keepLocalDir"
150 | }
151 |
152 | class PythonErrException(message: String, cause: Throwable)
153 | extends Exception(message, cause) {
154 |
155 | def this(message: String) = this(message, null)
156 | }
157 |
--------------------------------------------------------------------------------
/src/main/java/tech/mlsql/arrow/python/runner/SparkSocketRunner.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.arrow.python.runner
2 |
3 | import org.apache.arrow.vector.VectorSchemaRoot
4 | import org.apache.arrow.vector.ipc.ArrowStreamReader
5 | import org.apache.spark.sql.catalyst.InternalRow
6 | import org.apache.spark.sql.types.StructType
7 | import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch}
8 | import org.apache.spark.{SparkException, TaskKilledException}
9 | import tech.mlsql.arrow._
10 | import tech.mlsql.arrow.context.CommonTaskContext
11 | import tech.mlsql.common.utils.distribute.socket.server.SocketServerInExecutor
12 | import tech.mlsql.common.utils.log.Logging
13 |
14 | import java.io._
15 | import java.net.Socket
16 | import java.nio.charset.StandardCharsets
17 | import scala.collection.JavaConverters._
18 |
19 |
20 | class SparkSocketRunner(runnerName: String, host: String, timeZoneId: String) {
21 |
22 | def serveToStream(threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
23 | val (_server, _host, _port) = SocketServerInExecutor.setupOneConnectionServer(host, runnerName)(s => {
24 | val out = new BufferedOutputStream(s.getOutputStream())
25 | Utils.tryWithSafeFinally {
26 | writeFunc(out)
27 | } {
28 | out.close()
29 | }
30 | })
31 |
32 | Array(_server, _host, _port)
33 | }
34 |
35 | def serveToStreamWithArrow(iter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, context: CommonTaskContext) = {
36 | serveToStream(runnerName) { out =>
37 | val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
38 | val arrowBatch = ArrowConverters.toBatchIterator(
39 | iter, schema, maxRecordsPerBatch, timeZoneId, context)
40 | batchWriter.writeBatches(arrowBatch)
41 | batchWriter.end()
42 | }
43 | }
44 |
45 | def readFromStreamWithArrow(host: String, port: Int, context: CommonTaskContext) = {
46 | val socket = new Socket(host, port)
47 | val stream = new DataInputStream(socket.getInputStream)
48 | val outfile = new DataOutputStream(socket.getOutputStream)
49 | new ReaderIterator[ColumnarBatch](stream, System.currentTimeMillis(), context) {
50 | private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
51 | s"stdin reader ", 0, Long.MaxValue)
52 |
53 | private var reader: ArrowStreamReader = _
54 | private var root: VectorSchemaRoot = _
55 | private var schema: StructType = _
56 | private var vectors: Array[ColumnVector] = _
57 | context.readerRegister(() => {})(reader, allocator)
58 |
59 | private var batchLoaded = true
60 |
61 | protected override def read(): ColumnarBatch = {
62 | try {
63 | if (reader != null && batchLoaded) {
64 | batchLoaded = reader.loadNextBatch()
65 | if (batchLoaded) {
66 | val batch = new ColumnarBatch(vectors)
67 | batch.setNumRows(root.getRowCount)
68 | batch
69 | } else {
70 | reader.close(false)
71 | allocator.close()
72 | // Reach end of stream. Call `read()` again to read control data.
73 | read()
74 | }
75 | } else {
76 | stream.readInt() match {
77 | case SpecialLengths.START_ARROW_STREAM =>
78 |
79 | try {
80 | reader = new ArrowStreamReader(stream, allocator)
81 | root = reader.getVectorSchemaRoot()
82 | schema = ArrowUtils.fromArrowSchema(root.getSchema())
83 | vectors = root.getFieldVectors().asScala.map { vector =>
84 | new ArrowColumnVector(vector)
85 | }.toArray[ColumnVector]
86 | read()
87 | } catch {
88 | case e: IOException if (e.getMessage.contains("Missing schema") || e.getMessage.contains("Expected schema but header was")) =>
89 | logInfo("Arrow read schema fail", e)
90 | reader = null
91 | read()
92 | }
93 |
94 | case SpecialLengths.ARROW_STREAM_CRASH =>
95 | read()
96 |
97 | case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
98 | throw handlePythonException(outfile)
99 |
100 | case SpecialLengths.END_OF_DATA_SECTION =>
101 | handleEndOfDataSection(outfile)
102 | null
103 | }
104 | }
105 | } catch handleException
106 | }
107 | }.flatMap { batch =>
108 | batch.rowIterator.asScala
109 | }
110 | }
111 |
112 | }
113 |
114 | object SparkSocketRunner {
115 |
116 | }
117 |
118 | abstract class ReaderIterator[OUT](
119 | stream: DataInputStream,
120 | startTime: Long,
121 | context: CommonTaskContext)
122 | extends Iterator[OUT] with Logging {
123 |
124 | private var nextObj: OUT = _
125 | private var eos = false
126 |
127 | override def hasNext: Boolean = nextObj != null || {
128 | if (!eos) {
129 | nextObj = read()
130 | hasNext
131 | } else {
132 | false
133 | }
134 | }
135 |
136 | override def next(): OUT = {
137 | if (hasNext) {
138 | val obj = nextObj
139 | nextObj = null.asInstanceOf[OUT]
140 | obj
141 | } else {
142 | Iterator.empty.next()
143 | }
144 | }
145 |
146 | /**
147 | * Reads next object from the stream.
148 | * When the stream reaches end of data, needs to process the following sections,
149 | * and then returns null.
150 | */
151 | protected def read(): OUT
152 |
153 |
154 | protected def handlePythonException(out: DataOutputStream): SparkException = {
155 | // Signals that an exception has been thrown in python
156 | val exLength = stream.readInt()
157 | val obj = new Array[Byte](exLength)
158 | stream.readFully(obj)
159 | try {
160 | out.writeInt(SpecialLengths.END_OF_STREAM)
161 | out.flush()
162 | } catch {
163 | case e: Exception => logError("", e)
164 | }
165 | new SparkException(new String(obj, StandardCharsets.UTF_8), null)
166 | }
167 |
168 | protected def handleEndOfStream(out: DataOutputStream): Unit = {
169 |
170 | eos = true
171 | }
172 |
173 | protected def handleEndOfDataSection(out: DataOutputStream): Unit = {
174 | //read end of stream
175 | val flag = stream.readInt()
176 | if (flag != SpecialLengths.END_OF_STREAM) {
177 | logWarning(
178 | s"""
179 | |-----------------------WARNING--------------------------------------------------------------------
180 | |Here we should received message is SpecialLengths.END_OF_STREAM:${SpecialLengths.END_OF_STREAM}
181 | |But It's now ${flag}.
182 | |This may cause the **** python worker leak **** and make the ***interactive mode fails***.
183 | |--------------------------------------------------------------------------------------------------
184 | """.stripMargin)
185 | }
186 | try {
187 | out.writeInt(SpecialLengths.END_OF_STREAM)
188 | out.flush()
189 | } catch {
190 | case e: Exception => logError("", e)
191 | }
192 |
193 | eos = true
194 | }
195 |
196 | protected val handleException: PartialFunction[Throwable, OUT] = {
197 | case e: Exception if context.isTaskInterrupt()() =>
198 | logDebug("Exception thrown after task interruption", e)
199 | throw new TaskKilledException(context.getTaskKillReason()().getOrElse("unknown reason"))
200 | case e: SparkException =>
201 | throw e
202 | case eof: EOFException =>
203 | throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
204 |
205 | case e: Exception =>
206 | throw new SparkException("Error to read", e)
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/ApplyPythonScript.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import org.apache.spark.{TaskContext, WowRowEncoder}
4 | import org.apache.spark.sql.Row
5 | import org.apache.spark.sql.catalyst.InternalRow
6 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
7 | import org.apache.spark.sql.types.StructType
8 | import tech.mlsql.arrow.python.ispark.SparkContextImp
9 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonFunction}
10 |
11 | import scala.collection.JavaConverters._
12 |
13 | /**
14 | * 19/4/2021 WilliamZhu(allwefantasy@gmail.com)
15 | */
16 | class ApplyPythonScript(_rayAddress: String, _envs: java.util.HashMap[String, String], _timezoneId: String, _pythonMode: String = "ray") {
17 | def execute(script: String, struct: StructType): Iterator[Row] => Iterator[InternalRow] = {
18 | // val rayAddress = _rayAddress
19 | val envs = _envs
20 | val timezoneId = _timezoneId
21 | val pythonMode = _pythonMode
22 |
23 | iter =>
24 | val encoder = WowRowEncoder.fromRow(struct)
25 | //RowEncoder.apply(struct).resolveAndBind()
26 | val batch = new ArrowPythonRunner(
27 | Seq(ChainedPythonFunctions(Seq(PythonFunction(
28 | script, envs, "python", "3.6")))), struct,
29 | timezoneId, Map("pythonMode" -> pythonMode)
30 | )
31 | val newIter = iter.map(encoder)
32 | val commonTaskContext = new SparkContextImp(TaskContext.get(), batch)
33 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
34 | columnarBatchIter.flatMap(_.rowIterator.asScala).map(f => f.copy())
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/JavaApp1Spec.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import java.util
4 |
5 | import org.apache.spark.sql.Row
6 | import org.apache.spark.sql.types.{ArrayType, DoubleType, StructField, StructType}
7 | import org.apache.spark.{TaskContext, WowRowEncoder}
8 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
9 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext}
10 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonConf, PythonFunction}
11 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros.str
12 |
13 | import scala.collection.JavaConverters._
14 |
15 | /**
16 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com)
17 | */
18 | class JavaApp1Spec extends FunSuite
19 | with BeforeAndAfterAll {
20 |
21 | def condaEnv = "source /Users/allwefantasy/opt/anaconda3/bin/activate ray-dev"
22 |
23 | test("normal java application") {
24 | val envs = new util.HashMap[String, String]()
25 | envs.put(str(PythonConf.PYTHON_ENV), s"${condaEnv} && export ARROW_PRE_0_15_IPC_FORMAT=1 ")
26 | val sourceSchema = StructType(Seq(StructField("value", ArrayType(DoubleType))))
27 |
28 | val runnerConf = Map(
29 | "HOME" -> "",
30 | "OWNER" -> "",
31 | "GROUP_ID" -> "",
32 | "directData" -> "true",
33 | "runIn" -> "driver",
34 | "mode" -> "model",
35 | "rayAddress" -> "127.0.0.1:10001",
36 | "pythonMode" -> "ray"
37 | )
38 |
39 | val batch = new ArrowPythonRunner(
40 | Seq(ChainedPythonFunctions(Seq(PythonFunction(
41 | """
42 | |import ray
43 | |from pyjava.api.mlsql import RayContext
44 | |ray_context = RayContext.connect(globals(), context.conf["rayAddress"],namespace="default")
45 | |udfMaster = ray.get_actor("model_predict")
46 | |[index,worker] = ray.get(udfMaster.get.remote())
47 | |
48 | |input = [row["value"] for row in context.fetch_once_as_rows()]
49 | |try:
50 | | res = ray.get(worker.apply.remote(input))
51 | |except Exception as inst:
52 | | res=[]
53 | | print(inst)
54 | |
55 | |udfMaster.give_back.remote(index)
56 | |ray_context.build_result([res])
57 | """.stripMargin, envs, "python", "3.6")))), sourceSchema,
58 | "GMT", runnerConf
59 | )
60 |
61 |
62 | val sourceEnconder = WowRowEncoder.fromRow(sourceSchema) //RowEncoder.apply(sourceSchema).resolveAndBind()
63 | val newIter = Seq(Row.fromSeq(Seq(Seq(1.1))), Row.fromSeq(Seq(Seq(1.2)))).map { irow =>
64 | sourceEnconder(irow).copy()
65 | }.iterator
66 |
67 | val javaConext = new JavaContext
68 | val commonTaskContext = new AppContextImpl(javaConext, batch)
69 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
70 | val outputSchema = StructType(Seq(StructField("value", ArrayType(ArrayType(DoubleType)))))
71 | val outputEnconder = WowRowEncoder.toRow(outputSchema)
72 | columnarBatchIter.flatMap { batch =>
73 | batch.rowIterator.asScala
74 | }.map { r =>
75 | outputEnconder(r)
76 | }.toList.foreach(item => println(item.getAs[Seq[Seq[DoubleType]]](0)))
77 |
78 | javaConext.markComplete
79 | javaConext.close
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/JavaAppSpec.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import java.util
4 |
5 | import org.apache.spark.{TaskContext, WowRowEncoder}
6 | import org.apache.spark.sql.Row
7 | import org.apache.spark.sql.catalyst.encoders.RowEncoder
8 | import org.apache.spark.sql.types.{StringType, StructField, StructType}
9 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
10 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext}
11 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonConf, PythonFunction}
12 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros.str
13 |
14 | import scala.collection.JavaConverters._
15 |
16 | /**
17 | * 2019-08-15 WilliamZhu(allwefantasy@gmail.com)
18 | */
19 | class JavaAppSpec extends FunSuite
20 | with BeforeAndAfterAll {
21 |
22 | def condaEnv = "source /Users/allwefantasy/opt/anaconda3/bin/activate dev"
23 |
24 | test("normal java application") {
25 | val envs = new util.HashMap[String, String]()
26 | envs.put(str(PythonConf.PYTHON_ENV), s"${condaEnv} && export ARROW_PRE_0_15_IPC_FORMAT=1 ")
27 | val sourceSchema = StructType(Seq(StructField("value", StringType)))
28 | val batch = new ArrowPythonRunner(
29 | Seq(ChainedPythonFunctions(Seq(PythonFunction(
30 | """
31 | |import pandas as pd
32 | |import numpy as np
33 | |
34 | |def process():
35 | | for item in context.fetch_once_as_rows():
36 | | item["value1"] = item["value"] + "_suffix"
37 | | yield item
38 | |
39 | |context.build_result(process())
40 | """.stripMargin, envs, "python", "3.6")))), sourceSchema,
41 | "GMT", Map()
42 | )
43 |
44 |
45 | val sourceEnconder = WowRowEncoder.fromRow(sourceSchema) //RowEncoder.apply(sourceSchema).resolveAndBind()
46 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow =>
47 | sourceEnconder(irow).copy()
48 | }.iterator
49 |
50 | val javaConext = new JavaContext
51 | val commonTaskContext = new AppContextImpl(javaConext, batch)
52 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
53 | //copy is required
54 | columnarBatchIter.flatMap { batch =>
55 | batch.rowIterator.asScala
56 | }.foreach(f => println(f.copy()))
57 | javaConext.markComplete
58 | javaConext.close
59 | }
60 |
61 | }
62 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/JavaArrowServer.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import org.apache.spark.WowRowEncoder
4 | import org.apache.spark.sql.Row
5 | import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
6 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
7 | import tech.mlsql.arrow.python.iapp.{AppContextImpl, JavaContext}
8 | import tech.mlsql.arrow.python.runner.SparkSocketRunner
9 | import tech.mlsql.common.utils.network.NetUtils
10 |
11 | /**
12 | * 24/12/2019 WilliamZhu(allwefantasy@gmail.com)
13 | */
14 | class JavaArrowServer extends FunSuite with BeforeAndAfterAll {
15 |
16 | test("test java arrow server") {
17 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin")
18 |
19 | val dataSchema = StructType(Seq(StructField("value", StringType)))
20 | val encoder = WowRowEncoder.fromRow(dataSchema) //RowEncoder.apply(dataSchema).resolveAndBind()
21 | val newIter = Seq(Row.fromSeq(Seq("a1")), Row.fromSeq(Seq("a2"))).map { irow =>
22 | encoder(irow)
23 | }.iterator
24 | val javaConext = new JavaContext
25 | val commonTaskContext = new AppContextImpl(javaConext, null)
26 |
27 | val Array(_, host, port) = socketRunner.serveToStreamWithArrow(newIter, dataSchema, 10, commonTaskContext)
28 | println(s"${host}:${port}")
29 | Thread.currentThread().join()
30 | }
31 |
32 | test("test read python arrow server") {
33 | val enconder = WowRowEncoder.toRow(StructType(Seq(StructField("a", LongType),StructField("b", LongType))))
34 | val socketRunner = new SparkSocketRunner("wow", NetUtils.getHost, "Asia/Harbin")
35 | val javaConext = new JavaContext
36 | val commonTaskContext = new AppContextImpl(javaConext, null)
37 | val iter = socketRunner.readFromStreamWithArrow("127.0.0.1", 11111, commonTaskContext)
38 | iter.foreach(i => println(enconder(i.copy())))
39 | javaConext.close
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/Main.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import tech.mlsql.arrow.python.runner.PythonProjectRunner
4 | import tech.mlsql.common.utils.path.PathFun
5 |
6 | /**
7 | * 4/3/2020 WilliamZhu(allwefantasy@gmail.com)
8 | */
9 | object Main {
10 | def main(args: Array[String]): Unit = {
11 | val project = getExampleProject("pyproject1")
12 | val runner = new PythonProjectRunner(project, Map())
13 | val output = runner.run(Seq("bash", "-c", "source activate dev && python -u train.py"), Map(
14 | "tempDataLocalPath" -> "/tmp/data",
15 | "tempModelLocalPath" -> "/tmp/model"
16 | ))
17 | output.foreach(println)
18 | }
19 | def getExampleProject(name: String) = {
20 | PathFun(getHome).add("examples").add(name).toPath
21 | }
22 |
23 | def getHome = {
24 | getClass.getResource("").getPath.split("target/test-classes").head
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/PythonProjectSpec.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
4 | import tech.mlsql.arrow.python.runner.PythonProjectRunner
5 | import tech.mlsql.common.utils.path.PathFun
6 |
7 | /**
8 | * 2019-08-22 WilliamZhu(allwefantasy@gmail.com)
9 | */
10 | class PythonProjectSpec extends FunSuite with BeforeAndAfterAll {
11 | test("test python project") {
12 | val project = getExampleProject("pyproject1")
13 | val runner = new PythonProjectRunner(project, Map())
14 | val output = runner.run(Seq("bash", "-c", "source activate dev && python -u train.py"), Map(
15 | "tempDataLocalPath" -> "/tmp/data",
16 | "tempModelLocalPath" -> "/tmp/model"
17 | ))
18 | output.foreach(println)
19 | }
20 |
21 | def getExampleProject(name: String) = {
22 | PathFun(getHome).add("examples").add(name).toPath
23 | }
24 |
25 | def getHome = {
26 | getClass.getResource("").getPath.split("target/test-classes").head
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/RayEnv.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2016 Kyligence Inc. All rights reserved.
3 | *
4 | * http://kyligence.io
5 | *
6 | * This software is the confidential and proprietary information of
7 | * Kyligence Inc. ("Confidential Information"). You shall not disclose
8 | * such Confidential Information and shall use it only in accordance
9 | * with the terms of the license agreement you entered into with
10 | * Kyligence Inc.
11 | *
12 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
13 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
14 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
15 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
16 | * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
17 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18 | * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
19 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
20 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 | */
24 | package tech.mlsql.test
25 |
26 | import java.io.File
27 | import java.nio.charset.StandardCharsets
28 | import java.util.TimeZone
29 | import java.util.regex.Pattern
30 |
31 | import com.google.common.io.Files
32 | import org.apache.spark.{TaskContext, WowRowEncoder}
33 | import org.apache.spark.rdd.RDD
34 | import org.apache.spark.sql.catalyst.InternalRow
35 | import org.apache.spark.sql.{DataFrame, Row}
36 | import os.CommandResult
37 | import tech.mlsql.arrow.python.ispark.SparkContextImp
38 | import tech.mlsql.arrow.python.runner.SparkSocketRunner
39 | import tech.mlsql.common.utils.log.Logging
40 | import tech.mlsql.common.utils.net.NetTool
41 | import tech.mlsql.common.utils.network.NetUtils
42 | import tech.mlsql.common.utils.shell.ShellCommand
43 | import tech.mlsql.test.RayEnv.ServerInfo
44 |
45 | class RayEnv extends Logging with Serializable {
46 |
47 | @transient var rayAddress: String = _
48 | @transient var rayOptions: Map[String, String] = Map.empty
49 |
50 | @transient var dataServers: Seq[ServerInfo] = _
51 |
52 | def startRay(envName: String): Unit = {
53 |
54 | val rayVersionRes = runWithProfile(envName, "ray --version")
55 | val rayVersion = rayVersionRes.out.lines.toList.head.split("version").last.trim
56 |
57 | val result = if (rayVersion >= "1.0.0" || rayVersion == "0.8.7") {
58 | runWithProfile(envName, "ray start --head")
59 | } else {
60 | runWithProfile(envName, "ray start --head --include-webui")
61 | }
62 | val initRegex = Pattern.compile(".*ray.init\\((.*)\\).*")
63 | val errResultLines = result.err.lines.toList
64 | val outResultLines = result.out.lines.toList
65 |
66 | errResultLines.foreach(logInfo(_))
67 | outResultLines.foreach(logInfo(_))
68 |
69 | (errResultLines ++ outResultLines).foreach(line => {
70 | val matcher = initRegex.matcher(line)
71 | if (matcher.matches()) {
72 | val params = matcher.group(1)
73 |
74 |
75 | val options = params.split(",")
76 | .filter(_.contains("=")).map(param => {
77 | val words = param.trim.split("=")
78 | (words.head, words(1).substring(1, words(1).length - 1))
79 | }).toMap
80 |
81 |
82 | if (rayAddress == null) {
83 | rayAddress = options.getOrElse("address", options.getOrElse("redis_address", null))
84 | rayOptions = options - "address" - "redis_address"
85 | logInfo(s"Start Ray:${rayAddress}")
86 | }
87 | }
88 | })
89 |
90 | if (rayVersion >= "1.0.0") {
91 | val initRegex2 = Pattern.compile(".*--address='(.*?)'.*")
92 | (errResultLines ++ outResultLines).foreach(line => {
93 | val matcher = initRegex2.matcher(line)
94 | if (matcher.matches()) {
95 | val params = matcher.group(1)
96 |
97 | val host = params.split(":").head
98 | rayAddress = host + ":10001"
99 | rayOptions = rayOptions
100 | new Thread(new Runnable {
101 | override def run(): Unit = {
102 | runWithProfile(envName,
103 | s"""
104 | |python -m ray.util.client.server --host ${host} --port 10001
105 | |""".stripMargin)
106 | }
107 | }).start()
108 | Thread.sleep(3000)
109 | logInfo(s"Start Ray:${rayAddress}")
110 | }
111 | })
112 | }
113 |
114 |
115 | if (rayAddress == null) {
116 | throw new RuntimeException("Fail to start ray")
117 | }
118 |
119 |
120 | }
121 |
122 | def stopRay(envName: String): Unit = {
123 | val result = runWithProfile(envName, "ray stop")
124 | result
125 | }
126 |
127 | def startDataServer(df: DataFrame): Unit = {
128 | val dataSchema = df.schema
129 | dataServers = df.repartition(1).rdd.mapPartitions(iter => {
130 | val socketServer = new SparkSocketRunner("serve-runner-for-ut", NetTool.localHostName(), TimeZone.getDefault.getID)
131 | val commonTaskContext = new SparkContextImp(TaskContext.get(), null)
132 | val rab = WowRowEncoder.fromRow(dataSchema) //RowEncoder.apply(dataSchema).resolveAndBind()
133 | val newIter = iter.map(row => {
134 | rab(row)
135 | })
136 | val Array(_server, _host: String, _port: Int) = socketServer.serveToStreamWithArrow(newIter, dataSchema, 10, commonTaskContext)
137 | Seq(ServerInfo(_host, _port, TimeZone.getDefault.getID)).iterator
138 | }).collect().toSeq
139 | }
140 |
141 | def collectResult(rdd: RDD[Row]): RDD[InternalRow] = {
142 | rdd.flatMap { row =>
143 | val socketRunner = new SparkSocketRunner("read-runner-for-ut", NetUtils.getHost, TimeZone.getDefault.getID)
144 | val commonTaskContext = new SparkContextImp(TaskContext.get(), null)
145 | val pythonWorkerHost = row.getAs[String]("host")
146 | val pythonWorkerPort = row.getAs[Long]("port").toInt
147 | logInfo(s" Ray On Data Mode: connect python worker[${pythonWorkerHost}:${pythonWorkerPort}] ")
148 | val iter = socketRunner.readFromStreamWithArrow(pythonWorkerHost, pythonWorkerPort, commonTaskContext)
149 | iter.map(f => f.copy())
150 | }
151 | }
152 |
153 |
154 | private def runWithProfile(envName: String, command: String): CommandResult = {
155 | val tmpShellFile = File.createTempFile("shell", ".sh")
156 | val setupEnv = if (envName.trim.startsWith("conda") || envName.trim.startsWith("source")) envName else s"""conda activate ${envName}"""
157 | try {
158 | Files.write(
159 | s"""
160 | |#!/bin/bash
161 | |export LC_ALL=en_US.utf-8
162 | |export LANG=en_US.utf-8
163 | |
164 | |# source ~/.bash_profile
165 | |${setupEnv}
166 | |${command}
167 | |""".stripMargin, tmpShellFile, StandardCharsets.UTF_8)
168 | val cmdResult = ShellCommand.execCmdV2("/bin/bash", tmpShellFile.getAbsolutePath)
169 | // if (cmdResult.exitCode != 0) {
170 | // throw new RuntimeException(s"run command failed ${cmdResult.toString()}")
171 | // }
172 | cmdResult
173 | } finally {
174 | tmpShellFile.delete()
175 | }
176 | }
177 |
178 | }
179 |
180 | object RayEnv {
181 |
182 | case class ServerInfo(host: String, port: Long, timezone: String)
183 |
184 | }
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/SparkSpec.scala:
--------------------------------------------------------------------------------
1 | package tech.mlsql.test
2 |
3 | import java.util
4 |
5 | import org.apache.spark.sql.{Row, SparkSession, SparkUtils}
6 | import org.scalatest.{BeforeAndAfterAll, FunSuite}
7 | import tech.mlsql.common.utils.log.Logging
8 | import tech.mlsql.test.function.SparkFunctions.MockData
9 |
10 | /**
11 | * 2019-08-14 WilliamZhu(allwefantasy@gmail.com)
12 | */
13 | class SparkSpec extends FunSuite with BeforeAndAfterAll with Logging{
14 |
15 | val rayEnv = new RayEnv
16 | var spark: SparkSession = null
17 |
18 | def condaEnv = "source /Users/allwefantasy/opt/anaconda3/bin/activate ray1.2"
19 |
20 | //spark.executor.heartbeatInterval
21 | test("test python ray connect") {
22 | val session = spark
23 | import session.implicits._
24 | val timezoneId = session.sessionState.conf.sessionLocalTimeZone
25 |
26 | val dataDF = session.createDataset(Range(0, 100).map(i => MockData(s"Title${i}", s"body-${i}"))).toDF()
27 | rayEnv.startDataServer(dataDF)
28 |
29 | val df = session.createDataset(rayEnv.dataServers).toDF()
30 |
31 | val envs = new util.HashMap[String, String]()
32 | envs.put("PYTHON_ENV", s"${condaEnv};export ARROW_PRE_0_15_IPC_FORMAT=1")
33 | //envs.put("PYTHONPATH", (os.pwd / "python").toString())
34 |
35 | val aps = new ApplyPythonScript(rayEnv.rayAddress, envs, timezoneId)
36 | val rayAddress = rayEnv.rayAddress
37 | logInfo(rayAddress)
38 | val func = aps.execute(
39 | s"""
40 | |import ray
41 | |import time
42 | |from pyjava.api.mlsql import RayContext
43 | |import numpy as np;
44 | |ray_context = RayContext.connect(globals(),"${rayAddress}")
45 | |def echo(row):
46 | | row1 = {}
47 | | row1["title"]=row['title'][1:]
48 | | row1["body"]= row["body"] + ',' + row["body"]
49 | | return row1
50 | |ray_context.foreach(echo)
51 | """.stripMargin, df.schema)
52 |
53 | val outputDF = df.rdd.mapPartitions(func)
54 |
55 | val pythonServers = SparkUtils.internalCreateDataFrame(session, outputDF, df.schema).collect()
56 |
57 | val rdd = session.sparkContext.makeRDD[Row](pythonServers, numSlices = pythonServers.length)
58 | val pythonOutputDF = rayEnv.collectResult(rdd)
59 | val output = SparkUtils.internalCreateDataFrame(session, pythonOutputDF, dataDF.schema).collect()
60 | assert(output.length == 100)
61 | output.zipWithIndex.foreach({
62 | case (row, index) =>
63 | assert(row.getString(0) == s"itle${index}")
64 | assert(row.getString(1) == s"body-${index},body-${index}")
65 | })
66 | }
67 |
68 | override def beforeAll(): Unit = {
69 | spark = SparkSession.builder().master("local[*]").appName("test").getOrCreate()
70 | super.beforeAll()
71 | rayEnv.startRay(condaEnv)
72 | }
73 |
74 | override def afterAll(): Unit = {
75 | if (spark != null) {
76 | spark.sparkContext.stop()
77 | }
78 | rayEnv.stopRay(condaEnv)
79 | super.afterAll()
80 | }
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/src/test/java/tech/mlsql/test/function/SparkFunctions.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2016 Kyligence Inc. All rights reserved.
3 | *
4 | * http://kyligence.io
5 | *
6 | * This software is the confidential and proprietary information of
7 | * Kyligence Inc. ("Confidential Information"). You shall not disclose
8 | * such Confidential Information and shall use it only in accordance
9 | * with the terms of the license agreement you entered into with
10 | * Kyligence Inc.
11 | *
12 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
13 | * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
14 | * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
15 | * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
16 | * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
17 | * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
18 | * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
19 | * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
20 | * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 | */
24 | package tech.mlsql.test.function
25 |
26 | import org.apache.spark.{TaskContext, WowRowEncoder}
27 | import org.apache.spark.sql.Row
28 | import org.apache.spark.sql.catalyst.InternalRow
29 | import org.apache.spark.sql.types.StructType
30 | import tech.mlsql.arrow.python.ispark.SparkContextImp
31 | import tech.mlsql.arrow.python.runner.{ArrowPythonRunner, ChainedPythonFunctions, PythonConf, PythonFunction}
32 | import tech.mlsql.common.utils.lang.sc.ScalaMethodMacros.str
33 | import java.util
34 |
35 | import scala.collection.JavaConverters.asScalaIteratorConverter
36 |
37 | object SparkFunctions {
38 |
39 | case class MockData(title: String, body: String)
40 |
41 | def testScript1(struct: StructType, rayAddress: String, timezoneId: String): Iterator[Row] => Iterator[InternalRow] = {
42 | iter =>
43 | val encoder = WowRowEncoder.fromRow(struct)
44 | val envs = new util.HashMap[String, String]()
45 | envs.put(str(PythonConf.PYTHON_ENV), "source ~/.bash_profile && conda activate dev && export ARROW_PRE_0_15_IPC_FORMAT=1")
46 | envs.put("PYTHONPATH", (os.pwd / "python").toString())
47 | val batch = new ArrowPythonRunner(
48 | Seq(ChainedPythonFunctions(Seq(PythonFunction(
49 | s"""
50 | |import ray
51 | |import time
52 | |from pyjava.api.mlsql import RayContext
53 | |import numpy as np;
54 | |ray_context = RayContext.connect(globals(),"${rayAddress}")
55 | |def echo(row):
56 | | row1 = {}
57 | | row1["title"]=row['title'][1:]
58 | | row1["body"]= row["body"] + ',' + row["body"]
59 | | return row1
60 | |ray_context.foreach(echo)
61 | """.stripMargin, envs, "python", "3.6")))), struct,
62 | timezoneId, Map("pythonMode" -> "ray")
63 | )
64 | val newIter = iter.map(encoder)
65 | val commonTaskContext = new SparkContextImp(TaskContext.get(), batch)
66 | val columnarBatchIter = batch.compute(Iterator(newIter), TaskContext.getPartitionId(), commonTaskContext)
67 | columnarBatchIter.flatMap(_.rowIterator.asScala).map(f => f.copy())
68 | }
69 | }
70 |
--------------------------------------------------------------------------------