├── .travis.yml ├── spark-transformers-logo.png ├── adapters-2.0 ├── src │ ├── main │ │ └── java │ │ │ └── com │ │ │ └── flipkart │ │ │ └── fdp │ │ │ └── ml │ │ │ ├── utils │ │ │ └── Constants.java │ │ │ ├── adapter │ │ │ ├── AbstractModelInfoAdapter.java │ │ │ ├── ModelInfoAdapter.java │ │ │ ├── StringMergeInfoAdapter.java │ │ │ ├── StringSanitizerModelInfoAdapter.java │ │ │ ├── PopularWordsEstimatorModelInfoAdapter.java │ │ │ ├── HashingTFModelInfoAdapter.java │ │ │ ├── BucketizerModelInfoAdapter.java │ │ │ ├── PipelineModelInfoAdapter.java │ │ │ ├── VectorAssemblerModelAdapter.java │ │ │ ├── ChiSqSelectorModelInfoAdapter.java │ │ │ ├── CountVectorizerModelInfoAdapter.java │ │ │ ├── RegexTokenizerModelInfoAdapter.java │ │ │ ├── MinMaxScalerModelInfoAdapter.java │ │ │ ├── StandardScalerModelInfoAdapter.java │ │ │ ├── StringIndexerModelInfoAdapter.java │ │ │ ├── DecisionTreeRegressionModelInfoAdapter.java │ │ │ ├── LogisticRegressionModelInfoAdapter.java │ │ │ ├── LogisticRegressionModelInfoAdapter1.java │ │ │ ├── CommonAddressFeaturesModelInfoAdapter.java │ │ │ ├── DecisionTreeClassificationModelInfoAdapter.java │ │ │ └── GradientBoostClassificationModelInfoAdapter.java │ │ │ └── export │ │ │ └── ModelExporter.java │ └── test │ │ └── java │ │ └── com │ │ └── flipkart │ │ └── fdp │ │ └── ml │ │ ├── adapter │ │ ├── SparkTestBase.java │ │ ├── LogisticRegressionBridgeTest.java │ │ ├── LogisticRegression1BridgeTest.java │ │ ├── StringMergeBridgeTest.java │ │ ├── GradientBoostClassificationModelTest.java │ │ ├── DecisionTreeRegressionModelBridgeTest.java │ │ ├── StringIndexerBridgeTest.java │ │ ├── RegexTokenizerBridgeTest.java │ │ ├── StringSanitizerBridgeTest.java │ │ └── DecisionTreeClassificationModelBridgeTest.java │ │ └── export │ │ ├── LogisticRegressionExporterTest.java │ │ └── LogisticRegression1ExporterTest.java └── pom.xml ├── adapters-1.6 ├── src │ ├── main │ │ └── java │ │ │ └── com │ │ │ └── flipkart │ │ │ └── fdp │ │ │ └── ml │ │ │ ├── utils │ │ │ └── Constants.java │ │ │ └── adapter │ │ │ ├── ModelInfoAdapter.java │ │ │ ├── Log1PScalerModelInfoAdapter.java │ │ │ ├── AbstractModelInfoAdapter.java │ │ │ ├── AlgebraicTransformModelInfoAdapter.java │ │ │ ├── CustomOneHotEncoderModelInfoAdapter.java │ │ │ ├── HashingTFModelInfoAdapter.java │ │ │ ├── BucketizerModelInfoAdapter.java │ │ │ ├── IfZeroVectorModelInfoAdapter.java │ │ │ ├── PipelineModelInfoAdapter.java │ │ │ ├── VectorAssemblerModelAdapter.java │ │ │ ├── ChiSqSelectorModelInfoAdapter.java │ │ │ ├── CountVectorizerModelInfoAdapter.java │ │ │ ├── ProbabilityTransformModelInfoAdapter.java │ │ │ ├── FillNAValuesTransformerModelInfoAdapter.java │ │ │ ├── VectorBinarizerModelAdapter.java │ │ │ ├── RegexTokenizerModelInfoAdapter.java │ │ │ ├── MinMaxScalerModelInfoAdapter.java │ │ │ ├── StandardScalerModelInfoAdapter.java │ │ │ ├── StringIndexerModelInfoAdapter.java │ │ │ ├── DecisionTreeRegressionModelInfoAdapter.java │ │ │ ├── LogisticRegressionModelInfoAdapter.java │ │ │ ├── LogisticRegressionModelInfoAdapter1.java │ │ │ ├── DecisionTreeClassificationModelInfoAdapter.java │ │ │ ├── OneHotEncoderModelInfoAdapter.java │ │ │ └── GradientBoostClassificationModelInfoAdapter.java │ └── test │ │ └── java │ │ └── com │ │ └── flipkart │ │ └── fdp │ │ └── ml │ │ ├── adapter │ │ ├── SparkTestBase.java │ │ ├── LogisticRegressionBridgeTest.java │ │ ├── LogisticRegression1BridgeTest.java │ │ └── RegexTokenizerBridgeTest.java │ │ └── export │ │ ├── LogisticRegressionExporterTest.java │ │ └── LogisticRegression1ExporterTest.java └── pom.xml ├── models-info ├── src │ └── main │ │ └── java │ │ └── com │ │ └── flipkart │ │ └── fdp │ │ └── ml │ │ ├── modelinfo │ │ ├── StringMergeModelInfo.java │ │ ├── StringSanitizerModelInfo.java │ │ ├── AbstractModelInfo.java │ │ ├── VectorAssemblerModelInfo.java │ │ ├── AlgebraicTransformModelInfo.java │ │ ├── Log1PScalerModelInfo.java │ │ ├── HashingTFModelInfo.java │ │ ├── BucketizerModelInfo.java │ │ ├── PipelineModelInfo.java │ │ ├── OneHotEncoderModelInfo.java │ │ ├── ChiSqSelectorModelInfo.java │ │ ├── ProbabilityTransformModelInfo.java │ │ ├── ModelInfo.java │ │ ├── IfZeroVectorModelInfo.java │ │ ├── PopularWordsEstimatorModelInfo.java │ │ ├── CountVectorizerModelInfo.java │ │ ├── MinMaxScalerModelInfo.java │ │ ├── StandardScalerModelInfo.java │ │ ├── RegexTokenizerModelInfo.java │ │ ├── VectorBinarizerModelInfo.java │ │ ├── StringIndexerModelInfo.java │ │ ├── FillNAValuesTransformerModelInfo.java │ │ ├── LogisticRegressionModelInfo.java │ │ ├── CommonAddressFeaturesModelInfo.java │ │ ├── GradientBoostModelInfo.java │ │ ├── RandomForestModelInfo.java │ │ └── DecisionTreeModelInfo.java │ │ ├── transformer │ │ ├── Transformer.java │ │ ├── StringMergeTransformer.java │ │ ├── PopularWordsEstimatorTransformer.java │ │ ├── StringSanitizerTransformer.java │ │ ├── Log1PScalerTransformer.java │ │ ├── ChiSqSelectorTransformer.java │ │ ├── ProbabilityTransformTransformer.java │ │ ├── OneHotEncoderTransformer.java │ │ ├── AlgebraicTransformTransformer.java │ │ ├── PipelineModelTransformer.java │ │ ├── FillNAValuesTransformer.java │ │ ├── VectorBinarizerTranformer.java │ │ ├── HashingTFTransformer.java │ │ ├── IfZeroVectorTransformer.java │ │ ├── StringIndexerTransformer.java │ │ ├── LogisticRegressionTransformer.java │ │ ├── BucketizerTransformer.java │ │ ├── GradientBoostClassificationTransformer.java │ │ ├── MinMaxScalerTransformer.java │ │ ├── StandardScalerTransformer.java │ │ ├── RegexTokenizerTransformer.java │ │ ├── VectorAssemblerTransformer.java │ │ └── CountVectorizerTransformer.java │ │ ├── importer │ │ ├── SerializationConstants.java │ │ └── ModelImporter.java │ │ └── utils │ │ └── PipelineUtils.java └── pom.xml ├── .gitignore ├── custom-transformer-2.0 ├── src │ └── main │ │ └── scala │ │ └── com │ │ └── flipkart │ │ └── transformer │ │ ├── common │ │ ├── HasRawInputCol.scala │ │ ├── HasOutputCol.scala │ │ └── HasInputCol.scala │ │ └── ml │ │ ├── StringSanitizer.scala │ │ └── StringMerge.scala └── pom.xml ├── custom-transformer ├── src │ └── main │ │ └── scala │ │ └── com │ │ └── flipkart │ │ └── fdp │ │ └── ml │ │ ├── FillNAValuesTransformer.scala │ │ └── Log1PScaler.scala └── pom.xml └── README.md /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | jdk: 3 | - oraclejdk8 4 | install: mvn -q clean compile 5 | script: mvn -q test 6 | -------------------------------------------------------------------------------- /spark-transformers-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flipkart-incubator/spark-transformers/HEAD/spark-transformers-logo.png -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/utils/Constants.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.utils; 2 | 3 | public class Constants { 4 | public static final String SUPPORTED_SPARK_VERSION_PREFIX = "2.0"; 5 | public static final String LIBRARY_VERSION = "1.0"; 6 | } 7 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/utils/Constants.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.utils; 2 | 3 | import java.io.Serializable; 4 | 5 | public class Constants implements Serializable { 6 | public static final String SUPPORTED_SPARK_VERSION_PREFIX = "1.6"; 7 | public static final String LIBRARY_VERSION = "1.0"; 8 | } 9 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/StringMergeModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.StringMergeTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | 6 | public class StringMergeModelInfo extends AbstractModelInfo{ 7 | @Override 8 | public Transformer getTransformer() { 9 | return new StringMergeTransformer(this); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | models-info/target/ 3 | adapters-1.6/target/ 4 | adapters-2.0/target/ 5 | adapters-2.0/spark-warehouse/ 6 | target/ 7 | 8 | #intellij specific files 9 | .idea/ 10 | *.iml 11 | *.ipr 12 | *.iws 13 | # Mobile Tools for Java (J2ME) 14 | .mtj.tmp/ 15 | 16 | # Package Files # 17 | *.jar 18 | *.war 19 | *.ear 20 | 21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 22 | hs_err_pid* 23 | -------------------------------------------------------------------------------- /custom-transformer-2.0/src/main/scala/com/flipkart/transformer/common/HasRawInputCol.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.transformer.common 2 | 3 | import org.apache.spark.ml.param.{Param, Params} 4 | 5 | /** 6 | * Created by gaurav.prasad on 08/11/16. 7 | */ 8 | trait HasRawInputCol extends Params { 9 | val rawInputCol: Param[String] = new Param[String](this, "rawInputCol", "Raw words.") 10 | 11 | def getRawInputCol = $(rawInputCol) 12 | } 13 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/StringSanitizerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.StringSanitizerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | 6 | public class StringSanitizerModelInfo extends AbstractModelInfo { 7 | @Override 8 | public Transformer getTransformer() { 9 | return new StringSanitizerTransformer(this); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /custom-transformer-2.0/src/main/scala/com/flipkart/transformer/common/HasOutputCol.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.transformer.common 2 | 3 | import org.apache.spark.ml.param.{Param, Params} 4 | 5 | /** 6 | * Created by gaurav.prasad on 08/11/16. 7 | */ 8 | trait HasOutputCol extends Params { 9 | val outputCol: Param[String] = new Param[String](this, "outputCol", "Will have the fraction of common words.") 10 | 11 | def getOutputCol = $(outputCol) 12 | } 13 | -------------------------------------------------------------------------------- /custom-transformer-2.0/src/main/scala/com/flipkart/transformer/common/HasInputCol.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.transformer.common 2 | 3 | import org.apache.spark.ml.param.{Param, Params} 4 | 5 | /** 6 | * Created by gaurav.prasad on 08/11/16. 7 | */ 8 | /* 9 | * Transformers and Estimators 10 | * */ 11 | trait HasInputCol extends Params { 12 | val inputCol: Param[String] = new Param[String](this, "inputCol", "Input should have sanitized split words.") 13 | 14 | def getInputCol = $(inputCol) 15 | } 16 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/AbstractModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import lombok.Getter; 4 | import lombok.Setter; 5 | 6 | import java.io.Serializable; 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | @Getter 11 | @Setter 12 | public abstract class AbstractModelInfo implements ModelInfo , Serializable { 13 | private Set inputKeys = new LinkedHashSet<>(); 14 | private Set outputKeys = new LinkedHashSet<>(); 15 | } 16 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/VectorAssemblerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.Transformer; 4 | import com.flipkart.fdp.ml.transformer.VectorAssemblerTransformer; 5 | 6 | /** 7 | * Represents information for VectorAssembler model 8 | *

9 | * Created by rohan.shetty on 28/03/16. 10 | */ 11 | public class VectorAssemblerModelInfo extends AbstractModelInfo { 12 | @Override 13 | public Transformer getTransformer() { 14 | return new VectorAssemblerTransformer(this); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/AlgebraicTransformModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.AlgebraicTransformTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Created by shubhranshu.shekhar on 18/08/16. 9 | */ 10 | @Data 11 | public class AlgebraicTransformModelInfo extends AbstractModelInfo{ 12 | private double[] coefficients; 13 | 14 | @Override 15 | public Transformer getTransformer() { 16 | return new AlgebraicTransformTransformer(this); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/AbstractModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 4 | 5 | 6 | public abstract class AbstractModelInfoAdapter implements ModelInfoAdapter { 7 | 8 | @Override 9 | public T adapt(F from) { 10 | return getModelInfo(from); 11 | } 12 | 13 | /** 14 | * @param from source object in spark's mllib 15 | * @return returns the corresponding {@link ModelInfo} object that represents the model information 16 | */ 17 | abstract T getModelInfo(F from); 18 | 19 | } 20 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/Log1PScalerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.Log1PScalerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a LogScaler model 9 | */ 10 | @Data 11 | public class Log1PScalerModelInfo extends AbstractModelInfo { 12 | 13 | /** 14 | * @return an corresponding {@link Log1PScalerTransformer} for this model info 15 | */ 16 | @Override 17 | public Transformer getTransformer() { 18 | return new Log1PScalerTransformer(this); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/HashingTFModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.HashingTFTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a HashingTF model 9 | */ 10 | @Data 11 | public class HashingTFModelInfo extends AbstractModelInfo { 12 | private int numFeatures; 13 | 14 | /** 15 | * @return an corresponding {@link HashingTFTransformer} for this model info 16 | */ 17 | @Override 18 | public Transformer getTransformer() { 19 | return new HashingTFTransformer(this); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/BucketizerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.BucketizerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a Bucketizer model 9 | */ 10 | @Data 11 | public class BucketizerModelInfo extends AbstractModelInfo { 12 | 13 | private double[] splits; 14 | 15 | /** 16 | * @return an corresponding {@link BucketizerTransformer} for this model info 17 | */ 18 | @Override 19 | public Transformer getTransformer() { 20 | return new BucketizerTransformer(this); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/PipelineModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.PipelineModelTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a pipeline model 9 | */ 10 | @Data 11 | public class PipelineModelInfo extends AbstractModelInfo { 12 | 13 | private ModelInfo stages[]; 14 | 15 | /** 16 | * @return an corresponding {@link PipelineModelTransformer} for this model info 17 | */ 18 | @Override 19 | public Transformer getTransformer() { 20 | return new PipelineModelTransformer(this); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/OneHotEncoderModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.OneHotEncoderTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a one hot encoder model 9 | */ 10 | @Data 11 | public class OneHotEncoderModelInfo extends AbstractModelInfo { 12 | 13 | private int numTypes; 14 | 15 | /** 16 | * @return an corresponding {@link OneHotEncoderTransformer} for this model info 17 | */ 18 | @Override 19 | public Transformer getTransformer() { 20 | return new OneHotEncoderTransformer(this); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/ChiSqSelectorModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.ChiSqSelectorTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a ChiSqSelector model 9 | */ 10 | @Data 11 | public class ChiSqSelectorModelInfo extends AbstractModelInfo { 12 | private int[] selectedFeatures; 13 | 14 | /** 15 | * @return an corresponding {@link ChiSqSelectorTransformer} for this model info 16 | */ 17 | @Override 18 | public Transformer getTransformer() { 19 | return new ChiSqSelectorTransformer(this); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/ProbabilityTransformModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | /** 4 | * Created by shubhranshu.shekhar on 18/08/16. 5 | */ 6 | import com.flipkart.fdp.ml.transformer.ProbabilityTransformTransformer; 7 | import com.flipkart.fdp.ml.transformer.Transformer; 8 | import lombok.Data; 9 | 10 | @Data 11 | public class ProbabilityTransformModelInfo extends AbstractModelInfo { 12 | private double actualClickProportion; 13 | private double underSampledClickProportion; 14 | private int probIndex; 15 | 16 | @Override 17 | public Transformer getTransformer() { 18 | return new ProbabilityTransformTransformer(this); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/ModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.Transformer; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * This interface represents information of a model. The implementors of this class should capture 9 | * the information(coefficients) of a model and the corresponding {@link Transformer} would use that 10 | * information for prediction/transformation 11 | */ 12 | public interface ModelInfo extends Serializable { 13 | 14 | /** 15 | * @return {@link Transformer} that will use the information(coefficients) of this model 16 | * to transform input 17 | */ 18 | Transformer getTransformer(); 19 | } 20 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/IfZeroVectorModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.IfZeroVectorTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a LogScaler model 9 | */ 10 | @Data 11 | public class IfZeroVectorModelInfo extends AbstractModelInfo { 12 | 13 | private String thenSetValue; 14 | 15 | private String elseSetCol; 16 | 17 | /** 18 | * @return an corresponding {@link IfZeroVectorTransformer} for this model info 19 | */ 20 | @Override 21 | public Transformer getTransformer() { 22 | return new IfZeroVectorTransformer(this); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/PopularWordsEstimatorModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.PopularWordsEstimatorTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | 6 | import java.util.HashSet; 7 | 8 | public class PopularWordsEstimatorModelInfo extends AbstractModelInfo { 9 | private HashSet popularWords; 10 | 11 | @Override 12 | public Transformer getTransformer() { 13 | return new PopularWordsEstimatorTransformer(this); 14 | } 15 | 16 | public HashSet getPopularWords() { 17 | return popularWords; 18 | } 19 | 20 | public void setPopularWords(HashSet popularWords) { 21 | this.popularWords = popularWords; 22 | } 23 | } -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/CountVectorizerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.CountVectorizerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a CountVectorizer model 9 | */ 10 | @Data 11 | public class CountVectorizerModelInfo extends AbstractModelInfo { 12 | 13 | private double minTF; 14 | private String[] vocabulary; 15 | 16 | /** 17 | * @return an corresponding {@link CountVectorizerTransformer} for this model info 18 | */ 19 | @Override 20 | public Transformer getTransformer() { 21 | return new CountVectorizerTransformer(this); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/MinMaxScalerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.MinMaxScalerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a MinMaxScaler model 9 | */ 10 | 11 | @Data 12 | public class MinMaxScalerModelInfo extends AbstractModelInfo { 13 | private double[] originalMin, originalMax; 14 | private double min, max; 15 | 16 | /** 17 | * @return an corresponding {@link MinMaxScalerTransformer} for this model info 18 | */ 19 | @Override 20 | public Transformer getTransformer() { 21 | return new MinMaxScalerTransformer(this); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/StandardScalerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.StandardScalerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a StandardScaler model 9 | */ 10 | 11 | @Data 12 | public class StandardScalerModelInfo extends AbstractModelInfo { 13 | private double[] std, mean; 14 | private boolean withStd, withMean; 15 | 16 | /** 17 | * @return an corresponding {@link StandardScalerTransformer} for this model info 18 | */ 19 | @Override 20 | public Transformer getTransformer() { 21 | return new StandardScalerTransformer(this); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/Transformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | import java.util.Set; 6 | 7 | /** 8 | * This interface represents a capability of a class to transform the input using a suitable model 9 | * representation captured in {@link com.flipkart.fdp.ml.modelinfo.ModelInfo}. 10 | */ 11 | public interface Transformer extends Serializable { 12 | 13 | /** 14 | * @param input values as map of (String, Object) for the transformation 15 | * similar to the lines of a dataframe. 16 | */ 17 | public void transform(Map input); 18 | 19 | public Set getInputKeys(); 20 | 21 | public Set getOutputKeys(); 22 | } 23 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/RegexTokenizerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.RegexTokenizerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a RegexTokenizer model 9 | */ 10 | 11 | @Data 12 | public class RegexTokenizerModelInfo extends AbstractModelInfo { 13 | private int minTokenLength; 14 | private boolean gaps, toLowercase; 15 | private String pattern; 16 | 17 | /** 18 | * @return an corresponding {@link RegexTokenizerTransformer} for this model info 19 | */ 20 | @Override 21 | public Transformer getTransformer() { 22 | return new RegexTokenizerTransformer(this); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/VectorBinarizerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | 4 | 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import com.flipkart.fdp.ml.transformer.VectorBinarizerTranformer; 7 | import lombok.Data; 8 | 9 | /** 10 | * Represents information for a Vector Binarizer model 11 | * Created by karan.verma on 09/11/16. 12 | */ 13 | @Data 14 | public class VectorBinarizerModelInfo extends AbstractModelInfo { 15 | 16 | private double threshold; 17 | 18 | /** 19 | * @return an corresponding {@link com.flipkart.fdp.ml.transformer.VectorBinarizerTranformer} for this model info 20 | */ 21 | @Override 22 | public Transformer getTransformer() { 23 | return new VectorBinarizerTranformer(this); 24 | } 25 | 26 | } 27 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/StringIndexerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.StringIndexerTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | /** 11 | * Represents information for StringIndexer model 12 | */ 13 | @Data 14 | public class StringIndexerModelInfo extends AbstractModelInfo { 15 | 16 | private Map labelToIndex = new HashMap(); 17 | private boolean failOnUnseenValues = true; 18 | 19 | /** 20 | * @return an corresponding {@link StringIndexerTransformer} for this model info 21 | */ 22 | @Override 23 | public Transformer getTransformer() { 24 | return new StringIndexerTransformer(this); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/FillNAValuesTransformerModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.FillNAValuesTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | 11 | @Data 12 | public class FillNAValuesTransformerModelInfo extends AbstractModelInfo { 13 | 14 | //TODO: types are inferred during deserialization. Integers are being inferred as doubles. Verification is needed if it is a problem 15 | private Map naValuesMap = new HashMap<>(); 16 | /** 17 | * @return an corresponding {@link FillNAValuesTransformer} for this model info 18 | */ 19 | @Override 20 | public Transformer getTransformer() { 21 | return new FillNAValuesTransformer(this); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/ModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 4 | 5 | /** 6 | * Transforms a model ( eg spark's models in MlLib ) to {@link ModelInfo} object 7 | * that can be exported via {@link com.flipkart.fdp.ml.export.ModelExporter} 8 | */ 9 | public interface ModelInfoAdapter { 10 | 11 | /** 12 | * @param from source object in spark's mllib 13 | * @return returns the corresponding {@link ModelInfo} object that represents the model information 14 | */ 15 | T adapt(F from); 16 | 17 | /** 18 | * @return Get the source class which is being adapted from. 19 | */ 20 | Class getSource(); 21 | 22 | /** 23 | * @return Get the target adaptor class which is being adapted to. 24 | */ 25 | Class getTarget(); 26 | } 27 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/LogisticRegressionModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.LogisticRegressionTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | /** 8 | * Represents information for a Logistic Regression model 9 | */ 10 | @Data 11 | public class LogisticRegressionModelInfo extends AbstractModelInfo { 12 | private double[] weights; 13 | private double intercept; 14 | private int numClasses; 15 | private int numFeatures; 16 | private double threshold; 17 | 18 | private String probabilityKey = "probability"; 19 | 20 | /** 21 | * @return an corresponding {@link LogisticRegressionTransformer} for this model info 22 | */ 23 | @Override 24 | public Transformer getTransformer() { 25 | return new LogisticRegressionTransformer(this); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/importer/SerializationConstants.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.importer; 2 | 3 | import java.io.Serializable; 4 | import java.nio.charset.Charset; 5 | 6 | /** 7 | * Class holding constants used in serialization 8 | */ 9 | public class SerializationConstants implements Serializable { 10 | public static final Charset CHARSET = Charset.forName("UTF-8"); 11 | //key to identify type in serialized format 12 | public static final String TYPE_IDENTIFIER = "_class"; 13 | //key to identify model info payload in serialized format 14 | public static final String MODEL_INFO_IDENTIFIER = "_model_info"; 15 | //key to identify the spark version it was imported from 16 | public static final String SPARK_VERSION = "_spark_version"; 17 | //key to identify the exporter library version it was exported with 18 | public static final String EXPORTER_LIBRARY_VERSION = "_exporter_library_version"; 19 | } 20 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/CommonAddressFeaturesModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.CommonAddressFeaturesTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | import java.util.HashSet; 8 | import java.util.List; 9 | 10 | @Data 11 | public class CommonAddressFeaturesModelInfo extends AbstractModelInfo { 12 | private String mergedAddressParam; 13 | private String sanitizedAddressParam; 14 | 15 | private String numWordsParam; 16 | private String numCommasParam; 17 | private String numericPresentParam; 18 | private String addressLengthParam; 19 | private String favouredStartColParam; 20 | private String unfavouredStartColParam; 21 | 22 | private HashSet favourableStarts; 23 | private HashSet unFavourableStarts; 24 | 25 | @Override 26 | public Transformer getTransformer() { 27 | return new CommonAddressFeaturesTransformer(this); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/StringMergeInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringMergeModelInfo; 4 | import com.flipkart.transformer.ml.StringMerge; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | public class StringMergeInfoAdapter extends AbstractModelInfoAdapter { 10 | @Override 11 | StringMergeModelInfo getModelInfo(StringMerge from) { 12 | StringMergeModelInfo modelInfo = new StringMergeModelInfo(); 13 | 14 | Set inputKeys = new LinkedHashSet<>(); 15 | inputKeys.add(from.getInputCol1()); 16 | inputKeys.add(from.getInputCol2()); 17 | modelInfo.setInputKeys(inputKeys); 18 | 19 | Set outputKeys = new LinkedHashSet<>(); 20 | outputKeys.add(from.getOutputCol()); 21 | modelInfo.setOutputKeys(outputKeys); 22 | return modelInfo; 23 | } 24 | 25 | @Override 26 | public Class getSource() { 27 | return StringMerge.class; 28 | } 29 | 30 | @Override 31 | public Class getTarget() { 32 | return StringMergeModelInfo.class; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/StringSanitizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringSanitizerModelInfo; 4 | import com.flipkart.transformer.ml.StringSanitizer; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | public class StringSanitizerModelInfoAdapter extends AbstractModelInfoAdapter { 10 | @Override 11 | StringSanitizerModelInfo getModelInfo(StringSanitizer from) { 12 | StringSanitizerModelInfo modelInfo = new StringSanitizerModelInfo(); 13 | 14 | Set inputKeys = new LinkedHashSet<>(); 15 | inputKeys.add(from.getInputCol()); 16 | modelInfo.setInputKeys(inputKeys); 17 | 18 | Set outputKeys = new LinkedHashSet<>(); 19 | outputKeys.add(from.getOutputCol()); 20 | modelInfo.setOutputKeys(outputKeys); 21 | return modelInfo; 22 | } 23 | 24 | @Override 25 | public Class getSource() { 26 | return StringSanitizer.class; 27 | } 28 | 29 | @Override 30 | public Class getTarget() { 31 | return StringSanitizerModelInfo.class; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/ModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 4 | import org.apache.spark.sql.DataFrame; 5 | 6 | import java.io.Serializable; 7 | 8 | /** 9 | * Transforms a model ( eg spark's models in MlLib ) to {@link com.flipkart.fdp.ml.modelinfo.ModelInfo} object 10 | * that can be exported via {@link com.flipkart.fdp.ml.export.ModelExporter} 11 | */ 12 | public interface ModelInfoAdapter extends Serializable { 13 | 14 | /** 15 | * @param from source object in spark's mllib 16 | * @param df Data frame that is used for training is required for some models as state information is being stored as column metadata by spark models 17 | * @return returns the corresponding {@link ModelInfo} object that represents the model information 18 | */ 19 | T adapt(F from, DataFrame df); 20 | 21 | /** 22 | * @return Get the source class which is being adapted from. 23 | */ 24 | Class getSource(); 25 | 26 | /** 27 | * @return Get the target adaptor class which is being adapted to. 28 | */ 29 | Class getTarget(); 30 | } 31 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/StringMergeTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringMergeModelInfo; 4 | 5 | import java.util.Iterator; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | public class StringMergeTransformer implements Transformer { 10 | private StringMergeModelInfo modelInfo; 11 | 12 | public StringMergeTransformer(final StringMergeModelInfo modelInfo) { 13 | this.modelInfo = modelInfo; 14 | } 15 | 16 | @Override 17 | public void transform(Map input) { 18 | Iterator iterator = modelInfo.getInputKeys().iterator(); 19 | 20 | String input1 = (String) input.get(iterator.next()); 21 | String input2 = (String) input.get(iterator.next()); 22 | input.put(modelInfo.getOutputKeys().iterator().next(), transformInput(input1, input2)); 23 | } 24 | 25 | private String transformInput(String input1, String input2) { 26 | return (input1 + " " + input2).trim(); 27 | } 28 | 29 | @Override 30 | public Set getInputKeys() { 31 | return modelInfo.getInputKeys(); 32 | } 33 | 34 | @Override 35 | public Set getOutputKeys() { 36 | return modelInfo.getOutputKeys(); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/GradientBoostModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.GradientBoostClassificationTransformer; 4 | import com.flipkart.fdp.ml.transformer.RandomForestTransformer; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import lombok.Data; 7 | 8 | import java.util.ArrayList; 9 | import java.util.List; 10 | 11 | /** 12 | * Represents information for a Random Forest model 13 | */ 14 | 15 | @Data 16 | public class GradientBoostModelInfo extends AbstractModelInfo { 17 | 18 | private boolean regression; 19 | private int numFeatures; 20 | private List trees = new ArrayList<>(); 21 | private List treeWeights = new ArrayList<>(); 22 | 23 | private String probabilityKey = "probability"; 24 | private String rawPredictionKey = "rawPrediction"; 25 | 26 | /** 27 | * @return an corresponding {@link RandomForestTransformer} for this model info 28 | */ 29 | @Override 30 | public Transformer getTransformer() { 31 | return new GradientBoostClassificationTransformer(this); 32 | } 33 | 34 | public boolean isClassification() { 35 | return !regression; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/SparkTestBase.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import org.apache.spark.api.java.JavaSparkContext; 4 | import org.apache.spark.sql.Row; 5 | import org.apache.spark.sql.RowFactory; 6 | import org.apache.spark.sql.SparkSession; 7 | import org.junit.After; 8 | import org.junit.Before; 9 | 10 | import java.io.IOException; 11 | 12 | /** 13 | * Base class for test that need to create and use a spark context. 14 | */ 15 | public class SparkTestBase { 16 | protected transient SparkSession spark; 17 | protected transient JavaSparkContext jsc; 18 | protected static final double EPSILON = 1.0e-6; 19 | 20 | @Before 21 | public void setUp() throws IOException { 22 | spark = SparkSession.builder() 23 | .master("local[2]") 24 | .appName(getClass().getSimpleName()) 25 | .getOrCreate(); 26 | jsc = new JavaSparkContext(spark.sparkContext()); 27 | } 28 | 29 | @After 30 | public void tearDown() { 31 | spark.stop(); 32 | spark = null; 33 | } 34 | 35 | /** 36 | * An alias for RowFactory.create. 37 | */ 38 | public Row cr(Object... values) { 39 | return RowFactory.create(values); 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/RandomForestModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.RandomForestTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | import java.util.ArrayList; 8 | import java.util.List; 9 | 10 | /** 11 | * Represents information for a Random Forest model 12 | */ 13 | 14 | @Data 15 | public class RandomForestModelInfo extends AbstractModelInfo { 16 | 17 | private boolean regression; 18 | private int numFeatures; 19 | private int numClasses; 20 | private List trees = new ArrayList<>(); 21 | //Weights are currently not being used while prediction as it is not implemented in spark-mllib itself as of now. Keeping this as a placeholder for now. 22 | private List treeWeights = new ArrayList<>(); 23 | 24 | private String probabilityKey = "probability"; 25 | private String rawPredictionKey = "rawPrediction"; 26 | 27 | /** 28 | * @return an corresponding {@link RandomForestTransformer} for this model info 29 | */ 30 | @Override 31 | public Transformer getTransformer() { 32 | return new RandomForestTransformer(this); 33 | } 34 | 35 | public boolean isClassification() { 36 | return !regression; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/PopularWordsEstimatorModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.PopularWordsEstimatorModelInfo; 4 | import org.apache.spark.ml.PopularWordsModel; 5 | 6 | import java.util.Arrays; 7 | import java.util.HashSet; 8 | import java.util.LinkedHashSet; 9 | import java.util.Set; 10 | 11 | public class PopularWordsEstimatorModelInfoAdapter extends AbstractModelInfoAdapter { 12 | 13 | @Override 14 | PopularWordsEstimatorModelInfo getModelInfo(PopularWordsModel from) { 15 | PopularWordsEstimatorModelInfo modelInfo = new PopularWordsEstimatorModelInfo(); 16 | modelInfo.setPopularWords(new HashSet<>(Arrays.asList(from.popularWords()))); 17 | 18 | Set inputKeys = new LinkedHashSet<>(); 19 | inputKeys.add(from.getInputCol()); 20 | modelInfo.setInputKeys(inputKeys); 21 | 22 | Set outputKeys = new LinkedHashSet<>(); 23 | outputKeys.add(from.getOutputCol()); 24 | modelInfo.setOutputKeys(outputKeys); 25 | 26 | return modelInfo; 27 | } 28 | 29 | @Override 30 | public Class getSource() { 31 | return PopularWordsModel.class; 32 | } 33 | 34 | @Override 35 | public Class getTarget() { 36 | return PopularWordsEstimatorModelInfo.class; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /models-info/pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | 7 | 8 | com.flipkart.fdp.ml 9 | spark-transformers 10 | 0.4.0 11 | .. 12 | 13 | 14 | models-info 15 | jar 16 | Custom Models and transformers 17 | 18 | 19 | 20 | com.google.code.gson 21 | gson 22 | 2.5 23 | 24 | 25 | 26 | org.apache.commons 27 | commons-lang3 28 | 3.3.2 29 | 30 | 31 | 32 | com.github.fommil.netlib 33 | core 34 | 1.1.2 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/PopularWordsEstimatorTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.PopularWordsEstimatorModelInfo; 4 | 5 | import java.util.*; 6 | 7 | public class PopularWordsEstimatorTransformer implements Transformer { 8 | private PopularWordsEstimatorModelInfo modelInfo; 9 | 10 | public PopularWordsEstimatorTransformer(final PopularWordsEstimatorModelInfo modelInfo) { 11 | this.modelInfo = modelInfo; 12 | } 13 | 14 | public double predict(final String[] words) { 15 | return getMatchedWordsCount(modelInfo.getPopularWords(), words) / words.length; 16 | } 17 | 18 | private double getMatchedWordsCount(HashSet popularWords, String[] words) { 19 | double count = 0.0; 20 | for (String word : words) { 21 | if (popularWords.contains(word)) { 22 | count++; 23 | } 24 | } 25 | return count; 26 | } 27 | 28 | @Override 29 | public void transform(Map input) { 30 | String key = modelInfo.getInputKeys().iterator().next(); 31 | String[] inp = (String[]) input.get(key); 32 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 33 | } 34 | 35 | @Override 36 | public Set getInputKeys() { 37 | return modelInfo.getInputKeys(); 38 | } 39 | 40 | @Override 41 | public Set getOutputKeys() { 42 | return modelInfo.getOutputKeys(); 43 | } 44 | } -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/StringSanitizerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringSanitizerModelInfo; 4 | 5 | import java.util.Arrays; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.Set; 9 | 10 | public class StringSanitizerTransformer implements Transformer { 11 | private StringSanitizerModelInfo modelInfo; 12 | 13 | public StringSanitizerTransformer(StringSanitizerModelInfo modelInfo) { 14 | this.modelInfo = modelInfo; 15 | } 16 | 17 | @Override 18 | public void transform(Map input) { 19 | String key = modelInfo.getInputKeys().iterator().next(); 20 | String inp = (String) input.get(key); 21 | input.put(modelInfo.getOutputKeys().iterator().next(), transformInput(inp)); 22 | } 23 | 24 | private String[] transformInput(String input) { 25 | String s = input.toLowerCase() 26 | .replaceAll("\\P{Print}", " ") 27 | .replaceAll("[^0-9a-zA-Z ]", " ") 28 | .replaceAll("\\d{10}", " ") 29 | .replaceAll("\\d{6}", " ") 30 | .trim() 31 | .replaceAll("\\s+", " "); 32 | String[] split = s.split(" "); 33 | return split; 34 | } 35 | 36 | @Override 37 | public Set getInputKeys() { 38 | return modelInfo.getInputKeys(); 39 | } 40 | 41 | @Override 42 | public Set getOutputKeys() { 43 | return modelInfo.getOutputKeys(); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/Log1PScalerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.Log1PScalerModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a LogScaler model representation 10 | * captured by {@link Log1PScalerModelInfo}. 11 | */ 12 | public class Log1PScalerTransformer implements Transformer { 13 | private final Log1PScalerModelInfo modelInfo; 14 | 15 | public Log1PScalerTransformer(Log1PScalerModelInfo modelInfo) { 16 | this.modelInfo = modelInfo; 17 | } 18 | 19 | @Override 20 | public void transform(Map input) { 21 | double[] inp = (double[]) input.get(modelInfo.getInputKeys().iterator().next()); 22 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 23 | } 24 | 25 | private double[] predict(double[] inp) { 26 | double[] output = new double[inp.length]; 27 | for (int i = 0; i < inp.length; i++) { 28 | output[i] = Math.log1p(inp[i]); 29 | } 30 | return output; 31 | } 32 | 33 | @Override 34 | public Set getInputKeys() { 35 | return modelInfo.getInputKeys(); 36 | } 37 | 38 | @Override 39 | public Set getOutputKeys() { 40 | return modelInfo.getOutputKeys(); 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/HashingTFModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.HashingTFModelInfo; 4 | import org.apache.spark.ml.feature.HashingTF; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link HashingTF} in MlLib to {@link HashingTFModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class HashingTFModelInfoAdapter extends AbstractModelInfoAdapter { 14 | @Override 15 | public HashingTFModelInfo getModelInfo(final HashingTF from) { 16 | final HashingTFModelInfo modelInfo = new HashingTFModelInfo(); 17 | modelInfo.setNumFeatures(from.getNumFeatures()); 18 | 19 | Set inputKeys = new LinkedHashSet(); 20 | inputKeys.add(from.getInputCol()); 21 | modelInfo.setInputKeys(inputKeys); 22 | 23 | Set outputKeys = new LinkedHashSet(); 24 | outputKeys.add(from.getOutputCol()); 25 | modelInfo.setOutputKeys(outputKeys); 26 | 27 | return modelInfo; 28 | } 29 | 30 | @Override 31 | public Class getSource() { 32 | return HashingTF.class; 33 | } 34 | 35 | @Override 36 | public Class getTarget() { 37 | return HashingTFModelInfo.class; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/Log1PScalerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.Log1PScaler; 4 | import com.flipkart.fdp.ml.modelinfo.Log1PScalerModelInfo; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms {@link Log1PScaler} in MlLib to {@link Log1PScalerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class Log1PScalerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | 16 | @Override 17 | public Log1PScalerModelInfo getModelInfo(final Log1PScaler from, DataFrame df) { 18 | Log1PScalerModelInfo modelInfo = new Log1PScalerModelInfo(); 19 | 20 | Set inputKeys = new LinkedHashSet(); 21 | inputKeys.add(from.getInputCol()); 22 | modelInfo.setInputKeys(inputKeys); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | modelInfo.setOutputKeys(outputKeys); 27 | 28 | return modelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return Log1PScaler.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return Log1PScalerModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/BucketizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.BucketizerModelInfo; 4 | import org.apache.spark.ml.feature.Bucketizer; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link Bucketizer} in MlLib to {@link BucketizerModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class BucketizerModelInfoAdapter extends AbstractModelInfoAdapter { 14 | 15 | @Override 16 | public BucketizerModelInfo getModelInfo(final Bucketizer from) { 17 | final BucketizerModelInfo modelInfo = new BucketizerModelInfo(); 18 | modelInfo.setSplits(from.getSplits()); 19 | 20 | Set inputKeys = new LinkedHashSet(); 21 | inputKeys.add(from.getInputCol()); 22 | modelInfo.setInputKeys(inputKeys); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | modelInfo.setOutputKeys(outputKeys); 27 | 28 | return modelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return Bucketizer.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return BucketizerModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/AbstractModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 4 | import com.flipkart.fdp.ml.utils.Constants; 5 | import org.apache.commons.lang3.StringUtils; 6 | import org.apache.spark.sql.DataFrame; 7 | 8 | import java.io.Serializable; 9 | 10 | 11 | public abstract class AbstractModelInfoAdapter implements ModelInfoAdapter, Serializable { 12 | 13 | private void preConditions(DataFrame df) { 14 | if (null != df) { 15 | if (!StringUtils.startsWith(df.sqlContext().sparkContext().version(), Constants.SUPPORTED_SPARK_VERSION_PREFIX)) { 16 | throw new UnsupportedOperationException("Only spark version " + Constants.SUPPORTED_SPARK_VERSION_PREFIX + " is supported by this version of the library"); 17 | } 18 | } 19 | } 20 | 21 | @Override 22 | public T adapt(F from, DataFrame df) { 23 | preConditions(df); 24 | return getModelInfo(from, df); 25 | } 26 | 27 | /** 28 | * @param from source object in spark's mllib 29 | * @param df Data frame that is used for training is required for some models as state information is being stored as column metadata by spark models 30 | * @return returns the corresponding {@link ModelInfo} object that represents the model information 31 | */ 32 | abstract T getModelInfo(F from, DataFrame df); 33 | 34 | } 35 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/AlgebraicTransformModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.AlgebraicTransform; 4 | import com.flipkart.fdp.ml.modelinfo.AlgebraicTransformModelInfo; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Created by shubhranshu.shekhar on 18/08/16. 12 | */ 13 | public class AlgebraicTransformModelInfoAdapter extends AbstractModelInfoAdapter { 14 | @Override 15 | public AlgebraicTransformModelInfo getModelInfo(final AlgebraicTransform from, DataFrame df) { 16 | AlgebraicTransformModelInfo modelInfo = new AlgebraicTransformModelInfo(); 17 | modelInfo.setCoefficients(from.getCoefficients()); 18 | 19 | Set inputKeys = new LinkedHashSet(); 20 | inputKeys.add(from.getInputCol()); 21 | modelInfo.setInputKeys(inputKeys); 22 | 23 | Set outputKeys = new LinkedHashSet(); 24 | outputKeys.add(from.getOutputCol()); 25 | modelInfo.setOutputKeys(outputKeys); 26 | return modelInfo; 27 | } 28 | 29 | @Override 30 | public Class getSource() { 31 | return AlgebraicTransform.class; 32 | } 33 | 34 | @Override 35 | public Class getTarget() { 36 | return AlgebraicTransformModelInfo.class; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/CustomOneHotEncoderModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.CustomOneHotEncoderModel; 4 | import com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Created by shubhranshu.shekhar on 21/06/16. 12 | */ 13 | public class CustomOneHotEncoderModelInfoAdapter extends AbstractModelInfoAdapter { 14 | 15 | @Override 16 | public OneHotEncoderModelInfo getModelInfo(final CustomOneHotEncoderModel from, DataFrame df) { 17 | OneHotEncoderModelInfo modelInfo = new OneHotEncoderModelInfo(); 18 | 19 | modelInfo.setNumTypes(from.vectorSize()); 20 | 21 | Set inputKeys = new LinkedHashSet(); 22 | inputKeys.add(from.getInputCol()); 23 | modelInfo.setInputKeys(inputKeys); 24 | 25 | Set outputKeys = new LinkedHashSet(); 26 | outputKeys.add(from.getOutputCol()); 27 | modelInfo.setOutputKeys(outputKeys); 28 | 29 | return modelInfo; 30 | } 31 | 32 | @Override 33 | public Class getSource() { 34 | return CustomOneHotEncoderModel.class; 35 | } 36 | 37 | @Override 38 | public Class getTarget() { 39 | return OneHotEncoderModelInfo.class; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /adapters-1.6/src/test/java/com/flipkart/fdp/ml/adapter/SparkTestBase.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import org.apache.spark.SparkConf; 4 | import org.apache.spark.SparkContext; 5 | import org.apache.spark.api.java.JavaSparkContext; 6 | import org.apache.spark.sql.Row; 7 | import org.apache.spark.sql.RowFactory; 8 | import org.apache.spark.sql.SQLContext; 9 | import org.junit.After; 10 | import org.junit.Before; 11 | import org.slf4j.Logger; 12 | import org.slf4j.LoggerFactory; 13 | 14 | /** 15 | * Base class for test that need to create and use a spark context. 16 | */ 17 | public class SparkTestBase { 18 | private static final Logger LOG = LoggerFactory.getLogger(SparkTestBase.class); 19 | protected JavaSparkContext sc; 20 | protected SQLContext sqlContext; 21 | public static final double EPSILON = 1.0e-6; 22 | 23 | @Before 24 | public void setup() { 25 | SparkConf sparkConf = new SparkConf(); 26 | String master = "local[2]"; 27 | sparkConf.setMaster(master); 28 | sparkConf.setAppName("Local Spark Unit Test"); 29 | sc = new JavaSparkContext(new SparkContext(sparkConf)); 30 | sqlContext = new SQLContext(sc); 31 | } 32 | 33 | @After 34 | public void tearDown() { 35 | sc.close(); 36 | sqlContext = null; 37 | } 38 | 39 | /** 40 | * An alias for RowFactory.create. 41 | */ 42 | public Row cr(Object... values) { 43 | return RowFactory.create(values); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /custom-transformer/src/main/scala/com/flipkart/fdp/ml/FillNAValuesTransformer.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml 2 | 3 | import org.apache.spark.ml.Transformer 4 | import org.apache.spark.ml.param.{Param, ParamMap} 5 | import org.apache.spark.ml.util.Identifiable 6 | import org.apache.spark.sql.DataFrame 7 | import org.apache.spark.sql.types.StructType 8 | 9 | 10 | class FillNAValuesTransformer(override val uid: String) extends Transformer { 11 | 12 | val naValueMap: Param[java.util.Map[String, Any]] = new Param[java.util.Map[String, Any]](this, "naValueMap", "column name to default value map in case value is NA"); 13 | setDefault(naValueMap -> new java.util.HashMap[String, Any]()) 14 | 15 | def this() { 16 | this(Identifiable.randomUID("FillNAValuesTransformer")) 17 | } 18 | 19 | def getNAValueMap: java.util.Map[String, Any] = $(naValueMap) 20 | 21 | def setNAValueMap(columnToNAValueMap: java.util.Map[String, Any]) = { 22 | if(! columnToNAValueMap.isEmpty) { 23 | $(naValueMap).putAll(columnToNAValueMap); 24 | } 25 | } 26 | 27 | override def transform(dataFrame: DataFrame): DataFrame = { 28 | dataFrame.na.fill($(naValueMap)); 29 | } 30 | 31 | override def copy(extra: ParamMap): FillNAValuesTransformer = { 32 | val copied = new FillNAValuesTransformer(uid) 33 | copyValues(copied, extra) 34 | } 35 | 36 | override def transformSchema(schema: StructType): StructType = { 37 | //This Transformer does not change the schema of df 38 | return schema; 39 | } 40 | 41 | } 42 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/PipelineModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.ModelInfoAdapterFactory; 4 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 5 | import com.flipkart.fdp.ml.modelinfo.PipelineModelInfo; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.apache.spark.ml.PipelineModel; 8 | import org.apache.spark.ml.Transformer; 9 | 10 | /** 11 | * Transforms Spark's {@link PipelineModel} to {@link PipelineModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | @Slf4j 15 | public class PipelineModelInfoAdapter extends AbstractModelInfoAdapter { 16 | @Override 17 | public PipelineModelInfo getModelInfo(final PipelineModel from) { 18 | final PipelineModelInfo modelInfo = new PipelineModelInfo(); 19 | final ModelInfo stages[] = new ModelInfo[from.stages().length]; 20 | for (int i = 0; i < from.stages().length; i++) { 21 | Transformer sparkModel = from.stages()[i]; 22 | stages[i] = ModelInfoAdapterFactory.getAdapter(sparkModel.getClass()).adapt(sparkModel); 23 | } 24 | modelInfo.setStages(stages); 25 | return modelInfo; 26 | } 27 | 28 | @Override 29 | public Class getSource() { 30 | return PipelineModel.class; 31 | } 32 | 33 | @Override 34 | public Class getTarget() { 35 | return PipelineModelInfo.class; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/HashingTFModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.HashingTFModelInfo; 4 | import org.apache.spark.ml.feature.HashingTF; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link HashingTF} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.HashingTFModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class HashingTFModelInfoAdapter extends AbstractModelInfoAdapter { 15 | @Override 16 | public HashingTFModelInfo getModelInfo(final HashingTF from, DataFrame df) { 17 | final HashingTFModelInfo modelInfo = new HashingTFModelInfo(); 18 | modelInfo.setNumFeatures(from.getNumFeatures()); 19 | 20 | Set inputKeys = new LinkedHashSet(); 21 | inputKeys.add(from.getInputCol()); 22 | modelInfo.setInputKeys(inputKeys); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | modelInfo.setOutputKeys(outputKeys); 27 | 28 | return modelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return HashingTF.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return HashingTFModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/VectorAssemblerModelAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.VectorAssemblerModelInfo; 4 | import org.apache.spark.ml.feature.VectorAssembler; 5 | 6 | import java.util.Arrays; 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link VectorAssembler} in MlLib to {@link VectorAssemblerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | *

14 | * Created by rohan.shetty on 28/03/16. 15 | */ 16 | public class VectorAssemblerModelAdapter extends AbstractModelInfoAdapter { 17 | 18 | @Override 19 | VectorAssemblerModelInfo getModelInfo(VectorAssembler from) { 20 | VectorAssemblerModelInfo vectorAssemblerModelInfo = new VectorAssemblerModelInfo(); 21 | 22 | vectorAssemblerModelInfo.setInputKeys(new LinkedHashSet<>(Arrays.asList(from.getInputCols()))); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | vectorAssemblerModelInfo.setOutputKeys(outputKeys); 27 | 28 | return vectorAssemblerModelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return VectorAssembler.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return VectorAssemblerModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/BucketizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.BucketizerModelInfo; 4 | import org.apache.spark.ml.feature.Bucketizer; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link Bucketizer} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.BucketizerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class BucketizerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | 16 | @Override 17 | public BucketizerModelInfo getModelInfo(final Bucketizer from, final DataFrame df) { 18 | final BucketizerModelInfo modelInfo = new BucketizerModelInfo(); 19 | modelInfo.setSplits(from.getSplits()); 20 | 21 | Set inputKeys = new LinkedHashSet(); 22 | inputKeys.add(from.getInputCol()); 23 | modelInfo.setInputKeys(inputKeys); 24 | 25 | Set outputKeys = new LinkedHashSet(); 26 | outputKeys.add(from.getOutputCol()); 27 | modelInfo.setOutputKeys(outputKeys); 28 | return modelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return Bucketizer.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return BucketizerModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/ChiSqSelectorModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ChiSqSelectorModelInfo; 4 | import org.apache.spark.ml.feature.ChiSqSelectorModel; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link ChiSqSelectorModel} in MlLib to {@link ChiSqSelectorModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class ChiSqSelectorModelInfoAdapter extends AbstractModelInfoAdapter { 14 | 15 | @Override 16 | public ChiSqSelectorModelInfo getModelInfo(final ChiSqSelectorModel from) { 17 | ChiSqSelectorModelInfo modelInfo = new ChiSqSelectorModelInfo(); 18 | modelInfo.setSelectedFeatures(from.selectedFeatures()); 19 | 20 | Set inputKeys = new LinkedHashSet(); 21 | inputKeys.add(from.getFeaturesCol()); 22 | modelInfo.setInputKeys(inputKeys); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | modelInfo.setOutputKeys(outputKeys); 27 | 28 | return modelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return ChiSqSelectorModel.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return ChiSqSelectorModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/ChiSqSelectorTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ChiSqSelectorModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a ChiSqSelectorModel model representation 10 | * captured by {@link com.flipkart.fdp.ml.modelinfo.ChiSqSelectorModelInfo}. 11 | */ 12 | public class ChiSqSelectorTransformer implements Transformer { 13 | 14 | private final ChiSqSelectorModelInfo modelInfo; 15 | 16 | public ChiSqSelectorTransformer(final ChiSqSelectorModelInfo modelInfo) { 17 | this.modelInfo = modelInfo; 18 | } 19 | 20 | public double[] predict(double[] inp) { 21 | double[] output = new double[modelInfo.getSelectedFeatures().length]; 22 | int count = 0; 23 | for (int featureIndex : modelInfo.getSelectedFeatures()) { 24 | output[count++] = inp[featureIndex]; 25 | } 26 | return output; 27 | } 28 | 29 | @Override 30 | public void transform(Map input) { 31 | double[] inp = (double[]) input.get(modelInfo.getInputKeys().iterator().next()); 32 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 33 | } 34 | 35 | @Override 36 | public Set getInputKeys() { 37 | return modelInfo.getInputKeys(); 38 | } 39 | 40 | @Override 41 | public Set getOutputKeys() { 42 | return modelInfo.getOutputKeys(); 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/ProbabilityTransformTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ProbabilityTransformModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Created by shubhranshu.shekhar on 18/08/16. 10 | */ 11 | public class ProbabilityTransformTransformer implements Transformer { 12 | private final ProbabilityTransformModelInfo modelInfo; 13 | 14 | public ProbabilityTransformTransformer(final ProbabilityTransformModelInfo modelInfo) { 15 | this.modelInfo = modelInfo; 16 | } 17 | 18 | public double predict(final double input) { 19 | double p1 = modelInfo.getActualClickProportion(); 20 | double r1 = modelInfo.getUnderSampledClickProportion(); 21 | double probIndex = modelInfo.getProbIndex();//not used because in the map LR only fills prob wrt positive class 22 | 23 | double encoding = (input *p1/r1) / ((input *p1/r1) + ((1-input) *(1-p1)/(1-r1))); 24 | return encoding; 25 | } 26 | 27 | @Override 28 | public void transform(Map input) { 29 | double inp = (double) input.get(modelInfo.getInputKeys().iterator().next()); 30 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 31 | } 32 | 33 | @Override 34 | public Set getInputKeys() { 35 | return modelInfo.getInputKeys(); 36 | } 37 | 38 | @Override 39 | public Set getOutputKeys() { 40 | return modelInfo.getOutputKeys(); 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/OneHotEncoderTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo; 4 | 5 | import java.util.Arrays; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms input/ predicts for a OneHotEncoder model representation 11 | * captured by {@link com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo}. 12 | */ 13 | public class OneHotEncoderTransformer implements Transformer { 14 | 15 | private final OneHotEncoderModelInfo modelInfo; 16 | 17 | public OneHotEncoderTransformer(final OneHotEncoderModelInfo modelInfo) { 18 | this.modelInfo = modelInfo; 19 | } 20 | 21 | public double[] predict(final double input) { 22 | int size = modelInfo.getNumTypes(); 23 | 24 | final double encoding[] = new double[size]; 25 | Arrays.fill(encoding, 0.0); 26 | 27 | if ((int) input < size) { 28 | encoding[((int) input)] = 1.0; 29 | } 30 | return encoding; 31 | } 32 | 33 | @Override 34 | public void transform(Map input) { 35 | double inp = (double) input.get(modelInfo.getInputKeys().iterator().next()); 36 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 37 | } 38 | 39 | @Override 40 | public Set getInputKeys() { 41 | return modelInfo.getInputKeys(); 42 | } 43 | 44 | @Override 45 | public Set getOutputKeys() { 46 | return modelInfo.getOutputKeys(); 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/IfZeroVectorModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.IfZeroVector; 4 | import com.flipkart.fdp.ml.modelinfo.IfZeroVectorModelInfo; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms {@link IfZeroVector} to {@link IfZeroVectorModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class IfZeroVectorModelInfoAdapter extends AbstractModelInfoAdapter { 15 | 16 | @Override 17 | public IfZeroVectorModelInfo getModelInfo(final IfZeroVector from, DataFrame df) { 18 | IfZeroVectorModelInfo modelInfo = new IfZeroVectorModelInfo(); 19 | 20 | Set inputKeys = new LinkedHashSet(); 21 | inputKeys.add(from.getInputCol()); 22 | modelInfo.setInputKeys(inputKeys); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | modelInfo.setOutputKeys(outputKeys); 27 | 28 | modelInfo.setThenSetValue(from.getThenSetValue()); 29 | modelInfo.setElseSetCol(from.getElseSetCol()); 30 | 31 | return modelInfo; 32 | } 33 | 34 | @Override 35 | public Class getSource() { 36 | return IfZeroVector.class; 37 | } 38 | 39 | @Override 40 | public Class getTarget() { 41 | return IfZeroVectorModelInfo.class; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/PipelineModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.ModelInfoAdapterFactory; 4 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 5 | import com.flipkart.fdp.ml.modelinfo.PipelineModelInfo; 6 | import lombok.extern.slf4j.Slf4j; 7 | import org.apache.spark.ml.PipelineModel; 8 | import org.apache.spark.ml.Transformer; 9 | import org.apache.spark.sql.DataFrame; 10 | 11 | /** 12 | * Transforms Spark's {@link PipelineModel} to {@link PipelineModelInfo} object 13 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 14 | */ 15 | @Slf4j 16 | public class PipelineModelInfoAdapter extends AbstractModelInfoAdapter { 17 | @Override 18 | public PipelineModelInfo getModelInfo(final PipelineModel from, final DataFrame df) { 19 | final PipelineModelInfo modelInfo = new PipelineModelInfo(); 20 | final ModelInfo stages[] = new ModelInfo[from.stages().length]; 21 | for (int i = 0; i < from.stages().length; i++) { 22 | Transformer sparkModel = from.stages()[i]; 23 | stages[i] = ModelInfoAdapterFactory.getAdapter(sparkModel.getClass()).adapt(sparkModel, df); 24 | } 25 | modelInfo.setStages(stages); 26 | return modelInfo; 27 | } 28 | 29 | @Override 30 | public Class getSource() { 31 | return PipelineModel.class; 32 | } 33 | 34 | @Override 35 | public Class getTarget() { 36 | return PipelineModelInfo.class; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/CountVectorizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo; 4 | import org.apache.spark.ml.feature.CountVectorizerModel; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link CountVectorizerModel} in MlLib to {@link CountVectorizerModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class CountVectorizerModelInfoAdapter extends AbstractModelInfoAdapter { 14 | @Override 15 | public CountVectorizerModelInfo getModelInfo(final CountVectorizerModel from) { 16 | final CountVectorizerModelInfo modelInfo = new CountVectorizerModelInfo(); 17 | modelInfo.setMinTF(from.getMinTF()); 18 | modelInfo.setVocabulary(from.vocabulary()); 19 | 20 | Set inputKeys = new LinkedHashSet(); 21 | inputKeys.add(from.getInputCol()); 22 | modelInfo.setInputKeys(inputKeys); 23 | 24 | Set outputKeys = new LinkedHashSet(); 25 | outputKeys.add(from.getOutputCol()); 26 | modelInfo.setOutputKeys(outputKeys); 27 | 28 | return modelInfo; 29 | } 30 | 31 | @Override 32 | public Class getSource() { 33 | return CountVectorizerModel.class; 34 | } 35 | 36 | @Override 37 | public Class getTarget() { 38 | return CountVectorizerModelInfo.class; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/VectorAssemblerModelAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.VectorAssemblerModelInfo; 4 | import org.apache.spark.ml.feature.VectorAssembler; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.Arrays; 8 | import java.util.LinkedHashSet; 9 | import java.util.Set; 10 | 11 | /** 12 | * Transforms Spark's {@link VectorAssembler} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.VectorAssemblerModelInfo} object 13 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 14 | *

15 | * Created by rohan.shetty on 28/03/16. 16 | */ 17 | public class VectorAssemblerModelAdapter extends AbstractModelInfoAdapter { 18 | 19 | @Override 20 | VectorAssemblerModelInfo getModelInfo(VectorAssembler from, DataFrame df) { 21 | VectorAssemblerModelInfo vectorAssemblerModelInfo = new VectorAssemblerModelInfo(); 22 | 23 | vectorAssemblerModelInfo.setInputKeys(new LinkedHashSet<>(Arrays.asList(from.getInputCols()))); 24 | 25 | Set outputKeys = new LinkedHashSet(); 26 | outputKeys.add(from.getOutputCol()); 27 | vectorAssemblerModelInfo.setOutputKeys(outputKeys); 28 | 29 | return vectorAssemblerModelInfo; 30 | } 31 | 32 | @Override 33 | public Class getSource() { 34 | return VectorAssembler.class; 35 | } 36 | 37 | @Override 38 | public Class getTarget() { 39 | return VectorAssemblerModelInfo.class; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/AlgebraicTransformTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.AlgebraicTransformModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Created by shubhranshu.shekhar on 18/08/16. 10 | */ 11 | public class AlgebraicTransformTransformer implements Transformer { 12 | private final AlgebraicTransformModelInfo modelInfo; 13 | 14 | public AlgebraicTransformTransformer(final AlgebraicTransformModelInfo modelInfo) { 15 | this.modelInfo = modelInfo; 16 | } 17 | 18 | public double predict(final double input) { 19 | double[] coeff = modelInfo.getCoefficients(); 20 | if(coeff.length == 0){ 21 | return 0.0; 22 | } 23 | else{ 24 | double sum = coeff[0]; 25 | double mul = input; 26 | for(int i = 1; i < coeff.length; i++){ 27 | sum = sum + (coeff[i] * mul); 28 | mul = mul * input; 29 | } 30 | return sum; 31 | } 32 | } 33 | 34 | @Override 35 | public void transform(Map input) { 36 | double inp = (double) input.get(modelInfo.getInputKeys().iterator().next()); 37 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 38 | } 39 | 40 | @Override 41 | public Set getInputKeys() { 42 | return modelInfo.getInputKeys(); 43 | } 44 | 45 | @Override 46 | public Set getOutputKeys() { 47 | return modelInfo.getOutputKeys(); 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/ChiSqSelectorModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ChiSqSelectorModelInfo; 4 | import org.apache.spark.ml.feature.ChiSqSelectorModel; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link ChiSqSelectorModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.ChiSqSelectorModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class ChiSqSelectorModelInfoAdapter extends AbstractModelInfoAdapter { 15 | 16 | @Override 17 | public ChiSqSelectorModelInfo getModelInfo(final ChiSqSelectorModel from, DataFrame df) { 18 | ChiSqSelectorModelInfo modelInfo = new ChiSqSelectorModelInfo(); 19 | modelInfo.setSelectedFeatures(from.selectedFeatures()); 20 | 21 | Set inputKeys = new LinkedHashSet(); 22 | inputKeys.add(from.getFeaturesCol()); 23 | modelInfo.setInputKeys(inputKeys); 24 | 25 | Set outputKeys = new LinkedHashSet(); 26 | outputKeys.add(from.getOutputCol()); 27 | modelInfo.setOutputKeys(outputKeys); 28 | 29 | return modelInfo; 30 | } 31 | 32 | @Override 33 | public Class getSource() { 34 | return ChiSqSelectorModel.class; 35 | } 36 | 37 | @Override 38 | public Class getTarget() { 39 | return ChiSqSelectorModelInfo.class; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/PipelineModelTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.PipelineModelInfo; 4 | import com.flipkart.fdp.ml.utils.PipelineUtils; 5 | 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms input/ predicts for a Pipeline model representation 11 | * captured by {@link com.flipkart.fdp.ml.modelinfo.PipelineModelInfo}. 12 | */ 13 | public class PipelineModelTransformer implements Transformer { 14 | 15 | private final Transformer transformers[]; 16 | private Set extractedInputColumns; 17 | private Set extractedOutputColumns; 18 | 19 | public PipelineModelTransformer(final PipelineModelInfo modelInfo) { 20 | transformers = new Transformer[modelInfo.getStages().length]; 21 | for (int i = 0; i < transformers.length; i++) { 22 | transformers[i] = modelInfo.getStages()[i].getTransformer(); 23 | } 24 | extractedInputColumns = PipelineUtils.extractRequiredInputColumns(transformers); 25 | extractedOutputColumns = PipelineUtils.extractRequiredOutputColumns(transformers); 26 | } 27 | 28 | @Override 29 | public void transform(final Map input) { 30 | for (Transformer transformer : transformers) { 31 | transformer.transform(input); 32 | } 33 | } 34 | 35 | @Override 36 | public Set getInputKeys() { 37 | return extractedInputColumns; 38 | } 39 | 40 | @Override 41 | public Set getOutputKeys() { 42 | return extractedOutputColumns; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/RegexTokenizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.RegexTokenizerModelInfo; 4 | import org.apache.spark.ml.feature.RegexTokenizer; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link RegexTokenizer} in MlLib to {@link RegexTokenizerModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class RegexTokenizerModelInfoAdapter extends AbstractModelInfoAdapter { 14 | 15 | @Override 16 | public RegexTokenizerModelInfo getModelInfo(final RegexTokenizer from) { 17 | final RegexTokenizerModelInfo modelInfo = new RegexTokenizerModelInfo(); 18 | modelInfo.setMinTokenLength(from.getMinTokenLength()); 19 | modelInfo.setGaps(from.getGaps()); 20 | modelInfo.setPattern(from.getPattern()); 21 | modelInfo.setToLowercase(from.getToLowercase()); 22 | 23 | Set inputKeys = new LinkedHashSet(); 24 | inputKeys.add(from.getInputCol()); 25 | modelInfo.setInputKeys(inputKeys); 26 | 27 | Set outputKeys = new LinkedHashSet(); 28 | outputKeys.add(from.getOutputCol()); 29 | modelInfo.setOutputKeys(outputKeys); 30 | 31 | return modelInfo; 32 | } 33 | 34 | @Override 35 | public Class getSource() { 36 | return RegexTokenizer.class; 37 | } 38 | 39 | @Override 40 | public Class getTarget() { 41 | return RegexTokenizerModelInfo.class; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/FillNAValuesTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.FillNAValuesTransformerModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a {@link FillNAValuesTransformerModelInfo} model representation 10 | * captured by {@link FillNAValuesTransformerModelInfo}. 11 | */ 12 | public class FillNAValuesTransformer implements Transformer { 13 | private final FillNAValuesTransformerModelInfo modelInfo; 14 | 15 | public FillNAValuesTransformer(FillNAValuesTransformerModelInfo modelInfo) { 16 | this.modelInfo = modelInfo; 17 | } 18 | 19 | @Override 20 | public void transform(Map input) { 21 | for(Map.Entry entry : modelInfo.getNaValuesMap().entrySet()) { 22 | if( isNA(input.get(entry.getKey()))) { 23 | input.put(entry.getKey(), entry.getValue()); 24 | } 25 | } 26 | } 27 | 28 | private boolean isNA(Object data) { 29 | if( null == data) { 30 | return true; 31 | } 32 | if( data instanceof Double) { 33 | return ((Double)data).isNaN(); 34 | } 35 | if( data instanceof Float) { 36 | return ((Float)data).isNaN(); 37 | } 38 | return false; 39 | } 40 | 41 | @Override 42 | public Set getInputKeys() { 43 | return modelInfo.getInputKeys(); 44 | } 45 | 46 | @Override 47 | public Set getOutputKeys() { 48 | return modelInfo.getOutputKeys(); 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/MinMaxScalerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.MinMaxScalerModelInfo; 4 | import org.apache.spark.ml.feature.MinMaxScalerModel; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link MinMaxScalerModel} in MlLib to {@link MinMaxScalerModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class MinMaxScalerModelInfoAdapter extends AbstractModelInfoAdapter { 14 | @Override 15 | public MinMaxScalerModelInfo getModelInfo(final MinMaxScalerModel from) { 16 | final MinMaxScalerModelInfo modelInfo = new MinMaxScalerModelInfo(); 17 | modelInfo.setOriginalMax(from.originalMax().toArray()); 18 | modelInfo.setOriginalMin(from.originalMin().toArray()); 19 | modelInfo.setMax(from.getMax()); 20 | modelInfo.setMin(from.getMin()); 21 | 22 | Set inputKeys = new LinkedHashSet(); 23 | inputKeys.add(from.getInputCol()); 24 | modelInfo.setInputKeys(inputKeys); 25 | 26 | Set outputKeys = new LinkedHashSet(); 27 | outputKeys.add(from.getOutputCol()); 28 | modelInfo.setOutputKeys(outputKeys); 29 | 30 | return modelInfo; 31 | } 32 | 33 | @Override 34 | public Class getSource() { 35 | return MinMaxScalerModel.class; 36 | } 37 | 38 | @Override 39 | public Class getTarget() { 40 | return MinMaxScalerModelInfo.class; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /custom-transformer-2.0/src/main/scala/com/flipkart/transformer/ml/StringSanitizer.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.transformer.ml 2 | 3 | import org.apache.spark.annotation.DeveloperApi 4 | import org.apache.spark.ml.UnaryTransformer 5 | import org.apache.spark.ml.param.ParamMap 6 | import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} 7 | import org.apache.spark.sql.types._ 8 | 9 | /** 10 | * StringSanitizer: Removes non printable character, spaces, sequence of digits etc 11 | */ 12 | class StringSanitizer(override val uid: String) extends UnaryTransformer[String, Seq[String], StringSanitizer] with DefaultParamsWritable { 13 | def this() = this(Identifiable.randomUID("StringSanitizer")) 14 | 15 | override protected def validateInputType(inputType: DataType): Unit = { 16 | require(inputType == StringType, s"Input type must be string type but got $inputType.") 17 | } 18 | 19 | override protected def outputDataType: DataType = new ArrayType(StringType, true) 20 | 21 | override def copy(extra: ParamMap): StringSanitizer = defaultCopy(extra) 22 | 23 | @DeveloperApi 24 | override def transformSchema(schema: StructType): StructType = { 25 | StructType(schema.fields :+ StructField(getOutputCol, StringType)) 26 | } 27 | 28 | override protected def createTransformFunc: (String) => Seq[String] = { originStr => 29 | originStr.toLowerCase 30 | .replaceAll("\\P{Print}", " ") 31 | .replaceAll("[^0-9a-zA-Z ]", " ") 32 | .replaceAll("\\d{10}", " ") 33 | .replaceAll("\\d{6}", " ") 34 | .trim 35 | .replaceAll("""\s+""", " ") 36 | .split(" ") 37 | } 38 | } 39 | 40 | object StringSanitizer extends DefaultParamsReadable[StringSanitizer] -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/CountVectorizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo; 4 | import org.apache.spark.ml.feature.CountVectorizerModel; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link CountVectorizerModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class CountVectorizerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | @Override 16 | public CountVectorizerModelInfo getModelInfo(final CountVectorizerModel from, final DataFrame df) { 17 | final CountVectorizerModelInfo modelInfo = new CountVectorizerModelInfo(); 18 | modelInfo.setMinTF(from.getMinTF()); 19 | modelInfo.setVocabulary(from.vocabulary()); 20 | 21 | Set inputKeys = new LinkedHashSet(); 22 | inputKeys.add(from.getInputCol()); 23 | modelInfo.setInputKeys(inputKeys); 24 | 25 | Set outputKeys = new LinkedHashSet(); 26 | outputKeys.add(from.getOutputCol()); 27 | modelInfo.setOutputKeys(outputKeys); 28 | 29 | return modelInfo; 30 | } 31 | 32 | @Override 33 | public Class getSource() { 34 | return CountVectorizerModel.class; 35 | } 36 | 37 | @Override 38 | public Class getTarget() { 39 | return CountVectorizerModelInfo.class; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/StandardScalerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StandardScalerModelInfo; 4 | import org.apache.spark.ml.feature.StandardScalerModel; 5 | 6 | import java.util.LinkedHashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms Spark's {@link StandardScalerModel} in MlLib to {@link StandardScalerModelInfo} object 11 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 12 | */ 13 | public class StandardScalerModelInfoAdapter extends AbstractModelInfoAdapter { 14 | @Override 15 | public StandardScalerModelInfo getModelInfo(final StandardScalerModel from) { 16 | final StandardScalerModelInfo modelInfo = new StandardScalerModelInfo(); 17 | modelInfo.setMean(from.mean().toArray()); 18 | modelInfo.setStd(from.std().toArray()); 19 | modelInfo.setWithMean(from.getWithMean()); 20 | modelInfo.setWithStd(from.getWithStd()); 21 | 22 | Set inputKeys = new LinkedHashSet(); 23 | inputKeys.add(from.getInputCol()); 24 | modelInfo.setInputKeys(inputKeys); 25 | 26 | Set outputKeys = new LinkedHashSet(); 27 | outputKeys.add(from.getOutputCol()); 28 | modelInfo.setOutputKeys(outputKeys); 29 | 30 | return modelInfo; 31 | } 32 | 33 | @Override 34 | public Class getSource() { 35 | return StandardScalerModel.class; 36 | } 37 | 38 | @Override 39 | public Class getTarget() { 40 | return StandardScalerModelInfo.class; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/VectorBinarizerTranformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.VectorBinarizerModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Created by karan.verma on 09/11/16. 10 | */ 11 | 12 | 13 | public class VectorBinarizerTranformer implements Transformer { 14 | private final VectorBinarizerModelInfo modelInfo; 15 | 16 | public VectorBinarizerTranformer(final VectorBinarizerModelInfo modelInfo) { 17 | this.modelInfo = modelInfo; 18 | } 19 | 20 | @Override 21 | public void transform(Map input) { 22 | Object value = input.get(modelInfo.getInputKeys().iterator().next()); 23 | double[] inp = (double[])value; 24 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp, modelInfo.getThreshold())); 25 | } 26 | 27 | private double[] predict(double[] inp, double threshold) { 28 | 29 | if(inp == null || inp.length == 0) { 30 | return null; 31 | } 32 | double[] output = new double[inp.length]; 33 | 34 | for(int i = 0; i < inp.length; i++) { 35 | double currentValue = inp[i]; 36 | if (currentValue > threshold) { 37 | output[i] = 1.0; 38 | } else { 39 | output[i] = 0.0; 40 | } 41 | } 42 | return output; 43 | } 44 | 45 | @Override 46 | public Set getInputKeys() { 47 | return modelInfo.getInputKeys(); 48 | } 49 | 50 | @Override 51 | public Set getOutputKeys() { 52 | return modelInfo.getOutputKeys(); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/ProbabilityTransformModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ProbabilityTransformModelInfo; 4 | import com.flipkart.fdp.ml.ProbabilityTransformModel; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Created by shubhranshu.shekhar on 18/08/16. 12 | */ 13 | public class ProbabilityTransformModelInfoAdapter extends AbstractModelInfoAdapter { 14 | @Override 15 | public ProbabilityTransformModelInfo getModelInfo(final ProbabilityTransformModel from, DataFrame df) { 16 | ProbabilityTransformModelInfo modelInfo = new ProbabilityTransformModelInfo(); 17 | 18 | modelInfo.setActualClickProportion(from.getActualClickProportion()); 19 | modelInfo.setUnderSampledClickProportion(from.getUnderSampledClickProportion()); 20 | modelInfo.setProbIndex(from.getProbIndex()); 21 | 22 | Set inputKeys = new LinkedHashSet(); 23 | inputKeys.add(from.getInputCol()); 24 | modelInfo.setInputKeys(inputKeys); 25 | 26 | Set outputKeys = new LinkedHashSet(); 27 | outputKeys.add(from.getOutputCol()); 28 | modelInfo.setOutputKeys(outputKeys); 29 | return modelInfo; 30 | } 31 | 32 | @Override 33 | public Class getSource() { 34 | return ProbabilityTransformModel.class; 35 | } 36 | 37 | @Override 38 | public Class getTarget() { 39 | return ProbabilityTransformModelInfo.class; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/FillNAValuesTransformerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.FillNAValuesTransformer; 4 | import com.flipkart.fdp.ml.modelinfo.FillNAValuesTransformerModelInfo; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms {@link FillNAValuesTransformer} to {@link FillNAValuesTransformerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class FillNAValuesTransformerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | 16 | @Override 17 | public FillNAValuesTransformerModelInfo getModelInfo(final FillNAValuesTransformer from, DataFrame df) { 18 | 19 | final FillNAValuesTransformerModelInfo modelInfo = new FillNAValuesTransformerModelInfo(); 20 | modelInfo.setNaValuesMap(from.getNAValueMap()); 21 | 22 | Set inputKeys = new LinkedHashSet(); 23 | inputKeys.addAll(from.getNAValueMap().keySet()); 24 | modelInfo.setInputKeys(inputKeys); 25 | 26 | Set outputKeys = new LinkedHashSet(); 27 | outputKeys.addAll(from.getNAValueMap().keySet()); 28 | modelInfo.setOutputKeys(outputKeys); 29 | 30 | return modelInfo; 31 | } 32 | 33 | @Override 34 | public Class getSource() { 35 | return FillNAValuesTransformer.class; 36 | } 37 | 38 | @Override 39 | public Class getTarget() { 40 | return FillNAValuesTransformerModelInfo.class; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/VectorBinarizerModelAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.VectorBinarizerModelInfo; 4 | import org.apache.spark.ml.feature.VectorBinarizer; 5 | 6 | import org.apache.spark.sql.DataFrame; 7 | 8 | import java.util.Arrays; 9 | import java.util.LinkedHashSet; 10 | import java.util.Set; 11 | 12 | 13 | /** 14 | * Transforms {@link org.apache.spark.ml.feature.VectorBinarizer} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.VectorBinarizerModelInfo} object 15 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 16 | *

17 | * Created by karan.verma on 9/11/16. 18 | */ 19 | 20 | public class VectorBinarizerModelAdapter extends AbstractModelInfoAdapter { 21 | @Override 22 | VectorBinarizerModelInfo getModelInfo(VectorBinarizer from, DataFrame df) { 23 | 24 | VectorBinarizerModelInfo vectorBinarizerModelInfo = new VectorBinarizerModelInfo(); 25 | 26 | vectorBinarizerModelInfo.setInputKeys(new LinkedHashSet<>(Arrays.asList(from.getInputCol()))); 27 | 28 | Set outputKeys = new LinkedHashSet(); 29 | 30 | outputKeys.add(from.getOutputCol()); 31 | vectorBinarizerModelInfo.setOutputKeys(outputKeys); 32 | vectorBinarizerModelInfo.setThreshold(from.getThreshold()); 33 | 34 | return vectorBinarizerModelInfo; 35 | } 36 | 37 | @Override 38 | public Class getSource() { 39 | return VectorBinarizer.class; 40 | } 41 | 42 | @Override 43 | public Class getTarget() { 44 | return VectorBinarizerModelInfo.class; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/RegexTokenizerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.RegexTokenizerModelInfo; 4 | import org.apache.spark.ml.feature.RegexTokenizer; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link RegexTokenizer} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.RegexTokenizerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class RegexTokenizerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | 16 | @Override 17 | public RegexTokenizerModelInfo getModelInfo(final RegexTokenizer from, final DataFrame df) { 18 | final RegexTokenizerModelInfo modelInfo = new RegexTokenizerModelInfo(); 19 | modelInfo.setMinTokenLength(from.getMinTokenLength()); 20 | modelInfo.setGaps(from.getGaps()); 21 | modelInfo.setPattern(from.getPattern()); 22 | modelInfo.setToLowercase(from.getToLowercase()); 23 | 24 | Set inputKeys = new LinkedHashSet(); 25 | inputKeys.add(from.getInputCol()); 26 | modelInfo.setInputKeys(inputKeys); 27 | 28 | Set outputKeys = new LinkedHashSet(); 29 | outputKeys.add(from.getOutputCol()); 30 | modelInfo.setOutputKeys(outputKeys); 31 | 32 | return modelInfo; 33 | } 34 | 35 | @Override 36 | public Class getSource() { 37 | return RegexTokenizer.class; 38 | } 39 | 40 | @Override 41 | public Class getTarget() { 42 | return RegexTokenizerModelInfo.class; 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/HashingTFTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.HashingTFModelInfo; 4 | 5 | import java.util.Arrays; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms input/ predicts for a HashingTF model representation 11 | * captured by {@link com.flipkart.fdp.ml.modelinfo.HashingTFModelInfo}. 12 | */ 13 | public class HashingTFTransformer implements Transformer { 14 | 15 | private final HashingTFModelInfo modelInfo; 16 | 17 | public HashingTFTransformer(final HashingTFModelInfo modelInfo) { 18 | this.modelInfo = modelInfo; 19 | } 20 | 21 | public double[] predict(final String[] terms) { 22 | final double[] encoding = new double[modelInfo.getNumFeatures()]; 23 | Arrays.fill(encoding, 0.0); 24 | 25 | for (final String term : terms) { 26 | int index = term.hashCode() % modelInfo.getNumFeatures(); 27 | //care for negative values 28 | if (index < 0) { 29 | index += modelInfo.getNumFeatures(); 30 | } 31 | encoding[index] += 1.0; 32 | } 33 | return encoding; 34 | } 35 | 36 | @Override 37 | public void transform(Map input) { 38 | String[] inp = (String[]) input.get(modelInfo.getInputKeys().iterator().next()); 39 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 40 | } 41 | 42 | @Override 43 | public Set getInputKeys() { 44 | return modelInfo.getInputKeys(); 45 | } 46 | 47 | @Override 48 | public Set getOutputKeys() { 49 | return modelInfo.getOutputKeys(); 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/MinMaxScalerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.MinMaxScalerModelInfo; 4 | import org.apache.spark.ml.feature.MinMaxScalerModel; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link MinMaxScalerModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.MinMaxScalerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class MinMaxScalerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | @Override 16 | public MinMaxScalerModelInfo getModelInfo(final MinMaxScalerModel from, final DataFrame df) { 17 | final MinMaxScalerModelInfo modelInfo = new MinMaxScalerModelInfo(); 18 | modelInfo.setOriginalMax(from.originalMax().toArray()); 19 | modelInfo.setOriginalMin(from.originalMin().toArray()); 20 | modelInfo.setMax(from.getMax()); 21 | modelInfo.setMin(from.getMin()); 22 | 23 | Set inputKeys = new LinkedHashSet(); 24 | inputKeys.add(from.getInputCol()); 25 | modelInfo.setInputKeys(inputKeys); 26 | 27 | Set outputKeys = new LinkedHashSet(); 28 | outputKeys.add(from.getOutputCol()); 29 | modelInfo.setOutputKeys(outputKeys); 30 | 31 | return modelInfo; 32 | } 33 | 34 | @Override 35 | public Class getSource() { 36 | return MinMaxScalerModel.class; 37 | } 38 | 39 | @Override 40 | public Class getTarget() { 41 | return MinMaxScalerModelInfo.class; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/StandardScalerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StandardScalerModelInfo; 4 | import org.apache.spark.ml.feature.StandardScalerModel; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link StandardScalerModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.StandardScalerModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | public class StandardScalerModelInfoAdapter extends AbstractModelInfoAdapter { 15 | @Override 16 | public StandardScalerModelInfo getModelInfo(final StandardScalerModel from, final DataFrame df) { 17 | final StandardScalerModelInfo modelInfo = new StandardScalerModelInfo(); 18 | modelInfo.setMean(from.mean().toArray()); 19 | modelInfo.setStd(from.std().toArray()); 20 | modelInfo.setWithMean(from.getWithMean()); 21 | modelInfo.setWithStd(from.getWithStd()); 22 | 23 | Set inputKeys = new LinkedHashSet(); 24 | inputKeys.add(from.getInputCol()); 25 | modelInfo.setInputKeys(inputKeys); 26 | 27 | Set outputKeys = new LinkedHashSet(); 28 | outputKeys.add(from.getOutputCol()); 29 | modelInfo.setOutputKeys(outputKeys); 30 | 31 | return modelInfo; 32 | } 33 | 34 | @Override 35 | public Class getSource() { 36 | return StandardScalerModel.class; 37 | } 38 | 39 | @Override 40 | public Class getTarget() { 41 | return StandardScalerModelInfo.class; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/modelinfo/DecisionTreeModelInfo.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.modelinfo; 2 | 3 | import com.flipkart.fdp.ml.transformer.DecisionTreeTransformer; 4 | import com.flipkart.fdp.ml.transformer.Transformer; 5 | import lombok.Data; 6 | 7 | import java.io.Serializable; 8 | import java.util.ArrayList; 9 | import java.util.HashSet; 10 | import java.util.List; 11 | import java.util.Set; 12 | 13 | /** 14 | * Represents information for a Decision Tree model. This class has been specifically designed to not contain and type heirarchy 15 | * for internal node/ leaf node , continuous/categorical split, regression/classification. 16 | * This has been done to keep serialization and deserialization of these objects simple. 17 | * Most of the json serializers (jackson, gson) do not handle type hierarchies well during deserialization. 18 | */ 19 | @Data 20 | public class DecisionTreeModelInfo extends AbstractModelInfo { 21 | private DecisionNode root; 22 | private String probabilityKey = "probability"; 23 | private String rawPredictionKey = "rawPrediction"; 24 | 25 | /** 26 | * @return an corresponding {@link DecisionTreeTransformer} for this model info 27 | */ 28 | @Override 29 | public Transformer getTransformer() { 30 | return new DecisionTreeTransformer(this); 31 | } 32 | 33 | @Data 34 | public static class DecisionNode implements Serializable { 35 | private int feature; 36 | private boolean leaf; 37 | private double threshold; 38 | private double prediction; 39 | private List impurityStats = new ArrayList<>(); 40 | private Set leftCategories = new HashSet<>(); 41 | private boolean continuousSplit; 42 | 43 | DecisionNode leftNode; 44 | DecisionNode rightNode; 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /adapters-1.6/pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | 7 | 8 | com.flipkart.fdp.ml 9 | spark-transformers 10 | 0.4.0 11 | .. 12 | 13 | 14 | 15 | 1.6.2 16 | 17 | 18 | adapters-1.6_${scala.binary.version} 19 | jar 20 | Spark 1.6 Model Adapters 21 | 22 | 23 | 24 | ${project.parent.groupId} 25 | models-info 26 | ${project.parent.version} 27 | 28 | 29 | org.apache.spark 30 | spark-core_${scala.binary.version} 31 | ${spark.version} 32 | provided 33 | 34 | 35 | org.apache.spark 36 | spark-mllib_${scala.binary.version} 37 | ${spark.version} 38 | provided 39 | 40 | 41 | ${project.parent.groupId} 42 | custom-transformer_${scala.binary.version} 43 | ${project.parent.version} 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /adapters-2.0/pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | 7 | 8 | com.flipkart.fdp.ml 9 | spark-transformers 10 | 0.4.0 11 | .. 12 | 13 | 14 | 15 | 2.0.2 16 | 17 | 18 | adapters-2.0_${scala.binary.version} 19 | jar 20 | Spark 2.0 Model Adapters 21 | 22 | 23 | 24 | ${project.parent.groupId} 25 | models-info 26 | ${project.parent.version} 27 | 28 | 29 | org.apache.spark 30 | spark-core_${scala.binary.version} 31 | ${spark.version} 32 | provided 33 | 34 | 35 | org.apache.spark 36 | spark-mllib_${scala.binary.version} 37 | ${spark.version} 38 | provided 39 | 40 | 41 | ${project.parent.groupId} 42 | custom-transformer-2.0_${scala.binary.version} 43 | ${project.parent.version} 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/StringIndexerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo; 4 | import org.apache.spark.ml.feature.StringIndexerModel; 5 | 6 | import java.util.HashMap; 7 | import java.util.LinkedHashSet; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | /** 12 | * Transforms Spark's {@link StringIndexerModel} in MlLib to {@link StringIndexerModelInfo} object 13 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 14 | */ 15 | public class StringIndexerModelInfoAdapter extends AbstractModelInfoAdapter { 16 | 17 | @Override 18 | public StringIndexerModelInfo getModelInfo(final StringIndexerModel from) { 19 | final String[] labels = from.labels(); 20 | final Map labelToIndex = new HashMap(); 21 | for (int i = 0; i < labels.length; i++) { 22 | labelToIndex.put(labels[i], (double) i); 23 | } 24 | final StringIndexerModelInfo modelInfo = new StringIndexerModelInfo(); 25 | modelInfo.setLabelToIndex(labelToIndex); 26 | 27 | Set inputKeys = new LinkedHashSet(); 28 | inputKeys.add(from.getInputCol()); 29 | modelInfo.setInputKeys(inputKeys); 30 | 31 | Set outputKeys = new LinkedHashSet(); 32 | outputKeys.add(from.getOutputCol()); 33 | modelInfo.setOutputKeys(outputKeys); 34 | 35 | return modelInfo; 36 | } 37 | 38 | @Override 39 | public Class getSource() { 40 | return StringIndexerModel.class; 41 | } 42 | 43 | @Override 44 | public Class getTarget() { 45 | return StringIndexerModelInfo.class; 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /custom-transformer-2.0/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | spark-transformers 7 | com.flipkart.fdp.ml 8 | 0.4.0 9 | 10 | 4.0.0 11 | 12 | custom-transformer-2.0_${scala.binary.version} 13 | 14 | 15 | 2.0.2 16 | 17 | 18 | 19 | 20 | 21 | net.alchim31.maven 22 | scala-maven-plugin 23 | 3.2.2 24 | 25 | 26 | 27 | compile 28 | testCompile 29 | doc-jar 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | org.scala-lang 41 | scala-library 42 | ${scala.version} 43 | provided 44 | 45 | 46 | org.apache.spark 47 | spark-core_${scala.binary.version} 48 | ${spark.version} 49 | provided 50 | 51 | 52 | org.apache.spark 53 | spark-mllib_${scala.binary.version} 54 | ${spark.version} 55 | provided 56 | 57 | 58 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/StringIndexerModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo; 4 | import org.apache.spark.ml.feature.StringIndexerModel; 5 | import org.apache.spark.sql.DataFrame; 6 | 7 | import java.util.HashMap; 8 | import java.util.LinkedHashSet; 9 | import java.util.Map; 10 | import java.util.Set; 11 | 12 | /** 13 | * Transforms Spark's {@link StringIndexerModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo} object 14 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 15 | */ 16 | public class StringIndexerModelInfoAdapter extends AbstractModelInfoAdapter { 17 | 18 | @Override 19 | public StringIndexerModelInfo getModelInfo(final StringIndexerModel from, DataFrame df) { 20 | final String[] labels = from.labels(); 21 | final Map labelToIndex = new HashMap(); 22 | for (int i = 0; i < labels.length; i++) { 23 | labelToIndex.put(labels[i], (double) i); 24 | } 25 | final StringIndexerModelInfo modelInfo = new StringIndexerModelInfo(); 26 | modelInfo.setLabelToIndex(labelToIndex); 27 | 28 | Set inputKeys = new LinkedHashSet(); 29 | inputKeys.add(from.getInputCol()); 30 | modelInfo.setInputKeys(inputKeys); 31 | 32 | Set outputKeys = new LinkedHashSet(); 33 | outputKeys.add(from.getOutputCol()); 34 | modelInfo.setOutputKeys(outputKeys); 35 | 36 | return modelInfo; 37 | } 38 | 39 | @Override 40 | public Class getSource() { 41 | return StringIndexerModel.class; 42 | } 43 | 44 | @Override 45 | public Class getTarget() { 46 | return StringIndexerModelInfo.class; 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/IfZeroVectorTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.IfZeroVectorModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a IfZeroVector model representation 10 | * captured by {@link IfZeroVectorModelInfo}. 11 | */ 12 | public class IfZeroVectorTransformer implements Transformer { 13 | 14 | private final IfZeroVectorModelInfo modelInfo; 15 | 16 | public IfZeroVectorTransformer(IfZeroVectorModelInfo modelInfo) { 17 | this.modelInfo = modelInfo; 18 | } 19 | 20 | @Override 21 | public void transform(Map input) { 22 | Object value = input.get(modelInfo.getInputKeys().iterator().next()); 23 | double[] inp = (value == null)? null: (double[])value; 24 | String elseSetColValue = (String)input.get(modelInfo.getElseSetCol()); 25 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp, modelInfo.getThenSetValue(), elseSetColValue)); 26 | } 27 | 28 | private String predict(double[] inp, String thenSetValue, String elseSetColValue) { 29 | if(inp == null || inp.length == 0) { 30 | return thenSetValue; 31 | } 32 | boolean allZero = true; 33 | for(int i=0; i getInputKeys() { 45 | return modelInfo.getInputKeys(); 46 | } 47 | 48 | @Override 49 | public Set getOutputKeys() { 50 | return modelInfo.getOutputKeys(); 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/DecisionTreeRegressionModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import java.util.LinkedHashSet; 4 | import java.util.Set; 5 | 6 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 7 | import org.apache.spark.ml.tree.Node; 8 | 9 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 10 | import com.flipkart.fdp.ml.utils.DecisionNodeAdapterUtils; 11 | 12 | import lombok.extern.slf4j.Slf4j; 13 | 14 | 15 | /** 16 | * Transforms Spark's {@link org.apache.spark.ml.regression.DecisionTreeRegressionModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo} object 17 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 18 | */ 19 | @Slf4j 20 | public class DecisionTreeRegressionModelInfoAdapter 21 | extends AbstractModelInfoAdapter { 22 | 23 | public DecisionTreeModelInfo getModelInfo(final DecisionTreeRegressionModel decisionTreeModel) { 24 | final DecisionTreeModelInfo treeInfo = new DecisionTreeModelInfo(); 25 | 26 | Node rootNode = decisionTreeModel.rootNode(); 27 | treeInfo.setRoot( DecisionNodeAdapterUtils.adaptNode(rootNode)); 28 | 29 | final Set inputKeys = new LinkedHashSet(); 30 | inputKeys.add(decisionTreeModel.getFeaturesCol()); 31 | inputKeys.add(decisionTreeModel.getLabelCol()); 32 | treeInfo.setInputKeys(inputKeys); 33 | 34 | final Set outputKeys = new LinkedHashSet(); 35 | outputKeys.add(decisionTreeModel.getPredictionCol()); 36 | treeInfo.setOutputKeys(outputKeys); 37 | 38 | return treeInfo; 39 | } 40 | 41 | @Override 42 | public Class getSource() { 43 | return DecisionTreeRegressionModel.class; 44 | } 45 | 46 | @Override 47 | public Class getTarget() { 48 | return DecisionTreeModelInfo.class; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Spark-Transformers: Library for exporting spark models in Java ecosystem. 4 | 5 | [![][travis img]][travis] 6 | [![][maven img]][maven] 7 | [![][license img]][license] 8 | ## Full Documentation 9 | 10 | See the [Wiki](https://github.com/flipkart-incubator/spark-transformers/wiki) for full documentation, examples, operational details and other information. 11 | 12 | ## Communication 13 | 14 | - Google Group: [ml-platform](mailto:ml-platform@flipkart.com) 15 | - [GitHub Issues](https://github.com/flipkart-incubator/spark-transformers/issues) 16 | 17 | 18 | ## What does it do? 19 | 20 | * Provide a way to export Spark models/transformations into a custom format which can be imported back into a java object. 21 | * Provide a way to do model predictions in java ecosystem. 22 | 23 | ## Tech Talk on spark transformers 24 | [Presentation link](https://docs.google.com/presentation/d/1tfcV0jnoTWhFonY_dzrBO1Qk5ar-jY5yoqWcRQQe6v4/edit?usp=sharing) 25 | 26 | ## Bugs and Feedback 27 | 28 | For bugs, questions and discussions please use the [Github Issues](https://github.com/flipkart-incubator/spark-transformers/issues). 29 | 30 | 31 | ## LICENSE 32 | Spark-Transformers is licensed under : The Apache Software License, Version 2.0. Here is a copy of the license (http://www.apache.org/licenses/LICENSE-2.0.txt) 33 | 34 | [maven]:http://search.maven.org/#search%7Cga%7C1%7Ccom.flipkart.fdp.ml 35 | [maven img]:https://img.shields.io/maven-central/v/com.flipkart.fdp.ml/spark-transformers.svg 36 | 37 | [travis]:https://travis-ci.org/flipkart-incubator/spark-transformers 38 | [travis img]:https://img.shields.io/travis/flipkart-incubator/spark-transformers.svg 39 | 40 | [release]:https://github.com/flipkart-incubator/spark-transformers/releases 41 | [release img]:https://img.shields.io/badge/release-0.1-green.svg 42 | 43 | [license]:http://www.apache.org/licenses/LICENSE-2.0.txt 44 | [license img]:https://img.shields.io/badge/License-Apache%202-blue.svg 45 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/StringIndexerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a String Indexer model representation 10 | * captured by {@link com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo}. 11 | */ 12 | public class StringIndexerTransformer implements Transformer { 13 | 14 | private final StringIndexerModelInfo modelInfo; 15 | private final double maxIndex; 16 | 17 | public StringIndexerTransformer(final StringIndexerModelInfo modelInfo) { 18 | this.modelInfo = modelInfo; 19 | //derive maximum index value to handleUnseen values 20 | double max = 0.0; 21 | for(Map.Entry entry : modelInfo.getLabelToIndex().entrySet()) { 22 | max = Math.max(max, entry.getValue()); 23 | } 24 | maxIndex = max; 25 | } 26 | 27 | public double predict(final String s) { 28 | Double index = modelInfo.getLabelToIndex().get(s); 29 | if (null == index) { 30 | if(modelInfo.isFailOnUnseenValues()) { 31 | throw new RuntimeException("Unseen label :" + s); 32 | }else { 33 | //handling unseen value. Returning maxIndex+1 34 | index = maxIndex+1; 35 | } 36 | } 37 | return index; 38 | } 39 | 40 | @Override 41 | public void transform(Map input) { 42 | Object inp = input.get(modelInfo.getInputKeys().iterator().next()); 43 | String stringInput = (null != inp)?inp.toString() : ""; 44 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(stringInput)); 45 | } 46 | 47 | @Override 48 | public Set getInputKeys() { 49 | return modelInfo.getInputKeys(); 50 | } 51 | 52 | @Override 53 | public Set getOutputKeys() { 54 | return modelInfo.getOutputKeys(); 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/DecisionTreeRegressionModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 4 | import com.flipkart.fdp.ml.utils.DecisionNodeAdapterUtils; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 7 | import org.apache.spark.ml.tree.Node; 8 | import org.apache.spark.sql.DataFrame; 9 | 10 | import java.util.LinkedHashSet; 11 | import java.util.Set; 12 | 13 | 14 | /** 15 | * Transforms Spark's {@link org.apache.spark.ml.regression.DecisionTreeRegressionModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo} object 16 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 17 | */ 18 | @Slf4j 19 | public class DecisionTreeRegressionModelInfoAdapter 20 | extends AbstractModelInfoAdapter { 21 | 22 | public DecisionTreeModelInfo getModelInfo(final DecisionTreeRegressionModel decisionTreeModel, final DataFrame df) { 23 | final DecisionTreeModelInfo treeInfo = new DecisionTreeModelInfo(); 24 | 25 | Node rootNode = decisionTreeModel.rootNode(); 26 | treeInfo.setRoot( DecisionNodeAdapterUtils.adaptNode(rootNode)); 27 | 28 | final Set inputKeys = new LinkedHashSet(); 29 | inputKeys.add(decisionTreeModel.getFeaturesCol()); 30 | inputKeys.add(decisionTreeModel.getLabelCol()); 31 | treeInfo.setInputKeys(inputKeys); 32 | 33 | final Set outputKeys = new LinkedHashSet(); 34 | outputKeys.add(decisionTreeModel.getPredictionCol()); 35 | treeInfo.setOutputKeys(outputKeys); 36 | 37 | return treeInfo; 38 | } 39 | 40 | @Override 41 | public Class getSource() { 42 | return DecisionTreeRegressionModel.class; 43 | } 44 | 45 | @Override 46 | public Class getTarget() { 47 | return DecisionTreeModelInfo.class; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /custom-transformer-2.0/src/main/scala/com/flipkart/transformer/ml/StringMerge.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.transformer.ml 2 | 3 | import org.apache.spark.annotation.DeveloperApi 4 | import org.apache.spark.ml.Transformer 5 | import org.apache.spark.ml.param.{Param, ParamMap} 6 | import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} 7 | import org.apache.spark.sql.functions._ 8 | import org.apache.spark.sql.types.{StringType, StructField, StructType} 9 | import org.apache.spark.sql.{DataFrame, _} 10 | 11 | /** 12 | * StringMerge: Merges given two input String 13 | */ 14 | class StringMerge(override val uid: String) extends Transformer with DefaultParamsWritable { 15 | final val inputCol1: Param[String] = new Param[String](this, "inputCol1", "input first column name") 16 | final def getInputCol1: String = $(inputCol1) 17 | 18 | final val inputCol2: Param[String] = new Param[String](this, "inputCol2", "input second column name") 19 | final def getInputCol2: String = $(inputCol2) 20 | 21 | final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") 22 | final def getOutputCol: String = $(outputCol) 23 | 24 | def setInputCol1(value: String): this.type = set(inputCol1, value) 25 | def setInputCol2(value: String): this.type = set(inputCol2, value) 26 | def setOutputCol(value: String): this.type = set(outputCol, value) 27 | 28 | val mergeAddress = udf((s1: String, s2: String) => (s1 + " " + s2).trim) 29 | 30 | def this() = this(Identifiable.randomUID("StringMerge")) 31 | 32 | override def transform(dataset: Dataset[_]): DataFrame = { 33 | transformSchema(dataset.schema) 34 | dataset 35 | .withColumn(getOutputCol, mergeAddress(col(getInputCol1), col(getInputCol2))) 36 | } 37 | 38 | override def copy(extra: ParamMap): Transformer = defaultCopy(extra) 39 | 40 | @DeveloperApi 41 | override def transformSchema(schema: StructType): StructType = { 42 | StructType(schema.fields :+ StructField(getOutputCol, StringType)) 43 | } 44 | } 45 | 46 | object StringMerge extends DefaultParamsReadable[StringMerge] 47 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/LogisticRegressionTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 4 | import org.slf4j.Logger; 5 | import org.slf4j.LoggerFactory; 6 | 7 | import java.util.Map; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms input/ predicts for a Logistic Regression modelInfo representation 12 | * captured by {@link com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo}. 13 | */ 14 | public class LogisticRegressionTransformer implements Transformer { 15 | private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionTransformer.class); 16 | private final LogisticRegressionModelInfo modelInfo; 17 | 18 | public LogisticRegressionTransformer(final LogisticRegressionModelInfo modelInfo) { 19 | this.modelInfo = modelInfo; 20 | } 21 | 22 | public double getProbability(final double[] input) { 23 | double dotProduct = 0.0; 24 | for (int i = 0; i < input.length; i++) { 25 | dotProduct += input[i] * modelInfo.getWeights()[i]; 26 | } 27 | double margin = dotProduct + modelInfo.getIntercept(); 28 | double predictedRaw = 1.0 / (1.0 + Math.exp(-margin)); 29 | return (predictedRaw); 30 | } 31 | 32 | public double predict(final double predictedRaw) { 33 | return (predictedRaw > modelInfo.getThreshold() ? 1.0 : 0.0); 34 | } 35 | 36 | @Override 37 | public void transform(Map input) { 38 | double[] inp = (double[]) input.get(modelInfo.getInputKeys().iterator().next()); 39 | input.put(modelInfo.getProbabilityKey(), getProbability(inp)); 40 | input.put(modelInfo.getOutputKeys().iterator().next(), predict((double) input.get(modelInfo.getProbabilityKey()))); 41 | } 42 | 43 | @Override 44 | public Set getInputKeys() { 45 | return modelInfo.getInputKeys(); 46 | } 47 | 48 | @Override 49 | public Set getOutputKeys() { 50 | return modelInfo.getOutputKeys(); 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/LogisticRegressionBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import org.apache.spark.mllib.classification.LogisticRegressionModel; 8 | import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; 9 | import org.apache.spark.mllib.linalg.Vector; 10 | import org.apache.spark.mllib.regression.LabeledPoint; 11 | import org.apache.spark.mllib.util.MLUtils; 12 | import org.junit.Test; 13 | 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | 18 | import static org.junit.Assert.assertEquals; 19 | 20 | public class LogisticRegressionBridgeTest extends SparkTestBase { 21 | 22 | @Test 23 | public void testLogisticRegression() { 24 | //prepare data 25 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 26 | JavaRDD trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); 27 | 28 | //Train model in spark 29 | LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd()); 30 | 31 | //Export this model 32 | byte[] exportedModel = ModelExporter.export(lrmodel); 33 | 34 | //Import and get Transformer 35 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 36 | 37 | //validate predictions 38 | List testPoints = trainingData.collect(); 39 | for (LabeledPoint i : testPoints) { 40 | Vector v = i.features(); 41 | double actual = lrmodel.predict(v); 42 | 43 | Map data = new HashMap(); 44 | data.put("features", v.toArray()); 45 | transformer.transform(data); 46 | double predicted = (double) data.get("prediction"); 47 | 48 | assertEquals(actual, predicted, 0.01); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/export/LogisticRegressionExporterTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.export; 2 | 3 | import com.flipkart.fdp.ml.adapter.SparkTestBase; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import org.apache.spark.mllib.classification.LogisticRegressionModel; 8 | import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; 9 | import org.apache.spark.mllib.regression.LabeledPoint; 10 | import org.apache.spark.mllib.util.MLUtils; 11 | import org.junit.Test; 12 | 13 | import static junit.framework.TestCase.assertEquals; 14 | 15 | ; 16 | 17 | public class LogisticRegressionExporterTest extends SparkTestBase { 18 | 19 | @Test 20 | public void shouldExportAndImportCorrectly() { 21 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 22 | JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); 23 | 24 | //Train model in spark 25 | LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd()); 26 | 27 | //Export this model 28 | byte[] exportedModel = ModelExporter.export(lrmodel); 29 | 30 | //Import it back 31 | LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel); 32 | 33 | //check if they are exactly equal with respect to their fields 34 | //it maybe edge cases eg. order of elements in the list is changed 35 | assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01); 36 | assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01); 37 | assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01); 38 | assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01); 39 | for (int i = 0; i < importedModel.getNumFeatures(); i++) 40 | assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01); 41 | 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /adapters-1.6/src/test/java/com/flipkart/fdp/ml/adapter/LogisticRegressionBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import org.apache.spark.mllib.classification.LogisticRegressionModel; 8 | import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; 9 | import org.apache.spark.mllib.linalg.Vector; 10 | import org.apache.spark.mllib.regression.LabeledPoint; 11 | import org.apache.spark.mllib.util.MLUtils; 12 | import org.junit.Test; 13 | 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | 18 | import static org.junit.Assert.assertEquals; 19 | 20 | public class LogisticRegressionBridgeTest extends SparkTestBase { 21 | 22 | @Test 23 | public void testLogisticRegression() { 24 | //prepare data 25 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 26 | JavaRDD trainingData = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); 27 | 28 | //Train model in spark 29 | LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd()); 30 | 31 | //Export this model 32 | byte[] exportedModel = ModelExporter.export(lrmodel, null); 33 | 34 | //Import and get Transformer 35 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 36 | 37 | //validate predictions 38 | List testPoints = trainingData.collect(); 39 | for (LabeledPoint i : testPoints) { 40 | Vector v = i.features(); 41 | double actual = lrmodel.predict(v); 42 | 43 | Map data = new HashMap(); 44 | data.put("features", v.toArray()); 45 | transformer.transform(data); 46 | double predicted = (double) data.get("prediction"); 47 | 48 | assertEquals(actual, predicted, EPSILON); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /adapters-1.6/src/test/java/com/flipkart/fdp/ml/adapter/LogisticRegression1BridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.ml.classification.LogisticRegression; 7 | import org.apache.spark.ml.classification.LogisticRegressionModel; 8 | import org.apache.spark.mllib.linalg.Vector; 9 | import org.apache.spark.mllib.regression.LabeledPoint; 10 | import org.apache.spark.mllib.util.MLUtils; 11 | import org.apache.spark.sql.DataFrame; 12 | import org.junit.Test; 13 | 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | 18 | import static org.junit.Assert.assertEquals; 19 | 20 | public class LogisticRegression1BridgeTest extends SparkTestBase { 21 | 22 | @Test 23 | public void testLogisticRegression() { 24 | //prepare data 25 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 26 | 27 | DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath); 28 | 29 | //Train model in spark 30 | LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData); 31 | 32 | //Export this model 33 | byte[] exportedModel = ModelExporter.export(lrmodel, trainingData); 34 | 35 | //Import and get Transformer 36 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 37 | 38 | //validate predictions 39 | List testPoints = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().collect(); 40 | for (LabeledPoint i : testPoints) { 41 | Vector v = i.features(); 42 | double actual = lrmodel.predict(v); 43 | 44 | Map data = new HashMap(); 45 | data.put("features", v.toArray()); 46 | transformer.transform(data); 47 | double predicted = (double) data.get("prediction"); 48 | 49 | assertEquals(actual, predicted, EPSILON); 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /adapters-1.6/src/test/java/com/flipkart/fdp/ml/export/LogisticRegressionExporterTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.export; 2 | 3 | import com.flipkart.fdp.ml.adapter.SparkTestBase; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import org.apache.spark.mllib.classification.LogisticRegressionModel; 8 | import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; 9 | import org.apache.spark.mllib.regression.LabeledPoint; 10 | import org.apache.spark.mllib.util.MLUtils; 11 | import org.junit.Test; 12 | 13 | import static junit.framework.TestCase.assertEquals; 14 | 15 | ; 16 | 17 | public class LogisticRegressionExporterTest extends SparkTestBase { 18 | 19 | @Test 20 | public void shouldExportAndImportCorrectly() { 21 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 22 | JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); 23 | 24 | //Train model in spark 25 | LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd()); 26 | 27 | //Export this model 28 | byte[] exportedModel = ModelExporter.export(lrmodel, null); 29 | 30 | //Import it back 31 | LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel); 32 | 33 | //check if they are exactly equal with respect to their fields 34 | //it maybe edge cases eg. order of elements in the list is changed 35 | assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON); 36 | assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON); 37 | assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON); 38 | assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON); 39 | for (int i = 0; i < importedModel.getNumFeatures(); i++) 40 | assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON); 41 | 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/LogisticRegression1BridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.ml.classification.LogisticRegression; 7 | import org.apache.spark.ml.classification.LogisticRegressionModel; 8 | import org.apache.spark.ml.linalg.Vector; 9 | import org.apache.spark.mllib.regression.LabeledPoint; 10 | import org.apache.spark.mllib.util.MLUtils; 11 | import org.apache.spark.sql.Dataset; 12 | import org.apache.spark.sql.Row; 13 | import org.junit.Test; 14 | 15 | import java.util.HashMap; 16 | import java.util.List; 17 | import java.util.Map; 18 | 19 | import static org.junit.Assert.assertEquals; 20 | 21 | 22 | public class LogisticRegression1BridgeTest extends SparkTestBase { 23 | 24 | @Test 25 | public void testLogisticRegression() { 26 | //prepare data 27 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 28 | 29 | Dataset trainingData = spark.read().format("libsvm").load(datapath); 30 | 31 | //Train model in spark 32 | LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData); 33 | 34 | //Export this model 35 | byte[] exportedModel = ModelExporter.export(lrmodel); 36 | 37 | //Import and get Transformer 38 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 39 | 40 | //validate predictions 41 | List testPoints = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD().collect(); 42 | for (LabeledPoint i : testPoints) { 43 | Vector v = i.features().asML(); 44 | double actual = lrmodel.predict(v); 45 | 46 | Map data = new HashMap(); 47 | data.put("features", v.toArray()); 48 | transformer.transform(data); 49 | double predicted = (double) data.get("prediction"); 50 | 51 | assertEquals(actual, predicted, 0.01); 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/LogisticRegressionModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 4 | import lombok.extern.slf4j.Slf4j; 5 | import org.apache.spark.mllib.classification.LogisticRegressionModel; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link LogisticRegressionModel} in MlLib to {@link LogisticRegressionModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | @Slf4j 15 | public class LogisticRegressionModelInfoAdapter 16 | extends AbstractModelInfoAdapter { 17 | 18 | @Override 19 | public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) { 20 | final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); 21 | logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray()); 22 | logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); 23 | logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); 24 | logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); 25 | logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get()); 26 | 27 | Set inputKeys = new LinkedHashSet(); 28 | inputKeys.add("features"); 29 | logisticRegressionModelInfo.setInputKeys(inputKeys); 30 | 31 | Set outputKeys = new LinkedHashSet(); 32 | outputKeys.add("prediction"); 33 | outputKeys.add("probability"); 34 | logisticRegressionModelInfo.setOutputKeys(outputKeys); 35 | 36 | return logisticRegressionModelInfo; 37 | } 38 | 39 | @Override 40 | public Class getSource() { 41 | return LogisticRegressionModel.class; 42 | } 43 | 44 | @Override 45 | public Class getTarget() { 46 | return LogisticRegressionModelInfo.class; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/LogisticRegressionModelInfoAdapter1.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 4 | import lombok.extern.slf4j.Slf4j; 5 | import org.apache.spark.ml.classification.LogisticRegressionModel; 6 | 7 | import java.util.LinkedHashSet; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms Spark's {@link LogisticRegressionModel} to {@link LogisticRegressionModelInfo} object 12 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 13 | */ 14 | @Slf4j 15 | public class LogisticRegressionModelInfoAdapter1 16 | extends AbstractModelInfoAdapter { 17 | 18 | @Override 19 | public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) { 20 | final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); 21 | logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray()); 22 | logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); 23 | logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); 24 | logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); 25 | logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold()); 26 | 27 | Set inputKeys = new LinkedHashSet(); 28 | inputKeys.add(sparkLRModel.getFeaturesCol()); 29 | logisticRegressionModelInfo.setInputKeys(inputKeys); 30 | 31 | Set outputKeys = new LinkedHashSet(); 32 | outputKeys.add(sparkLRModel.getPredictionCol()); 33 | outputKeys.add(sparkLRModel.getProbabilityCol()); 34 | logisticRegressionModelInfo.setOutputKeys(outputKeys); 35 | 36 | return logisticRegressionModelInfo; 37 | } 38 | 39 | @Override 40 | public Class getSource() { 41 | return LogisticRegressionModel.class; 42 | } 43 | 44 | @Override 45 | public Class getTarget() { 46 | return LogisticRegressionModelInfo.class; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/BucketizerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.BucketizerModelInfo; 4 | 5 | import java.util.Arrays; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms input/ predicts for a Bucketizer model representation 11 | * captured by {@link com.flipkart.fdp.ml.modelinfo.BucketizerModelInfo}. 12 | */ 13 | public class BucketizerTransformer implements Transformer { 14 | 15 | private final BucketizerModelInfo modelInfo; 16 | 17 | public BucketizerTransformer(final BucketizerModelInfo modelInfo) { 18 | this.modelInfo = modelInfo; 19 | } 20 | 21 | public double predict(final double input) { 22 | if (modelInfo.getSplits().length <= 0) { 23 | throw new RuntimeException("BucketizerTransformer : splits have incorrect length : " + modelInfo.getSplits().length); 24 | } 25 | 26 | final double last = modelInfo.getSplits()[modelInfo.getSplits().length - 1]; 27 | if (input == last) { 28 | return modelInfo.getSplits().length - 2; 29 | } 30 | 31 | int idx = Arrays.binarySearch(modelInfo.getSplits(), input); 32 | if (idx >= 0) { 33 | return idx; 34 | } else { 35 | int insertPos = -idx - 1; 36 | if (insertPos == 0 || insertPos == modelInfo.getSplits().length) { 37 | throw new RuntimeException("BucketizerTransformer : Feature value : " + input + " out of bounds : (" + modelInfo.getSplits()[0] + "," + last + ")"); 38 | } else { 39 | return insertPos - 1; 40 | } 41 | } 42 | } 43 | 44 | @Override 45 | public void transform(Map input) { 46 | double inp = (double) input.get(modelInfo.getInputKeys().iterator().next()); 47 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 48 | } 49 | 50 | @Override 51 | public Set getInputKeys() { 52 | return modelInfo.getInputKeys(); 53 | } 54 | 55 | @Override 56 | public Set getOutputKeys() { 57 | return modelInfo.getOutputKeys(); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /adapters-1.6/src/test/java/com/flipkart/fdp/ml/export/LogisticRegression1ExporterTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.export; 2 | 3 | import com.flipkart.fdp.ml.adapter.SparkTestBase; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 6 | import org.apache.spark.ml.classification.LogisticRegression; 7 | import org.apache.spark.ml.classification.LogisticRegressionModel; 8 | import org.apache.spark.sql.DataFrame; 9 | import org.junit.Test; 10 | 11 | import static junit.framework.TestCase.assertEquals; 12 | 13 | public class LogisticRegression1ExporterTest extends SparkTestBase { 14 | 15 | @Test 16 | public void shouldExportAndImportCorrectly() { 17 | //prepare data 18 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 19 | 20 | DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath); 21 | 22 | //Train model in spark 23 | LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData); 24 | 25 | //Export this model 26 | byte[] exportedModel = ModelExporter.export(lrmodel, trainingData); 27 | 28 | //Import it back 29 | LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel); 30 | 31 | //check if they are exactly equal with respect to their fields 32 | //it maybe edge cases eg. order of elements in the list is changed 33 | assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON); 34 | assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON); 35 | assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON); 36 | assertEquals(lrmodel.getThreshold(), importedModel.getThreshold(), EPSILON); 37 | for (int i = 0; i < importedModel.getNumFeatures(); i++) 38 | assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON); 39 | 40 | assertEquals(lrmodel.getFeaturesCol(), importedModel.getInputKeys().iterator().next()); 41 | assertEquals(lrmodel.getPredictionCol(), importedModel.getOutputKeys().iterator().next()); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/export/LogisticRegression1ExporterTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.export; 2 | 3 | import com.flipkart.fdp.ml.adapter.SparkTestBase; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 6 | import org.apache.spark.ml.classification.LogisticRegression; 7 | import org.apache.spark.ml.classification.LogisticRegressionModel; 8 | import org.apache.spark.sql.Dataset; 9 | import org.apache.spark.sql.Row; 10 | import org.junit.Test; 11 | 12 | import static junit.framework.TestCase.assertEquals; 13 | 14 | public class LogisticRegression1ExporterTest extends SparkTestBase { 15 | 16 | @Test 17 | public void shouldExportAndImportCorrectly() { 18 | //prepare data 19 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 20 | 21 | Dataset trainingData = spark.read().format("libsvm").load(datapath); 22 | 23 | //Train model in spark 24 | LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData); 25 | 26 | //Export this model 27 | byte[] exportedModel = ModelExporter.export(lrmodel); 28 | 29 | //Import it back 30 | LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel); 31 | 32 | //check if they are exactly equal with respect to their fields 33 | //it maybe edge cases eg. order of elements in the list is changed 34 | assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01); 35 | assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01); 36 | assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01); 37 | assertEquals(lrmodel.getThreshold(), importedModel.getThreshold(), 0.01); 38 | for (int i = 0; i < importedModel.getNumFeatures(); i++) 39 | assertEquals(lrmodel.coefficients().toArray()[i], importedModel.getWeights()[i], 0.01); 40 | 41 | assertEquals(lrmodel.getFeaturesCol(), importedModel.getInputKeys().iterator().next()); 42 | assertEquals(lrmodel.getPredictionCol(), importedModel.getOutputKeys().iterator().next()); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/LogisticRegressionModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 4 | import lombok.extern.slf4j.Slf4j; 5 | import org.apache.spark.mllib.classification.LogisticRegressionModel; 6 | import org.apache.spark.sql.DataFrame; 7 | 8 | import java.util.LinkedHashSet; 9 | import java.util.Set; 10 | 11 | /** 12 | * Transforms Spark's {@link LogisticRegressionModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo} object 13 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 14 | */ 15 | @Slf4j 16 | public class LogisticRegressionModelInfoAdapter 17 | extends AbstractModelInfoAdapter { 18 | 19 | @Override 20 | public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) { 21 | final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); 22 | logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray()); 23 | logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); 24 | logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); 25 | logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); 26 | logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get()); 27 | 28 | Set inputKeys = new LinkedHashSet(); 29 | inputKeys.add("features"); 30 | logisticRegressionModelInfo.setInputKeys(inputKeys); 31 | 32 | Set outputKeys = new LinkedHashSet(); 33 | outputKeys.add("prediction"); 34 | outputKeys.add("probability"); 35 | logisticRegressionModelInfo.setOutputKeys(outputKeys); 36 | 37 | return logisticRegressionModelInfo; 38 | } 39 | 40 | @Override 41 | public Class getSource() { 42 | return LogisticRegressionModel.class; 43 | } 44 | 45 | @Override 46 | public Class getTarget() { 47 | return LogisticRegressionModelInfo.class; 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/utils/PipelineUtils.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.utils; 2 | 3 | import com.flipkart.fdp.ml.transformer.Transformer; 4 | 5 | import java.io.Serializable; 6 | import java.util.HashSet; 7 | import java.util.Set; 8 | 9 | /** 10 | * Utility to extract input columns for a pipeline 11 | */ 12 | public class PipelineUtils implements Serializable { 13 | public static Set extractRequiredInputColumns(Transformer[] transformers) { 14 | Set inputColumns = new HashSet<>(); 15 | 16 | //Add inputs for each transformer in the input set 17 | for(Transformer t : transformers) { 18 | inputColumns.addAll(t.getInputKeys()); 19 | } 20 | 21 | //remove non modifying columns of each transformer 22 | for(Transformer t : transformers) { 23 | //calculate set difference Set(outputs) - Set(inputs) 24 | Set setDifference = new HashSet<>(t.getOutputKeys()); 25 | setDifference.removeAll(t.getInputKeys()); 26 | 27 | inputColumns.removeAll(setDifference); 28 | } 29 | 30 | //Not handled cases where a transformer replaces/modifies any column that is not its input column. 31 | return inputColumns; 32 | } 33 | 34 | public static Set extractRequiredOutputColumns(Transformer[] transformers) { 35 | Set outputColumns = new HashSet<>(); 36 | 37 | //Add outputs for each transformer in the output set 38 | //traversing in reverse 39 | for(int i = transformers.length-1; i>=0; i--) { 40 | outputColumns.addAll(transformers[i].getOutputKeys()); 41 | } 42 | 43 | //remove non modifying columns of each transformer 44 | for(int i = transformers.length-1; i>=0; i--) { 45 | //calculate set difference Set(inputs) - Set(outputs) 46 | Set setDifference = new HashSet<>(transformers[i].getInputKeys()); 47 | setDifference.removeAll(transformers[i].getOutputKeys()); 48 | 49 | outputColumns.removeAll(setDifference); 50 | } 51 | 52 | //Not handled cases where a transformer replaces/modifies any column that is not its input column. 53 | return outputColumns; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/GradientBoostClassificationTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 9 | import com.flipkart.fdp.ml.modelinfo.GradientBoostModelInfo; 10 | import com.github.fommil.netlib.BLAS; 11 | 12 | public class GradientBoostClassificationTransformer implements Transformer { 13 | private final GradientBoostModelInfo forest; 14 | private final List subTransformers; 15 | 16 | public GradientBoostClassificationTransformer(final GradientBoostModelInfo forest) { 17 | this.forest = forest; 18 | this.subTransformers = new ArrayList<>(forest.getTrees().size()); 19 | for (final DecisionTreeModelInfo tree : forest.getTrees()) { 20 | subTransformers.add((DecisionTreeTransformer) tree.getTransformer()); 21 | } 22 | } 23 | 24 | public double predict(final double[] input) { 25 | double[] treePredictions = new double[subTransformers.size()]; 26 | double [] treeWeights = new double[subTransformers.size()]; 27 | List modelTreeWeights = forest.getTreeWeights(); 28 | int index = 0; 29 | for (final DecisionTreeTransformer treeTransformer : subTransformers) { 30 | treePredictions[index] = treeTransformer.predict(input); 31 | treeWeights[index] = modelTreeWeights.get(index); 32 | index++; 33 | } 34 | double prediction = BLAS.getInstance().ddot(subTransformers.size(), treePredictions, 1, treeWeights, 1); 35 | if (prediction > 0.0) 36 | return 1.0; 37 | else 38 | return 0.0; 39 | } 40 | 41 | @Override 42 | public void transform(Map input) { 43 | double[] inp = (double[]) input.get(forest.getInputKeys().iterator().next()); 44 | input.put(forest.getOutputKeys().iterator().next(), predict(inp)); 45 | } 46 | 47 | 48 | @Override 49 | public Set getInputKeys() { 50 | return forest.getInputKeys(); 51 | } 52 | 53 | @Override 54 | public Set getOutputKeys() { 55 | return forest.getOutputKeys(); 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/StringMergeBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import com.flipkart.transformer.ml.StringMerge; 8 | import org.apache.spark.sql.Dataset; 9 | import org.apache.spark.sql.Row; 10 | import org.apache.spark.sql.RowFactory; 11 | import org.apache.spark.sql.types.*; 12 | import org.junit.Test; 13 | 14 | import java.util.*; 15 | 16 | import static org.junit.Assert.assertEquals; 17 | 18 | public class StringMergeBridgeTest extends SparkTestBase { 19 | @Test 20 | public void testStringMerge() { 21 | 22 | //prepare data 23 | JavaRDD rdd = jsc.parallelize(Arrays.asList( 24 | RowFactory.create(1, "string1", "string2"), 25 | RowFactory.create(1, "first part of string", "second part of string") 26 | )); 27 | 28 | StructType schema = new StructType(new StructField[]{ 29 | new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), 30 | new StructField("input1", DataTypes.StringType, true, Metadata.empty()), 31 | new StructField("input2", DataTypes.StringType, true, Metadata.empty()) 32 | }); 33 | Dataset df = spark.createDataFrame(rdd, schema); 34 | 35 | //train model in spark 36 | StringMerge sparkModel = new StringMerge() 37 | .setInputCol1("input1") 38 | .setInputCol2("input2") 39 | .setOutputCol("output"); 40 | //Export this model 41 | byte[] exportedModel = ModelExporter.export(sparkModel); 42 | 43 | // //Import and get Transformer 44 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 45 | // 46 | // //compare predictions 47 | List sparkOutput = sparkModel.transform(df).orderBy("id").select("input1", "input2", "output").collectAsList(); 48 | for (Row row : sparkOutput) { 49 | 50 | Map data = new HashMap(); 51 | data.put(sparkModel.getInputCol1(), row.get(0)); 52 | data.put(sparkModel.getInputCol2(), row.get(1)); 53 | transformer.transform(data); 54 | String actual = (String) data.get(sparkModel.getOutputCol()); 55 | 56 | assertEquals(actual, row.get(2)); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/GradientBoostClassificationModelTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | import org.apache.spark.ml.classification.GBTClassificationModel; 10 | import org.apache.spark.ml.classification.GBTClassifier; 11 | import org.apache.spark.ml.linalg.SparseVector; 12 | import org.apache.spark.sql.Dataset; 13 | import org.apache.spark.sql.Row; 14 | import org.junit.Test; 15 | 16 | import com.flipkart.fdp.ml.export.ModelExporter; 17 | import com.flipkart.fdp.ml.importer.ModelImporter; 18 | import com.flipkart.fdp.ml.transformer.Transformer; 19 | 20 | /** 21 | * 22 | * @author harshit.pandey 23 | * 24 | */ 25 | public class GradientBoostClassificationModelTest extends SparkTestBase { 26 | 27 | @Test 28 | public void testGradientBoostClassification() { 29 | // Load the data stored in LIBSVM format as a DataFrame. 30 | String datapath = "src/test/resources/binary_classification_test.libsvm"; 31 | 32 | Dataset data = spark.read().format("libsvm").load(datapath); 33 | 34 | // Split the data into training and test sets (30% held out for testing) 35 | Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); 36 | Dataset trainingData = splits[0]; 37 | Dataset testData = splits[1]; 38 | 39 | // Train a RandomForest model. 40 | GBTClassificationModel classificationModel = new GBTClassifier().fit(trainingData); 41 | 42 | byte[] exportedModel = ModelExporter.export(classificationModel); 43 | 44 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 45 | 46 | List sparkOutput = 47 | classificationModel.transform(testData).select("features", "prediction","label").collectAsList(); 48 | 49 | // compare predictions 50 | for (Row row : sparkOutput) { 51 | Map data_ = new HashMap<>(); 52 | data_.put("features", ((SparseVector) row.get(0)).toArray()); 53 | data_.put("label", (row.get(2)).toString()); 54 | transformer.transform(data_); 55 | System.out.println(data_); 56 | System.out.println(data_.get("prediction")+" ,"+row.get(1)); 57 | assertEquals((double) data_.get("prediction"), (double) row.get(1), EPSILON); 58 | } 59 | 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /custom-transformer/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | spark-transformers 7 | com.flipkart.fdp.ml 8 | 0.4.0 9 | 10 | 4.0.0 11 | 12 | custom-transformer_${scala.binary.version} 13 | 14 | 15 | 1.6.2 16 | 17 | 18 | 19 | 20 | 21 | net.alchim31.maven 22 | scala-maven-plugin 23 | 3.2.2 24 | 25 | 26 | 27 | compile 28 | testCompile 29 | doc-jar 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | org.scala-lang 41 | scala-library 42 | ${scala.version} 43 | provided 44 | 45 | 46 | org.apache.spark 47 | spark-core_${scala.binary.version} 48 | ${spark.version} 49 | provided 50 | 51 | 52 | org.apache.spark 53 | spark-mllib_${scala.binary.version} 54 | ${spark.version} 55 | provided 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/LogisticRegressionModelInfoAdapter1.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.LogisticRegressionModelInfo; 4 | import lombok.extern.slf4j.Slf4j; 5 | import org.apache.spark.ml.classification.LogisticRegressionModel; 6 | import org.apache.spark.sql.DataFrame; 7 | 8 | import java.util.LinkedHashSet; 9 | import java.util.Set; 10 | 11 | /** 12 | * Transforms Spark's {@link LogisticRegressionModel} to {@link LogisticRegressionModelInfo} object 13 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 14 | */ 15 | @Slf4j 16 | public class LogisticRegressionModelInfoAdapter1 17 | extends AbstractModelInfoAdapter { 18 | 19 | @Override 20 | public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) { 21 | final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); 22 | logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray()); 23 | logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); 24 | logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); 25 | logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); 26 | logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold()); 27 | logisticRegressionModelInfo.setProbabilityKey(sparkLRModel.getProbabilityCol()); 28 | 29 | Set inputKeys = new LinkedHashSet(); 30 | inputKeys.add(sparkLRModel.getFeaturesCol()); 31 | logisticRegressionModelInfo.setInputKeys(inputKeys); 32 | 33 | Set outputKeys = new LinkedHashSet(); 34 | outputKeys.add(sparkLRModel.getPredictionCol()); 35 | outputKeys.add(sparkLRModel.getProbabilityCol()); 36 | logisticRegressionModelInfo.setOutputKeys(outputKeys); 37 | 38 | return logisticRegressionModelInfo; 39 | } 40 | 41 | @Override 42 | public Class getSource() { 43 | return LogisticRegressionModel.class; 44 | } 45 | 46 | @Override 47 | public Class getTarget() { 48 | return LogisticRegressionModelInfo.class; 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/CommonAddressFeaturesModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.CommonAddressFeaturesModelInfo; 4 | import com.flipkart.transformer.ml.CommonAddressFeatures; 5 | 6 | import java.util.Arrays; 7 | import java.util.HashSet; 8 | import java.util.LinkedHashSet; 9 | import java.util.Set; 10 | 11 | 12 | public class CommonAddressFeaturesModelInfoAdapter extends AbstractModelInfoAdapter { 13 | @Override 14 | CommonAddressFeaturesModelInfo getModelInfo(CommonAddressFeatures from) { 15 | CommonAddressFeaturesModelInfo modelInfo = new CommonAddressFeaturesModelInfo(); 16 | modelInfo.setFavourableStarts(new HashSet<>(Arrays.asList(from.favourableStartWords()))); 17 | modelInfo.setUnFavourableStarts(new HashSet<>(Arrays.asList(from.unfavourableStartWords()))); 18 | 19 | modelInfo.setSanitizedAddressParam(from.getInputCol()); 20 | modelInfo.setMergedAddressParam(from.getRawInputCol()); 21 | 22 | modelInfo.setNumWordsParam(from.getNumWordsParam()); 23 | modelInfo.setNumCommasParam(from.getNumCommasParams()); 24 | modelInfo.setNumericPresentParam(from.getNumericPresentParam()); 25 | modelInfo.setAddressLengthParam(from.getAddressLengthParam()); 26 | modelInfo.setFavouredStartColParam(from.getFavouredStartColParam()); 27 | modelInfo.setUnfavouredStartColParam(from.getUnfavouredStartColParam()); 28 | 29 | Set inputKeys = new LinkedHashSet<>(); 30 | inputKeys.add(from.getInputCol()); 31 | inputKeys.add(from.getRawInputCol()); 32 | modelInfo.setInputKeys(inputKeys); 33 | 34 | Set outputKeys = new LinkedHashSet<>(); 35 | outputKeys.add(from.getNumWordsParam()); 36 | outputKeys.add(from.getNumCommasParams()); 37 | outputKeys.add(from.getNumericPresentParam()); 38 | outputKeys.add(from.getAddressLengthParam()); 39 | outputKeys.add(from.getFavouredStartColParam()); 40 | outputKeys.add(from.getUnfavouredStartColParam()); 41 | modelInfo.setOutputKeys(outputKeys); 42 | 43 | return modelInfo; 44 | } 45 | 46 | @Override 47 | public Class getSource() { 48 | return CommonAddressFeatures.class; 49 | } 50 | 51 | @Override 52 | public Class getTarget() { 53 | return CommonAddressFeaturesModelInfo.class; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/DecisionTreeClassificationModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import java.util.LinkedHashSet; 4 | import java.util.Set; 5 | 6 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel; 7 | import org.apache.spark.ml.tree.Node; 8 | 9 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 10 | import com.flipkart.fdp.ml.utils.DecisionNodeAdapterUtils; 11 | 12 | import lombok.extern.slf4j.Slf4j; 13 | 14 | 15 | /** 16 | * Transforms Spark's {@link org.apache.spark.ml.classification.DecisionTreeClassificationModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo} object 17 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 18 | */ 19 | @Slf4j 20 | public class DecisionTreeClassificationModelInfoAdapter 21 | extends AbstractModelInfoAdapter { 22 | 23 | public DecisionTreeModelInfo getModelInfo(final DecisionTreeClassificationModel decisionTreeModel) { 24 | final DecisionTreeModelInfo treeInfo = new DecisionTreeModelInfo(); 25 | 26 | Node rootNode = decisionTreeModel.rootNode(); 27 | treeInfo.setRoot(DecisionNodeAdapterUtils.adaptNode(rootNode)); 28 | 29 | final Set inputKeys = new LinkedHashSet(); 30 | inputKeys.add(decisionTreeModel.getFeaturesCol()); 31 | inputKeys.add(decisionTreeModel.getLabelCol()); 32 | treeInfo.setInputKeys(inputKeys); 33 | 34 | final Set outputKeys = new LinkedHashSet(); 35 | outputKeys.add(decisionTreeModel.getPredictionCol()); 36 | outputKeys.add(decisionTreeModel.getProbabilityCol()); 37 | outputKeys.add(decisionTreeModel.getRawPredictionCol()); 38 | treeInfo.setProbabilityKey(decisionTreeModel.getProbabilityCol()); 39 | treeInfo.setRawPredictionKey(decisionTreeModel.getRawPredictionCol()); 40 | treeInfo.setOutputKeys(outputKeys); 41 | 42 | return treeInfo; 43 | } 44 | 45 | @Override 46 | public Class getSource() { 47 | return DecisionTreeClassificationModel.class; 48 | } 49 | 50 | @Override 51 | public Class getTarget() { 52 | return DecisionTreeModelInfo.class; 53 | } 54 | } 55 | 56 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/MinMaxScalerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.MinMaxScalerModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a MinMaxScaler model representation 10 | * captured by {@link com.flipkart.fdp.ml.modelinfo.MinMaxScalerModelInfo}. 11 | */ 12 | 13 | public class MinMaxScalerTransformer implements Transformer { 14 | private final MinMaxScalerModelInfo modelInfo; 15 | 16 | public MinMaxScalerTransformer(final MinMaxScalerModelInfo modelInfo) { 17 | this.modelInfo = modelInfo; 18 | } 19 | 20 | double[] predict(final double[] input) { 21 | //validate size of vectors 22 | if (modelInfo.getOriginalMax().length != modelInfo.getOriginalMin().length || modelInfo.getOriginalMax().length != input.length) { 23 | throw new IllegalArgumentException("Size of max, min and input vector are different : " 24 | + modelInfo.getOriginalMax().length + " , " + modelInfo.getOriginalMin().length + " , " + input.length); 25 | } 26 | 27 | final double[] originalRange = new double[modelInfo.getOriginalMax().length]; 28 | for (int i = 0; i < originalRange.length; i++) { 29 | originalRange[i] = modelInfo.getOriginalMax()[i] - modelInfo.getOriginalMin()[i]; 30 | } 31 | 32 | final double scale = modelInfo.getMax() - modelInfo.getMin(); 33 | for (int i = 0; i < input.length; i++) { 34 | if (originalRange[i] != 0.0) { 35 | input[i] = (input[i] - modelInfo.getOriginalMin()[i]) / originalRange[i]; 36 | } else { 37 | input[i] = 0.5; 38 | } 39 | input[i] = input[i] * scale + modelInfo.getMin(); 40 | } 41 | return input; 42 | } 43 | 44 | @Override 45 | public void transform(Map input) { 46 | double inp[] = (double[]) input.get(modelInfo.getInputKeys().iterator().next()); 47 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 48 | } 49 | 50 | @Override 51 | public Set getInputKeys() { 52 | return modelInfo.getInputKeys(); 53 | } 54 | 55 | @Override 56 | public Set getOutputKeys() { 57 | return modelInfo.getOutputKeys(); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/StandardScalerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.StandardScalerModelInfo; 4 | 5 | import java.util.Map; 6 | import java.util.Set; 7 | 8 | /** 9 | * Transforms input/ predicts for a Standard Scalar model representation 10 | * captured by {@link com.flipkart.fdp.ml.modelinfo.StandardScalerModelInfo}. 11 | */ 12 | public class StandardScalerTransformer implements Transformer { 13 | private final StandardScalerModelInfo modelInfo; 14 | 15 | public StandardScalerTransformer(final StandardScalerModelInfo modelInfo) { 16 | this.modelInfo = modelInfo; 17 | } 18 | 19 | public double[] predict(final double[] input) { 20 | 21 | if (modelInfo.isWithMean()) { 22 | if (input.length != modelInfo.getMean().length) { 23 | throw new IllegalArgumentException("Size of input vector and mean are different : " 24 | + input.length + " and " + modelInfo.getMean().length); 25 | } 26 | for (int i = 0; i < input.length; i++) { 27 | input[i] -= modelInfo.getMean()[i]; 28 | } 29 | } 30 | 31 | if (modelInfo.isWithStd()) { 32 | if (input.length != modelInfo.getStd().length) { 33 | throw new IllegalArgumentException("Size of std and input vector are different : " 34 | + input.length + " and " + modelInfo.getStd().length); 35 | } 36 | for (int i = 0; i < input.length; i++) { 37 | double stdi = modelInfo.getStd()[i]; 38 | if (stdi != 0.0) { 39 | input[i] /= stdi; 40 | } else { 41 | input[i] = 0.0; 42 | } 43 | } 44 | } 45 | return input; 46 | } 47 | 48 | @Override 49 | public void transform(Map input) { 50 | double[] inp = (double[]) input.get(modelInfo.getInputKeys().iterator().next()); 51 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 52 | } 53 | 54 | @Override 55 | public Set getInputKeys() { 56 | return modelInfo.getInputKeys(); 57 | } 58 | 59 | @Override 60 | public Set getOutputKeys() { 61 | return modelInfo.getOutputKeys(); 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/DecisionTreeClassificationModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 4 | import com.flipkart.fdp.ml.utils.DecisionNodeAdapterUtils; 5 | import lombok.extern.slf4j.Slf4j; 6 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel; 7 | import org.apache.spark.ml.tree.Node; 8 | import org.apache.spark.sql.DataFrame; 9 | 10 | import java.util.LinkedHashSet; 11 | import java.util.Set; 12 | 13 | 14 | /** 15 | * Transforms Spark's {@link org.apache.spark.ml.classification.DecisionTreeClassificationModel} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo} object 16 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 17 | */ 18 | @Slf4j 19 | public class DecisionTreeClassificationModelInfoAdapter 20 | extends AbstractModelInfoAdapter { 21 | 22 | public DecisionTreeModelInfo getModelInfo(final DecisionTreeClassificationModel decisionTreeModel,final DataFrame df) { 23 | final DecisionTreeModelInfo treeInfo = new DecisionTreeModelInfo(); 24 | 25 | Node rootNode = decisionTreeModel.rootNode(); 26 | treeInfo.setRoot(DecisionNodeAdapterUtils.adaptNode(rootNode)); 27 | 28 | final Set inputKeys = new LinkedHashSet(); 29 | inputKeys.add(decisionTreeModel.getFeaturesCol()); 30 | inputKeys.add(decisionTreeModel.getLabelCol()); 31 | treeInfo.setInputKeys(inputKeys); 32 | 33 | final Set outputKeys = new LinkedHashSet(); 34 | outputKeys.add(decisionTreeModel.getPredictionCol()); 35 | outputKeys.add(decisionTreeModel.getProbabilityCol()); 36 | outputKeys.add(decisionTreeModel.getRawPredictionCol()); 37 | treeInfo.setProbabilityKey(decisionTreeModel.getProbabilityCol()); 38 | treeInfo.setRawPredictionKey(decisionTreeModel.getRawPredictionCol()); 39 | treeInfo.setOutputKeys(outputKeys); 40 | 41 | return treeInfo; 42 | } 43 | 44 | @Override 45 | public Class getSource() { 46 | return DecisionTreeClassificationModel.class; 47 | } 48 | 49 | @Override 50 | public Class getTarget() { 51 | return DecisionTreeModelInfo.class; 52 | } 53 | } 54 | 55 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/DecisionTreeRegressionModelBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | import org.apache.spark.ml.linalg.SparseVector; 10 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 11 | import org.apache.spark.ml.regression.DecisionTreeRegressor; 12 | import org.apache.spark.sql.Dataset; 13 | import org.apache.spark.sql.Row; 14 | import org.junit.Test; 15 | 16 | import com.flipkart.fdp.ml.export.ModelExporter; 17 | import com.flipkart.fdp.ml.importer.ModelImporter; 18 | import com.flipkart.fdp.ml.transformer.DecisionTreeTransformer; 19 | 20 | /** 21 | * 22 | * @author harshit.pandey 23 | * 24 | */ 25 | public class DecisionTreeRegressionModelBridgeTest extends SparkTestBase { 26 | 27 | 28 | @Test 29 | public void testDecisionTreeRegressionPrediction() { 30 | // Load the data stored in LIBSVM format as a DataFrame. 31 | String datapath = "src/test/resources/regression_test.libsvm"; 32 | 33 | Dataset data = spark.read().format("libsvm").load(datapath); 34 | 35 | 36 | // Split the data into training and test sets (30% held out for testing) 37 | Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); 38 | Dataset trainingData = splits[0]; 39 | Dataset testData = splits[1]; 40 | 41 | // Train a DecisionTree model. 42 | DecisionTreeRegressionModel regressionModel = new DecisionTreeRegressor().fit(trainingData); 43 | trainingData.printSchema(); 44 | 45 | List output = regressionModel.transform(testData).select("features", "prediction").collectAsList(); 46 | byte[] exportedModel = ModelExporter.export(regressionModel); 47 | 48 | DecisionTreeTransformer transformer = (DecisionTreeTransformer) ModelImporter.importAndGetTransformer(exportedModel); 49 | 50 | System.out.println(transformer); 51 | //compare predictions 52 | for (Row row : output) { 53 | Map data_ = new HashMap<>(); 54 | data_.put("features", ((SparseVector) row.get(0)).toArray()); 55 | transformer.transform(data_); 56 | System.out.println(data_); 57 | System.out.println(data_.get("prediction")); 58 | assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON); 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/RegexTokenizerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.RegexTokenizerModelInfo; 4 | 5 | import java.util.*; 6 | import java.util.function.Predicate; 7 | import java.util.regex.Matcher; 8 | import java.util.regex.Pattern; 9 | 10 | /** 11 | * Transforms input/ predicts for a Regex Tokenizer model representation 12 | * captured by {@link com.flipkart.fdp.ml.modelinfo.RegexTokenizerModelInfo}. 13 | */ 14 | public class RegexTokenizerTransformer implements Transformer { 15 | private final RegexTokenizerModelInfo modelInfo; 16 | 17 | public RegexTokenizerTransformer(final RegexTokenizerModelInfo modelInfo) { 18 | this.modelInfo = modelInfo; 19 | } 20 | 21 | public String[] predict(final String input) { 22 | final Pattern regex = Pattern.compile(modelInfo.getPattern()); 23 | final String targetStr = (modelInfo.isToLowercase() ? input.toLowerCase() : input); 24 | final List tokens; 25 | if (modelInfo.isGaps()) { 26 | //using linkedlist for efficient deletion while filtering 27 | tokens = new LinkedList(Arrays.asList(targetStr.split(regex.pattern()))); 28 | } else { 29 | List allMatches = new LinkedList<>(); 30 | Matcher m = regex.matcher(targetStr); 31 | while (m.find()) { 32 | allMatches.add(m.group()); 33 | } 34 | tokens = allMatches; 35 | } 36 | tokens.removeIf(new Predicate() { 37 | @Override 38 | public boolean test(String p) { 39 | return p.length() < modelInfo.getMinTokenLength(); 40 | } 41 | }); 42 | final String[] filteredTokens = new String[tokens.size()]; 43 | for (int i = 0; i < filteredTokens.length; i++) { 44 | filteredTokens[i] = tokens.get(i); 45 | } 46 | return filteredTokens; 47 | } 48 | 49 | @Override 50 | public void transform(Map input) { 51 | String inp = (String) input.get(modelInfo.getInputKeys().iterator().next()); 52 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 53 | } 54 | 55 | @Override 56 | public Set getInputKeys() { 57 | return modelInfo.getInputKeys(); 58 | } 59 | 60 | @Override 61 | public Set getOutputKeys() { 62 | return modelInfo.getOutputKeys(); 63 | } 64 | 65 | } 66 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/StringIndexerBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.ml.feature.StringIndexer; 7 | import org.apache.spark.ml.feature.StringIndexerModel; 8 | import org.apache.spark.sql.Dataset; 9 | import org.apache.spark.sql.Row; 10 | import org.apache.spark.sql.types.StructField; 11 | import org.apache.spark.sql.types.StructType; 12 | import org.junit.Test; 13 | 14 | import java.util.Arrays; 15 | import java.util.HashMap; 16 | import java.util.List; 17 | import java.util.Map; 18 | 19 | import static org.apache.spark.sql.types.DataTypes.*; 20 | import static org.junit.Assert.assertEquals; 21 | 22 | /** 23 | * Created by akshay.us on 3/2/16. 24 | */ 25 | public class StringIndexerBridgeTest extends SparkTestBase { 26 | 27 | @Test 28 | public void testStringIndexer() { 29 | 30 | //prepare data 31 | StructType schema = createStructType(new StructField[]{ 32 | createStructField("id", IntegerType, false), 33 | createStructField("label", StringType, false) 34 | }); 35 | List trainingData = Arrays.asList( 36 | cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); 37 | Dataset dataset = spark.createDataFrame(trainingData, schema); 38 | 39 | //train model in spark 40 | StringIndexerModel model = new StringIndexer() 41 | .setInputCol("label") 42 | .setOutputCol("labelIndex").fit(dataset); 43 | 44 | //Export this model 45 | byte[] exportedModel = ModelExporter.export(model); 46 | 47 | //Import and get Transformer 48 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 49 | 50 | //compare predictions 51 | List sparkOutput = model.transform(dataset).orderBy("id").select("id", "label", "labelIndex").collectAsList(); 52 | for (Row row : sparkOutput) { 53 | 54 | Map data = new HashMap(); 55 | data.put(model.getInputCol(), (String) row.get(1)); 56 | transformer.transform(data); 57 | double output = (double) data.get(model.getOutputCol()); 58 | 59 | double indexerOutput = (output); 60 | assertEquals(indexerOutput, (double) row.get(2), 0.01); 61 | } 62 | 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /adapters-1.6/src/test/java/com/flipkart/fdp/ml/adapter/RegexTokenizerBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.commons.lang.ArrayUtils; 7 | import org.apache.spark.ml.feature.RegexTokenizer; 8 | import org.apache.spark.sql.DataFrame; 9 | import org.apache.spark.sql.Row; 10 | import org.apache.spark.sql.types.StructField; 11 | import org.apache.spark.sql.types.StructType; 12 | import org.junit.Test; 13 | 14 | import java.util.Arrays; 15 | import java.util.HashMap; 16 | import java.util.List; 17 | import java.util.Map; 18 | 19 | import static org.apache.spark.sql.types.DataTypes.*; 20 | 21 | /** 22 | * Created by akshay.us on 3/14/16. 23 | */ 24 | public class RegexTokenizerBridgeTest extends SparkTestBase { 25 | 26 | @Test 27 | public void testRegexTokenizer() { 28 | 29 | //prepare data 30 | StructType schema = createStructType(new StructField[]{ 31 | createStructField("rawText", StringType, false), 32 | }); 33 | List trainingData = Arrays.asList( 34 | cr("Test of tok."), 35 | cr("Te,st. punct") 36 | ); 37 | DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); 38 | 39 | //train model in spark 40 | RegexTokenizer sparkModel = new RegexTokenizer() 41 | .setInputCol("rawText") 42 | .setOutputCol("tokens") 43 | .setPattern("\\s") 44 | .setGaps(true) 45 | .setToLowercase(false) 46 | .setMinTokenLength(3); 47 | 48 | //Export this model 49 | byte[] exportedModel = ModelExporter.export(sparkModel, dataset); 50 | 51 | //Import and get Transformer 52 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 53 | 54 | Row[] pairs = sparkModel.transform(dataset).select("rawText", "tokens").collect(); 55 | for (Row row : pairs) { 56 | 57 | Map data = new HashMap(); 58 | data.put(sparkModel.getInputCol(), row.getString(0)); 59 | transformer.transform(data); 60 | String[] output = (String[]) data.get(sparkModel.getOutputCol()); 61 | 62 | Object sparkOp = row.get(1); 63 | System.out.println(ArrayUtils.toString(output)); 64 | System.out.println(row.get(1)); 65 | } 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/RegexTokenizerBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.commons.lang.ArrayUtils; 7 | import org.apache.spark.ml.feature.RegexTokenizer; 8 | import org.apache.spark.sql.Dataset; 9 | import org.apache.spark.sql.Row; 10 | import org.apache.spark.sql.types.StructField; 11 | import org.apache.spark.sql.types.StructType; 12 | import org.junit.Test; 13 | 14 | import java.util.Arrays; 15 | import java.util.HashMap; 16 | import java.util.List; 17 | import java.util.Map; 18 | 19 | import static org.apache.spark.sql.types.DataTypes.*; 20 | 21 | /** 22 | * Created by akshay.us on 3/14/16. 23 | */ 24 | public class RegexTokenizerBridgeTest extends SparkTestBase { 25 | 26 | @Test 27 | public void testRegexTokenizer() { 28 | 29 | //prepare data 30 | StructType schema = createStructType(new StructField[]{ 31 | createStructField("rawText", StringType, false), 32 | }); 33 | List trainingData = Arrays.asList( 34 | cr("Test of tok."), 35 | cr("Te,st. punct") 36 | ); 37 | Dataset dataset = spark.createDataFrame(trainingData, schema); 38 | 39 | //train model in spark 40 | RegexTokenizer sparkModel = new RegexTokenizer() 41 | .setInputCol("rawText") 42 | .setOutputCol("tokens") 43 | .setPattern("\\s") 44 | .setGaps(true) 45 | .setToLowercase(false) 46 | .setMinTokenLength(3); 47 | 48 | //Export this model 49 | byte[] exportedModel = ModelExporter.export(sparkModel); 50 | 51 | //Import and get Transformer 52 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 53 | 54 | List pairs = sparkModel.transform(dataset).select("rawText", "tokens").collectAsList(); 55 | for (Row row : pairs) { 56 | 57 | Map data = new HashMap(); 58 | data.put(sparkModel.getInputCol(), row.getString(0)); 59 | transformer.transform(data); 60 | String[] output = (String[]) data.get(sparkModel.getOutputCol()); 61 | 62 | Object sparkOp = row.get(1); 63 | System.out.println(ArrayUtils.toString(output)); 64 | System.out.println(row.get(1)); 65 | } 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/export/ModelExporter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.export; 2 | 3 | import com.flipkart.fdp.ml.ModelInfoAdapterFactory; 4 | import com.flipkart.fdp.ml.importer.SerializationConstants; 5 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 6 | import com.flipkart.fdp.ml.modelinfo.PipelineModelInfo; 7 | import com.flipkart.fdp.ml.utils.Constants; 8 | import com.google.gson.Gson; 9 | 10 | import java.util.HashMap; 11 | import java.util.Map; 12 | 13 | /** 14 | * Exports a {@link ModelInfo} object into byte[]. 15 | * The serialization format currently being used is json 16 | */ 17 | public class ModelExporter { 18 | private static final Gson gson = new Gson(); 19 | 20 | /** 21 | * Exports a Model object into byte[]. 22 | * The serialization format currently being used is json 23 | * 24 | * @param model model info to be exported 25 | * @return byte[] 26 | */ 27 | public static byte[] export(Object model) { 28 | return export( 29 | ModelInfoAdapterFactory.getAdapter(model.getClass()) 30 | .adapt(model)).getBytes(SerializationConstants.CHARSET); 31 | } 32 | 33 | /** 34 | * Exports a {@link ModelInfo} object into byte[]. 35 | * The serialization format currently being used is json 36 | * 37 | * @param modelInfo model info to be exported of type {@link ModelInfo} 38 | * @return byte[] 39 | */ 40 | private static String export(ModelInfo modelInfo) { 41 | final Map map = new HashMap(); 42 | map.put(SerializationConstants.SPARK_VERSION, Constants.SUPPORTED_SPARK_VERSION_PREFIX); 43 | map.put(SerializationConstants.EXPORTER_LIBRARY_VERSION, Constants.LIBRARY_VERSION); 44 | map.put(SerializationConstants.TYPE_IDENTIFIER, modelInfo.getClass().getCanonicalName()); 45 | if (modelInfo instanceof PipelineModelInfo) { 46 | //custom serialization is needed as type is not encoded into gson serialized modelInfo 47 | PipelineModelInfo pipelineModelInfo = (PipelineModelInfo) modelInfo; 48 | String[] serializedModels = new String[pipelineModelInfo.getStages().length]; 49 | for (int i = 0; i < serializedModels.length; i++) { 50 | serializedModels[i] = export(pipelineModelInfo.getStages()[i]); 51 | } 52 | map.put(SerializationConstants.MODEL_INFO_IDENTIFIER, gson.toJson(serializedModels)); 53 | } else { 54 | map.put(SerializationConstants.MODEL_INFO_IDENTIFIER, gson.toJson(modelInfo)); 55 | } 56 | return gson.toJson(map); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/StringSanitizerBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.export.ModelExporter; 4 | import com.flipkart.fdp.ml.importer.ModelImporter; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import com.flipkart.transformer.ml.StringSanitizer; 8 | import org.apache.spark.sql.Dataset; 9 | import org.apache.spark.sql.Row; 10 | import org.apache.spark.sql.RowFactory; 11 | import org.apache.spark.sql.types.DataTypes; 12 | import org.apache.spark.sql.types.Metadata; 13 | import org.apache.spark.sql.types.StructField; 14 | import org.apache.spark.sql.types.StructType; 15 | import org.junit.Test; 16 | 17 | import java.util.Arrays; 18 | import java.util.HashMap; 19 | import java.util.List; 20 | import java.util.Map; 21 | 22 | import static org.junit.Assert.assertTrue; 23 | 24 | public class StringSanitizerBridgeTest extends SparkTestBase { 25 | @Test 26 | public void testStringSanitizer() { 27 | 28 | //prepare data 29 | JavaRDD rdd = jsc.parallelize(Arrays.asList( 30 | RowFactory.create(1, "Jyoti complex near Sananda clothes store; English Bazar; Malda;WB;India,"), 31 | RowFactory.create(2, "hallalli vinayaka tent road c/o B K vishwanath Mandya"), 32 | RowFactory.create(3, "M.sathish S/o devudu Lakshmi opticals Gokavaram bus stand Rajhamundry 9494954476") 33 | )); 34 | 35 | StructType schema = new StructType(new StructField[]{ 36 | new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), 37 | new StructField("rawText", DataTypes.StringType, false, Metadata.empty()) 38 | }); 39 | Dataset dataset = spark.createDataFrame(rdd, schema); 40 | dataset.show(); 41 | 42 | //train model in spark 43 | StringSanitizer sparkModel = new StringSanitizer() 44 | .setInputCol("rawText") 45 | .setOutputCol("token"); 46 | 47 | //Export this model 48 | byte[] exportedModel = ModelExporter.export(sparkModel); 49 | 50 | //Import and get Transformer 51 | Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); 52 | 53 | List pairs = sparkModel.transform(dataset).select("rawText", "token").collectAsList(); 54 | 55 | for (Row row : pairs) { 56 | Map data = new HashMap(); 57 | data.put(sparkModel.getInputCol(), row.getString(0)); 58 | transformer.transform(data); 59 | 60 | String[] actual = (String[]) data.get(sparkModel.getOutputCol()); 61 | 62 | List actualList = Arrays.asList(actual); 63 | List expected = row.getList(1); 64 | 65 | assertTrue("both should be same", actualList.equals(expected)); 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/OneHotEncoderModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo; 4 | import org.apache.spark.ml.attribute.Attribute; 5 | import org.apache.spark.ml.attribute.AttributeType; 6 | import org.apache.spark.ml.attribute.BinaryAttribute; 7 | import org.apache.spark.ml.attribute.NominalAttribute; 8 | import org.apache.spark.ml.feature.OneHotEncoder; 9 | import org.apache.spark.sql.DataFrame; 10 | 11 | import java.util.LinkedHashSet; 12 | import java.util.Set; 13 | 14 | /** 15 | * Transforms Spark's {@link OneHotEncoder} in MlLib to {@link com.flipkart.fdp.ml.modelinfo.OneHotEncoderModelInfo} object 16 | * that can be exported through {@link com.flipkart.fdp.ml.export.ModelExporter} 17 | 18 | Exporting Spark's OHE is ugly. 19 | {@link com.flipkart.fdp.ml.CustomOneHotEncoder} 20 | */ 21 | 22 | 23 | public class OneHotEncoderModelInfoAdapter extends AbstractModelInfoAdapter { 24 | 25 | @Override 26 | public OneHotEncoderModelInfo getModelInfo(final OneHotEncoder from, DataFrame df) { 27 | OneHotEncoderModelInfo modelInfo = new OneHotEncoderModelInfo(); 28 | String inputColumn = from.getInputCol(); 29 | 30 | //Ugly but the only way to deal with spark here 31 | int numTypes = -1; 32 | Attribute attribute = Attribute.fromStructField(df.schema().apply(inputColumn)); 33 | if (attribute.attrType() == AttributeType.Nominal()) { 34 | numTypes = ((NominalAttribute) Attribute.fromStructField(df.schema().apply(inputColumn))).values().get().length; 35 | } else if (attribute.attrType() == AttributeType.Binary()) { 36 | numTypes = ((BinaryAttribute) Attribute.fromStructField(df.schema().apply(inputColumn))).values().get().length; 37 | } 38 | 39 | //TODO: Since dropLast is not accesible here, We are deliberately setting numTypes. This is the reason, we should use CustomOneHotEncoder 40 | modelInfo.setNumTypes(numTypes - 1); 41 | 42 | Set inputKeys = new LinkedHashSet(); 43 | inputKeys.add(from.getInputCol()); 44 | modelInfo.setInputKeys(inputKeys); 45 | 46 | Set outputKeys = new LinkedHashSet(); 47 | outputKeys.add(from.getOutputCol()); 48 | modelInfo.setOutputKeys(outputKeys); 49 | 50 | return modelInfo; 51 | } 52 | 53 | @Override 54 | public Class getSource() { 55 | return OneHotEncoder.class; 56 | } 57 | 58 | @Override 59 | public Class getTarget() { 60 | return OneHotEncoderModelInfo.class; 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /adapters-2.0/src/main/java/com/flipkart/fdp/ml/adapter/GradientBoostClassificationModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import java.util.ArrayList; 4 | import java.util.LinkedHashSet; 5 | import java.util.List; 6 | import java.util.Set; 7 | 8 | import org.apache.spark.ml.classification.GBTClassificationModel; 9 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 10 | import org.apache.spark.ml.tree.DecisionTreeModel; 11 | 12 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 13 | import com.flipkart.fdp.ml.modelinfo.GradientBoostModelInfo; 14 | 15 | import lombok.extern.slf4j.Slf4j; 16 | 17 | /** 18 | * Adapts {@link GBTClassificationModel} to {@link GradientBoostModelInfo} 19 | * @author harshit.pandey 20 | * 21 | */ 22 | @Slf4j 23 | public class GradientBoostClassificationModelInfoAdapter extends AbstractModelInfoAdapter { 24 | 25 | private static final DecisionTreeRegressionModelInfoAdapter DECISION_TREE_ADAPTER = new DecisionTreeRegressionModelInfoAdapter(); 26 | 27 | @Override 28 | GradientBoostModelInfo getModelInfo(final GBTClassificationModel sparkGbModel) { 29 | final GradientBoostModelInfo modelInfo = new GradientBoostModelInfo(); 30 | 31 | modelInfo.setNumFeatures(sparkGbModel.numFeatures()); 32 | modelInfo.setRegression(false); //false for classification 33 | 34 | final List treeWeights = new ArrayList(); 35 | for (double w : sparkGbModel.treeWeights()) { 36 | treeWeights.add(w); 37 | } 38 | 39 | modelInfo.setTreeWeights(treeWeights); 40 | 41 | final List decisionTrees = new ArrayList<>(); 42 | for (DecisionTreeModel decisionTreeModel : sparkGbModel.trees()) { 43 | decisionTrees.add(DECISION_TREE_ADAPTER.getModelInfo((DecisionTreeRegressionModel) decisionTreeModel)); 44 | } 45 | 46 | modelInfo.setTrees(decisionTrees); 47 | 48 | final Set inputKeys = new LinkedHashSet(); 49 | inputKeys.add(sparkGbModel.getFeaturesCol()); 50 | inputKeys.add(sparkGbModel.getLabelCol()); 51 | modelInfo.setInputKeys(inputKeys); 52 | 53 | final Set outputKeys = new LinkedHashSet(); 54 | outputKeys.add(sparkGbModel.getPredictionCol()); 55 | modelInfo.setOutputKeys(outputKeys); 56 | 57 | return modelInfo; 58 | } 59 | 60 | @Override 61 | public Class getSource() { 62 | return GBTClassificationModel.class; 63 | } 64 | 65 | @Override 66 | public Class getTarget() { 67 | return GradientBoostModelInfo.class; 68 | } 69 | 70 | } 71 | -------------------------------------------------------------------------------- /adapters-2.0/src/test/java/com/flipkart/fdp/ml/adapter/DecisionTreeClassificationModelBridgeTest.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import static org.junit.Assert.assertArrayEquals; 4 | import static org.junit.Assert.assertEquals; 5 | 6 | import java.util.HashMap; 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel; 11 | import org.apache.spark.ml.classification.DecisionTreeClassifier; 12 | import org.apache.spark.ml.linalg.DenseVector; 13 | import org.apache.spark.ml.linalg.SparseVector; 14 | import org.apache.spark.sql.Dataset; 15 | import org.apache.spark.sql.Row; 16 | import org.junit.Test; 17 | 18 | import com.flipkart.fdp.ml.export.ModelExporter; 19 | import com.flipkart.fdp.ml.importer.ModelImporter; 20 | import com.flipkart.fdp.ml.transformer.DecisionTreeTransformer; 21 | 22 | /** 23 | * 24 | * @author harshit.pandey 25 | * 26 | */ 27 | public class DecisionTreeClassificationModelBridgeTest extends SparkTestBase { 28 | 29 | 30 | @Test 31 | public void testDecisionTreeClassificationPrediction() { 32 | // Load the data stored in LIBSVM format as a DataFrame. 33 | String datapath = "src/test/resources/classification_test.libsvm"; 34 | Dataset data = spark.read().format("libsvm").load(datapath); 35 | 36 | 37 | // Split the data into training and test sets (30% held out for testing) 38 | Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); 39 | Dataset trainingData = splits[0]; 40 | Dataset testData = splits[1]; 41 | 42 | // Train a DecisionTree model. 43 | DecisionTreeClassificationModel classifierModel = new DecisionTreeClassifier().fit(trainingData); 44 | trainingData.printSchema(); 45 | 46 | List output = classifierModel.transform(testData).select("features", "prediction","rawPrediction").collectAsList(); 47 | byte[] exportedModel = ModelExporter.export(classifierModel); 48 | 49 | DecisionTreeTransformer transformer = (DecisionTreeTransformer) ModelImporter.importAndGetTransformer(exportedModel); 50 | 51 | //compare predictions 52 | for (Row row : output) { 53 | Map data_ = new HashMap<>(); 54 | double [] actualRawPrediction = ((DenseVector) row.get(2)).toArray(); 55 | data_.put("features", ((SparseVector) row.get(0)).toArray()); 56 | transformer.transform(data_); 57 | System.out.println(data_); 58 | System.out.println(data_.get("prediction")); 59 | assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON); 60 | assertArrayEquals((double[]) data_.get("rawPrediction"), actualRawPrediction, EPSILON); 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/VectorAssemblerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.VectorAssemblerModelInfo; 4 | 5 | import java.util.ArrayList; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | /** 10 | * Transforms input/ predicts for a Vector assembler model representation 11 | * captured by {@link com.flipkart.fdp.ml.modelinfo.VectorAssemblerModelInfo}. 12 | *

13 | * Created by rohan.shetty on 28/03/16. 14 | */ 15 | public class VectorAssemblerTransformer implements Transformer { 16 | private final VectorAssemblerModelInfo modelInfo; 17 | 18 | public VectorAssemblerTransformer(final VectorAssemblerModelInfo modelInfo) { 19 | this.modelInfo = modelInfo; 20 | } 21 | 22 | private double[] predict(Object[] inputs) { 23 | 24 | ArrayList output = new ArrayList<>(); 25 | int i = 0; 26 | for (Object input : inputs) { 27 | if (input == null) { 28 | throw new RuntimeException("Values to assemble cannot be null"); 29 | } else if (isTypeDouble(input)) { 30 | output.add((double) input); 31 | } else if (isTypeDoubleArray(input)) { 32 | for (double val : (double[]) input) { 33 | output.add(val); 34 | } 35 | } else { 36 | throw new RuntimeException("Values to assemble cannot be of type: " + input.getClass().getCanonicalName()); 37 | } 38 | } 39 | double[] primitiveOutput = new double[output.size()]; 40 | i = 0; 41 | for (Double val : output) { 42 | primitiveOutput[i++] = val; 43 | } 44 | return primitiveOutput; 45 | } 46 | 47 | private boolean isTypeDouble(Object o) { 48 | return o != null && 49 | (double.class.equals(o.getClass()) || 50 | Double.class.equals(o.getClass())); 51 | } 52 | 53 | private boolean isTypeDoubleArray(Object o) { 54 | return o != null && double[].class.equals(o.getClass()); 55 | } 56 | 57 | 58 | @Override 59 | public void transform(Map input) { 60 | Object[] inputs = new Object[modelInfo.getInputKeys().size()]; 61 | int i = 0; 62 | for (String inputKey : modelInfo.getInputKeys()) { 63 | inputs[i++] = input.get(inputKey); 64 | } 65 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inputs)); 66 | } 67 | 68 | @Override 69 | public Set getInputKeys() { 70 | return modelInfo.getInputKeys(); 71 | } 72 | 73 | @Override 74 | public Set getOutputKeys() { 75 | return modelInfo.getOutputKeys(); 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /adapters-1.6/src/main/java/com/flipkart/fdp/ml/adapter/GradientBoostClassificationModelInfoAdapter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.adapter; 2 | 3 | import java.util.ArrayList; 4 | import java.util.LinkedHashSet; 5 | import java.util.List; 6 | import java.util.Set; 7 | 8 | import org.apache.spark.ml.classification.GBTClassificationModel; 9 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 10 | import org.apache.spark.ml.tree.DecisionTreeModel; 11 | import org.apache.spark.sql.DataFrame; 12 | 13 | import com.flipkart.fdp.ml.modelinfo.DecisionTreeModelInfo; 14 | import com.flipkart.fdp.ml.modelinfo.GradientBoostModelInfo; 15 | 16 | import lombok.extern.slf4j.Slf4j; 17 | 18 | /** 19 | * Adapts {@link GBTClassificationModel} to {@link GradientBoostModelInfo} 20 | * @author harshit.pandey 21 | * 22 | */ 23 | @Slf4j 24 | public class GradientBoostClassificationModelInfoAdapter extends AbstractModelInfoAdapter { 25 | 26 | private static final DecisionTreeRegressionModelInfoAdapter DECISION_TREE_ADAPTER = new DecisionTreeRegressionModelInfoAdapter(); 27 | 28 | @Override 29 | GradientBoostModelInfo getModelInfo(final GBTClassificationModel sparkGbModel, final DataFrame df) { 30 | final GradientBoostModelInfo modelInfo = new GradientBoostModelInfo(); 31 | 32 | modelInfo.setNumFeatures(sparkGbModel.numFeatures()); 33 | modelInfo.setRegression(false); //false for classification 34 | 35 | final List treeWeights = new ArrayList(); 36 | for (double w : sparkGbModel.treeWeights()) { 37 | treeWeights.add(w); 38 | } 39 | 40 | modelInfo.setTreeWeights(treeWeights); 41 | 42 | final List decisionTrees = new ArrayList<>(); 43 | for (DecisionTreeModel decisionTreeModel : sparkGbModel.trees()) { 44 | decisionTrees.add(DECISION_TREE_ADAPTER.getModelInfo((DecisionTreeRegressionModel) decisionTreeModel,df)); 45 | } 46 | 47 | modelInfo.setTrees(decisionTrees); 48 | 49 | final Set inputKeys = new LinkedHashSet(); 50 | inputKeys.add(sparkGbModel.getFeaturesCol()); 51 | inputKeys.add(sparkGbModel.getLabelCol()); 52 | modelInfo.setInputKeys(inputKeys); 53 | 54 | final Set outputKeys = new LinkedHashSet(); 55 | outputKeys.add(sparkGbModel.getPredictionCol()); 56 | modelInfo.setOutputKeys(outputKeys); 57 | 58 | return modelInfo; 59 | } 60 | 61 | @Override 62 | public Class getSource() { 63 | return GBTClassificationModel.class; 64 | } 65 | 66 | @Override 67 | public Class getTarget() { 68 | return GradientBoostModelInfo.class; 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/transformer/CountVectorizerTransformer.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.transformer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo; 4 | 5 | import java.util.Arrays; 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | import java.util.Set; 9 | 10 | /** 11 | * Transforms input/ predicts for a Count vectorizer model representation 12 | * captured by {@link com.flipkart.fdp.ml.modelinfo.CountVectorizerModelInfo}. 13 | */ 14 | public class CountVectorizerTransformer implements Transformer { 15 | private final CountVectorizerModelInfo modelInfo; 16 | private final Map vocabulary; 17 | 18 | public CountVectorizerTransformer(final CountVectorizerModelInfo modelInfo) { 19 | this.modelInfo = modelInfo; 20 | vocabulary = new HashMap(); 21 | for (int i = 0; i < modelInfo.getVocabulary().length; i++) { 22 | vocabulary.put(modelInfo.getVocabulary()[i], i); 23 | } 24 | } 25 | 26 | double[] predict(final String[] input) { 27 | final Map termFrequencies = new HashMap(); 28 | final int tokenCount = input.length; 29 | for (String term : input) { 30 | if (vocabulary.containsKey(term)) { 31 | if (termFrequencies.containsKey(term)) { 32 | termFrequencies.put(term, termFrequencies.get(term) + 1); 33 | } else { 34 | termFrequencies.put(term, 1); 35 | } 36 | } else { 37 | //ignore terms not in vocabulary 38 | } 39 | } 40 | final int effectiveMinTF = (int) ((modelInfo.getMinTF() >= 1.0) ? modelInfo.getMinTF() : modelInfo.getMinTF() * tokenCount); 41 | 42 | final double[] encoding = new double[modelInfo.getVocabulary().length]; 43 | Arrays.fill(encoding, 0.0); 44 | 45 | for (final Map.Entry entry : termFrequencies.entrySet()) { 46 | //filter out terms with freq < effectiveMinTF 47 | if (entry.getValue() >= effectiveMinTF) { 48 | int position = vocabulary.get(entry.getKey()); 49 | encoding[position] = entry.getValue(); 50 | } 51 | } 52 | return encoding; 53 | } 54 | 55 | @Override 56 | public void transform(Map input) { 57 | String[] inp = (String[]) input.get(modelInfo.getInputKeys().iterator().next()); 58 | input.put(modelInfo.getOutputKeys().iterator().next(), predict(inp)); 59 | } 60 | 61 | @Override 62 | public Set getInputKeys() { 63 | return modelInfo.getInputKeys(); 64 | } 65 | 66 | @Override 67 | public Set getOutputKeys() { 68 | return modelInfo.getOutputKeys(); 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /models-info/src/main/java/com/flipkart/fdp/ml/importer/ModelImporter.java: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml.importer; 2 | 3 | import com.flipkart.fdp.ml.modelinfo.ModelInfo; 4 | import com.flipkart.fdp.ml.modelinfo.PipelineModelInfo; 5 | import com.flipkart.fdp.ml.transformer.Transformer; 6 | import com.google.gson.Gson; 7 | import com.google.gson.reflect.TypeToken; 8 | 9 | import java.io.Serializable; 10 | import java.util.Map; 11 | 12 | /** 13 | * Imports byte[] representing a model into corresponding {@link ModelInfo} object. 14 | * The serialization format currently being used is json 15 | */ 16 | public class ModelImporter implements Serializable { 17 | private static final Gson gson = new Gson(); 18 | 19 | 20 | /** 21 | * Imports byte[] representing a model into corresponding {@link ModelInfo} object 22 | * and returns the transformer for this model. 23 | * 24 | * @param serializedModelInfo byte[] representing the serialized data 25 | * @return transformer for the imported model of type {@link Transformer} 26 | */ 27 | public static Transformer importAndGetTransformer(byte[] serializedModelInfo) { 28 | return importModelInfo(serializedModelInfo).getTransformer(); 29 | } 30 | 31 | /** 32 | * Imports byte[] representing a model into corresponding {@link ModelInfo} object. 33 | * The serialization format currently being used is json 34 | * 35 | * @param serializedModelInfo byte[] representing the serialized data 36 | * @return model info imported of type {@link ModelInfo} 37 | */ 38 | public static ModelInfo importModelInfo(byte[] serializedModelInfo) { 39 | String data = new String(serializedModelInfo, SerializationConstants.CHARSET); 40 | Map map = gson.fromJson(data, new TypeToken>() { 41 | }.getType()); 42 | Class modelClass = null; 43 | try { 44 | modelClass = Class.forName(map.get(SerializationConstants.TYPE_IDENTIFIER)); 45 | } catch (ClassNotFoundException e) { 46 | throw new RuntimeException(e); 47 | } 48 | if (modelClass == PipelineModelInfo.class) { 49 | String[] serializedModelInfos = gson.fromJson(map.get(SerializationConstants.MODEL_INFO_IDENTIFIER), String[].class); 50 | ModelInfo[] modelInfos = new ModelInfo[serializedModelInfos.length]; 51 | for (int i = 0; i < modelInfos.length; i++) { 52 | modelInfos[i] = importModelInfo(serializedModelInfos[i].getBytes()); 53 | } 54 | PipelineModelInfo pipelineModelInfo = new PipelineModelInfo(); 55 | pipelineModelInfo.setStages(modelInfos); 56 | return pipelineModelInfo; 57 | } else { 58 | return (ModelInfo) gson.fromJson(map.get(SerializationConstants.MODEL_INFO_IDENTIFIER), modelClass); 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /custom-transformer/src/main/scala/com/flipkart/fdp/ml/Log1PScaler.scala: -------------------------------------------------------------------------------- 1 | package com.flipkart.fdp.ml 2 | 3 | import org.apache.spark.ml.Transformer 4 | import org.apache.spark.ml.param.{Param, ParamMap} 5 | import org.apache.spark.ml.util.Identifiable 6 | import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT, Vectors} 7 | import org.apache.spark.sql.DataFrame 8 | import org.apache.spark.sql.functions._ 9 | import org.apache.spark.sql.types.StructType 10 | 11 | 12 | class Log1PScaler(override val uid: String) 13 | extends Transformer { 14 | 15 | final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") 16 | final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") 17 | 18 | def this() { 19 | this(Identifiable.randomUID("customLogScaler")) 20 | } 21 | 22 | final def getInputCol: String = $(inputCol) 23 | 24 | def setInputCol(value: String): this.type = set(inputCol, value) 25 | 26 | final def getOutputCol: String = $(outputCol) 27 | 28 | def setOutputCol(value: String): this.type = set(outputCol, value) 29 | 30 | override def transform(dataFrame: DataFrame): DataFrame = { 31 | transformSchema(dataFrame.schema) 32 | 33 | val encode = udf { inputVector: Vector => 34 | inputVector match { 35 | case DenseVector(vs) => 36 | val values = vs.clone() 37 | val size = values.size 38 | var i = 0 39 | while (i < size) { 40 | values(i) = Math.log1p(values(i)); 41 | i += 1 42 | } 43 | Vectors.dense(values) 44 | case SparseVector(size, indices, vs) => 45 | // For sparse vector, the `index` array inside sparse vector object will not be changed, 46 | // so we can re-use it to save memory. 47 | val values = vs.clone() 48 | val nnz = values.size 49 | var i = 0 50 | while (i < nnz) { 51 | values(i) = Math.log1p(values(i)); 52 | i += 1 53 | } 54 | Vectors.sparse(size, indices, values) 55 | case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) 56 | } 57 | } 58 | dataFrame.withColumn($(outputCol), encode(col($(inputCol)))) 59 | } 60 | 61 | override def transformSchema(schema: StructType): StructType = { 62 | val inputType = schema($(inputCol)).dataType 63 | require(inputType.isInstanceOf[VectorUDT], 64 | s"Input column ${$(inputCol)} must be a vector column") 65 | require(!schema.fieldNames.contains($(outputCol)), 66 | s"Output column ${$(outputCol)} already exists.") 67 | return CustomSchemaUtil.appendColumn(schema, $(outputCol), new VectorUDT) 68 | } 69 | 70 | override def copy(extra: ParamMap): Log1PScaler = { 71 | val copied = new Log1PScaler(uid) 72 | copyValues(copied, extra) 73 | } 74 | } 75 | 76 | --------------------------------------------------------------------------------