├── version.txt
├── python
├── requirements.txt
├── MANIFEST.in
├── dirty_cat_spark
│ ├── feature
│ │ ├── __init__.py
│ │ └── encoder.py
│ ├── __init__.py
│ └── utils
│ │ ├── java_reader.py
│ │ └── __init__.py
├── README.md
└── setup.py
├── Makefile
├── src
├── main
│ └── scala
│ │ └── com
│ │ └── rakuten
│ │ └── dirty_cat
│ │ ├── spark_utils
│ │ ├── Utils.scala
│ │ ├── OpenHashMap.scala
│ │ ├── BitSet.scala
│ │ └── OpenHashSet.scala
│ │ ├── utils
│ │ ├── TextUtils.scala
│ │ └── StringDistances.scala
│ │ ├── spark_persistence
│ │ └── ReadWrite.scala
│ │ └── features
│ │ └── SimilarityEncoder.scala
└── test
│ └── scala
│ └── com
│ └── rakuten
│ └── dirty_cat
│ └── feature
│ └── SimilarityEncoderTestSuite.scala
├── README.md
├── LICENSE
└── examples
└── Midwest_Survey.ipynb
/version.txt:
--------------------------------------------------------------------------------
1 | 0.1-SNAPSHOT
2 |
3 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | pyspark
2 | setuptools
--------------------------------------------------------------------------------
/python/MANIFEST.in:
--------------------------------------------------------------------------------
1 | # Include jar files
2 | include dirty_cat_spark/jars/*.jar
3 |
--------------------------------------------------------------------------------
/python/dirty_cat_spark/feature/__init__.py:
--------------------------------------------------------------------------------
1 | # from pydirtycat.feature.encoder import SimilarityEncoder
2 | # # from pydirtycat.feature.encoder import SimilarityEncoderModel
3 |
4 | # __all__ = ["SimilarityEncoder"]
5 |
--------------------------------------------------------------------------------
/python/dirty_cat_spark/__init__.py:
--------------------------------------------------------------------------------
1 | # import sys
2 | # from sparknlp import annotator
3 | # sys.modules['com.johnsnowlabs.nlp.annotators'] = annotator
4 |
5 | # from pydirtycat import feature
6 |
7 | # __all__ = ["feature"]
8 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL = /bin/bash
2 | VERSION = $(shell cat version.txt)
3 |
4 | .PHONY: clean clean-pyc clean-dist
5 |
6 | clean: clean-dist clean-pyc
7 |
8 | clean-pyc:
9 | find . -name '*.pyc' -exec rm -f {} +
10 | find . -name '*.pyo' -exec rm -f {} +
11 | find . -name '*~' -exec rm -f {} +
12 | find . -name '__pycache__' -exec rm -fr {} +
13 |
14 | clean-dist:
15 | rm -rf target
16 | rm -rf python/build/
17 | rm -rf python/*.egg-info
18 |
19 |
20 | publish: clean
21 | # use spark packages to create the distribution
22 | sbt clean
23 | sbt compile
24 | sbt package
25 |
26 | cd python; python setup.py sdist
27 |
28 |
--------------------------------------------------------------------------------
/python/dirty_cat_spark/utils/java_reader.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from pyspark.ml.util import MLReader, _jvm
4 |
5 |
6 | class CustomJavaMLReader(MLReader):
7 | """
8 | (Custom) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
9 | """
10 |
11 | def __init__(self, clazz, java_class):
12 | self._clazz = clazz
13 | self._jread = self._load_java_obj(java_class).read()
14 |
15 | def load(self, path):
16 | """Load the ML instance from the input path."""
17 | java_obj = self._jread.load(path)
18 | return self._clazz._from_java(java_obj)
19 |
20 | @classmethod
21 | def _load_java_obj(cls, java_class):
22 | """Load the peer Java object of the ML instance."""
23 | java_obj = _jvm()
24 | for name in java_class.split("."):
25 | java_obj = getattr(java_obj, name)
26 | return java_obj
27 |
--------------------------------------------------------------------------------
/python/README.md:
--------------------------------------------------------------------------------
1 | # Python wrapper
2 |
3 | This is a python wrapper of com.rakuten.dirty_cat.
4 |
5 |
6 | ## Install
7 |
8 | Please make sure to have built target/scala_VERSION/PACKAGE.jar as this is
9 | required to use this wrapper.
10 |
11 |
12 | * Local machine:
13 |
14 | ```{.python}
15 | python setup.py install --user
16 | ```
17 |
18 | * Distributed: Build distribution
19 | ```{.python}
20 | python setup.py sdist
21 | ```
22 |
23 |
24 | ## Testing
25 |
26 | ```{.bash}
27 | spark-submit --jars target/scala-2.11/com-rakuten-dirty_cat_2.11-0.1-SNAPSHOT.jar python/test/test_similarity_encoder.py
28 | ```
29 |
30 |
31 |
32 | ### Usage
33 |
34 | #### Declaration
35 | ```{.python}
36 | from dirty_cat_spark.feature.encoder.SimilarityEncoder
37 |
38 | encoder = (SimilarityEncoder()
39 | .setInputCol("devices")
40 | .setOutputCol("devicesEncoded")
41 | .setSimilarityType("nGram")
42 | .setVocabSize(1000))
43 | ```
44 |
45 | #### Using it in a pipeline
46 | ```{.python}
47 | from pyspark.ml import Pipeline
48 |
49 | pipeline = Pipeline(stages[encoder, YOUR_ESTIMATOR])
50 | pipelineModel = pipeline.fit(dataframe)
51 | ```
52 |
53 | #### Serialization
54 | ```{.python}
55 | pipelineModel.writei().overwrite().save("pipeline.parquet")
56 | ```
57 |
58 |
59 | ## Reference
60 | Patricio Cerda, Gaël Varoquaux, Balázs Kégl. Similarity encoding for learning with dirty categorical variables. 2018. Accepted for publication in: Machine Learning journal, Springer.
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/spark_utils/Utils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | *
17 | * This code is a modified version of the original Spark 2.1 implementation.
18 | */
19 |
20 | package com.rakuten.dirty_cat.utils
21 |
22 | // based on org.apache.spark.util copy /paste
23 | private[dirty_cat] object Utils {
24 |
25 | def getSparkClassLoader: ClassLoader = getClass.getClassLoader
26 |
27 | def getContextOrSparkClassLoader: ClassLoader =
28 | Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
29 |
30 | // scalastyle:off classforname
31 | /** Preferred alternative to Class.forName(className) */
32 | def classForName(className: String): Class[_] = {
33 | Class.forName(className, true, getContextOrSparkClassLoader)
34 | // scalastyle:on classforname
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/utils/TextUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo
3 | * under one or more contributor license agreements.
4 | * See the NOTICE file distributed with this work for additional information
5 | * regarding copyright ownership.
6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file
7 | * to You under the Apache License, Version 2.0
8 | * (the "License"); you may not use this file except in compliance with
9 | * the License. You may obtain a copy of the License at
10 | *
11 | * http://www.apache.org/licenses/LICENSE-2.0
12 | *
13 | * Unless required by applicable law or agreed to in writing, software
14 | * distributed under the License is distributed on an "AS IS" BASIS,
15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | * See the License for the specific language governing permissions and
17 | * limitations under the License.
18 | */
19 |
20 | package com.rakuten.dirty_cat.utils
21 |
22 |
23 | import org.apache.spark.sql.functions.udf
24 |
25 | package object TextUtils {
26 |
27 | import java.text.Normalizer
28 | import scala.util.Try
29 |
30 | def normalizeString(input: String): String = {
31 | Try{
32 | val cleaned = input.trim.toLowerCase
33 | val normalized = (Normalizer.normalize(cleaned, Normalizer.Form.NFD)
34 | .replaceAll("[\\p{InCombiningDiacriticalMarks}\\p{IsM}\\p{IsLm}\\p{IsSk}]+", ""))
35 | (normalized
36 | .replaceAll("'s", "")
37 | .replaceAll("ß", "ss")
38 | .replaceAll("ø", "o")
39 | .replaceAll("[^a-zA-Z0-9-]+", "-")
40 | .replaceAll("-+", "-")
41 | .stripSuffix("-"))
42 | }.getOrElse(input)
43 | }
44 |
45 | val normalizeStringUDF = udf[String, String](normalizeString(_))
46 | }
47 |
48 |
--------------------------------------------------------------------------------
/python/dirty_cat_spark/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import py4j.protocol
2 | from py4j.protocol import Py4JJavaError
3 | from py4j.java_gateway import JavaObject
4 | from py4j.java_collections import JavaArray, JavaList, JavaMap
5 |
6 | from pyspark import RDD, SparkContext
7 | from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
8 | from pyspark.sql import DataFrame, SQLContext
9 |
10 | # Hack for support float('inf') in Py4j
11 | _old_smart_decode = py4j.protocol.smart_decode
12 |
13 | _float_str_mapping = {
14 | 'nan': 'NaN',
15 | 'inf': 'Infinity',
16 | '-inf': '-Infinity',
17 | }
18 |
19 |
20 | def _new_smart_decode(obj):
21 | if isinstance(obj, float):
22 | s = str(obj)
23 | return _float_str_mapping.get(s, s)
24 | return _old_smart_decode(obj)
25 |
26 | py4j.protocol.smart_decode = _new_smart_decode
27 |
28 |
29 | _picklable_classes = [
30 | 'SparseVector',
31 | 'DenseVector',
32 | 'SparseMatrix',
33 | 'DenseMatrix',
34 | ]
35 |
36 |
37 | def _java2py(sc, r, encoding="bytes"):
38 | if isinstance(r, JavaObject):
39 | clsName = r.getClass().getSimpleName()
40 | # convert RDD into JavaRDD
41 | if clsName != 'JavaRDD' and clsName.endswith("RDD"):
42 | r = r.toJavaRDD()
43 | clsName = 'JavaRDD'
44 |
45 | if clsName == 'JavaRDD':
46 | jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
47 | return RDD(jrdd, sc)
48 |
49 | if clsName == 'Dataset':
50 | return DataFrame(r, SQLContext.getOrCreate(sc))
51 |
52 | if clsName in _picklable_classes:
53 | r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
54 |
55 | elif isinstance(r, (JavaArray, JavaList, JavaMap)):
56 | try:
57 | r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
58 | except Py4JJavaError:
59 | pass # not pickable
60 |
61 | if isinstance(r, (bytearray, bytes)):
62 | r = PickleSerializer().loads(bytes(r), encoding=encoding)
63 | return r
64 |
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dirty Cat: Dealing with dirty categorical (strings).
2 |
3 | DirtyCat(Scala) is a package that leverage Spark ML to perform large scale Machine Learning, and provides an alternative to encode string variables.
4 | This package is largely based on the python original code, https://github.com/dirty-cat
5 |
6 |
7 | ## Documentation
8 | * https://github.com/dirty-cat
9 | * Patricio Cerda, Gaël Varoquaux, Balázs Kégl. Similarity encoding for learning with dirty categorical variables. Machine Learning journal, Springer. 2018.
10 |
11 |
12 | ### Getting started: How to use it
13 |
14 | The DirtyCat project is built for both Scala 2.11.x against Spark v2.3.0.
15 |
16 | This package is provided as it is, hence, you will have to install it by
17 | yourself. Here are some indications to start using it.
18 |
19 |
20 | ### Build it by yourself: Installation
21 |
22 | This project can be built with [SBT](https://www.scala-sbt.org/) 1.1.x.
23 |
24 |
25 | Change build.sbt to satisfy your scala/spark installations.
26 | Then, run on the command line
27 | ```{.bash}
28 | sbt clean
29 |
30 | sbt compile
31 |
32 | sbt package
33 | ```
34 |
35 | This will generate a .jar file in: target/scala_VERSION/PACKAGE.jar, where
36 | PACKAGE = com.rakuten.dirty_cat_VERSION-0.1-SNAPSHOT.jar
37 |
38 |
39 | If you are using Jupyter notebooks (scala), you can
40 | add this file to your toree-spark-options in your Jupyter kernel.
41 |
42 | * Find your available kernesls running:
43 | ```
44 | jupyter kernelspec list
45 | ```
46 | * Go to your Scala kernel and add:
47 | ```{.python}
48 | "env": {
49 | "DEFAULT_INTERPRETER": "Scala",
50 | "__TOREE_SPARK_OPTS__": "--conf spark.driver.memory=2g --conf spark.executor.cores=4 --conf spark.executor.memory=1g --jars PATH/target/scala_VERSION/PACKAGE.jar
51 | }
52 | ```
53 |
54 |
55 | To submit your spark application, run
56 | ```{.bash}
57 | spark-submit --master local[3] --jars target/scala-2.11/dirty_cat_2.11-1.0.jar YOUR_APPLICATION
58 | ```
59 |
60 |
61 | ### Ceate local package
62 | ```{.bash}
63 | make publish
64 | ```
65 |
66 | ### Usage with Spark ML
67 |
68 | #### Declaration
69 | ```{.scala}
70 | import com.rakuten.dirty_cat.feature.SimilarityEncoder
71 |
72 | val encoder = (new SimilarityEncoder()
73 | .setInputCol("devices")
74 | .setOutputCol("devicesEncoded")
75 | .setSimilarityType("nGram")
76 | .setVocabSize(1000))
77 | ```
78 |
79 | #### Using it in a pipeline
80 | ```{.scala}
81 | import org.apache.spark.ml.Pipeline
82 |
83 | val pipeline = (new Pipeline().setStages(Array(encoder, YOUR_ESTIMATOR)))
84 | val pipelineModel = pipeline.fit(dataframe)
85 | ```
86 |
87 | #### Serialization
88 | ```{.scala}
89 | pipelineModel.write.overwrite().save("pipeline.parquet")
90 | ```
91 |
92 |
93 |
94 | ## History
95 |
96 | Andrés Hoyos-Idrobo started this implementation of DirtyCat as a way to improve his Spark/Scala skills.
97 |
98 | Contributions from:
99 |
100 | * Andrés Hoyos-Idrobo
101 |
102 |
103 | Corporate (Code) Contributors:
104 | * Rakuten Institute of Technology
105 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/utils/StringDistances.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo
3 | * under one or more contributor license agreements.
4 | * See the NOTICE file distributed with this work for additional information
5 | * regarding copyright ownership.
6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file
7 | * to You under the Apache License, Version 2.0
8 | * (the "License"); you may not use this file except in compliance with
9 | * the License. You may obtain a copy of the License at
10 | *
11 | * http://www.apache.org/licenses/LICENSE-2.0
12 | *
13 | * Unless required by applicable law or agreed to in writing, software
14 | * distributed under the License is distributed on an "AS IS" BASIS,
15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | * See the License for the specific language governing permissions and
17 | * limitations under the License.
18 | */
19 |
20 | package com.rakuten.dirty_cat.util
21 |
22 |
23 | import scala.collection.mutable
24 | import scala.collection.parallel.ParSeq
25 | import org.apache.commons.lang3.StringUtils
26 |
27 |
28 | private[dirty_cat] object StringSimilarity{
29 |
30 |
31 | private[util] def getNGrams(string: String, n: Int): List[List[String]] = {
32 | // start = and end =
33 | val tokens = List("") ::: string.toLowerCase().split("").toList ::: List("")
34 | val nGram = tokens.sliding(n).toList
35 | nGram }
36 |
37 |
38 | private[util] def getCounts(nGram: List[List[String]]): Map[List[String],Int] = {
39 | nGram.groupBy(identity).mapValues(_.size) }
40 |
41 |
42 | private[util] def getNGramSimilarity(string1: String, string2: String, n: Int): Double = {
43 | val ngrams1 = getNGrams(string1, n)
44 | val ngrams2 = getNGrams(string2, n)
45 | val counts1 = getCounts(ngrams1)
46 | val counts2 = getCounts(ngrams2)
47 |
48 | val sameGrams = (counts1.keySet.intersect(counts2.keySet)
49 | .map(k => k -> List(counts1(k), counts2(k))).toMap)
50 |
51 | val nSameGrams = sameGrams.size
52 | val nAllGrams = ngrams1.length + ngrams2.length
53 |
54 | val similarity = nSameGrams.toDouble / (nAllGrams.toDouble - nSameGrams.toDouble)
55 |
56 | similarity }
57 |
58 |
59 | def getLevenshteinRatio(string1: String, string2: String): Double = {
60 | val totalLength = (string1.length + string2.length).toDouble
61 | if (totalLength == 0D){ 1D } else { (totalLength - StringUtils.getLevenshteinDistance(string1, string2)) / totalLength }}
62 |
63 |
64 | def getJaroWinklerRatio(string1: String, string2: String): Double = {
65 | val totalLength = (string1.length + string2.length).toDouble
66 | if (totalLength == 0D){ 1D } else { (totalLength - StringUtils.getJaroWinklerDistance(string1, string2)) / totalLength }}
67 |
68 |
69 | def getNGramSimilaritySeq(data: Seq[String], categories: Seq[String], n: Int): Seq[Seq[Double]] = {
70 | data.map{xi => categories.map{yi => getNGramSimilarity(xi, yi, n)}}}
71 |
72 |
73 | def getLevenshteinSimilaritySeq(data: Seq[String], categories: Seq[String]): Seq[Seq[Double]] = {
74 | data.map{xi => categories.map{yi => getLevenshteinRatio(xi, yi)}}}
75 |
76 |
77 | def getJaroWinklerSimilaritySeq(data: Seq[String], categories: Seq[String]): Seq[Seq[Double]] = {
78 | data.map{xi => categories.map{yi => getJaroWinklerRatio(xi, yi)}}}
79 | }
80 |
81 |
--------------------------------------------------------------------------------
/python/setup.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | from __future__ import print_function
13 | import sys
14 | import os
15 | import glob
16 | from setuptools import setup, find_packages
17 | from shutil import copyfile, copytree, rmtree
18 |
19 |
20 | basedir = os.path.dirname(os.path.abspath(__file__))
21 | os.chdir(basedir)
22 |
23 | # A temporary path so we can access above the Python project root and fetch scripts and jars we need
24 | JARS_TARGET = os.path.join(basedir, "dirty_cat_spark/lib")
25 | DIRTY_CAT_HOME = os.path.abspath("../")
26 |
27 | # Figure out where the jars are we need to package with PySpark.
28 | JARS_PATH = glob.glob(os.path.join(DIRTY_CAT_HOME, "target/scala-*"))
29 |
30 | if len(JARS_PATH) == 1:
31 | JARS_PATH = JARS_PATH[0]
32 | else:
33 | raise IOError("cannot find jar files."
34 | " Please make user to run sbt package first")
35 |
36 |
37 | def _supports_symlinks():
38 | """Check if the system supports symlinks (e.g. *nix) or not."""
39 | return getattr(os, "symlink", None) is not None
40 |
41 |
42 | if _supports_symlinks():
43 | os.symlink(JARS_PATH, JARS_TARGET)
44 | else:
45 | # For windows fall back to the slower copytree
46 | copytree(JARS_PATH, JARS_TARGET)
47 |
48 |
49 |
50 | def setup_package():
51 |
52 | try:
53 | def f(*path):
54 | return open(os.path.join(basedir, *path))
55 |
56 | setup(
57 | name='dirty_cat_spark',
58 | maintainer='Andres Hoyos Idrobo',
59 | maintainer_email='andres.hoyosidrobo@rakuten.com',
60 | version=f('../version.txt').read().strip(),
61 | description='Similarity-based embedding to encode dirty categorical strings in PySpark.',
62 | long_description=f('../README.md').read(),
63 | # url='https://github.com/XXX/dirty_cat',
64 | license='Apache License 2.0',
65 |
66 | keywords='spark pyspark categorical ml',
67 |
68 | classifiers=[
69 | 'Development Status :: 2 - Pre-Alpha',
70 | 'Environment :: Other Environment',
71 | 'Intended Audience :: Developers',
72 | 'License :: OSI Approved :: Apache Software License',
73 | 'Operating System :: OS Independent',
74 | 'Programming Language :: Python',
75 | 'Programming Language :: Python :: 2',
76 | 'Programming Language :: Python :: 2.7',
77 | 'Topic :: Software Development :: Libraries',
78 | 'Topic :: Scientific/Engineering :: Information Analysis',
79 | 'Topic :: Utilities',
80 | ],
81 | install_requires=open('./requirements.txt').read().split(),
82 |
83 | packages=find_packages(exclude=['test']),
84 | include_package_data=True, # Needed to install jar file
85 | )
86 |
87 | finally:
88 | # print("here")
89 | if _supports_symlinks():
90 | os.remove(JARS_TARGET)
91 | else:
92 | rmtree(JARS_TARGET)
93 |
94 |
95 | if __name__ == '__main__':
96 | setup_package()
97 |
--------------------------------------------------------------------------------
/src/test/scala/com/rakuten/dirty_cat/feature/SimilarityEncoderTestSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo
3 | * under one or more contributor license agreements.
4 | * See the NOTICE file distributed with this work for additional information
5 | * regarding copyright ownership.
6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file
7 | * to You under the Apache License, Version 2.0
8 | * (the "License"); you may not use this file except in compliance with
9 | * the License. You may obtain a copy of the License at
10 | *
11 | * http://www.apache.org/licenses/LICENSE-2.0
12 | *
13 | * Unless required by applicable law or agreed to in writing, software
14 | * distributed under the License is distributed on an "AS IS" BASIS,
15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | * See the License for the specific language governing permissions and
17 | * limitations under the License.
18 | */
19 |
20 | package com.rakuten.dirty_cat.feature
21 |
22 |
23 | import com.holdenkarau.spark.testing.SharedSparkContext
24 | import org.apache.spark.sql.{SQLContext, DataFrame}
25 | import org.apache.spark.SparkException
26 | import org.scalatest.FunSuite
27 | import org.scalatest.Matchers
28 | import org.apache.spark.sql.types.{StringType, IntegerType, StructType, StructField}
29 | import org.apache.spark.sql.Row
30 |
31 |
32 |
33 | class SimilarityEncoderSuite extends FunSuite with SharedSparkContext {
34 |
35 | import com.rakuten.dirty_cat.feature.{SimilarityEncoder, SimilarityEncoderModel}
36 |
37 |
38 | private def generateDataFrame(): DataFrame = {
39 |
40 | val schema = StructType(List(StructField("id", IntegerType),
41 | StructField("name", StringType)))
42 |
43 | val sqlContext = new SQLContext(sc)
44 |
45 | val rdd = sc.parallelize(Seq(
46 | Row(0, "andres"),
47 | Row(1, "andrea"),
48 | Row(2, "carlos I"),
49 | Row(3, "camilo II"),
50 | Row(4, "camila de aragon"),
51 | Row(5, "guido"),
52 | Row(6, "camilo II"),
53 | Row(7, "guido"),
54 | Row(8, "guido"),
55 | Row(9, "andrea"),
56 | Row(10, "andrea"),
57 | Row(11, "camila de aragon")))
58 |
59 | val dataframe = sqlContext.createDataFrame(rdd, schema)
60 |
61 | dataframe
62 |
63 | }
64 |
65 |
66 | test("coverage") {
67 |
68 | List("leverstein", "nGram", "jako").map{similarity =>
69 |
70 | val encoder = (new SimilarityEncoder()
71 | .setInputCol("name")
72 | .setOutputCol("nameEncoded")
73 | .setSimilarityType(similarity))
74 | }
75 | }
76 |
77 |
78 |
79 | test("fit"){
80 |
81 | val dataframe = generateDataFrame()
82 |
83 | val encoder = (new SimilarityEncoder()
84 | .setInputCol("name")
85 | .setOutputCol("nameEncoded")
86 | .setSimilarityType("nGram"))
87 |
88 | val encoderModel = encoder.fit(dataframe)
89 |
90 | val dataframeEncoded = encoderModel.transform(dataframe)
91 |
92 |
93 | }
94 |
95 |
96 | test("pipeline"){
97 |
98 | import org.apache.spark.ml.{Pipeline, PipelineModel}
99 | import org.apache.spark.ml.feature.StandardScaler
100 |
101 | val dataframe = generateDataFrame()
102 |
103 | val encoder = (new SimilarityEncoder()
104 | .setInputCol("name")
105 | .setOutputCol("nameEncoded")
106 | .setSimilarityType("nGram"))
107 |
108 | val scaler = (new StandardScaler()
109 | .setInputCol("nameEncoded")
110 | .setOutputCol("scaledFeatures"))
111 |
112 | val pipeline = (new Pipeline()
113 | .setStages(Array(encoder, scaler)))
114 |
115 | val pipelineModel = pipeline.fit(dataframe)
116 |
117 | val dataframeFeatures = pipelineModel.transform(dataframe)
118 |
119 | }
120 |
121 | }
122 |
123 |
124 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/spark_utils/OpenHashMap.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | *
17 | * This code is a modified version of the original Spark 2.1 implementation.
18 | */
19 | package com.rakuten.dirty_cat.util.collection
20 |
21 | import scala.reflect.ClassTag
22 |
23 | // import org.apache.spark.annotation.DeveloperApi
24 |
25 | /**
26 | * :: DeveloperApi ::
27 | * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
28 | * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
29 | * space overhead.
30 | *
31 | * Under the hood, it uses our OpenHashSet implementation.
32 | */
33 | // @DeveloperApi
34 | private[dirty_cat]
35 | class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
36 | initialCapacity: Int)
37 | extends Iterable[(K, V)]
38 | with Serializable {
39 |
40 | def this() = this(64)
41 |
42 | protected var _keySet = new OpenHashSet[K](initialCapacity)
43 |
44 | // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
45 | // bug that would generate two arrays (one for Object and one for specialized T).
46 | private var _values: Array[V] = _
47 | _values = new Array[V](_keySet.capacity)
48 |
49 | @transient private var _oldValues: Array[V] = null
50 |
51 | // Treat the null key differently so we can use nulls in "data" to represent empty items.
52 | private var haveNullValue = false
53 | private var nullValue: V = null.asInstanceOf[V]
54 |
55 | override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
56 |
57 | /** Tests whether this map contains a binding for a key. */
58 | def contains(k: K): Boolean = {
59 | if (k == null) {
60 | haveNullValue
61 | } else {
62 | _keySet.getPos(k) != OpenHashSet.INVALID_POS
63 | }
64 | }
65 |
66 | /** Get the value for a given key */
67 | def apply(k: K): V = {
68 | if (k == null) {
69 | nullValue
70 | } else {
71 | val pos = _keySet.getPos(k)
72 | if (pos < 0) {
73 | null.asInstanceOf[V]
74 | } else {
75 | _values(pos)
76 | }
77 | }
78 | }
79 |
80 | /** Set the value for a key */
81 | def update(k: K, v: V) {
82 | if (k == null) {
83 | haveNullValue = true
84 | nullValue = v
85 | } else {
86 | val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
87 | _values(pos) = v
88 | _keySet.rehashIfNeeded(k, grow, move)
89 | _oldValues = null
90 | }
91 | }
92 |
93 | /**
94 | * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
95 | * set its value to mergeValue(oldValue).
96 | *
97 | * @return the newly updated value.
98 | */
99 | def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
100 | if (k == null) {
101 | if (haveNullValue) {
102 | nullValue = mergeValue(nullValue)
103 | } else {
104 | haveNullValue = true
105 | nullValue = defaultValue
106 | }
107 | nullValue
108 | } else {
109 | val pos = _keySet.addWithoutResize(k)
110 | if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
111 | val newValue = defaultValue
112 | _values(pos & OpenHashSet.POSITION_MASK) = newValue
113 | _keySet.rehashIfNeeded(k, grow, move)
114 | newValue
115 | } else {
116 | _values(pos) = mergeValue(_values(pos))
117 | _values(pos)
118 | }
119 | }
120 | }
121 |
122 | override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
123 | var pos = -1
124 | var nextPair: (K, V) = computeNextPair()
125 |
126 | /** Get the next value we should return from next(), or null if we're finished iterating */
127 | def computeNextPair(): (K, V) = {
128 | if (pos == -1) { // Treat position -1 as looking at the null value
129 | if (haveNullValue) {
130 | pos += 1
131 | return (null.asInstanceOf[K], nullValue)
132 | }
133 | pos += 1
134 | }
135 | pos = _keySet.nextPos(pos)
136 | if (pos >= 0) {
137 | val ret = (_keySet.getValue(pos), _values(pos))
138 | pos += 1
139 | ret
140 | } else {
141 | null
142 | }
143 | }
144 |
145 | def hasNext: Boolean = nextPair != null
146 |
147 | def next(): (K, V) = {
148 | val pair = nextPair
149 | nextPair = computeNextPair()
150 | pair
151 | }
152 | }
153 |
154 | // The following member variables are declared as protected instead of private for the
155 | // specialization to work (specialized class extends the non-specialized one and needs access
156 | // to the "private" variables).
157 | // They also should have been val's. We use var's because there is a Scala compiler bug that
158 | // would throw illegal access error at runtime if they are declared as val's.
159 | protected var grow = (newCapacity: Int) => {
160 | _oldValues = _values
161 | _values = new Array[V](newCapacity)
162 | }
163 |
164 | protected var move = (oldPos: Int, newPos: Int) => {
165 | _values(newPos) = _oldValues(oldPos)
166 | }
167 | }
168 |
169 |
--------------------------------------------------------------------------------
/python/dirty_cat_spark/feature/encoder.py:
--------------------------------------------------------------------------------
1 | from pyspark import since, keyword_only, SparkContext
2 | from pyspark.ml.param import Param, Params, TypeConverters
3 | from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasHandleInvalid
4 | from pyspark.ml.util import JavaMLReadable, JavaMLWritable
5 | from pyspark.ml.wrapper import _jvm
6 | from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper
7 |
8 | from dirty_cat_spark.utils.java_reader import CustomJavaMLReader
9 |
10 |
11 |
12 | class SimilarityEncoder(JavaEstimator, HasInputCol, HasOutputCol,
13 | HasHandleInvalid, JavaMLReadable, JavaMLWritable):
14 | """
15 |
16 | >>> encoder = SimilarityEncoder(inputCol="names", outputCol="encoderNames")
17 | """
18 |
19 | vocabSize = Param(Params._dummy(), "vocabSize", "",
20 | typeConverter=TypeConverters.toInt)
21 |
22 | nGramSize = Param(Params._dummy(), "nGramSize", "",
23 | typeConverter=TypeConverters.toInt)
24 |
25 | similarityType = Param(Params._dummy(), "similarityType", "",
26 | typeConverter=TypeConverters.toString)
27 |
28 | handleInvalid = Param(Params._dummy(), "handleInvalid", "",
29 | typeConverter=TypeConverters.toString)
30 |
31 | stringOrderType = Param(Params._dummy(), "stringOrderType", "",
32 | typeConverter=TypeConverters.toString)
33 |
34 | @keyword_only
35 | def __init__(self, inputCol=None, outputCol=None,
36 | nGramSize=3, similarityType="nGram",
37 | handleInvalid="keep",
38 | stringOrderType="frequencyDesc",
39 | vocabSize=100):
40 | """
41 | __init__(self, inputCol=None, outputCol=None,
42 | nGramSize=3, similarityType="nGram",
43 | handleInvalid="keep", stringOrderType="frequencyDesc",
44 | vocabSize=100)
45 | """
46 | super(SimilarityEncoder, self).__init__()
47 |
48 | self._java_obj = self._new_java_obj(
49 | "com.rakuten.dirty_cat.feature.SimilarityEncoder", self.uid)
50 |
51 | self._setDefault(nGramSize=3,
52 | # vocabSize=100,
53 | stringOrderType="frequencyDesc",
54 | handleInvalid="keep",
55 | similarityType="nGram")
56 |
57 | kwargs = self._input_kwargs
58 | self.setParams(**kwargs)
59 |
60 | @keyword_only
61 | def setParams(self, inputCol=None, outputCol=None,
62 | nGramSize=3, similarityType="nGram",
63 | handleInvalid="keep",
64 | stringOrderType="frequencyDesc",
65 | vocabSize=100):
66 | """
67 | setParams(self, inputCol=None, outputCol=None, nGramSize=3,
68 | similarityType="nGram", handleInvalid="keep",
69 | stringOrderType="frequencyDesc", vocabSize=100)
70 |
71 | Set the params for the SimilarityEncoder
72 | """
73 | kwargs = self._input_kwargs
74 | return self._set(**kwargs)
75 |
76 |
77 | def setStringOrderType(self, value):
78 | return self._set(stringOrderType=value)
79 |
80 | def setSimilarityType(self, value):
81 | return self._set(similarityType=value)
82 |
83 | def setNGramSize(self, value):
84 | return self._set(nGramSize=value)
85 |
86 | def setVocabSize(self, value):
87 | return self._set(vocabSize=value)
88 |
89 |
90 | def getStringOrderType(self):
91 | return self.getOrDefault(self.stringOrderType)
92 |
93 | def getSimilarityType(self):
94 | return self.getOrDefault(self.similarityType)
95 |
96 | def getNGramSize(self):
97 | return self.getOrDefault(self.nGramSize)
98 |
99 | def getVocabSize(self):
100 | return self.getOrDefault(self.vocabSize)
101 |
102 |
103 | def _create_model(self, java_model):
104 | return SimilarityEncoderModel(java_model)
105 |
106 |
107 |
108 | class SimilarityEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable):
109 | """Model fitted by :py:class:`SimilarityEncoder`. """
110 |
111 | @property
112 | def vocabularyReference(self):
113 | """
114 | """
115 | return self._call_java("vocabularyReference")
116 |
117 | # # @classmethod
118 | # # def from_vocabularyReference(cls, vocabularyReference, inputCol,
119 | # # outputCol=None, nGramSize=None,
120 | # # similarityType=None, handleInvalid=None,
121 | # # stringOrderType=None, vocabSize=None):
122 | # # """
123 | # # Construct the model directly from an array of label strings,
124 | # # requires an active SparkContext.
125 | # # """
126 | # # sc = SparkContext._active_spark_context
127 | # # java_class = sc._gateway.jvm.java.lang.String
128 | # # jVocabularyReference = SimilarityEncoderModel._new_java_array(
129 | # # vocabularyReference, java_class)
130 | # # model = SimilarityEncoderModel._create_from_java_class(
131 | # # 'dirty_cat.feature.SimilarityEncoderModel', jVocabularyReference)
132 | # # model.setInputCol(inputCol)
133 | # # if outputCol is not None:
134 | # # model.setOutputCol(outputCol)
135 | # # if nGramSize is not None:
136 | # # model.setNGramSize(nGramSize)
137 | # # if similarityType is not None:
138 | # # model.setSimilarityType(similarityType)
139 | # # if handleInvalid is not None:
140 | # # model.setHandleInvalid(handleInvalid)
141 | # # if stringOrderType is not None:
142 | # # model.setStringOrderType(stringOrderType)
143 | # # if vocabSize is not None:
144 | # # model.setVocabSize(vocabSize)
145 | # # return model
146 |
147 |
148 | # # @staticmethod
149 | # # def _from_java(java_stage):
150 | # # """
151 | # # Given a Java object, create and return a Python wrapper of it.
152 | # # Used for ML persistence.
153 | # # Meta-algorithms such as Pipeline should override this method as a classmethod.
154 | # # """
155 | # # # Generate a default new instance from the stage_name class.
156 | # # py_type =SimilarityEncoderModel
157 | # # if issubclass(py_type, JavaParams):
158 | # # # Load information from java_stage to the instance.
159 | # # py_stage = py_type()
160 | # # py_stage._java_obj = java_stage
161 | # # py_stage._resetUid(java_stage.uid())
162 | # # py_stage._transfer_params_from_java()
163 |
164 | # # return py_stage
165 |
166 | # # @classmethod
167 | # # def read(cls):
168 | # # """Returns an MLReader instance for this class."""
169 | # # return CustomJavaMLReader(
170 | # # cls, 'dirty_cat.feature.SimilarityEncoderModel')
171 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/spark_utils/BitSet.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | *
17 | * This code is a modified version of the original Spark 2.1 implementation.
18 | */
19 | package com.rakuten.dirty_cat.util.collection
20 |
21 | import java.util.Arrays
22 |
23 | /**
24 | * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
25 | * safety/bound checking.
26 | */
27 | class BitSet(numBits: Int) extends Serializable {
28 |
29 | private val words = new Array[Long](bit2words(numBits))
30 | private val numWords = words.length
31 |
32 | /**
33 | * Compute the capacity (number of bits) that can be represented
34 | * by this bitset.
35 | */
36 | def capacity: Int = numWords * 64
37 |
38 | /**
39 | * Clear all set bits.
40 | */
41 | def clear(): Unit = Arrays.fill(words, 0)
42 |
43 | /**
44 | * Set all the bits up to a given index
45 | */
46 | def setUntil(bitIndex: Int): Unit = {
47 | val wordIndex = bitIndex >> 6 // divide by 64
48 | Arrays.fill(words, 0, wordIndex, -1)
49 | if(wordIndex < words.length) {
50 | // Set the remaining bits (note that the mask could still be zero)
51 | val mask = ~(-1L << (bitIndex & 0x3f))
52 | words(wordIndex) |= mask
53 | }
54 | }
55 |
56 | /**
57 | * Clear all the bits up to a given index
58 | */
59 | def clearUntil(bitIndex: Int): Unit = {
60 | val wordIndex = bitIndex >> 6 // divide by 64
61 | Arrays.fill(words, 0, wordIndex, 0)
62 | if(wordIndex < words.length) {
63 | // Clear the remaining bits
64 | val mask = -1L << (bitIndex & 0x3f)
65 | words(wordIndex) &= mask
66 | }
67 | }
68 |
69 | /**
70 | * Compute the bit-wise AND of the two sets returning the
71 | * result.
72 | */
73 | def &(other: BitSet): BitSet = {
74 | val newBS = new BitSet(math.max(capacity, other.capacity))
75 | val smaller = math.min(numWords, other.numWords)
76 | assert(newBS.numWords >= numWords)
77 | assert(newBS.numWords >= other.numWords)
78 | var ind = 0
79 | while( ind < smaller ) {
80 | newBS.words(ind) = words(ind) & other.words(ind)
81 | ind += 1
82 | }
83 | newBS
84 | }
85 |
86 | /**
87 | * Compute the bit-wise OR of the two sets returning the
88 | * result.
89 | */
90 | def |(other: BitSet): BitSet = {
91 | val newBS = new BitSet(math.max(capacity, other.capacity))
92 | assert(newBS.numWords >= numWords)
93 | assert(newBS.numWords >= other.numWords)
94 | val smaller = math.min(numWords, other.numWords)
95 | var ind = 0
96 | while( ind < smaller ) {
97 | newBS.words(ind) = words(ind) | other.words(ind)
98 | ind += 1
99 | }
100 | while( ind < numWords ) {
101 | newBS.words(ind) = words(ind)
102 | ind += 1
103 | }
104 | while( ind < other.numWords ) {
105 | newBS.words(ind) = other.words(ind)
106 | ind += 1
107 | }
108 | newBS
109 | }
110 |
111 | /**
112 | * Compute the symmetric difference by performing bit-wise XOR of the two sets returning the
113 | * result.
114 | */
115 | def ^(other: BitSet): BitSet = {
116 | val newBS = new BitSet(math.max(capacity, other.capacity))
117 | val smaller = math.min(numWords, other.numWords)
118 | var ind = 0
119 | while (ind < smaller) {
120 | newBS.words(ind) = words(ind) ^ other.words(ind)
121 | ind += 1
122 | }
123 | if (ind < numWords) {
124 | Array.copy( words, ind, newBS.words, ind, numWords - ind )
125 | }
126 | if (ind < other.numWords) {
127 | Array.copy( other.words, ind, newBS.words, ind, other.numWords - ind )
128 | }
129 | newBS
130 | }
131 |
132 | /**
133 | * Compute the difference of the two sets by performing bit-wise AND-NOT returning the
134 | * result.
135 | */
136 | def andNot(other: BitSet): BitSet = {
137 | val newBS = new BitSet(capacity)
138 | val smaller = math.min(numWords, other.numWords)
139 | var ind = 0
140 | while (ind < smaller) {
141 | newBS.words(ind) = words(ind) & ~other.words(ind)
142 | ind += 1
143 | }
144 | if (ind < numWords) {
145 | Array.copy( words, ind, newBS.words, ind, numWords - ind )
146 | }
147 | newBS
148 | }
149 |
150 | /**
151 | * Sets the bit at the specified index to true.
152 | * @param index the bit index
153 | */
154 | def set(index: Int) {
155 | val bitmask = 1L << (index & 0x3f) // mod 64 and shift
156 | words(index >> 6) |= bitmask // div by 64 and mask
157 | }
158 |
159 | def unset(index: Int) {
160 | val bitmask = 1L << (index & 0x3f) // mod 64 and shift
161 | words(index >> 6) &= ~bitmask // div by 64 and mask
162 | }
163 |
164 | /**
165 | * Return the value of the bit with the specified index. The value is true if the bit with
166 | * the index is currently set in this BitSet; otherwise, the result is false.
167 | *
168 | * @param index the bit index
169 | * @return the value of the bit with the specified index
170 | */
171 | def get(index: Int): Boolean = {
172 | val bitmask = 1L << (index & 0x3f) // mod 64 and shift
173 | (words(index >> 6) & bitmask) != 0 // div by 64 and mask
174 | }
175 |
176 | /**
177 | * Get an iterator over the set bits.
178 | */
179 | def iterator: Iterator[Int] = new Iterator[Int] {
180 | var ind = nextSetBit(0)
181 | override def hasNext: Boolean = ind >= 0
182 | override def next(): Int = {
183 | val tmp = ind
184 | ind = nextSetBit(ind + 1)
185 | tmp
186 | }
187 | }
188 |
189 |
190 | /** Return the number of bits set to true in this BitSet. */
191 | def cardinality(): Int = {
192 | var sum = 0
193 | var i = 0
194 | while (i < numWords) {
195 | sum += java.lang.Long.bitCount(words(i))
196 | i += 1
197 | }
198 | sum
199 | }
200 |
201 | /**
202 | * Returns the index of the first bit that is set to true that occurs on or after the
203 | * specified starting index. If no such bit exists then -1 is returned.
204 | *
205 | * To iterate over the true bits in a BitSet, use the following loop:
206 | *
207 | * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
208 | * // operate on index i here
209 | * }
210 | *
211 | * @param fromIndex the index to start checking from (inclusive)
212 | * @return the index of the next set bit, or -1 if there is no such bit
213 | */
214 | def nextSetBit(fromIndex: Int): Int = {
215 | var wordIndex = fromIndex >> 6
216 | if (wordIndex >= numWords) {
217 | return -1
218 | }
219 |
220 | // Try to find the next set bit in the current word
221 | val subIndex = fromIndex & 0x3f
222 | var word = words(wordIndex) >> subIndex
223 | if (word != 0) {
224 | return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
225 | }
226 |
227 | // Find the next set bit in the rest of the words
228 | wordIndex += 1
229 | while (wordIndex < numWords) {
230 | word = words(wordIndex)
231 | if (word != 0) {
232 | return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
233 | }
234 | wordIndex += 1
235 | }
236 |
237 | -1
238 | }
239 |
240 | /** Return the number of longs it would take to hold numBits. */
241 | private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
242 | }
243 |
244 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/spark_persistence/ReadWrite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | *
17 | * This code is a modified version of the original Spark 2.1 implementation.
18 | */
19 | package com.rakuten.dirty_cat.persistence
20 |
21 | import com.rakuten.dirty_cat.utils.Utils
22 |
23 | import org.apache.hadoop.fs.Path
24 | import org.apache.spark.SparkContext
25 | import org.apache.spark.ml.param.{ParamPair, Params}
26 | import org.apache.spark.ml.util.{MLReader, MLWriter}
27 | import org.json4s.JsonAST.{JObject, JValue, JArray, JField}
28 | import org.json4s.JsonDSL._
29 | import org.json4s.jackson.JsonMethods._
30 | import org.json4s.{DefaultFormats, _}
31 |
32 |
33 |
34 | // This originates from apache-spark DefaultPramsWriter copy paste
35 | private[dirty_cat] object DefaultParamsWriter {
36 |
37 | /**
38 | * Saves metadata + Params to: path + "/metadata"
39 | * - class
40 | * - timestamp
41 | * - sparkVersion
42 | * - uid
43 | * - paramMap
44 | * - (optionally, extra metadata)
45 | *
46 | * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
47 | * @param paramMap If given, this is saved in the "paramMap" field.
48 | * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
49 | * [[org.apache.spark.ml.param.Param.jsonEncode()]].
50 | */
51 | def saveMetadata(
52 | instance: Params,
53 | path: String,
54 | sc: SparkContext,
55 | extraMetadata: Option[JObject] = None,
56 | paramMap: Option[JValue] = None): Unit = {
57 |
58 | val metadataPath = new Path(path, "metadata").toString
59 | val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
60 | sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
61 | }
62 |
63 | /**
64 | * Helper for [[saveMetadata()]] which extracts the JSON to save.
65 | * This is useful for ensemble models which need to save metadata for many sub-models.
66 | *
67 | * @see [[saveMetadata()]] for details on what this includes.
68 | */
69 | def getMetadataToSave(
70 | instance: Params,
71 | sc: SparkContext,
72 | extraMetadata: Option[JObject] = None,
73 | paramMap: Option[JValue] = None): String = {
74 | val uid = instance.uid
75 | val cls = instance.getClass.getName
76 | val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
77 | val jsonParams = (paramMap.getOrElse(render(params.map{ case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v))}.toList)))
78 |
79 | val basicMetadata = ("class" -> cls) ~
80 | ("timestamp" -> System.currentTimeMillis()) ~
81 | ("sparkVersion" -> sc.version) ~
82 | ("uid" -> uid) ~
83 | ("paramMap" -> jsonParams)
84 | val metadata = extraMetadata match {
85 | case Some(jObject) =>
86 | basicMetadata ~ jObject
87 | case None =>
88 | basicMetadata
89 | }
90 | val metadataJson: String = compact(render(metadata))
91 | metadataJson
92 | }
93 | }
94 |
95 |
96 |
97 | // This originates from apache-spark DefaultPramsReader copy paste
98 | private[dirty_cat] object DefaultParamsReader {
99 |
100 | /**
101 | * All info from metadata file.
102 | *
103 | * @param params paramMap, as a `JValue`
104 | * @param metadata All metadata, including the other fields
105 | * @param metadataJson Full metadata file String (for debugging)
106 | */
107 | case class Metadata(className: String,
108 | uid: String,
109 | timestamp: Long,
110 | sparkVersion: String,
111 | params: JValue,
112 | metadata: JValue,
113 | metadataJson: String) {
114 |
115 | /**
116 | * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
117 | * This can be useful for getting a Param value before an instance of `Params`
118 | * is available.
119 | */
120 | def getParamValue(paramName: String): JValue = {
121 | implicit val format = DefaultFormats
122 | params match {
123 | case JObject(pairs) =>
124 | val values = pairs.filter { case (pName, jsonValue) =>
125 | pName == paramName
126 | }.map(_._2)
127 | assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
128 | s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
129 | values.head
130 | case _ =>
131 | throw new IllegalArgumentException(
132 | s"Cannot recognize JSON metadata: $metadataJson.")
133 | }
134 | }
135 | }
136 |
137 | /**
138 | * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
139 | *
140 | * @param expectedClassName If non empty, this is checked against the loaded metadata.
141 | * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
142 | */
143 | def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
144 | val metadataPath = new Path(path, "metadata").toString
145 | val metadataStr = sc.textFile(metadataPath, 1).first()
146 | parseMetadata(metadataStr, expectedClassName)
147 | }
148 |
149 | /**
150 | * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
151 | * This is a helper function for [[loadMetadata()]].
152 | *
153 | * @param metadataStr JSON string of metadata
154 | * @param expectedClassName If non empty, this is checked against the loaded metadata.
155 | * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
156 | */
157 | def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
158 | val metadata = parse(metadataStr)
159 |
160 | implicit val format = DefaultFormats
161 | val className = (metadata \ "class").extract[String]
162 | val uid = (metadata \ "uid").extract[String]
163 | val timestamp = (metadata \ "timestamp").extract[Long]
164 | val sparkVersion = (metadata \ "sparkVersion").extract[String]
165 | val params = metadata \ "paramMap"
166 | if (expectedClassName.nonEmpty) {
167 | require(className == expectedClassName, s"Error loading metadata: Expected class name" +
168 | s" $expectedClassName but found class name $className")
169 | }
170 |
171 | Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
172 | }
173 |
174 | /**
175 | * Extract Params from metadata, and set them in the instance.
176 | * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]].
177 | * TODO: Move to [[Metadata]] method
178 | */
179 | def getAndSetParams(instance: Params, metadata: Metadata): Unit = {
180 | implicit val format = DefaultFormats
181 | metadata.params match {
182 | case JObject(pairs) =>
183 | pairs.foreach { case (paramName, jsonValue) =>
184 | val param = instance.getParam(paramName)
185 | val value = param.jsonDecode(compact(render(jsonValue)))
186 | instance.set(param, value)
187 | }
188 | case _ =>
189 | throw new IllegalArgumentException(
190 | s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
191 | }
192 | }
193 |
194 | /**
195 | * Load a `Params` instance from the given path, and return it.
196 | * This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]].
197 | */
198 | def loadParamsInstance[T](path: String, sc: SparkContext): T = {
199 | val metadata = DefaultParamsReader.loadMetadata(path, sc)
200 | val cls = Utils.classForName(metadata.className)
201 | cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
202 | }
203 | }
204 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/spark_utils/OpenHashSet.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | *
17 | * This code is a modified version of the original Spark 2.1 implementation.
18 | */
19 | package com.rakuten.dirty_cat.util.collection
20 |
21 | import scala.reflect._
22 | import com.google.common.hash.Hashing.murmur3_32
23 |
24 | /**
25 | * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
26 | * removed.
27 | *
28 | * The underlying implementation uses Scala compiler's specialization to generate optimized
29 | * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
30 | * while incurring much less memory overhead. This can serve as building blocks for higher level
31 | * data structures such as an optimized HashMap.
32 | *
33 | * This OpenHashSet is designed to serve as building blocks for higher level data structures
34 | * such as an optimized hash map. Compared with standard hash set implementations, this class
35 | * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to
36 | * retrieve the position of a key in the underlying array.
37 | *
38 | * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
39 | * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
40 | */
41 | private[dirty_cat]
42 | class OpenHashSet[@specialized(Long, Int) T: ClassTag](
43 | initialCapacity: Int,
44 | loadFactor: Double)
45 | extends Serializable {
46 |
47 | require(initialCapacity <= OpenHashSet.MAX_CAPACITY,
48 | s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements")
49 | require(initialCapacity >= 1, "Invalid initial capacity")
50 | require(loadFactor < 1.0, "Load factor must be less than 1.0")
51 | require(loadFactor > 0.0, "Load factor must be greater than 0.0")
52 |
53 | import OpenHashSet._
54 |
55 | def this(initialCapacity: Int) = this(initialCapacity, 0.7)
56 |
57 | def this() = this(64)
58 |
59 | // The following member variables are declared as protected instead of private for the
60 | // specialization to work (specialized class extends the non-specialized one and needs access
61 | // to the "private" variables).
62 |
63 | protected val hasher: Hasher[T] = {
64 | // It would've been more natural to write the following using pattern matching. But Scala 2.9.x
65 | // compiler has a bug when specialization is used together with this pattern matching, and
66 | // throws:
67 | // scala.tools.nsc.symtab.Types$TypeError: type mismatch;
68 | // found : scala.reflect.AnyValManifest[Long]
69 | // required: scala.reflect.ClassTag[Int]
70 | // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
71 | // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
72 | // ...
73 | val mt = classTag[T]
74 | if (mt == ClassTag.Long) {
75 | (new LongHasher).asInstanceOf[Hasher[T]]
76 | } else if (mt == ClassTag.Int) {
77 | (new IntHasher).asInstanceOf[Hasher[T]]
78 | } else {
79 | new Hasher[T]
80 | }
81 | }
82 |
83 | protected var _capacity = nextPowerOf2(initialCapacity)
84 | protected var _mask = _capacity - 1
85 | protected var _size = 0
86 | protected var _growThreshold = (loadFactor * _capacity).toInt
87 |
88 | protected var _bitset = new BitSet(_capacity)
89 |
90 | def getBitSet: BitSet = _bitset
91 |
92 | // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
93 | // specialization bug that would generate two arrays (one for Object and one for specialized T).
94 | protected var _data: Array[T] = _
95 | _data = new Array[T](_capacity)
96 |
97 | /** Number of elements in the set. */
98 | def size: Int = _size
99 |
100 | /** The capacity of the set (i.e. size of the underlying array). */
101 | def capacity: Int = _capacity
102 |
103 | /** Return true if this set contains the specified element. */
104 | def contains(k: T): Boolean = getPos(k) != INVALID_POS
105 |
106 | /**
107 | * Add an element to the set. If the set is over capacity after the insertion, grow the set
108 | * and rehash all elements.
109 | */
110 | def add(k: T) {
111 | addWithoutResize(k)
112 | rehashIfNeeded(k, grow, move)
113 | }
114 |
115 | def union(other: OpenHashSet[T]): OpenHashSet[T] = {
116 | val iterator = other.iterator
117 | while (iterator.hasNext) {
118 | add(iterator.next())
119 | }
120 | this
121 | }
122 |
123 | /**
124 | * Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
125 | * The caller is responsible for calling rehashIfNeeded.
126 | *
127 | * Use (retval & POSITION_MASK) to get the actual position, and
128 | * (retval & NONEXISTENCE_MASK) == 0 for prior existence.
129 | *
130 | * @return The position where the key is placed, plus the highest order bit is set if the key
131 | * does not exists previously.
132 | */
133 | def addWithoutResize(k: T): Int = {
134 | var pos = hashcode(hasher.hash(k)) & _mask
135 | var delta = 1
136 | while (true) {
137 | if (!_bitset.get(pos)) {
138 | // This is a new key.
139 | _data(pos) = k
140 | _bitset.set(pos)
141 | _size += 1
142 | return pos | NONEXISTENCE_MASK
143 | } else if (_data(pos) == k) {
144 | // Found an existing key.
145 | return pos
146 | } else {
147 | // quadratic probing with values increase by 1, 2, 3, ...
148 | pos = (pos + delta) & _mask
149 | delta += 1
150 | }
151 | }
152 | throw new RuntimeException("Should never reach here.")
153 | }
154 |
155 | /**
156 | * Rehash the set if it is overloaded.
157 | * @param k A parameter unused in the function, but to force the Scala compiler to specialize
158 | * this method.
159 | * @param allocateFunc Callback invoked when we are allocating a new, larger array.
160 | * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
161 | * to a new position (in the new data array).
162 | */
163 | def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
164 | if (_size > _growThreshold) {
165 | rehash(k, allocateFunc, moveFunc)
166 | }
167 | }
168 |
169 | /**
170 | * Return the position of the element in the underlying array, or INVALID_POS if it is not found.
171 | */
172 | def getPos(k: T): Int = {
173 | var pos = hashcode(hasher.hash(k)) & _mask
174 | var delta = 1
175 | while (true) {
176 | if (!_bitset.get(pos)) {
177 | return INVALID_POS
178 | } else if (k == _data(pos)) {
179 | return pos
180 | } else {
181 | // quadratic probing with values increase by 1, 2, 3, ...
182 | pos = (pos + delta) & _mask
183 | delta += 1
184 | }
185 | }
186 | throw new RuntimeException("Should never reach here.")
187 | }
188 |
189 | /** Return the value at the specified position. */
190 | def getValue(pos: Int): T = _data(pos)
191 |
192 | def iterator: Iterator[T] = new Iterator[T] {
193 | var pos = nextPos(0)
194 | override def hasNext: Boolean = pos != INVALID_POS
195 | override def next(): T = {
196 | val tmp = getValue(pos)
197 | pos = nextPos(pos + 1)
198 | tmp
199 | }
200 | }
201 |
202 | /** Return the value at the specified position. */
203 | def getValueSafe(pos: Int): T = {
204 | assert(_bitset.get(pos))
205 | _data(pos)
206 | }
207 |
208 | /**
209 | * Return the next position with an element stored, starting from the given position inclusively.
210 | */
211 | def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
212 |
213 | /**
214 | * Double the table's size and re-hash everything. We are not really using k, but it is declared
215 | * so Scala compiler can specialize this method (which leads to calling the specialized version
216 | * of putInto).
217 | *
218 | * @param k A parameter unused in the function, but to force the Scala compiler to specialize
219 | * this method.
220 | * @param allocateFunc Callback invoked when we are allocating a new, larger array.
221 | * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
222 | * to a new position (in the new data array).
223 | */
224 | private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
225 | val newCapacity = _capacity * 2
226 | require(newCapacity > 0 && newCapacity <= OpenHashSet.MAX_CAPACITY,
227 | s"Can't contain more than ${(loadFactor * OpenHashSet.MAX_CAPACITY).toInt} elements")
228 | allocateFunc(newCapacity)
229 | val newBitset = new BitSet(newCapacity)
230 | val newData = new Array[T](newCapacity)
231 | val newMask = newCapacity - 1
232 |
233 | var oldPos = 0
234 | while (oldPos < capacity) {
235 | if (_bitset.get(oldPos)) {
236 | val key = _data(oldPos)
237 | var newPos = hashcode(hasher.hash(key)) & newMask
238 | var i = 1
239 | var keepGoing = true
240 | // No need to check for equality here when we insert so this has one less if branch than
241 | // the similar code path in addWithoutResize.
242 | while (keepGoing) {
243 | if (!newBitset.get(newPos)) {
244 | // Inserting the key at newPos
245 | newData(newPos) = key
246 | newBitset.set(newPos)
247 | moveFunc(oldPos, newPos)
248 | keepGoing = false
249 | } else {
250 | val delta = i
251 | newPos = (newPos + delta) & newMask
252 | i += 1
253 | }
254 | }
255 | }
256 | oldPos += 1
257 | }
258 |
259 | _bitset = newBitset
260 | _data = newData
261 | _capacity = newCapacity
262 | _mask = newMask
263 | _growThreshold = (loadFactor * newCapacity).toInt
264 | }
265 |
266 | /**
267 | * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
268 | */
269 | private def hashcode(h: Int): Int = murmur3_32().hashInt(h).asInt()
270 |
271 | private def nextPowerOf2(n: Int): Int = {
272 | val highBit = Integer.highestOneBit(n)
273 | if (highBit == n) n else highBit << 1
274 | }
275 | }
276 |
277 |
278 | private[dirty_cat]
279 | object OpenHashSet {
280 |
281 | val MAX_CAPACITY = 1 << 30
282 | val INVALID_POS = -1
283 | val NONEXISTENCE_MASK = 1 << 31
284 | val POSITION_MASK = (1 << 31) - 1
285 |
286 | /**
287 | * A set of specialized hash function implementation to avoid boxing hash code computation
288 | * in the specialized implementation of OpenHashSet.
289 | */
290 | sealed class Hasher[@specialized(Long, Int) T] extends Serializable {
291 | def hash(o: T): Int = o.hashCode()
292 | }
293 |
294 | class LongHasher extends Hasher[Long] {
295 | override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
296 | }
297 |
298 | class IntHasher extends Hasher[Int] {
299 | override def hash(o: Int): Int = o
300 | }
301 |
302 | private def grow1(newSize: Int) {}
303 | private def move1(oldPos: Int, newPos: Int) { }
304 |
305 | private val grow = grow1 _
306 | private val move = move1 _
307 | }
308 |
309 |
--------------------------------------------------------------------------------
/examples/Midwest_Survey.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pyspark.sql.functions import col, when"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "data = (spark\n",
19 | " .read\n",
20 | " .option(\"header\", \"true\")\n",
21 | " .csv(\"../data/FiveThirtyEight_Midwest_Survey.csv\"))\n",
22 | "data = data.where(col('Location (Census Region)').isNotNull())"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "# Encoding nulls\n",
32 | "columns_with_null = [\n",
33 | " 'Location (Census Region)',\n",
34 | " 'Gender', 'Age', \n",
35 | " 'Household Income', 'Education']\n",
36 | "for column in columns_with_null:\n",
37 | " data = data.withColumn(column, when(col(column).isNull(), \"__null\")\n",
38 | " .otherwise(col(column)))"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 4,
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "name": "stdout",
48 | "output_type": "stream",
49 | "text": [
50 | "root\n",
51 | " |-- RespondentID: string (nullable = true)\n",
52 | " |-- In your own words, what would you call the part of the country you live in now?: string (nullable = true)\n",
53 | " |-- Personally identification as a Midwesterner?: string (nullable = true)\n",
54 | " |-- Illinois in MW?: string (nullable = true)\n",
55 | " |-- Indiana in MW?: string (nullable = true)\n",
56 | " |-- Iowa in MW?: string (nullable = true)\n",
57 | " |-- Kansas in MW?: string (nullable = true)\n",
58 | " |-- Michigan in MW?: string (nullable = true)\n",
59 | " |-- Minnesota in MW?: string (nullable = true)\n",
60 | " |-- Missouri in MW?: string (nullable = true)\n",
61 | " |-- Nebraska in MW?: string (nullable = true)\n",
62 | " |-- North Dakota in MW?: string (nullable = true)\n",
63 | " |-- Ohio in MW?: string (nullable = true)\n",
64 | " |-- South Dakota in MW?: string (nullable = true)\n",
65 | " |-- Wisconsin in MW?: string (nullable = true)\n",
66 | " |-- Arkansas in MW?: string (nullable = true)\n",
67 | " |-- Colorado in MW?: string (nullable = true)\n",
68 | " |-- Kentucky in MW?: string (nullable = true)\n",
69 | " |-- Oklahoma in MW?: string (nullable = true)\n",
70 | " |-- Pennsylvania in MW?: string (nullable = true)\n",
71 | " |-- West Virginia in MW?: string (nullable = true)\n",
72 | " |-- Montana in MW?: string (nullable = true)\n",
73 | " |-- Wyoming in MW?: string (nullable = true)\n",
74 | " |-- ZIP Code: string (nullable = true)\n",
75 | " |-- Gender: string (nullable = true)\n",
76 | " |-- Age: string (nullable = true)\n",
77 | " |-- Household Income: string (nullable = true)\n",
78 | " |-- Education: string (nullable = true)\n",
79 | " |-- Location (Census Region): string (nullable = true)\n",
80 | "\n"
81 | ]
82 | }
83 | ],
84 | "source": [
85 | "data.printSchema()"
86 | ]
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {},
91 | "source": [
92 | "## Splitting data into train and test "
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 5,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "data_train, data_test = data.randomSplit([0.6, 0.4], seed=5)"
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {},
107 | "source": [
108 | "## Separating clean, and dirty columns as well a a column we will try to predict"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 6,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "target_column = 'Location (Census Region)'\n",
118 | "dirty_column = 'In your own words, what would you call the part of the country you live in now?'\n",
119 | "clean_columns = [\n",
120 | " 'Personally identification as a Midwesterner?',\n",
121 | " 'Illinois in MW?',\n",
122 | " 'Indiana in MW?',\n",
123 | " 'Kansas in MW?',\n",
124 | " 'Iowa in MW?',\n",
125 | " 'Michigan in MW?',\n",
126 | " 'Minnesota in MW?',\n",
127 | " 'Missouri in MW?',\n",
128 | " 'Nebraska in MW?',\n",
129 | " 'North Dakota in MW?',\n",
130 | " 'Ohio in MW?',\n",
131 | " 'South Dakota in MW?',\n",
132 | " 'Wisconsin in MW?',\n",
133 | " 'Arkansas in MW?',\n",
134 | " 'Colorado in MW?',\n",
135 | " 'Kentucky in MW?',\n",
136 | " 'Oklahoma in MW?',\n",
137 | " 'Pennsylvania in MW?',\n",
138 | " 'West Virginia in MW?',\n",
139 | " 'Montana in MW?',\n",
140 | " 'Wyoming in MW?',\n",
141 | " 'Gender',\n",
142 | " 'Age',\n",
143 | " 'Household Income',\n",
144 | " 'Education'\n",
145 | "]"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": 7,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "from pyspark.ml import Pipeline\n",
155 | "from pyspark.ml.feature import OneHotEncoder\n",
156 | "from pyspark.ml.feature import StandardScaler, VectorAssembler\n",
157 | "from pyspark.ml.feature import StringIndexer, VectorIndexer\n",
158 | "from pyspark.ml.classification import RandomForestClassifier"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 8,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "from dirty_cat_spark.feature.encoder import SimilarityEncoder\n",
168 | "\n",
169 | "\n",
170 | "encoder_similarity = (SimilarityEncoder()\n",
171 | " .setInputCol(dirty_column)\n",
172 | " .setOutputCol(\"encoded\")\n",
173 | " .setSimilarityType(\"nGram\")\n",
174 | " .setVocabSize(200))\n",
175 | "\n",
176 | "string_indexer_dirty = (StringIndexer()\n",
177 | " .setInputCol(dirty_column)\n",
178 | " .setOutputCol(dirty_column + \"_indexed\")\n",
179 | " .setHandleInvalid(\"keep\")) \n",
180 | "\n",
181 | "encoder_hot = (OneHotEncoder()\n",
182 | " .setInputCol(dirty_column + \"_indexed\")\n",
183 | " .setOutputCol(\"encoded\"))"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 9,
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "string_indexer_clean = [(StringIndexer()\n",
193 | " .setInputCol(clean_column)\n",
194 | " .setOutputCol(clean_column + \"_indexed\")\n",
195 | " .setHandleInvalid(\"keep\")) \n",
196 | " for clean_column in clean_columns]\n",
197 | "\n",
198 | "assembler = (VectorAssembler()\n",
199 | " .setInputCols([c + \"_indexed\" for c in clean_columns] + [\"encoded\"])\n",
200 | " .setOutputCol(\"features\"))\n",
201 | "\n",
202 | "vector_indexer = (VectorIndexer()\n",
203 | " .setInputCol(\"features\")\n",
204 | " .setOutputCol(\"featuresIndexed\")\n",
205 | " .setMaxCategories(10)\n",
206 | " .setHandleInvalid(\"skip\"))\n",
207 | "\n",
208 | "scaler = (StandardScaler()\n",
209 | " .setInputCol(\"featuresIndexed\")\n",
210 | " .setOutputCol(\"scaledFeatures\")\n",
211 | " .setWithMean(False))\n",
212 | "\n",
213 | "\n",
214 | "indexed_label = StringIndexer(inputCol=target_column, \n",
215 | " outputCol=\"indexedLabel\")\n",
216 | "\n",
217 | "classifier = (RandomForestClassifier()\n",
218 | " .setFeaturesCol(\"scaledFeatures\")\n",
219 | " .setLabelCol(\"indexedLabel\")\n",
220 | " .setNumTrees(10)\n",
221 | " .setSeed(5))"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 10,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "pipeline_similarity = Pipeline(stages=string_indexer_clean + \n",
231 | " [encoder_similarity, assembler, \n",
232 | " vector_indexer, \n",
233 | " scaler, indexed_label, classifier])"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": 11,
239 | "metadata": {},
240 | "outputs": [],
241 | "source": [
242 | "pipeline_similarity_hot = Pipeline(stages=string_indexer_clean + \n",
243 | " [string_indexer_dirty, encoder_hot, \n",
244 | " assembler, vector_indexer, \n",
245 | " scaler, indexed_label, classifier])"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": null,
251 | "metadata": {},
252 | "outputs": [],
253 | "source": [
254 | "pipeline_similarity_hot_model = pipeline_similarity_hot.fit(data_train)"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "pipeline_similarity_model = pipeline_similarity.fit(data_train)"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": [
272 | "res = pipeline_similarity_model.transform(data_test)\n",
273 | "res_hot = pipeline_similarity_hot_model.transform(data_test)"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "metadata": {},
280 | "outputs": [],
281 | "source": [
282 | "res_df = res.select(\"probability\", \"indexedLabel\", \"prediction\").toPandas()\n",
283 | "res_hot_df = res_hot.select(\"probability\", \"indexedLabel\", \"prediction\").toPandas()"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "metadata": {},
290 | "outputs": [],
291 | "source": [
292 | "%matplotlib inline\n",
293 | "import numpy as np\n",
294 | "from matplotlib import pyplot as plt"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": null,
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "y_true = np.hstack(res_df.indexedLabel.values)\n",
304 | "y_pred = res_df.prediction.values\n",
305 | "\n",
306 | "accuracy = np.sum(y_pred == y_true) / y_true.shape[0]\n",
307 | "\n",
308 | "accuracy"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": null,
314 | "metadata": {},
315 | "outputs": [],
316 | "source": [
317 | "y_true = np.hstack(res_hot_df.indexedLabel.values)\n",
318 | "y_hot_pred = res_hot_df.prediction.values\n",
319 | "\n",
320 | "accuracy_hot = np.sum(y_hot_pred == y_true) / y_true.shape[0]\n",
321 | "\n",
322 | "accuracy_hot"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": null,
328 | "metadata": {},
329 | "outputs": [],
330 | "source": []
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": null,
335 | "metadata": {},
336 | "outputs": [],
337 | "source": []
338 | }
339 | ],
340 | "metadata": {
341 | "kernelspec": {
342 | "display_name": "Python 3",
343 | "language": "python",
344 | "name": "python3"
345 | },
346 | "language_info": {
347 | "codemirror_mode": {
348 | "name": "ipython",
349 | "version": 3
350 | },
351 | "file_extension": ".py",
352 | "mimetype": "text/x-python",
353 | "name": "python",
354 | "nbconvert_exporter": "python",
355 | "pygments_lexer": "ipython3",
356 | "version": "3.6.4"
357 | }
358 | },
359 | "nbformat": 4,
360 | "nbformat_minor": 2
361 | }
362 |
--------------------------------------------------------------------------------
/src/main/scala/com/rakuten/dirty_cat/features/SimilarityEncoder.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo
3 | * under one or more contributor license agreements.
4 | * See the NOTICE file distributed with this work for additional information
5 | * regarding copyright ownership.
6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file
7 | * to You under the Apache License, Version 2.0
8 | * (the "License"); you may not use this file except in compliance with
9 | * the License. You may obtain a copy of the License at
10 | *
11 | * http://www.apache.org/licenses/LICENSE-2.0
12 | *
13 | * Unless required by applicable law or agreed to in writing, software
14 | * distributed under the License is distributed on an "AS IS" BASIS,
15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | * See the License for the specific language governing permissions and
17 | * limitations under the License.
18 | */
19 |
20 | package com.rakuten.dirty_cat.feature
21 |
22 | import com.rakuten.dirty_cat.persistence.{DefaultParamsReader, DefaultParamsWriter}
23 | import com.rakuten.dirty_cat.util.StringSimilarity
24 | import com.rakuten.dirty_cat.util.collection.OpenHashMap
25 |
26 | import org.json4s.JsonAST.{JObject}
27 | import org.apache.spark.ml.linalg.{Vector, Vectors}
28 | // Replace this by VectorUDT
29 | import org.apache.spark.ml.linalg.SQLDataTypes.{VectorType}
30 |
31 | import org.apache.hadoop.fs.Path
32 |
33 | import org.apache.spark.annotation.Since
34 | import org.apache.spark.ml.{Estimator, Model, Transformer}
35 | import org.apache.spark.sql.Row
36 | import org.apache.spark.sql.{Dataset, DataFrame}
37 | //
38 | import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
39 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer}
40 |
41 | import org.apache.spark.rdd.RDD
42 | import org.apache.spark.sql.functions.{col, udf}
43 | import org.apache.spark.sql.types.{StringType, StructType, StructField}
44 |
45 | import org.apache.spark.ml.util.Identifiable
46 | import org.apache.spark.ml.util.{MLWritable, MLReadable, DefaultParamsWritable, DefaultParamsReadable}
47 | import org.apache.spark.ml.util.{MLWriter, MLReader}
48 | import org.apache.spark.ml.param.shared._
49 | import org.apache.spark.SparkException
50 |
51 | import org.apache.spark.ml.param.{ParamValidators, ParamPair, Param, Params, ParamMap, DoubleParam, IntParam}
52 |
53 | import scala.collection.JavaConverters._
54 | import java.lang.{Double => JDouble, Integer => JInt, String => JString}
55 | import java.util.{NoSuchElementException, Map => JMap}
56 | import scala.language.implicitConversions
57 |
58 |
59 |
60 |
61 | private[feature] trait SimilarityBase extends Params with HasInputCol with HasOutputCol with HasHandleInvalid{
62 |
63 | final val nGramSize = new IntParam(this, "nGramSize", "")
64 |
65 | final val vocabSize = new IntParam(this, "vocabSize", "Number of dimensions of the encoding")
66 |
67 | override val handleInvalid = new Param[String](this, "handleInvalid",
68 | "How to handle invalid data (unseen labels or NULL values). " +
69 | "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
70 | "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
71 | ParamValidators.inArray(SimilarityEncoder.supportedHandleInvalids))
72 |
73 | final val similarityType = new Param[String](this, "similarityType", "" +
74 | s"Supported options: ${SimilarityEncoder.supportedStringOrderType.mkString(", ")}.",
75 | ParamValidators.inArray(SimilarityEncoder.supportedSimilarityType))
76 |
77 |
78 | final val stringOrderType = new Param[String](this, "stringOrderType",
79 | "How to order labels of string column. " +
80 | "The first label after ordering is assigned an index of 0. " +
81 | s"Supported options: ${SimilarityEncoder.supportedStringOrderType.mkString(", ")}.",
82 | ParamValidators.inArray(SimilarityEncoder.supportedStringOrderType))
83 |
84 | final def getNGramSize: Int = $(nGramSize)
85 |
86 | final def getVocabSize: Int = $(vocabSize)
87 |
88 | final def getSimilarityType: String = $(similarityType)
89 |
90 | final def getStringOrderType: String = $(stringOrderType)
91 |
92 |
93 | setDefault(nGramSize -> 3,
94 | similarityType -> SimilarityEncoder.nGram,
95 | vocabSize -> 100,
96 | stringOrderType ->SimilarityEncoder.frequencyDesc,
97 | handleInvalid -> SimilarityEncoder.KEEP_INVALID)
98 |
99 |
100 | /** Validates and transforms the input schema. */
101 | protected def validateAndTransformSchema(schema: StructType): StructType = {
102 | val inputColName = $(inputCol)
103 | val inputDataType = schema(inputColName).dataType
104 | require(inputDataType == StringType,
105 | s"The input column $inputColName must be a string" +
106 | s"but got $inputDataType.")
107 | val inputFields = schema.fields
108 | val outputColName = $(outputCol)
109 | require(inputFields.forall(_.name != outputColName),
110 | s"Output column $outputColName already exists.")
111 |
112 | val outputFields = inputFields :+ new StructField(outputColName, VectorType, true)
113 |
114 | StructType(outputFields)
115 | }
116 |
117 | protected def getCategories(dataset: Dataset[_]): Array[(String, Int)] = {
118 | val inputColName = $(inputCol)
119 |
120 | val labels= (dataset
121 | .select(col($(inputCol)).cast(StringType))
122 | .na.drop(Array($(inputCol)))
123 | .groupBy($(inputCol))
124 | .count())
125 |
126 | // Different options to sort
127 | val vocabulary: Array[Row] = $(stringOrderType) match {
128 | case SimilarityEncoder.frequencyDesc => labels.sort(col("count").desc).take($(vocabSize))
129 | case SimilarityEncoder.frequencyAsc => labels.sort(col("count")).take($(vocabSize))
130 | }
131 |
132 | vocabulary.map(row => (row.getAs[String](0), row.getAs[Int](1))).toArray
133 |
134 | }
135 | }
136 |
137 |
138 | class SimilarityEncoder private[dirty_cat] (override val uid: String) extends Estimator[SimilarityEncoderModel] with SimilarityBase {
139 |
140 | def this() = this(Identifiable.randomUID("SimilarityEncoder"))
141 |
142 | def setInputCol(value: String): this.type = set(inputCol, value)
143 |
144 | def setOutputCol(value: String): this.type = set(outputCol, value)
145 |
146 | def setNGramSize(value: Int): this.type = set(nGramSize, value)
147 |
148 | def setVocabSize(value: Int): this.type = set(vocabSize, value)
149 |
150 | def setSimilarityType(value: String): this.type = set(similarityType, value)
151 |
152 | def setStringOrderType(value: String): this.type = set(stringOrderType, value)
153 |
154 | def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
155 |
156 |
157 | override def copy(extra: ParamMap): SimilarityEncoder = { defaultCopy(extra) }
158 |
159 |
160 | override def transformSchema(schema: StructType): StructType = {
161 | validateAndTransformSchema(schema)
162 | }
163 |
164 | override def fit(dataset: Dataset[_]): SimilarityEncoderModel = {
165 |
166 | transformSchema(dataset.schema, logging = true)
167 |
168 | val vocabularyReference = getCategories(dataset).take($(vocabSize))
169 |
170 | copyValues(new SimilarityEncoderModel(uid,
171 | vocabularyReference.toMap).setParent(this))
172 |
173 | }
174 | }
175 |
176 |
177 | object SimilarityEncoder extends DefaultParamsReadable[SimilarityEncoder] {
178 | private[feature] val SKIP_INVALID: String = "skip"
179 | private[feature] val KEEP_INVALID: String = "keep"
180 | private[feature] val supportedHandleInvalids: Array[String] =
181 | Array(SKIP_INVALID, KEEP_INVALID)
182 | private[feature] val frequencyDesc: String = "frequencyDesc"
183 | private[feature] val frequencyAsc: String = "frequencyAsc"
184 | private[feature] val supportedStringOrderType: Array[String] =
185 | Array(frequencyDesc, frequencyAsc)
186 | private[feature] val nGram: String = "nGram"
187 | private[feature] val leverstein: String = "leverstein"
188 | private[feature] val jako: String = "jako"
189 | private[feature] val supportedSimilarityType: Array[String] = Array(nGram, leverstein, jako)
190 |
191 | override def load(path: String): SimilarityEncoder = super.load(path)
192 | }
193 |
194 |
195 |
196 |
197 | /* This encoding is an alternative to OneHotEncoder in the case of
198 | dirty categorical variables. */
199 | class SimilarityEncoderModel private[dirty_cat] (override val uid: String,
200 | val vocabularyReference: Map[String, Int]) extends
201 | Model[SimilarityEncoderModel] with SimilarityBase with MLWritable with Serializable{
202 |
203 | import SimilarityEncoderModel._
204 |
205 | // only called in copy()
206 | def this(uid: String) = this(uid, null)
207 |
208 |
209 | private def int2Integer(x: Int) = java.lang.Integer.valueOf(x)
210 |
211 | private def string2String(x: String) = java.lang.String.valueOf(x)
212 |
213 | /* Java-friendly version of [[vocabularyReference]] */
214 | def javaVocabularyReference: JMap[JString, JInt] = {
215 | vocabularyReference.map{ case (k, v) => string2String(k) -> int2Integer(v) }.asJava}
216 |
217 |
218 | def setInputCol(value: String): this.type = set(inputCol, value)
219 |
220 | def setOutputCol(value: String): this.type = set(outputCol, value)
221 |
222 | def setNGramSize(value: Int): this.type = set(nGramSize, value)
223 |
224 | def setSimilarityType(value: String): this.type = set(similarityType, value)
225 |
226 |
227 | override def transformSchema(schema: StructType): StructType = {
228 | if (schema.fieldNames.contains($(inputCol))) {
229 | validateAndTransformSchema(schema)
230 | } else {
231 | // If the input column does not exist during transformation, we skip
232 | // SimilarityEncoderModel.
233 | schema
234 | }
235 | }
236 |
237 |
238 | override def transform(dataset: Dataset[_]): DataFrame = {
239 |
240 | if (!dataset.schema.fieldNames.contains($(inputCol))) {
241 | logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
242 | "Skip SimilarityEncoderModel.")
243 | return dataset.toDF
244 | }
245 |
246 | transformSchema(dataset.schema, logging = true)
247 |
248 | val inputColName = $(inputCol)
249 | val outputColName = $(outputCol)
250 | // Transformation
251 | val vocabulary = getCategories(dataset).toArray
252 | val vocabularyLabels = vocabulary.map(_._1).toSeq
253 | val vocabularyReferenceLabels = vocabularyReference.toArray.map(_._1)
254 |
255 |
256 | val similarityValues = $(similarityType) match {
257 | case SimilarityEncoder.nGram => StringSimilarity.getNGramSimilaritySeq(vocabularyLabels, vocabularyReferenceLabels, $(nGramSize))
258 | case SimilarityEncoder.leverstein => StringSimilarity.getLevenshteinSimilaritySeq(vocabularyLabels, vocabularyReferenceLabels)
259 | case SimilarityEncoder.jako => StringSimilarity.getJaroWinklerSimilaritySeq(vocabularyLabels, vocabularyReferenceLabels)
260 | }
261 |
262 | val labelToEncode: OpenHashMap[String, Array[Double]] = {
263 | val n = vocabularyLabels.length
264 | val map = new OpenHashMap[String, Array[Double]](n)
265 | var i = 0
266 | while (i < n ){
267 | map.update(vocabularyLabels(i), similarityValues(i).toArray)
268 | i += 1
269 | }
270 | map
271 | }
272 |
273 | val (filteredDataset, keepInvalid) = $(handleInvalid) match {
274 | case SimilarityEncoder.SKIP_INVALID =>
275 | val filterer = udf { label: String =>
276 | labelToEncode.contains(label)
277 | }
278 | (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
279 | case _ => (dataset, getHandleInvalid == SimilarityEncoder.KEEP_INVALID)
280 | }
281 |
282 | val emptyValues = Vectors.dense(Array.fill($(vocabSize))(0D))
283 |
284 | val dirtyCatUDF = udf { label: String =>
285 | if (label == null) {
286 | if (keepInvalid) {
287 | emptyValues
288 | } else {
289 | throw (new SparkException("SimilarityEncoder encountered NULL value. To handle or skip " +
290 | "NULLS, try setting SimilarityEncoder.handleInvalid."))
291 | }
292 | } else {
293 | if (labelToEncode.contains(label)) {
294 | Vectors.dense(labelToEncode(label))
295 | } else if (keepInvalid) {
296 | emptyValues
297 | } else {
298 | throw (new SparkException(s"Unseen label: $label. To handle unseen labels, " +
299 | s"set Param handleInvalid to ${SimilarityEncoder.KEEP_INVALID}."))
300 | }
301 | }
302 | }.asNondeterministic()
303 |
304 | filteredDataset.withColumn($(outputCol), dirtyCatUDF(col($(inputCol))))
305 | }
306 |
307 | override def copy(extra: ParamMap) = { defaultCopy(extra) }
308 |
309 |
310 | override def write: MLWriter = new SimilarityEncoderModelWriter(this)
311 |
312 | }
313 |
314 |
315 | /// Add read and write
316 | /** [[MLWriter]] instance for [[SimilarityEncoderModel]] */
317 | object SimilarityEncoderModel extends MLReadable[SimilarityEncoderModel] {
318 | //
319 | override def read: MLReader[SimilarityEncoderModel] = new SimilarityEncoderModelReader
320 |
321 | override def load(path: String): SimilarityEncoderModel = super.load(path)
322 |
323 | private[SimilarityEncoderModel]
324 | class SimilarityEncoderModelWriter(instance: SimilarityEncoderModel) extends MLWriter {
325 |
326 | private case class Data(vocabularyReference: Map[String, Int])
327 |
328 | override protected def saveImpl(path: String): Unit = {
329 |
330 | implicit val sparkSession = super.sparkSession
331 | implicit val sc = sparkSession.sparkContext
332 |
333 | // Save metadata and Params
334 | DefaultParamsWriter.saveMetadata(instance, path, sc)
335 | // Save model data
336 | val data = Data(instance.vocabularyReference)
337 | val dataPath = new Path(path, "data").toString
338 |
339 | sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
340 | }
341 | }
342 |
343 |
344 | private class SimilarityEncoderModelReader extends MLReader[SimilarityEncoderModel] {
345 |
346 | private val className = classOf[SimilarityEncoderModel].getName
347 |
348 | override def load(path: String): SimilarityEncoderModel = {
349 |
350 | implicit val sc = super.sparkSession.sparkContext
351 |
352 | val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
353 |
354 | val dataPath = new Path(path, "data").toString
355 | val data = sparkSession.read.parquet(dataPath).select("vocabularyReference")
356 |
357 | val vocabularyReference = data.head().getAs[Map[String, Int]](0)
358 | val model = new SimilarityEncoderModel(metadata.uid, vocabularyReference)
359 | DefaultParamsReader.getAndSetParams(model, metadata)
360 |
361 | model
362 | }
363 | }
364 |
365 |
366 | }
367 |
368 |
369 |
370 |
371 |
--------------------------------------------------------------------------------