├── .gitignore ├── .gitmodules ├── .travis.yml ├── LICENSE ├── Makefile ├── README.md ├── build.sbt ├── log4j.properties ├── project ├── assembly.sbt ├── build.properties └── plugins.sbt ├── python ├── .gitignore ├── pyspark_cassandra │ ├── __init__.py │ ├── conf.py │ ├── context.py │ ├── format.py │ ├── rdd.py │ ├── streaming.py │ ├── tests.py │ ├── types.py │ └── util.py └── setup.py ├── sbin ├── local.sh ├── notebook.sh ├── profile.sh └── released.sh ├── src └── main │ └── scala │ ├── pyspark_cassandra │ ├── Pickling.scala │ ├── PythonHelper.scala │ ├── RowReaders.scala │ ├── RowTransformers.scala │ ├── RowWriter.scala │ ├── SpanBy.scala │ └── Utils.scala │ └── pyspark_util └── version.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # maven / sbt 2 | target 3 | .cache* 4 | 5 | # python 6 | *.pyc 7 | venv* 8 | 9 | # eclipse 10 | .classpath 11 | .project 12 | .pydevproject 13 | .settings 14 | 15 | #testing 16 | lib 17 | .ccm 18 | metastore_db 19 | /bin/ 20 | .sp-creds.txt 21 | *.ipynb 22 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pyspark-util"] 2 | path = pyspark-util 3 | url = https://github.com/TargetHolding/pyspark-util.git 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | 3 | scala: 4 | - "2.10.5" 5 | 6 | jdk: 7 | - oraclejdk8 8 | 9 | env: 10 | - CASSANDRA_VERSION=2.1.12 SPARK_VERSION=1.5.2 SPARK_PACKAGE_TYPE=hadoop2.6 11 | 12 | - CASSANDRA_VERSION=2.2.4 SPARK_VERSION=1.5.2 SPARK_PACKAGE_TYPE=hadoop2.6 13 | - CASSANDRA_VERSION=2.2.4 SPARK_VERSION=1.6.1 SPARK_PACKAGE_TYPE=hadoop2.6 14 | 15 | - CASSANDRA_VERSION=3.0.3 SPARK_VERSION=1.5.2 SPARK_PACKAGE_TYPE=hadoop2.6 16 | - CASSANDRA_VERSION=3.0.3 SPARK_VERSION=1.6.1 SPARK_PACKAGE_TYPE=hadoop2.6 17 | 18 | - CASSANDRA_VERSION=3.2.1 SPARK_VERSION=1.6.1 SPARK_PACKAGE_TYPE=hadoop2.6 19 | 20 | addons: 21 | apt: 22 | packages: 23 | - build-essential 24 | - python-dev 25 | - python-pip 26 | - python-virtualenv 27 | - libev4 28 | - libev-dev 29 | 30 | script: JVM_OPTS= make clean dist start-cassandra test-travis stop-cassandra 31 | 32 | sudo: false 33 | 34 | cache: 35 | directories: 36 | - $HOME/.m2 37 | - $HOME/.ivy2 38 | - $HOME/.sbt 39 | - $HOME/.local 40 | - $HOME/.cache/pip 41 | 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | 205 | THIRD-PARTY DEPENDENCIES 206 | ======================== 207 | Convenience copies of some third-party dependencies are distributed with 208 | Apache Cassandra as Java jar files in lib/. Licensing information for 209 | these files can be found in the lib/licenses directory. 210 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = /bin/bash 2 | VERSION = $(shell cat version.txt) 3 | 4 | .PHONY: clean clean-pyc clean-dist dist test-travis 5 | 6 | 7 | 8 | clean: clean-dist clean-pyc 9 | 10 | clean-pyc: 11 | find . -name '*.pyc' -exec rm -f {} + 12 | find . -name '*.pyo' -exec rm -f {} + 13 | find . -name '*~' -exec rm -f {} + 14 | find . -name '__pycache__' -exec rm -fr {} + 15 | 16 | clean-dist: 17 | rm -rf target 18 | rm -rf python/build/ 19 | rm -rf python/*.egg-info 20 | 21 | 22 | 23 | install-venv: 24 | test -d venv || virtualenv venv 25 | 26 | install-cassandra-driver: install-venv 27 | venv/bin/pip install cassandra-driver 28 | 29 | install-ccm: install-venv 30 | venv/bin/pip install ccm 31 | 32 | start-cassandra: install-ccm 33 | mkdir -p .ccm 34 | venv/bin/ccm status || venv/bin/ccm create pyspark_cassandra_test -v $(CASSANDRA_VERSION) -n 1 -s 35 | 36 | stop-cassandra: 37 | venv/bin/ccm remove 38 | 39 | 40 | 41 | test: test-python test-scala test-integration 42 | 43 | test-python: 44 | 45 | test-scala: 46 | 47 | test-integration: \ 48 | test-integration-setup \ 49 | test-integration-matrix \ 50 | test-integration-teardown 51 | 52 | test-integration-setup: \ 53 | start-cassandra 54 | 55 | test-integration-teardown: \ 56 | stop-cassandra 57 | 58 | test-integration-matrix: \ 59 | install-cassandra-driver \ 60 | test-integration-spark-1.4.1 \ 61 | test-integration-spark-1.5.0 \ 62 | test-integration-spark-1.5.1 \ 63 | test-integration-spark-1.5.2 \ 64 | test-integration-spark-1.6.0 \ 65 | test-integration-spark-1.6.1 66 | 67 | test-travis: install-cassandra-driver 68 | $(call test-integration-for-version,$$SPARK_VERSION,$$SPARK_PACKAGE_TYPE) 69 | 70 | test-integration-spark-1.3.1: 71 | $(call test-integration-for-version,1.3.1,hadoop2.6) 72 | 73 | test-integration-spark-1.4.1: 74 | $(call test-integration-for-version,1.4.1,hadoop2.6) 75 | 76 | test-integration-spark-1.5.0: 77 | $(call test-integration-for-version,1.5.0,hadoop2.6) 78 | 79 | test-integration-spark-1.5.1: 80 | $(call test-integration-for-version,1.5.1,hadoop2.6) 81 | 82 | test-integration-spark-1.5.2: 83 | $(call test-integration-for-version,1.5.2,hadoop2.6) 84 | 85 | test-integration-spark-1.6.0: 86 | $(call test-integration-for-version,1.6.0,hadoop2.6) 87 | 88 | test-integration-spark-1.6.1: 89 | $(call test-integration-for-version,1.6.1,hadoop2.6) 90 | 91 | define test-integration-for-version 92 | echo ====================================================================== 93 | echo testing integration with spark-$1 94 | 95 | mkdir -p lib && test -d lib/spark-$1-bin-$2 || \ 96 | (pushd lib && curl http://ftp.tudelft.nl/apache/spark/spark-$1/spark-$1-bin-$2.tgz | tar xz && popd) 97 | 98 | cp log4j.properties lib/spark-$1-bin-$2/conf/ 99 | 100 | source venv/bin/activate ; \ 101 | lib/spark-$1-bin-$2/bin/spark-submit \ 102 | --master local[*] \ 103 | --driver-memory 512m \ 104 | --conf spark.cassandra.connection.host="localhost" \ 105 | --jars target/scala-2.10/pyspark-cassandra-assembly-$(VERSION).jar \ 106 | --py-files target/scala-2.10/pyspark-cassandra-assembly-$(VERSION).jar \ 107 | python/pyspark_cassandra/tests.py 108 | 109 | echo ====================================================================== 110 | endef 111 | 112 | 113 | 114 | dist: clean-pyc 115 | sbt assembly 116 | cd python ; \ 117 | find . -mindepth 2 -name '*.py' -print | \ 118 | zip ../target/scala-2.10/pyspark-cassandra-assembly-$(VERSION).jar -@ 119 | 120 | 121 | all: clean dist 122 | 123 | 124 | publish: clean 125 | # use spark packages to create the distribution 126 | sbt spDist 127 | 128 | # push the python source files into the jar 129 | cd python ; \ 130 | find . -mindepth 2 -name '*.py' -print | \ 131 | zip ../target/scala-2.10/pyspark-cassandra_2.10-$(VERSION).jar -@ 132 | 133 | # copy it to the right name, and update the jar in the zip 134 | cp target/scala-2.10/pyspark-cassandra{_2.10,}-$(VERSION).jar 135 | cd target/scala-2.10 ;\ 136 | zip ../pyspark-cassandra-$(VERSION).zip pyspark-cassandra-$(VERSION).jar 137 | 138 | # send the package to spark-packages 139 | spark-package publish -c ".sp-creds.txt" -n "TargetHolding/pyspark-cassandra" -v $(VERSION) -f . -z target/pyspark-cassandra-$(VERSION).zip 140 | 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PySpark Cassandra 2 | ================= 3 | 4 | **This PySpark Cassandra repository is no longer maintained. Please check this repository for Spark 2.0+ support: https://github.com/anguenot/pyspark-cassandra** 5 | 6 | --- 7 | 8 | [![Build Status](https://travis-ci.org/TargetHolding/pyspark-cassandra.svg)](https://travis-ci.org/TargetHolding/pyspark-cassandra) 9 | [![Codacy Badge](https://api.codacy.com/project/badge/grade/1fb73418b06b4db18e3a4103a0ce056c)](https://www.codacy.com/app/frensjan/pyspark-cassandra) 10 | 11 | PySpark Cassandra brings back the fun in working with Cassandra data in PySpark. 12 | 13 | This module provides python support for Apache Spark's Resillient Distributed Datasets from Apache Cassandra CQL rows using [Cassandra Spark Connector](https://github.com/datastax/spark-cassandra-connector) within PySpark, both in the interactive shell and in python programmes submitted with spark-submit. 14 | 15 | This project was initially forked from https://github.com/Parsely/pyspark-cassandra, but in order to submit it to http://spark-packages.org/, a plain old repository was created. 16 | 17 | **Contents:** 18 | * [Compatibility](#compatibility) 19 | * [Using with PySpark](#using-with-pyspark) 20 | * [Using with PySpark shell](#using-with-pyspark-shell) 21 | * [Building](#building) 22 | * [API](#api) 23 | * [Examples](#examples) 24 | * [Problems / ideas?](#problems--ideas) 25 | * [Contributing](#contributing) 26 | 27 | 28 | 29 | Compatibility 30 | ------------- 31 | Feedback on (in-)compatibility is much appreciated. 32 | 33 | ### Spark 34 | The current version of PySpark Cassandra is succesfully used with Spark version 1.5 and 1.6. Use older versions for Spark 1.2, 1.3 or 1.4. 35 | 36 | ### Cassandra 37 | PySpark Cassandra is compatible with Cassandra: 38 | * 2.1.5 and higher 39 | * 2.2 40 | * 3 41 | 42 | ### Python 43 | PySpark Cassandra is used with python 2.7, python 3.3 and 3.4. 44 | 45 | ### Scala 46 | PySpark Cassandra is currently only packaged for Scala 2.10 47 | 48 | 49 | 50 | Using with PySpark 51 | ------------------ 52 | 53 | ### With Spark Packages 54 | Pyspark Cassandra is published at [Spark Packages](http://spark-packages.org/package/TargetHolding/pyspark-cassandra). This allows easy usage with Spark through: 55 | ```bash 56 | spark-submit \ 57 | --packages TargetHolding/pyspark-cassandra: \ 58 | --conf spark.cassandra.connection.host=your,cassandra,node,names 59 | ``` 60 | 61 | 62 | ### Without Spark Packages 63 | 64 | ```bash 65 | spark-submit \ 66 | --jars /path/to/pyspark-cassandra-assembly-.jar \ 67 | --driver-class-path /path/to/pyspark-cassandra-assembly-.jar \ 68 | --py-files /path/to/pyspark-cassandra-assembly-.jar \ 69 | --conf spark.cassandra.connection.host=your,cassandra,node,names \ 70 | --master spark://spark-master:7077 \ 71 | yourscript.py 72 | ``` 73 | (note that the the --driver-class-path due to [SPARK-5185](https://issues.apache.org/jira/browse/SPARK-5185)) 74 | (also not that the assembly will include the python source files, quite similar to a python source distribution) 75 | 76 | 77 | Using with PySpark shell 78 | ------------------------ 79 | 80 | Replace `spark-submit` with `pyspark` to start the interactive shell and don't provide a script as argument and then import PySpark Cassandra. Note that when performing this import the `sc` variable in pyspark is augmented with the `cassandraTable(...)` method. 81 | 82 | ```python 83 | import pyspark_cassandra 84 | ``` 85 | 86 | 87 | 88 | Building 89 | -------- 90 | 91 | ### For [Spark Packages](http://spark-packages.org/package/TargetHolding/pyspark-cassandra) Pyspark Cassandra can be published using: 92 | ```bash 93 | sbt compile 94 | ``` 95 | The package can be published locally with: 96 | ```bash 97 | sbt spPublishLocal 98 | ``` 99 | The package can be published to Spark Packages with (requires authentication and authorization): 100 | ```bash 101 | make publish 102 | ``` 103 | 104 | ### For local testing / without Spark Packages 105 | A Java / JVM library as well as a python library is required to use PySpark Cassandra. They can be built with: 106 | 107 | ```bash 108 | make dist 109 | ``` 110 | 111 | This creates a fat jar with the Spark Cassandra Connector and additional classes for bridging Spark and PySpark for Cassandra data and the .py source files at: `target/scala-2.10/pyspark-cassandra-assembly-.jar` 112 | 113 | 114 | 115 | API 116 | --- 117 | 118 | The PySpark Cassandra API aims to stay close to the Cassandra Spark Connector API. Reading its [documentation](https://github.com/datastax/spark-cassandra-connector/#documentation) is a good place to start. 119 | 120 | 121 | ### pyspark_cassandra.RowFormat 122 | 123 | The primary representation of CQL rows in PySpark Cassandra is the ROW format. However `sc.cassandraTable(...)` supports the `row_format` argument which can be any of the constants from `RowFormat`: 124 | * `DICT`: The default layout, a CQL row is represented as a python dict with the CQL row columns as keys. 125 | * `TUPLE`: A CQL row is represented as a python tuple with the values in CQL table column order / the order of the selected columns. 126 | * `ROW`: A pyspark_cassandra.Row object representing a CQL row. 127 | 128 | Column values are related between CQL and python as follows: 129 | 130 | | **CQL** | **python** | 131 | |:---------:|:---------------------:| 132 | | ascii | unicode string | 133 | | bigint | long | 134 | | blob | bytearray | 135 | | boolean | boolean | 136 | | counter | int, long | 137 | | decimal | decimal | 138 | | double | float | 139 | | float | float | 140 | | inet | str | 141 | | int | int | 142 | | map | dict | 143 | | set | set | 144 | | list | list | 145 | | text | unicode string | 146 | | timestamp | datetime.datetime | 147 | | timeuuid | uuid.UUID | 148 | | varchar | unicode string | 149 | | varint | long | 150 | | uuid | uuid.UUID | 151 | | _UDT_ | pyspark_cassandra.UDT | 152 | 153 | 154 | ### pyspark_cassandra.Row 155 | 156 | This is the default type to which CQL rows are mapped. It is directly compatible with `pyspark.sql.Row` but is (correctly) mutable and provides some other improvements. 157 | 158 | 159 | ### pyspark_cassandra.UDT 160 | 161 | This type is structurally identical to pyspark_cassandra.Row but serves user defined types. Mapping to custom python types (e.g. via CQLEngine) is not yet supported. 162 | 163 | 164 | ### pyspark_cassandra.CassandraSparkContext 165 | 166 | A `CassandraSparkContext` is very similar to a regular `SparkContext`. It is created in the same way, can be used to read files, parallelize local data, broadcast a variable, etc. See the [Spark Programming Guide](https://spark.apache.org/docs/1.2.0/programming-guide.html) for more details. *But* it exposes one additional method: 167 | 168 | * ``cassandraTable(keyspace, table, ...)``: Returns a CassandraRDD for the given keyspace and table. Additional arguments which can be provided: 169 | 170 | * `row_format` can be set to any of the `pyspark_cassandra.RowFormat` values (defaults to `ROW`) 171 | * `split_size` sets the size in the number of CQL rows in each partition (defaults to `100000`) 172 | * `fetch_size` sets the number of rows to fetch per request from Cassandra (defaults to `1000`) 173 | * `consistency_level` sets with which consistency level to read the data (defaults to `LOCAL_ONE`) 174 | 175 | 176 | ### pyspark.RDD 177 | 178 | PySpark Cassandra supports saving arbitrary RDD's to Cassandra using: 179 | 180 | * ``rdd.saveToCassandra(keyspace, table, ...)``: Saves an RDD to Cassandra. The RDD is expected to contain dicts with keys mapping to CQL columns. Additional arguments which can be supplied are: 181 | 182 | * ``columns(iterable)``: The columns to save, i.e. which keys to take from the dicts in the RDD. 183 | * ``batch_size(int)``: The size in bytes to batch up in an unlogged batch of CQL inserts. 184 | * ``batch_buffer_size(int)``: The maximum number of batches which are 'pending'. 185 | * ``batch_grouping_key(string)``: The way batches are formed (defaults to "partition"): 186 | * ``all``: any row can be added to any batch 187 | * ``replicaset``: rows are batched for replica sets 188 | * ``partition``: rows are batched by their partition key 189 | * ``consistency_level(cassandra.ConsistencyLevel)``: The consistency level used in writing to Cassandra. 190 | * ``parallelism_level(int)``: The maximum number of batches written in parallel. 191 | * ``throughput_mibps``: Maximum write throughput allowed per single core in MB/s. 192 | * ``ttl(int or timedelta)``: The time to live as milliseconds or timedelta to use for the values. 193 | * ``timestamp(int, date or datetime)``: The timestamp in milliseconds, date or datetime to use for the values. 194 | * ``metrics_enabled(bool)``: Whether to enable task metrics updates. 195 | 196 | 197 | ### pyspark_cassandra.CassandraRDD 198 | 199 | A `CassandraRDD` is very similar to a regular `RDD` in pyspark. It is extended with the following methods: 200 | 201 | * ``select(*columns)``: Creates a CassandraRDD with the select clause applied. 202 | * ``where(clause, *args)``: Creates a CassandraRDD with a CQL where clause applied. The clause can contain ? markers with the arguments supplied as *args. 203 | * ``limit(num)``: Creates a CassandraRDD with the limit clause applied. 204 | * ``take(num)``: Takes at most ``num`` records from the Cassandra table. Note that if ``limit()`` was invoked before ``take()`` a normal pyspark ``take()`` is performed. Otherwise, first limit is set and _then_ a ``take()`` is performed. 205 | * ``cassandraCount()``: Lets Cassandra perform a count, instead of loading the data to Spark first. 206 | * ``saveToCassandra(...)``: As above, but the keyspace and/or table __may__ be omitted to save to the same keyspace and/or table. 207 | * ``spanBy(*columns)``: Groups rows by the given columns without shuffling. 208 | * ``joinWithCassandraTable(keyspace, table)``: Join an RDD with a Cassandra table on the partition key. Use .on(...) to specifiy other columns to join on. .select(...), .where(...) and .limit(...) can be used as well. 209 | 210 | 211 | ### pyspark_cassandra.streaming 212 | 213 | When importing ```pyspark_cassandra.streaming``` the method ``saveToCassandra(...)``` is made available on DStreams. Also support for joining with a Cassandra table is added: 214 | * ``joinWithCassandraTable(keyspace, table, selected_columns, join_columns)``: 215 | 216 | 217 | Examples 218 | -------- 219 | 220 | Creating a SparkContext with Cassandra support 221 | 222 | ```python 223 | import pyspark_cassandra 224 | 225 | conf = SparkConf() \ 226 | .setAppName("PySpark Cassandra Test") \ 227 | .setMaster("spark://spark-master:7077") \ 228 | .set("spark.cassandra.connection.host", "cas-1") 229 | 230 | sc = CassandraSparkContext(conf=conf) 231 | ``` 232 | 233 | Using select and where to narrow the data in an RDD and then filter, map, reduce and collect it:: 234 | 235 | ```python 236 | sc \ 237 | .cassandraTable("keyspace", "table") \ 238 | .select("col-a", "col-b") \ 239 | .where("key=?", "x") \ 240 | .filter(lambda r: r["col-b"].contains("foo")) \ 241 | .map(lambda r: (r["col-a"], 1) 242 | .reduceByKey(lambda a, b: a + b) 243 | .collect() 244 | ``` 245 | 246 | Storing data in Cassandra:: 247 | 248 | ```python 249 | rdd = sc.parallelize([{ 250 | "key": k, 251 | "stamp": datetime.now(), 252 | "val": random() * 10, 253 | "tags": ["a", "b", "c"], 254 | "options": { 255 | "foo": "bar", 256 | "baz": "qux", 257 | } 258 | } for k in ["x", "y", "z"]]) 259 | 260 | rdd.saveToCassandra( 261 | "keyspace", 262 | "table", 263 | ttl=timedelta(hours=1), 264 | ) 265 | ``` 266 | 267 | Create a streaming context, convert every line to a generater of words which are saved to cassandra. Through this example all unique words are stored in Cassandra. 268 | 269 | The words are wrapped as a tuple so that they are in a format which can be stored. A dict or a pyspark_cassandra.Row object would have worked as well. 270 | 271 | ```python 272 | from pyspark.streaming import StreamingContext 273 | from pyspark_cassandra import streaming 274 | 275 | ssc = StreamingContext(sc, 2) 276 | 277 | ssc \ 278 | .socketTextStream("localhost", 9999) \ 279 | .flatMap(lambda l: ((w,) for w in (l,))) \ 280 | .saveToCassandra('keyspace', 'words') 281 | 282 | ssc.start() 283 | ``` 284 | 285 | Joining with Cassandra: 286 | 287 | ```python 288 | joined = rdd \ 289 | .joinWithCassandraTable('keyspace', 'accounts') \ 290 | .on('id') \ 291 | .select('e-mail', 'followers') 292 | 293 | for left, right in joined: 294 | ... 295 | ``` 296 | 297 | Or with a DStream: 298 | 299 | ```python 300 | joined = dstream.joinWithCassandraTable(self.keyspace, self.table, ['e-mail', 'followers'], ['id']) 301 | ``` 302 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import scala.io 2 | 3 | name := "pyspark-cassandra" 4 | 5 | version := io.Source.fromFile("version.txt").mkString.trim 6 | 7 | organization := "TargetHolding" 8 | 9 | scalaVersion := "2.10.5" 10 | 11 | credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") 12 | 13 | licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0") 14 | 15 | libraryDependencies ++= Seq( 16 | "com.datastax.spark" %% "spark-cassandra-connector-java" % "1.6.0-M1" 17 | ) 18 | 19 | spName := "TargetHolding/pyspark-cassandra" 20 | 21 | sparkVersion := "1.5.1" 22 | 23 | sparkComponents ++= Seq("streaming", "sql") 24 | 25 | javacOptions ++= Seq("-source", "1.8", "-target", "1.8") 26 | 27 | assemblyOption in assembly := (assemblyOption in assembly).value.copy( 28 | includeScala = false 29 | ) 30 | 31 | assemblyMergeStrategy in assembly <<= (assemblyMergeStrategy in assembly) { 32 | (old) => { 33 | case PathList("META-INF", "MANIFEST.MF") => MergeStrategy.discard 34 | case PathList("META-INF", xs @ _*) => MergeStrategy.last 35 | case x => MergeStrategy.last 36 | } 37 | } 38 | 39 | EclipseKeys.withSource := true 40 | -------------------------------------------------------------------------------- /log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=WARN, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.eclipse.jetty=WARN 10 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 13 | log4j.logger.org.apache.spark.metrics=ERROR 14 | log4j.logger.org.apache.spark.util=ERROR 15 | log4j.logger.org.apache.hadoop.util=ERROR 16 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.0") 2 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=0.13.9 3 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" 2 | 3 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.3") 4 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "4.0.0") 5 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | build -------------------------------------------------------------------------------- /python/pyspark_cassandra/__init__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | This module provides python support for Apache Spark's Resillient Distributed Datasets from 15 | Apache Cassandra CQL rows using the Spark Cassandra Connector from 16 | https://github.com/datastax/spark-cassandra-connector. 17 | """ 18 | 19 | import inspect 20 | 21 | import pyspark.context 22 | import pyspark.rdd 23 | 24 | import pyspark_cassandra.context 25 | import pyspark_cassandra.rdd 26 | 27 | from .conf import WriteConf 28 | from .context import CassandraSparkContext, monkey_patch_sc 29 | from .rdd import RowFormat 30 | from .types import Row, UDT 31 | 32 | 33 | __all__ = [ 34 | "CassandraSparkContext", 35 | "ReadConf", 36 | "Row", 37 | "RowFormat", 38 | "streaming", 39 | "UDT", 40 | "WriteConf" 41 | ] 42 | 43 | 44 | # Monkey patch the default python RDD so that it can be stored to Cassandra as CQL rows 45 | from .rdd import saveToCassandra, joinWithCassandraTable 46 | pyspark.rdd.RDD.saveToCassandra = saveToCassandra 47 | pyspark.rdd.RDD.joinWithCassandraTable = joinWithCassandraTable 48 | 49 | # Monkey patch the sc variable in the caller if any 50 | frame = inspect.currentframe().f_back 51 | # Go back at most 10 frames 52 | for _ in range(10): 53 | if not frame: 54 | break 55 | elif "sc" in frame.f_globals: 56 | monkey_patch_sc(frame.f_globals["sc"]) 57 | break 58 | else: 59 | frame = frame.f_back 60 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/conf.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from datetime import timedelta, datetime, date 14 | 15 | 16 | class _Conf(object): 17 | @classmethod 18 | def build(cls, conf=None, **kwargs): 19 | if conf and kwargs: 20 | settings = conf.settings() 21 | settings.update(kwargs) 22 | return cls(**settings) 23 | elif conf: 24 | return conf 25 | else: 26 | return cls(**kwargs) 27 | 28 | def settings(self): 29 | return {k:v for k, v in self.__dict__.items() if v is not None} 30 | 31 | def __str__(self): 32 | return '%s(%s)' % ( 33 | self.__class__.__name__, 34 | ', '.join('%s=%s' % (k, v) for k, v in self.settings().items()) 35 | ) 36 | 37 | 38 | class ReadConf(_Conf): 39 | def __init__(self, split_count=None, split_size=None, fetch_size=None, consistency_level=None, 40 | metrics_enabled=None): 41 | ''' 42 | TODO docstring 43 | ''' 44 | self.split_count = split_count 45 | self.split_size = split_size 46 | self.fetch_size = fetch_size 47 | self.consistency_level = consistency_level 48 | self.metrics_enabled = metrics_enabled 49 | 50 | 51 | class WriteConf(_Conf): 52 | def __init__(self, batch_size=None, batch_buffer_size=None, batch_grouping_key=None, 53 | consistency_level=None, parallelism_level=None, throughput_mibps=None, ttl=None, 54 | timestamp=None, metrics_enabled=None): 55 | ''' 56 | @param batch_size(int): 57 | The size in bytes to batch up in an unlogged batch of CQL inserts. 58 | If None given the default size of 16*1024 is used or 59 | spark.cassandra.output.batch.size.bytes if set. 60 | @param batch_buffer_size(int): 61 | The maximum number of batches which are 'pending'. 62 | If None given the default of 1000 is used. 63 | @param batch_grouping_key(string): 64 | The way batches are formed: 65 | * all: any row can be added to any batch 66 | * replicaset: rows are batched for replica sets 67 | * partition: rows are batched by their partition key 68 | * None: defaults to "partition" 69 | @param consistency_level(cassandra.ConsistencyLevel): 70 | The consistency level used in writing to Cassandra. 71 | If None defaults to LOCAL_ONE or spark.cassandra.output.consistency.level if set. 72 | @param parallelism_level(int): 73 | The maximum number of batches written in parallel. 74 | If None defaults to 8 or spark.cassandra.output.concurrent.writes if set. 75 | @param throughput_mibps(int): 76 | @param ttl(int or timedelta): 77 | The time to live as seconds or timedelta to use for the values. 78 | If None given no TTL is used. 79 | @param timestamp(int, date or datetime): 80 | The timestamp in microseconds, date or datetime to use for the values. 81 | If None given the Cassandra nodes determine the timestamp. 82 | @param metrics_enabled(bool): 83 | Whether to enable task metrics updates. 84 | ''' 85 | self.batch_size = batch_size 86 | self.batch_buffer_size = batch_buffer_size 87 | self.batch_grouping_key = batch_grouping_key 88 | self.consistency_level = consistency_level 89 | self.parallelism_level = parallelism_level 90 | self.throughput_mibps = throughput_mibps 91 | 92 | # convert time delta in ttl in seconds 93 | if ttl and isinstance(ttl, timedelta): 94 | ttl = int(ttl.total_seconds()) 95 | self.ttl = ttl 96 | 97 | # convert date or datetime objects to a timestamp in milliseconds since the UNIX epoch 98 | if timestamp and (isinstance(timestamp, datetime) or isinstance(timestamp, date)): 99 | timestamp = (timestamp - timestamp.__class__(1970, 1, 1)).total_seconds() 100 | timestamp = int(timestamp * 1000 * 1000) 101 | self.timestamp = timestamp 102 | 103 | self.metrics_enabled = metrics_enabled 104 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/context.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from functools import partial 14 | 15 | import pyspark.context 16 | from pyspark_cassandra.rdd import CassandraTableScanRDD 17 | 18 | 19 | def monkey_patch_sc(sc): 20 | sc.__class__ = CassandraSparkContext 21 | sc.__dict__["cassandraTable"] = partial(CassandraSparkContext.cassandraTable, sc) 22 | sc.__dict__["cassandraTable"].__doc__ = CassandraSparkContext.cassandraTable.__doc__ 23 | 24 | 25 | class CassandraSparkContext(pyspark.context.SparkContext): 26 | """Wraps a SparkContext which allows reading CQL rows from Cassandra""" 27 | 28 | def cassandraTable(self, *args, **kwargs): 29 | """Returns a CassandraTableScanRDD for the given keyspace and table""" 30 | return CassandraTableScanRDD(self, *args, **kwargs) 31 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/format.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | class RowFormat(object): 14 | """An enumeration of CQL row formats used in Cassandra RDD's""" 15 | 16 | DICT = 0 17 | TUPLE = 1 18 | ROW = 2 19 | 20 | values = (DICT, TUPLE, ROW) 21 | 22 | def __init__(self): 23 | raise NotImplemented('RowFormat is not meant to be initialized, use e.g. RowFormat.DICT') 24 | 25 | 26 | class ColumnSelector(object): 27 | def __init__(self, partition_key=False, primary_key=False, *columns): 28 | if sum([bool(partition_key), bool(primary_key), bool(columns)]) > 1: 29 | raise ValueError( 30 | "can't combine selection of partition_key and/or primar_key and/or columns") 31 | 32 | self.partition_key = partition_key 33 | self.primary_key = primary_key 34 | self.columns = columns 35 | 36 | @classmethod 37 | def none(cls): 38 | return ColumnSelector() 39 | 40 | @classmethod 41 | def partition_key(cls): 42 | return ColumnSelector(partition_key=True) 43 | 44 | @classmethod 45 | def primary_key(cls): 46 | return ColumnSelector(primary_key=True) 47 | 48 | @classmethod 49 | def some(cls, *columns): 50 | return ColumnSelector(columns) 51 | 52 | def __str__(self): 53 | s = '[column selection of: ' 54 | if self.partition_key: 55 | s += 'partition_key' 56 | elif self.primary_key: 57 | s += 'primary_key' 58 | elif self.columns: 59 | s += ', '.join(c for c in self.columns) 60 | else: 61 | s += 'nothing' 62 | return s + ']' 63 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/rdd.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from copy import copy 14 | from itertools import groupby 15 | from operator import itemgetter 16 | import sys 17 | 18 | from pyspark.rdd import RDD 19 | from pyspark_cassandra.conf import ReadConf, WriteConf 20 | from pyspark_cassandra.format import ColumnSelector, RowFormat 21 | from pyspark_cassandra.types import Row 22 | from pyspark_cassandra.util import as_java_array, as_java_object, helper 23 | 24 | 25 | if sys.version_info > (3,): 26 | long = int # @ReservedAssignment 27 | 28 | 29 | try: 30 | import pandas as pd # @UnusedImport, import used in SpanningRDD 31 | except: 32 | pass 33 | 34 | 35 | def saveToCassandra(rdd, keyspace=None, table=None, columns=None, row_format=None, keyed=None, 36 | write_conf=None, **write_conf_kwargs): 37 | ''' 38 | Saves an RDD to Cassandra. The RDD is expected to contain dicts with keys mapping to CQL 39 | columns. 40 | 41 | Arguments: 42 | @param rdd(RDD): 43 | The RDD to save. Equals to self when invoking saveToCassandra on a monkey patched RDD. 44 | @param keyspace(string):in 45 | The keyspace to save the RDD in. If not given and the rdd is a CassandraRDD the same 46 | keyspace is used. 47 | @param table(string): 48 | The CQL table to save the RDD in. If not given and the rdd is a CassandraRDD the same 49 | table is used. 50 | 51 | Keyword arguments: 52 | @param columns(iterable): 53 | The columns to save, i.e. which keys to take from the dicts in the RDD. 54 | If None given all columns are be stored. 55 | 56 | @param row_format(RowFormat): 57 | Make explicit how to map the RDD elements into Cassandra rows. 58 | If None given the mapping is auto-detected as far as possible. 59 | @param keyed(bool): 60 | Make explicit that the RDD consists of key, value tuples (and not arrays of length 61 | two). 62 | 63 | @param write_conf(WriteConf): 64 | A WriteConf object to use when saving to Cassandra 65 | @param **write_conf_kwargs: 66 | WriteConf parameters to use when saving to Cassandra 67 | ''' 68 | 69 | keyspace = keyspace or getattr(rdd, 'keyspace', None) 70 | if not keyspace: 71 | raise ValueError("keyspace not set") 72 | 73 | table = table or getattr(rdd, 'table', None) 74 | if not table: 75 | raise ValueError("table not set") 76 | 77 | # create write config as map 78 | write_conf = WriteConf.build(write_conf, **write_conf_kwargs) 79 | write_conf = as_java_object(rdd.ctx._gateway, write_conf.settings()) 80 | # convert the columns to a string array 81 | columns = as_java_array(rdd.ctx._gateway, "String", columns) if columns else None 82 | 83 | helper(rdd.ctx) \ 84 | .saveToCassandra( 85 | rdd._jrdd, 86 | keyspace, 87 | table, 88 | columns, 89 | row_format, 90 | keyed, 91 | write_conf, 92 | ) 93 | 94 | 95 | class _CassandraRDD(RDD): 96 | ''' 97 | A Resilient Distributed Dataset of Cassandra CQL rows. As any RDD, objects of this class 98 | are immutable; i.e. operations on this RDD generate a new RDD. 99 | ''' 100 | 101 | def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, **read_conf_kwargs): 102 | if not keyspace: 103 | raise ValueError("keyspace not set") 104 | 105 | if not table: 106 | raise ValueError("table not set") 107 | 108 | if row_format is None: 109 | row_format = RowFormat.ROW 110 | elif row_format < 0 or row_format >= len(RowFormat.values): 111 | raise ValueError("invalid row_format %s" % row_format) 112 | 113 | self.keyspace = keyspace 114 | self.table = table 115 | self.row_format = row_format 116 | self.read_conf = ReadConf.build(read_conf, **read_conf_kwargs) 117 | self._limit = None 118 | 119 | # this jrdd is for compatibility with pyspark.rdd.RDD 120 | # while allowing this constructor to be use for type checking etc 121 | # and setting _jrdd //after// invoking this constructor 122 | class DummyJRDD(object): 123 | def id(self): 124 | return -1 125 | jrdd = DummyJRDD() 126 | 127 | super(_CassandraRDD, self).__init__(jrdd, ctx) 128 | 129 | 130 | @property 131 | def _helper(self): 132 | return helper(self.ctx) 133 | 134 | 135 | def _pickle_jrdd(self): 136 | jrdd = self._helper.pickleRows(self._crdd, self.row_format) 137 | return self._helper.javaRDD(jrdd) 138 | 139 | 140 | def get_crdd(self): 141 | return self._crdd 142 | 143 | def set_crdd(self, crdd): 144 | self._crdd = crdd 145 | self._jrdd = self._pickle_jrdd() 146 | self._id = self._jrdd.id 147 | 148 | crdd = property(get_crdd, set_crdd) 149 | 150 | 151 | saveToCassandra = saveToCassandra 152 | 153 | 154 | def select(self, *columns): 155 | """Creates a CassandraRDD with the select clause applied.""" 156 | columns = as_java_array(self.ctx._gateway, "String", (str(c) for c in columns)) 157 | return self._specialize('select', columns) 158 | 159 | 160 | def where(self, clause, *args): 161 | """Creates a CassandraRDD with a CQL where clause applied. 162 | @param clause: The where clause, either complete or with ? markers 163 | @param *args: The parameters for the ? markers in the where clause. 164 | """ 165 | args = as_java_array(self.ctx._gateway, "Object", args) 166 | return self._specialize('where', *[clause, args]) 167 | 168 | 169 | def limit(self, limit): 170 | """Creates a CassandraRDD with the limit clause applied.""" 171 | self._limit = limit 172 | return self._specialize('limit', long(limit)) 173 | 174 | 175 | def take(self, num): 176 | """Takes at most 'num' records from the Cassandra table. 177 | 178 | Note that if limit() was invoked before take() a normal pyspark take() 179 | is performed. Otherwise, first limit is set and _then_ a take() is 180 | performed. 181 | """ 182 | if self._limit: 183 | return super(_CassandraRDD, self).take(num) 184 | else: 185 | return self.limit(num).take(num) 186 | 187 | 188 | def cassandraCount(self): 189 | """Lets Cassandra perform a count, instead of loading data to Spark""" 190 | return self._crdd.cassandraCount() 191 | 192 | 193 | def _specialize(self, func_name, *args, **kwargs): 194 | func = getattr(self._helper, func_name) 195 | 196 | new = copy(self) 197 | new.crdd = func(new._crdd, *args, **kwargs) 198 | 199 | return new 200 | 201 | 202 | def spanBy(self, *columns): 203 | """"Groups rows by the given columns without shuffling. 204 | 205 | @param *columns: an iterable of columns by which to group. 206 | 207 | Note that: 208 | - The rows are grouped by comparing the given columns in order and 209 | starting a new group whenever the value of the given columns changes. 210 | This works well with using the partition keys and one or more of the 211 | clustering keys. Use rdd.groupBy(...) for any other grouping. 212 | - The grouping is applied on the partition level. I.e. any grouping 213 | will be a subset of its containing partition. 214 | """ 215 | 216 | return SpanningRDD(self.ctx, self._crdd, self._jrdd, self._helper, columns) 217 | 218 | 219 | def __copy__(self): 220 | c = self.__class__.__new__(self.__class__) 221 | c.__dict__.update(self.__dict__) 222 | return c 223 | 224 | 225 | 226 | class CassandraTableScanRDD(_CassandraRDD): 227 | def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, **read_conf_kwargs): 228 | super(CassandraTableScanRDD, self).__init__(ctx, keyspace, table, row_format, read_conf, 229 | **read_conf_kwargs) 230 | 231 | self._key_by = ColumnSelector.none() 232 | 233 | read_conf = as_java_object(ctx._gateway, self.read_conf.settings()) 234 | 235 | self.crdd = self._helper \ 236 | .cassandraTable( 237 | ctx._jsc, 238 | keyspace, 239 | table, 240 | read_conf, 241 | ) 242 | 243 | 244 | def by_primary_key(self): 245 | return self.key_by(primary_key=True) 246 | 247 | def key_by(self, primary_key=True, partition_key=False, *columns): 248 | # TODO implement keying by arbitrary columns 249 | if columns: 250 | raise NotImplementedError('keying by arbitrary columns is not (yet) supported') 251 | if partition_key: 252 | raise NotImplementedError('keying by partition key is not (yet) supported') 253 | 254 | new = copy(self) 255 | new._key_by = ColumnSelector(partition_key, primary_key, *columns) 256 | new.crdd = self.crdd 257 | 258 | return new 259 | 260 | 261 | def _pickle_jrdd(self): 262 | # TODO implement keying by arbitrary columns 263 | jrdd = self._helper.pickleRows(self.crdd, self.row_format, self._key_by.primary_key) 264 | return self._helper.javaRDD(jrdd) 265 | 266 | 267 | 268 | class SpanningRDD(RDD): 269 | ''' 270 | An RDD which groups rows with the same key (as defined through named 271 | columns) within each partition. 272 | ''' 273 | def __init__(self, ctx, crdd, jrdd, helper, columns): 274 | self._crdd = crdd 275 | self.columns = columns 276 | self._helper = helper 277 | 278 | rdd = RDD(jrdd, ctx).mapPartitions(self._spanning_iterator()) 279 | super(SpanningRDD, self).__init__(rdd._jrdd, ctx) 280 | 281 | 282 | def _spanning_iterator(self): 283 | ''' implements basic spanning on the python side operating on Rows ''' 284 | # TODO implement in Java and support not only Rows 285 | 286 | columns = set(str(c) for c in self.columns) 287 | 288 | def spanning_iterator(partition): 289 | def key_by(columns): 290 | for row in partition: 291 | k = Row(**{c: row.__getattr__(c) for c in columns}) 292 | for c in columns: 293 | del row[c] 294 | 295 | yield (k, row) 296 | 297 | for g, l in groupby(key_by(columns), itemgetter(0)): 298 | yield g, list(_[1] for _ in l) 299 | 300 | return spanning_iterator 301 | 302 | 303 | def asDataFrames(self, *index_by): 304 | ''' 305 | Reads the spanned rows as DataFrames if pandas is available, or as 306 | a dict of numpy arrays if only numpy is available or as a dict with 307 | primitives and objects otherwise. 308 | 309 | @param index_by If pandas is available, the dataframes will be 310 | indexed by the given columns. 311 | ''' 312 | for c in index_by: 313 | if c in self.columns: 314 | raise ValueError('column %s cannot be used as index in the data' 315 | 'frames as it is a column by which the rows are spanned.') 316 | 317 | columns = as_java_array(self.ctx._gateway, "String", (str(c) for c in self.columns)) 318 | jrdd = self._helper.spanBy(self._crdd, columns) 319 | rdd = RDD(jrdd, self.ctx) 320 | 321 | global pd 322 | if index_by and pd: 323 | return rdd.mapValues(lambda _: _.set_index(*[str(c) for c in index_by])) 324 | else: 325 | return rdd 326 | 327 | 328 | def joinWithCassandraTable(left_rdd, keyspace, table): 329 | ''' 330 | Join an RDD with a Cassandra table on the partition key. Use .on(...) 331 | to specifiy other columns to join on. .select(...), .where(...) and 332 | .limit(...) can be used as well. 333 | 334 | Arguments: 335 | @param left_rdd(RDD): 336 | The RDD to join. Equals to self when invoking joinWithCassandraTable on a monkey 337 | patched RDD. 338 | @param keyspace(string): 339 | The keyspace to join on 340 | @param table(string): 341 | The CQL table to join on. 342 | ''' 343 | 344 | return CassandraJoinRDD(left_rdd, keyspace, table) 345 | 346 | 347 | class CassandraJoinRDD(_CassandraRDD): 348 | ''' 349 | TODO 350 | ''' 351 | 352 | def __init__(self, left_rdd, keyspace, table): 353 | super(CassandraJoinRDD, self).__init__(left_rdd.ctx, keyspace, table) 354 | self.crdd = self._helper \ 355 | .joinWithCassandraTable( 356 | left_rdd._jrdd, 357 | keyspace, 358 | table 359 | ) 360 | 361 | 362 | def on(self, *columns): 363 | columns = as_java_array(self.ctx._gateway, "String", (str(c) for c in columns)) 364 | return self._specialize('on', columns) 365 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/streaming.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from pyspark_cassandra.util import as_java_object, as_java_array 14 | 15 | from pyspark.streaming.dstream import DStream 16 | from pyspark_cassandra.conf import WriteConf 17 | from pyspark_cassandra.util import helper 18 | from pyspark.serializers import AutoBatchedSerializer, PickleSerializer 19 | 20 | 21 | def saveToCassandra(dstream, keyspace, table, columns=None, row_format=None, keyed=None, 22 | write_conf=None, **write_conf_kwargs): 23 | ctx = dstream._ssc._sc 24 | gw = ctx._gateway 25 | 26 | # create write config as map 27 | write_conf = WriteConf.build(write_conf, **write_conf_kwargs) 28 | write_conf = as_java_object(gw, write_conf.settings()) 29 | # convert the columns to a string array 30 | columns = as_java_array(gw, "String", columns) if columns else None 31 | 32 | return helper(ctx).saveToCassandra(dstream._jdstream, keyspace, table, columns, row_format, 33 | keyed, write_conf) 34 | 35 | 36 | def joinWithCassandraTable(dstream, keyspace, table, selected_columns=None, join_columns=None): 37 | """Joins a DStream (a stream of RDDs) with a Cassandra table 38 | 39 | Arguments: 40 | @param dstream(DStream) 41 | The DStream to join. Equals to self when invoking joinWithCassandraTable on a monkey 42 | patched RDD. 43 | @param keyspace(string): 44 | The keyspace to join on. 45 | @param table(string): 46 | The CQL table to join on. 47 | @param selected_columns(string): 48 | The columns to select from the Cassandra table. 49 | @param join_columns(string): 50 | The columns used to join on from the Cassandra table. 51 | """ 52 | 53 | ssc = dstream._ssc 54 | ctx = ssc._sc 55 | gw = ctx._gateway 56 | 57 | selected_columns = as_java_array(gw, "String", selected_columns) if selected_columns else None 58 | join_columns = as_java_array(gw, "String", join_columns) if join_columns else None 59 | 60 | h = helper(ctx) 61 | dstream = h.joinWithCassandraTable(dstream._jdstream, keyspace, table, selected_columns, 62 | join_columns) 63 | dstream = h.pickleRows(dstream) 64 | dstream = h.javaDStream(dstream) 65 | 66 | return DStream(dstream, ssc, AutoBatchedSerializer(PickleSerializer())) 67 | 68 | 69 | # Monkey patch the default python DStream so that data in it can be stored to and joined with 70 | # Cassandra tables 71 | DStream.saveToCassandra = saveToCassandra 72 | DStream.joinWithCassandraTable = joinWithCassandraTable 73 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/tests.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from _functools import partial 14 | from datetime import datetime, timedelta 15 | from decimal import Decimal 16 | import string 17 | import sys 18 | import time 19 | import unittest 20 | import uuid 21 | 22 | from cassandra import ConsistencyLevel 23 | from cassandra.cluster import Cluster 24 | from cassandra.util import uuid_from_time 25 | 26 | from pyspark import SparkConf 27 | from pyspark.accumulators import AddingAccumulatorParam 28 | from pyspark.streaming.context import StreamingContext 29 | 30 | from pyspark_cassandra import CassandraSparkContext, RowFormat, Row, UDT 31 | import pyspark_cassandra 32 | import pyspark_cassandra.streaming 33 | from pyspark_cassandra.conf import ReadConf, WriteConf 34 | from itertools import chain 35 | from math import sqrt 36 | from uuid import UUID 37 | 38 | 39 | class CassandraTestCase(unittest.TestCase): 40 | keyspace = "test_pyspark_cassandra" 41 | 42 | def rdd(self, keyspace=None, table=None, key=None, column=None, **kwargs): 43 | keyspace = keyspace or getattr(self, 'keyspace', None) 44 | table = table or getattr(self, 'table', None) 45 | rdd = self.sc.cassandraTable(keyspace, table, **kwargs) 46 | if key is not None: 47 | rdd = rdd.where('key=?', key) 48 | if column is not None: 49 | rdd = rdd.select(column) 50 | return rdd 51 | 52 | def read_test(self, type_name, value=None): 53 | rdd = self.rdd(key=type_name, column=type_name) 54 | self.assertEqual(rdd.count(), 1) 55 | read = getattr(rdd.first(), type_name) 56 | self.assertEqual(read, value) 57 | return read 58 | 59 | def read_write_test(self, type_name, value): 60 | row = {'key': type_name, type_name: value} 61 | rdd = self.sc.parallelize([row]) 62 | rdd.saveToCassandra(self.keyspace, self.table) 63 | return self.read_test(type_name, value) 64 | 65 | 66 | 67 | class SimpleTypesTestBase(CassandraTestCase): 68 | table = "simple_types" 69 | 70 | simple_types = [ 71 | 'ascii', 'bigint', 'blob', 'boolean', 'decimal', 'double', 'float', 72 | 'inet', 'int', 'text', 'timestamp', 'timeuuid', 'varchar', 'varint', 73 | 'uuid', 74 | ] 75 | 76 | @classmethod 77 | def setUpClass(cls): 78 | super(SimpleTypesTestBase, cls).setUpClass() 79 | cls.session.execute(''' 80 | CREATE TABLE IF NOT EXISTS ''' + cls.table + ''' ( 81 | key text primary key, %s 82 | ) 83 | ''' % ', '.join('{0} {0}'.format(t) for t in cls.simple_types)) 84 | 85 | def setUp(self): 86 | super(SimpleTypesTestBase, self).setUp() 87 | self.session.execute('TRUNCATE ' + self.table) 88 | 89 | 90 | class SimpleTypesTest(SimpleTypesTestBase): 91 | def test_ascii(self): 92 | self.read_write_test('ascii', 'some ascii') 93 | 94 | def test_bigint(self): 95 | self.read_write_test('bigint', sys.maxint) 96 | 97 | def test_blob(self): 98 | self.read_write_test('blob', bytearray('some blob')) 99 | 100 | def test_boolean(self): 101 | self.read_write_test('boolean', False) 102 | 103 | def test_decimal(self): 104 | self.read_write_test('decimal', Decimal(0.5)) 105 | 106 | def test_double(self): 107 | self.read_write_test('double', 0.5) 108 | 109 | def test_float(self): 110 | self.read_write_test('float', 0.5) 111 | 112 | # TODO returns resolved hostname with ip address (hostname/ip, 113 | # e.g. /127.0.0.1), but doesn't accept with / ... 114 | # def test_inet(self): 115 | # self.read_write_test('inet', u'/127.0.0.1') 116 | 117 | def test_int(self): 118 | self.read_write_test('int', 1) 119 | 120 | def test_text(self): 121 | self.read_write_test('text', u'some text') 122 | 123 | # TODO implement test with datetime with tzinfo without depending on pytz 124 | # def test_timestamp(self): 125 | # self.read_write_test('timestamp', datetime(2015, 1, 1)) 126 | 127 | def test_timeuuid(self): 128 | uuid = uuid_from_time(datetime(2015, 1, 1)) 129 | self.read_write_test('timeuuid', uuid) 130 | 131 | def test_varchar(self): 132 | self.read_write_test('varchar', u'some varchar') 133 | 134 | def test_varint(self): 135 | self.read_write_test('varint', 1) 136 | 137 | def test_uuid(self): 138 | self.read_write_test('uuid', uuid.UUID('22dadfd0-b971-11e4-a856-85a08dca5bbf')) 139 | 140 | 141 | 142 | class CollectionTypesTest(CassandraTestCase): 143 | table = "collection_types" 144 | collection_types = { 145 | 'm': 'map', 146 | 'l': 'list', 147 | 's': 'set', 148 | } 149 | 150 | @classmethod 151 | def setUpClass(cls): 152 | super(CollectionTypesTest, cls).setUpClass() 153 | cls.session.execute(''' 154 | CREATE TABLE IF NOT EXISTS %s ( 155 | key text primary key, %s 156 | ) 157 | ''' % (cls.table, ', '.join('%s %s' % (k, v) for k, v in cls.collection_types.items()))) 158 | 159 | @classmethod 160 | def tearDownClass(cls): 161 | super(CollectionTypesTest, cls).tearDownClass() 162 | 163 | def setUp(self): 164 | super(CollectionTypesTest, self).setUp() 165 | self.session.execute('TRUNCATE %s' % self.table) 166 | 167 | def collections_common_tests(self, collection, column): 168 | rows = [ 169 | {'key':k, column:v} 170 | for k, v in collection.items() 171 | ] 172 | 173 | self.sc.parallelize(rows).saveToCassandra(self.keyspace, self.table) 174 | 175 | rdd = self.sc.cassandraTable(self.keyspace, self.table).select('key', column).cache() 176 | self.assertEqual(len(collection), rdd.count()) 177 | 178 | collected = rdd.collect() 179 | self.assertEqual(len(collection), len(collected)) 180 | 181 | for row in collected: 182 | self.assertEqual(collection[row.key], getattr(row, column)) 183 | 184 | return rdd 185 | 186 | def test_list(self): 187 | lists = {'l%s' % i: list(string.ascii_lowercase[:i]) for i in range(1, 10)} 188 | self.collections_common_tests(lists, 'l') 189 | 190 | def test_map(self): 191 | maps = {'m%s' % i : {k : 'x' for k in string.ascii_lowercase[:i]} for i in range(1, 10)} 192 | self.collections_common_tests(maps, 'm') 193 | 194 | def test_set(self): 195 | maps = {'s%s' % i : set(string.ascii_lowercase[:i]) for i in range(1, 10)} 196 | self.collections_common_tests(maps, 's') 197 | 198 | 199 | 200 | class UDTTest(CassandraTestCase): 201 | table = "udt_types" 202 | 203 | types = { 204 | 'simple_udt': { 205 | 'col_text': 'text', 206 | 'col_int': 'int', 207 | 'col_boolean': 'boolean', 208 | }, 209 | 'udt_wset': { 210 | 'col_text': 'text', 211 | 'col_set': 'set', 212 | }, 213 | } 214 | 215 | @classmethod 216 | def setUpClass(cls): 217 | super(UDTTest, cls).setUpClass() 218 | 219 | cls.udt_support = cls.session.cluster.protocol_version >= 4 220 | if cls.udt_support: 221 | for name, udt in cls.types.items(): 222 | cls.session.execute(''' 223 | CREATE TYPE IF NOT EXISTS %s ( 224 | %s 225 | ) 226 | ''' % (name, ',\n\t'.join('%s %s' % f for f in udt.items()))) 227 | 228 | fields = ', '.join( 229 | '{udt_type} frozen<{udt_type}>'.format(udt_type=udt_type) 230 | for udt_type in cls.types 231 | ) 232 | 233 | fields += ', ' + ', '.join( 234 | '{udt_type}_{col_type} {col_type}>'.format(udt_type=udt_type, col_type=col_type) 235 | for udt_type in cls.types 236 | for col_type in ('set', 'list') 237 | ) 238 | 239 | cls.session.execute(''' 240 | CREATE TABLE IF NOT EXISTS %s ( 241 | key text primary key, %s 242 | ) 243 | ''' % (cls.table, fields)) 244 | 245 | def setUp(self): 246 | if not self.udt_support: 247 | self.skipTest("testing with Cassandra < 2.2, can't test with UDT's") 248 | 249 | super(UDTTest, self).setUp() 250 | self.session.execute('TRUNCATE %s' % self.table) 251 | 252 | def read_write_test(self, type_name, value): 253 | read = super(UDTTest, self).read_write_test(type_name, value) 254 | self.assertTrue(isinstance(read, UDT), 255 | 'value read is not an instance of UDT') 256 | 257 | udt = self.types[type_name] 258 | for field in udt: 259 | self.assertEqual(getattr(read, field), value[field]) 260 | 261 | def test_simple_udt(self): 262 | self.read_write_test('simple_udt', UDT(col_text='text', col_int=1, col_boolean=True)) 263 | 264 | def test_simple_udt_null(self): 265 | super(UDTTest, self).read_write_test('simple_udt', None) 266 | 267 | def test_simple_udt_null_field(self): 268 | self.read_write_test('simple_udt', UDT(col_text='text', col_int=None, col_boolean=True)) 269 | self.read_write_test('simple_udt', UDT(col_text=None, col_int=1, col_boolean=True)) 270 | 271 | def test_udt_wset(self): 272 | self.read_write_test('udt_wset', UDT(col_text='text', col_set={1, 2, 3})) 273 | 274 | def test_collection_of_udts(self): 275 | super(UDTTest, self).read_write_test('simple_udt_list', None) 276 | 277 | udts = [UDT(col_text='text ' + str(i), col_int=i, col_boolean=bool(i % 2)) for i in range(10)] 278 | super(UDTTest, self).read_write_test('simple_udt_set', set(udts)) 279 | super(UDTTest, self).read_write_test('simple_udt_list', udts) 280 | 281 | udts = [UDT(col_text='text ' + str(i), col_int=i, col_boolean=None) for i in range(10)] 282 | super(UDTTest, self).read_write_test('simple_udt_set', set(udts)) 283 | super(UDTTest, self).read_write_test('simple_udt_list', udts) 284 | 285 | 286 | 287 | class SelectiveSaveTest(SimpleTypesTestBase): 288 | def _save_and_get(self, *row): 289 | columns = ['key', 'text'] 290 | self.sc.parallelize(row).saveToCassandra(self.keyspace, self.table, columns=columns) 291 | rdd = self.rdd().select(*columns) 292 | self.assertEqual(rdd.count(), 1) 293 | return rdd.first() 294 | 295 | 296 | def test_row(self): 297 | row = Row(key='selective-save-test-row', int=2, text='a', boolean=False) 298 | read = self._save_and_get(row) 299 | 300 | for k in ['key', 'text']: 301 | self.assertEqual(getattr(row, k), getattr(read, k)) 302 | for k in ['boolean', 'int']: 303 | self.assertIsNone(getattr(read, k, None)) 304 | 305 | 306 | def test_dict(self): 307 | row = dict(key='selective-save-test-row', int=2, text='a', boolean=False) 308 | read = self._save_and_get(row) 309 | 310 | for k in ['key', 'text']: 311 | self.assertEqual(row[k], read[k]) 312 | for k in ['boolean', 'int']: 313 | self.assertIsNone(getattr(read, k, None)) 314 | 315 | 316 | 317 | class LimitAndTakeTest(SimpleTypesTestBase): 318 | size = 1000 319 | 320 | def setUp(self): 321 | super(LimitAndTakeTest, self).setUp() 322 | data = self.sc.parallelize(range(0, self.size)).map(lambda i: {'key':i, 'int':i}) 323 | data.saveToCassandra(self.keyspace, self.table) 324 | 325 | def test_limit(self): 326 | data = self.rdd() 327 | 328 | for i in (5, 10, 100, 1000, 1500): 329 | l = min(i, self.size) 330 | self.assertEqual(len(data.take(i)), l) 331 | self.assertEqual(len(data.limit(i).collect()), l) 332 | self.assertEqual(len(data.limit(i * 2).take(i)), l) 333 | 334 | 335 | class FormatTest(SimpleTypesTestBase): 336 | expected = Row(key='format-test', int=2, text='a') 337 | 338 | def setUp(self): 339 | super(FormatTest, self).setUp() 340 | self.sc.parallelize([self.expected]).saveToCassandra(self.keyspace, self.table) 341 | 342 | def read_as(self, row_format, keyed): 343 | table = self.rdd(row_format=row_format) 344 | if keyed: 345 | table = table.by_primary_key() 346 | table = table.where('key=?', self.expected.key) 347 | return table.first() 348 | 349 | def assert_rowtype(self, row_format, row_type, keyed=False): 350 | row = self.read_as(row_format, keyed) 351 | self.assertEqual(type(row), row_type) 352 | return row 353 | 354 | def assert_kvtype(self, row_format, kv_type): 355 | row = self.assert_rowtype(row_format, tuple, keyed=True) 356 | self.assertEqual(len(row), 2) 357 | k, v = row 358 | self.assertEqual(type(k), kv_type) 359 | self.assertEqual(type(v), kv_type) 360 | return k, v 361 | 362 | def test_tuple(self): 363 | row = self.assert_rowtype(RowFormat.TUPLE, tuple) 364 | self.assertEqual(self.expected.key, row[0]) 365 | 366 | def test_kvtuple(self): 367 | k, _ = self.assert_kvtype(RowFormat.TUPLE, tuple) 368 | self.assertEqual(self.expected.key, k[0]) 369 | 370 | def test_dict(self): 371 | row = self.assert_rowtype(RowFormat.DICT, dict) 372 | self.assertEqual(self.expected.key, row['key']) 373 | 374 | def test_kvdict(self): 375 | k, _ = self.assert_kvtype(RowFormat.DICT, dict) 376 | self.assertEqual(self.expected.key, k['key']) 377 | 378 | def test_row(self): 379 | row = self.assert_rowtype(RowFormat.ROW, pyspark_cassandra.Row) 380 | self.assertEqual(self.expected.key, row.key) 381 | 382 | def test_kvrow(self): 383 | k, _ = self.assert_kvtype(RowFormat.ROW, pyspark_cassandra.Row) 384 | self.assertEqual(self.expected.key, k.key) 385 | 386 | 387 | 388 | class ConfTest(SimpleTypesTestBase): 389 | # TODO this is still a very basic test, more cases and (better) validation required 390 | def setUp(self): 391 | super(SimpleTypesTestBase, self).setUp() 392 | for i in range(100): 393 | self.session.execute( 394 | "INSERT INTO %s (key, text, int) values ('%s', '%s', %s)" 395 | % (self.table, i, i, i) 396 | ) 397 | 398 | def test_read_conf(self): 399 | self.rdd(split_count=100).collect() 400 | self.rdd(split_size=32).collect() 401 | self.rdd(fetch_size=100).collect() 402 | self.rdd(consistency_level='LOCAL_QUORUM').collect() 403 | self.rdd(consistency_level=ConsistencyLevel.LOCAL_QUORUM).collect() 404 | self.rdd(metrics_enabled=True).collect() 405 | self.rdd(read_conf=ReadConf(split_count=10, consistency_level='ALL')).collect() 406 | self.rdd(read_conf=ReadConf(consistency_level='ALL', metrics_enabled=True)).collect() 407 | 408 | def test_write_conf(self): 409 | rdd = self.sc.parallelize([{'key':i, 'text':i, 'int':i} for i in range(10)]) 410 | save = partial(rdd.saveToCassandra, self.keyspace, self.table) 411 | 412 | save(batch_size=100) 413 | save(batch_buffer_size=100) 414 | save(batch_grouping_key='replica_set') 415 | save(batch_grouping_key='partition') 416 | save(consistency_level='ALL') 417 | save(consistency_level=ConsistencyLevel.LOCAL_QUORUM) 418 | save(parallelism_level=10) 419 | save(throughput_mibps=10) 420 | save(ttl=5) 421 | save(ttl=timedelta(minutes=30)) 422 | save(timestamp=time.clock() * 1000 * 1000) 423 | save(timestamp=datetime.now()) 424 | save(metrics_enabled=True) 425 | save(write_conf=WriteConf(ttl=3, metrics_enabled=True)) 426 | 427 | 428 | class StreamingTest(SimpleTypesTestBase): 429 | interval = .1 430 | 431 | size = 10 432 | count = 3 433 | 434 | rows = [ 435 | [ 436 | {'key': str(j * size + i), 'text': str(j * size + i)} 437 | for i in range(size) 438 | ] 439 | for j in range(count) 440 | ] 441 | 442 | @classmethod 443 | def setUpClass(cls): 444 | super(StreamingTest, cls).setUpClass() 445 | cls.ssc = StreamingContext(cls.sc, cls.interval) 446 | 447 | def setUp(self): 448 | super(StreamingTest, self).setUp() 449 | self.rdds = list(map(self.sc.parallelize, self.rows)) 450 | self.stream = self.ssc.queueStream(self.rdds) 451 | 452 | def test(self): 453 | self.stream.saveToCassandra(self.keyspace, self.table) 454 | 455 | self.ssc.start() 456 | self.ssc.awaitTermination((self.count + 1) * self.interval) 457 | self.ssc.stop(stopSparkContext=False, stopGraceFully=True) 458 | 459 | tbl = self.rdd(row_format=RowFormat.TUPLE).select('key', 'text') 460 | read = tbl.by_primary_key().collect() 461 | self.assertEqual(len(read), self.size * self.count) 462 | for (k, v) in read: 463 | self.assertEqual(k, v) 464 | 465 | 466 | class JoinRDDTest(SimpleTypesTestBase): 467 | 468 | def setUp(self): 469 | super(JoinRDDTest, self).setUp() 470 | 471 | def test_simple_pk(self): 472 | table = 'join_rdd_test_simple_pk' 473 | 474 | self.session.execute(''' 475 | CREATE TABLE IF NOT EXISTS ''' + table + ''' ( 476 | key text primary key, value text 477 | ) 478 | ''') 479 | self.session.execute('TRUNCATE %s' % table) 480 | 481 | rows = { 482 | str(c) : str(i) for i, c in 483 | enumerate(string.ascii_lowercase) 484 | } 485 | 486 | for k, v in rows.items(): 487 | self.session.execute( 488 | 'INSERT INTO ' + table + ' (key, value) values (%s, %s)', (k, v) 489 | ) 490 | 491 | rdd = self.sc.parallelize(rows.items()) 492 | self.assertEqual(dict(rdd.collect()), rows) 493 | 494 | tbl = rdd.joinWithCassandraTable(self.keyspace, table) 495 | joined = tbl.on('key').select('key', 'value').cache() 496 | self.assertEqual(dict(joined.keys().collect()), dict(joined.values().collect())) 497 | for (k, v) in joined.collect(): 498 | self.assertEqual(k, v) 499 | 500 | 501 | def test_composite_pk(self): 502 | table = 'join_rdd_test_composite_pk' 503 | 504 | self.session.execute(''' 505 | CREATE TABLE IF NOT EXISTS ''' + table + ''' ( 506 | pk text, cc text, value text, 507 | primary key (pk, cc) 508 | ) 509 | ''') 510 | self.session.execute('TRUNCATE %s' % table) 511 | 512 | rows = [ 513 | # (pk, cc, pk + '-' + cc) 514 | (unicode(pk), unicode(cc), unicode(pk + '-' + cc)) 515 | for pk in string.ascii_lowercase[:3] 516 | for cc in (str(i) for i in range(3)) 517 | ] 518 | 519 | for row in rows: 520 | self.session.execute( 521 | 'INSERT INTO ' + table + ' (pk, cc, value) values (%s, %s, %s)', row 522 | ) 523 | 524 | rdd = self.sc.parallelize(rows) 525 | 526 | joined = rdd.joinWithCassandraTable(self.keyspace, table).on('pk', 'cc') 527 | self.assertEqual(sorted(zip(rows, rows)), sorted(joined.map(tuple).collect())) 528 | 529 | joined = rdd.joinWithCassandraTable(self.keyspace, table).on('pk') 530 | self.assertEqual(len(rows) * sqrt(len(rows)), joined.count()) 531 | 532 | 533 | # TODO test 534 | # .where() 535 | # .limit() 536 | 537 | 538 | 539 | 540 | class JoinDStreamTest(StreamingTest): 541 | def setUp(self): 542 | super(JoinDStreamTest, self).setUp() 543 | self.joined_rows = self.sc.accumulator([], accum_param=AddingAccumulatorParam([])) 544 | 545 | def checkRDD(self, time, rdd): 546 | self.joined_rows += rdd.collect() 547 | 548 | def test(self): 549 | rows = list(chain(*self.rows)) 550 | rows_by_key = {row['key'] : row for row in rows} 551 | 552 | self.sc \ 553 | .parallelize(rows) \ 554 | .saveToCassandra(self.keyspace, self.table) 555 | 556 | self.stream \ 557 | .joinWithCassandraTable(self.keyspace, self.table, ['text'], ['key']) \ 558 | .foreachRDD(self.checkRDD) 559 | 560 | self.ssc.start() 561 | self.ssc.awaitTermination((self.count + 1) * self.interval) 562 | self.ssc.stop(stopSparkContext=False, stopGraceFully=True) 563 | 564 | joined_rows = self.joined_rows.value 565 | self.assertEqual(len(joined_rows), len(rows)) 566 | for row in joined_rows: 567 | self.assertEqual(len(row), 2) 568 | left, right = row 569 | 570 | self.assertEqual(type(left), type(right)) 571 | self.assertEqual(rows_by_key[left['key']], left) 572 | self.assertEqual(left['text'], right['text']) 573 | self.assertEqual(len(right), 1) 574 | 575 | 576 | 577 | class RegressionTest(CassandraTestCase): 578 | def test_64(self): 579 | self.session.execute(''' 580 | CREATE TABLE IF NOT EXISTS test_64 ( 581 | delay double PRIMARY KEY, 582 | pdf list, 583 | pos list 584 | ) 585 | ''') 586 | self.session.execute('''TRUNCATE test_64''') 587 | 588 | res = ([0.0, 1.0, 2.0], [12.0, 3.0, 0.0], 0.0) 589 | rdd = self.sc.parallelize([res]) 590 | rdd.saveToCassandra(self.keyspace, 'test_64', columns=['pos', 'pdf', 'delay']) 591 | 592 | row = self.rdd(table='test_64').first() 593 | self.assertEqual(row.pos, res[0]) 594 | self.assertEqual(row.pdf, res[1]) 595 | self.assertEqual(row.delay, res[2]) 596 | 597 | def test_89(self): 598 | self.session.execute(''' 599 | CREATE TABLE IF NOT EXISTS test_89 ( 600 | id text PRIMARY KEY, 601 | val text 602 | ) 603 | ''') 604 | self.session.execute('''TRUNCATE test_89''') 605 | 606 | self.sc.parallelize([dict(id='a', val='b')]).saveToCassandra(self.keyspace, 'test_89') 607 | joined = (self.sc 608 | .parallelize([dict(id='a', uuid=UUID('27776620-e46e-11e5-a837-0800200c9a66'))]) 609 | .joinWithCassandraTable(self.keyspace, 'test_89') 610 | .collect() 611 | ) 612 | 613 | self.assertEqual(len(joined), 1) 614 | self.assertEqual(len(joined[0]), 2) 615 | left, right = joined[0] 616 | self.assertEqual(left['id'], 'a') 617 | self.assertEqual(left['uuid'], UUID('27776620-e46e-11e5-a837-0800200c9a66')) 618 | self.assertEqual(right['id'], 'a') 619 | self.assertEqual(right['val'], 'b') 620 | 621 | def test_93(self): 622 | self.session.execute(''' 623 | CREATE TABLE IF NOT EXISTS test_93 ( 624 | name text, 625 | data_final blob, 626 | data_inter blob, 627 | family_label text, 628 | rand double, 629 | source text, 630 | score float, 631 | PRIMARY KEY (name) 632 | ) 633 | ''') 634 | 635 | self.sc.parallelize([ 636 | Row(name=str(i), data_final=bytearray(str(i)), data_inter=bytearray(str(i)), 637 | family_label=str(i), rand=i / 10, source=str(i), score=i * 10) 638 | for i in range(4) 639 | ]).saveToCassandra(self.keyspace, 'test_93') 640 | 641 | joined = (self.sc 642 | .parallelize([ 643 | Row(name='1', score=0.4), 644 | Row(name='2', score=0.5), 645 | ]) 646 | .joinWithCassandraTable(self.keyspace, 'test_93') 647 | .on('name').collect() 648 | ) 649 | 650 | self.assertEqual(len(joined), 2) 651 | 652 | 653 | 654 | 655 | if __name__ == '__main__': 656 | try: 657 | # connect to cassandra and create a keyspace for testing 658 | CassandraTestCase.session = Cluster().connect() 659 | CassandraTestCase.session.execute(''' 660 | CREATE KEYSPACE IF NOT EXISTS %s 661 | WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; 662 | ''' % (CassandraTestCase.keyspace,)) 663 | CassandraTestCase.session.set_keyspace(CassandraTestCase.keyspace) 664 | 665 | # create a cassandra spark context 666 | CassandraTestCase.sc = CassandraSparkContext(conf=SparkConf().setAppName("PySpark Cassandra Test")) 667 | 668 | # perform the unit tests 669 | unittest.main() 670 | # suite = unittest.TestLoader().loadTestsFromTestCase(RegressionTest) 671 | # unittest.TextTestRunner().run(suite) 672 | finally: 673 | # stop the spark context and cassandra session 674 | # stop the spark context and cassandra session 675 | if hasattr(CassandraTestCase, 'sc'): 676 | CassandraTestCase.sc.stop() 677 | if hasattr(CassandraTestCase, 'session'): 678 | CassandraTestCase.session.shutdown() 679 | 680 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/types.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from datetime import datetime, tzinfo, timedelta 14 | from itertools import chain 15 | from operator import itemgetter 16 | import struct 17 | 18 | 19 | try: 20 | # import accessed as globals, see _create_spanning_dataframe(...) 21 | import numpy as np # @UnusedImport 22 | import pandas as pd # @UnusedImport 23 | except ImportError: 24 | pass 25 | 26 | 27 | 28 | def _create_row(fields, values): 29 | return _create_struct(Row, fields, values) 30 | 31 | def _create_udt(fields, values): 32 | return _create_struct(UDT, fields, values) 33 | 34 | def _create_struct(cls, fields, values): 35 | d = {k: v for k, v in zip(fields, values)} 36 | return cls(**d) 37 | 38 | 39 | # TODO replace this datastructure with something faster 40 | # but functionally equivalent! 41 | # 42 | # >>> %timeit dict(x=1, y=2, z=3) 43 | # 1000000 loops, best of 3: 292 ns per loop 44 | # 45 | # >>> %timeit pyspark_cassandra.Row(x=1, y=2, z=3) 46 | # 1000000 loops, best of 3: 1.41 µs per loop 47 | # 48 | # >>> %timeit FastRow(x=1, y=2, z=3) 49 | # 1000000 loops, best of 3: 666 ns per loop 50 | # 51 | # where FastRow = namedtuple('FastRow', ['x', 'y', 'z']) 52 | 53 | 54 | class Struct(tuple): 55 | """Adaptation from the pyspark.sql.Row which better supports adding fields""" 56 | 57 | def __new__(cls, **kwargs): 58 | if not kwargs: 59 | raise ValueError("Cannot construct empty %s" % cls) 60 | 61 | struct = tuple.__new__(cls) 62 | struct.__FIELDS__ = kwargs 63 | return struct 64 | 65 | 66 | def asDict(self): 67 | return self.__dict__() 68 | 69 | def __dict__(self): 70 | return self.__FIELDS__ 71 | 72 | def __iter__(self): 73 | return iter(self.__FIELDS__.values()) 74 | 75 | @property 76 | def _fields(self): 77 | return self.keys() 78 | 79 | def keys(self): 80 | return self.__FIELDS__.keys() 81 | 82 | def values(self): 83 | return self.__FIELDS__.values() 84 | 85 | 86 | def __len__(self): 87 | return len(self.__FIELDS__) 88 | 89 | def __hash__(self): 90 | h = 1 91 | for v in chain(self.keys(), self.values()): 92 | h = 31 * h + hash(v) 93 | return h 94 | 95 | def __eq__(self, other): 96 | try: 97 | return self.__FIELDS__.__eq__(other.__FIELDS__) 98 | except AttributeError: 99 | return False 100 | 101 | def __ne__(self, other): 102 | return not self == other 103 | 104 | 105 | def __add__(self, other): 106 | d = dict(self.__FIELDS__) 107 | d.update(other.__FIELDS__) 108 | return self.__class__(**d) 109 | 110 | def __sub__(self, other): 111 | d = { k:v for k, v in self.__FIELDS__.items() if k not in other } 112 | return self.__class__(**d) 113 | 114 | def __and__(self, other): 115 | d = { k:v for k, v in self.__FIELDS__.items() if k in other } 116 | return self.__class__(**d) 117 | 118 | 119 | def __contains__(self, name): 120 | return name in self.__FIELDS__ 121 | 122 | 123 | def __setitem__(self, name, value): 124 | self.__setattr__(name, value) 125 | 126 | def __delitem__(self, name): 127 | self.__delattr__(name) 128 | 129 | def __getitem__(self, name): 130 | return self.__getattr__(name) 131 | 132 | 133 | def __getattr__(self, name): 134 | try: 135 | return self.__FIELDS__[name] 136 | except KeyError: 137 | raise AttributeError(name) 138 | 139 | def __setattr__(self, name, value): 140 | if name == "__FIELDS__": 141 | tuple.__setattr__(self, name, value) 142 | else: 143 | self.__FIELDS__[name] = value 144 | 145 | def __delattr__(self, name): 146 | try: 147 | del self.__FIELDS__[name] 148 | except KeyError: 149 | raise AttributeError(name) 150 | 151 | 152 | def __getstate__(self): 153 | return self.__dict__() 154 | 155 | def __reduce__(self): 156 | keys = list(self.__FIELDS__.keys()) 157 | values = list(self.__FIELDS__.values()) 158 | return (self._creator(), (keys, values,)) 159 | 160 | 161 | def __repr__(self): 162 | fields = sorted(self.__FIELDS__.items(), key=itemgetter(0)) 163 | values = ", ".join("%s=%r" % (k, v) for k, v in fields if k != '__FIELDS__') 164 | return "%s(%s)" % (self.__class__.__name__, values) 165 | 166 | 167 | 168 | class Row(Struct): 169 | def _creator(self): 170 | return _create_row 171 | 172 | class UDT(Struct): 173 | def _creator(self): 174 | return _create_udt 175 | 176 | 177 | 178 | def _create_spanning_dataframe(cnames, ctypes, cvalues): 179 | ''' 180 | Constructs a 'dataframe' from column names, numpy column types and 181 | the column values. 182 | 183 | @param cnames: An iterable of name strings 184 | @param ctypes: An iterable of numpy dtypes as strings (e.g. '>f4') 185 | @param cvalues: An iterable of 186 | 187 | Note that cnames, ctypes and cvalues are expected to have equal length. 188 | ''' 189 | 190 | if len(cnames) != len(ctypes) or len(ctypes) != len(cvalues): 191 | raise ValueError('The lengths of cnames, ctypes and cvalues must equal') 192 | 193 | # convert the column values to numpy arrays if numpy is available 194 | # otherwise use lists 195 | global np 196 | convert = _to_nparrays if np else _to_list 197 | arrays = {n : convert(t, v) for n, t, v in zip(cnames, ctypes, cvalues)} 198 | 199 | # if pandas is available, provide the arrays / lists as DataFrame 200 | # otherwise use pyspark_cassandra.Row 201 | global pd 202 | if pd: 203 | return pd.DataFrame(arrays) 204 | else: 205 | return Row(**arrays) 206 | 207 | 208 | def _to_nparrays(ctype, cvalue): 209 | if isinstance(cvalue, (bytes, bytearray)): 210 | # The array is byte swapped and set to little-endian. java encodes 211 | # ints, longs, floats, etc. in big-endian. 212 | # This costs some cycles (around 1 ms per 1*10^6 elements) but when 213 | # using it it saves some when using the array (around 25 to 50 % which 214 | # for summing amounts to half a ms) 215 | # (the perf numbers above are on an Intel i5-4200M) 216 | # Also it solves an issue with pickling datetime64 arrays see 217 | # https://github.com/numpy/numpy/issues/5883 218 | return np.frombuffer(cvalue, ctype).byteswap(True).newbyteorder('<') 219 | else: 220 | return np.fromiter(cvalue, None) 221 | 222 | 223 | def _to_list(ctype, cvalue): 224 | if isinstance(cvalue, (bytes, bytearray)): 225 | return _decode_primitives(ctype, cvalue) 226 | elif hasattr(cvalue, '__len__'): 227 | return cvalue 228 | else: 229 | return list(cvalue) 230 | 231 | # from https://docs.python.org/3/library/datetime.html 232 | ZERO = timedelta(0) 233 | 234 | class UTC(tzinfo): 235 | def utcoffset(self, dt): 236 | return ZERO 237 | 238 | def tzname(self, dt): 239 | return "UTC" 240 | 241 | def dst(self, dt): 242 | return ZERO 243 | 244 | def __repr__(self): 245 | return self.__class__.__name__ 246 | 247 | utc = UTC() 248 | 249 | 250 | _numpy_to_struct_formats = { 251 | '>b1': '?', 252 | 'i4': '>i', 253 | '>i8': '>q', 254 | '>f4': '>f', 255 | '>f8': '>d', 256 | '>M8[ms]': '>q', 257 | } 258 | 259 | def _decode_primitives(ctype, cvalue): 260 | fmt = _numpy_to_struct_formats.get(ctype) 261 | 262 | # if unsupported, return as the list if bytes it was 263 | if not fmt: 264 | return cvalue 265 | 266 | primitives = _unpack(fmt, cvalue) 267 | 268 | if ctype == '>M8[ms]': 269 | return [datetime.utcfromtimestamp(l).replace(tzinfo=UTC) for l in primitives] 270 | else: 271 | return primitives 272 | 273 | 274 | def _unpack(fmt, cvalue): 275 | stride = struct.calcsize(fmt) 276 | if len(cvalue) % stride != 0: 277 | raise ValueError('number of bytes must be a multiple of %s for format %s' % (stride, fmt)) 278 | 279 | return [struct.unpack(cvalue[o:o + stride]) for o in range(len(cvalue) / stride, stride)] 280 | 281 | -------------------------------------------------------------------------------- /python/pyspark_cassandra/util.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | 14 | 15 | from collections import Set, Iterable, Mapping 16 | from datetime import datetime 17 | from time import mktime 18 | 19 | from pyspark_cassandra.types import UDT 20 | 21 | 22 | def as_java_array(gateway, java_type, iterable): 23 | """Creates a Java array from a Python iterable, using the given p4yj gateway""" 24 | 25 | if iterable is None: 26 | return None 27 | 28 | java_type = gateway.jvm.__getattr__(java_type) 29 | lst = list(iterable) 30 | arr = gateway.new_array(java_type, len(lst)) 31 | 32 | for i, e in enumerate(lst): 33 | jobj = as_java_object(gateway, e) 34 | arr[i] = jobj 35 | 36 | return arr 37 | 38 | 39 | def as_java_object(gateway, obj): 40 | """ 41 | Converts a limited set of types to their corresponding types in java. Supported are 42 | 'primitives' (which aren't converted), datetime.datetime and the set-, dict- and 43 | iterable-like types. 44 | """ 45 | 46 | if obj is None: 47 | return None 48 | 49 | t = type(obj) 50 | 51 | if issubclass(t, (bool, int, float, str)): 52 | return obj 53 | 54 | elif issubclass(t, UDT): 55 | field_names = as_java_array(gateway, "String", obj.keys()) 56 | field_values = as_java_array(gateway, "Object", obj.values()) 57 | udt = gateway.jvm.UDTValueConverter(field_names, field_values) 58 | return udt.toConnectorType() 59 | 60 | elif issubclass(t, datetime): 61 | timestamp = int(mktime(obj.timetuple()) * 1000) 62 | return gateway.jvm.java.util.Date(timestamp) 63 | 64 | elif issubclass(t, (dict, Mapping)): 65 | hash_map = gateway.jvm.java.util.HashMap() 66 | for (k, v) in obj.items(): hash_map[k] = v 67 | return hash_map 68 | 69 | elif issubclass(t, (set, Set)): 70 | hash_set = gateway.jvm.java.util.HashSet() 71 | for e in obj: hash_set.add(e) 72 | return hash_set 73 | 74 | elif issubclass(t, (list, Iterable)): 75 | array_list = gateway.jvm.java.util.ArrayList() 76 | for e in obj: array_list.append(e) 77 | return array_list 78 | 79 | else: 80 | return obj 81 | 82 | 83 | def load_class(ctx, name): 84 | return ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ 85 | .loadClass(name) 86 | 87 | _helper = None 88 | 89 | def helper(ctx): 90 | global _helper 91 | 92 | if not _helper: 93 | _helper = load_class(ctx, "pyspark_cassandra.PythonHelper").newInstance() 94 | 95 | return _helper 96 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from setuptools import setup, find_packages 18 | 19 | 20 | basedir = os.path.dirname(os.path.abspath(__file__)) 21 | os.chdir(basedir) 22 | 23 | def f(*path): 24 | return open(os.path.join(basedir, *path)) 25 | 26 | setup( 27 | name='pyspark_cassandra', 28 | maintainer='Frens Jan Rumph', 29 | maintainer_email='frens.jan.rumph@target-holding.nl', 30 | version=f('../version.txt').read().strip(), 31 | description='Utilities to asssist in working with Cassandra and PySpark.', 32 | long_description=f('../README.md').read(), 33 | url='https://github.com/TargetHolding/pyspark-cassandra', 34 | license='Apache License 2.0', 35 | 36 | packages=find_packages(), 37 | include_package_data=True, 38 | 39 | classifiers=[ 40 | 'Development Status :: 2 - Pre-Alpha', 41 | 'Environment :: Other Environment', 42 | 'Framework :: Django', 43 | 'Intended Audience :: Developers', 44 | 'License :: OSI Approved :: Apache Software License', 45 | 'Operating System :: OS Independent', 46 | 'Programming Language :: Python', 47 | 'Programming Language :: Python :: 2', 48 | 'Programming Language :: Python :: 2.7', 49 | 'Topic :: Database', 50 | 'Topic :: Software Development :: Libraries', 51 | 'Topic :: Scientific/Engineering :: Information Analysis', 52 | 'Topic :: Utilities', 53 | ] 54 | ) 55 | -------------------------------------------------------------------------------- /sbin/local.sh: -------------------------------------------------------------------------------- 1 | DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )/.. 2 | 3 | VERSION=`cat version.txt` 4 | 5 | PYSPARK_DRIVER_PYTHON=ipython \ 6 | $DIR/lib/spark-1.6.0-bin-hadoop2.6/bin/pyspark \ 7 | --conf spark.cassandra.connection.host="localhost" \ 8 | --driver-memory 2g \ 9 | --master local[*] \ 10 | --jars $DIR/target/scala-2.10/pyspark-cassandra-assembly-$VERSION.jar \ 11 | --py-files $DIR/target/scala-2.10/pyspark-cassandra-assembly-$VERSION.jar \ 12 | $@ 13 | 14 | -------------------------------------------------------------------------------- /sbin/notebook.sh: -------------------------------------------------------------------------------- 1 | DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )/.. 2 | 3 | VERSION=`cat version.txt` 4 | 5 | IPYTHON_OPTS=notebook \ 6 | PYSPARK_DRIVER_PYTHON=ipython \ 7 | $DIR/lib/spark-1.6.0-bin-hadoop2.6/bin/pyspark \ 8 | --conf spark.cassandra.connection.host="localhost" \ 9 | --driver-memory 2g \ 10 | --master local[*] \ 11 | --jars $DIR/target/scala-2.10/pyspark-cassandra-assembly-$VERSION.jar \ 12 | --py-files $DIR/target/scala-2.10/pyspark-cassandra-assembly-$VERSION.jar \ 13 | $@ 14 | 15 | -------------------------------------------------------------------------------- /sbin/profile.sh: -------------------------------------------------------------------------------- 1 | DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) 2 | 3 | $DIR/pyspark_cassandra.sh --conf spark.python.profile=true $@ 4 | 5 | -------------------------------------------------------------------------------- /sbin/released.sh: -------------------------------------------------------------------------------- 1 | #export IPYTHON_OPTS="notebook" 2 | 3 | DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )/.. 4 | 5 | PYSPARK_DRIVER_PYTHON=ipython \ 6 | $DIR/lib/spark-1.4.1-bin-hadoop2.6/bin/pyspark \ 7 | --conf spark.cassandra.connection.host="localhost" \ 8 | --driver-memory 2g \ 9 | --master local[*] \ 10 | --packages TargetHolding/pyspark-cassandra:0.2.2 \ 11 | $@ 12 | 13 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/Pickling.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | import pyspark_util.Conversions._ 18 | import pyspark_util.{ Pickling => PicklingUtils, _ } 19 | import java.io.OutputStream 20 | import java.math.BigInteger 21 | import java.net.{ Inet4Address, Inet6Address, InetAddress } 22 | import java.nio.ByteBuffer 23 | import java.nio.channels.Channels 24 | import java.util.{ ArrayList, Collection, HashMap, List => JList, Map => JMap, UUID } 25 | import scala.reflect.ClassTag 26 | import scala.collection.JavaConversions._ 27 | import scala.collection.immutable.HashMap.HashTrieMap 28 | import scala.collection.immutable.List 29 | import scala.collection.immutable.Map.{ Map1, Map2, Map3, Map4, WithDefault } 30 | import scala.collection.mutable.{ ArraySeq, Buffer, WrappedArray } 31 | import scala.reflect.runtime.universe.typeTag 32 | import com.datastax.driver.core.{ ProtocolVersion, UDTValue => DriverUDTValue } 33 | import com.datastax.spark.connector.UDTValue 34 | import com.datastax.spark.connector.types.TypeConverter 35 | import net.razorvine.pickle.{ IObjectConstructor, IObjectPickler, Opcodes, PickleUtils, Pickler, Unpickler } 36 | import org.apache.spark.rdd.RDD 37 | import org.apache.spark.streaming.dstream.DStream 38 | import java.io.NotSerializableException 39 | import com.datastax.spark.connector.GettableData 40 | 41 | class Pickling extends PicklingUtils { 42 | override def register() { 43 | super.register() 44 | 45 | Unpickler.registerConstructor("pyspark.sql", "_create_row", PlainRowUnpickler) 46 | Unpickler.registerConstructor("pyspark_cassandra.types", "_create_row", PlainRowUnpickler) 47 | Unpickler.registerConstructor("pyspark_cassandra.types", "_create_udt", UDTValueUnpickler) 48 | 49 | Pickler.registerCustomPickler(classOf[Row], PlainRowPickler) 50 | Pickler.registerCustomPickler(classOf[UDTValue], UDTValuePickler) 51 | Pickler.registerCustomPickler(classOf[DriverUDTValue], DriverUDTValuePickler) 52 | Pickler.registerCustomPickler(classOf[DataFrame], DataFramePickler) 53 | } 54 | } 55 | 56 | object PlainRowPickler extends StructPickler { 57 | def creator = "pyspark_cassandra.types\n_create_row\n" 58 | def fields(o: Any) = o.asInstanceOf[Row].fields 59 | def values(o: Any, fields: Seq[_]) = o.asInstanceOf[Row].values 60 | } 61 | 62 | object PlainRowUnpickler extends StructUnpickler { 63 | def construct(fields: Seq[String], values: Seq[AnyRef]) = Row(fields, values) 64 | } 65 | 66 | object UDTValuePickler extends StructPickler { 67 | def creator = "pyspark_cassandra.types\n_create_udt\n" 68 | def fields(o: Any) = o.asInstanceOf[UDTValue].columnNames 69 | def values(o: Any, fields: Seq[_]) = o.asInstanceOf[UDTValue].columnValues 70 | } 71 | 72 | object DriverUDTValuePickler extends StructPickler { 73 | def creator = "pyspark_cassandra.types\n_create_udt\n" 74 | 75 | def fields(o: Any) = o.asInstanceOf[DriverUDTValue].getType().getFieldNames().toSeq 76 | 77 | def values(o: Any, fields: Seq[_]) = { 78 | val v = o.asInstanceOf[DriverUDTValue] 79 | v.getType().map { 80 | field => v.getObject(field.getName) 81 | }.toList 82 | } 83 | } 84 | 85 | object UDTValueUnpickler extends StructUnpickler { 86 | def construct(fields: Seq[String], values: Seq[AnyRef]) = { 87 | val f = asArray[String](fields) 88 | val v = asArray[AnyRef](values) 89 | UDTValue(f, v) 90 | } 91 | } 92 | 93 | object DataFramePickler extends IObjectPickler { 94 | def pickle(o: Any, out: OutputStream, pickler: Pickler): Unit = { 95 | val df = o.asInstanceOf[DataFrame] 96 | 97 | val columns = df.values.map { 98 | v => 99 | v(0) match { 100 | case c: ByteBuffer => new GatheredByteBuffers(v.asInstanceOf[List[ByteBuffer]]) 101 | case c => c 102 | } 103 | } 104 | 105 | out.write(Opcodes.GLOBAL) 106 | out.write("pyspark_cassandra.types\n_create_spanning_dataframe\n".getBytes()) 107 | out.write(Opcodes.MARK) 108 | pickler.save(df.names) 109 | pickler.save(df.types) 110 | pickler.save(columns) 111 | out.write(Opcodes.TUPLE) 112 | out.write(Opcodes.REDUCE) 113 | } 114 | } 115 | 116 | object UnpickledUUIDConverter extends TypeConverter[UUID] { 117 | val tt = typeTag[UUID] 118 | def targetTypeTag = tt 119 | def convertPF = { case holder: UUIDHolder => holder.uuid } 120 | } 121 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/PythonHelper.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | import pyspark_cassandra.Utils._ 18 | import pyspark_util.Pickling._ 19 | import pyspark_util.Conversions._ 20 | import java.lang.Boolean 21 | import java.util.{ List => JList, Map => JMap } 22 | import scala.collection.JavaConversions._ 23 | import org.apache.spark.SparkContext 24 | import org.apache.spark.api.java.{ JavaRDD, JavaSparkContext } 25 | import org.apache.spark.rdd.RDD 26 | import org.apache.spark.streaming.api.java.JavaDStream 27 | import org.apache.spark.streaming.dstream.DStream 28 | import com.datastax.driver.core.ConsistencyLevel 29 | import com.datastax.spark.connector._ 30 | import com.datastax.spark.connector.rdd._ 31 | import com.datastax.spark.connector.streaming.toDStreamFunctions 32 | import com.datastax.spark.connector.writer._ 33 | import com.datastax.spark.connector.types.TypeConverter 34 | 35 | class PythonHelper() { 36 | TypeConverter.registerConverter(UnpickledUUIDConverter) 37 | implicit val pickling = new Pickling() 38 | 39 | /* ----------------------------------------------------------------------- */ 40 | /* loading from cassandra ------------------------------------------------ */ 41 | /* ----------------------------------------------------------------------- */ 42 | 43 | def cassandraTable(jsc: JavaSparkContext, keyspace: String, table: String, readConf: JMap[String, Any]) = { 44 | val conf = parseReadConf(jsc.sc, Some(readConf)) 45 | implicit val rrf = new DeferringRowReaderFactory() 46 | jsc.sc.cassandraTable(keyspace, table).withReadConf(conf) 47 | } 48 | 49 | def select(rdd: CassandraRDD[UnreadRow], columns: Array[String]) = { 50 | rdd.select(columns.map { new ColumnName(_) }: _*) 51 | } 52 | 53 | def limit(rdd: CassandraRDD[UnreadRow], lim: Long) = { 54 | rdd.limit(lim) 55 | } 56 | 57 | def where(rdd: CassandraRDD[UnreadRow], cql: String, values: Array[Any]) = { 58 | rdd.where(cql, values: _*) 59 | } 60 | 61 | def cassandraCount(rdd: CassandraRDD[UnreadRow]) = rdd.cassandraCount() 62 | 63 | /* ----------------------------------------------------------------------- */ 64 | /* span by columns ------------------------------------------------------- */ 65 | /* ----------------------------------------------------------------------- */ 66 | 67 | def spanBy(rdd: RDD[UnreadRow], columns: Array[String]) = { 68 | SpanBy.binary(rdd, columns) 69 | } 70 | 71 | /* ----------------------------------------------------------------------- */ 72 | /* save to cassandra ----------------------------------------------------- */ 73 | /* ----------------------------------------------------------------------- */ 74 | 75 | /* rdds ------------------------------------------------------------------ */ 76 | 77 | def saveToCassandra(rdd: JavaRDD[Array[Byte]], keyspace: String, table: String, columns: Array[String], 78 | rowFormat: Integer, keyed: Boolean, writeConf: JMap[String, Any]) = { 79 | 80 | val selectedColumns = columnSelector(columns) 81 | val conf = parseWriteConf(Some(writeConf)) 82 | 83 | implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) 84 | rdd.rdd.unpickle().saveToCassandra(keyspace, table, selectedColumns, conf) 85 | } 86 | 87 | /* dstreams -------------------------------------------------------------- */ 88 | 89 | def saveToCassandra(dstream: JavaDStream[Array[Byte]], keyspace: String, table: String, columns: Array[String], 90 | rowFormat: Integer, keyed: Boolean, writeConf: JMap[String, Any]) = { 91 | 92 | val selectedColumns = columnSelector(columns) 93 | val conf = parseWriteConf(Some(writeConf)) 94 | 95 | implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) 96 | dstream.dstream.unpickle().saveToCassandra(keyspace, table, selectedColumns, conf) 97 | } 98 | 99 | /* ----------------------------------------------------------------------- */ 100 | /* join with cassandra tables -------------------------------------------- */ 101 | /* ----------------------------------------------------------------------- */ 102 | 103 | /* rdds ------------------------------------------------------------------ */ 104 | 105 | def joinWithCassandraTable(rdd: JavaRDD[Array[Byte]], keyspace: String, table: String): CassandraJoinRDD[Any, UnreadRow] = { 106 | implicit val rwf = new GenericRowWriterFactory(None, None) 107 | implicit val rrf = new DeferringRowReaderFactory() 108 | rdd.rdd.unpickle().joinWithCassandraTable(keyspace, table) 109 | } 110 | 111 | def on(rdd: CassandraJoinRDD[Any, UnreadRow], columns: Array[String]) = { 112 | rdd.on(columnSelector(columns, PartitionKeyColumns)) 113 | } 114 | 115 | /* dstreams -------------------------------------------------------------- */ 116 | 117 | def joinWithCassandraTable(dstream: JavaDStream[Array[Byte]], keyspace: String, table: String, 118 | selectedColumns: Array[String], joinColumns: Array[String]): DStream[(Any, UnreadRow)] = { 119 | val columns = columnSelector(selectedColumns) 120 | val joinOn = columnSelector(joinColumns, PartitionKeyColumns) 121 | implicit val rwf = new GenericRowWriterFactory(None, None) 122 | implicit val rrf = new DeferringRowReaderFactory() 123 | dstream.dstream.unpickle().joinWithCassandraTable(keyspace, table, columns, joinOn) 124 | } 125 | 126 | /* ----------------------------------------------------------------------- */ 127 | /* utilities for moving rdds and dstreams from and to pyspark ------------ */ 128 | /* ----------------------------------------------------------------------- */ 129 | 130 | def pickleRows(rdd: CassandraRDD[UnreadRow]): RDD[Array[Byte]] = 131 | pickleRows(rdd, null) 132 | 133 | def pickleRows(rdd: CassandraRDD[UnreadRow], rowFormat: Integer): RDD[Array[Byte]] = 134 | pickleRows(rdd, rowFormat, false) 135 | 136 | def pickleRows(rdd: CassandraRDD[UnreadRow], rowFormat: Integer, keyed: Boolean) = { 137 | // TODO implement keying by arbitrary columns, analogues to the spanBy(...) and spark-cassandra-connector 138 | val parser = Format.parser(Format(rowFormat), asBooleanOption(keyed)) 139 | rdd.map(parser).pickle() 140 | } 141 | 142 | def pickleRows(rdd: CassandraJoinRDD[Any, UnreadRow], rowFormat: Integer, keyed: Boolean): RDD[Array[Byte]] = 143 | pickleRows(rdd) 144 | 145 | def pickleRows(rdd: CassandraJoinRDD[Any, UnreadRow], rowFormat: Integer): RDD[Array[Byte]] = 146 | pickleRows(rdd) 147 | 148 | def pickleRows(rdd: CassandraJoinRDD[Any, UnreadRow]) = 149 | rdd.map(new JoinedRowTransformer()).pickle() 150 | 151 | def pickleRows(dstream: DStream[(Any, UnreadRow)]) = { 152 | dstream.map(new JoinedRowTransformer()).transform((rdd, time) => rdd.pickle()) 153 | } 154 | 155 | def javaRDD(rdd: RDD[_]) = JavaRDD.fromRDD(rdd) 156 | 157 | def javaDStream(dstream: DStream[_]) = JavaDStream.fromDStream(dstream) 158 | } 159 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/RowReaders.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | import com.datastax.driver.core.{ ProtocolVersion, Row => DriverRow } 18 | import com.datastax.spark.connector.ColumnRef 19 | import com.datastax.spark.connector.cql.TableDef 20 | import com.datastax.spark.connector.rdd.reader.{RowReader, RowReaderFactory} 21 | import com.datastax.spark.connector.GettableData 22 | 23 | /** A container for a 'raw' row from the java driver, to be deserialized. */ 24 | case class UnreadRow(row: DriverRow, columnNames: Array[String], table: TableDef) { 25 | def deserialize(c: String) = { 26 | if (row.isNull(c)) null else GettableData.get(row, c) 27 | } 28 | 29 | def deserialize(c: Int) = { 30 | if (row.isNull(c)) null else GettableData.get(row, c) 31 | } 32 | } 33 | 34 | class DeferringRowReader(table: TableDef, selectedColumns: IndexedSeq[ColumnRef]) 35 | extends RowReader[UnreadRow] { 36 | 37 | def targetClass = classOf[UnreadRow] 38 | 39 | override def neededColumns: Option[Seq[ColumnRef]] = None // TODO or selected columns? 40 | 41 | override def read(row: DriverRow, columns: Array[String]): UnreadRow = { 42 | assert(row.getColumnDefinitions().size() >= columns.size, "Not enough columns available in row") 43 | UnreadRow(row, columns, table) 44 | } 45 | } 46 | 47 | class DeferringRowReaderFactory extends RowReaderFactory[UnreadRow] { 48 | def targetClass: Class[UnreadRow] = classOf[UnreadRow] 49 | 50 | def rowReader(table: TableDef, selectedColumns: IndexedSeq[ColumnRef]): RowReader[UnreadRow] = { 51 | new DeferringRowReader(table, selectedColumns) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/RowTransformers.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | trait FromUnreadRow[T] extends (UnreadRow => T) with Serializable 18 | 19 | // TODO consider replacying array of Map[String, Object] with a real tuple 20 | // not just here by the way, but all over the place ... this is Scala! 21 | trait ToKV[KV] extends FromUnreadRow[Array[Any]] { 22 | def apply(row: UnreadRow): Array[Any] = { 23 | val k = transform(row, row.columnNames.intersect(row.table.primaryKey.map { _.columnName })) 24 | val v = transform(row, row.columnNames.intersect(row.table.regularColumns.map { _.columnName })) 25 | Array(k, v) 26 | } 27 | 28 | def transform(row: UnreadRow, columns: Array[String]): KV 29 | } 30 | 31 | // TODO why ship field names for every row? 32 | case class Row(fields: Seq[String], values: Seq[AnyRef]) 33 | 34 | object ToRow extends FromUnreadRow[Row] { 35 | override def apply(row: UnreadRow): Row = { 36 | Row( 37 | row.columnNames, 38 | row.columnNames.map { 39 | c => row.deserialize(c) 40 | }) 41 | } 42 | } 43 | 44 | object ToKVRows extends ToKV[Row] { 45 | def transform(row: UnreadRow, columns: Array[String]): Row = { 46 | Row(columns, columns.map { c => row.deserialize(c) }) 47 | } 48 | } 49 | 50 | object ToTuple extends FromUnreadRow[Array[Any]] { 51 | def apply(row: UnreadRow): Array[Any] = { 52 | (row.columnNames.indices map { c => row.deserialize(c) }).toArray 53 | } 54 | } 55 | 56 | object ToKVTuple extends ToKV[Array[Any]] { 57 | def transform(row: UnreadRow, columns: Array[String]): Array[Any] = { 58 | columns.map { c => row.deserialize(c) } 59 | } 60 | } 61 | 62 | object ToDict extends FromUnreadRow[Map[String, Object]] { 63 | def apply(row: UnreadRow): Map[String, Object] = { 64 | Map(row.columnNames.zipWithIndex.map { case (c, i) => c -> row.deserialize(i) }: _*) 65 | } 66 | } 67 | 68 | object ToKVDicts extends ToKV[Map[String, Object]] { 69 | def transform(row: UnreadRow, columns: Array[String]): Map[String, Object] = { 70 | Map(columns.map { c => c -> row.deserialize(c) }: _*) 71 | } 72 | } 73 | 74 | class JoinedRowTransformer extends (((Any, UnreadRow)) => (Any, Any)) with Serializable { 75 | def apply(pair: (Any, UnreadRow)): (Any, Any) = { 76 | val parsed = Format.parser(pair._1).apply(pair._2) 77 | (pair._1, parsed) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/RowWriter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | import java.util.{ Map => JMap } 18 | 19 | import scala.collection.{ IndexedSeq, Seq } 20 | 21 | import com.datastax.spark.connector.ColumnRef 22 | import com.datastax.spark.connector.cql.TableDef 23 | import com.datastax.spark.connector.writer.{ RowWriter, RowWriterFactory } 24 | 25 | class GenericRowWriterFactory(format: Option[Format.Value], keyed: Option[Boolean]) extends RowWriterFactory[Any] { 26 | def rowWriter(table: TableDef, selectedColumns: IndexedSeq[ColumnRef]): RowWriter[Any] = { 27 | new GenericRowWriter(format, keyed, selectedColumns) 28 | } 29 | } 30 | 31 | class GenericRowWriter(format: Option[Format.Value], keyed: Option[Boolean], columns: IndexedSeq[ColumnRef]) extends RowWriter[Any] { 32 | val cNames = columns.map { _.columnName } 33 | val idxedCols = cNames.zipWithIndex 34 | 35 | def columnNames: Seq[String] = cNames 36 | def indexedColumns = idxedCols 37 | 38 | var fmt: Option[(Format.Value, Boolean)] = None 39 | 40 | def readColumnValues(row: Any, buffer: Array[Any]): Unit = { 41 | if (fmt.isEmpty) { 42 | fmt = Some((format, keyed) match { 43 | case (Some(x: Format.Value), Some(y: Boolean)) => (x, y) 44 | case _ => Format.detect(row) 45 | }) 46 | } 47 | 48 | fmt.get match { 49 | case (Format.TUPLE, false) => readAsTuple(row, buffer) 50 | case (Format.TUPLE, true) => readAsKVTuples(row, buffer) 51 | case (Format.DICT, false) => readAsDict(row, buffer) 52 | case (Format.DICT, true) => readAsKVDicts(row, buffer) 53 | case (Format.ROW, false) => readAsRow(row, buffer) 54 | case (Format.ROW, true) => readAsKVRows(row, buffer) 55 | case _ => throw new IllegalArgumentException("Unsupported or unknown cassandra row format") 56 | } 57 | } 58 | 59 | def readAsTuple(row: Any, buffer: Array[Any]) = { 60 | val v = row.asInstanceOf[Array[Any]] 61 | System.arraycopy(v, 0, buffer, 0, Math.min(v.length, buffer.length)); 62 | } 63 | 64 | def readAsKVTuples(row: Any, buffer: Array[Any]) = { 65 | val k, v = row.asInstanceOf[Array[Array[Any]]] 66 | 67 | val keyLength = Math.min(k.length, buffer.length); 68 | System.arraycopy(k, 0, buffer, 0, keyLength); 69 | System.arraycopy(v, 0, buffer, keyLength, Math.min(v.length, buffer.length - keyLength)); 70 | } 71 | 72 | def readAsDict(row: Any, buffer: Array[Any]) = { 73 | val v = row.asInstanceOf[JMap[Any, Any]] 74 | 75 | indexedColumns.map { 76 | case (s, i) => buffer(i) = v.get(s) 77 | } 78 | } 79 | 80 | def readAsKVDicts(row: Any, buffer: Array[Any]) = { 81 | val Array(k, v) = row.asInstanceOf[Array[JMap[Any, Any]]] 82 | 83 | indexedColumns.map { 84 | case (c, i) => { 85 | buffer(i) = if (v.containsKey(c)) v.get(c) else k.get(c) 86 | } 87 | } 88 | } 89 | 90 | def readAsRow(row: Any, buffer: Array[Any]): Unit = { 91 | val v = row.asInstanceOf[Row] 92 | readAsRow(v, buffer, 0) 93 | } 94 | 95 | def readAsKVRows(row: Any, buffer: Array[Any]) = { 96 | val Array(k, v) = row.asInstanceOf[Array[Row]] 97 | readAsRow(k, buffer, 0) 98 | readAsRow(v, buffer, k.fields.length) 99 | } 100 | 101 | private def readAsRow(row: Row, buffer: Array[Any], offset: Int): Unit = { 102 | row.fields.zipWithIndex.foreach { 103 | case (f, srcIdx) => 104 | val dstIdx = columnNames.indexOf(f) 105 | if (dstIdx >= 0) { 106 | buffer(dstIdx + offset) = row.values(srcIdx) 107 | } 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/SpanBy.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | import java.nio.ByteBuffer 18 | import scala.collection.JavaConversions.asScalaBuffer 19 | import scala.collection.mutable.ArrayBuffer 20 | import org.apache.spark.rdd.RDD 21 | import com.datastax.driver.core.{ DataType, ProtocolVersion } 22 | import com.datastax.spark.connector.toRDDFunctions 23 | import com.datastax.spark.connector.GettableData 24 | 25 | case class DataFrame(names: Array[String], types: Array[String], values: Seq[ArrayBuffer[Any]]) 26 | 27 | object SpanBy { 28 | def binary(rdd: RDD[UnreadRow], columns: Array[String]) = { 29 | // span by the given columns 30 | val spanned = rdd.spanBy { r => columns.map { c => r.row.getBytesUnsafe(c) } } 31 | 32 | // deserialize the spans 33 | spanned.map { 34 | case (k, rows) => { 35 | // get the columns for the data frame (so excluding the ones spanned by) 36 | val colDefs = rows.head.row.getColumnDefinitions.asList() 37 | val colTypesWithIdx = colDefs.map { 38 | d => d.getType 39 | }.zipWithIndex.filter { 40 | case (c, i) => columns.contains(c.getName) 41 | } 42 | 43 | // deserialize the spanning key 44 | val deserializedKey = k.zipWithIndex.map { 45 | case (bb, i) => GettableData.get(rows.head.row, i) 46 | } 47 | 48 | // transpose the rows in to columns and 'deserialize' 49 | val df = colDefs.map { x => new ArrayBuffer[Any] } 50 | for { 51 | row <- rows 52 | (ct, i) <- colTypesWithIdx 53 | } { 54 | df(i) += deserialize(row, ct, i) 55 | } 56 | 57 | // list the numpy types of the columns in the span (i.e. the non-key columns) 58 | val numpyTypes = colTypesWithIdx.map { case (c, i) => numpyType(c).getOrElse(null) } 59 | 60 | // return the key and 'dataframe container' 61 | (deserializedKey, new DataFrame(columns, numpyTypes.toArray, df)) 62 | } 63 | } 64 | } 65 | 66 | /** 67 | * 'Deserializes' a value of the given type _only if_ there is no binary representation possibly which can 68 | * be converted into a numpy array. I.e. longs will _not_ actually be deserialized, but Strings or UUIDs 69 | * will. If possible the value will be written out as a binary string for an entire column to be converted 70 | * to Numpy arrays. 71 | */ 72 | private def deserialize(row: UnreadRow, dt: DataType, i: Int) = { 73 | if (binarySupport(dt)) 74 | row.row.getBytesUnsafe(i) 75 | else 76 | GettableData.get(row.row, i) 77 | } 78 | 79 | /** Checks if a Cassandra type can be represented as a binary string. */ 80 | private def binarySupport(dataType: DataType) = { 81 | numpyType(dataType) match { 82 | case Some(x) => true 83 | case None => false 84 | } 85 | } 86 | 87 | /** Provides a Numpy type string for every Cassandra type supported. */ 88 | private def numpyType(dataType: DataType) = { 89 | Option(dataType.getName match { 90 | case DataType.Name.BOOLEAN => ">b1" 91 | case DataType.Name.INT => ">i4" 92 | case DataType.Name.BIGINT => ">i8" 93 | case DataType.Name.COUNTER => ">i8" 94 | case DataType.Name.FLOAT => ">f4" 95 | case DataType.Name.DOUBLE => ">f8" 96 | case DataType.Name.TIMESTAMP => ">M8[ms]" 97 | case _ => null 98 | }) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_cassandra/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed under the Apache License, Version 2.0 (the "License"); 3 | * you may not use this file except in compliance with the License. 4 | * You may obtain a copy of the License at 5 | * 6 | * http://www.apache.org/licenses/LICENSE-2.0 7 | * 8 | * Unless required by applicable law or agreed to in writing, software 9 | * distributed under the License is distributed on an "AS IS" BASIS, 10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | * See the License for the specific language governing permissions and 12 | * limitations under the License. 13 | */ 14 | 15 | package pyspark_cassandra 16 | 17 | import java.nio.ByteBuffer 18 | import java.lang.{ Boolean => JBoolean } 19 | import java.util.{ List => JList, Map => JMap } 20 | import scala.reflect.ClassTag 21 | import scala.collection.JavaConversions._ 22 | import scala.collection.mutable.Buffer 23 | import org.apache.spark.SparkContext 24 | import com.datastax.driver.core.{ CodecRegistry, ConsistencyLevel, DataType, ProtocolVersion } 25 | import com.datastax.spark.connector._ 26 | import com.datastax.spark.connector.rdd._ 27 | import com.datastax.spark.connector.writer._ 28 | 29 | object Utils { 30 | def columnSelector(columns: Array[String], default: ColumnSelector = AllColumns) = { 31 | if (columns != null && columns.length > 0) { 32 | SomeColumns(columns.map { ColumnName(_) }: _*) 33 | } else { 34 | default 35 | } 36 | } 37 | 38 | def parseReadConf(sc: SparkContext, readConf: Option[JMap[String, Any]]) = { 39 | var conf = ReadConf.fromSparkConf(sc.getConf) 40 | 41 | readConf match { 42 | case Some(rc) => 43 | for { (k, v) <- rc } { 44 | (k, v) match { 45 | case ("split_count", v: Int) => conf = conf.copy(splitCount = Option(v)) 46 | case ("split_size", v: Int) => conf = conf.copy(splitSizeInMB = v) 47 | case ("fetch_size", v: Int) => conf = conf.copy(fetchSizeInRows = v) 48 | case ("consistency_level", v: Int) => conf = conf.copy(consistencyLevel = ConsistencyLevel.values()(v)) 49 | case ("consistency_level", v) => conf = conf.copy(consistencyLevel = ConsistencyLevel.valueOf(v.toString)) 50 | case ("metrics_enabled", v: Boolean) => conf = conf.copy(taskMetricsEnabled = v) 51 | case _ => throw new IllegalArgumentException(s"Read conf key $k with value $v unsupported") 52 | } 53 | } 54 | case None => // do nothing 55 | } 56 | 57 | conf 58 | } 59 | 60 | def parseWriteConf(writeConf: Option[JMap[String, Any]]) = { 61 | var conf = WriteConf() 62 | 63 | writeConf match { 64 | case Some(wc) => 65 | for { (k, v) <- wc } { 66 | (k, v) match { 67 | case ("batch_size", v: Int) => conf = conf.copy(batchSize = BytesInBatch(v)) 68 | case ("batch_buffer_size", v: Int) => conf = conf.copy(batchGroupingBufferSize = v) 69 | case ("batch_grouping_key", "replica_set") => conf = conf.copy(batchGroupingKey = BatchGroupingKey.ReplicaSet) 70 | case ("batch_grouping_key", "partition") => conf = conf.copy(batchGroupingKey = BatchGroupingKey.ReplicaSet) 71 | case ("consistency_level", v: Int) => conf = conf.copy(consistencyLevel = ConsistencyLevel.values()(v)) 72 | case ("consistency_level", v) => conf = conf.copy(consistencyLevel = ConsistencyLevel.valueOf(v.toString)) 73 | case ("parallelism_level", v: Int) => conf = conf.copy(parallelismLevel = v) 74 | case ("throughput_mibps", v: Number) => conf = conf.copy(throughputMiBPS = v.doubleValue()) 75 | case ("ttl", v: Int) => conf = conf.copy(ttl = TTLOption.constant(v)) 76 | case ("timestamp", v: Number) => conf = conf.copy(timestamp = TimestampOption.constant(v.longValue())) 77 | case ("metrics_enabled", v: Boolean) => conf = conf.copy(taskMetricsEnabled = v) 78 | case _ => throw new IllegalArgumentException(s"Write conf key $k with value $v unsupported") 79 | } 80 | } 81 | case None => // do nothing 82 | } 83 | 84 | conf 85 | } 86 | } 87 | 88 | object Format extends Enumeration { 89 | val DICT, TUPLE, ROW = Value 90 | 91 | def apply(format: Integer): Option[Format.Value] = 92 | if (format != null) Some(Format(format)) else None 93 | 94 | def parser(example: Any): FromUnreadRow[_] = { 95 | val format = detect(example) 96 | return Format.parser(format._1, format._2) 97 | } 98 | 99 | def parser(format: Option[Format.Value], keyed: Option[Boolean]): FromUnreadRow[_] = 100 | parser(format.getOrElse(Format.ROW), keyed.getOrElse(false)) 101 | 102 | def parser(format: Format.Value, keyed: Boolean): FromUnreadRow[_] = { 103 | (format, keyed) match { 104 | case (Format.ROW, false) => ToRow 105 | case (Format.ROW, true) => ToKVRows 106 | 107 | case (Format.TUPLE, false) => ToTuple 108 | case (Format.TUPLE, true) => ToKVTuple 109 | 110 | case (Format.DICT, false) => ToDict 111 | case (Format.DICT, true) => ToKVDicts 112 | 113 | case _ => throw new IllegalArgumentException() 114 | } 115 | } 116 | 117 | def detect(row: Any) = { 118 | // The detection works because primary keys can't be maps, sets or lists. If the detection still fails, a 119 | // user must set the row_format explicitly 120 | 121 | row match { 122 | // Rows map to ROW of course 123 | case row: Row => (Format.ROW, false) 124 | 125 | // If the row is a map, the only possible format is DICT 126 | case row: Map[_, _] => (Format.DICT, false) 127 | case row: JMap[_, _] => (Format.DICT, false) 128 | 129 | // otherwise it must be a tuple 130 | case row: Array[_] => 131 | // If the row is a tuple of length two, try to figure out if it's a (key,value) tuple 132 | if (row.length == 2) { 133 | val Array(k, v) = row 134 | 135 | k match { 136 | case k: Map[_, _] => (Format.DICT, true) 137 | case k: Array[_] => (Format.TUPLE, true) 138 | case k: Row => (Format.ROW, true) 139 | case _ => (Format.TUPLE, false) 140 | } 141 | } 142 | 143 | (Format.TUPLE, false) 144 | 145 | // or if we really can't figure it out, request the user to set it explicitly 146 | case _ => 147 | val cls = row.getClass() 148 | throw new RuntimeException(s"Unable to detect or unsupported row format ($cls). Set it explicitly " + 149 | "with saveToCassandra(..., row_format=...).") 150 | } 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /src/main/scala/pyspark_util: -------------------------------------------------------------------------------- 1 | ../../../pyspark-util/src/main/scala/pyspark_util/ -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.3.5 2 | --------------------------------------------------------------------------------