├── version.txt ├── python ├── requirements.txt ├── MANIFEST.in ├── dirty_cat_spark │ ├── feature │ │ ├── __init__.py │ │ └── encoder.py │ ├── __init__.py │ └── utils │ │ ├── java_reader.py │ │ └── __init__.py ├── README.md └── setup.py ├── Makefile ├── src ├── main │ └── scala │ │ └── com │ │ └── rakuten │ │ └── dirty_cat │ │ ├── spark_utils │ │ ├── Utils.scala │ │ ├── OpenHashMap.scala │ │ ├── BitSet.scala │ │ └── OpenHashSet.scala │ │ ├── utils │ │ ├── TextUtils.scala │ │ └── StringDistances.scala │ │ ├── spark_persistence │ │ └── ReadWrite.scala │ │ └── features │ │ └── SimilarityEncoder.scala └── test │ └── scala │ └── com │ └── rakuten │ └── dirty_cat │ └── feature │ └── SimilarityEncoderTestSuite.scala ├── README.md ├── LICENSE └── examples └── Midwest_Survey.ipynb /version.txt: -------------------------------------------------------------------------------- 1 | 0.1-SNAPSHOT 2 | 3 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | pyspark 2 | setuptools -------------------------------------------------------------------------------- /python/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include jar files 2 | include dirty_cat_spark/jars/*.jar 3 | -------------------------------------------------------------------------------- /python/dirty_cat_spark/feature/__init__.py: -------------------------------------------------------------------------------- 1 | # from pydirtycat.feature.encoder import SimilarityEncoder 2 | # # from pydirtycat.feature.encoder import SimilarityEncoderModel 3 | 4 | # __all__ = ["SimilarityEncoder"] 5 | -------------------------------------------------------------------------------- /python/dirty_cat_spark/__init__.py: -------------------------------------------------------------------------------- 1 | # import sys 2 | # from sparknlp import annotator 3 | # sys.modules['com.johnsnowlabs.nlp.annotators'] = annotator 4 | 5 | # from pydirtycat import feature 6 | 7 | # __all__ = ["feature"] 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = /bin/bash 2 | VERSION = $(shell cat version.txt) 3 | 4 | .PHONY: clean clean-pyc clean-dist 5 | 6 | clean: clean-dist clean-pyc 7 | 8 | clean-pyc: 9 | find . -name '*.pyc' -exec rm -f {} + 10 | find . -name '*.pyo' -exec rm -f {} + 11 | find . -name '*~' -exec rm -f {} + 12 | find . -name '__pycache__' -exec rm -fr {} + 13 | 14 | clean-dist: 15 | rm -rf target 16 | rm -rf python/build/ 17 | rm -rf python/*.egg-info 18 | 19 | 20 | publish: clean 21 | # use spark packages to create the distribution 22 | sbt clean 23 | sbt compile 24 | sbt package 25 | 26 | cd python; python setup.py sdist 27 | 28 | -------------------------------------------------------------------------------- /python/dirty_cat_spark/utils/java_reader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from pyspark.ml.util import MLReader, _jvm 4 | 5 | 6 | class CustomJavaMLReader(MLReader): 7 | """ 8 | (Custom) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types 9 | """ 10 | 11 | def __init__(self, clazz, java_class): 12 | self._clazz = clazz 13 | self._jread = self._load_java_obj(java_class).read() 14 | 15 | def load(self, path): 16 | """Load the ML instance from the input path.""" 17 | java_obj = self._jread.load(path) 18 | return self._clazz._from_java(java_obj) 19 | 20 | @classmethod 21 | def _load_java_obj(cls, java_class): 22 | """Load the peer Java object of the ML instance.""" 23 | java_obj = _jvm() 24 | for name in java_class.split("."): 25 | java_obj = getattr(java_obj, name) 26 | return java_obj 27 | -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Python wrapper 2 | 3 | This is a python wrapper of com.rakuten.dirty_cat. 4 | 5 | 6 | ## Install 7 | 8 | Please make sure to have built target/scala_VERSION/PACKAGE.jar as this is 9 | required to use this wrapper. 10 | 11 | 12 | * Local machine: 13 | 14 | ```{.python} 15 | python setup.py install --user 16 | ``` 17 | 18 | * Distributed: Build distribution 19 | ```{.python} 20 | python setup.py sdist 21 | ``` 22 | 23 | 24 | ## Testing 25 | 26 | ```{.bash} 27 | spark-submit --jars target/scala-2.11/com-rakuten-dirty_cat_2.11-0.1-SNAPSHOT.jar python/test/test_similarity_encoder.py 28 | ``` 29 | 30 | 31 | 32 | ### Usage 33 | 34 | #### Declaration 35 | ```{.python} 36 | from dirty_cat_spark.feature.encoder.SimilarityEncoder 37 | 38 | encoder = (SimilarityEncoder() 39 | .setInputCol("devices") 40 | .setOutputCol("devicesEncoded") 41 | .setSimilarityType("nGram") 42 | .setVocabSize(1000)) 43 | ``` 44 | 45 | #### Using it in a pipeline 46 | ```{.python} 47 | from pyspark.ml import Pipeline 48 | 49 | pipeline = Pipeline(stages[encoder, YOUR_ESTIMATOR]) 50 | pipelineModel = pipeline.fit(dataframe) 51 | ``` 52 | 53 | #### Serialization 54 | ```{.python} 55 | pipelineModel.writei().overwrite().save("pipeline.parquet") 56 | ``` 57 | 58 | 59 | ## Reference 60 | Patricio Cerda, Gaël Varoquaux, Balázs Kégl. Similarity encoding for learning with dirty categorical variables. 2018. Accepted for publication in: Machine Learning journal, Springer. 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/spark_utils/Utils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * This code is a modified version of the original Spark 2.1 implementation. 18 | */ 19 | 20 | package com.rakuten.dirty_cat.utils 21 | 22 | // based on org.apache.spark.util copy /paste 23 | private[dirty_cat] object Utils { 24 | 25 | def getSparkClassLoader: ClassLoader = getClass.getClassLoader 26 | 27 | def getContextOrSparkClassLoader: ClassLoader = 28 | Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) 29 | 30 | // scalastyle:off classforname 31 | /** Preferred alternative to Class.forName(className) */ 32 | def classForName(className: String): Class[_] = { 33 | Class.forName(className, true, getContextOrSparkClassLoader) 34 | // scalastyle:on classforname 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/utils/TextUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo 3 | * under one or more contributor license agreements. 4 | * See the NOTICE file distributed with this work for additional information 5 | * regarding copyright ownership. 6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file 7 | * to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | 20 | package com.rakuten.dirty_cat.utils 21 | 22 | 23 | import org.apache.spark.sql.functions.udf 24 | 25 | package object TextUtils { 26 | 27 | import java.text.Normalizer 28 | import scala.util.Try 29 | 30 | def normalizeString(input: String): String = { 31 | Try{ 32 | val cleaned = input.trim.toLowerCase 33 | val normalized = (Normalizer.normalize(cleaned, Normalizer.Form.NFD) 34 | .replaceAll("[\\p{InCombiningDiacriticalMarks}\\p{IsM}\\p{IsLm}\\p{IsSk}]+", "")) 35 | (normalized 36 | .replaceAll("'s", "") 37 | .replaceAll("ß", "ss") 38 | .replaceAll("ø", "o") 39 | .replaceAll("[^a-zA-Z0-9-]+", "-") 40 | .replaceAll("-+", "-") 41 | .stripSuffix("-")) 42 | }.getOrElse(input) 43 | } 44 | 45 | val normalizeStringUDF = udf[String, String](normalizeString(_)) 46 | } 47 | 48 | -------------------------------------------------------------------------------- /python/dirty_cat_spark/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import py4j.protocol 2 | from py4j.protocol import Py4JJavaError 3 | from py4j.java_gateway import JavaObject 4 | from py4j.java_collections import JavaArray, JavaList, JavaMap 5 | 6 | from pyspark import RDD, SparkContext 7 | from pyspark.serializers import PickleSerializer, AutoBatchedSerializer 8 | from pyspark.sql import DataFrame, SQLContext 9 | 10 | # Hack for support float('inf') in Py4j 11 | _old_smart_decode = py4j.protocol.smart_decode 12 | 13 | _float_str_mapping = { 14 | 'nan': 'NaN', 15 | 'inf': 'Infinity', 16 | '-inf': '-Infinity', 17 | } 18 | 19 | 20 | def _new_smart_decode(obj): 21 | if isinstance(obj, float): 22 | s = str(obj) 23 | return _float_str_mapping.get(s, s) 24 | return _old_smart_decode(obj) 25 | 26 | py4j.protocol.smart_decode = _new_smart_decode 27 | 28 | 29 | _picklable_classes = [ 30 | 'SparseVector', 31 | 'DenseVector', 32 | 'SparseMatrix', 33 | 'DenseMatrix', 34 | ] 35 | 36 | 37 | def _java2py(sc, r, encoding="bytes"): 38 | if isinstance(r, JavaObject): 39 | clsName = r.getClass().getSimpleName() 40 | # convert RDD into JavaRDD 41 | if clsName != 'JavaRDD' and clsName.endswith("RDD"): 42 | r = r.toJavaRDD() 43 | clsName = 'JavaRDD' 44 | 45 | if clsName == 'JavaRDD': 46 | jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) 47 | return RDD(jrdd, sc) 48 | 49 | if clsName == 'Dataset': 50 | return DataFrame(r, SQLContext.getOrCreate(sc)) 51 | 52 | if clsName in _picklable_classes: 53 | r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) 54 | 55 | elif isinstance(r, (JavaArray, JavaList, JavaMap)): 56 | try: 57 | r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) 58 | except Py4JJavaError: 59 | pass # not pickable 60 | 61 | if isinstance(r, (bytearray, bytes)): 62 | r = PickleSerializer().loads(bytes(r), encoding=encoding) 63 | return r 64 | 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dirty Cat: Dealing with dirty categorical (strings). 2 | 3 | DirtyCat(Scala) is a package that leverage Spark ML to perform large scale Machine Learning, and provides an alternative to encode string variables. 4 | This package is largely based on the python original code, https://github.com/dirty-cat 5 | 6 | 7 | ## Documentation 8 | * https://github.com/dirty-cat 9 | * Patricio Cerda, Gaël Varoquaux, Balázs Kégl. Similarity encoding for learning with dirty categorical variables. Machine Learning journal, Springer. 2018. 10 | 11 | 12 | ### Getting started: How to use it 13 | 14 | The DirtyCat project is built for both Scala 2.11.x against Spark v2.3.0. 15 | 16 | This package is provided as it is, hence, you will have to install it by 17 | yourself. Here are some indications to start using it. 18 | 19 | 20 | ### Build it by yourself: Installation 21 | 22 | This project can be built with [SBT](https://www.scala-sbt.org/) 1.1.x. 23 | 24 | 25 | Change build.sbt to satisfy your scala/spark installations. 26 | Then, run on the command line 27 | ```{.bash} 28 | sbt clean 29 | 30 | sbt compile 31 | 32 | sbt package 33 | ``` 34 | 35 | This will generate a .jar file in: target/scala_VERSION/PACKAGE.jar, where 36 | PACKAGE = com.rakuten.dirty_cat_VERSION-0.1-SNAPSHOT.jar 37 | 38 | 39 | If you are using Jupyter notebooks (scala), you can 40 | add this file to your toree-spark-options in your Jupyter kernel. 41 | 42 | * Find your available kernesls running: 43 | ``` 44 | jupyter kernelspec list 45 | ``` 46 | * Go to your Scala kernel and add: 47 | ```{.python} 48 | "env": { 49 | "DEFAULT_INTERPRETER": "Scala", 50 | "__TOREE_SPARK_OPTS__": "--conf spark.driver.memory=2g --conf spark.executor.cores=4 --conf spark.executor.memory=1g --jars PATH/target/scala_VERSION/PACKAGE.jar 51 | } 52 | ``` 53 | 54 | 55 | To submit your spark application, run 56 | ```{.bash} 57 | spark-submit --master local[3] --jars target/scala-2.11/dirty_cat_2.11-1.0.jar YOUR_APPLICATION 58 | ``` 59 | 60 | 61 | ### Ceate local package 62 | ```{.bash} 63 | make publish 64 | ``` 65 | 66 | ### Usage with Spark ML 67 | 68 | #### Declaration 69 | ```{.scala} 70 | import com.rakuten.dirty_cat.feature.SimilarityEncoder 71 | 72 | val encoder = (new SimilarityEncoder() 73 | .setInputCol("devices") 74 | .setOutputCol("devicesEncoded") 75 | .setSimilarityType("nGram") 76 | .setVocabSize(1000)) 77 | ``` 78 | 79 | #### Using it in a pipeline 80 | ```{.scala} 81 | import org.apache.spark.ml.Pipeline 82 | 83 | val pipeline = (new Pipeline().setStages(Array(encoder, YOUR_ESTIMATOR))) 84 | val pipelineModel = pipeline.fit(dataframe) 85 | ``` 86 | 87 | #### Serialization 88 | ```{.scala} 89 | pipelineModel.write.overwrite().save("pipeline.parquet") 90 | ``` 91 | 92 | 93 | 94 | ## History 95 | 96 | Andrés Hoyos-Idrobo started this implementation of DirtyCat as a way to improve his Spark/Scala skills. 97 | 98 | Contributions from: 99 | 100 | * Andrés Hoyos-Idrobo 101 | 102 | 103 | Corporate (Code) Contributors: 104 | * Rakuten Institute of Technology 105 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/utils/StringDistances.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo 3 | * under one or more contributor license agreements. 4 | * See the NOTICE file distributed with this work for additional information 5 | * regarding copyright ownership. 6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file 7 | * to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | 20 | package com.rakuten.dirty_cat.util 21 | 22 | 23 | import scala.collection.mutable 24 | import scala.collection.parallel.ParSeq 25 | import org.apache.commons.lang3.StringUtils 26 | 27 | 28 | private[dirty_cat] object StringSimilarity{ 29 | 30 | 31 | private[util] def getNGrams(string: String, n: Int): List[List[String]] = { 32 | // start = and end = 33 | val tokens = List("") ::: string.toLowerCase().split("").toList ::: List("") 34 | val nGram = tokens.sliding(n).toList 35 | nGram } 36 | 37 | 38 | private[util] def getCounts(nGram: List[List[String]]): Map[List[String],Int] = { 39 | nGram.groupBy(identity).mapValues(_.size) } 40 | 41 | 42 | private[util] def getNGramSimilarity(string1: String, string2: String, n: Int): Double = { 43 | val ngrams1 = getNGrams(string1, n) 44 | val ngrams2 = getNGrams(string2, n) 45 | val counts1 = getCounts(ngrams1) 46 | val counts2 = getCounts(ngrams2) 47 | 48 | val sameGrams = (counts1.keySet.intersect(counts2.keySet) 49 | .map(k => k -> List(counts1(k), counts2(k))).toMap) 50 | 51 | val nSameGrams = sameGrams.size 52 | val nAllGrams = ngrams1.length + ngrams2.length 53 | 54 | val similarity = nSameGrams.toDouble / (nAllGrams.toDouble - nSameGrams.toDouble) 55 | 56 | similarity } 57 | 58 | 59 | def getLevenshteinRatio(string1: String, string2: String): Double = { 60 | val totalLength = (string1.length + string2.length).toDouble 61 | if (totalLength == 0D){ 1D } else { (totalLength - StringUtils.getLevenshteinDistance(string1, string2)) / totalLength }} 62 | 63 | 64 | def getJaroWinklerRatio(string1: String, string2: String): Double = { 65 | val totalLength = (string1.length + string2.length).toDouble 66 | if (totalLength == 0D){ 1D } else { (totalLength - StringUtils.getJaroWinklerDistance(string1, string2)) / totalLength }} 67 | 68 | 69 | def getNGramSimilaritySeq(data: Seq[String], categories: Seq[String], n: Int): Seq[Seq[Double]] = { 70 | data.map{xi => categories.map{yi => getNGramSimilarity(xi, yi, n)}}} 71 | 72 | 73 | def getLevenshteinSimilaritySeq(data: Seq[String], categories: Seq[String]): Seq[Seq[Double]] = { 74 | data.map{xi => categories.map{yi => getLevenshteinRatio(xi, yi)}}} 75 | 76 | 77 | def getJaroWinklerSimilaritySeq(data: Seq[String], categories: Seq[String]): Seq[Seq[Double]] = { 78 | data.map{xi => categories.map{yi => getJaroWinklerRatio(xi, yi)}}} 79 | } 80 | 81 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | from __future__ import print_function 13 | import sys 14 | import os 15 | import glob 16 | from setuptools import setup, find_packages 17 | from shutil import copyfile, copytree, rmtree 18 | 19 | 20 | basedir = os.path.dirname(os.path.abspath(__file__)) 21 | os.chdir(basedir) 22 | 23 | # A temporary path so we can access above the Python project root and fetch scripts and jars we need 24 | JARS_TARGET = os.path.join(basedir, "dirty_cat_spark/lib") 25 | DIRTY_CAT_HOME = os.path.abspath("../") 26 | 27 | # Figure out where the jars are we need to package with PySpark. 28 | JARS_PATH = glob.glob(os.path.join(DIRTY_CAT_HOME, "target/scala-*")) 29 | 30 | if len(JARS_PATH) == 1: 31 | JARS_PATH = JARS_PATH[0] 32 | else: 33 | raise IOError("cannot find jar files." 34 | " Please make user to run sbt package first") 35 | 36 | 37 | def _supports_symlinks(): 38 | """Check if the system supports symlinks (e.g. *nix) or not.""" 39 | return getattr(os, "symlink", None) is not None 40 | 41 | 42 | if _supports_symlinks(): 43 | os.symlink(JARS_PATH, JARS_TARGET) 44 | else: 45 | # For windows fall back to the slower copytree 46 | copytree(JARS_PATH, JARS_TARGET) 47 | 48 | 49 | 50 | def setup_package(): 51 | 52 | try: 53 | def f(*path): 54 | return open(os.path.join(basedir, *path)) 55 | 56 | setup( 57 | name='dirty_cat_spark', 58 | maintainer='Andres Hoyos Idrobo', 59 | maintainer_email='andres.hoyosidrobo@rakuten.com', 60 | version=f('../version.txt').read().strip(), 61 | description='Similarity-based embedding to encode dirty categorical strings in PySpark.', 62 | long_description=f('../README.md').read(), 63 | # url='https://github.com/XXX/dirty_cat', 64 | license='Apache License 2.0', 65 | 66 | keywords='spark pyspark categorical ml', 67 | 68 | classifiers=[ 69 | 'Development Status :: 2 - Pre-Alpha', 70 | 'Environment :: Other Environment', 71 | 'Intended Audience :: Developers', 72 | 'License :: OSI Approved :: Apache Software License', 73 | 'Operating System :: OS Independent', 74 | 'Programming Language :: Python', 75 | 'Programming Language :: Python :: 2', 76 | 'Programming Language :: Python :: 2.7', 77 | 'Topic :: Software Development :: Libraries', 78 | 'Topic :: Scientific/Engineering :: Information Analysis', 79 | 'Topic :: Utilities', 80 | ], 81 | install_requires=open('./requirements.txt').read().split(), 82 | 83 | packages=find_packages(exclude=['test']), 84 | include_package_data=True, # Needed to install jar file 85 | ) 86 | 87 | finally: 88 | # print("here") 89 | if _supports_symlinks(): 90 | os.remove(JARS_TARGET) 91 | else: 92 | rmtree(JARS_TARGET) 93 | 94 | 95 | if __name__ == '__main__': 96 | setup_package() 97 | -------------------------------------------------------------------------------- /src/test/scala/com/rakuten/dirty_cat/feature/SimilarityEncoderTestSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo 3 | * under one or more contributor license agreements. 4 | * See the NOTICE file distributed with this work for additional information 5 | * regarding copyright ownership. 6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file 7 | * to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | 20 | package com.rakuten.dirty_cat.feature 21 | 22 | 23 | import com.holdenkarau.spark.testing.SharedSparkContext 24 | import org.apache.spark.sql.{SQLContext, DataFrame} 25 | import org.apache.spark.SparkException 26 | import org.scalatest.FunSuite 27 | import org.scalatest.Matchers 28 | import org.apache.spark.sql.types.{StringType, IntegerType, StructType, StructField} 29 | import org.apache.spark.sql.Row 30 | 31 | 32 | 33 | class SimilarityEncoderSuite extends FunSuite with SharedSparkContext { 34 | 35 | import com.rakuten.dirty_cat.feature.{SimilarityEncoder, SimilarityEncoderModel} 36 | 37 | 38 | private def generateDataFrame(): DataFrame = { 39 | 40 | val schema = StructType(List(StructField("id", IntegerType), 41 | StructField("name", StringType))) 42 | 43 | val sqlContext = new SQLContext(sc) 44 | 45 | val rdd = sc.parallelize(Seq( 46 | Row(0, "andres"), 47 | Row(1, "andrea"), 48 | Row(2, "carlos I"), 49 | Row(3, "camilo II"), 50 | Row(4, "camila de aragon"), 51 | Row(5, "guido"), 52 | Row(6, "camilo II"), 53 | Row(7, "guido"), 54 | Row(8, "guido"), 55 | Row(9, "andrea"), 56 | Row(10, "andrea"), 57 | Row(11, "camila de aragon"))) 58 | 59 | val dataframe = sqlContext.createDataFrame(rdd, schema) 60 | 61 | dataframe 62 | 63 | } 64 | 65 | 66 | test("coverage") { 67 | 68 | List("leverstein", "nGram", "jako").map{similarity => 69 | 70 | val encoder = (new SimilarityEncoder() 71 | .setInputCol("name") 72 | .setOutputCol("nameEncoded") 73 | .setSimilarityType(similarity)) 74 | } 75 | } 76 | 77 | 78 | 79 | test("fit"){ 80 | 81 | val dataframe = generateDataFrame() 82 | 83 | val encoder = (new SimilarityEncoder() 84 | .setInputCol("name") 85 | .setOutputCol("nameEncoded") 86 | .setSimilarityType("nGram")) 87 | 88 | val encoderModel = encoder.fit(dataframe) 89 | 90 | val dataframeEncoded = encoderModel.transform(dataframe) 91 | 92 | 93 | } 94 | 95 | 96 | test("pipeline"){ 97 | 98 | import org.apache.spark.ml.{Pipeline, PipelineModel} 99 | import org.apache.spark.ml.feature.StandardScaler 100 | 101 | val dataframe = generateDataFrame() 102 | 103 | val encoder = (new SimilarityEncoder() 104 | .setInputCol("name") 105 | .setOutputCol("nameEncoded") 106 | .setSimilarityType("nGram")) 107 | 108 | val scaler = (new StandardScaler() 109 | .setInputCol("nameEncoded") 110 | .setOutputCol("scaledFeatures")) 111 | 112 | val pipeline = (new Pipeline() 113 | .setStages(Array(encoder, scaler))) 114 | 115 | val pipelineModel = pipeline.fit(dataframe) 116 | 117 | val dataframeFeatures = pipelineModel.transform(dataframe) 118 | 119 | } 120 | 121 | } 122 | 123 | 124 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/spark_utils/OpenHashMap.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * This code is a modified version of the original Spark 2.1 implementation. 18 | */ 19 | package com.rakuten.dirty_cat.util.collection 20 | 21 | import scala.reflect.ClassTag 22 | 23 | // import org.apache.spark.annotation.DeveloperApi 24 | 25 | /** 26 | * :: DeveloperApi :: 27 | * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, 28 | * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less 29 | * space overhead. 30 | * 31 | * Under the hood, it uses our OpenHashSet implementation. 32 | */ 33 | // @DeveloperApi 34 | private[dirty_cat] 35 | class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( 36 | initialCapacity: Int) 37 | extends Iterable[(K, V)] 38 | with Serializable { 39 | 40 | def this() = this(64) 41 | 42 | protected var _keySet = new OpenHashSet[K](initialCapacity) 43 | 44 | // Init in constructor (instead of in declaration) to work around a Scala compiler specialization 45 | // bug that would generate two arrays (one for Object and one for specialized T). 46 | private var _values: Array[V] = _ 47 | _values = new Array[V](_keySet.capacity) 48 | 49 | @transient private var _oldValues: Array[V] = null 50 | 51 | // Treat the null key differently so we can use nulls in "data" to represent empty items. 52 | private var haveNullValue = false 53 | private var nullValue: V = null.asInstanceOf[V] 54 | 55 | override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size 56 | 57 | /** Tests whether this map contains a binding for a key. */ 58 | def contains(k: K): Boolean = { 59 | if (k == null) { 60 | haveNullValue 61 | } else { 62 | _keySet.getPos(k) != OpenHashSet.INVALID_POS 63 | } 64 | } 65 | 66 | /** Get the value for a given key */ 67 | def apply(k: K): V = { 68 | if (k == null) { 69 | nullValue 70 | } else { 71 | val pos = _keySet.getPos(k) 72 | if (pos < 0) { 73 | null.asInstanceOf[V] 74 | } else { 75 | _values(pos) 76 | } 77 | } 78 | } 79 | 80 | /** Set the value for a key */ 81 | def update(k: K, v: V) { 82 | if (k == null) { 83 | haveNullValue = true 84 | nullValue = v 85 | } else { 86 | val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK 87 | _values(pos) = v 88 | _keySet.rehashIfNeeded(k, grow, move) 89 | _oldValues = null 90 | } 91 | } 92 | 93 | /** 94 | * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, 95 | * set its value to mergeValue(oldValue). 96 | * 97 | * @return the newly updated value. 98 | */ 99 | def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { 100 | if (k == null) { 101 | if (haveNullValue) { 102 | nullValue = mergeValue(nullValue) 103 | } else { 104 | haveNullValue = true 105 | nullValue = defaultValue 106 | } 107 | nullValue 108 | } else { 109 | val pos = _keySet.addWithoutResize(k) 110 | if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { 111 | val newValue = defaultValue 112 | _values(pos & OpenHashSet.POSITION_MASK) = newValue 113 | _keySet.rehashIfNeeded(k, grow, move) 114 | newValue 115 | } else { 116 | _values(pos) = mergeValue(_values(pos)) 117 | _values(pos) 118 | } 119 | } 120 | } 121 | 122 | override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { 123 | var pos = -1 124 | var nextPair: (K, V) = computeNextPair() 125 | 126 | /** Get the next value we should return from next(), or null if we're finished iterating */ 127 | def computeNextPair(): (K, V) = { 128 | if (pos == -1) { // Treat position -1 as looking at the null value 129 | if (haveNullValue) { 130 | pos += 1 131 | return (null.asInstanceOf[K], nullValue) 132 | } 133 | pos += 1 134 | } 135 | pos = _keySet.nextPos(pos) 136 | if (pos >= 0) { 137 | val ret = (_keySet.getValue(pos), _values(pos)) 138 | pos += 1 139 | ret 140 | } else { 141 | null 142 | } 143 | } 144 | 145 | def hasNext: Boolean = nextPair != null 146 | 147 | def next(): (K, V) = { 148 | val pair = nextPair 149 | nextPair = computeNextPair() 150 | pair 151 | } 152 | } 153 | 154 | // The following member variables are declared as protected instead of private for the 155 | // specialization to work (specialized class extends the non-specialized one and needs access 156 | // to the "private" variables). 157 | // They also should have been val's. We use var's because there is a Scala compiler bug that 158 | // would throw illegal access error at runtime if they are declared as val's. 159 | protected var grow = (newCapacity: Int) => { 160 | _oldValues = _values 161 | _values = new Array[V](newCapacity) 162 | } 163 | 164 | protected var move = (oldPos: Int, newPos: Int) => { 165 | _values(newPos) = _oldValues(oldPos) 166 | } 167 | } 168 | 169 | -------------------------------------------------------------------------------- /python/dirty_cat_spark/feature/encoder.py: -------------------------------------------------------------------------------- 1 | from pyspark import since, keyword_only, SparkContext 2 | from pyspark.ml.param import Param, Params, TypeConverters 3 | from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasHandleInvalid 4 | from pyspark.ml.util import JavaMLReadable, JavaMLWritable 5 | from pyspark.ml.wrapper import _jvm 6 | from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper 7 | 8 | from dirty_cat_spark.utils.java_reader import CustomJavaMLReader 9 | 10 | 11 | 12 | class SimilarityEncoder(JavaEstimator, HasInputCol, HasOutputCol, 13 | HasHandleInvalid, JavaMLReadable, JavaMLWritable): 14 | """ 15 | 16 | >>> encoder = SimilarityEncoder(inputCol="names", outputCol="encoderNames") 17 | """ 18 | 19 | vocabSize = Param(Params._dummy(), "vocabSize", "", 20 | typeConverter=TypeConverters.toInt) 21 | 22 | nGramSize = Param(Params._dummy(), "nGramSize", "", 23 | typeConverter=TypeConverters.toInt) 24 | 25 | similarityType = Param(Params._dummy(), "similarityType", "", 26 | typeConverter=TypeConverters.toString) 27 | 28 | handleInvalid = Param(Params._dummy(), "handleInvalid", "", 29 | typeConverter=TypeConverters.toString) 30 | 31 | stringOrderType = Param(Params._dummy(), "stringOrderType", "", 32 | typeConverter=TypeConverters.toString) 33 | 34 | @keyword_only 35 | def __init__(self, inputCol=None, outputCol=None, 36 | nGramSize=3, similarityType="nGram", 37 | handleInvalid="keep", 38 | stringOrderType="frequencyDesc", 39 | vocabSize=100): 40 | """ 41 | __init__(self, inputCol=None, outputCol=None, 42 | nGramSize=3, similarityType="nGram", 43 | handleInvalid="keep", stringOrderType="frequencyDesc", 44 | vocabSize=100) 45 | """ 46 | super(SimilarityEncoder, self).__init__() 47 | 48 | self._java_obj = self._new_java_obj( 49 | "com.rakuten.dirty_cat.feature.SimilarityEncoder", self.uid) 50 | 51 | self._setDefault(nGramSize=3, 52 | # vocabSize=100, 53 | stringOrderType="frequencyDesc", 54 | handleInvalid="keep", 55 | similarityType="nGram") 56 | 57 | kwargs = self._input_kwargs 58 | self.setParams(**kwargs) 59 | 60 | @keyword_only 61 | def setParams(self, inputCol=None, outputCol=None, 62 | nGramSize=3, similarityType="nGram", 63 | handleInvalid="keep", 64 | stringOrderType="frequencyDesc", 65 | vocabSize=100): 66 | """ 67 | setParams(self, inputCol=None, outputCol=None, nGramSize=3, 68 | similarityType="nGram", handleInvalid="keep", 69 | stringOrderType="frequencyDesc", vocabSize=100) 70 | 71 | Set the params for the SimilarityEncoder 72 | """ 73 | kwargs = self._input_kwargs 74 | return self._set(**kwargs) 75 | 76 | 77 | def setStringOrderType(self, value): 78 | return self._set(stringOrderType=value) 79 | 80 | def setSimilarityType(self, value): 81 | return self._set(similarityType=value) 82 | 83 | def setNGramSize(self, value): 84 | return self._set(nGramSize=value) 85 | 86 | def setVocabSize(self, value): 87 | return self._set(vocabSize=value) 88 | 89 | 90 | def getStringOrderType(self): 91 | return self.getOrDefault(self.stringOrderType) 92 | 93 | def getSimilarityType(self): 94 | return self.getOrDefault(self.similarityType) 95 | 96 | def getNGramSize(self): 97 | return self.getOrDefault(self.nGramSize) 98 | 99 | def getVocabSize(self): 100 | return self.getOrDefault(self.vocabSize) 101 | 102 | 103 | def _create_model(self, java_model): 104 | return SimilarityEncoderModel(java_model) 105 | 106 | 107 | 108 | class SimilarityEncoderModel(JavaModel, JavaMLReadable, JavaMLWritable): 109 | """Model fitted by :py:class:`SimilarityEncoder`. """ 110 | 111 | @property 112 | def vocabularyReference(self): 113 | """ 114 | """ 115 | return self._call_java("vocabularyReference") 116 | 117 | # # @classmethod 118 | # # def from_vocabularyReference(cls, vocabularyReference, inputCol, 119 | # # outputCol=None, nGramSize=None, 120 | # # similarityType=None, handleInvalid=None, 121 | # # stringOrderType=None, vocabSize=None): 122 | # # """ 123 | # # Construct the model directly from an array of label strings, 124 | # # requires an active SparkContext. 125 | # # """ 126 | # # sc = SparkContext._active_spark_context 127 | # # java_class = sc._gateway.jvm.java.lang.String 128 | # # jVocabularyReference = SimilarityEncoderModel._new_java_array( 129 | # # vocabularyReference, java_class) 130 | # # model = SimilarityEncoderModel._create_from_java_class( 131 | # # 'dirty_cat.feature.SimilarityEncoderModel', jVocabularyReference) 132 | # # model.setInputCol(inputCol) 133 | # # if outputCol is not None: 134 | # # model.setOutputCol(outputCol) 135 | # # if nGramSize is not None: 136 | # # model.setNGramSize(nGramSize) 137 | # # if similarityType is not None: 138 | # # model.setSimilarityType(similarityType) 139 | # # if handleInvalid is not None: 140 | # # model.setHandleInvalid(handleInvalid) 141 | # # if stringOrderType is not None: 142 | # # model.setStringOrderType(stringOrderType) 143 | # # if vocabSize is not None: 144 | # # model.setVocabSize(vocabSize) 145 | # # return model 146 | 147 | 148 | # # @staticmethod 149 | # # def _from_java(java_stage): 150 | # # """ 151 | # # Given a Java object, create and return a Python wrapper of it. 152 | # # Used for ML persistence. 153 | # # Meta-algorithms such as Pipeline should override this method as a classmethod. 154 | # # """ 155 | # # # Generate a default new instance from the stage_name class. 156 | # # py_type =SimilarityEncoderModel 157 | # # if issubclass(py_type, JavaParams): 158 | # # # Load information from java_stage to the instance. 159 | # # py_stage = py_type() 160 | # # py_stage._java_obj = java_stage 161 | # # py_stage._resetUid(java_stage.uid()) 162 | # # py_stage._transfer_params_from_java() 163 | 164 | # # return py_stage 165 | 166 | # # @classmethod 167 | # # def read(cls): 168 | # # """Returns an MLReader instance for this class.""" 169 | # # return CustomJavaMLReader( 170 | # # cls, 'dirty_cat.feature.SimilarityEncoderModel') 171 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/spark_utils/BitSet.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * This code is a modified version of the original Spark 2.1 implementation. 18 | */ 19 | package com.rakuten.dirty_cat.util.collection 20 | 21 | import java.util.Arrays 22 | 23 | /** 24 | * A simple, fixed-size bit set implementation. This implementation is fast because it avoids 25 | * safety/bound checking. 26 | */ 27 | class BitSet(numBits: Int) extends Serializable { 28 | 29 | private val words = new Array[Long](bit2words(numBits)) 30 | private val numWords = words.length 31 | 32 | /** 33 | * Compute the capacity (number of bits) that can be represented 34 | * by this bitset. 35 | */ 36 | def capacity: Int = numWords * 64 37 | 38 | /** 39 | * Clear all set bits. 40 | */ 41 | def clear(): Unit = Arrays.fill(words, 0) 42 | 43 | /** 44 | * Set all the bits up to a given index 45 | */ 46 | def setUntil(bitIndex: Int): Unit = { 47 | val wordIndex = bitIndex >> 6 // divide by 64 48 | Arrays.fill(words, 0, wordIndex, -1) 49 | if(wordIndex < words.length) { 50 | // Set the remaining bits (note that the mask could still be zero) 51 | val mask = ~(-1L << (bitIndex & 0x3f)) 52 | words(wordIndex) |= mask 53 | } 54 | } 55 | 56 | /** 57 | * Clear all the bits up to a given index 58 | */ 59 | def clearUntil(bitIndex: Int): Unit = { 60 | val wordIndex = bitIndex >> 6 // divide by 64 61 | Arrays.fill(words, 0, wordIndex, 0) 62 | if(wordIndex < words.length) { 63 | // Clear the remaining bits 64 | val mask = -1L << (bitIndex & 0x3f) 65 | words(wordIndex) &= mask 66 | } 67 | } 68 | 69 | /** 70 | * Compute the bit-wise AND of the two sets returning the 71 | * result. 72 | */ 73 | def &(other: BitSet): BitSet = { 74 | val newBS = new BitSet(math.max(capacity, other.capacity)) 75 | val smaller = math.min(numWords, other.numWords) 76 | assert(newBS.numWords >= numWords) 77 | assert(newBS.numWords >= other.numWords) 78 | var ind = 0 79 | while( ind < smaller ) { 80 | newBS.words(ind) = words(ind) & other.words(ind) 81 | ind += 1 82 | } 83 | newBS 84 | } 85 | 86 | /** 87 | * Compute the bit-wise OR of the two sets returning the 88 | * result. 89 | */ 90 | def |(other: BitSet): BitSet = { 91 | val newBS = new BitSet(math.max(capacity, other.capacity)) 92 | assert(newBS.numWords >= numWords) 93 | assert(newBS.numWords >= other.numWords) 94 | val smaller = math.min(numWords, other.numWords) 95 | var ind = 0 96 | while( ind < smaller ) { 97 | newBS.words(ind) = words(ind) | other.words(ind) 98 | ind += 1 99 | } 100 | while( ind < numWords ) { 101 | newBS.words(ind) = words(ind) 102 | ind += 1 103 | } 104 | while( ind < other.numWords ) { 105 | newBS.words(ind) = other.words(ind) 106 | ind += 1 107 | } 108 | newBS 109 | } 110 | 111 | /** 112 | * Compute the symmetric difference by performing bit-wise XOR of the two sets returning the 113 | * result. 114 | */ 115 | def ^(other: BitSet): BitSet = { 116 | val newBS = new BitSet(math.max(capacity, other.capacity)) 117 | val smaller = math.min(numWords, other.numWords) 118 | var ind = 0 119 | while (ind < smaller) { 120 | newBS.words(ind) = words(ind) ^ other.words(ind) 121 | ind += 1 122 | } 123 | if (ind < numWords) { 124 | Array.copy( words, ind, newBS.words, ind, numWords - ind ) 125 | } 126 | if (ind < other.numWords) { 127 | Array.copy( other.words, ind, newBS.words, ind, other.numWords - ind ) 128 | } 129 | newBS 130 | } 131 | 132 | /** 133 | * Compute the difference of the two sets by performing bit-wise AND-NOT returning the 134 | * result. 135 | */ 136 | def andNot(other: BitSet): BitSet = { 137 | val newBS = new BitSet(capacity) 138 | val smaller = math.min(numWords, other.numWords) 139 | var ind = 0 140 | while (ind < smaller) { 141 | newBS.words(ind) = words(ind) & ~other.words(ind) 142 | ind += 1 143 | } 144 | if (ind < numWords) { 145 | Array.copy( words, ind, newBS.words, ind, numWords - ind ) 146 | } 147 | newBS 148 | } 149 | 150 | /** 151 | * Sets the bit at the specified index to true. 152 | * @param index the bit index 153 | */ 154 | def set(index: Int) { 155 | val bitmask = 1L << (index & 0x3f) // mod 64 and shift 156 | words(index >> 6) |= bitmask // div by 64 and mask 157 | } 158 | 159 | def unset(index: Int) { 160 | val bitmask = 1L << (index & 0x3f) // mod 64 and shift 161 | words(index >> 6) &= ~bitmask // div by 64 and mask 162 | } 163 | 164 | /** 165 | * Return the value of the bit with the specified index. The value is true if the bit with 166 | * the index is currently set in this BitSet; otherwise, the result is false. 167 | * 168 | * @param index the bit index 169 | * @return the value of the bit with the specified index 170 | */ 171 | def get(index: Int): Boolean = { 172 | val bitmask = 1L << (index & 0x3f) // mod 64 and shift 173 | (words(index >> 6) & bitmask) != 0 // div by 64 and mask 174 | } 175 | 176 | /** 177 | * Get an iterator over the set bits. 178 | */ 179 | def iterator: Iterator[Int] = new Iterator[Int] { 180 | var ind = nextSetBit(0) 181 | override def hasNext: Boolean = ind >= 0 182 | override def next(): Int = { 183 | val tmp = ind 184 | ind = nextSetBit(ind + 1) 185 | tmp 186 | } 187 | } 188 | 189 | 190 | /** Return the number of bits set to true in this BitSet. */ 191 | def cardinality(): Int = { 192 | var sum = 0 193 | var i = 0 194 | while (i < numWords) { 195 | sum += java.lang.Long.bitCount(words(i)) 196 | i += 1 197 | } 198 | sum 199 | } 200 | 201 | /** 202 | * Returns the index of the first bit that is set to true that occurs on or after the 203 | * specified starting index. If no such bit exists then -1 is returned. 204 | * 205 | * To iterate over the true bits in a BitSet, use the following loop: 206 | * 207 | * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) { 208 | * // operate on index i here 209 | * } 210 | * 211 | * @param fromIndex the index to start checking from (inclusive) 212 | * @return the index of the next set bit, or -1 if there is no such bit 213 | */ 214 | def nextSetBit(fromIndex: Int): Int = { 215 | var wordIndex = fromIndex >> 6 216 | if (wordIndex >= numWords) { 217 | return -1 218 | } 219 | 220 | // Try to find the next set bit in the current word 221 | val subIndex = fromIndex & 0x3f 222 | var word = words(wordIndex) >> subIndex 223 | if (word != 0) { 224 | return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word) 225 | } 226 | 227 | // Find the next set bit in the rest of the words 228 | wordIndex += 1 229 | while (wordIndex < numWords) { 230 | word = words(wordIndex) 231 | if (word != 0) { 232 | return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word) 233 | } 234 | wordIndex += 1 235 | } 236 | 237 | -1 238 | } 239 | 240 | /** Return the number of longs it would take to hold numBits. */ 241 | private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 242 | } 243 | 244 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/spark_persistence/ReadWrite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * This code is a modified version of the original Spark 2.1 implementation. 18 | */ 19 | package com.rakuten.dirty_cat.persistence 20 | 21 | import com.rakuten.dirty_cat.utils.Utils 22 | 23 | import org.apache.hadoop.fs.Path 24 | import org.apache.spark.SparkContext 25 | import org.apache.spark.ml.param.{ParamPair, Params} 26 | import org.apache.spark.ml.util.{MLReader, MLWriter} 27 | import org.json4s.JsonAST.{JObject, JValue, JArray, JField} 28 | import org.json4s.JsonDSL._ 29 | import org.json4s.jackson.JsonMethods._ 30 | import org.json4s.{DefaultFormats, _} 31 | 32 | 33 | 34 | // This originates from apache-spark DefaultPramsWriter copy paste 35 | private[dirty_cat] object DefaultParamsWriter { 36 | 37 | /** 38 | * Saves metadata + Params to: path + "/metadata" 39 | * - class 40 | * - timestamp 41 | * - sparkVersion 42 | * - uid 43 | * - paramMap 44 | * - (optionally, extra metadata) 45 | * 46 | * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. 47 | * @param paramMap If given, this is saved in the "paramMap" field. 48 | * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using 49 | * [[org.apache.spark.ml.param.Param.jsonEncode()]]. 50 | */ 51 | def saveMetadata( 52 | instance: Params, 53 | path: String, 54 | sc: SparkContext, 55 | extraMetadata: Option[JObject] = None, 56 | paramMap: Option[JValue] = None): Unit = { 57 | 58 | val metadataPath = new Path(path, "metadata").toString 59 | val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) 60 | sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) 61 | } 62 | 63 | /** 64 | * Helper for [[saveMetadata()]] which extracts the JSON to save. 65 | * This is useful for ensemble models which need to save metadata for many sub-models. 66 | * 67 | * @see [[saveMetadata()]] for details on what this includes. 68 | */ 69 | def getMetadataToSave( 70 | instance: Params, 71 | sc: SparkContext, 72 | extraMetadata: Option[JObject] = None, 73 | paramMap: Option[JValue] = None): String = { 74 | val uid = instance.uid 75 | val cls = instance.getClass.getName 76 | val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] 77 | val jsonParams = (paramMap.getOrElse(render(params.map{ case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v))}.toList))) 78 | 79 | val basicMetadata = ("class" -> cls) ~ 80 | ("timestamp" -> System.currentTimeMillis()) ~ 81 | ("sparkVersion" -> sc.version) ~ 82 | ("uid" -> uid) ~ 83 | ("paramMap" -> jsonParams) 84 | val metadata = extraMetadata match { 85 | case Some(jObject) => 86 | basicMetadata ~ jObject 87 | case None => 88 | basicMetadata 89 | } 90 | val metadataJson: String = compact(render(metadata)) 91 | metadataJson 92 | } 93 | } 94 | 95 | 96 | 97 | // This originates from apache-spark DefaultPramsReader copy paste 98 | private[dirty_cat] object DefaultParamsReader { 99 | 100 | /** 101 | * All info from metadata file. 102 | * 103 | * @param params paramMap, as a `JValue` 104 | * @param metadata All metadata, including the other fields 105 | * @param metadataJson Full metadata file String (for debugging) 106 | */ 107 | case class Metadata(className: String, 108 | uid: String, 109 | timestamp: Long, 110 | sparkVersion: String, 111 | params: JValue, 112 | metadata: JValue, 113 | metadataJson: String) { 114 | 115 | /** 116 | * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name. 117 | * This can be useful for getting a Param value before an instance of `Params` 118 | * is available. 119 | */ 120 | def getParamValue(paramName: String): JValue = { 121 | implicit val format = DefaultFormats 122 | params match { 123 | case JObject(pairs) => 124 | val values = pairs.filter { case (pName, jsonValue) => 125 | pName == paramName 126 | }.map(_._2) 127 | assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" + 128 | s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", ")) 129 | values.head 130 | case _ => 131 | throw new IllegalArgumentException( 132 | s"Cannot recognize JSON metadata: $metadataJson.") 133 | } 134 | } 135 | } 136 | 137 | /** 138 | * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]] 139 | * 140 | * @param expectedClassName If non empty, this is checked against the loaded metadata. 141 | * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata 142 | */ 143 | def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { 144 | val metadataPath = new Path(path, "metadata").toString 145 | val metadataStr = sc.textFile(metadataPath, 1).first() 146 | parseMetadata(metadataStr, expectedClassName) 147 | } 148 | 149 | /** 150 | * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]]. 151 | * This is a helper function for [[loadMetadata()]]. 152 | * 153 | * @param metadataStr JSON string of metadata 154 | * @param expectedClassName If non empty, this is checked against the loaded metadata. 155 | * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata 156 | */ 157 | def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { 158 | val metadata = parse(metadataStr) 159 | 160 | implicit val format = DefaultFormats 161 | val className = (metadata \ "class").extract[String] 162 | val uid = (metadata \ "uid").extract[String] 163 | val timestamp = (metadata \ "timestamp").extract[Long] 164 | val sparkVersion = (metadata \ "sparkVersion").extract[String] 165 | val params = metadata \ "paramMap" 166 | if (expectedClassName.nonEmpty) { 167 | require(className == expectedClassName, s"Error loading metadata: Expected class name" + 168 | s" $expectedClassName but found class name $className") 169 | } 170 | 171 | Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr) 172 | } 173 | 174 | /** 175 | * Extract Params from metadata, and set them in the instance. 176 | * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. 177 | * TODO: Move to [[Metadata]] method 178 | */ 179 | def getAndSetParams(instance: Params, metadata: Metadata): Unit = { 180 | implicit val format = DefaultFormats 181 | metadata.params match { 182 | case JObject(pairs) => 183 | pairs.foreach { case (paramName, jsonValue) => 184 | val param = instance.getParam(paramName) 185 | val value = param.jsonDecode(compact(render(jsonValue))) 186 | instance.set(param, value) 187 | } 188 | case _ => 189 | throw new IllegalArgumentException( 190 | s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") 191 | } 192 | } 193 | 194 | /** 195 | * Load a `Params` instance from the given path, and return it. 196 | * This assumes the instance implements [[org.apache.spark.ml.util.MLReadable]]. 197 | */ 198 | def loadParamsInstance[T](path: String, sc: SparkContext): T = { 199 | val metadata = DefaultParamsReader.loadMetadata(path, sc) 200 | val cls = Utils.classForName(metadata.className) 201 | cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/spark_utils/OpenHashSet.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | * 17 | * This code is a modified version of the original Spark 2.1 implementation. 18 | */ 19 | package com.rakuten.dirty_cat.util.collection 20 | 21 | import scala.reflect._ 22 | import com.google.common.hash.Hashing.murmur3_32 23 | 24 | /** 25 | * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never 26 | * removed. 27 | * 28 | * The underlying implementation uses Scala compiler's specialization to generate optimized 29 | * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet 30 | * while incurring much less memory overhead. This can serve as building blocks for higher level 31 | * data structures such as an optimized HashMap. 32 | * 33 | * This OpenHashSet is designed to serve as building blocks for higher level data structures 34 | * such as an optimized hash map. Compared with standard hash set implementations, this class 35 | * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to 36 | * retrieve the position of a key in the underlying array. 37 | * 38 | * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed 39 | * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). 40 | */ 41 | private[dirty_cat] 42 | class OpenHashSet[@specialized(Long, Int) T: ClassTag]( 43 | initialCapacity: Int, 44 | loadFactor: Double) 45 | extends Serializable { 46 | 47 | require(initialCapacity <= OpenHashSet.MAX_CAPACITY, 48 | s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements") 49 | require(initialCapacity >= 1, "Invalid initial capacity") 50 | require(loadFactor < 1.0, "Load factor must be less than 1.0") 51 | require(loadFactor > 0.0, "Load factor must be greater than 0.0") 52 | 53 | import OpenHashSet._ 54 | 55 | def this(initialCapacity: Int) = this(initialCapacity, 0.7) 56 | 57 | def this() = this(64) 58 | 59 | // The following member variables are declared as protected instead of private for the 60 | // specialization to work (specialized class extends the non-specialized one and needs access 61 | // to the "private" variables). 62 | 63 | protected val hasher: Hasher[T] = { 64 | // It would've been more natural to write the following using pattern matching. But Scala 2.9.x 65 | // compiler has a bug when specialization is used together with this pattern matching, and 66 | // throws: 67 | // scala.tools.nsc.symtab.Types$TypeError: type mismatch; 68 | // found : scala.reflect.AnyValManifest[Long] 69 | // required: scala.reflect.ClassTag[Int] 70 | // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) 71 | // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) 72 | // ... 73 | val mt = classTag[T] 74 | if (mt == ClassTag.Long) { 75 | (new LongHasher).asInstanceOf[Hasher[T]] 76 | } else if (mt == ClassTag.Int) { 77 | (new IntHasher).asInstanceOf[Hasher[T]] 78 | } else { 79 | new Hasher[T] 80 | } 81 | } 82 | 83 | protected var _capacity = nextPowerOf2(initialCapacity) 84 | protected var _mask = _capacity - 1 85 | protected var _size = 0 86 | protected var _growThreshold = (loadFactor * _capacity).toInt 87 | 88 | protected var _bitset = new BitSet(_capacity) 89 | 90 | def getBitSet: BitSet = _bitset 91 | 92 | // Init of the array in constructor (instead of in declaration) to work around a Scala compiler 93 | // specialization bug that would generate two arrays (one for Object and one for specialized T). 94 | protected var _data: Array[T] = _ 95 | _data = new Array[T](_capacity) 96 | 97 | /** Number of elements in the set. */ 98 | def size: Int = _size 99 | 100 | /** The capacity of the set (i.e. size of the underlying array). */ 101 | def capacity: Int = _capacity 102 | 103 | /** Return true if this set contains the specified element. */ 104 | def contains(k: T): Boolean = getPos(k) != INVALID_POS 105 | 106 | /** 107 | * Add an element to the set. If the set is over capacity after the insertion, grow the set 108 | * and rehash all elements. 109 | */ 110 | def add(k: T) { 111 | addWithoutResize(k) 112 | rehashIfNeeded(k, grow, move) 113 | } 114 | 115 | def union(other: OpenHashSet[T]): OpenHashSet[T] = { 116 | val iterator = other.iterator 117 | while (iterator.hasNext) { 118 | add(iterator.next()) 119 | } 120 | this 121 | } 122 | 123 | /** 124 | * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. 125 | * The caller is responsible for calling rehashIfNeeded. 126 | * 127 | * Use (retval & POSITION_MASK) to get the actual position, and 128 | * (retval & NONEXISTENCE_MASK) == 0 for prior existence. 129 | * 130 | * @return The position where the key is placed, plus the highest order bit is set if the key 131 | * does not exists previously. 132 | */ 133 | def addWithoutResize(k: T): Int = { 134 | var pos = hashcode(hasher.hash(k)) & _mask 135 | var delta = 1 136 | while (true) { 137 | if (!_bitset.get(pos)) { 138 | // This is a new key. 139 | _data(pos) = k 140 | _bitset.set(pos) 141 | _size += 1 142 | return pos | NONEXISTENCE_MASK 143 | } else if (_data(pos) == k) { 144 | // Found an existing key. 145 | return pos 146 | } else { 147 | // quadratic probing with values increase by 1, 2, 3, ... 148 | pos = (pos + delta) & _mask 149 | delta += 1 150 | } 151 | } 152 | throw new RuntimeException("Should never reach here.") 153 | } 154 | 155 | /** 156 | * Rehash the set if it is overloaded. 157 | * @param k A parameter unused in the function, but to force the Scala compiler to specialize 158 | * this method. 159 | * @param allocateFunc Callback invoked when we are allocating a new, larger array. 160 | * @param moveFunc Callback invoked when we move the key from one position (in the old data array) 161 | * to a new position (in the new data array). 162 | */ 163 | def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { 164 | if (_size > _growThreshold) { 165 | rehash(k, allocateFunc, moveFunc) 166 | } 167 | } 168 | 169 | /** 170 | * Return the position of the element in the underlying array, or INVALID_POS if it is not found. 171 | */ 172 | def getPos(k: T): Int = { 173 | var pos = hashcode(hasher.hash(k)) & _mask 174 | var delta = 1 175 | while (true) { 176 | if (!_bitset.get(pos)) { 177 | return INVALID_POS 178 | } else if (k == _data(pos)) { 179 | return pos 180 | } else { 181 | // quadratic probing with values increase by 1, 2, 3, ... 182 | pos = (pos + delta) & _mask 183 | delta += 1 184 | } 185 | } 186 | throw new RuntimeException("Should never reach here.") 187 | } 188 | 189 | /** Return the value at the specified position. */ 190 | def getValue(pos: Int): T = _data(pos) 191 | 192 | def iterator: Iterator[T] = new Iterator[T] { 193 | var pos = nextPos(0) 194 | override def hasNext: Boolean = pos != INVALID_POS 195 | override def next(): T = { 196 | val tmp = getValue(pos) 197 | pos = nextPos(pos + 1) 198 | tmp 199 | } 200 | } 201 | 202 | /** Return the value at the specified position. */ 203 | def getValueSafe(pos: Int): T = { 204 | assert(_bitset.get(pos)) 205 | _data(pos) 206 | } 207 | 208 | /** 209 | * Return the next position with an element stored, starting from the given position inclusively. 210 | */ 211 | def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) 212 | 213 | /** 214 | * Double the table's size and re-hash everything. We are not really using k, but it is declared 215 | * so Scala compiler can specialize this method (which leads to calling the specialized version 216 | * of putInto). 217 | * 218 | * @param k A parameter unused in the function, but to force the Scala compiler to specialize 219 | * this method. 220 | * @param allocateFunc Callback invoked when we are allocating a new, larger array. 221 | * @param moveFunc Callback invoked when we move the key from one position (in the old data array) 222 | * to a new position (in the new data array). 223 | */ 224 | private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { 225 | val newCapacity = _capacity * 2 226 | require(newCapacity > 0 && newCapacity <= OpenHashSet.MAX_CAPACITY, 227 | s"Can't contain more than ${(loadFactor * OpenHashSet.MAX_CAPACITY).toInt} elements") 228 | allocateFunc(newCapacity) 229 | val newBitset = new BitSet(newCapacity) 230 | val newData = new Array[T](newCapacity) 231 | val newMask = newCapacity - 1 232 | 233 | var oldPos = 0 234 | while (oldPos < capacity) { 235 | if (_bitset.get(oldPos)) { 236 | val key = _data(oldPos) 237 | var newPos = hashcode(hasher.hash(key)) & newMask 238 | var i = 1 239 | var keepGoing = true 240 | // No need to check for equality here when we insert so this has one less if branch than 241 | // the similar code path in addWithoutResize. 242 | while (keepGoing) { 243 | if (!newBitset.get(newPos)) { 244 | // Inserting the key at newPos 245 | newData(newPos) = key 246 | newBitset.set(newPos) 247 | moveFunc(oldPos, newPos) 248 | keepGoing = false 249 | } else { 250 | val delta = i 251 | newPos = (newPos + delta) & newMask 252 | i += 1 253 | } 254 | } 255 | } 256 | oldPos += 1 257 | } 258 | 259 | _bitset = newBitset 260 | _data = newData 261 | _capacity = newCapacity 262 | _mask = newMask 263 | _growThreshold = (loadFactor * newCapacity).toInt 264 | } 265 | 266 | /** 267 | * Re-hash a value to deal better with hash functions that don't differ in the lower bits. 268 | */ 269 | private def hashcode(h: Int): Int = murmur3_32().hashInt(h).asInt() 270 | 271 | private def nextPowerOf2(n: Int): Int = { 272 | val highBit = Integer.highestOneBit(n) 273 | if (highBit == n) n else highBit << 1 274 | } 275 | } 276 | 277 | 278 | private[dirty_cat] 279 | object OpenHashSet { 280 | 281 | val MAX_CAPACITY = 1 << 30 282 | val INVALID_POS = -1 283 | val NONEXISTENCE_MASK = 1 << 31 284 | val POSITION_MASK = (1 << 31) - 1 285 | 286 | /** 287 | * A set of specialized hash function implementation to avoid boxing hash code computation 288 | * in the specialized implementation of OpenHashSet. 289 | */ 290 | sealed class Hasher[@specialized(Long, Int) T] extends Serializable { 291 | def hash(o: T): Int = o.hashCode() 292 | } 293 | 294 | class LongHasher extends Hasher[Long] { 295 | override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt 296 | } 297 | 298 | class IntHasher extends Hasher[Int] { 299 | override def hash(o: Int): Int = o 300 | } 301 | 302 | private def grow1(newSize: Int) {} 303 | private def move1(oldPos: Int, newPos: Int) { } 304 | 305 | private val grow = grow1 _ 306 | private val move = move1 _ 307 | } 308 | 309 | -------------------------------------------------------------------------------- /examples/Midwest_Survey.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pyspark.sql.functions import col, when" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "data = (spark\n", 19 | " .read\n", 20 | " .option(\"header\", \"true\")\n", 21 | " .csv(\"../data/FiveThirtyEight_Midwest_Survey.csv\"))\n", 22 | "data = data.where(col('Location (Census Region)').isNotNull())" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Encoding nulls\n", 32 | "columns_with_null = [\n", 33 | " 'Location (Census Region)',\n", 34 | " 'Gender', 'Age', \n", 35 | " 'Household Income', 'Education']\n", 36 | "for column in columns_with_null:\n", 37 | " data = data.withColumn(column, when(col(column).isNull(), \"__null\")\n", 38 | " .otherwise(col(column)))" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 4, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "root\n", 51 | " |-- RespondentID: string (nullable = true)\n", 52 | " |-- In your own words, what would you call the part of the country you live in now?: string (nullable = true)\n", 53 | " |-- Personally identification as a Midwesterner?: string (nullable = true)\n", 54 | " |-- Illinois in MW?: string (nullable = true)\n", 55 | " |-- Indiana in MW?: string (nullable = true)\n", 56 | " |-- Iowa in MW?: string (nullable = true)\n", 57 | " |-- Kansas in MW?: string (nullable = true)\n", 58 | " |-- Michigan in MW?: string (nullable = true)\n", 59 | " |-- Minnesota in MW?: string (nullable = true)\n", 60 | " |-- Missouri in MW?: string (nullable = true)\n", 61 | " |-- Nebraska in MW?: string (nullable = true)\n", 62 | " |-- North Dakota in MW?: string (nullable = true)\n", 63 | " |-- Ohio in MW?: string (nullable = true)\n", 64 | " |-- South Dakota in MW?: string (nullable = true)\n", 65 | " |-- Wisconsin in MW?: string (nullable = true)\n", 66 | " |-- Arkansas in MW?: string (nullable = true)\n", 67 | " |-- Colorado in MW?: string (nullable = true)\n", 68 | " |-- Kentucky in MW?: string (nullable = true)\n", 69 | " |-- Oklahoma in MW?: string (nullable = true)\n", 70 | " |-- Pennsylvania in MW?: string (nullable = true)\n", 71 | " |-- West Virginia in MW?: string (nullable = true)\n", 72 | " |-- Montana in MW?: string (nullable = true)\n", 73 | " |-- Wyoming in MW?: string (nullable = true)\n", 74 | " |-- ZIP Code: string (nullable = true)\n", 75 | " |-- Gender: string (nullable = true)\n", 76 | " |-- Age: string (nullable = true)\n", 77 | " |-- Household Income: string (nullable = true)\n", 78 | " |-- Education: string (nullable = true)\n", 79 | " |-- Location (Census Region): string (nullable = true)\n", 80 | "\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "data.printSchema()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Splitting data into train and test " 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "data_train, data_test = data.randomSplit([0.6, 0.4], seed=5)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Separating clean, and dirty columns as well a a column we will try to predict" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 6, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "target_column = 'Location (Census Region)'\n", 118 | "dirty_column = 'In your own words, what would you call the part of the country you live in now?'\n", 119 | "clean_columns = [\n", 120 | " 'Personally identification as a Midwesterner?',\n", 121 | " 'Illinois in MW?',\n", 122 | " 'Indiana in MW?',\n", 123 | " 'Kansas in MW?',\n", 124 | " 'Iowa in MW?',\n", 125 | " 'Michigan in MW?',\n", 126 | " 'Minnesota in MW?',\n", 127 | " 'Missouri in MW?',\n", 128 | " 'Nebraska in MW?',\n", 129 | " 'North Dakota in MW?',\n", 130 | " 'Ohio in MW?',\n", 131 | " 'South Dakota in MW?',\n", 132 | " 'Wisconsin in MW?',\n", 133 | " 'Arkansas in MW?',\n", 134 | " 'Colorado in MW?',\n", 135 | " 'Kentucky in MW?',\n", 136 | " 'Oklahoma in MW?',\n", 137 | " 'Pennsylvania in MW?',\n", 138 | " 'West Virginia in MW?',\n", 139 | " 'Montana in MW?',\n", 140 | " 'Wyoming in MW?',\n", 141 | " 'Gender',\n", 142 | " 'Age',\n", 143 | " 'Household Income',\n", 144 | " 'Education'\n", 145 | "]" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "from pyspark.ml import Pipeline\n", 155 | "from pyspark.ml.feature import OneHotEncoder\n", 156 | "from pyspark.ml.feature import StandardScaler, VectorAssembler\n", 157 | "from pyspark.ml.feature import StringIndexer, VectorIndexer\n", 158 | "from pyspark.ml.classification import RandomForestClassifier" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 8, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from dirty_cat_spark.feature.encoder import SimilarityEncoder\n", 168 | "\n", 169 | "\n", 170 | "encoder_similarity = (SimilarityEncoder()\n", 171 | " .setInputCol(dirty_column)\n", 172 | " .setOutputCol(\"encoded\")\n", 173 | " .setSimilarityType(\"nGram\")\n", 174 | " .setVocabSize(200))\n", 175 | "\n", 176 | "string_indexer_dirty = (StringIndexer()\n", 177 | " .setInputCol(dirty_column)\n", 178 | " .setOutputCol(dirty_column + \"_indexed\")\n", 179 | " .setHandleInvalid(\"keep\")) \n", 180 | "\n", 181 | "encoder_hot = (OneHotEncoder()\n", 182 | " .setInputCol(dirty_column + \"_indexed\")\n", 183 | " .setOutputCol(\"encoded\"))" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 9, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "string_indexer_clean = [(StringIndexer()\n", 193 | " .setInputCol(clean_column)\n", 194 | " .setOutputCol(clean_column + \"_indexed\")\n", 195 | " .setHandleInvalid(\"keep\")) \n", 196 | " for clean_column in clean_columns]\n", 197 | "\n", 198 | "assembler = (VectorAssembler()\n", 199 | " .setInputCols([c + \"_indexed\" for c in clean_columns] + [\"encoded\"])\n", 200 | " .setOutputCol(\"features\"))\n", 201 | "\n", 202 | "vector_indexer = (VectorIndexer()\n", 203 | " .setInputCol(\"features\")\n", 204 | " .setOutputCol(\"featuresIndexed\")\n", 205 | " .setMaxCategories(10)\n", 206 | " .setHandleInvalid(\"skip\"))\n", 207 | "\n", 208 | "scaler = (StandardScaler()\n", 209 | " .setInputCol(\"featuresIndexed\")\n", 210 | " .setOutputCol(\"scaledFeatures\")\n", 211 | " .setWithMean(False))\n", 212 | "\n", 213 | "\n", 214 | "indexed_label = StringIndexer(inputCol=target_column, \n", 215 | " outputCol=\"indexedLabel\")\n", 216 | "\n", 217 | "classifier = (RandomForestClassifier()\n", 218 | " .setFeaturesCol(\"scaledFeatures\")\n", 219 | " .setLabelCol(\"indexedLabel\")\n", 220 | " .setNumTrees(10)\n", 221 | " .setSeed(5))" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "pipeline_similarity = Pipeline(stages=string_indexer_clean + \n", 231 | " [encoder_similarity, assembler, \n", 232 | " vector_indexer, \n", 233 | " scaler, indexed_label, classifier])" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 11, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "pipeline_similarity_hot = Pipeline(stages=string_indexer_clean + \n", 243 | " [string_indexer_dirty, encoder_hot, \n", 244 | " assembler, vector_indexer, \n", 245 | " scaler, indexed_label, classifier])" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "pipeline_similarity_hot_model = pipeline_similarity_hot.fit(data_train)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "pipeline_similarity_model = pipeline_similarity.fit(data_train)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "res = pipeline_similarity_model.transform(data_test)\n", 273 | "res_hot = pipeline_similarity_hot_model.transform(data_test)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "res_df = res.select(\"probability\", \"indexedLabel\", \"prediction\").toPandas()\n", 283 | "res_hot_df = res_hot.select(\"probability\", \"indexedLabel\", \"prediction\").toPandas()" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "%matplotlib inline\n", 293 | "import numpy as np\n", 294 | "from matplotlib import pyplot as plt" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "y_true = np.hstack(res_df.indexedLabel.values)\n", 304 | "y_pred = res_df.prediction.values\n", 305 | "\n", 306 | "accuracy = np.sum(y_pred == y_true) / y_true.shape[0]\n", 307 | "\n", 308 | "accuracy" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "y_true = np.hstack(res_hot_df.indexedLabel.values)\n", 318 | "y_hot_pred = res_hot_df.prediction.values\n", 319 | "\n", 320 | "accuracy_hot = np.sum(y_hot_pred == y_true) / y_true.shape[0]\n", 321 | "\n", 322 | "accuracy_hot" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [] 338 | } 339 | ], 340 | "metadata": { 341 | "kernelspec": { 342 | "display_name": "Python 3", 343 | "language": "python", 344 | "name": "python3" 345 | }, 346 | "language_info": { 347 | "codemirror_mode": { 348 | "name": "ipython", 349 | "version": 3 350 | }, 351 | "file_extension": ".py", 352 | "mimetype": "text/x-python", 353 | "name": "python", 354 | "nbconvert_exporter": "python", 355 | "pygments_lexer": "ipython3", 356 | "version": "3.6.4" 357 | } 358 | }, 359 | "nbformat": 4, 360 | "nbformat_minor": 2 361 | } 362 | -------------------------------------------------------------------------------- /src/main/scala/com/rakuten/dirty_cat/features/SimilarityEncoder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Rakuten Institute of Technology and Andres Hoyos-Idrobo 3 | * under one or more contributor license agreements. 4 | * See the NOTICE file distributed with this work for additional information 5 | * regarding copyright ownership. 6 | * Rakuten Institute of technology and Andres Hoyos-Idrobo licenses this file 7 | * to You under the Apache License, Version 2.0 8 | * (the "License"); you may not use this file except in compliance with 9 | * the License. You may obtain a copy of the License at 10 | * 11 | * http://www.apache.org/licenses/LICENSE-2.0 12 | * 13 | * Unless required by applicable law or agreed to in writing, software 14 | * distributed under the License is distributed on an "AS IS" BASIS, 15 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | * See the License for the specific language governing permissions and 17 | * limitations under the License. 18 | */ 19 | 20 | package com.rakuten.dirty_cat.feature 21 | 22 | import com.rakuten.dirty_cat.persistence.{DefaultParamsReader, DefaultParamsWriter} 23 | import com.rakuten.dirty_cat.util.StringSimilarity 24 | import com.rakuten.dirty_cat.util.collection.OpenHashMap 25 | 26 | import org.json4s.JsonAST.{JObject} 27 | import org.apache.spark.ml.linalg.{Vector, Vectors} 28 | // Replace this by VectorUDT 29 | import org.apache.spark.ml.linalg.SQLDataTypes.{VectorType} 30 | 31 | import org.apache.hadoop.fs.Path 32 | 33 | import org.apache.spark.annotation.Since 34 | import org.apache.spark.ml.{Estimator, Model, Transformer} 35 | import org.apache.spark.sql.Row 36 | import org.apache.spark.sql.{Dataset, DataFrame} 37 | // 38 | import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} 39 | import org.apache.spark.ml.feature.{IndexToString, StringIndexer} 40 | 41 | import org.apache.spark.rdd.RDD 42 | import org.apache.spark.sql.functions.{col, udf} 43 | import org.apache.spark.sql.types.{StringType, StructType, StructField} 44 | 45 | import org.apache.spark.ml.util.Identifiable 46 | import org.apache.spark.ml.util.{MLWritable, MLReadable, DefaultParamsWritable, DefaultParamsReadable} 47 | import org.apache.spark.ml.util.{MLWriter, MLReader} 48 | import org.apache.spark.ml.param.shared._ 49 | import org.apache.spark.SparkException 50 | 51 | import org.apache.spark.ml.param.{ParamValidators, ParamPair, Param, Params, ParamMap, DoubleParam, IntParam} 52 | 53 | import scala.collection.JavaConverters._ 54 | import java.lang.{Double => JDouble, Integer => JInt, String => JString} 55 | import java.util.{NoSuchElementException, Map => JMap} 56 | import scala.language.implicitConversions 57 | 58 | 59 | 60 | 61 | private[feature] trait SimilarityBase extends Params with HasInputCol with HasOutputCol with HasHandleInvalid{ 62 | 63 | final val nGramSize = new IntParam(this, "nGramSize", "") 64 | 65 | final val vocabSize = new IntParam(this, "vocabSize", "Number of dimensions of the encoding") 66 | 67 | override val handleInvalid = new Param[String](this, "handleInvalid", 68 | "How to handle invalid data (unseen labels or NULL values). " + 69 | "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + 70 | "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", 71 | ParamValidators.inArray(SimilarityEncoder.supportedHandleInvalids)) 72 | 73 | final val similarityType = new Param[String](this, "similarityType", "" + 74 | s"Supported options: ${SimilarityEncoder.supportedStringOrderType.mkString(", ")}.", 75 | ParamValidators.inArray(SimilarityEncoder.supportedSimilarityType)) 76 | 77 | 78 | final val stringOrderType = new Param[String](this, "stringOrderType", 79 | "How to order labels of string column. " + 80 | "The first label after ordering is assigned an index of 0. " + 81 | s"Supported options: ${SimilarityEncoder.supportedStringOrderType.mkString(", ")}.", 82 | ParamValidators.inArray(SimilarityEncoder.supportedStringOrderType)) 83 | 84 | final def getNGramSize: Int = $(nGramSize) 85 | 86 | final def getVocabSize: Int = $(vocabSize) 87 | 88 | final def getSimilarityType: String = $(similarityType) 89 | 90 | final def getStringOrderType: String = $(stringOrderType) 91 | 92 | 93 | setDefault(nGramSize -> 3, 94 | similarityType -> SimilarityEncoder.nGram, 95 | vocabSize -> 100, 96 | stringOrderType ->SimilarityEncoder.frequencyDesc, 97 | handleInvalid -> SimilarityEncoder.KEEP_INVALID) 98 | 99 | 100 | /** Validates and transforms the input schema. */ 101 | protected def validateAndTransformSchema(schema: StructType): StructType = { 102 | val inputColName = $(inputCol) 103 | val inputDataType = schema(inputColName).dataType 104 | require(inputDataType == StringType, 105 | s"The input column $inputColName must be a string" + 106 | s"but got $inputDataType.") 107 | val inputFields = schema.fields 108 | val outputColName = $(outputCol) 109 | require(inputFields.forall(_.name != outputColName), 110 | s"Output column $outputColName already exists.") 111 | 112 | val outputFields = inputFields :+ new StructField(outputColName, VectorType, true) 113 | 114 | StructType(outputFields) 115 | } 116 | 117 | protected def getCategories(dataset: Dataset[_]): Array[(String, Int)] = { 118 | val inputColName = $(inputCol) 119 | 120 | val labels= (dataset 121 | .select(col($(inputCol)).cast(StringType)) 122 | .na.drop(Array($(inputCol))) 123 | .groupBy($(inputCol)) 124 | .count()) 125 | 126 | // Different options to sort 127 | val vocabulary: Array[Row] = $(stringOrderType) match { 128 | case SimilarityEncoder.frequencyDesc => labels.sort(col("count").desc).take($(vocabSize)) 129 | case SimilarityEncoder.frequencyAsc => labels.sort(col("count")).take($(vocabSize)) 130 | } 131 | 132 | vocabulary.map(row => (row.getAs[String](0), row.getAs[Int](1))).toArray 133 | 134 | } 135 | } 136 | 137 | 138 | class SimilarityEncoder private[dirty_cat] (override val uid: String) extends Estimator[SimilarityEncoderModel] with SimilarityBase { 139 | 140 | def this() = this(Identifiable.randomUID("SimilarityEncoder")) 141 | 142 | def setInputCol(value: String): this.type = set(inputCol, value) 143 | 144 | def setOutputCol(value: String): this.type = set(outputCol, value) 145 | 146 | def setNGramSize(value: Int): this.type = set(nGramSize, value) 147 | 148 | def setVocabSize(value: Int): this.type = set(vocabSize, value) 149 | 150 | def setSimilarityType(value: String): this.type = set(similarityType, value) 151 | 152 | def setStringOrderType(value: String): this.type = set(stringOrderType, value) 153 | 154 | def setHandleInvalid(value: String): this.type = set(handleInvalid, value) 155 | 156 | 157 | override def copy(extra: ParamMap): SimilarityEncoder = { defaultCopy(extra) } 158 | 159 | 160 | override def transformSchema(schema: StructType): StructType = { 161 | validateAndTransformSchema(schema) 162 | } 163 | 164 | override def fit(dataset: Dataset[_]): SimilarityEncoderModel = { 165 | 166 | transformSchema(dataset.schema, logging = true) 167 | 168 | val vocabularyReference = getCategories(dataset).take($(vocabSize)) 169 | 170 | copyValues(new SimilarityEncoderModel(uid, 171 | vocabularyReference.toMap).setParent(this)) 172 | 173 | } 174 | } 175 | 176 | 177 | object SimilarityEncoder extends DefaultParamsReadable[SimilarityEncoder] { 178 | private[feature] val SKIP_INVALID: String = "skip" 179 | private[feature] val KEEP_INVALID: String = "keep" 180 | private[feature] val supportedHandleInvalids: Array[String] = 181 | Array(SKIP_INVALID, KEEP_INVALID) 182 | private[feature] val frequencyDesc: String = "frequencyDesc" 183 | private[feature] val frequencyAsc: String = "frequencyAsc" 184 | private[feature] val supportedStringOrderType: Array[String] = 185 | Array(frequencyDesc, frequencyAsc) 186 | private[feature] val nGram: String = "nGram" 187 | private[feature] val leverstein: String = "leverstein" 188 | private[feature] val jako: String = "jako" 189 | private[feature] val supportedSimilarityType: Array[String] = Array(nGram, leverstein, jako) 190 | 191 | override def load(path: String): SimilarityEncoder = super.load(path) 192 | } 193 | 194 | 195 | 196 | 197 | /* This encoding is an alternative to OneHotEncoder in the case of 198 | dirty categorical variables. */ 199 | class SimilarityEncoderModel private[dirty_cat] (override val uid: String, 200 | val vocabularyReference: Map[String, Int]) extends 201 | Model[SimilarityEncoderModel] with SimilarityBase with MLWritable with Serializable{ 202 | 203 | import SimilarityEncoderModel._ 204 | 205 | // only called in copy() 206 | def this(uid: String) = this(uid, null) 207 | 208 | 209 | private def int2Integer(x: Int) = java.lang.Integer.valueOf(x) 210 | 211 | private def string2String(x: String) = java.lang.String.valueOf(x) 212 | 213 | /* Java-friendly version of [[vocabularyReference]] */ 214 | def javaVocabularyReference: JMap[JString, JInt] = { 215 | vocabularyReference.map{ case (k, v) => string2String(k) -> int2Integer(v) }.asJava} 216 | 217 | 218 | def setInputCol(value: String): this.type = set(inputCol, value) 219 | 220 | def setOutputCol(value: String): this.type = set(outputCol, value) 221 | 222 | def setNGramSize(value: Int): this.type = set(nGramSize, value) 223 | 224 | def setSimilarityType(value: String): this.type = set(similarityType, value) 225 | 226 | 227 | override def transformSchema(schema: StructType): StructType = { 228 | if (schema.fieldNames.contains($(inputCol))) { 229 | validateAndTransformSchema(schema) 230 | } else { 231 | // If the input column does not exist during transformation, we skip 232 | // SimilarityEncoderModel. 233 | schema 234 | } 235 | } 236 | 237 | 238 | override def transform(dataset: Dataset[_]): DataFrame = { 239 | 240 | if (!dataset.schema.fieldNames.contains($(inputCol))) { 241 | logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + 242 | "Skip SimilarityEncoderModel.") 243 | return dataset.toDF 244 | } 245 | 246 | transformSchema(dataset.schema, logging = true) 247 | 248 | val inputColName = $(inputCol) 249 | val outputColName = $(outputCol) 250 | // Transformation 251 | val vocabulary = getCategories(dataset).toArray 252 | val vocabularyLabels = vocabulary.map(_._1).toSeq 253 | val vocabularyReferenceLabels = vocabularyReference.toArray.map(_._1) 254 | 255 | 256 | val similarityValues = $(similarityType) match { 257 | case SimilarityEncoder.nGram => StringSimilarity.getNGramSimilaritySeq(vocabularyLabels, vocabularyReferenceLabels, $(nGramSize)) 258 | case SimilarityEncoder.leverstein => StringSimilarity.getLevenshteinSimilaritySeq(vocabularyLabels, vocabularyReferenceLabels) 259 | case SimilarityEncoder.jako => StringSimilarity.getJaroWinklerSimilaritySeq(vocabularyLabels, vocabularyReferenceLabels) 260 | } 261 | 262 | val labelToEncode: OpenHashMap[String, Array[Double]] = { 263 | val n = vocabularyLabels.length 264 | val map = new OpenHashMap[String, Array[Double]](n) 265 | var i = 0 266 | while (i < n ){ 267 | map.update(vocabularyLabels(i), similarityValues(i).toArray) 268 | i += 1 269 | } 270 | map 271 | } 272 | 273 | val (filteredDataset, keepInvalid) = $(handleInvalid) match { 274 | case SimilarityEncoder.SKIP_INVALID => 275 | val filterer = udf { label: String => 276 | labelToEncode.contains(label) 277 | } 278 | (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) 279 | case _ => (dataset, getHandleInvalid == SimilarityEncoder.KEEP_INVALID) 280 | } 281 | 282 | val emptyValues = Vectors.dense(Array.fill($(vocabSize))(0D)) 283 | 284 | val dirtyCatUDF = udf { label: String => 285 | if (label == null) { 286 | if (keepInvalid) { 287 | emptyValues 288 | } else { 289 | throw (new SparkException("SimilarityEncoder encountered NULL value. To handle or skip " + 290 | "NULLS, try setting SimilarityEncoder.handleInvalid.")) 291 | } 292 | } else { 293 | if (labelToEncode.contains(label)) { 294 | Vectors.dense(labelToEncode(label)) 295 | } else if (keepInvalid) { 296 | emptyValues 297 | } else { 298 | throw (new SparkException(s"Unseen label: $label. To handle unseen labels, " + 299 | s"set Param handleInvalid to ${SimilarityEncoder.KEEP_INVALID}.")) 300 | } 301 | } 302 | }.asNondeterministic() 303 | 304 | filteredDataset.withColumn($(outputCol), dirtyCatUDF(col($(inputCol)))) 305 | } 306 | 307 | override def copy(extra: ParamMap) = { defaultCopy(extra) } 308 | 309 | 310 | override def write: MLWriter = new SimilarityEncoderModelWriter(this) 311 | 312 | } 313 | 314 | 315 | /// Add read and write 316 | /** [[MLWriter]] instance for [[SimilarityEncoderModel]] */ 317 | object SimilarityEncoderModel extends MLReadable[SimilarityEncoderModel] { 318 | // 319 | override def read: MLReader[SimilarityEncoderModel] = new SimilarityEncoderModelReader 320 | 321 | override def load(path: String): SimilarityEncoderModel = super.load(path) 322 | 323 | private[SimilarityEncoderModel] 324 | class SimilarityEncoderModelWriter(instance: SimilarityEncoderModel) extends MLWriter { 325 | 326 | private case class Data(vocabularyReference: Map[String, Int]) 327 | 328 | override protected def saveImpl(path: String): Unit = { 329 | 330 | implicit val sparkSession = super.sparkSession 331 | implicit val sc = sparkSession.sparkContext 332 | 333 | // Save metadata and Params 334 | DefaultParamsWriter.saveMetadata(instance, path, sc) 335 | // Save model data 336 | val data = Data(instance.vocabularyReference) 337 | val dataPath = new Path(path, "data").toString 338 | 339 | sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) 340 | } 341 | } 342 | 343 | 344 | private class SimilarityEncoderModelReader extends MLReader[SimilarityEncoderModel] { 345 | 346 | private val className = classOf[SimilarityEncoderModel].getName 347 | 348 | override def load(path: String): SimilarityEncoderModel = { 349 | 350 | implicit val sc = super.sparkSession.sparkContext 351 | 352 | val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 353 | 354 | val dataPath = new Path(path, "data").toString 355 | val data = sparkSession.read.parquet(dataPath).select("vocabularyReference") 356 | 357 | val vocabularyReference = data.head().getAs[Map[String, Int]](0) 358 | val model = new SimilarityEncoderModel(metadata.uid, vocabularyReference) 359 | DefaultParamsReader.getAndSetParams(model, metadata) 360 | 361 | model 362 | } 363 | } 364 | 365 | 366 | } 367 | 368 | 369 | 370 | 371 | --------------------------------------------------------------------------------