├── .codecov.yml ├── .coveragerc ├── .gitignore ├── .travis.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── build.sbt ├── build └── sbt ├── build_and_package.sh ├── imgs ├── sparkling_ml.jpg └── sparkling_ml.png ├── project ├── build.properties └── plugins.sbt ├── requirements.txt ├── scalastyle-config.xml ├── setup.py ├── sparklingml ├── __init__.py ├── feature │ ├── __init__.py │ ├── lucene_analyzers.py │ └── python_pipelines.py ├── java_wrapper_ml.py ├── param │ ├── __init__.py │ ├── _shared_params_code_gen.py │ └── shared.py ├── startup.py └── transformation_functions.py └── src ├── main └── scala │ └── com │ └── sparklingpandas │ └── sparklingml │ ├── CodeGenerator.scala │ ├── feature │ ├── BasicPython.scala │ ├── LuceneAnalyzer.scala │ ├── LuceneAnalyzerGenerators.scala │ ├── LuceneAnalyzers.scala │ └── LuceneHelpers.scala │ ├── param │ ├── SharedParamsCodeGen.scala │ └── sharedParams.scala │ └── util │ └── python │ ├── Initialize.scala │ └── PythonTransformer.scala └── test └── scala └── com └── sparklingpandas └── sparklingml ├── feature ├── BasicPython.scala ├── LuceneAnalyzerGeneratorsTest.scala ├── LuceneAnalyzersTests.scala └── LuceneBaseTests.scala └── param └── SharedParamsCodeGenTest.scala /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | ignore: #things we ignore 3 | - "sparklingml/statup.py" # This is tested from Java so doesn't show up in coverage #s -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | concurrency=multiprocessing -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | build.sbt_back 4 | 5 | # sbt specific 6 | dist/* 7 | target/ 8 | lib_managed/ 9 | src_managed/ 10 | project/boot/ 11 | project/plugins/project/ 12 | sbt/*.jar 13 | mini-complete-example/sbt/*.jar 14 | 15 | # Scala-IDE specific 16 | .scala_dependencies 17 | 18 | #Emacs 19 | *~ 20 | 21 | #ignore the metastore 22 | metastore_db/* 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .env 33 | .Python 34 | env/ 35 | bin/ 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .coverage 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | 61 | # Translations 62 | *.mo 63 | 64 | # Mr Developer 65 | .mr.developer.cfg 66 | .project 67 | .pydevproject 68 | 69 | # Rope 70 | .ropeproject 71 | 72 | # Django stuff: 73 | *.log 74 | *.pot 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyCharm files 80 | *.idea 81 | 82 | # emacs stuff 83 | 84 | # Autoenv 85 | .env 86 | *~ 87 | # Byte-compiled / optimized / DLL files 88 | __pycache__/ 89 | *.py[cod] 90 | 91 | # C extensions 92 | *.so 93 | 94 | # Distribution / packaging 95 | .env 96 | .Python 97 | env/ 98 | bin/ 99 | build/ 100 | develop-eggs/ 101 | dist/ 102 | eggs/ 103 | lib/ 104 | lib64/ 105 | parts/ 106 | sdist/ 107 | var/ 108 | *.egg-info/ 109 | .installed.cfg 110 | *.egg 111 | 112 | # Installer logs 113 | pip-log.txt 114 | pip-delete-this-directory.txt 115 | 116 | # Unit test / coverage reports 117 | htmlcov/ 118 | .tox/ 119 | .coverage 120 | .cache 121 | nosetests.xml 122 | coverage.xml 123 | 124 | # Translations 125 | *.mo 126 | 127 | # Mr Developer 128 | .mr.developer.cfg 129 | .project 130 | .pydevproject 131 | 132 | # Rope 133 | .ropeproject 134 | 135 | # Django stuff: 136 | *.log 137 | *.pot 138 | 139 | # Sphinx documentation 140 | docs/_build/ 141 | 142 | # PyCharm files 143 | *.idea 144 | 145 | # emacs stuff 146 | \#*\# 147 | \.\#* 148 | 149 | # Autoenv 150 | .env 151 | *~ 152 | 153 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | # Don't really need sudo, just need more memory! 3 | sudo: required 4 | # These directories are cached to S3 at the end of the build 5 | cache: 6 | directories: 7 | - $HOME/.ivy2/cache 8 | - $HOME/.sbt/boot/ 9 | - $HOME/.sbt/launchers 10 | - $HOME/myenv 11 | jdk: 12 | - oraclejdk8 13 | scala: 14 | - 2.11.8 15 | before_install: 16 | - echo "We are at:" 17 | - git rev-parse HEAD 18 | - git log -n 5 19 | - pip install --user virtualenv 20 | - if [ ! -d "myenv" ]; then virtualenv myenv && source myenv/bin/activate && pip install --upgrade pip; fi 21 | - source myenv/bin/activate 22 | - export PATH=myenv/bin/:$PATH 23 | - pip install spacy 24 | - pip install nose codecov coverage sphinx isort 25 | - pip install -r requirements.txt 26 | - python -m spacy download en 27 | script: 28 | - echo "Scala style" 29 | - ./build/sbt scalastyle 30 | - echo "Python style" 31 | - "flake8 --ignore=E305,E402,F403,F405,F999 sparklingml/" 32 | - cd sparklingml 33 | - isort -c --skip lucene_analyzers.py 34 | - cd .. 35 | - echo "Build and test" 36 | - ./build/sbt clean coverage compile package pack test coverageReport 37 | - bash <(curl -s https://codecov.io/bash) -cF scala || echo "No Scala report uploaded" 38 | - "find ./target |grep -i jar" 39 | - "nosetests --logging-level=INFO --detailed-errors --verbosity=2 --with-coverage --cover-html-dir=./htmlcov --cover-package=sparklingml --with-doctest --doctest-options=+ELLIPSIS,+NORMALIZE_WHITESPACE" 40 | - coverage combine 41 | - bash <(curl -s https://codecov.io/bash) -cF python || echo "No Python report uploaded" 42 | sudo: false 43 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at holden@pigscanfly.ca. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Yay! I'm so excited you are interested in contributing to Sparkling ML! 2 | 3 | The first thing you should start of with doing is subscribing to our mailing list, there isn't a lot of traffic yet but we can work on any questions you have their together. 4 | https://groups.google.com/forum/#!forum/sparklingml-dev 5 | 6 | Once you've subscribed reach out about what kind of model/algorithm you want to bring into the fold. 7 | 8 | If this is your first time getting started adding a new model with Spark's pipeline API there are some resources to get started with: 9 | - (Blog post) O'Reilly Radar on extending Spark ML by Holden - https://www.oreilly.com/learning/extend-spark-ml-for-your-own-modeltransformer-types 10 | - (Video) Spark Summit talk by Holden & Seth on extending Spark ML for custom models https://www.youtube.com/watch?v=gCfVVrgWgxY 11 | 12 | The goal of the project is not to be home to a lot of complex model code, but rather to help bring existing ML tools into Spark's pipeline API while making them accesiable accross Python, Scala, and Java. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt 2 | include *.md 3 | include sparklingml/jar/*.jar 4 | recursive-include docs *.md 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![sparkling ml logo](https://raw.githubusercontent.com/sparklingpandas/sparklingml/master/imgs/sparkling_ml.png) 2 | [![buildstatus](https://travis-ci.org/sparklingpandas/sparklingml.svg?branch=master)](https://travis-ci.org/sparklingpandas/sparklingml) 3 | [![codecov.io](http://codecov.io/github/sparklingpandas/sparklingml/coverage.svg?branch=master)](http://codecov.io/github/sparklingpandas/sparklingml?branch=master) 4 | 5 | # sparklingml 6 | Machine Learning Pipeline Stages for Spark (exposed in Scala/Java + Python) 7 | 8 | ## Why? 9 | 10 | SparklingML's goal is to expose additional machine learning stages for Spark with the pipeline interface. 11 | 12 | ## Status 13 | 14 | Super early! Come join! 15 | 16 | Dev mailing list: https://groups.google.com/forum/#!forum/sparklingml-dev 17 | 18 | ## Building 19 | 20 | Sparkling ML consists of two components, a Python component and a Java/Scala component. The Python component depends on having the Java/Scala component pre-build which can be done by running `./build/sbt package`. 21 | 22 | 23 | The Python component depends on the package listed in requirements.txt (as well as part of setup.py). Development and testing also requires spacy, nose, codecov, pylint, and flake8. 24 | 25 | 26 | The script `build_and_package.sh` builds & tests both the Scala and Python code. 27 | 28 | 29 | For now this only works with Spark 2.3.2, it needs some changes to support other versions. 30 | 31 | ### Tests 32 | 33 | Are your DocTests failing with 34 | 35 | >Expected nothing 36 | >Got: 37 | > 38 | > Warning: no model found for 'en' 39 | > 40 | > Only loading the 'en' tokenizer. 41 | > 42 | 43 | 44 | Make sure you've installed spacy & the en language pack (`python -m spacy download en`) 45 | 46 | ## Including in your build 47 | 48 | SparklingML is not yet ready for production use. 49 | 50 | ## License 51 | 52 | SparklingML is licensed under the Apache 2 license. Some additional components may be under a different license. 53 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | organization := "com.sparklingpandas" 2 | 3 | name := "sparklingml" 4 | 5 | publishMavenStyle := true 6 | 7 | version := "0.0.1-SNAPSHOT" 8 | 9 | sparkVersion := "2.3.2" 10 | 11 | scalaVersion := "2.11.8" 12 | 13 | coverageHighlighting := true 14 | 15 | javacOptions ++= Seq("-source", "1.8", "-target", "1.8") 16 | 17 | //tag::spName[] 18 | spName := "sparklingpandas/sparklingml" 19 | //end::spName[] 20 | 21 | sparkComponents := Seq("core", "sql", "catalyst", "mllib") 22 | 23 | parallelExecution in Test := false 24 | fork in test := true 25 | 26 | 27 | coverageHighlighting := true 28 | coverageEnabled := true 29 | 30 | 31 | javaOptions in test ++= Seq("-Xms1G", "-Xmx1G", "-XX:MaxPermSize=1024M", "-XX:+CMSClassUnloadingEnabled") 32 | 33 | test in assembly := {} 34 | 35 | libraryDependencies ++= Seq( 36 | // spark components 37 | "org.apache.spark" %% "spark-core" % sparkVersion.value % "provided", 38 | "org.apache.spark" %% "spark-hive" % sparkVersion.value % "provided", 39 | "org.apache.spark" %% "spark-sql" % sparkVersion.value % "provided", 40 | "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "provided", 41 | "org.apache.spark" %% "spark-mllib" % sparkVersion.value % "provided", 42 | // algorithm providers 43 | "org.apache.lucene" % "lucene-analyzers-common" % "6.6.0", 44 | "org.apache.lucene" % "lucene-analyzers-icu" % "6.6.0", 45 | "org.apache.lucene" % "lucene-analyzers-kuromoji" % "6.6.0", 46 | "org.apache.lucene" % "lucene-analyzers-morfologik" % "6.6.0", 47 | "org.apache.lucene" % "lucene-analyzers-phonetic" % "6.6.0", 48 | "org.apache.lucene" % "lucene-analyzers-smartcn" % "6.6.0", 49 | "org.apache.lucene" % "lucene-analyzers-stempel" % "6.6.0", 50 | "org.apache.lucene" % "lucene-analyzers-uima" % "6.6.0", 51 | // internals that are only used during code gen 52 | // TODO(holden): exclude from assembly but keep for runMain somehow? 53 | "org.scala-lang" % "scala-reflect" % "2.11.7" % "provided", 54 | "org.reflections" % "reflections" % "0.9.11" % "provided", 55 | // testing libraries 56 | "org.scalatest" %% "scalatest" % "3.0.1" % "test", 57 | "org.scalacheck" %% "scalacheck" % "1.13.4" % "test", 58 | "com.holdenkarau" %% "spark-testing-base" % "0.7.3" % "test") 59 | 60 | 61 | scalacOptions ++= Seq("-deprecation", "-unchecked") 62 | 63 | pomIncludeRepository := { x => false } 64 | 65 | resolvers ++= Seq( 66 | "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/", 67 | "Typesafe repository" at "http://repo.typesafe.com/typesafe/releases/", 68 | "Second Typesafe repo" at "http://repo.typesafe.com/typesafe/maven-releases/", 69 | "Mesosphere Public Repository" at "http://downloads.mesosphere.io/maven", 70 | Resolver.sonatypeRepo("public"), 71 | // restlet has a seperate maven repo because idk 72 | "restlet" at "http://maven.restlet.com", 73 | // idk why this doesn't seem to be included in DefaultMavenRepository 74 | "DefaultMavenRepository" at " https://repo1.maven.org/maven2/" 75 | ) 76 | 77 | // publish settings 78 | publishTo := { 79 | val nexus = "https://oss.sonatype.org/" 80 | if (isSnapshot.value) 81 | Some("snapshots" at nexus + "content/repositories/snapshots") 82 | else 83 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 84 | } 85 | 86 | licenses := Seq("Apache License 2.0" -> 87 | url("http://www.apache.org/licenses/LICENSE-2.0.html")) 88 | 89 | homepage := Some(url("https://github.com/sparklingpandas/sparklingml")) 90 | 91 | pomExtra := ( 92 | 93 | git@github.com:sparklingpandas/sparklingml.git 94 | scm:git@github.com:sparklingpandas/sparklingml.git 95 | 96 | 97 | 98 | holdenk 99 | Holden Karau 100 | http://www.holdenkarau.com 101 | holden@pigscanfly.ca 102 | 103 | 104 | ) 105 | 106 | credentials ++= Seq( 107 | Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"), 108 | Credentials(Path.userHome / ".ivy2" / ".sparkcredentials")) 109 | 110 | spIncludeMaven := true 111 | 112 | useGpg := true 113 | 114 | testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a") 115 | -------------------------------------------------------------------------------- /build/sbt: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | # This script launches sbt for this project. If present it uses the system 21 | # version of sbt. If there is no system version of sbt it attempts to download 22 | # sbt locally. 23 | SBT_VERSION=0.13.16 24 | URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar 25 | URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar 26 | JAR=build/sbt-launch-${SBT_VERSION}.jar 27 | 28 | # Download sbt launch jar if it hasn't been downloaded yet 29 | if [ ! -f ${JAR} ]; then 30 | # Download 31 | printf "Attempting to fetch sbt\n" 32 | set -x 33 | JAR_DL=${JAR}.part 34 | if hash wget 2>/dev/null; then 35 | (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} 36 | elif hash axel 2>/dev/null; then 37 | (axel ${URL1} -o ${JAR_DL} || axel ${URL2} -o ${JAR_DL}) && mv ${JAR_DL} ${JAR} 38 | else 39 | printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" 40 | exit -1 41 | fi 42 | fi 43 | if [ ! -f ${JAR} ]; then 44 | # We failed to download 45 | printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" 46 | exit -1 47 | fi 48 | printf "Launching sbt from ${JAR}\n" 49 | java \ 50 | -Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m \ 51 | -jar ${JAR} \ 52 | "$@" 53 | -------------------------------------------------------------------------------- /build_and_package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | set -e 4 | 5 | 6 | echo "Installing python requirements" 7 | 8 | pip install -r requirements.txt 9 | 10 | echo "Checking if spacy 'en' is installed otherwise download it" 11 | 12 | python -c "import spacy;spacy.load('en_core_web_sm')" || python -m spacy download en_core_web_sm && python -c "import spacy;spacy.load('en_core_web_sm')" 13 | 14 | echo "Checking scala style issues" 15 | 16 | ./build/sbt scalastyle 17 | 18 | echo "Checking python style issues" 19 | 20 | flake8 --ignore=E402,F405,F401,F403 sparklingml/ 21 | 22 | echo "Building JVM code" 23 | 24 | ./build/sbt clean pack assembly 25 | 26 | echo "Copying assembly jar to python loadable directory" 27 | 28 | mkdir -p ./sparklingml/jar 29 | cp target/scala-2.11/sparklingml-assembly-0.0.1-SNAPSHOT.jar ./sparklingml/jar/sparklingml.jar 30 | 31 | echo "Testing Python code" 32 | 33 | nosetests --logging-level=INFO --detailed-errors --verbosity=2 --with-coverage --cover-html-dir=./htmlcov --cover-package=sparklingml --with-doctest --doctest-options=+ELLIPSIS,+NORMALIZE_WHITESPACE 34 | 35 | echo "Testing pip install of Python code" 36 | 37 | pip install . 38 | mkdir /tmp/abcd 39 | pushd /tmp/abcd 40 | python -c "import sparklingml" 41 | popd 42 | 43 | 44 | echo "Testing JVM code" 45 | 46 | # Skip for now due to gateway issues. 47 | # ./build/sbt test 48 | 49 | echo "Finished" 50 | -------------------------------------------------------------------------------- /imgs/sparkling_ml.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sparklingpandas/sparklingml/0fc718697774524f1c06fe1670c39c76aa4522cc/imgs/sparkling_ml.jpg -------------------------------------------------------------------------------- /imgs/sparkling_ml.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sparklingpandas/sparklingml/0fc718697774524f1c06fe1670c39c76aa4522cc/imgs/sparkling_ml.png -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.16 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") 2 | 3 | resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" 4 | 5 | resolvers += "Spark Package Main Repo" at "https://dl.bintray.com/spark-packages/maven" 6 | 7 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.6") 8 | 9 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") 10 | 11 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") 12 | 13 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") 14 | 15 | addSbtPlugin("org.xerial.sbt" % "sbt-pack" % "0.9.1") 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | pyspark==2.4.4 3 | pypandoc 4 | scipy 5 | numpy 6 | nose 7 | unittest2>=1.0.0 8 | pandas>=0.13 9 | spacy 10 | future 11 | pyarrow==0.11.0 12 | flake8==3.5.0 13 | nltk 14 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 2 | Scalastyle standard configuration 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | VERSION = '0.0.1' 4 | 5 | setup( 6 | name='sparklingml', 7 | version=VERSION, 8 | author='Holden Karau', 9 | author_email='holden@pigscanfly.ca', 10 | # Copy the shell script into somewhere likely to be in the users path 11 | packages=find_packages(), 12 | include_package_data=True, 13 | package_data={ 14 | 'sparklingml': ['jar/sparklingml.jar'] 15 | }, 16 | url='https://github.com/sparklingpandas/sparklingml', 17 | license='LICENSE', 18 | description='Add additional ML algorithms to Spark', 19 | long_description=open('README.md').read(), 20 | install_requires=[ 21 | 'pyspark>=2.3.0', 22 | 'nltk', 23 | 'numpy', # Requires for PySpark ML 24 | 'pandas', 25 | 'spacy', 26 | 'future', 27 | 'pyarrow', 28 | ], 29 | test_requires=[ 30 | 'nose==1.3.7', 31 | 'coverage>3.7.0', 32 | 'unittest2>=1.0.0', 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /sparklingml/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | """ 19 | Sparkling ML provides additional ML algorithms for Spark. 20 | """ 21 | from __future__ import print_function 22 | import glob 23 | import os 24 | from pkg_resources import resource_filename 25 | 26 | 27 | if 'IS_TEST' not in os.environ and "JARS" not in os.environ: 28 | VERSION = '0.0.1' 29 | # Check if target pack exists 30 | jar = None 31 | if os.path.exists("target/pack/lib/"): 32 | print("Using packed jars during development") 33 | jars = glob.glob("target/pack/lib/*.jar") 34 | abs_path_jars = map(os.path.abspath, jars) 35 | jar = ",".join(abs_path_jars) 36 | else: 37 | JAR_FILE = 'sparklingml-assembly-' + VERSION + '.jar' 38 | DEV_JAR = 'sparklingml-assembly-' + VERSION + '-SNAPSHOT.jar' 39 | my_location = os.path.dirname(os.path.realpath(__file__)) 40 | local_prefixes = [ 41 | # For development, use the sbt target scala-2.11 first 42 | # since the init script is in sparklingpandas move up one dir 43 | os.path.join(my_location, '../target/scala-2.11/'), 44 | # Also try the present working directory 45 | os.path.join(os.getcwd(), '../target/scala-2.11/'), 46 | os.path.join(os.getcwd(), 'target/scala-2.11/')] 47 | prod_jars = [os.path.join(prefix, JAR_FILE) 48 | for prefix in local_prefixes] 49 | dev_jars = [os.path.join(prefix, DEV_JAR) 50 | for prefix in local_prefixes] 51 | 52 | jars = prod_jars + dev_jars 53 | try: 54 | jars.append(os.path.abspath(resource_filename( 55 | 'sparklingml', 56 | "jar/sparklingml.jar"))) 57 | except Exception as e: 58 | print("Could not resolve resource file %s. This is not necessarily" 59 | " (and is expected during development) but should not occur " 60 | "in production if pip installed." % str(e)) 61 | try: 62 | jar = [jar_path 63 | for jar_path in jars 64 | if os.path.exists(jar_path)][0] 65 | except IndexError: 66 | print("Failed to find jars. Looked at paths %s." % jars) 67 | if 'SPARKLING_ML_SPECIFIC' not in os.environ: 68 | raise IOError("Failed to find jars. Looked at paths %s." 69 | % jars) 70 | else: 71 | print("Failed to find jars, but launched from the JVM" 72 | "so this _should_ be ok.") 73 | if jar is not None: 74 | os.environ["JARS"] = jar 75 | print("Using backing jar " + jar) 76 | os.environ["PYSPARK_SUBMIT_ARGS"] = ( 77 | "--jars %s --driver-class-path %s pyspark-shell") % (jar, jar) 78 | -------------------------------------------------------------------------------- /sparklingml/feature/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | """ 19 | Sparkling ML provides additional ML algorithms for Spark. 20 | """ 21 | from __future__ import print_function 22 | -------------------------------------------------------------------------------- /sparklingml/feature/python_pipelines.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from pyspark import keyword_only 4 | from pyspark.ml import Model 5 | from pyspark.ml.param import * 6 | from pyspark.ml.param.shared import * 7 | from pyspark.ml.util import * 8 | from pyspark.rdd import ignore_unicode_prefix 9 | from pyspark.sql.functions import (PandasUDFType, UserDefinedFunction, 10 | pandas_udf) 11 | 12 | from sparklingml.transformation_functions import * 13 | 14 | # Most of the Python models are wrappers of JavaModels, and we will need those 15 | # as well. For now this simple example shows how to expose a simple Python only 16 | # UDF. Look in Spark for exposing Java examples. 17 | 18 | 19 | @ignore_unicode_prefix 20 | class StrLenPlusKTransformer(Model, HasInputCol, HasOutputCol): 21 | """ 22 | strLenPlusK takes one parameter it is k and returns 23 | the string length plus k. This is intended to illustrate how 24 | to make a Python stage and not for actual use. 25 | >>> from pyspark.sql import SparkSession 26 | >>> spark = SparkSession.builder.master("local[2]").getOrCreate() 27 | >>> df = spark.createDataFrame([("hi",), ("boo",)], ["values"]) 28 | >>> tr = StrLenPlusKTransformer(inputCol="values", outputCol="c", k=2) 29 | >>> tr.getK() 30 | 2 31 | >>> tr.transform(df).head().c 32 | 4 33 | >>> tr.setK(1) 34 | StrLenPlusKTransformer_... 35 | >>> tr.transform(df).head().c 36 | 3 37 | """ 38 | 39 | # We need a parameter to configure k 40 | k = Param(Params._dummy(), 41 | "k", "amount to add to str len", 42 | typeConverter=TypeConverters.toInt) 43 | 44 | @keyword_only 45 | def __init__(self, k=None, inputCol=None, outputCol=None): 46 | super(StrLenPlusKTransformer, self).__init__() 47 | kwargs = self._input_kwargs 48 | self.setParams(**kwargs) 49 | 50 | @keyword_only 51 | def setParams(self, k=None, inputCol=None, outputCol=None): 52 | """ 53 | setParams(self, k=None, inputCol=None, outputCol=None): 54 | """ 55 | kwargs = self._input_kwargs 56 | return self._set(**kwargs) 57 | 58 | def setK(self, value): 59 | """ 60 | Sets the value of :py:attr:`k`. 61 | """ 62 | return self._set(k=value) 63 | 64 | def getK(self): 65 | """ 66 | Gets the value of K or its default value. 67 | """ 68 | return self.getOrDefault(self.k) 69 | 70 | def _transform(self, dataset): 71 | func = StrLenPlusK.func(self.getK()) 72 | ret_type = StrLenPlusK.returnType() 73 | udf = UserDefinedFunction(func, ret_type) 74 | return dataset.withColumn( 75 | self.getOutputCol(), udf(self.getInputCol()) 76 | ) 77 | 78 | 79 | @ignore_unicode_prefix 80 | class SpacyTokenizeTransformer(Model, HasInputCol, HasOutputCol): 81 | """ 82 | Tokenize the provided input using Spacy. 83 | >>> from pyspark.sql import SparkSession 84 | >>> spark = SparkSession.builder.master("local[2]").getOrCreate() 85 | >>> df = spark.createDataFrame([("hi boo", 0.0), 86 | ... ("bye boo", 1.0)], 87 | ... ["vals", "label"]) 88 | >>> tr = SpacyTokenizeTransformer(inputCol="vals", outputCol="c") 89 | >>> str(tr.getLang()) 90 | 'en_core_web_sm' 91 | >>> tr.transform(df).head().c 92 | [u'hi', u'boo'] 93 | >>> from pyspark.ml import Pipeline 94 | >>> from pyspark.ml.classification import LogisticRegression 95 | >>> from pyspark.ml.feature import HashingTF 96 | >>> hashingtf = HashingTF(inputCol="c", outputCol="features") 97 | >>> lr = LogisticRegression(featuresCol="features") 98 | >>> pipeline = Pipeline(stages=[tr, hashingtf, lr]) 99 | >>> model = pipeline.fit(df) 100 | """ 101 | 102 | # We need a parameter to configure k 103 | lang = Param(Params._dummy(), 104 | "lang", "language", 105 | typeConverter=TypeConverters.toString) 106 | 107 | @keyword_only 108 | def __init__(self, lang="en_core_web_sm", inputCol=None, outputCol=None): 109 | super(SpacyTokenizeTransformer, self).__init__() 110 | self._setDefault(lang="en_core_web_sm") 111 | kwargs = self._input_kwargs 112 | self.setParams(**kwargs) 113 | 114 | @keyword_only 115 | def setParams(self, lang="en_core_web_sm", inputCol=None, outputCol=None): 116 | """ 117 | setParams(self, lang="en_core_web_sm", inputCol=None, outputCol=None): 118 | """ 119 | kwargs = self._input_kwargs 120 | return self._set(**kwargs) 121 | 122 | def setLang(self, value): 123 | """ 124 | Sets the value of :py:attr:`lang`. 125 | """ 126 | return self._set(lang=value) 127 | 128 | def getLang(self): 129 | """ 130 | Gets the value of lang or its default value. 131 | """ 132 | return self.getOrDefault(self.lang) 133 | 134 | def _transform(self, dataset): 135 | SpacyTokenize.setup(dataset._sc, dataset.sql_ctx, self.getLang()) 136 | func = SpacyTokenize.func(self.getLang()) 137 | ret_type = SpacyTokenize.returnType() 138 | udf = pandas_udf(func, ret_type) 139 | return dataset.withColumn( 140 | self.getOutputCol(), udf(self.getInputCol()) 141 | ) 142 | 143 | 144 | @ignore_unicode_prefix 145 | class SpacyAdvancedTokenizeTransformer(Model, HasInputCol, HasOutputCol): 146 | """ 147 | Tokenize the provided input using Spacy. 148 | >>> from pyspark.sql import SparkSession 149 | >>> spark = SparkSession.builder.master("local[2]").getOrCreate() 150 | >>> df = spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) 151 | >>> tr = SpacyAdvancedTokenizeTransformer(inputCol="vals", outputCol="c") 152 | >>> str(tr.getLang()) 153 | 'en_core_web_sm' 154 | >>> tr.getSpacyFields() 155 | ['_', 'ancestors', ... 156 | >>> tr.setSpacyFields(["text", "lang_"]) 157 | SpacyAdvancedTokenizeTransformer_... 158 | >>> r = tr.transform(df).head().c 159 | >>> l = list(map(lambda d: sorted(d.items()), r)) 160 | >>> l[0] 161 | [(u'lang_', u'en'), (u'text', u'hi')] 162 | >>> l[1] 163 | [(u'lang_', u'en'), (u'text', u'boo')] 164 | """ 165 | 166 | lang = Param(Params._dummy(), 167 | "lang", "language", 168 | typeConverter=TypeConverters.toString) 169 | 170 | spacyFields = Param(Params._dummy(), 171 | "spacyFields", "fields of token to keep", 172 | typeConverter=TypeConverters.toListString) 173 | 174 | @keyword_only 175 | def __init__(self, lang=None, 176 | spacyFields=None, 177 | inputCol=None, outputCol=None): 178 | super(SpacyAdvancedTokenizeTransformer, self).__init__() 179 | kwargs = self._input_kwargs 180 | if "spacyFields" not in kwargs: 181 | kwargs["spacyFields"] = list(SpacyAdvancedTokenize.default_fields) 182 | self._setDefault( 183 | lang="en_core_web_sm", 184 | spacyFields=list(SpacyAdvancedTokenize.default_fields)) 185 | self.setParams(**kwargs) 186 | 187 | @keyword_only 188 | def setParams(self, lang="en_core_web_sm", spacyFields=None, 189 | inputCol=None, outputCol=None): 190 | """ 191 | setParams(self, lang="en_core_web_sm", 192 | SpacyAdvancedTokenize.default_fields, 193 | inputCol=None, outputCol=None): 194 | """ 195 | kwargs = self._input_kwargs 196 | if "spacyFields" not in kwargs: 197 | kwargs["spacyFields"] = list(SpacyAdvancedTokenize.default_fields) 198 | return self._set(**kwargs) 199 | 200 | def setLang(self, value): 201 | """ 202 | Sets the value of :py:attr:`lang`. 203 | """ 204 | return self._set(lang=value) 205 | 206 | def getLang(self): 207 | """ 208 | Gets the value of lang or its default value. 209 | """ 210 | return self.getOrDefault(self.lang) 211 | 212 | def setSpacyFields(self, value): 213 | """ 214 | Sets the value of :py:attr:`spacyFields`. 215 | """ 216 | return self._set(spacyFields=value) 217 | 218 | def getSpacyFields(self): 219 | """ 220 | Gets the value of lang or its default value. 221 | """ 222 | return self.getOrDefault(self.spacyFields) 223 | 224 | def _transform(self, dataset): 225 | SpacyAdvancedTokenize.setup( 226 | dataset._sc, dataset.sql_ctx, self.getLang()) 227 | func = SpacyAdvancedTokenize.func(self.getLang(), 228 | self.getSpacyFields()) 229 | ret_type = SpacyAdvancedTokenize.returnType( 230 | self.getLang(), self.getSpacyFields()) 231 | udf = UserDefinedFunction(func, ret_type) 232 | return dataset.withColumn( 233 | self.getOutputCol(), udf(self.getInputCol()) 234 | ) 235 | 236 | 237 | @ignore_unicode_prefix 238 | class NltkPosTransformer(Model, HasInputCol, HasOutputCol): 239 | """ 240 | Determine the positiveness of the sentence input. 241 | >>> from pyspark.sql import SparkSession 242 | >>> spark = SparkSession.builder.master("local[2]").getOrCreate() 243 | >>> df = spark.createDataFrame([("Boo is happy",), ("sad Boo",), 244 | ... ("i enjoy rope burn",)], ["vals"]) 245 | >>> tr = NltkPosTransformer(inputCol="vals", outputCol="c") 246 | >>> tr.transform(df).show() 247 | +-----------------+-----+ 248 | | vals| c| 249 | +-----------------+-----+ 250 | | Boo is happy|0.6...| 251 | | sad Boo| 0.0| 252 | |i enjoy rope burn|0.6...| 253 | +-----------------+-----+... 254 | """ 255 | 256 | @keyword_only 257 | def __init__(self, inputCol=None, outputCol=None): 258 | super(NltkPosTransformer, self).__init__() 259 | kwargs = self._input_kwargs 260 | self.setParams(**kwargs) 261 | 262 | @keyword_only 263 | def setParams(self, inputCol=None, outputCol=None): 264 | """ 265 | setParams(self, inputCol=None, outputCol=None): 266 | """ 267 | kwargs = self._input_kwargs 268 | return self._set(**kwargs) 269 | 270 | def _transform(self, dataset): 271 | func = NltkPos.func() 272 | ret_type = NltkPos.returnType() 273 | udf = pandas_udf(func, ret_type, PandasUDFType.SCALAR) 274 | return dataset.withColumn( 275 | self.getOutputCol(), udf(self.getInputCol()) 276 | ) 277 | 278 | 279 | if __name__ == '__main__': 280 | import doctest 281 | doctest.testmod(optionflags=doctest.ELLIPSIS) 282 | -------------------------------------------------------------------------------- /sparklingml/java_wrapper_ml.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from pyspark.ml.wrapper import JavaModel, JavaTransformer 4 | 5 | 6 | class SparklingJavaTransformer(JavaTransformer): 7 | """ 8 | Base class for Java transformers exposed in Python. 9 | """ 10 | def __init__(self, jt=None): 11 | super(SparklingJavaTransformer, self).__init__() 12 | if not jt: 13 | self._java_obj = self._new_java_obj(self.transformer_name) 14 | else: 15 | self._java_obj = jt 16 | 17 | 18 | class SparklingJavaModel(JavaModel): 19 | """ 20 | Base class for Java mdels exposed in Python. 21 | """ 22 | def __init__(self, jm=None): 23 | super(SparklingJavaModel, self).__init__() 24 | if not jm: 25 | self._java_obj = self._new_java_obj(self.model_name) 26 | else: 27 | self._java_obj = jm 28 | -------------------------------------------------------------------------------- /sparklingml/param/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sparklingpandas/sparklingml/0fc718697774524f1c06fe1670c39c76aa4522cc/sparklingml/param/__init__.py -------------------------------------------------------------------------------- /sparklingml/param/_shared_params_code_gen.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | from __future__ import print_function 19 | 20 | header = """# 21 | # Licensed to the Apache Software Foundation (ASF) under one or more 22 | # contributor license agreements. See the NOTICE file distributed with 23 | # this work for additional information regarding copyright ownership. 24 | # The ASF licenses this file to You under the Apache License, Version 2.0 25 | # (the "License"); you may not use this file except in compliance with 26 | # the License. You may obtain a copy of the License at 27 | # 28 | # http://www.apache.org/licenses/LICENSE-2.0 29 | # 30 | # Unless required by applicable law or agreed to in writing, software 31 | # distributed under the License is distributed on an "AS IS" BASIS, 32 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 33 | # See the License for the specific language governing permissions and 34 | # limitations under the License. 35 | #""" 36 | 37 | # Code generator for shared params (shared.py). Run under this folder with: 38 | # python _shared_params_code_gen.py > shared.py 39 | 40 | 41 | def _gen_param_header(name, doc, defaultValueStr, typeConverter): 42 | """ 43 | Generates the header part for shared variables 44 | 45 | :param name: param name 46 | :param doc: param doc 47 | """ 48 | template = '''class Has$Name(Params): 49 | """ 50 | Mixin for param $name: 51 | $doc 52 | """ 53 | 54 | $name = Param( 55 | Params._dummy(), 56 | "$name", 57 | "$doc", 58 | typeConverter=$typeConverter) 59 | 60 | def __init__(self): 61 | super(Has$Name, self).__init__()''' 62 | 63 | if defaultValueStr is not None: 64 | template += ''' 65 | self._setDefault($name=$defaultValueStr)''' 66 | 67 | Name = name[0].upper() + name[1:] 68 | if typeConverter is None: 69 | typeConverter = str(None) 70 | return template \ 71 | .replace("$name", name) \ 72 | .replace("$Name", Name) \ 73 | .replace("$doc", doc) \ 74 | .replace("$defaultValueStr", str(defaultValueStr)) \ 75 | .replace("$typeConverter", typeConverter) 76 | 77 | 78 | def _gen_param_code(name, doc, defaultValueStr): 79 | """ 80 | Generates Python code for a shared param class. 81 | 82 | :param name: param name 83 | :param doc: param doc 84 | :param defaultValueStr: string representation of the default value 85 | :return: code string 86 | """ 87 | # TODO: How to correctly inherit instance attributes? 88 | template = ''' 89 | def set$Name(self, value): 90 | """ 91 | Sets the value of :py:attr:`$name`. 92 | """ 93 | return self._set($name=value) 94 | 95 | def get$Name(self): 96 | """ 97 | Gets the value of $name or its default value. 98 | """ 99 | return self.getOrDefault(self.$name)''' 100 | 101 | Name = name[0].upper() + name[1:] 102 | return template \ 103 | .replace("$name", name) \ 104 | .replace("$Name", Name) \ 105 | .replace("$doc", doc) \ 106 | .replace("$defaultValueStr", str(defaultValueStr)) 107 | 108 | 109 | if __name__ == "__main__": 110 | print(header) 111 | print("\n# DO NOT MODIFY THIS FILE! It was generated by") 112 | print("# _shared_params_code_gen.py.\n") 113 | print("from pyspark.ml.param import *\n\n") 114 | shared = [ 115 | ("stopwordCase", 116 | "If the case should be considered when filtering stopwords", 117 | False, "TypeConverters.toBoolean"), 118 | ("stopwords", 119 | "Stopwords to be filtered. Default depends on underlying transformer", 120 | None, "TypeConverters.toListString") 121 | ] 122 | 123 | code = [] 124 | for name, doc, defaultValueStr, typeConverter in shared: 125 | param_code = _gen_param_header(name, doc, defaultValueStr, 126 | typeConverter) 127 | code.append(param_code + "\n" + 128 | _gen_param_code(name, doc, defaultValueStr)) 129 | print("\n\n\n".join(code)) 130 | -------------------------------------------------------------------------------- /sparklingml/param/shared.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | # DO NOT MODIFY THIS FILE! It was generated by 19 | # _shared_params_code_gen.py. 20 | 21 | from pyspark.ml.param import * 22 | 23 | 24 | class HasStopwordCase(Params): 25 | """ 26 | Mixin for param stopwordCase: 27 | If the case should be considered when filtering stopwords 28 | """ 29 | 30 | stopwordCase = Param( 31 | Params._dummy(), 32 | "stopwordCase", 33 | "If the case should be considered when filtering stopwords", 34 | typeConverter=TypeConverters.toBoolean) 35 | 36 | def __init__(self): 37 | super(HasStopwordCase, self).__init__() 38 | self._setDefault(stopwordCase=False) 39 | 40 | def setStopwordCase(self, value): 41 | """ 42 | Sets the value of :py:attr:`stopwordCase`. 43 | """ 44 | return self._set(stopwordCase=value) 45 | 46 | def getStopwordCase(self): 47 | """ 48 | Gets the value of stopwordCase or its default value. 49 | """ 50 | return self.getOrDefault(self.stopwordCase) 51 | 52 | 53 | class HasStopwords(Params): 54 | """ 55 | Mixin for param stopwords: 56 | Stopwords to be filtered. Default depends on underlying transformer 57 | """ 58 | 59 | stopwords = Param( 60 | Params._dummy(), 61 | "stopwords", 62 | "Stopwords to be filtered. Default depends on underlying transformer", 63 | typeConverter=TypeConverters.toListString) 64 | 65 | def __init__(self): 66 | super(HasStopwords, self).__init__() 67 | 68 | def setStopwords(self, value): 69 | """ 70 | Sets the value of :py:attr:`stopwords`. 71 | """ 72 | return self._set(stopwords=value) 73 | 74 | def getStopwords(self): 75 | """ 76 | Gets the value of stopwords or its default value. 77 | """ 78 | return self.getOrDefault(self.stopwords) 79 | -------------------------------------------------------------------------------- /sparklingml/startup.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from py4j.java_gateway import * 4 | # Spark imports 5 | from pyspark.conf import SparkConf 6 | from pyspark.context import SparkContext 7 | from pyspark.sql import * 8 | from pyspark.sql.functions import UserDefinedFunction 9 | 10 | from sparklingml.transformation_functions import * 11 | 12 | # Hack to allow people to hook in more easily 13 | try: 14 | from user_functions import * 15 | setup_user() 16 | except ImportError: 17 | pass 18 | 19 | 20 | # This class is used to allow the Scala process to call into Python 21 | # It may not run in the same Python process as your regular Python 22 | # shell if you are running PySpark normally. 23 | class PythonRegistrationProvider(object): 24 | """ 25 | Provide an entry point for Scala to call to register functions. 26 | """ 27 | 28 | def __init__(self, gateway): 29 | self.gateway = gateway 30 | self._sc = None 31 | self._session = None 32 | self._count = 0 33 | 34 | def registerFunction(self, ssc, jsession, function_name, params): 35 | jvm = self.gateway.jvm 36 | # If we don't have a reference to a running SparkContext 37 | # Get the SparkContext from the provided SparkSession. 38 | if not self._sc: 39 | master = ssc.master() 40 | jsc = jvm.org.apache.spark.api.java.JavaSparkContext(ssc) 41 | jsparkConf = ssc.conf() 42 | sparkConf = SparkConf(_jconf=jsparkConf) 43 | self._sc = SparkContext( 44 | master=master, 45 | conf=sparkConf, 46 | gateway=self.gateway, 47 | jsc=jsc) 48 | self._session = SparkSession.builder.getOrCreate() 49 | if function_name in functions_info: 50 | function_info = functions_info[function_name] 51 | if params: 52 | evaledParams = ast.literal_eval(params) 53 | else: 54 | evaledParams = [] 55 | func = function_info.func(*evaledParams) 56 | ret_type = function_info.returnType() 57 | self._count = self._count + 1 58 | registration_name = function_name + str(self._count) 59 | udf = UserDefinedFunction(func, ret_type, registration_name) 60 | # Used to allow non-default (e.g. Arrow) UDFS 61 | udf.evalType = function_info.evalType() 62 | judf = udf._judf 63 | return judf 64 | else: 65 | print("Could not find function") 66 | # We do this rather than raising an exception since Py4J debugging 67 | # is rough and we can check it. 68 | return None 69 | 70 | class Java: 71 | package = "com.sparklingpandas.sparklingml.util.python" 72 | className = "PythonRegisterationProvider" 73 | implements = [package + "." + className] 74 | 75 | 76 | if __name__ == "__main__": 77 | def spark_jvm_imports(jvm): 78 | # Import the classes used by PySpark 79 | java_import(jvm, "org.apache.spark.SparkConf") 80 | java_import(jvm, "org.apache.spark.api.java.*") 81 | java_import(jvm, "org.apache.spark.api.python.*") 82 | java_import(jvm, "org.apache.spark.ml.python.*") 83 | java_import(jvm, "org.apache.spark.mllib.api.python.*") 84 | # TODO(davies): move into sql 85 | java_import(jvm, "org.apache.spark.sql.*") 86 | java_import(jvm, "org.apache.spark.sql.hive.*") 87 | java_import(jvm, "scala.Tuple2") 88 | 89 | import os 90 | if "SPARKLING_ML_SPECIFIC" in os.environ: 91 | # Py4J setup work so we can talk 92 | gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) 93 | gateway = JavaGateway( 94 | GatewayClient(port=gateway_port), 95 | # TODO: handle dynamic port binding here correctly. 96 | callback_server_parameters=CallbackServerParameters(port=0), 97 | auto_convert=True) 98 | # retrieve the port on which the python callback server was bound to. 99 | python_port = gateway.get_callback_server().get_listening_port() 100 | # bind the callback server on the java side to the new python_port 101 | gateway.java_gateway_server.resetCallbackClient( 102 | gateway.java_gateway_server.getCallbackClient().getAddress(), 103 | python_port) 104 | # Create our registration provider interface for Py4J to call into 105 | provider = PythonRegistrationProvider(gateway) 106 | # Sparkling pandas specific imports 107 | jvm = gateway.jvm 108 | java_import(jvm, "com.sparklingpandas.sparklingml") 109 | java_import(jvm, "com.sparklingpandas.sparklingml.util.python") 110 | # We need to re-do the Spark gateway imports as well 111 | spark_jvm_imports(jvm) 112 | python_utils = jvm.com.sparklingpandas.sparklingml.util.python 113 | pythonRegistrationObj = python_utils.PythonRegistration 114 | boople = jvm.org.apache.spark.SparkConf(False) 115 | pythonRegistrationObj.register(provider) 116 | # Busy loop so we don't exit. This is also kind of a hack. 117 | import time 118 | while (True): 119 | time.sleep(1) 120 | print("real exit") 121 | -------------------------------------------------------------------------------- /sparklingml/transformation_functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | import inspect 4 | import sys 5 | import threading 6 | 7 | import pandas # noqa: F401 8 | import spacy 9 | from pyspark.rdd import ignore_unicode_prefix 10 | from pyspark.sql.functions import PandasUDFType 11 | from pyspark.sql.types import * 12 | 13 | if sys.version_info.major == 3: 14 | unicode = str 15 | 16 | functions_info = dict() 17 | 18 | 19 | class TransformationFunction(object): 20 | @classmethod 21 | def setup(cls, sc, session, *args): 22 | """Perform any setup work (like global broadcasts)""" 23 | pass 24 | 25 | @classmethod 26 | def returnType(cls, *args): 27 | """Return the sql return type""" 28 | return None 29 | 30 | @classmethod 31 | def func(cls, *args): 32 | """Returns a function constructed using the args.""" 33 | return None 34 | 35 | @classmethod 36 | def evalType(cls): 37 | """Returns the eval type to be used.""" 38 | from pyspark.rdd import PythonEvalType 39 | return PythonEvalType.SQL_BATCHED_UDF 40 | 41 | 42 | class ScalarVectorizedTransformationFunction(TransformationFunction): 43 | """Transformation functions which are Scalar Vectorized UDFS.""" 44 | 45 | @classmethod 46 | def evalType(cls): 47 | """Returns the eval type to be used.""" 48 | return PandasUDFType.SCALAR 49 | 50 | 51 | @ignore_unicode_prefix 52 | class StrLenPlusK(TransformationFunction): 53 | """ 54 | strLenPlusK takes one parameter it is k and returns 55 | the string length plus k. This is intended to illustrate how 56 | to make a Python stage usable from Scala, not for actual usage. 57 | """ 58 | @classmethod 59 | def setup(cls, sc, session, *args): 60 | pass 61 | 62 | @classmethod 63 | def func(cls, *args): 64 | k = args[0] 65 | 66 | def inner(inputString): 67 | """Compute the string length plus K (based on parameters).""" 68 | return len(inputString) + k 69 | return inner 70 | 71 | @classmethod 72 | def returnType(cls, *args): 73 | return IntegerType() 74 | 75 | 76 | functions_info["strlenplusk"] = StrLenPlusK 77 | 78 | 79 | # Spacy isn't serializable but loading it is semi-expensive 80 | @ignore_unicode_prefix 81 | class SpacyMagic(object): 82 | """ 83 | Simple Spacy Magic to minimize loading time. 84 | >>> spm = SpacyMagic() 85 | >>> spm2 = SpacyMagic() 86 | >>> spm == spm2 87 | True 88 | >>> spm.get("en_core_web_sm") 89 | >> spm.get("non-happy-language") 91 | Traceback (most recent call last): 92 | ... 93 | Exception: Failed to find or download language non-happy-language:... 94 | >>> spm.broadcast() 95 | >> spt = SpacyTokenize() 163 | >>> sp = spt.func("en_core_web_sm") 164 | >>> r = sp(pandas.Series(["hi boo"])) 165 | ... 166 | >>> r 167 | 0 [hi, boo] 168 | dtype: object 169 | """ 170 | @classmethod 171 | def setup(cls, sc, session, *args): 172 | pass 173 | 174 | @classmethod 175 | def func(cls, *args): 176 | lang = args[0] 177 | spm = SpacyMagic() 178 | 179 | def inner(inputSeries): 180 | """Tokenize the inputString using spacy for 181 | the provided language.""" 182 | nlp = spm.get(lang) 183 | 184 | def tokenizeElem(elem): 185 | result_itr = map(lambda token: token.text, 186 | list(nlp(unicode(elem)))) 187 | return list(result_itr) 188 | 189 | return inputSeries.apply(tokenizeElem) 190 | return inner 191 | 192 | @classmethod 193 | def returnType(cls, *args): 194 | return ArrayType(StringType()) 195 | 196 | 197 | functions_info["spacytokenize"] = SpacyTokenize 198 | 199 | 200 | @ignore_unicode_prefix 201 | class SpacyAdvancedTokenize(TransformationFunction): 202 | """ 203 | Tokenize input text using spacy and return the extra information. 204 | >>> spta = SpacyAdvancedTokenize() 205 | >>> spa = spta.func("en_core_web_sm", ["lower_", "text", "lang", "a"]) 206 | >>> r = spa("Hi boo") 207 | >>> l = list(map(lambda d: sorted(d.items()), r)) 208 | >>> l[0] 209 | [(u'a', None), (u'lang', '...'), (u'lower_', 'hi'), (u'text', 'Hi')] 210 | >>> l[1] 211 | [(u'a', None), (u'lang', '...'), (u'lower_', 'boo'), (u'text', 'boo')] 212 | """ 213 | 214 | default_fields = map( 215 | lambda x: x[0], 216 | inspect.getmembers(spacy.tokens.Token, 217 | lambda x: ">> sentences = pandas.Series(["Boo is happy", "Boo is sad", "confused."]) 258 | >>> myFunc = NltkPos().func() 259 | >>> import math 260 | >>> myFunc(sentences).apply(math.ceil) 261 | 0 1... 262 | 1 0... 263 | 2 0... 264 | dtype: ... 265 | """ 266 | 267 | @classmethod 268 | def func(cls, *args): 269 | def inner(input_series): 270 | from nltk.sentiment.vader import SentimentIntensityAnalyzer 271 | # Hack until https://github.com/nteract/coffee_boat/issues/47 272 | try: 273 | sid = SentimentIntensityAnalyzer() 274 | except LookupError: 275 | import nltk 276 | nltk.download('vader_lexicon') 277 | sid = SentimentIntensityAnalyzer() 278 | result = input_series.apply( 279 | lambda sentence: sid.polarity_scores(sentence)['pos']) 280 | return result 281 | return inner 282 | 283 | @classmethod 284 | def returnType(cls, *args): 285 | return DoubleType() 286 | 287 | 288 | functions_info["nltkpos"] = NltkPos 289 | 290 | if __name__ == '__main__': 291 | import doctest 292 | doctest.testmod(optionflags=doctest.ELLIPSIS) 293 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/CodeGenerator.scala: -------------------------------------------------------------------------------- 1 | package com.sparklingpandas.sparklingml 2 | 3 | private[sparklingpandas] class CodeGenerator { 4 | val testRoot = "src/test/scala/com/sparklingpandas/sparklingml" 5 | val mainRoot = "src/main/scala/com/sparklingpandas/sparklingml" 6 | val pythonRoot = "sparklingml/" 7 | val scalaLicenseHeader = """/* 8 | | * Licensed to the Apache Software Foundation (ASF) under one or more 9 | | * contributor license agreements. See the NOTICE file distributed with 10 | | * this work for additional information regarding copyright ownership. 11 | | * The ASF licenses this file to You under the Apache License, Version 2.0 12 | | * (the "License"); you may not use this file except in compliance with 13 | | * the License. You may obtain a copy of the License at 14 | | * 15 | | * http://www.apache.org/licenses/LICENSE-2.0 16 | | * 17 | | * Unless required by applicable law or agreed to in writing, software 18 | | * distributed under the License is distributed on an "AS IS" BASIS, 19 | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | | * See the License for the specific language governing permissions and 21 | | * limitations under the License. 22 | | */""".stripMargin('|') 23 | 24 | val pythonLicenseHeader = 25 | """# 26 | |# Licensed to the Apache Software Foundation (ASF) under one or more 27 | |# contributor license agreements. See the NOTICE file distributed with 28 | |# this work for additional information regarding copyright ownership. 29 | |# The ASF licenses this file to You under the Apache License, Version 2.0 30 | |# (the "License"); you may not use this file except in compliance with 31 | |# the License. You may obtain a copy of the License at 32 | |# 33 | |# http://www.apache.org/licenses/LICENSE-2.0 34 | |# 35 | |# Unless required by applicable law or agreed to in writing, software 36 | |# distributed under the License is distributed on an "AS IS" BASIS, 37 | |# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 38 | |# See the License for the specific language governing permissions and 39 | |# limitations under the License. 40 | |#""".stripMargin('|') 41 | 42 | val pythonDoctestFooter = 43 | """ 44 | |if __name__ == "__main__": 45 | | import doctest 46 | | doctest.testmod(optionflags=doctest.ELLIPSIS) 47 | |""".stripMargin('|') 48 | 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/feature/BasicPython.scala: -------------------------------------------------------------------------------- 1 | package com.sparklingpandas.sparklingml.feature 2 | 3 | import org.apache.spark.ml.param._ 4 | import org.apache.spark.ml.util.Identifiable 5 | import org.apache.spark.sql.types._ 6 | 7 | 8 | import com.sparklingpandas.sparklingml.util.python.PythonTransformer 9 | 10 | class NltkPosPython(override val uid: String) extends PythonTransformer { 11 | 12 | def this() = this(Identifiable.randomUID("StrLenPlusKPython")) 13 | 14 | override val pythonFunctionName = "nltkpos" 15 | override protected def outputDataType = DoubleType 16 | override protected def validateInputType(inputType: DataType): Unit = { 17 | if (inputType != StringType) { 18 | throw new IllegalArgumentException( 19 | s"Expected input type StringType instead found ${inputType}") 20 | } 21 | } 22 | 23 | override def copy(extra: ParamMap) = { 24 | defaultCopy(extra) 25 | } 26 | 27 | def miniSerializeParams() = "" 28 | } 29 | 30 | 31 | class StrLenPlusKPython(override val uid: String) extends PythonTransformer { 32 | 33 | final val k: IntParam = new IntParam(this, "k", "number to add to strlen") 34 | 35 | /** @group getParam */ 36 | final def getK: Int = $(k) 37 | 38 | final def setK(value: Int): this.type = set(this.k, value) 39 | 40 | def this() = this(Identifiable.randomUID("StrLenPlusKPython")) 41 | 42 | override val pythonFunctionName = "strlenplusk" 43 | override protected def outputDataType = IntegerType 44 | override protected def validateInputType(inputType: DataType): Unit = { 45 | if (inputType != StringType) { 46 | throw new IllegalArgumentException( 47 | s"Expected input type StringType instead found ${inputType}") 48 | } 49 | } 50 | 51 | override def copy(extra: ParamMap) = { 52 | defaultCopy(extra) 53 | } 54 | 55 | def miniSerializeParams() = { 56 | "[" + $(k) + "]" 57 | } 58 | } 59 | 60 | class SpacyTokenizePython(override val uid: String) extends PythonTransformer { 61 | 62 | final val lang = new Param[String](this, "lang", "language for tokenization") 63 | 64 | /** @group getParam */ 65 | final def getLang: String = $(lang) 66 | 67 | final def setLang(value: String): this.type = set(this.lang, value) 68 | 69 | def this() = this(Identifiable.randomUID("SpacyTokenizePython")) 70 | 71 | override val pythonFunctionName = "spacytokenize" 72 | override protected def outputDataType = ArrayType(StringType) 73 | override protected def validateInputType(inputType: DataType): Unit = { 74 | if (inputType != StringType) { 75 | throw new IllegalArgumentException( 76 | s"Expected input type StringType instead found ${inputType}") 77 | } 78 | } 79 | 80 | override def copy(extra: ParamMap) = { 81 | defaultCopy(extra) 82 | } 83 | 84 | def miniSerializeParams() = { 85 | "[\"" + $(lang) + "\"]" 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzer.scala: -------------------------------------------------------------------------------- 1 | package com.sparklingpandas.sparklingml.feature 2 | 3 | import org.apache.spark.annotation.DeveloperApi 4 | import org.apache.spark.ml.UnaryTransformer 5 | import org.apache.spark.sql.Dataset 6 | import org.apache.spark.sql.types._ 7 | 8 | import org.apache.lucene.analysis.Analyzer 9 | import org.apache.lucene.analysis.tokenattributes.CharTermAttribute 10 | 11 | /** 12 | * Abstract trait for Lucene Transformer. An alternative option is to 13 | * use LuceneTextAnalyzerTransformer from the spark-solr project. 14 | */ 15 | @DeveloperApi 16 | trait LuceneTransformer[T <:LuceneTransformer[T]] 17 | extends UnaryTransformer[String, Array[String], T] { 18 | 19 | // Implement this function to construct an analyzer based on the provided settings. 20 | def buildAnalyzer(): Analyzer 21 | 22 | override def outputDataType: DataType = ArrayType(StringType) 23 | 24 | override def validateInputType(inputType: DataType): Unit = { 25 | require(inputType.isInstanceOf[StringType], 26 | s"The input column must be StringType, but got $inputType.") 27 | } 28 | 29 | override def createTransformFunc: String => Array[String] = { 30 | (inputText: String) => { 31 | val analyzer = buildAnalyzer() 32 | val inputStream = analyzer.tokenStream($(inputCol), inputText) 33 | val builder = Array.newBuilder[String] 34 | val charTermAttr = inputStream.addAttribute(classOf[CharTermAttribute]) 35 | inputStream.reset() 36 | while (inputStream.incrementToken) builder += charTermAttr.toString 37 | inputStream.end() 38 | inputStream.close() 39 | builder.result() 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzerGenerators.scala: -------------------------------------------------------------------------------- 1 | package com.sparklingpandas.sparklingml.feature 2 | 3 | import java.lang.reflect.Modifier 4 | import java.io.PrintWriter 5 | 6 | import scala.collection.JavaConverters._ 7 | import scala.collection.mutable.StringBuilder 8 | 9 | import com.sparklingpandas.sparklingml.CodeGenerator 10 | 11 | import org.reflections.Reflections 12 | 13 | 14 | import org.apache.spark.annotation.DeveloperApi 15 | 16 | import org.apache.lucene.analysis.{Analyzer, CharArraySet} 17 | 18 | 19 | /** 20 | * Code generator for LuceneAnalyzers (LuceneAnalyzers.scala). Run with 21 | * {{{ 22 | * build/sbt "test:runMain com.sparklingpandas.sparklingml.feature.LuceneAnalyzerGenerators" 23 | * }}} 24 | */ 25 | private[sparklingpandas] object LuceneAnalyzerGenerators extends CodeGenerator { 26 | 27 | def writeScalaCode(code: String, file: String) = { 28 | val header = s"""${scalaLicenseHeader} 29 | | 30 | |package com.sparklingpandas.sparklingml.feature 31 | | 32 | |import org.apache.spark.ml.param._ 33 | |import org.apache.spark.ml.util.Identifiable 34 | | 35 | |import org.apache.lucene.analysis.Analyzer 36 | | 37 | |import com.sparklingpandas.sparklingml.param._ 38 | | 39 | |// DO NOT MODIFY THIS FILE! 40 | |// It was auto generated by LuceneAnalyzerGenerators. 41 | | 42 | """.stripMargin('|') 43 | val writer = new PrintWriter(file) 44 | writer.write(header) 45 | writer.write(code) 46 | writer.close() 47 | } 48 | 49 | def writePythonCode(code: String, file: String) = { 50 | val pythonHeader = s"""${pythonLicenseHeader} 51 | | 52 | |# DO NOT MODIFY THIS FILE! 53 | |# It was auto generated by LuceneanalyzerGenerators 54 | | 55 | |from __future__ import unicode_literals 56 | | 57 | |from pyspark import keyword_only 58 | |from pyspark.ml.param import * 59 | |from pyspark.ml.param.shared import HasInputCol, HasOutputCol 60 | |# The shared params aren't really intended to be public currently.. 61 | |from pyspark.ml.param.shared import * 62 | |from pyspark.ml.util import * 63 | | 64 | |from sparklingml.java_wrapper_ml import * 65 | |from sparklingml.param.shared import HasStopwords, HasStopwordCase 66 | | 67 | |""".stripMargin('|') 68 | val writer = new PrintWriter(file) 69 | writer.write(pythonHeader) 70 | writer.write(code) 71 | writer.write(pythonDoctestFooter) 72 | writer.close() 73 | } 74 | 75 | def main(args: Array[String]): Unit = { 76 | val (testCode, transformerCode, pyCode) = generate() 77 | val testCodeFile = s"${testRoot}/feature/LuceneAnalyzersTests.scala" 78 | val transformerCodeFile = s"${mainRoot}/feature/LuceneAnalyzers.scala" 79 | val pyCodeFile = s"${pythonRoot}/feature/lucene_analyzers.py" 80 | 81 | List((testCode, testCodeFile), (transformerCode, transformerCodeFile)).foreach { 82 | case (code: String, file: String) => 83 | writeScalaCode(code, file) 84 | } 85 | writePythonCode(pyCode, pyCodeFile) 86 | } 87 | 88 | def generate(): (String, String, String) = { 89 | val reflections = new Reflections("org.apache.lucene"); 90 | val generalAnalyzers = 91 | reflections.getSubTypesOf(classOf[org.apache.lucene.analysis.Analyzer]) 92 | .asScala.toList.sortBy(_.toString) 93 | val concreteAnalyzers = 94 | generalAnalyzers.filter(cls => !Modifier.isAbstract(cls.getModifiers)) 95 | // A bit of a hack but strip out the factories and such 96 | val relevantAnalyzers = concreteAnalyzers.filter(cls => 97 | !(cls.toString.contains("$") || cls.toString.contains("Factory"))) 98 | val generated = relevantAnalyzers.map{ cls => 99 | generateForClass(cls) 100 | } 101 | val testCode = new StringBuilder() 102 | val transformerCode = new StringBuilder() 103 | val pyCode = new StringBuilder() 104 | generated.foreach{case (test, transform, python) => 105 | testCode ++= test 106 | transformerCode ++= transform 107 | pyCode ++= python 108 | } 109 | (testCode.toString, transformerCode.toString, pyCode.toString) 110 | } 111 | 112 | def generateStopwordStage(cls: Class[_], 113 | constructorParametersSizes: List[Int], 114 | clsShortName: String, 115 | clsFullName: String): (String, String, String) = { 116 | val includeWarning = constructorParametersSizes.exists(_ > 1) 117 | val warning = if (includeWarning) { 118 | s""" 119 | | * There are additional parameters which can not yet be contro 120 | lled through this API 121 | | * See https://github.com/sparklingpandas/sparklingml/issues/3""" 122 | .stripMargin('|') 123 | } else { 124 | "" 125 | } 126 | val testCode = 127 | s""" 128 | |/** 129 | | * A super simple test 130 | | */ 131 | |class ${clsShortName}LuceneTest 132 | | extends LuceneStopwordTransformerTest[${clsShortName}Lucene] { 133 | | val transformer = new ${clsShortName}Lucene() 134 | |} 135 | |""".stripMargin('|') 136 | val code = 137 | s""" 138 | |/** 139 | | * A basic Transformer based on ${clsShortName}. 140 | | * Supports configuring stopwords.${warning} 141 | | */ 142 | | 143 | |class ${clsShortName}Lucene(override val uid: String) 144 | | extends LuceneTransformer[${clsShortName}Lucene] 145 | | with HasStopwords with HasStopwordCase { 146 | | 147 | | def this() = this(Identifiable.randomUID("${clsShortName}")) 148 | | 149 | | def buildAnalyzer(): Analyzer = { 150 | | // In the future we can use getDefaultStopWords here to allow people 151 | | // to control the snowball stemmer distinctly from the stopwords. 152 | | // but that is a TODO for later. 153 | | if (isSet(stopwords)) { 154 | | new ${clsFullName}( 155 | | LuceneHelpers.wordstoCharArraySet($$(stopwords), !$$(stopwordCase))) 156 | | } else { 157 | | new ${clsFullName}() 158 | | } 159 | | } 160 | |} 161 | |""".stripMargin('|') 162 | val pyCode = 163 | s""" 164 | |class ${clsShortName}Lucene( 165 | | SparklingJavaTransformer, HasInputCol, HasOutputCol, 166 | | HasStopwords, HasStopwordCase): 167 | | \"\"\" 168 | | >>> from pyspark.sql import SparkSession 169 | | >>> spark = SparkSession.builder.master("local[2]").getOrCreate() 170 | | >>> df = spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) 171 | | >>> transformer = ${clsShortName}Lucene() 172 | | >>> transformer.setParams(inputCol="vals", outputCol="out") 173 | | ${clsShortName}Lucene_... 174 | | >>> result = transformer.transform(df) 175 | | >>> result.count() 176 | | 2 177 | | >>> transformer.setStopwordCase(True) 178 | | ${clsShortName}Lucene_... 179 | | >>> result = transformer.transform(df) 180 | | >>> result.count() 181 | | 2 182 | | \"\"\" 183 | | package_name = "com.sparklingpandas.sparklingml.feature" 184 | | class_name = "${clsShortName}Lucene" 185 | | transformer_name = package_name + "." + class_name 186 | | 187 | | @keyword_only 188 | | def __init__(self, inputCol=None, outputCol=None, 189 | | stopwords=None, stopwordCase=False): 190 | | \"\"\" 191 | | __init__(self, inputCol=None, outputCol=None, 192 | | stopwords=None, stopwordCase=False) 193 | | \"\"\" 194 | | super(${clsShortName}Lucene, self).__init__() 195 | | self._setDefault(stopwordCase=False) 196 | | kwargs = self._input_kwargs 197 | | self.setParams(**kwargs) 198 | | 199 | | @keyword_only 200 | | def setParams(self, inputCol=None, outputCol=None, 201 | | stopwords=None, stopwordCase=False): 202 | | \"\"\" 203 | | setParams(inputCol=None, outputCol=None, 204 | | stopwords=None, stopwordCase=False) 205 | | \"\"\" 206 | | kwargs = self._input_kwargs 207 | | return self._set(**kwargs) 208 | | 209 | |""".stripMargin('|') 210 | (testCode, code, pyCode) 211 | } 212 | 213 | def generateZeroArgStage(clsShortName: String, clsFullName: String): 214 | (String, String, String) = { 215 | val testCode = 216 | s""" 217 | |/** 218 | | * A super simple test 219 | | */ 220 | |class ${clsShortName}LuceneTest 221 | | extends LuceneTransformerTest[${clsShortName}Lucene] { 222 | | val transformer = new ${clsShortName}Lucene() 223 | |} 224 | |""".stripMargin('|') 225 | val code = 226 | s""" 227 | |/** 228 | | * A basic Transformer based on ${clsShortName} - does not support 229 | | * any configuration properties. 230 | | * See https://github.com/sparklingpandas/sparklingml/issues/3 231 | | * & LuceneAnalyzerGenerators for details. 232 | | */ 233 | | 234 | |class ${clsShortName}Lucene(override val uid: String) 235 | | extends LuceneTransformer[${clsShortName}Lucene] { 236 | | 237 | | def this() = this(Identifiable.randomUID("${clsShortName}")) 238 | | 239 | | def buildAnalyzer(): Analyzer = { 240 | | new ${clsFullName}() 241 | | } 242 | |} 243 | |""".stripMargin('|') 244 | val pyCode = 245 | s""" 246 | |class ${clsShortName}Lucene( 247 | | SparklingJavaTransformer, HasInputCol, HasOutputCol): 248 | | \"\"\" 249 | | >>> from pyspark.sql import SparkSession 250 | | >>> spark = SparkSession.builder.master("local[2]").getOrCreate() 251 | | >>> df = spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"]) 252 | | >>> transformer = ${clsShortName}Lucene() 253 | | >>> transformer.setParams(inputCol="vals", outputCol="out") 254 | | ${clsShortName}Lucene_... 255 | | >>> result = transformer.transform(df) 256 | | >>> result.count() 257 | | 2 258 | | \"\"\" 259 | | package_name = "com.sparklingpandas.sparklingml.feature" 260 | | class_name = "${clsShortName}Lucene" 261 | | transformer_name = package_name + "." + class_name 262 | | 263 | | @keyword_only 264 | | def __init__(self, inputCol=None, outputCol=None): 265 | | \"\"\" 266 | | __init__(self, inputCol=None, outputCol=None) 267 | | \"\"\" 268 | | super(${clsShortName}Lucene, self).__init__() 269 | | kwargs = self._input_kwargs 270 | | self.setParams(**kwargs) 271 | | 272 | | @keyword_only 273 | | def setParams(self, inputCol=None, outputCol=None): 274 | | \"\"\" 275 | | setParams(inputCol=None, outputCol=None) 276 | | \"\"\" 277 | | kwargs = self._input_kwargs 278 | | return self._set(**kwargs) 279 | | 280 | |""".stripMargin('|') 281 | (testCode, code, pyCode) 282 | } 283 | 284 | def generateForClass(cls: Class[_]): (String, String, String) = { 285 | import scala.reflect.runtime.universe._ 286 | val rm = scala.reflect.runtime.currentMirror 287 | 288 | val clsSymbol = rm.classSymbol(cls) 289 | val clsType = clsSymbol.toType 290 | val clsFullName = clsSymbol.fullName 291 | val clsShortName = clsSymbol.name.toString 292 | val constructors = clsType.members.collect{ 293 | case m: MethodSymbol if m.isConstructor && m.isPublic => m } 294 | // Once we have the debug version constructorParametersLists should be useful 295 | val constructorParametersLists = constructors.map(_.paramLists).toList 296 | val constructorParametersSizes = constructorParametersLists.map(_(0).size) 297 | val javaReflectionConstructors = cls.getConstructors().toList 298 | val publicJavaReflectionConstructors = 299 | javaReflectionConstructors.filter(cls => Modifier.isPublic(cls.getModifiers())) 300 | val constructorParameterTypes = publicJavaReflectionConstructors.map(_.getParameterTypes()) 301 | // We do this in Java as well since some of the scala reflection magic returns private 302 | // constructors even though its filtered for public. See CustomAnalyzer for an example. 303 | val javaConstructorParametersSizes = constructorParameterTypes.map(_.size) 304 | // Since this isn't built with -parameters by default :( 305 | // we'd need a local version built with it to auto generate 306 | // the code here with the right parameters. 307 | // https://docs.oracle.com/javase/tutorial/reflect/member/methodparameterreflection.html 308 | // For now we could dump the class names and go from their 309 | // or we could play a game of pin the field on the constructor. 310 | // local build sounds like the best plan, lets do that l8r 311 | 312 | // Special case for handling stopword analyzers 313 | val baseClasses = clsType.baseClasses 314 | // Normally we'd do a checks with <:< but the Lucene types have multiple 315 | // StopwordAnalyzerBase's that don't inherit from eachother. 316 | val isStopWordAnalyzer = baseClasses.exists(_.asClass.fullName.contains("Stopword")) 317 | 318 | val charsetConstructors = 319 | constructorParameterTypes.filter(! _.exists(_ != classOf[CharArraySet])) 320 | val charsetConstructorSizes = charsetConstructors.map(_.size) 321 | 322 | // If it is a stop word analyzer and has a constructor with two charsets then it takes 323 | // the stopwords as a parameter. 324 | if (isStopWordAnalyzer && charsetConstructorSizes.contains(1)) { 325 | // If there are more parameters 326 | generateStopwordStage(cls, constructorParametersSizes, clsShortName, clsFullName) 327 | } else if (constructorParametersSizes.contains(0) && 328 | javaConstructorParametersSizes.contains(0)) { 329 | generateZeroArgStage(clsShortName, clsFullName) 330 | } else { 331 | ("", s""" 332 | |/* There is no default zero arg constructor for 333 | | *${clsFullName}. 334 | | */ 335 | |""".stripMargin('|'), "") 336 | } 337 | } 338 | } 339 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzers.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.feature 19 | 20 | import org.apache.spark.ml.param._ 21 | import org.apache.spark.ml.util.Identifiable 22 | 23 | import org.apache.lucene.analysis.Analyzer 24 | 25 | import com.sparklingpandas.sparklingml.param._ 26 | 27 | // DO NOT MODIFY THIS FILE! 28 | // It was auto generated by LuceneAnalyzerGenerators. 29 | 30 | 31 | /** 32 | * A basic Transformer based on ArabicAnalyzer. 33 | * Supports configuring stopwords. 34 | * There are additional parameters which can not yet be contro 35 | lled through this API 36 | * See https://github.com/sparklingpandas/sparklingml/issues/3 37 | */ 38 | 39 | class ArabicAnalyzerLucene(override val uid: String) 40 | extends LuceneTransformer[ArabicAnalyzerLucene] 41 | with HasStopwords with HasStopwordCase { 42 | 43 | def this() = this(Identifiable.randomUID("ArabicAnalyzer")) 44 | 45 | def buildAnalyzer(): Analyzer = { 46 | // In the future we can use getDefaultStopWords here to allow people 47 | // to control the snowball stemmer distinctly from the stopwords. 48 | // but that is a TODO for later. 49 | if (isSet(stopwords)) { 50 | new org.apache.lucene.analysis.ar.ArabicAnalyzer( 51 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 52 | } else { 53 | new org.apache.lucene.analysis.ar.ArabicAnalyzer() 54 | } 55 | } 56 | } 57 | 58 | /** 59 | * A basic Transformer based on BulgarianAnalyzer. 60 | * Supports configuring stopwords. 61 | * There are additional parameters which can not yet be contro 62 | lled through this API 63 | * See https://github.com/sparklingpandas/sparklingml/issues/3 64 | */ 65 | 66 | class BulgarianAnalyzerLucene(override val uid: String) 67 | extends LuceneTransformer[BulgarianAnalyzerLucene] 68 | with HasStopwords with HasStopwordCase { 69 | 70 | def this() = this(Identifiable.randomUID("BulgarianAnalyzer")) 71 | 72 | def buildAnalyzer(): Analyzer = { 73 | // In the future we can use getDefaultStopWords here to allow people 74 | // to control the snowball stemmer distinctly from the stopwords. 75 | // but that is a TODO for later. 76 | if (isSet(stopwords)) { 77 | new org.apache.lucene.analysis.bg.BulgarianAnalyzer( 78 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 79 | } else { 80 | new org.apache.lucene.analysis.bg.BulgarianAnalyzer() 81 | } 82 | } 83 | } 84 | 85 | /** 86 | * A basic Transformer based on BrazilianAnalyzer. 87 | * Supports configuring stopwords. 88 | * There are additional parameters which can not yet be contro 89 | lled through this API 90 | * See https://github.com/sparklingpandas/sparklingml/issues/3 91 | */ 92 | 93 | class BrazilianAnalyzerLucene(override val uid: String) 94 | extends LuceneTransformer[BrazilianAnalyzerLucene] 95 | with HasStopwords with HasStopwordCase { 96 | 97 | def this() = this(Identifiable.randomUID("BrazilianAnalyzer")) 98 | 99 | def buildAnalyzer(): Analyzer = { 100 | // In the future we can use getDefaultStopWords here to allow people 101 | // to control the snowball stemmer distinctly from the stopwords. 102 | // but that is a TODO for later. 103 | if (isSet(stopwords)) { 104 | new org.apache.lucene.analysis.br.BrazilianAnalyzer( 105 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 106 | } else { 107 | new org.apache.lucene.analysis.br.BrazilianAnalyzer() 108 | } 109 | } 110 | } 111 | 112 | /** 113 | * A basic Transformer based on CatalanAnalyzer. 114 | * Supports configuring stopwords. 115 | * There are additional parameters which can not yet be contro 116 | lled through this API 117 | * See https://github.com/sparklingpandas/sparklingml/issues/3 118 | */ 119 | 120 | class CatalanAnalyzerLucene(override val uid: String) 121 | extends LuceneTransformer[CatalanAnalyzerLucene] 122 | with HasStopwords with HasStopwordCase { 123 | 124 | def this() = this(Identifiable.randomUID("CatalanAnalyzer")) 125 | 126 | def buildAnalyzer(): Analyzer = { 127 | // In the future we can use getDefaultStopWords here to allow people 128 | // to control the snowball stemmer distinctly from the stopwords. 129 | // but that is a TODO for later. 130 | if (isSet(stopwords)) { 131 | new org.apache.lucene.analysis.ca.CatalanAnalyzer( 132 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 133 | } else { 134 | new org.apache.lucene.analysis.ca.CatalanAnalyzer() 135 | } 136 | } 137 | } 138 | 139 | /** 140 | * A basic Transformer based on CJKAnalyzer. 141 | * Supports configuring stopwords. 142 | */ 143 | 144 | class CJKAnalyzerLucene(override val uid: String) 145 | extends LuceneTransformer[CJKAnalyzerLucene] 146 | with HasStopwords with HasStopwordCase { 147 | 148 | def this() = this(Identifiable.randomUID("CJKAnalyzer")) 149 | 150 | def buildAnalyzer(): Analyzer = { 151 | // In the future we can use getDefaultStopWords here to allow people 152 | // to control the snowball stemmer distinctly from the stopwords. 153 | // but that is a TODO for later. 154 | if (isSet(stopwords)) { 155 | new org.apache.lucene.analysis.cjk.CJKAnalyzer( 156 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 157 | } else { 158 | new org.apache.lucene.analysis.cjk.CJKAnalyzer() 159 | } 160 | } 161 | } 162 | 163 | /** 164 | * A basic Transformer based on SoraniAnalyzer. 165 | * Supports configuring stopwords. 166 | * There are additional parameters which can not yet be contro 167 | lled through this API 168 | * See https://github.com/sparklingpandas/sparklingml/issues/3 169 | */ 170 | 171 | class SoraniAnalyzerLucene(override val uid: String) 172 | extends LuceneTransformer[SoraniAnalyzerLucene] 173 | with HasStopwords with HasStopwordCase { 174 | 175 | def this() = this(Identifiable.randomUID("SoraniAnalyzer")) 176 | 177 | def buildAnalyzer(): Analyzer = { 178 | // In the future we can use getDefaultStopWords here to allow people 179 | // to control the snowball stemmer distinctly from the stopwords. 180 | // but that is a TODO for later. 181 | if (isSet(stopwords)) { 182 | new org.apache.lucene.analysis.ckb.SoraniAnalyzer( 183 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 184 | } else { 185 | new org.apache.lucene.analysis.ckb.SoraniAnalyzer() 186 | } 187 | } 188 | } 189 | 190 | /** 191 | * A basic Transformer based on SmartChineseAnalyzer - does not support 192 | * any configuration properties. 193 | * See https://github.com/sparklingpandas/sparklingml/issues/3 194 | * & LuceneAnalyzerGenerators for details. 195 | */ 196 | 197 | class SmartChineseAnalyzerLucene(override val uid: String) 198 | extends LuceneTransformer[SmartChineseAnalyzerLucene] { 199 | 200 | def this() = this(Identifiable.randomUID("SmartChineseAnalyzer")) 201 | 202 | def buildAnalyzer(): Analyzer = { 203 | new org.apache.lucene.analysis.cn.smart.SmartChineseAnalyzer() 204 | } 205 | } 206 | 207 | /** 208 | * A basic Transformer based on KeywordAnalyzer - does not support 209 | * any configuration properties. 210 | * See https://github.com/sparklingpandas/sparklingml/issues/3 211 | * & LuceneAnalyzerGenerators for details. 212 | */ 213 | 214 | class KeywordAnalyzerLucene(override val uid: String) 215 | extends LuceneTransformer[KeywordAnalyzerLucene] { 216 | 217 | def this() = this(Identifiable.randomUID("KeywordAnalyzer")) 218 | 219 | def buildAnalyzer(): Analyzer = { 220 | new org.apache.lucene.analysis.core.KeywordAnalyzer() 221 | } 222 | } 223 | 224 | /** 225 | * A basic Transformer based on SimpleAnalyzer - does not support 226 | * any configuration properties. 227 | * See https://github.com/sparklingpandas/sparklingml/issues/3 228 | * & LuceneAnalyzerGenerators for details. 229 | */ 230 | 231 | class SimpleAnalyzerLucene(override val uid: String) 232 | extends LuceneTransformer[SimpleAnalyzerLucene] { 233 | 234 | def this() = this(Identifiable.randomUID("SimpleAnalyzer")) 235 | 236 | def buildAnalyzer(): Analyzer = { 237 | new org.apache.lucene.analysis.core.SimpleAnalyzer() 238 | } 239 | } 240 | 241 | /** 242 | * A basic Transformer based on StopAnalyzer. 243 | * Supports configuring stopwords. 244 | */ 245 | 246 | class StopAnalyzerLucene(override val uid: String) 247 | extends LuceneTransformer[StopAnalyzerLucene] 248 | with HasStopwords with HasStopwordCase { 249 | 250 | def this() = this(Identifiable.randomUID("StopAnalyzer")) 251 | 252 | def buildAnalyzer(): Analyzer = { 253 | // In the future we can use getDefaultStopWords here to allow people 254 | // to control the snowball stemmer distinctly from the stopwords. 255 | // but that is a TODO for later. 256 | if (isSet(stopwords)) { 257 | new org.apache.lucene.analysis.core.StopAnalyzer( 258 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 259 | } else { 260 | new org.apache.lucene.analysis.core.StopAnalyzer() 261 | } 262 | } 263 | } 264 | 265 | /** 266 | * A basic Transformer based on UnicodeWhitespaceAnalyzer - does not support 267 | * any configuration properties. 268 | * See https://github.com/sparklingpandas/sparklingml/issues/3 269 | * & LuceneAnalyzerGenerators for details. 270 | */ 271 | 272 | class UnicodeWhitespaceAnalyzerLucene(override val uid: String) 273 | extends LuceneTransformer[UnicodeWhitespaceAnalyzerLucene] { 274 | 275 | def this() = this(Identifiable.randomUID("UnicodeWhitespaceAnalyzer")) 276 | 277 | def buildAnalyzer(): Analyzer = { 278 | new org.apache.lucene.analysis.core.UnicodeWhitespaceAnalyzer() 279 | } 280 | } 281 | 282 | /** 283 | * A basic Transformer based on WhitespaceAnalyzer - does not support 284 | * any configuration properties. 285 | * See https://github.com/sparklingpandas/sparklingml/issues/3 286 | * & LuceneAnalyzerGenerators for details. 287 | */ 288 | 289 | class WhitespaceAnalyzerLucene(override val uid: String) 290 | extends LuceneTransformer[WhitespaceAnalyzerLucene] { 291 | 292 | def this() = this(Identifiable.randomUID("WhitespaceAnalyzer")) 293 | 294 | def buildAnalyzer(): Analyzer = { 295 | new org.apache.lucene.analysis.core.WhitespaceAnalyzer() 296 | } 297 | } 298 | 299 | /* There is no default zero arg constructor for 300 | *org.apache.lucene.analysis.custom.CustomAnalyzer. 301 | */ 302 | 303 | /** 304 | * A basic Transformer based on CzechAnalyzer. 305 | * Supports configuring stopwords. 306 | * There are additional parameters which can not yet be contro 307 | lled through this API 308 | * See https://github.com/sparklingpandas/sparklingml/issues/3 309 | */ 310 | 311 | class CzechAnalyzerLucene(override val uid: String) 312 | extends LuceneTransformer[CzechAnalyzerLucene] 313 | with HasStopwords with HasStopwordCase { 314 | 315 | def this() = this(Identifiable.randomUID("CzechAnalyzer")) 316 | 317 | def buildAnalyzer(): Analyzer = { 318 | // In the future we can use getDefaultStopWords here to allow people 319 | // to control the snowball stemmer distinctly from the stopwords. 320 | // but that is a TODO for later. 321 | if (isSet(stopwords)) { 322 | new org.apache.lucene.analysis.cz.CzechAnalyzer( 323 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 324 | } else { 325 | new org.apache.lucene.analysis.cz.CzechAnalyzer() 326 | } 327 | } 328 | } 329 | 330 | /** 331 | * A basic Transformer based on DanishAnalyzer. 332 | * Supports configuring stopwords. 333 | * There are additional parameters which can not yet be contro 334 | lled through this API 335 | * See https://github.com/sparklingpandas/sparklingml/issues/3 336 | */ 337 | 338 | class DanishAnalyzerLucene(override val uid: String) 339 | extends LuceneTransformer[DanishAnalyzerLucene] 340 | with HasStopwords with HasStopwordCase { 341 | 342 | def this() = this(Identifiable.randomUID("DanishAnalyzer")) 343 | 344 | def buildAnalyzer(): Analyzer = { 345 | // In the future we can use getDefaultStopWords here to allow people 346 | // to control the snowball stemmer distinctly from the stopwords. 347 | // but that is a TODO for later. 348 | if (isSet(stopwords)) { 349 | new org.apache.lucene.analysis.da.DanishAnalyzer( 350 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 351 | } else { 352 | new org.apache.lucene.analysis.da.DanishAnalyzer() 353 | } 354 | } 355 | } 356 | 357 | /** 358 | * A basic Transformer based on GermanAnalyzer. 359 | * Supports configuring stopwords. 360 | * There are additional parameters which can not yet be contro 361 | lled through this API 362 | * See https://github.com/sparklingpandas/sparklingml/issues/3 363 | */ 364 | 365 | class GermanAnalyzerLucene(override val uid: String) 366 | extends LuceneTransformer[GermanAnalyzerLucene] 367 | with HasStopwords with HasStopwordCase { 368 | 369 | def this() = this(Identifiable.randomUID("GermanAnalyzer")) 370 | 371 | def buildAnalyzer(): Analyzer = { 372 | // In the future we can use getDefaultStopWords here to allow people 373 | // to control the snowball stemmer distinctly from the stopwords. 374 | // but that is a TODO for later. 375 | if (isSet(stopwords)) { 376 | new org.apache.lucene.analysis.de.GermanAnalyzer( 377 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 378 | } else { 379 | new org.apache.lucene.analysis.de.GermanAnalyzer() 380 | } 381 | } 382 | } 383 | 384 | /** 385 | * A basic Transformer based on GreekAnalyzer. 386 | * Supports configuring stopwords. 387 | */ 388 | 389 | class GreekAnalyzerLucene(override val uid: String) 390 | extends LuceneTransformer[GreekAnalyzerLucene] 391 | with HasStopwords with HasStopwordCase { 392 | 393 | def this() = this(Identifiable.randomUID("GreekAnalyzer")) 394 | 395 | def buildAnalyzer(): Analyzer = { 396 | // In the future we can use getDefaultStopWords here to allow people 397 | // to control the snowball stemmer distinctly from the stopwords. 398 | // but that is a TODO for later. 399 | if (isSet(stopwords)) { 400 | new org.apache.lucene.analysis.el.GreekAnalyzer( 401 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 402 | } else { 403 | new org.apache.lucene.analysis.el.GreekAnalyzer() 404 | } 405 | } 406 | } 407 | 408 | /** 409 | * A basic Transformer based on EnglishAnalyzer. 410 | * Supports configuring stopwords. 411 | * There are additional parameters which can not yet be contro 412 | lled through this API 413 | * See https://github.com/sparklingpandas/sparklingml/issues/3 414 | */ 415 | 416 | class EnglishAnalyzerLucene(override val uid: String) 417 | extends LuceneTransformer[EnglishAnalyzerLucene] 418 | with HasStopwords with HasStopwordCase { 419 | 420 | def this() = this(Identifiable.randomUID("EnglishAnalyzer")) 421 | 422 | def buildAnalyzer(): Analyzer = { 423 | // In the future we can use getDefaultStopWords here to allow people 424 | // to control the snowball stemmer distinctly from the stopwords. 425 | // but that is a TODO for later. 426 | if (isSet(stopwords)) { 427 | new org.apache.lucene.analysis.en.EnglishAnalyzer( 428 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 429 | } else { 430 | new org.apache.lucene.analysis.en.EnglishAnalyzer() 431 | } 432 | } 433 | } 434 | 435 | /** 436 | * A basic Transformer based on SpanishAnalyzer. 437 | * Supports configuring stopwords. 438 | * There are additional parameters which can not yet be contro 439 | lled through this API 440 | * See https://github.com/sparklingpandas/sparklingml/issues/3 441 | */ 442 | 443 | class SpanishAnalyzerLucene(override val uid: String) 444 | extends LuceneTransformer[SpanishAnalyzerLucene] 445 | with HasStopwords with HasStopwordCase { 446 | 447 | def this() = this(Identifiable.randomUID("SpanishAnalyzer")) 448 | 449 | def buildAnalyzer(): Analyzer = { 450 | // In the future we can use getDefaultStopWords here to allow people 451 | // to control the snowball stemmer distinctly from the stopwords. 452 | // but that is a TODO for later. 453 | if (isSet(stopwords)) { 454 | new org.apache.lucene.analysis.es.SpanishAnalyzer( 455 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 456 | } else { 457 | new org.apache.lucene.analysis.es.SpanishAnalyzer() 458 | } 459 | } 460 | } 461 | 462 | /** 463 | * A basic Transformer based on BasqueAnalyzer. 464 | * Supports configuring stopwords. 465 | * There are additional parameters which can not yet be contro 466 | lled through this API 467 | * See https://github.com/sparklingpandas/sparklingml/issues/3 468 | */ 469 | 470 | class BasqueAnalyzerLucene(override val uid: String) 471 | extends LuceneTransformer[BasqueAnalyzerLucene] 472 | with HasStopwords with HasStopwordCase { 473 | 474 | def this() = this(Identifiable.randomUID("BasqueAnalyzer")) 475 | 476 | def buildAnalyzer(): Analyzer = { 477 | // In the future we can use getDefaultStopWords here to allow people 478 | // to control the snowball stemmer distinctly from the stopwords. 479 | // but that is a TODO for later. 480 | if (isSet(stopwords)) { 481 | new org.apache.lucene.analysis.eu.BasqueAnalyzer( 482 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 483 | } else { 484 | new org.apache.lucene.analysis.eu.BasqueAnalyzer() 485 | } 486 | } 487 | } 488 | 489 | /** 490 | * A basic Transformer based on PersianAnalyzer. 491 | * Supports configuring stopwords. 492 | */ 493 | 494 | class PersianAnalyzerLucene(override val uid: String) 495 | extends LuceneTransformer[PersianAnalyzerLucene] 496 | with HasStopwords with HasStopwordCase { 497 | 498 | def this() = this(Identifiable.randomUID("PersianAnalyzer")) 499 | 500 | def buildAnalyzer(): Analyzer = { 501 | // In the future we can use getDefaultStopWords here to allow people 502 | // to control the snowball stemmer distinctly from the stopwords. 503 | // but that is a TODO for later. 504 | if (isSet(stopwords)) { 505 | new org.apache.lucene.analysis.fa.PersianAnalyzer( 506 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 507 | } else { 508 | new org.apache.lucene.analysis.fa.PersianAnalyzer() 509 | } 510 | } 511 | } 512 | 513 | /** 514 | * A basic Transformer based on FinnishAnalyzer. 515 | * Supports configuring stopwords. 516 | * There are additional parameters which can not yet be contro 517 | lled through this API 518 | * See https://github.com/sparklingpandas/sparklingml/issues/3 519 | */ 520 | 521 | class FinnishAnalyzerLucene(override val uid: String) 522 | extends LuceneTransformer[FinnishAnalyzerLucene] 523 | with HasStopwords with HasStopwordCase { 524 | 525 | def this() = this(Identifiable.randomUID("FinnishAnalyzer")) 526 | 527 | def buildAnalyzer(): Analyzer = { 528 | // In the future we can use getDefaultStopWords here to allow people 529 | // to control the snowball stemmer distinctly from the stopwords. 530 | // but that is a TODO for later. 531 | if (isSet(stopwords)) { 532 | new org.apache.lucene.analysis.fi.FinnishAnalyzer( 533 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 534 | } else { 535 | new org.apache.lucene.analysis.fi.FinnishAnalyzer() 536 | } 537 | } 538 | } 539 | 540 | /** 541 | * A basic Transformer based on FrenchAnalyzer. 542 | * Supports configuring stopwords. 543 | * There are additional parameters which can not yet be contro 544 | lled through this API 545 | * See https://github.com/sparklingpandas/sparklingml/issues/3 546 | */ 547 | 548 | class FrenchAnalyzerLucene(override val uid: String) 549 | extends LuceneTransformer[FrenchAnalyzerLucene] 550 | with HasStopwords with HasStopwordCase { 551 | 552 | def this() = this(Identifiable.randomUID("FrenchAnalyzer")) 553 | 554 | def buildAnalyzer(): Analyzer = { 555 | // In the future we can use getDefaultStopWords here to allow people 556 | // to control the snowball stemmer distinctly from the stopwords. 557 | // but that is a TODO for later. 558 | if (isSet(stopwords)) { 559 | new org.apache.lucene.analysis.fr.FrenchAnalyzer( 560 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 561 | } else { 562 | new org.apache.lucene.analysis.fr.FrenchAnalyzer() 563 | } 564 | } 565 | } 566 | 567 | /** 568 | * A basic Transformer based on IrishAnalyzer. 569 | * Supports configuring stopwords. 570 | * There are additional parameters which can not yet be contro 571 | lled through this API 572 | * See https://github.com/sparklingpandas/sparklingml/issues/3 573 | */ 574 | 575 | class IrishAnalyzerLucene(override val uid: String) 576 | extends LuceneTransformer[IrishAnalyzerLucene] 577 | with HasStopwords with HasStopwordCase { 578 | 579 | def this() = this(Identifiable.randomUID("IrishAnalyzer")) 580 | 581 | def buildAnalyzer(): Analyzer = { 582 | // In the future we can use getDefaultStopWords here to allow people 583 | // to control the snowball stemmer distinctly from the stopwords. 584 | // but that is a TODO for later. 585 | if (isSet(stopwords)) { 586 | new org.apache.lucene.analysis.ga.IrishAnalyzer( 587 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 588 | } else { 589 | new org.apache.lucene.analysis.ga.IrishAnalyzer() 590 | } 591 | } 592 | } 593 | 594 | /** 595 | * A basic Transformer based on GalicianAnalyzer. 596 | * Supports configuring stopwords. 597 | * There are additional parameters which can not yet be contro 598 | lled through this API 599 | * See https://github.com/sparklingpandas/sparklingml/issues/3 600 | */ 601 | 602 | class GalicianAnalyzerLucene(override val uid: String) 603 | extends LuceneTransformer[GalicianAnalyzerLucene] 604 | with HasStopwords with HasStopwordCase { 605 | 606 | def this() = this(Identifiable.randomUID("GalicianAnalyzer")) 607 | 608 | def buildAnalyzer(): Analyzer = { 609 | // In the future we can use getDefaultStopWords here to allow people 610 | // to control the snowball stemmer distinctly from the stopwords. 611 | // but that is a TODO for later. 612 | if (isSet(stopwords)) { 613 | new org.apache.lucene.analysis.gl.GalicianAnalyzer( 614 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 615 | } else { 616 | new org.apache.lucene.analysis.gl.GalicianAnalyzer() 617 | } 618 | } 619 | } 620 | 621 | /** 622 | * A basic Transformer based on HindiAnalyzer. 623 | * Supports configuring stopwords. 624 | * There are additional parameters which can not yet be contro 625 | lled through this API 626 | * See https://github.com/sparklingpandas/sparklingml/issues/3 627 | */ 628 | 629 | class HindiAnalyzerLucene(override val uid: String) 630 | extends LuceneTransformer[HindiAnalyzerLucene] 631 | with HasStopwords with HasStopwordCase { 632 | 633 | def this() = this(Identifiable.randomUID("HindiAnalyzer")) 634 | 635 | def buildAnalyzer(): Analyzer = { 636 | // In the future we can use getDefaultStopWords here to allow people 637 | // to control the snowball stemmer distinctly from the stopwords. 638 | // but that is a TODO for later. 639 | if (isSet(stopwords)) { 640 | new org.apache.lucene.analysis.hi.HindiAnalyzer( 641 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 642 | } else { 643 | new org.apache.lucene.analysis.hi.HindiAnalyzer() 644 | } 645 | } 646 | } 647 | 648 | /** 649 | * A basic Transformer based on HungarianAnalyzer. 650 | * Supports configuring stopwords. 651 | * There are additional parameters which can not yet be contro 652 | lled through this API 653 | * See https://github.com/sparklingpandas/sparklingml/issues/3 654 | */ 655 | 656 | class HungarianAnalyzerLucene(override val uid: String) 657 | extends LuceneTransformer[HungarianAnalyzerLucene] 658 | with HasStopwords with HasStopwordCase { 659 | 660 | def this() = this(Identifiable.randomUID("HungarianAnalyzer")) 661 | 662 | def buildAnalyzer(): Analyzer = { 663 | // In the future we can use getDefaultStopWords here to allow people 664 | // to control the snowball stemmer distinctly from the stopwords. 665 | // but that is a TODO for later. 666 | if (isSet(stopwords)) { 667 | new org.apache.lucene.analysis.hu.HungarianAnalyzer( 668 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 669 | } else { 670 | new org.apache.lucene.analysis.hu.HungarianAnalyzer() 671 | } 672 | } 673 | } 674 | 675 | /** 676 | * A basic Transformer based on ArmenianAnalyzer. 677 | * Supports configuring stopwords. 678 | * There are additional parameters which can not yet be contro 679 | lled through this API 680 | * See https://github.com/sparklingpandas/sparklingml/issues/3 681 | */ 682 | 683 | class ArmenianAnalyzerLucene(override val uid: String) 684 | extends LuceneTransformer[ArmenianAnalyzerLucene] 685 | with HasStopwords with HasStopwordCase { 686 | 687 | def this() = this(Identifiable.randomUID("ArmenianAnalyzer")) 688 | 689 | def buildAnalyzer(): Analyzer = { 690 | // In the future we can use getDefaultStopWords here to allow people 691 | // to control the snowball stemmer distinctly from the stopwords. 692 | // but that is a TODO for later. 693 | if (isSet(stopwords)) { 694 | new org.apache.lucene.analysis.hy.ArmenianAnalyzer( 695 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 696 | } else { 697 | new org.apache.lucene.analysis.hy.ArmenianAnalyzer() 698 | } 699 | } 700 | } 701 | 702 | /** 703 | * A basic Transformer based on IndonesianAnalyzer. 704 | * Supports configuring stopwords. 705 | * There are additional parameters which can not yet be contro 706 | lled through this API 707 | * See https://github.com/sparklingpandas/sparklingml/issues/3 708 | */ 709 | 710 | class IndonesianAnalyzerLucene(override val uid: String) 711 | extends LuceneTransformer[IndonesianAnalyzerLucene] 712 | with HasStopwords with HasStopwordCase { 713 | 714 | def this() = this(Identifiable.randomUID("IndonesianAnalyzer")) 715 | 716 | def buildAnalyzer(): Analyzer = { 717 | // In the future we can use getDefaultStopWords here to allow people 718 | // to control the snowball stemmer distinctly from the stopwords. 719 | // but that is a TODO for later. 720 | if (isSet(stopwords)) { 721 | new org.apache.lucene.analysis.id.IndonesianAnalyzer( 722 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 723 | } else { 724 | new org.apache.lucene.analysis.id.IndonesianAnalyzer() 725 | } 726 | } 727 | } 728 | 729 | /** 730 | * A basic Transformer based on ItalianAnalyzer. 731 | * Supports configuring stopwords. 732 | * There are additional parameters which can not yet be contro 733 | lled through this API 734 | * See https://github.com/sparklingpandas/sparklingml/issues/3 735 | */ 736 | 737 | class ItalianAnalyzerLucene(override val uid: String) 738 | extends LuceneTransformer[ItalianAnalyzerLucene] 739 | with HasStopwords with HasStopwordCase { 740 | 741 | def this() = this(Identifiable.randomUID("ItalianAnalyzer")) 742 | 743 | def buildAnalyzer(): Analyzer = { 744 | // In the future we can use getDefaultStopWords here to allow people 745 | // to control the snowball stemmer distinctly from the stopwords. 746 | // but that is a TODO for later. 747 | if (isSet(stopwords)) { 748 | new org.apache.lucene.analysis.it.ItalianAnalyzer( 749 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 750 | } else { 751 | new org.apache.lucene.analysis.it.ItalianAnalyzer() 752 | } 753 | } 754 | } 755 | 756 | /** 757 | * A basic Transformer based on JapaneseAnalyzer - does not support 758 | * any configuration properties. 759 | * See https://github.com/sparklingpandas/sparklingml/issues/3 760 | * & LuceneAnalyzerGenerators for details. 761 | */ 762 | 763 | class JapaneseAnalyzerLucene(override val uid: String) 764 | extends LuceneTransformer[JapaneseAnalyzerLucene] { 765 | 766 | def this() = this(Identifiable.randomUID("JapaneseAnalyzer")) 767 | 768 | def buildAnalyzer(): Analyzer = { 769 | new org.apache.lucene.analysis.ja.JapaneseAnalyzer() 770 | } 771 | } 772 | 773 | /** 774 | * A basic Transformer based on LithuanianAnalyzer. 775 | * Supports configuring stopwords. 776 | * There are additional parameters which can not yet be contro 777 | lled through this API 778 | * See https://github.com/sparklingpandas/sparklingml/issues/3 779 | */ 780 | 781 | class LithuanianAnalyzerLucene(override val uid: String) 782 | extends LuceneTransformer[LithuanianAnalyzerLucene] 783 | with HasStopwords with HasStopwordCase { 784 | 785 | def this() = this(Identifiable.randomUID("LithuanianAnalyzer")) 786 | 787 | def buildAnalyzer(): Analyzer = { 788 | // In the future we can use getDefaultStopWords here to allow people 789 | // to control the snowball stemmer distinctly from the stopwords. 790 | // but that is a TODO for later. 791 | if (isSet(stopwords)) { 792 | new org.apache.lucene.analysis.lt.LithuanianAnalyzer( 793 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 794 | } else { 795 | new org.apache.lucene.analysis.lt.LithuanianAnalyzer() 796 | } 797 | } 798 | } 799 | 800 | /** 801 | * A basic Transformer based on LatvianAnalyzer. 802 | * Supports configuring stopwords. 803 | * There are additional parameters which can not yet be contro 804 | lled through this API 805 | * See https://github.com/sparklingpandas/sparklingml/issues/3 806 | */ 807 | 808 | class LatvianAnalyzerLucene(override val uid: String) 809 | extends LuceneTransformer[LatvianAnalyzerLucene] 810 | with HasStopwords with HasStopwordCase { 811 | 812 | def this() = this(Identifiable.randomUID("LatvianAnalyzer")) 813 | 814 | def buildAnalyzer(): Analyzer = { 815 | // In the future we can use getDefaultStopWords here to allow people 816 | // to control the snowball stemmer distinctly from the stopwords. 817 | // but that is a TODO for later. 818 | if (isSet(stopwords)) { 819 | new org.apache.lucene.analysis.lv.LatvianAnalyzer( 820 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 821 | } else { 822 | new org.apache.lucene.analysis.lv.LatvianAnalyzer() 823 | } 824 | } 825 | } 826 | 827 | /* There is no default zero arg constructor for 828 | *org.apache.lucene.analysis.miscellaneous.LimitTokenCountAnalyzer. 829 | */ 830 | 831 | /* There is no default zero arg constructor for 832 | *org.apache.lucene.analysis.miscellaneous.PerFieldAnalyzerWrapper. 833 | */ 834 | 835 | /** 836 | * A basic Transformer based on MorfologikAnalyzer - does not support 837 | * any configuration properties. 838 | * See https://github.com/sparklingpandas/sparklingml/issues/3 839 | * & LuceneAnalyzerGenerators for details. 840 | */ 841 | 842 | class MorfologikAnalyzerLucene(override val uid: String) 843 | extends LuceneTransformer[MorfologikAnalyzerLucene] { 844 | 845 | def this() = this(Identifiable.randomUID("MorfologikAnalyzer")) 846 | 847 | def buildAnalyzer(): Analyzer = { 848 | new org.apache.lucene.analysis.morfologik.MorfologikAnalyzer() 849 | } 850 | } 851 | 852 | /** 853 | * A basic Transformer based on DutchAnalyzer - does not support 854 | * any configuration properties. 855 | * See https://github.com/sparklingpandas/sparklingml/issues/3 856 | * & LuceneAnalyzerGenerators for details. 857 | */ 858 | 859 | class DutchAnalyzerLucene(override val uid: String) 860 | extends LuceneTransformer[DutchAnalyzerLucene] { 861 | 862 | def this() = this(Identifiable.randomUID("DutchAnalyzer")) 863 | 864 | def buildAnalyzer(): Analyzer = { 865 | new org.apache.lucene.analysis.nl.DutchAnalyzer() 866 | } 867 | } 868 | 869 | /** 870 | * A basic Transformer based on NorwegianAnalyzer. 871 | * Supports configuring stopwords. 872 | * There are additional parameters which can not yet be contro 873 | lled through this API 874 | * See https://github.com/sparklingpandas/sparklingml/issues/3 875 | */ 876 | 877 | class NorwegianAnalyzerLucene(override val uid: String) 878 | extends LuceneTransformer[NorwegianAnalyzerLucene] 879 | with HasStopwords with HasStopwordCase { 880 | 881 | def this() = this(Identifiable.randomUID("NorwegianAnalyzer")) 882 | 883 | def buildAnalyzer(): Analyzer = { 884 | // In the future we can use getDefaultStopWords here to allow people 885 | // to control the snowball stemmer distinctly from the stopwords. 886 | // but that is a TODO for later. 887 | if (isSet(stopwords)) { 888 | new org.apache.lucene.analysis.no.NorwegianAnalyzer( 889 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 890 | } else { 891 | new org.apache.lucene.analysis.no.NorwegianAnalyzer() 892 | } 893 | } 894 | } 895 | 896 | /** 897 | * A basic Transformer based on PolishAnalyzer. 898 | * Supports configuring stopwords. 899 | * There are additional parameters which can not yet be contro 900 | lled through this API 901 | * See https://github.com/sparklingpandas/sparklingml/issues/3 902 | */ 903 | 904 | class PolishAnalyzerLucene(override val uid: String) 905 | extends LuceneTransformer[PolishAnalyzerLucene] 906 | with HasStopwords with HasStopwordCase { 907 | 908 | def this() = this(Identifiable.randomUID("PolishAnalyzer")) 909 | 910 | def buildAnalyzer(): Analyzer = { 911 | // In the future we can use getDefaultStopWords here to allow people 912 | // to control the snowball stemmer distinctly from the stopwords. 913 | // but that is a TODO for later. 914 | if (isSet(stopwords)) { 915 | new org.apache.lucene.analysis.pl.PolishAnalyzer( 916 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 917 | } else { 918 | new org.apache.lucene.analysis.pl.PolishAnalyzer() 919 | } 920 | } 921 | } 922 | 923 | /** 924 | * A basic Transformer based on PortugueseAnalyzer. 925 | * Supports configuring stopwords. 926 | * There are additional parameters which can not yet be contro 927 | lled through this API 928 | * See https://github.com/sparklingpandas/sparklingml/issues/3 929 | */ 930 | 931 | class PortugueseAnalyzerLucene(override val uid: String) 932 | extends LuceneTransformer[PortugueseAnalyzerLucene] 933 | with HasStopwords with HasStopwordCase { 934 | 935 | def this() = this(Identifiable.randomUID("PortugueseAnalyzer")) 936 | 937 | def buildAnalyzer(): Analyzer = { 938 | // In the future we can use getDefaultStopWords here to allow people 939 | // to control the snowball stemmer distinctly from the stopwords. 940 | // but that is a TODO for later. 941 | if (isSet(stopwords)) { 942 | new org.apache.lucene.analysis.pt.PortugueseAnalyzer( 943 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 944 | } else { 945 | new org.apache.lucene.analysis.pt.PortugueseAnalyzer() 946 | } 947 | } 948 | } 949 | 950 | /* There is no default zero arg constructor for 951 | *org.apache.lucene.analysis.query.QueryAutoStopWordAnalyzer. 952 | */ 953 | 954 | /** 955 | * A basic Transformer based on RomanianAnalyzer. 956 | * Supports configuring stopwords. 957 | * There are additional parameters which can not yet be contro 958 | lled through this API 959 | * See https://github.com/sparklingpandas/sparklingml/issues/3 960 | */ 961 | 962 | class RomanianAnalyzerLucene(override val uid: String) 963 | extends LuceneTransformer[RomanianAnalyzerLucene] 964 | with HasStopwords with HasStopwordCase { 965 | 966 | def this() = this(Identifiable.randomUID("RomanianAnalyzer")) 967 | 968 | def buildAnalyzer(): Analyzer = { 969 | // In the future we can use getDefaultStopWords here to allow people 970 | // to control the snowball stemmer distinctly from the stopwords. 971 | // but that is a TODO for later. 972 | if (isSet(stopwords)) { 973 | new org.apache.lucene.analysis.ro.RomanianAnalyzer( 974 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 975 | } else { 976 | new org.apache.lucene.analysis.ro.RomanianAnalyzer() 977 | } 978 | } 979 | } 980 | 981 | /** 982 | * A basic Transformer based on RussianAnalyzer. 983 | * Supports configuring stopwords. 984 | * There are additional parameters which can not yet be contro 985 | lled through this API 986 | * See https://github.com/sparklingpandas/sparklingml/issues/3 987 | */ 988 | 989 | class RussianAnalyzerLucene(override val uid: String) 990 | extends LuceneTransformer[RussianAnalyzerLucene] 991 | with HasStopwords with HasStopwordCase { 992 | 993 | def this() = this(Identifiable.randomUID("RussianAnalyzer")) 994 | 995 | def buildAnalyzer(): Analyzer = { 996 | // In the future we can use getDefaultStopWords here to allow people 997 | // to control the snowball stemmer distinctly from the stopwords. 998 | // but that is a TODO for later. 999 | if (isSet(stopwords)) { 1000 | new org.apache.lucene.analysis.ru.RussianAnalyzer( 1001 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1002 | } else { 1003 | new org.apache.lucene.analysis.ru.RussianAnalyzer() 1004 | } 1005 | } 1006 | } 1007 | 1008 | /** 1009 | * A basic Transformer based on ShingleAnalyzerWrapper - does not support 1010 | * any configuration properties. 1011 | * See https://github.com/sparklingpandas/sparklingml/issues/3 1012 | * & LuceneAnalyzerGenerators for details. 1013 | */ 1014 | 1015 | class ShingleAnalyzerWrapperLucene(override val uid: String) 1016 | extends LuceneTransformer[ShingleAnalyzerWrapperLucene] { 1017 | 1018 | def this() = this(Identifiable.randomUID("ShingleAnalyzerWrapper")) 1019 | 1020 | def buildAnalyzer(): Analyzer = { 1021 | new org.apache.lucene.analysis.shingle.ShingleAnalyzerWrapper() 1022 | } 1023 | } 1024 | 1025 | /** 1026 | * A basic Transformer based on ClassicAnalyzer. 1027 | * Supports configuring stopwords. 1028 | */ 1029 | 1030 | class ClassicAnalyzerLucene(override val uid: String) 1031 | extends LuceneTransformer[ClassicAnalyzerLucene] 1032 | with HasStopwords with HasStopwordCase { 1033 | 1034 | def this() = this(Identifiable.randomUID("ClassicAnalyzer")) 1035 | 1036 | def buildAnalyzer(): Analyzer = { 1037 | // In the future we can use getDefaultStopWords here to allow people 1038 | // to control the snowball stemmer distinctly from the stopwords. 1039 | // but that is a TODO for later. 1040 | if (isSet(stopwords)) { 1041 | new org.apache.lucene.analysis.standard.ClassicAnalyzer( 1042 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1043 | } else { 1044 | new org.apache.lucene.analysis.standard.ClassicAnalyzer() 1045 | } 1046 | } 1047 | } 1048 | 1049 | /** 1050 | * A basic Transformer based on StandardAnalyzer. 1051 | * Supports configuring stopwords. 1052 | */ 1053 | 1054 | class StandardAnalyzerLucene(override val uid: String) 1055 | extends LuceneTransformer[StandardAnalyzerLucene] 1056 | with HasStopwords with HasStopwordCase { 1057 | 1058 | def this() = this(Identifiable.randomUID("StandardAnalyzer")) 1059 | 1060 | def buildAnalyzer(): Analyzer = { 1061 | // In the future we can use getDefaultStopWords here to allow people 1062 | // to control the snowball stemmer distinctly from the stopwords. 1063 | // but that is a TODO for later. 1064 | if (isSet(stopwords)) { 1065 | new org.apache.lucene.analysis.standard.StandardAnalyzer( 1066 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1067 | } else { 1068 | new org.apache.lucene.analysis.standard.StandardAnalyzer() 1069 | } 1070 | } 1071 | } 1072 | 1073 | /** 1074 | * A basic Transformer based on UAX29URLEmailAnalyzer. 1075 | * Supports configuring stopwords. 1076 | */ 1077 | 1078 | class UAX29URLEmailAnalyzerLucene(override val uid: String) 1079 | extends LuceneTransformer[UAX29URLEmailAnalyzerLucene] 1080 | with HasStopwords with HasStopwordCase { 1081 | 1082 | def this() = this(Identifiable.randomUID("UAX29URLEmailAnalyzer")) 1083 | 1084 | def buildAnalyzer(): Analyzer = { 1085 | // In the future we can use getDefaultStopWords here to allow people 1086 | // to control the snowball stemmer distinctly from the stopwords. 1087 | // but that is a TODO for later. 1088 | if (isSet(stopwords)) { 1089 | new org.apache.lucene.analysis.standard.UAX29URLEmailAnalyzer( 1090 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1091 | } else { 1092 | new org.apache.lucene.analysis.standard.UAX29URLEmailAnalyzer() 1093 | } 1094 | } 1095 | } 1096 | 1097 | /** 1098 | * A basic Transformer based on SwedishAnalyzer. 1099 | * Supports configuring stopwords. 1100 | * There are additional parameters which can not yet be contro 1101 | lled through this API 1102 | * See https://github.com/sparklingpandas/sparklingml/issues/3 1103 | */ 1104 | 1105 | class SwedishAnalyzerLucene(override val uid: String) 1106 | extends LuceneTransformer[SwedishAnalyzerLucene] 1107 | with HasStopwords with HasStopwordCase { 1108 | 1109 | def this() = this(Identifiable.randomUID("SwedishAnalyzer")) 1110 | 1111 | def buildAnalyzer(): Analyzer = { 1112 | // In the future we can use getDefaultStopWords here to allow people 1113 | // to control the snowball stemmer distinctly from the stopwords. 1114 | // but that is a TODO for later. 1115 | if (isSet(stopwords)) { 1116 | new org.apache.lucene.analysis.sv.SwedishAnalyzer( 1117 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1118 | } else { 1119 | new org.apache.lucene.analysis.sv.SwedishAnalyzer() 1120 | } 1121 | } 1122 | } 1123 | 1124 | /** 1125 | * A basic Transformer based on ThaiAnalyzer. 1126 | * Supports configuring stopwords. 1127 | */ 1128 | 1129 | class ThaiAnalyzerLucene(override val uid: String) 1130 | extends LuceneTransformer[ThaiAnalyzerLucene] 1131 | with HasStopwords with HasStopwordCase { 1132 | 1133 | def this() = this(Identifiable.randomUID("ThaiAnalyzer")) 1134 | 1135 | def buildAnalyzer(): Analyzer = { 1136 | // In the future we can use getDefaultStopWords here to allow people 1137 | // to control the snowball stemmer distinctly from the stopwords. 1138 | // but that is a TODO for later. 1139 | if (isSet(stopwords)) { 1140 | new org.apache.lucene.analysis.th.ThaiAnalyzer( 1141 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1142 | } else { 1143 | new org.apache.lucene.analysis.th.ThaiAnalyzer() 1144 | } 1145 | } 1146 | } 1147 | 1148 | /** 1149 | * A basic Transformer based on TurkishAnalyzer. 1150 | * Supports configuring stopwords. 1151 | * There are additional parameters which can not yet be contro 1152 | lled through this API 1153 | * See https://github.com/sparklingpandas/sparklingml/issues/3 1154 | */ 1155 | 1156 | class TurkishAnalyzerLucene(override val uid: String) 1157 | extends LuceneTransformer[TurkishAnalyzerLucene] 1158 | with HasStopwords with HasStopwordCase { 1159 | 1160 | def this() = this(Identifiable.randomUID("TurkishAnalyzer")) 1161 | 1162 | def buildAnalyzer(): Analyzer = { 1163 | // In the future we can use getDefaultStopWords here to allow people 1164 | // to control the snowball stemmer distinctly from the stopwords. 1165 | // but that is a TODO for later. 1166 | if (isSet(stopwords)) { 1167 | new org.apache.lucene.analysis.tr.TurkishAnalyzer( 1168 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1169 | } else { 1170 | new org.apache.lucene.analysis.tr.TurkishAnalyzer() 1171 | } 1172 | } 1173 | } 1174 | 1175 | /* There is no default zero arg constructor for 1176 | *org.apache.lucene.analysis.uima.UIMABaseAnalyzer. 1177 | */ 1178 | 1179 | /* There is no default zero arg constructor for 1180 | *org.apache.lucene.analysis.uima.UIMATypeAwareAnalyzer. 1181 | */ 1182 | 1183 | /** 1184 | * A basic Transformer based on UkrainianMorfologikAnalyzer. 1185 | * Supports configuring stopwords. 1186 | * There are additional parameters which can not yet be contro 1187 | lled through this API 1188 | * See https://github.com/sparklingpandas/sparklingml/issues/3 1189 | */ 1190 | 1191 | class UkrainianMorfologikAnalyzerLucene(override val uid: String) 1192 | extends LuceneTransformer[UkrainianMorfologikAnalyzerLucene] 1193 | with HasStopwords with HasStopwordCase { 1194 | 1195 | def this() = this(Identifiable.randomUID("UkrainianMorfologikAnalyzer")) 1196 | 1197 | def buildAnalyzer(): Analyzer = { 1198 | // In the future we can use getDefaultStopWords here to allow people 1199 | // to control the snowball stemmer distinctly from the stopwords. 1200 | // but that is a TODO for later. 1201 | if (isSet(stopwords)) { 1202 | new org.apache.lucene.analysis.uk.UkrainianMorfologikAnalyzer( 1203 | LuceneHelpers.wordstoCharArraySet($(stopwords), !$(stopwordCase))) 1204 | } else { 1205 | new org.apache.lucene.analysis.uk.UkrainianMorfologikAnalyzer() 1206 | } 1207 | } 1208 | } 1209 | 1210 | /* There is no default zero arg constructor for 1211 | *org.apache.lucene.collation.CollationKeyAnalyzer. 1212 | */ 1213 | 1214 | /* There is no default zero arg constructor for 1215 | *org.apache.lucene.collation.ICUCollationKeyAnalyzer. 1216 | */ 1217 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/feature/LuceneHelpers.scala: -------------------------------------------------------------------------------- 1 | package com.sparklingpandas.sparklingml.feature 2 | 3 | import scala.collection.JavaConverters._ 4 | 5 | import org.apache.lucene.analysis.CharArraySet 6 | 7 | object LuceneHelpers { 8 | /** 9 | * Convert a provided Array of strings into a CharArraySet. 10 | */ 11 | def wordstoCharArraySet(input: Array[String], ignoreCase: Boolean): 12 | CharArraySet = { 13 | new CharArraySet(input.toList.asJava, ignoreCase) 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/param/SharedParamsCodeGen.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.param 19 | 20 | import java.io.PrintWriter 21 | 22 | import scala.reflect.ClassTag 23 | import scala.xml.Utility 24 | 25 | import com.sparklingpandas.sparklingml.CodeGenerator 26 | 27 | /** 28 | * Code generator for shared params (sharedParams.scala). Run with 29 | * {{{ 30 | * build/sbt "runMain com.sparklingpandas.sparklingml.param.SharedParamsCodeGen" 31 | * }}}. 32 | * 33 | * Based on the same param generators in Spark, but with extra params. 34 | */ 35 | private[sparklingpandas] object SharedParamsCodeGen extends CodeGenerator { 36 | 37 | def main(args: Array[String]): Unit = { 38 | val params = Seq( 39 | // SparklingML Params 40 | ParamDesc[Boolean]("stopwordCase", 41 | "If the case should be considered when filtering stopwords", Some("false")), 42 | ParamDesc[Array[String]]("stopwords", 43 | "Stopwords to be filtered. Default value depends on underlying transformer"), 44 | // Spark Params 45 | ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), 46 | ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), 47 | ParamDesc[String]("inputCol", "input column name"), 48 | ParamDesc[Array[String]]("inputCols", "input column names"), 49 | ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"") 50 | )) 51 | 52 | val code = genSharedParams(params) 53 | val file = "src/main/scala/com/sparklingpandas/sparklingml/param/sharedParams.scala" 54 | val writer = new PrintWriter(file) 55 | writer.write(code) 56 | writer.close() 57 | } 58 | 59 | /** Generates the HasParam trait code for the input param. */ 60 | private def genHasParamTrait(param: ParamDesc[_]): String = { 61 | val name = param.name 62 | val Name = name(0).toUpper +: name.substring(1) 63 | val Param = param.paramTypeName 64 | val T = param.valueTypeName 65 | val doc = param.doc 66 | val defaultValue = param.defaultValueStr 67 | val defaultValueDoc = defaultValue.map { v => 68 | s" (default: $v)" 69 | }.getOrElse("") 70 | val setDefault = defaultValue.map { v => 71 | s""" 72 | | setDefault($name, $v) 73 | |""".stripMargin 74 | }.getOrElse("") 75 | val isValid = if (param.isValid != "") { 76 | ", " + param.isValid 77 | } else { 78 | "" 79 | } 80 | val groupStr = if (param.isExpertParam) { 81 | Array("expertParam", "expertGetParam") 82 | } else { 83 | Array("param", "getParam") 84 | } 85 | val methodStr = if (param.finalMethods) { 86 | "final def" 87 | } else { 88 | "def" 89 | } 90 | val fieldStr = if (param.finalFields) { 91 | "final val" 92 | } else { 93 | "val" 94 | } 95 | 96 | val htmlCompliantDoc = Utility.escape(doc) 97 | 98 | s""" 99 | |/** 100 | | * Trait for shared param $name$defaultValueDoc. 101 | | */ 102 | |trait Has$Name extends Params { 103 | | 104 | | /** 105 | | * Param for $htmlCompliantDoc. 106 | | * @group ${groupStr(0)} 107 | | */ 108 | | $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid) 109 | |$setDefault 110 | | /** @group ${groupStr(1)} */ 111 | | $methodStr get$Name: $T = $$($name) 112 | | 113 | | $methodStr set$Name(value: $T): this.type = set(this.$name, value) 114 | |} 115 | |""".stripMargin 116 | } 117 | 118 | /** Generates Scala source code for the input params with header. */ 119 | private def genSharedParams(params: Seq[ParamDesc[_]]): String = { 120 | val header = s"""${scalaLicenseHeader} 121 | | 122 | |package com.sparklingpandas.sparklingml.param 123 | | 124 | |import org.apache.spark.ml.param._ 125 | | 126 | |// DO NOT MODIFY THIS FILE! 127 | |// It was generated by SharedParamsCodeGen. 128 | | 129 | |// scalastyle:off 130 | |""".stripMargin 131 | 132 | val footer = "// scalastyle:on\n" 133 | 134 | val traits = params.map(genHasParamTrait).mkString 135 | 136 | header + traits + footer 137 | } 138 | } 139 | 140 | /** Description of a param. */ 141 | private case class ParamDesc[T: ClassTag]( 142 | name: String, 143 | doc: String, 144 | defaultValueStr: Option[String] = None, 145 | isValid: String = "", 146 | finalMethods: Boolean = true, 147 | finalFields: Boolean = true, 148 | isExpertParam: Boolean = false) { 149 | 150 | require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") 151 | require(doc.nonEmpty) // TODO: more rigorous on doc 152 | val c = implicitly[ClassTag[T]].runtimeClass 153 | 154 | def paramTypeName: String = { 155 | paramTypeNameFromClass(c) 156 | } 157 | 158 | def valueTypeName: String = { 159 | valueTypeNameFromClass(c) 160 | } 161 | private def paramTypeNameFromClass(c: Class[_]): String = { 162 | c match { 163 | case _ if c == classOf[Int] => "IntParam" 164 | case _ if c == classOf[Long] => "LongParam" 165 | case _ if c == classOf[Float] => "FloatParam" 166 | case _ if c == classOf[Double] => "DoubleParam" 167 | case _ if c == classOf[Boolean] => "BooleanParam" 168 | case _ if c.isArray && c.getComponentType == classOf[String] => 169 | s"StringArrayParam" 170 | case _ if c.isArray && c.getComponentType == classOf[Double] => 171 | s"DoubleArrayParam" 172 | case _ => s"Param[${getTypeString(c)}]" 173 | } 174 | } 175 | 176 | private def valueTypeNameFromClass(c: Class[_]): String = { 177 | getTypeString(c) 178 | } 179 | 180 | private def getTypeString(c: Class[_]): String = { 181 | c match { 182 | case _ if c == classOf[Int] => "Int" 183 | case _ if c == classOf[Long] => "Long" 184 | case _ if c == classOf[Float] => "Float" 185 | case _ if c == classOf[Double] => "Double" 186 | case _ if c == classOf[Boolean] => "Boolean" 187 | case _ if c == classOf[String] => "String" 188 | case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/param/sharedParams.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.param 19 | 20 | import org.apache.spark.ml.param._ 21 | 22 | // DO NOT MODIFY THIS FILE! 23 | // It was generated by SharedParamsCodeGen. 24 | 25 | // scalastyle:off 26 | 27 | /** 28 | * Trait for shared param stopwordCase (default: false). 29 | */ 30 | trait HasStopwordCase extends Params { 31 | 32 | /** 33 | * Param for If the case should be considered when filtering stopwords. 34 | * @group param 35 | */ 36 | final val stopwordCase: BooleanParam = new BooleanParam(this, "stopwordCase", "If the case should be considered when filtering stopwords") 37 | 38 | setDefault(stopwordCase, false) 39 | 40 | /** @group getParam */ 41 | final def getStopwordCase: Boolean = $(stopwordCase) 42 | 43 | final def setStopwordCase(value: Boolean): this.type = set(this.stopwordCase, value) 44 | } 45 | 46 | /** 47 | * Trait for shared param stopwords. 48 | */ 49 | trait HasStopwords extends Params { 50 | 51 | /** 52 | * Param for Stopwords to be filtered. Default value depends on underlying transformer. 53 | * @group param 54 | */ 55 | final val stopwords: StringArrayParam = new StringArrayParam(this, "stopwords", "Stopwords to be filtered. Default value depends on underlying transformer") 56 | 57 | /** @group getParam */ 58 | final def getStopwords: Array[String] = $(stopwords) 59 | 60 | final def setStopwords(value: Array[String]): this.type = set(this.stopwords, value) 61 | } 62 | 63 | /** 64 | * Trait for shared param featuresCol (default: "features"). 65 | */ 66 | trait HasFeaturesCol extends Params { 67 | 68 | /** 69 | * Param for features column name. 70 | * @group param 71 | */ 72 | final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") 73 | 74 | setDefault(featuresCol, "features") 75 | 76 | /** @group getParam */ 77 | final def getFeaturesCol: String = $(featuresCol) 78 | 79 | final def setFeaturesCol(value: String): this.type = set(this.featuresCol, value) 80 | } 81 | 82 | /** 83 | * Trait for shared param labelCol (default: "label"). 84 | */ 85 | trait HasLabelCol extends Params { 86 | 87 | /** 88 | * Param for label column name. 89 | * @group param 90 | */ 91 | final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") 92 | 93 | setDefault(labelCol, "label") 94 | 95 | /** @group getParam */ 96 | final def getLabelCol: String = $(labelCol) 97 | 98 | final def setLabelCol(value: String): this.type = set(this.labelCol, value) 99 | } 100 | 101 | /** 102 | * Trait for shared param inputCol. 103 | */ 104 | trait HasInputCol extends Params { 105 | 106 | /** 107 | * Param for input column name. 108 | * @group param 109 | */ 110 | final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") 111 | 112 | /** @group getParam */ 113 | final def getInputCol: String = $(inputCol) 114 | 115 | final def setInputCol(value: String): this.type = set(this.inputCol, value) 116 | } 117 | 118 | /** 119 | * Trait for shared param inputCols. 120 | */ 121 | trait HasInputCols extends Params { 122 | 123 | /** 124 | * Param for input column names. 125 | * @group param 126 | */ 127 | final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names") 128 | 129 | /** @group getParam */ 130 | final def getInputCols: Array[String] = $(inputCols) 131 | 132 | final def setInputCols(value: Array[String]): this.type = set(this.inputCols, value) 133 | } 134 | 135 | /** 136 | * Trait for shared param outputCol (default: uid + "__output"). 137 | */ 138 | trait HasOutputCol extends Params { 139 | 140 | /** 141 | * Param for output column name. 142 | * @group param 143 | */ 144 | final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") 145 | 146 | setDefault(outputCol, uid + "__output") 147 | 148 | /** @group getParam */ 149 | final def getOutputCol: String = $(outputCol) 150 | 151 | final def setOutputCol(value: String): this.type = set(this.outputCol, value) 152 | } 153 | // scalastyle:on 154 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/util/python/Initialize.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Initialize SparklingML on a given environment. This is done to allow 3 | * SparklingML to setup Python callbacks and does not need to be called 4 | * in the Python side. This is a little ugly, improvements are especially 5 | * welcome. 6 | */ 7 | package com.sparklingpandas.sparklingml.util.python 8 | 9 | import java.io._ 10 | 11 | import scala.concurrent.Promise 12 | import scala.collection.mutable.ArrayBuffer 13 | import scala.collection.Iterator 14 | import scala.collection.JavaConverters._ 15 | import scala.util.Success 16 | 17 | import org.apache.spark.SparkContext 18 | import org.apache.spark.deploy.PythonRunner._ 19 | import org.apache.spark.sql._ 20 | import org.apache.spark.internal.config._ 21 | import org.apache.spark.sql.execution.python.UserDefinedPythonFunction 22 | 23 | import py4j.GatewayServer 24 | 25 | /** 26 | * Abstract trait to implement in Python to allow Scala to call in to perform 27 | * registration. 28 | */ 29 | trait PythonRegisterationProvider { 30 | // Takes a SparkContext, SparkSession, String, and String 31 | // Returns UserDefinedPythonFunction but types + py4j :( 32 | def registerFunction( 33 | sc: SparkContext, session: Object, 34 | functionName: Object, params: Object): Object 35 | } 36 | 37 | /** 38 | * A utility class to redirect the child process's stdout or stderr. 39 | * This is copied from Spark. 40 | */ 41 | private[python] class RedirectThread( 42 | in: InputStream, 43 | out: OutputStream, 44 | name: String, 45 | propagateEof: Boolean = false) 46 | extends Thread(name) { 47 | 48 | setDaemon(true) 49 | override def run() { 50 | scala.util.control.Exception.ignoring(classOf[IOException]) { 51 | // FIXME: We copy the stream on the level of bytes to avoid encoding problems. 52 | tryWithSafeFinally { 53 | val buf = new Array[Byte](1024) 54 | Iterator.continually(in.read(buf)) 55 | .takeWhile(_ != -1) 56 | .foreach{len => 57 | out.write(buf, 0, len) 58 | out.flush() 59 | } 60 | } { 61 | if (propagateEof) { 62 | out.close() 63 | } 64 | } 65 | } 66 | } 67 | /** 68 | * Execute a block of code, then a finally block, but if exceptions happen in 69 | * the finally block, do not suppress the original exception. 70 | * 71 | * This is primarily an issue with `finally { out.close() }` blocks, where 72 | * close needs to be called to clean up `out`, but if an exception happened 73 | * in `out.write`, it's likely `out` may be corrupted and `out.close` will 74 | * fail as well. This would then suppress the original/likely more meaningful 75 | * exception from the original `out.write` call. 76 | */ 77 | def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { 78 | val ret = try { 79 | block 80 | } catch { 81 | case t: Throwable => 82 | // Purposefully not using NonFatal, because even fatal exceptions 83 | // we don't want to have our finallyBlock suppress 84 | try { 85 | finallyBlock 86 | } catch { 87 | case t2: Throwable => 88 | t.addSuppressed(t2) 89 | } 90 | throw t 91 | } 92 | finallyBlock 93 | ret 94 | } 95 | } 96 | 97 | 98 | object PythonRegistration { 99 | val pyFiles = "" 100 | // TODO(holden): Use reflection to determine if we've got an existing gateway server 101 | // to hijack instead. 102 | val gatewayServer: GatewayServer = { 103 | // Based on PythonUtils 104 | def sparkPythonPath: String = { 105 | val pythonPath = new ArrayBuffer[String] 106 | for (sparkHome <- sys.env.get("SPARK_HOME")) { 107 | pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) 108 | pythonPath += Seq(sparkHome, 109 | "python", "lib", "py4j-0.10.6-src.zip").mkString(File.separator) 110 | pythonPath += Seq(sparkHome, 111 | "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator) 112 | } 113 | pythonPath ++= SparkContext.jarOfObject(this) 114 | pythonPath.mkString(File.pathSeparator) 115 | } 116 | def mergePythonPaths(paths: String*): String = { 117 | paths.filter(_ != "").mkString(File.pathSeparator) 118 | } 119 | 120 | // Format python file paths before adding them to the PYTHONPATH 121 | val formattedPyFiles = formatPaths(pyFiles) 122 | 123 | // Launch a gatewayserver to handle registration, based on PythonRunner.scala 124 | val sparkConf = SparkContext.getOrCreate().getConf 125 | // Format python file paths before adding them to the PYTHONPATH 126 | val pythonExec = sparkConf.getOption("spark.pyspark.driver.python") 127 | .orElse(sparkConf.getOption("spark.pyspark.python")) 128 | .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) 129 | .orElse(sys.env.get("PYSPARK_PYTHON")) 130 | .getOrElse("python") 131 | // Launch a Py4J gateway server for the process to connect to; this will let it see our 132 | // Java system properties and such 133 | val gatewayServer = new py4j.GatewayServer(PythonRegistration, 0) 134 | val thread = new Thread(new Runnable() { 135 | override def run(): Unit = { 136 | gatewayServer.start(true) 137 | } 138 | }) 139 | thread.setName("py4j-gateway-init") 140 | thread.setDaemon(true) 141 | thread.start() 142 | 143 | // Wait until the gateway server has started, so that we know which port is it bound to. 144 | // `gatewayServer.start()` will start a new thread and run the server code there, after 145 | // initializing the socket, so the thread started above will end as soon as the server is 146 | // ready to serve connections. 147 | thread.join() 148 | 149 | // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the 150 | // python directories in SPARK_HOME (if set), and any files in the pyFiles argument 151 | val pathElements = new ArrayBuffer[String] 152 | pathElements ++= formattedPyFiles 153 | pathElements += sparkPythonPath 154 | pathElements += sys.env.getOrElse("PYTHONPATH", "") 155 | val pythonPath = mergePythonPaths(pathElements: _*) 156 | 157 | // Launch Python process 158 | val builder = new ProcessBuilder((Seq(pythonExec, "-m", "sparklingml.startup")).asJava) 159 | val env = builder.environment() 160 | env.put("SPARKLING_ML_SPECIFIC", "YES") 161 | env.put("PYTHONPATH", pythonPath) 162 | // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: 163 | env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string 164 | env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) 165 | // pass conf spark.pyspark.python to python process, the only way to pass info to 166 | // python process is through environment variable. 167 | env.put("PYSPARK_PYTHON", pythonExec) 168 | sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) 169 | // Ugly but needed for stdout and stderr to synchronize 170 | builder.redirectErrorStream(true) 171 | val pythonThread = new Thread(new Runnable() { 172 | override def run(): Unit = { 173 | try { 174 | val process = builder.start() 175 | 176 | new RedirectThread(process.getInputStream, System.out, "redirect output").start() 177 | 178 | val exitCode = process.waitFor() 179 | if (exitCode != 0) { 180 | throw new Exception(s"Exit code ${exitCode}") 181 | } 182 | } finally { 183 | gatewayServer.shutdown() 184 | } 185 | } 186 | }) 187 | pythonThread.setName("python-udf-registrationProvider-thread") 188 | pythonThread.setDaemon(true) 189 | pythonThread.start() 190 | println(s"Waiting for friend on port ${gatewayServer.getListeningPort}") 191 | gatewayServer 192 | } 193 | 194 | def register(provider: PythonRegisterationProvider) = { 195 | pythonRegistrationProvider.complete(Success(provider)) 196 | } 197 | 198 | val pythonRegistrationProvider = Promise[PythonRegisterationProvider]() 199 | } 200 | -------------------------------------------------------------------------------- /src/main/scala/com/sparklingpandas/sparklingml/util/python/PythonTransformer.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Base classes to allow Scala pipeline stages to be backed 3 | * by Python code. 4 | */ 5 | package com.sparklingpandas.sparklingml.util.python 6 | 7 | import com.sparklingpandas.sparklingml.param.{HasInputCol, HasOutputCol} 8 | 9 | import scala.concurrent._ 10 | import scala.concurrent.duration._ 11 | 12 | import org.apache.spark.sql._ 13 | import org.apache.spark.sql.types._ 14 | import org.apache.spark.ml.Transformer 15 | import org.apache.spark.sql.execution.python.UserDefinedPythonFunction 16 | 17 | trait PythonTransformer extends Transformer with HasInputCol with HasOutputCol { 18 | // Name of the python function to register as a UDF 19 | val pythonFunctionName: String 20 | 21 | def constructUDF(session: SparkSession) = { 22 | val registrationProviderFuture = 23 | PythonRegistration.pythonRegistrationProvider.future 24 | val registrationProvider = 25 | Await.result(registrationProviderFuture, 10 seconds) 26 | // Call the registration provider from startup.py to get a Python UDF back. 27 | val pythonUdf = Option(registrationProvider.registerFunction( 28 | session.sparkContext, 29 | session, 30 | pythonFunctionName, 31 | miniSerializeParams())) 32 | val castUdf = pythonUdf.map(_.asInstanceOf[UserDefinedPythonFunction]) 33 | .getOrElse(throw new Exception("Failed register PythonFunction.")) 34 | castUdf 35 | } 36 | 37 | override def transform(dataset: Dataset[_]): DataFrame = { 38 | transformSchema(dataset.schema, logging = true) 39 | val session = dataset.sparkSession 40 | val transformUDF = constructUDF(session) 41 | dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) 42 | } 43 | 44 | /** 45 | * Returns the data type of the output column. 46 | */ 47 | protected def outputDataType: DataType 48 | 49 | override def transformSchema(schema: StructType): StructType = { 50 | val inputType = schema($(inputCol)).dataType 51 | validateInputType(inputType) 52 | if (schema.fieldNames.contains($(outputCol))) { 53 | throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") 54 | } 55 | val outputFields = schema.fields :+ 56 | StructField($(outputCol), outputDataType, nullable = false) 57 | StructType(outputFields) 58 | } 59 | 60 | 61 | /** 62 | * Validates the input type. Throw an exception if it is invalid. 63 | */ 64 | protected def validateInputType(inputType: DataType): Unit 65 | 66 | /** 67 | * Do you need to pass some of your parameters to Python? 68 | * Put them in here and have them get evaluated with a lambda. 69 | * I know its kind of sketchy -- sorry! 70 | * This should be consider temporary, unless it works. 71 | */ 72 | def miniSerializeParams(): String 73 | } 74 | -------------------------------------------------------------------------------- /src/test/scala/com/sparklingpandas/sparklingml/feature/BasicPython.scala: -------------------------------------------------------------------------------- 1 | package com.sparklingpandas.sparklingml.feature 2 | 3 | import org.apache.spark.ml.param._ 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.util.Identifiable 6 | import org.apache.spark.sql.types._ 7 | 8 | import org.scalatest._ 9 | 10 | import com.holdenkarau.spark.testing.DataFrameSuiteBase 11 | 12 | case class BadInputData(input: Double) 13 | 14 | class NltkPosPythonSuite extends FunSuite with DataFrameSuiteBase with Matchers { 15 | 16 | override implicit def reuseContextIfPossible: Boolean = true 17 | 18 | override implicit def enableHiveSupport: Boolean = false 19 | 20 | test("verify that the transformer runs") { 21 | import spark.implicits._ 22 | val transformer = new NltkPosPython() 23 | val input = spark.createDataset( 24 | List(InputData("Boo is happy"), InputData("Boo is sad"), 25 | InputData("Boo says that the Sparking Pink Pandas are the coolest queer scooter club in SF") 26 | )) 27 | transformer.setInputCol("input") 28 | transformer.setOutputCol("output") 29 | val result = transformer.transform(input).collect() 30 | result.size shouldBe 3 31 | result(0)(0) shouldBe "Boo is happy" 32 | // TODO(Holden): Figure out why the +- 0.1 matcher syntax wasn't working here 33 | result(0)(1) shouldBe 0.649 34 | result(1)(0) shouldBe "Boo is sad" 35 | result(1)(1) shouldBe 0.0 36 | result(2)(1) shouldBe 0.0 37 | } 38 | 39 | test("verify we validate input types") { 40 | import spark.implicits._ 41 | val transformer = new NltkPosPython() 42 | val input = spark.createDataset( 43 | List(BadInputData(1.0), BadInputData(2.0))) 44 | transformer.setInputCol("input") 45 | transformer.setOutputCol("output") 46 | val pipeline = new Pipeline().setStages(Array(transformer)) 47 | // We expect the excepiton here 48 | assertThrows[java.lang.IllegalArgumentException] { 49 | val model = pipeline.fit(input) 50 | } 51 | } 52 | 53 | } 54 | 55 | 56 | class StrLenPlusKPythonSuite extends FunSuite with DataFrameSuiteBase with Matchers { 57 | 58 | override implicit def reuseContextIfPossible: Boolean = true 59 | 60 | override implicit def enableHiveSupport: Boolean = false 61 | 62 | test("verify that the transformer runs") { 63 | import spark.implicits._ 64 | val transformer = new StrLenPlusKPython() 65 | transformer.setK(1) 66 | val input = spark.createDataset( 67 | List(InputData("hi"), InputData("boo"), InputData("boop"))) 68 | transformer.setInputCol("input") 69 | transformer.setOutputCol("output") 70 | val result = transformer.transform(input).collect() 71 | result.size shouldBe 3 72 | result(0)(0) shouldBe "hi" 73 | result(0)(1) shouldBe 3 74 | result(1)(0) shouldBe "boo" 75 | result(1)(1) shouldBe 4 76 | } 77 | 78 | } 79 | 80 | class SpacyTokenizePythonSuite extends FunSuite with DataFrameSuiteBase with Matchers { 81 | 82 | override implicit def reuseContextIfPossible: Boolean = true 83 | 84 | override implicit def enableHiveSupport: Boolean = false 85 | 86 | test("verify spacy tokenization works") { 87 | import spark.implicits._ 88 | val transformer = new SpacyTokenizePython() 89 | transformer.setLang("en_core_web_sm") 90 | val input = spark.createDataset( 91 | List(InputData("hi boo"), InputData("boo"))) 92 | transformer.setInputCol("input") 93 | transformer.setOutputCol("output") 94 | val result = transformer.transform(input).collect() 95 | result.size shouldBe 2 96 | result(0)(1) shouldBe Array("hi", "boo") 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /src/test/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzerGeneratorsTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.feature 19 | 20 | import sys.process._ 21 | 22 | import org.scalatest._ 23 | 24 | class LuceneAnalyzerGeneratorsTest extends FunSuite with Matchers{ 25 | test("verify the generated code is up to date") { 26 | LuceneAnalyzerGenerators.main(Array[String]()) 27 | val basePath = "scala/com/sparklingpandas/sparklingml/feature/" 28 | val testResult = s"git diff -q ./src/test/${basePath}LuceneAnalyzersTests.scala".! 29 | val transformResult = s"git diff -q ./src/main/${basePath}LuceneAnalyzers.scala".! 30 | testResult shouldBe 0 31 | transformResult shouldBe 0 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/test/scala/com/sparklingpandas/sparklingml/feature/LuceneAnalyzersTests.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.feature 19 | 20 | import org.apache.spark.ml.param._ 21 | import org.apache.spark.ml.util.Identifiable 22 | 23 | import org.apache.lucene.analysis.Analyzer 24 | 25 | import com.sparklingpandas.sparklingml.param._ 26 | 27 | // DO NOT MODIFY THIS FILE! 28 | // It was auto generated by LuceneAnalyzerGenerators. 29 | 30 | 31 | /** 32 | * A super simple test 33 | */ 34 | class ArabicAnalyzerLuceneTest 35 | extends LuceneStopwordTransformerTest[ArabicAnalyzerLucene] { 36 | val transformer = new ArabicAnalyzerLucene() 37 | } 38 | 39 | /** 40 | * A super simple test 41 | */ 42 | class BulgarianAnalyzerLuceneTest 43 | extends LuceneStopwordTransformerTest[BulgarianAnalyzerLucene] { 44 | val transformer = new BulgarianAnalyzerLucene() 45 | } 46 | 47 | /** 48 | * A super simple test 49 | */ 50 | class BrazilianAnalyzerLuceneTest 51 | extends LuceneStopwordTransformerTest[BrazilianAnalyzerLucene] { 52 | val transformer = new BrazilianAnalyzerLucene() 53 | } 54 | 55 | /** 56 | * A super simple test 57 | */ 58 | class CatalanAnalyzerLuceneTest 59 | extends LuceneStopwordTransformerTest[CatalanAnalyzerLucene] { 60 | val transformer = new CatalanAnalyzerLucene() 61 | } 62 | 63 | /** 64 | * A super simple test 65 | */ 66 | class CJKAnalyzerLuceneTest 67 | extends LuceneStopwordTransformerTest[CJKAnalyzerLucene] { 68 | val transformer = new CJKAnalyzerLucene() 69 | } 70 | 71 | /** 72 | * A super simple test 73 | */ 74 | class SoraniAnalyzerLuceneTest 75 | extends LuceneStopwordTransformerTest[SoraniAnalyzerLucene] { 76 | val transformer = new SoraniAnalyzerLucene() 77 | } 78 | 79 | /** 80 | * A super simple test 81 | */ 82 | class SmartChineseAnalyzerLuceneTest 83 | extends LuceneTransformerTest[SmartChineseAnalyzerLucene] { 84 | val transformer = new SmartChineseAnalyzerLucene() 85 | } 86 | 87 | /** 88 | * A super simple test 89 | */ 90 | class KeywordAnalyzerLuceneTest 91 | extends LuceneTransformerTest[KeywordAnalyzerLucene] { 92 | val transformer = new KeywordAnalyzerLucene() 93 | } 94 | 95 | /** 96 | * A super simple test 97 | */ 98 | class SimpleAnalyzerLuceneTest 99 | extends LuceneTransformerTest[SimpleAnalyzerLucene] { 100 | val transformer = new SimpleAnalyzerLucene() 101 | } 102 | 103 | /** 104 | * A super simple test 105 | */ 106 | class StopAnalyzerLuceneTest 107 | extends LuceneStopwordTransformerTest[StopAnalyzerLucene] { 108 | val transformer = new StopAnalyzerLucene() 109 | } 110 | 111 | /** 112 | * A super simple test 113 | */ 114 | class UnicodeWhitespaceAnalyzerLuceneTest 115 | extends LuceneTransformerTest[UnicodeWhitespaceAnalyzerLucene] { 116 | val transformer = new UnicodeWhitespaceAnalyzerLucene() 117 | } 118 | 119 | /** 120 | * A super simple test 121 | */ 122 | class WhitespaceAnalyzerLuceneTest 123 | extends LuceneTransformerTest[WhitespaceAnalyzerLucene] { 124 | val transformer = new WhitespaceAnalyzerLucene() 125 | } 126 | 127 | /** 128 | * A super simple test 129 | */ 130 | class CzechAnalyzerLuceneTest 131 | extends LuceneStopwordTransformerTest[CzechAnalyzerLucene] { 132 | val transformer = new CzechAnalyzerLucene() 133 | } 134 | 135 | /** 136 | * A super simple test 137 | */ 138 | class DanishAnalyzerLuceneTest 139 | extends LuceneStopwordTransformerTest[DanishAnalyzerLucene] { 140 | val transformer = new DanishAnalyzerLucene() 141 | } 142 | 143 | /** 144 | * A super simple test 145 | */ 146 | class GermanAnalyzerLuceneTest 147 | extends LuceneStopwordTransformerTest[GermanAnalyzerLucene] { 148 | val transformer = new GermanAnalyzerLucene() 149 | } 150 | 151 | /** 152 | * A super simple test 153 | */ 154 | class GreekAnalyzerLuceneTest 155 | extends LuceneStopwordTransformerTest[GreekAnalyzerLucene] { 156 | val transformer = new GreekAnalyzerLucene() 157 | } 158 | 159 | /** 160 | * A super simple test 161 | */ 162 | class EnglishAnalyzerLuceneTest 163 | extends LuceneStopwordTransformerTest[EnglishAnalyzerLucene] { 164 | val transformer = new EnglishAnalyzerLucene() 165 | } 166 | 167 | /** 168 | * A super simple test 169 | */ 170 | class SpanishAnalyzerLuceneTest 171 | extends LuceneStopwordTransformerTest[SpanishAnalyzerLucene] { 172 | val transformer = new SpanishAnalyzerLucene() 173 | } 174 | 175 | /** 176 | * A super simple test 177 | */ 178 | class BasqueAnalyzerLuceneTest 179 | extends LuceneStopwordTransformerTest[BasqueAnalyzerLucene] { 180 | val transformer = new BasqueAnalyzerLucene() 181 | } 182 | 183 | /** 184 | * A super simple test 185 | */ 186 | class PersianAnalyzerLuceneTest 187 | extends LuceneStopwordTransformerTest[PersianAnalyzerLucene] { 188 | val transformer = new PersianAnalyzerLucene() 189 | } 190 | 191 | /** 192 | * A super simple test 193 | */ 194 | class FinnishAnalyzerLuceneTest 195 | extends LuceneStopwordTransformerTest[FinnishAnalyzerLucene] { 196 | val transformer = new FinnishAnalyzerLucene() 197 | } 198 | 199 | /** 200 | * A super simple test 201 | */ 202 | class FrenchAnalyzerLuceneTest 203 | extends LuceneStopwordTransformerTest[FrenchAnalyzerLucene] { 204 | val transformer = new FrenchAnalyzerLucene() 205 | } 206 | 207 | /** 208 | * A super simple test 209 | */ 210 | class IrishAnalyzerLuceneTest 211 | extends LuceneStopwordTransformerTest[IrishAnalyzerLucene] { 212 | val transformer = new IrishAnalyzerLucene() 213 | } 214 | 215 | /** 216 | * A super simple test 217 | */ 218 | class GalicianAnalyzerLuceneTest 219 | extends LuceneStopwordTransformerTest[GalicianAnalyzerLucene] { 220 | val transformer = new GalicianAnalyzerLucene() 221 | } 222 | 223 | /** 224 | * A super simple test 225 | */ 226 | class HindiAnalyzerLuceneTest 227 | extends LuceneStopwordTransformerTest[HindiAnalyzerLucene] { 228 | val transformer = new HindiAnalyzerLucene() 229 | } 230 | 231 | /** 232 | * A super simple test 233 | */ 234 | class HungarianAnalyzerLuceneTest 235 | extends LuceneStopwordTransformerTest[HungarianAnalyzerLucene] { 236 | val transformer = new HungarianAnalyzerLucene() 237 | } 238 | 239 | /** 240 | * A super simple test 241 | */ 242 | class ArmenianAnalyzerLuceneTest 243 | extends LuceneStopwordTransformerTest[ArmenianAnalyzerLucene] { 244 | val transformer = new ArmenianAnalyzerLucene() 245 | } 246 | 247 | /** 248 | * A super simple test 249 | */ 250 | class IndonesianAnalyzerLuceneTest 251 | extends LuceneStopwordTransformerTest[IndonesianAnalyzerLucene] { 252 | val transformer = new IndonesianAnalyzerLucene() 253 | } 254 | 255 | /** 256 | * A super simple test 257 | */ 258 | class ItalianAnalyzerLuceneTest 259 | extends LuceneStopwordTransformerTest[ItalianAnalyzerLucene] { 260 | val transformer = new ItalianAnalyzerLucene() 261 | } 262 | 263 | /** 264 | * A super simple test 265 | */ 266 | class JapaneseAnalyzerLuceneTest 267 | extends LuceneTransformerTest[JapaneseAnalyzerLucene] { 268 | val transformer = new JapaneseAnalyzerLucene() 269 | } 270 | 271 | /** 272 | * A super simple test 273 | */ 274 | class LithuanianAnalyzerLuceneTest 275 | extends LuceneStopwordTransformerTest[LithuanianAnalyzerLucene] { 276 | val transformer = new LithuanianAnalyzerLucene() 277 | } 278 | 279 | /** 280 | * A super simple test 281 | */ 282 | class LatvianAnalyzerLuceneTest 283 | extends LuceneStopwordTransformerTest[LatvianAnalyzerLucene] { 284 | val transformer = new LatvianAnalyzerLucene() 285 | } 286 | 287 | /** 288 | * A super simple test 289 | */ 290 | class MorfologikAnalyzerLuceneTest 291 | extends LuceneTransformerTest[MorfologikAnalyzerLucene] { 292 | val transformer = new MorfologikAnalyzerLucene() 293 | } 294 | 295 | /** 296 | * A super simple test 297 | */ 298 | class DutchAnalyzerLuceneTest 299 | extends LuceneTransformerTest[DutchAnalyzerLucene] { 300 | val transformer = new DutchAnalyzerLucene() 301 | } 302 | 303 | /** 304 | * A super simple test 305 | */ 306 | class NorwegianAnalyzerLuceneTest 307 | extends LuceneStopwordTransformerTest[NorwegianAnalyzerLucene] { 308 | val transformer = new NorwegianAnalyzerLucene() 309 | } 310 | 311 | /** 312 | * A super simple test 313 | */ 314 | class PolishAnalyzerLuceneTest 315 | extends LuceneStopwordTransformerTest[PolishAnalyzerLucene] { 316 | val transformer = new PolishAnalyzerLucene() 317 | } 318 | 319 | /** 320 | * A super simple test 321 | */ 322 | class PortugueseAnalyzerLuceneTest 323 | extends LuceneStopwordTransformerTest[PortugueseAnalyzerLucene] { 324 | val transformer = new PortugueseAnalyzerLucene() 325 | } 326 | 327 | /** 328 | * A super simple test 329 | */ 330 | class RomanianAnalyzerLuceneTest 331 | extends LuceneStopwordTransformerTest[RomanianAnalyzerLucene] { 332 | val transformer = new RomanianAnalyzerLucene() 333 | } 334 | 335 | /** 336 | * A super simple test 337 | */ 338 | class RussianAnalyzerLuceneTest 339 | extends LuceneStopwordTransformerTest[RussianAnalyzerLucene] { 340 | val transformer = new RussianAnalyzerLucene() 341 | } 342 | 343 | /** 344 | * A super simple test 345 | */ 346 | class ShingleAnalyzerWrapperLuceneTest 347 | extends LuceneTransformerTest[ShingleAnalyzerWrapperLucene] { 348 | val transformer = new ShingleAnalyzerWrapperLucene() 349 | } 350 | 351 | /** 352 | * A super simple test 353 | */ 354 | class ClassicAnalyzerLuceneTest 355 | extends LuceneStopwordTransformerTest[ClassicAnalyzerLucene] { 356 | val transformer = new ClassicAnalyzerLucene() 357 | } 358 | 359 | /** 360 | * A super simple test 361 | */ 362 | class StandardAnalyzerLuceneTest 363 | extends LuceneStopwordTransformerTest[StandardAnalyzerLucene] { 364 | val transformer = new StandardAnalyzerLucene() 365 | } 366 | 367 | /** 368 | * A super simple test 369 | */ 370 | class UAX29URLEmailAnalyzerLuceneTest 371 | extends LuceneStopwordTransformerTest[UAX29URLEmailAnalyzerLucene] { 372 | val transformer = new UAX29URLEmailAnalyzerLucene() 373 | } 374 | 375 | /** 376 | * A super simple test 377 | */ 378 | class SwedishAnalyzerLuceneTest 379 | extends LuceneStopwordTransformerTest[SwedishAnalyzerLucene] { 380 | val transformer = new SwedishAnalyzerLucene() 381 | } 382 | 383 | /** 384 | * A super simple test 385 | */ 386 | class ThaiAnalyzerLuceneTest 387 | extends LuceneStopwordTransformerTest[ThaiAnalyzerLucene] { 388 | val transformer = new ThaiAnalyzerLucene() 389 | } 390 | 391 | /** 392 | * A super simple test 393 | */ 394 | class TurkishAnalyzerLuceneTest 395 | extends LuceneStopwordTransformerTest[TurkishAnalyzerLucene] { 396 | val transformer = new TurkishAnalyzerLucene() 397 | } 398 | 399 | /** 400 | * A super simple test 401 | */ 402 | class UkrainianMorfologikAnalyzerLuceneTest 403 | extends LuceneStopwordTransformerTest[UkrainianMorfologikAnalyzerLucene] { 404 | val transformer = new UkrainianMorfologikAnalyzerLucene() 405 | } 406 | -------------------------------------------------------------------------------- /src/test/scala/com/sparklingpandas/sparklingml/feature/LuceneBaseTests.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.feature 19 | 20 | import org.apache.spark.ml.param._ 21 | 22 | import org.apache.lucene.analysis.Analyzer 23 | 24 | import org.scalatest._ 25 | 26 | import com.holdenkarau.spark.testing.DataFrameSuiteBase 27 | 28 | import com.sparklingpandas.sparklingml.param._ 29 | 30 | case class InputData(input: String) 31 | 32 | abstract class LuceneTransformerTest[T <: LuceneTransformer[_]] extends 33 | FunSuite with DataFrameSuiteBase with Matchers { 34 | 35 | override implicit def reuseContextIfPossible: Boolean = true 36 | 37 | override implicit def enableHiveSupport: Boolean = false 38 | 39 | val transformer: T 40 | 41 | test("verify that the transformer runs") { 42 | import spark.implicits._ 43 | val input = spark.createDataset( 44 | List(InputData("hi"), InputData("boo"), InputData("boop"))) 45 | transformer.setInputCol("input") 46 | val result = transformer.transform(input).collect() 47 | result.size shouldBe 3 48 | } 49 | } 50 | 51 | abstract class LuceneStopwordTransformerTest[T <: LuceneTransformer[_]] extends 52 | LuceneTransformerTest[T] { 53 | test("verify stopword is dropped, nothing else") { 54 | import spark.implicits._ 55 | val input = spark.createDataset( 56 | List(InputData("hi"), InputData("boo"), InputData("boop"))) 57 | val thst = transformer.asInstanceOf[HasStopwords] 58 | thst.set(thst.stopwords, Array("boo")) 59 | thst.setStopwords(Array("boop")) 60 | transformer.asInstanceOf[T].setInputCol("input") 61 | val result = transformer.transform(input).collect() 62 | result.size shouldBe 3 63 | result(2).getSeq(1) shouldBe empty 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /src/test/scala/com/sparklingpandas/sparklingml/param/SharedParamsCodeGenTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package com.sparklingpandas.sparklingml.param 19 | 20 | import sys.process._ 21 | 22 | import org.scalatest._ 23 | 24 | class SharedParamsCodeGenTest extends FunSuite with Matchers { 25 | test("verify the generated code is up to date") { 26 | SharedParamsCodeGen.main(Array[String]()) 27 | val basePath = "scala/com/sparklingpandas/sparklingml/param/" 28 | val result = s"git diff -q ./src/main/${basePath}/sharedParams.scala".! 29 | result shouldBe 0 30 | } 31 | } 32 | --------------------------------------------------------------------------------