├── .github └── workflows │ └── maven.yml ├── LICENSE.txt ├── README.md ├── pmml-sparkml-evaluator ├── pom.xml └── src │ └── main │ └── java │ └── org │ └── jpmml │ └── sparkml │ └── evaluator │ └── SparkMLFunctionRegistry.java ├── pmml-sparkml-example ├── pom.xml └── src │ └── main │ └── java │ └── org │ └── jpmml │ └── sparkml │ └── example │ └── Main.java ├── pmml-sparkml-lightgbm ├── pom.xml └── src │ ├── main │ ├── java │ │ └── org │ │ │ └── jpmml │ │ │ └── sparkml │ │ │ └── lightgbm │ │ │ ├── BoosterUtil.java │ │ │ ├── LightGBMClassificationModelConverter.java │ │ │ └── LightGBMRegressionModelConverter.java │ └── resources │ │ └── META-INF │ │ └── sparkml2pmml.properties │ └── test │ ├── java │ └── org │ │ └── jpmml │ │ └── sparkml │ │ └── lightgbm │ │ └── testing │ │ └── LightGBMTest.java │ └── resources │ ├── LightGBMAudit.scala │ ├── LightGBMAuditNA.scala │ ├── LightGBMAuto.scala │ ├── LightGBMAutoNA.scala │ ├── LightGBMIris.scala │ ├── csv │ ├── Audit.csv │ ├── AuditNA.csv │ ├── Auto.csv │ ├── AutoNA.csv │ ├── Iris.csv │ ├── LightGBMAudit.csv │ ├── LightGBMAuditNA.csv │ ├── LightGBMAuto.csv │ ├── LightGBMAutoNA.csv │ └── LightGBMIris.csv │ ├── pipeline │ ├── LightGBMAudit.zip │ ├── LightGBMAuditNA.zip │ ├── LightGBMAuto.zip │ ├── LightGBMAutoNA.zip │ └── LightGBMIris.zip │ └── schema │ ├── Audit.json │ ├── AuditNA.json │ ├── Auto.json │ ├── AutoNA.json │ └── Iris.json ├── pmml-sparkml-xgboost ├── pom.xml └── src │ ├── main │ ├── java │ │ └── org │ │ │ └── jpmml │ │ │ └── sparkml │ │ │ └── xgboost │ │ │ ├── BoosterUtil.java │ │ │ ├── XGBoostClassificationModelConverter.java │ │ │ └── XGBoostRegressionModelConverter.java │ └── resources │ │ └── META-INF │ │ └── sparkml2pmml.properties │ └── test │ ├── java │ └── org │ │ └── jpmml │ │ └── sparkml │ │ └── xgboost │ │ └── testing │ │ └── XGBoostTest.java │ └── resources │ ├── XGBoostAudit.scala │ ├── XGBoostAuditNA.scala │ ├── XGBoostAuto.scala │ ├── XGBoostAutoNA.scala │ ├── XGBoostHousing.scala │ ├── XGBoostIris.scala │ ├── csv │ ├── Audit.csv │ ├── AuditNA.csv │ ├── Auto.csv │ ├── AutoNA.csv │ ├── Housing.csv │ ├── Iris.csv │ ├── XGBoostAudit.csv │ ├── XGBoostAuditNA.csv │ ├── XGBoostAuto.csv │ ├── XGBoostAutoNA.csv │ ├── XGBoostHousing.csv │ └── XGBoostIris.csv │ ├── pipeline │ ├── XGBoostAudit.zip │ ├── XGBoostAuditNA.zip │ ├── XGBoostAuto.zip │ ├── XGBoostAutoNA.zip │ ├── XGBoostHousing.zip │ └── XGBoostIris.zip │ └── schema │ ├── Audit.json │ ├── AuditNA.json │ ├── Auto.json │ ├── AutoNA.json │ ├── Housing.json │ └── Iris.json ├── pmml-sparkml ├── pom.xml └── src │ ├── main │ ├── java │ │ └── org │ │ │ └── jpmml │ │ │ └── sparkml │ │ │ ├── AliasExpression.java │ │ │ ├── ArchiveUtil.java │ │ │ ├── AssociationRulesModelConverter.java │ │ │ ├── BinarizedCategoricalFeature.java │ │ │ ├── ClassificationModelConverter.java │ │ │ ├── ClusteringModelConverter.java │ │ │ ├── ConverterFactory.java │ │ │ ├── DatasetUtil.java │ │ │ ├── DocumentFeature.java │ │ │ ├── ExpressionTranslator.java │ │ │ ├── FeatureConverter.java │ │ │ ├── HasSparkMLOptions.java │ │ │ ├── ItemSetFeature.java │ │ │ ├── MatrixUtil.java │ │ │ ├── ModelConverter.java │ │ │ ├── MultiFeatureConverter.java │ │ │ ├── PMMLBuilder.java │ │ │ ├── PipelineModelUtil.java │ │ │ ├── PredictionModelConverter.java │ │ │ ├── ProbabilisticClassificationModelConverter.java │ │ │ ├── RegexKey.java │ │ │ ├── RegressionModelConverter.java │ │ │ ├── SparkMLEncoder.java │ │ │ ├── SparkSessionUtil.java │ │ │ ├── TermFeature.java │ │ │ ├── TermUtil.java │ │ │ ├── TransformerConverter.java │ │ │ ├── VectorUtil.java │ │ │ ├── WeightedTermFeature.java │ │ │ ├── feature │ │ │ ├── BinarizerConverter.java │ │ │ ├── BucketizerConverter.java │ │ │ ├── ChiSqSelectorModelConverter.java │ │ │ ├── ColumnPrunerConverter.java │ │ │ ├── CountVectorizerModelConverter.java │ │ │ ├── IDFModelConverter.java │ │ │ ├── ImputerModelConverter.java │ │ │ ├── IndexToStringConverter.java │ │ │ ├── InteractionConverter.java │ │ │ ├── InvalidCategoryTransformerConverter.java │ │ │ ├── MaxAbsScalerModelConverter.java │ │ │ ├── MinMaxScalerModelConverter.java │ │ │ ├── NGramConverter.java │ │ │ ├── OneHotEncoderModelConverter.java │ │ │ ├── PCAModelConverter.java │ │ │ ├── RFormulaModelConverter.java │ │ │ ├── RegexTokenizerConverter.java │ │ │ ├── SQLTransformerConverter.java │ │ │ ├── SparseToDenseTransformerConverter.java │ │ │ ├── StandardScalerModelConverter.java │ │ │ ├── StopWordsRemoverConverter.java │ │ │ ├── StringIndexerModelConverter.java │ │ │ ├── TokenizerConverter.java │ │ │ ├── VectorAssemblerConverter.java │ │ │ ├── VectorAttributeRewriterConverter.java │ │ │ ├── VectorIndexerModelConverter.java │ │ │ ├── VectorSizeHintConverter.java │ │ │ └── VectorSlicerConverter.java │ │ │ ├── model │ │ │ ├── DecisionTreeClassificationModelConverter.java │ │ │ ├── DecisionTreeRegressionModelConverter.java │ │ │ ├── FPGrowthModelConverter.java │ │ │ ├── GBTClassificationModelConverter.java │ │ │ ├── GBTRegressionModelConverter.java │ │ │ ├── GeneralizedLinearRegressionModelConverter.java │ │ │ ├── HasFeatureImportances.java │ │ │ ├── HasPredictionModelOptions.java │ │ │ ├── HasRegressionTableOptions.java │ │ │ ├── HasTreeOptions.java │ │ │ ├── KMeansModelConverter.java │ │ │ ├── LinearModelUtil.java │ │ │ ├── LinearRegressionModelConverter.java │ │ │ ├── LinearSVCModelConverter.java │ │ │ ├── LogisticRegressionModelConverter.java │ │ │ ├── MultilayerPerceptronClassificationModelConverter.java │ │ │ ├── NaiveBayesModelConverter.java │ │ │ ├── RandomForestClassificationModelConverter.java │ │ │ ├── RandomForestRegressionModelConverter.java │ │ │ ├── RegressionTableUtil.java │ │ │ └── TreeModelUtil.java │ │ │ ├── testing │ │ │ ├── SparkMLEncoderBatch.java │ │ │ └── SparkMLEncoderBatchTest.java │ │ │ └── visitors │ │ │ └── TreeModelCompactor.java │ ├── resources │ │ └── META-INF │ │ │ └── sparkml2pmml.properties │ └── scala │ │ └── org │ │ └── jpmml │ │ └── sparkml │ │ └── feature │ │ ├── InvalidCategoryTransformer.scala │ │ └── SparseToDenseTransformer.scala │ └── test │ ├── java │ └── org │ │ └── jpmml │ │ └── sparkml │ │ ├── AliasExpressionTest.java │ │ ├── ExpressionTranslatorTest.java │ │ ├── PMMLBuilderTest.java │ │ ├── RegexKeyTest.java │ │ ├── SparkMLTest.java │ │ ├── TermUtilTest.java │ │ ├── feature │ │ ├── InvalidCategoryTransformerTest.java │ │ ├── SQLTransformerConverterTest.java │ │ └── SparseToDenseTransformerTest.java │ │ └── testing │ │ ├── AssociationRulesTest.java │ │ ├── ClassificationTest.java │ │ ├── ClusteringTest.java │ │ ├── RegressionTest.java │ │ ├── SimpleSparkMLEncoderBatchTest.java │ │ └── SparkMLAlgorithms.java │ └── resources │ ├── common.py │ ├── csv │ ├── Audit.csv │ ├── Auto.csv │ ├── DecisionTreeAudit.csv │ ├── DecisionTreeAuto.csv │ ├── DecisionTreeHousing.csv │ ├── DecisionTreeIris.csv │ ├── DecisionTreeSentiment.csv │ ├── GBTAudit.csv │ ├── GBTAuto.csv │ ├── GLMAudit.csv │ ├── GLMAuto.csv │ ├── GLMHousing.csv │ ├── GLMSentiment.csv │ ├── GLMVisit.csv │ ├── Housing.csv │ ├── Iris.csv │ ├── KMeansIris.csv │ ├── LinearRegressionAuto.csv │ ├── LinearRegressionHousing.csv │ ├── LinearSVCSentiment.csv │ ├── LogisticRegressionAudit.csv │ ├── LogisticRegressionIris.csv │ ├── ModelChainAudit.csv │ ├── ModelChainAuto.csv │ ├── ModelChainIris.csv │ ├── NaiveBayesAudit.csv │ ├── NaiveBayesIris.csv │ ├── NeuralNetworkAudit.csv │ ├── NeuralNetworkIris.csv │ ├── RandomForestAudit.csv │ ├── RandomForestAuto.csv │ ├── RandomForestHousing.csv │ ├── RandomForestIris.csv │ ├── RandomForestSentiment.csv │ ├── Sentiment.csv │ ├── Shopping.csv │ └── Visit.csv │ ├── main.py │ ├── pipeline │ ├── DecisionTreeAudit.zip │ ├── DecisionTreeAuto.zip │ ├── DecisionTreeHousing.zip │ ├── DecisionTreeIris.zip │ ├── DecisionTreeSentiment.zip │ ├── FPGrowthShopping.zip │ ├── GBTAudit.zip │ ├── GBTAuto.zip │ ├── GLMAudit.zip │ ├── GLMAuto.zip │ ├── GLMHousing.zip │ ├── GLMSentiment.zip │ ├── GLMVisit.zip │ ├── KMeansIris.zip │ ├── LinearRegressionAuto.zip │ ├── LinearRegressionHousing.zip │ ├── LinearSVCSentiment.zip │ ├── LogisticRegressionAudit.zip │ ├── LogisticRegressionIris.zip │ ├── ModelChainAudit.zip │ ├── ModelChainAuto.zip │ ├── ModelChainIris.zip │ ├── NaiveBayesAudit.zip │ ├── NaiveBayesIris.zip │ ├── NeuralNetworkAudit.zip │ ├── NeuralNetworkIris.zip │ ├── RandomForestAudit.zip │ ├── RandomForestAuto.zip │ ├── RandomForestHousing.zip │ ├── RandomForestIris.zip │ └── RandomForestSentiment.zip │ └── schema │ ├── Audit.json │ ├── Auto.json │ ├── Housing.json │ ├── Iris.json │ ├── Sentiment.json │ ├── Shopping.json │ └── Visit.json └── pom.xml /.github/workflows/maven.yml: -------------------------------------------------------------------------------- 1 | name: maven 2 | 3 | on: 4 | push: 5 | branches: [ '3.0.X', master ] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | java: [ 11, 17 ] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-java@v4 18 | with: 19 | distribution: 'zulu' 20 | java-version: ${{ matrix.java }} 21 | cache: 'maven' 22 | - run: mvn -B package --file pom.xml 23 | -------------------------------------------------------------------------------- /pmml-sparkml-evaluator/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.1-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml-evaluator 13 | jar 14 | 15 | JPMML Spark ML JPMML-Evaluator integration 16 | JPMML Apache Spark ML JPMML-Evaluator integration 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-evaluator 30 | provided 31 | 32 | 33 | org.jpmml 34 | pmml-evaluator-testing 35 | provided 36 | 37 | 38 | 39 | org.junit.jupiter 40 | junit-jupiter-api 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /pmml-sparkml-evaluator/src/main/java/org/jpmml/sparkml/evaluator/SparkMLFunctionRegistry.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.evaluator; 20 | 21 | import java.util.Collections; 22 | import java.util.Map; 23 | import java.util.Objects; 24 | import java.util.function.Predicate; 25 | 26 | import org.jpmml.evaluator.Function; 27 | import org.jpmml.evaluator.FunctionRegistry; 28 | 29 | /** 30 | * @see FunctionRegistry 31 | */ 32 | public class SparkMLFunctionRegistry { 33 | 34 | private SparkMLFunctionRegistry(){ 35 | } 36 | 37 | static 38 | public void publish(String name){ 39 | publish(key -> Objects.equals(name, key)); 40 | } 41 | 42 | static 43 | public void publishAll(){ 44 | publish(key -> true); 45 | } 46 | 47 | static 48 | private void publish(Predicate predicate){ 49 | (SparkMLFunctionRegistry.functions.entrySet()).stream() 50 | .filter(entry -> predicate.test(entry.getKey())) 51 | .forEach(entry -> FunctionRegistry.putFunction(entry.getKey(), entry.getValue())); 52 | 53 | (SparkMLFunctionRegistry.functionClazzes.entrySet()).stream() 54 | .filter(entry -> predicate.test(entry.getKey())) 55 | .forEach(entry -> FunctionRegistry.putFunction(entry.getKey(), entry.getValue())); 56 | } 57 | 58 | private static final Map functions = Collections.emptyMap(); 59 | private static final Map> functionClazzes = Collections.emptyMap(); 60 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.1-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml-lightgbm 13 | jar 14 | 15 | JPMML Spark ML LightGBM converter 16 | JPMML Apache Spark ML LightGBM to PMML converter 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-sparkml 30 | 31 | 32 | 33 | org.jpmml 34 | pmml-evaluator-testing 35 | provided 36 | 37 | 38 | 39 | org.jpmml 40 | pmml-lightgbm 41 | 42 | 43 | 44 | com.microsoft.azure 45 | synapseml-lightgbm_2.12 46 | provided 47 | 48 | 49 | 50 | org.apache.spark 51 | spark-core_2.12 52 | provided 53 | 54 | 55 | org.apache.spark 56 | spark-mllib_2.12 57 | provided 58 | 59 | 60 | 61 | org.junit.jupiter 62 | junit-jupiter-api 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/java/org/jpmml/sparkml/lightgbm/BoosterUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm; 20 | 21 | import java.io.IOException; 22 | import java.io.StringReader; 23 | import java.util.LinkedHashMap; 24 | import java.util.List; 25 | import java.util.Map; 26 | 27 | import com.google.common.io.CharStreams; 28 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMModelMethods; 29 | import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster; 30 | import org.apache.spark.ml.Model; 31 | import org.apache.spark.ml.param.shared.HasPredictionCol; 32 | import org.dmg.pmml.mining.MiningModel; 33 | import org.jpmml.converter.Schema; 34 | import org.jpmml.lightgbm.GBDT; 35 | import org.jpmml.lightgbm.HasLightGBMOptions; 36 | import org.jpmml.lightgbm.LightGBMUtil; 37 | import org.jpmml.sparkml.ModelConverter; 38 | import scala.Option; 39 | 40 | public class BoosterUtil { 41 | 42 | private BoosterUtil(){ 43 | } 44 | 45 | static 46 | public , M extends Model & HasPredictionCol & LightGBMModelMethods> MiningModel encodeModel(C converter, Schema schema){ 47 | M model = converter.getModel(); 48 | 49 | GBDT gbdt = BoosterUtil.getGBDT(model); 50 | 51 | Integer bestIteration = model.getBoosterBestIteration(); 52 | if(bestIteration < 0){ 53 | bestIteration = null; 54 | } 55 | 56 | Map options = new LinkedHashMap<>(); 57 | options.put(HasLightGBMOptions.OPTION_COMPACT, converter.getOption(HasLightGBMOptions.OPTION_COMPACT, Boolean.TRUE)); 58 | options.put(HasLightGBMOptions.OPTION_NUM_ITERATION, converter.getOption(HasLightGBMOptions.OPTION_NUM_ITERATION, bestIteration)); 59 | 60 | Schema lgbmSchema = gbdt.toLightGBMSchema(schema); 61 | 62 | MiningModel miningModel = gbdt.encodeModel(options, lgbmSchema); 63 | 64 | return miningModel; 65 | } 66 | 67 | static 68 | private & LightGBMModelMethods> GBDT getGBDT(M model){ 69 | LightGBMBooster booster = model.getLightGBMBooster(); 70 | 71 | Option modelStr = booster.modelStr(); 72 | if(modelStr.isEmpty()){ 73 | throw new IllegalArgumentException(); 74 | } 75 | 76 | String string = modelStr.get(); 77 | 78 | try(StringReader reader = new StringReader(string)){ 79 | List lines = CharStreams.readLines(reader); 80 | 81 | return LightGBMUtil.loadGBDT(lines.iterator()); 82 | } catch(IOException ioe){ 83 | throw new RuntimeException(ioe); 84 | } 85 | } 86 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/java/org/jpmml/sparkml/lightgbm/LightGBMClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm; 20 | 21 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassificationModel; 22 | import org.dmg.pmml.mining.MiningModel; 23 | import org.dmg.pmml.regression.RegressionModel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.converter.mining.MiningModelUtil; 26 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 27 | 28 | public class LightGBMClassificationModelConverter extends ProbabilisticClassificationModelConverter { 29 | 30 | public LightGBMClassificationModelConverter(LightGBMClassificationModel model){ 31 | super(model); 32 | } 33 | 34 | @Override 35 | public int getNumberOfClasses(){ 36 | int numberOfClasses = super.getNumberOfClasses(); 37 | 38 | if(numberOfClasses == 1){ 39 | return 2; 40 | } 41 | 42 | return numberOfClasses; 43 | } 44 | 45 | @Override 46 | public MiningModel encodeModel(Schema schema){ 47 | LightGBMClassificationModel model = getModel(); 48 | 49 | MiningModel miningModel = BoosterUtil.encodeModel(this, schema); 50 | 51 | RegressionModel regressionModel = (RegressionModel)MiningModelUtil.getFinalModel(miningModel); 52 | regressionModel.setOutput(null); 53 | 54 | return miningModel; 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/java/org/jpmml/sparkml/lightgbm/LightGBMRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm; 20 | 21 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel; 22 | import org.dmg.pmml.mining.MiningModel; 23 | import org.jpmml.converter.Schema; 24 | import org.jpmml.sparkml.RegressionModelConverter; 25 | 26 | public class LightGBMRegressionModelConverter extends RegressionModelConverter { 27 | 28 | public LightGBMRegressionModelConverter(LightGBMRegressionModel model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public MiningModel encodeModel(Schema schema){ 34 | LightGBMRegressionModel model = getModel(); 35 | 36 | return BoosterUtil.encodeModel(this, schema); 37 | } 38 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/resources/META-INF/sparkml2pmml.properties: -------------------------------------------------------------------------------- 1 | com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassificationModel = org.jpmml.sparkml.lightgbm.LightGBMClassificationModelConverter 2 | com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel = org.jpmml.sparkml.lightgbm.LightGBMRegressionModelConverter 3 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/java/org/jpmml/sparkml/lightgbm/testing/LightGBMTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm.testing; 20 | 21 | import java.util.LinkedHashMap; 22 | import java.util.List; 23 | import java.util.Map; 24 | import java.util.function.Predicate; 25 | 26 | import com.google.common.base.Equivalence; 27 | import org.jpmml.converter.testing.Datasets; 28 | import org.jpmml.converter.testing.OptionsUtil; 29 | import org.jpmml.evaluator.ResultField; 30 | import org.jpmml.evaluator.testing.PMMLEquivalence; 31 | import org.jpmml.lightgbm.HasLightGBMOptions; 32 | import org.jpmml.sparkml.testing.SparkMLEncoderBatch; 33 | import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest; 34 | import org.junit.jupiter.api.AfterAll; 35 | import org.junit.jupiter.api.BeforeAll; 36 | import org.junit.jupiter.api.Test; 37 | 38 | public class LightGBMTest extends SparkMLEncoderBatchTest implements Datasets { 39 | 40 | public LightGBMTest(){ 41 | super(new PMMLEquivalence(1e-14, 1e-14)); 42 | } 43 | 44 | @Override 45 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 46 | columnFilter = columnFilter.and(SparkMLEncoderBatchTest.excludePredictionFields()); 47 | 48 | SparkMLEncoderBatch result = new SparkMLEncoderBatch(algorithm, dataset, columnFilter, equivalence){ 49 | 50 | @Override 51 | public LightGBMTest getArchiveBatchTest(){ 52 | return LightGBMTest.this; 53 | } 54 | 55 | @Override 56 | public List> getOptionsMatrix(){ 57 | Map options = new LinkedHashMap<>(); 58 | 59 | options.put(HasLightGBMOptions.OPTION_COMPACT, new Boolean[]{false, true}); 60 | 61 | return OptionsUtil.generateOptionsMatrix(options); 62 | } 63 | }; 64 | 65 | return result; 66 | } 67 | 68 | @Test 69 | public void evaluateLightGBMAudit() throws Exception { 70 | evaluate("LightGBM", AUDIT); 71 | } 72 | 73 | @Test 74 | public void evaluateLightGBMAuditNA() throws Exception { 75 | evaluate("LightGBM", AUDIT_NA); 76 | } 77 | 78 | @Test 79 | public void evaluateLightGBMAuto() throws Exception { 80 | evaluate("LightGBM", AUTO); 81 | } 82 | 83 | @Test 84 | public void evaluateLightGBMAutoNA() throws Exception { 85 | evaluate("LightGBM", AUTO_NA); 86 | } 87 | 88 | @Test 89 | public void evaluateLightGBMIris() throws Exception { 90 | evaluate("LightGBM", IRIS); 91 | } 92 | 93 | @BeforeAll 94 | static 95 | public void createSparkSession(){ 96 | SparkMLEncoderBatchTest.createSparkSession(); 97 | } 98 | 99 | @AfterAll 100 | static 101 | public void destroySparkSession(){ 102 | SparkMLEncoderBatchTest.destroySparkSession(); 103 | } 104 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/LightGBMAudit.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.sql.functions.{lit, udf} 8 | import org.apache.spark.sql.types.StringType 9 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 10 | 11 | var df = DatasetUtil.loadCsv(spark, new File("csv/Audit.csv")) 12 | df = DatasetUtil.castColumn(df, "Adjusted", StringType) 13 | 14 | DatasetUtil.storeSchema(df, new File("schema/Audit.json")) 15 | 16 | val cat_cols = Array("Education", "Employment", "Gender", "Marital", "Occupation") 17 | val cont_cols = Array("Age", "Hours", "Income") 18 | 19 | val labelIndexer = new StringIndexer().setInputCol("Adjusted").setOutputCol("idx_Adjusted") 20 | 21 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)) 22 | val assembler = new VectorAssembler().setInputCols(indexer.getOutputCols ++ cont_cols).setOutputCol("featureVector") 23 | 24 | val classifier = new LightGBMClassifier().setObjective("binary").setNumIterations(101).setLabelCol(labelIndexer.getOutputCol).setFeaturesCol(assembler.getOutputCol) 25 | 26 | val pipeline = new Pipeline().setStages(Array(labelIndexer, indexer, assembler, classifier)) 27 | val pipelineModel = pipeline.fit(df) 28 | 29 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/LightGBMAudit.zip")) 30 | 31 | val predLabel = udf{ (value: Float) => value.toInt.toString } 32 | val vectorToColumn = udf{ (vec: Vector, index: Int) => vec(index) } 33 | 34 | var lgbDf = pipelineModel.transform(df) 35 | lgbDf = lgbDf.selectExpr("prediction", "probability") 36 | lgbDf = lgbDf.withColumn("Adjusted", predLabel(lgbDf("prediction"))).drop("prediction") 37 | lgbDf = lgbDf.withColumn("probability(0)", vectorToColumn(lgbDf("probability"), lit(0))).withColumn("probability(1)", vectorToColumn(lgbDf("probability"), lit(1))).drop("probability").drop("probability") 38 | 39 | DatasetUtil.storeCsv(lgbDf, new File("csv/LightGBMAudit.csv")) 40 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/LightGBMAuditNA.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.sql.functions.{lit, udf} 8 | import org.apache.spark.sql.types.StringType 9 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 10 | import org.jpmml.sparkml.feature.InvalidCategoryTransformer 11 | 12 | var df = DatasetUtil.loadCsv(spark, new File("csv/AuditNA.csv")) 13 | df = DatasetUtil.castColumn(df, "Adjusted", StringType) 14 | 15 | DatasetUtil.storeSchema(df, new File("schema/AuditNA.json")) 16 | 17 | val cat_cols = Array("Education", "Employment", "Gender", "Marital", "Occupation") 18 | val cont_cols = Array("Age", "Hours", "Income") 19 | 20 | val labelIndexer = new StringIndexer().setInputCol("Adjusted").setOutputCol("idx_Adjusted") 21 | 22 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)).setHandleInvalid("keep") 23 | val indexTransformer = new InvalidCategoryTransformer().setInputCols(indexer.getOutputCols).setOutputCols(cat_cols.map(cat_col => "idxTransformed_" + cat_col)) 24 | 25 | val assembler = new VectorAssembler().setInputCols(indexTransformer.getOutputCols ++ cont_cols).setOutputCol("featureVector").setHandleInvalid("keep") 26 | 27 | val classifier = new LightGBMClassifier().setObjective("binary").setNumIterations(101).setLabelCol(labelIndexer.getOutputCol).setFeaturesCol(assembler.getOutputCol) 28 | 29 | val pipeline = new Pipeline().setStages(Array(labelIndexer, indexer, indexTransformer, assembler, classifier)) 30 | val pipelineModel = pipeline.fit(df) 31 | 32 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/LightGBMAuditNA.zip")) 33 | 34 | val predLabel = udf{ (value: Float) => value.toInt.toString } 35 | val vectorToColumn = udf{ (vec: Vector, index: Int) => vec(index) } 36 | 37 | var lgbDf = pipelineModel.transform(df) 38 | lgbDf = lgbDf.selectExpr("prediction", "probability") 39 | lgbDf = lgbDf.withColumn("Adjusted", predLabel(lgbDf("prediction"))).drop("prediction") 40 | lgbDf = lgbDf.withColumn("probability(0)", vectorToColumn(lgbDf("probability"), lit(0))).withColumn("probability(1)", vectorToColumn(lgbDf("probability"), lit(1))).drop("probability").drop("probability") 41 | 42 | DatasetUtil.storeCsv(lgbDf, new File("csv/LightGBMAuditNA.csv")) 43 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/LightGBMAuto.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressor 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.sql.types.StringType 7 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 8 | 9 | var df = DatasetUtil.loadCsv(spark, new File("csv/Auto.csv")) 10 | df = DatasetUtil.castColumn(df, "origin", StringType) 11 | 12 | DatasetUtil.storeSchema(df, new File("schema/Auto.json")) 13 | 14 | val cat_cols = Array("cylinders", "model_year", "origin") 15 | val cont_cols = Array("acceleration", "displacement", "horsepower", "weight") 16 | 17 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)) 18 | val assembler = new VectorAssembler().setInputCols(indexer.getOutputCols ++ cont_cols).setOutputCol("featureVector") 19 | 20 | val regressor = new LightGBMRegressor().setNumIterations(101).setLabelCol("mpg").setFeaturesCol(assembler.getOutputCol) 21 | 22 | val pipeline = new Pipeline().setStages(Array(indexer, assembler, regressor)) 23 | val pipelineModel = pipeline.fit(df) 24 | 25 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/LightGBMAuto.zip")) 26 | 27 | var lgbDf = pipelineModel.transform(df) 28 | lgbDf = lgbDf.selectExpr("prediction as mpg") 29 | 30 | DatasetUtil.storeCsv(lgbDf, new File("csv/LightGBMAuto.csv")) 31 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/LightGBMAutoNA.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressor 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 7 | import org.jpmml.sparkml.feature.InvalidCategoryTransformer 8 | 9 | var df = DatasetUtil.loadCsv(spark, new File("csv/AutoNA.csv")) 10 | 11 | DatasetUtil.storeSchema(df, new File("schema/AutoNA.json")) 12 | 13 | val cat_cols = Array("cylinders", "model_year", "origin") 14 | val cont_cols = Array("acceleration", "displacement", "horsepower", "weight") 15 | 16 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)).setHandleInvalid("keep") 17 | val indexTransformer = new InvalidCategoryTransformer().setInputCols(indexer.getOutputCols).setOutputCols(cat_cols.map(cat_col => "idxTransformed_" + cat_col)) 18 | 19 | val assembler = new VectorAssembler().setInputCols(indexTransformer.getOutputCols ++ cont_cols).setOutputCol("featureVector").setHandleInvalid("keep") 20 | 21 | val regressor = new LightGBMRegressor().setNumIterations(101).setLabelCol("mpg").setFeaturesCol(assembler.getOutputCol) 22 | 23 | val pipeline = new Pipeline().setStages(Array(indexer, indexTransformer, assembler, regressor)) 24 | val pipelineModel = pipeline.fit(df) 25 | 26 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/LightGBMAutoNA.zip")) 27 | 28 | var lgbDf = pipelineModel.transform(df) 29 | lgbDf = lgbDf.selectExpr("prediction as mpg") 30 | 31 | DatasetUtil.storeCsv(lgbDf, new File("csv/LightGBMAutoNA.csv")) 32 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/LightGBMIris.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.sql.functions.{lit, udf} 8 | import org.apache.spark.sql.types.StringType 9 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 10 | 11 | var df = DatasetUtil.loadCsv(spark, new File("csv/Iris.csv")) 12 | 13 | DatasetUtil.storeSchema(df, new File("schema/Iris.json")) 14 | 15 | val labelIndexer = new StringIndexer().setInputCol("Species").setOutputCol("idx_Species") 16 | val labelIndexerModel = labelIndexer.fit(df) 17 | 18 | val assembler = new VectorAssembler().setInputCols(Array("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width")).setOutputCol("featureVector") 19 | 20 | val classifier = new LightGBMClassifier().setObjective("multiclass").setNumIterations(17).setLabelCol(labelIndexer.getOutputCol).setFeaturesCol(assembler.getOutputCol) 21 | 22 | val pipeline = new Pipeline().setStages(Array(labelIndexer, assembler, classifier)) 23 | val pipelineModel = pipeline.fit(df) 24 | 25 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/LightGBMIris.zip")) 26 | 27 | val predLabel = udf{ (value: Double) => labelIndexerModel.labels(value.toInt) } 28 | val vectorToColumn = udf{ (vec: Vector, index: Int) => vec(index) } 29 | 30 | var lgbDf = pipelineModel.transform(df) 31 | lgbDf = lgbDf.selectExpr("prediction", "probability") 32 | lgbDf = lgbDf.withColumn("Species", predLabel(lgbDf("prediction"))).drop("prediction") 33 | lgbDf = lgbDf.withColumn("probability(setosa)", vectorToColumn(lgbDf("probability"), lit(0))).withColumn("probability(versicolor)", vectorToColumn(lgbDf("probability"), lit(1))).withColumn("probability(virginica)", vectorToColumn(lgbDf("probability"), lit(2))).drop("probability").drop("probability") 34 | 35 | DatasetUtil.storeCsv(lgbDf, new File("csv/LightGBMIris.csv")) 36 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAuditNA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAuditNA.zip -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAutoNA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMAutoNA.zip -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-lightgbm/src/test/resources/pipeline/LightGBMIris.zip -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/schema/Audit.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"Age","type":"integer","nullable":true,"metadata":{}},{"name":"Employment","type":"string","nullable":true,"metadata":{}},{"name":"Education","type":"string","nullable":true,"metadata":{}},{"name":"Marital","type":"string","nullable":true,"metadata":{}},{"name":"Occupation","type":"string","nullable":true,"metadata":{}},{"name":"Income","type":"double","nullable":true,"metadata":{}},{"name":"Gender","type":"string","nullable":true,"metadata":{}},{"name":"Deductions","type":"integer","nullable":true,"metadata":{}},{"name":"Hours","type":"integer","nullable":true,"metadata":{}},{"name":"Adjusted","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/schema/AuditNA.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"Age","type":"integer","nullable":true,"metadata":{}},{"name":"Employment","type":"string","nullable":true,"metadata":{}},{"name":"Education","type":"string","nullable":true,"metadata":{}},{"name":"Marital","type":"string","nullable":true,"metadata":{}},{"name":"Occupation","type":"string","nullable":true,"metadata":{}},{"name":"Income","type":"double","nullable":true,"metadata":{}},{"name":"Gender","type":"string","nullable":true,"metadata":{}},{"name":"Deductions","type":"integer","nullable":true,"metadata":{}},{"name":"Hours","type":"integer","nullable":true,"metadata":{}},{"name":"Adjusted","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/schema/Auto.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"cylinders","type":"integer","nullable":true,"metadata":{}},{"name":"displacement","type":"double","nullable":true,"metadata":{}},{"name":"horsepower","type":"integer","nullable":true,"metadata":{}},{"name":"weight","type":"integer","nullable":true,"metadata":{}},{"name":"acceleration","type":"double","nullable":true,"metadata":{}},{"name":"model_year","type":"integer","nullable":true,"metadata":{}},{"name":"mpg","type":"double","nullable":true,"metadata":{}},{"name":"origin","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/schema/AutoNA.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"cylinders","type":"integer","nullable":true,"metadata":{}},{"name":"displacement","type":"integer","nullable":true,"metadata":{}},{"name":"horsepower","type":"integer","nullable":true,"metadata":{}},{"name":"weight","type":"integer","nullable":true,"metadata":{}},{"name":"acceleration","type":"double","nullable":true,"metadata":{}},{"name":"model_year","type":"integer","nullable":true,"metadata":{}},{"name":"origin","type":"integer","nullable":true,"metadata":{}},{"name":"mpg","type":"double","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/schema/Iris.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"Sepal_Length","type":"double","nullable":true,"metadata":{}},{"name":"Sepal_Width","type":"double","nullable":true,"metadata":{}},{"name":"Petal_Length","type":"double","nullable":true,"metadata":{}},{"name":"Petal_Width","type":"double","nullable":true,"metadata":{}},{"name":"Species","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.1-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml-xgboost 13 | jar 14 | 15 | JPMML Spark ML XGBoost converter 16 | JPMML Apache Spark ML XGBoost to PMML converter 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-sparkml 30 | 31 | 32 | 33 | org.jpmml 34 | pmml-evaluator-testing 35 | provided 36 | 37 | 38 | 39 | org.jpmml 40 | pmml-xgboost 41 | 42 | 43 | 44 | ml.dmlc 45 | xgboost4j-spark_2.12 46 | provided 47 | 48 | 49 | 50 | org.apache.spark 51 | spark-core_2.12 52 | provided 53 | 54 | 55 | org.apache.spark 56 | spark-mllib_2.12 57 | provided 58 | 59 | 60 | 61 | org.junit.jupiter 62 | junit-jupiter-api 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/java/org/jpmml/sparkml/xgboost/BoosterUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.xgboost; 20 | 21 | import java.io.File; 22 | import java.io.FileInputStream; 23 | import java.io.InputStream; 24 | import java.util.LinkedHashMap; 25 | import java.util.Map; 26 | 27 | import com.google.common.io.MoreFiles; 28 | import com.google.common.io.RecursiveDeleteOption; 29 | import ml.dmlc.xgboost4j.scala.Booster; 30 | import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams; 31 | import org.apache.spark.ml.Model; 32 | import org.apache.spark.ml.param.shared.HasPredictionCol; 33 | import org.dmg.pmml.mining.MiningModel; 34 | import org.jpmml.converter.Schema; 35 | import org.jpmml.sparkml.ModelConverter; 36 | import org.jpmml.xgboost.HasXGBoostOptions; 37 | import org.jpmml.xgboost.Learner; 38 | import org.jpmml.xgboost.XGBoostUtil; 39 | 40 | public class BoosterUtil { 41 | 42 | private BoosterUtil(){ 43 | } 44 | 45 | static 46 | public & HasPredictionCol & GeneralParams, C extends ModelConverter> MiningModel encodeBooster(C converter, Booster booster, Schema schema){ 47 | M model = converter.getModel(); 48 | 49 | Learner learner; 50 | 51 | try { 52 | File tmpBoosterFile = File.createTempFile("Booster", ".json"); 53 | 54 | booster.saveModel(tmpBoosterFile.getAbsolutePath()); 55 | 56 | try(InputStream is = new FileInputStream(tmpBoosterFile)){ 57 | learner = XGBoostUtil.loadLearner(is); 58 | } 59 | 60 | MoreFiles.deleteRecursively(tmpBoosterFile.toPath(), RecursiveDeleteOption.ALLOW_INSECURE); 61 | } catch(Exception e){ 62 | throw new RuntimeException(e); 63 | } 64 | 65 | Float missing = model.getMissing(); 66 | if(missing.isNaN()){ 67 | missing = null; 68 | } 69 | 70 | Map options = new LinkedHashMap<>(); 71 | options.put(HasXGBoostOptions.OPTION_MISSING, converter.getOption(HasXGBoostOptions.OPTION_MISSING, missing)); 72 | options.put(HasXGBoostOptions.OPTION_COMPACT, converter.getOption(HasXGBoostOptions.OPTION_COMPACT, false)); 73 | options.put(HasXGBoostOptions.OPTION_INPUT_FLOAT, converter.getOption(HasXGBoostOptions.OPTION_INPUT_FLOAT, null)); 74 | options.put(HasXGBoostOptions.OPTION_NUMERIC, converter.getOption(HasXGBoostOptions.OPTION_NUMERIC, true)); 75 | options.put(HasXGBoostOptions.OPTION_PRUNE, converter.getOption(HasXGBoostOptions.OPTION_PRUNE, false)); 76 | options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, converter.getOption(HasXGBoostOptions.OPTION_NTREE_LIMIT, null)); 77 | 78 | Schema xgbSchema = learner.toXGBoostSchema(schema); 79 | 80 | return learner.encodeModel(options, xgbSchema); 81 | } 82 | } -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/java/org/jpmml/sparkml/xgboost/XGBoostClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.xgboost; 20 | 21 | import ml.dmlc.xgboost4j.scala.Booster; 22 | import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel; 23 | import org.dmg.pmml.mining.MiningModel; 24 | import org.dmg.pmml.regression.RegressionModel; 25 | import org.jpmml.converter.Schema; 26 | import org.jpmml.converter.mining.MiningModelUtil; 27 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 28 | 29 | public class XGBoostClassificationModelConverter extends ProbabilisticClassificationModelConverter { 30 | 31 | public XGBoostClassificationModelConverter(XGBoostClassificationModel model){ 32 | super(model); 33 | } 34 | 35 | @Override 36 | public MiningModel encodeModel(Schema schema){ 37 | XGBoostClassificationModel model = getModel(); 38 | 39 | Booster booster = model.nativeBooster(); 40 | 41 | MiningModel miningModel = BoosterUtil.encodeBooster(this, booster, schema); 42 | 43 | RegressionModel regressionModel = (RegressionModel)MiningModelUtil.getFinalModel(miningModel); 44 | regressionModel.setOutput(null); 45 | 46 | return miningModel; 47 | } 48 | } -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/java/org/jpmml/sparkml/xgboost/XGBoostRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.xgboost; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import ml.dmlc.xgboost4j.scala.Booster; 25 | import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel; 26 | import org.dmg.pmml.MiningFunction; 27 | import org.dmg.pmml.Model; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.OutputField; 30 | import org.dmg.pmml.mining.MiningModel; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.DerivedOutputField; 33 | import org.jpmml.converter.Label; 34 | import org.jpmml.converter.ModelUtil; 35 | import org.jpmml.converter.ScalarLabel; 36 | import org.jpmml.converter.Schema; 37 | import org.jpmml.sparkml.PredictionModelConverter; 38 | import org.jpmml.sparkml.SparkMLEncoder; 39 | import org.jpmml.sparkml.model.HasPredictionModelOptions; 40 | 41 | public class XGBoostRegressionModelConverter extends PredictionModelConverter { 42 | 43 | public XGBoostRegressionModelConverter(XGBoostRegressionModel model){ 44 | super(model); 45 | } 46 | 47 | @Override 48 | public MiningFunction getMiningFunction(){ 49 | return MiningFunction.REGRESSION; 50 | } 51 | 52 | @Override 53 | public MiningModel encodeModel(Schema schema){ 54 | XGBoostRegressionModel model = getModel(); 55 | 56 | Booster booster = model.nativeBooster(); 57 | 58 | return BoosterUtil.encodeBooster(this, booster, schema); 59 | } 60 | 61 | @Override 62 | public List registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder){ 63 | XGBoostRegressionModel model = getModel(); 64 | 65 | ScalarLabel scalarLabel = (ScalarLabel)label; 66 | 67 | String predictionCol = model.getPredictionCol(); 68 | 69 | Boolean keepPredictionCol = (Boolean)getOption(HasPredictionModelOptions.OPTION_KEEP_PREDICTIONCOL, Boolean.TRUE); 70 | 71 | OutputField predictedOutputField = ModelUtil.createPredictedField(predictionCol, OpType.CONTINUOUS, scalarLabel.getDataType()); 72 | 73 | DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol); 74 | 75 | encoder.putOnlyFeature(predictionCol, new ContinuousFeature(encoder, predictedField)); 76 | 77 | return Collections.emptyList(); 78 | } 79 | } -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/resources/META-INF/sparkml2pmml.properties: -------------------------------------------------------------------------------- 1 | ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel = org.jpmml.sparkml.xgboost.XGBoostClassificationModelConverter 2 | ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel = org.jpmml.sparkml.xgboost.XGBoostRegressionModelConverter 3 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/XGBoostAudit.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.sql.functions.{lit, udf} 8 | import org.apache.spark.sql.types.StringType 9 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 10 | import org.jpmml.sparkml.feature.SparseToDenseTransformer 11 | 12 | var df = DatasetUtil.loadCsv(spark, new File("csv/Audit.csv")) 13 | df = DatasetUtil.castColumn(df, "Adjusted", StringType) 14 | 15 | DatasetUtil.storeSchema(df, new File("schema/Audit.json")) 16 | 17 | val cat_cols = Array("Education", "Employment", "Gender", "Marital", "Occupation") 18 | val cont_cols = Array("Age", "Hours", "Income") 19 | 20 | val labelIndexer = new StringIndexer().setInputCol("Adjusted").setOutputCol("idx_Adjusted") 21 | 22 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)) 23 | val ohe = new OneHotEncoder().setHandleInvalid("keep").setDropLast(true).setInputCols(indexer.getOutputCols).setOutputCols(cat_cols.map(cat_col => "ohe_" + cat_col)) 24 | val assembler = new VectorAssembler().setInputCols(ohe.getOutputCols ++ cont_cols).setOutputCol("featureVector") 25 | 26 | val sparse2dense = new SparseToDenseTransformer().setInputCol(assembler.getOutputCol).setOutputCol("denseFeatureVec") 27 | 28 | val classifier = new XGBoostClassifier(Map("objective" -> "binary:logistic", "num_round" -> 101)).setLabelCol(labelIndexer.getOutputCol).setFeaturesCol(sparse2dense.getOutputCol) 29 | 30 | val pipeline = new Pipeline().setStages(Array(labelIndexer, indexer, ohe, assembler, sparse2dense, classifier)) 31 | val pipelineModel = pipeline.fit(df) 32 | 33 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/XGBoostAudit.zip")) 34 | 35 | val predLabel = udf{ (value: Float) => value.toInt.toString } 36 | val vectorToColumn = udf{ (vec: Vector, index: Int) => vec(index).toFloat } 37 | 38 | var xgbDf = pipelineModel.transform(df) 39 | xgbDf = xgbDf.selectExpr("prediction", "probability") 40 | xgbDf = xgbDf.withColumn("Adjusted", predLabel(xgbDf("prediction"))).drop("prediction") 41 | xgbDf = xgbDf.withColumn("probability(0)", vectorToColumn(xgbDf("probability"), lit(0))).withColumn("probability(1)", vectorToColumn(xgbDf("probability"), lit(1))).drop("probability").drop("probability") 42 | 43 | DatasetUtil.storeCsv(xgbDf, new File("csv/XGBoostAudit.csv")) 44 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/XGBoostAuditNA.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.ml.param.ParamMap 8 | import org.apache.spark.sql.functions.{lit, udf} 9 | import org.apache.spark.sql.types.StringType 10 | import org.apache.spark.ml.util.MLWritable 11 | import org.jpmml.sparkml.{ArchiveUtil, DatasetUtil, PipelineModelUtil} 12 | import org.jpmml.sparkml.feature.{InvalidCategoryTransformer, SparseToDenseTransformer} 13 | 14 | var df = DatasetUtil.loadCsv(spark, new File("csv/AuditNA.csv")) 15 | df = DatasetUtil.castColumn(df, "Adjusted", StringType) 16 | 17 | DatasetUtil.storeSchema(df, new File("schema/AuditNA.json")) 18 | 19 | val cat_cols = Array("Education", "Employment", "Gender", "Marital", "Occupation") 20 | val cont_cols = Array("Age", "Hours", "Income") 21 | 22 | val labelIndexer = new StringIndexer().setInputCol("Adjusted").setOutputCol("idx_Adjusted") 23 | 24 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)).setHandleInvalid("keep") 25 | val indexTransformer = new InvalidCategoryTransformer().setInputCols(indexer.getOutputCols).setOutputCols(cat_cols.map(cat_col => "idxTransformed_" + cat_col)) 26 | 27 | val assembler = new VectorAssembler().setInputCols(indexTransformer.getOutputCols ++ cont_cols).setOutputCol("featureVector").setHandleInvalid("keep") 28 | 29 | val sparse2dense = new SparseToDenseTransformer().setInputCol(assembler.getOutputCol).setOutputCol("denseFeatureVec") 30 | 31 | val classifier = new XGBoostClassifier(Map("objective" -> "binary:logistic", "num_round" -> 101)).setLabelCol(labelIndexer.getOutputCol).setFeaturesCol(sparse2dense.getOutputCol).setFeatureTypes(Array("c", "c", "c", "c", "c", "q", "q", "q"))//.setHandleInvalid("keep").setMissing(Float.NaN) 32 | 33 | val pipeline = new Pipeline().setStages(Array(labelIndexer, indexer, indexTransformer, assembler, sparse2dense, classifier)) 34 | val pipelineModel = pipeline.fit(df) 35 | 36 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/XGBoostAuditNA.zip")) 37 | 38 | val predLabel = udf{ (value: Float) => value.toInt.toString } 39 | val vectorToColumn = udf{ (vec: Vector, index: Int) => vec(index).toFloat } 40 | 41 | var xgbDf = pipelineModel.transform(df) 42 | xgbDf = xgbDf.selectExpr("prediction", "probability") 43 | xgbDf = xgbDf.withColumn("Adjusted", predLabel(xgbDf("prediction"))).drop("prediction") 44 | xgbDf = xgbDf.withColumn("probability(0)", vectorToColumn(xgbDf("probability"), lit(0))).withColumn("probability(1)", vectorToColumn(xgbDf("probability"), lit(1))).drop("probability").drop("probability") 45 | 46 | DatasetUtil.storeCsv(xgbDf, new File("csv/XGBoostAuditNA.csv")) 47 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/XGBoostAuto.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.sql.types.{FloatType, StringType} 7 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 8 | import org.jpmml.sparkml.feature.SparseToDenseTransformer 9 | 10 | var df = DatasetUtil.loadCsv(spark, new File("csv/Auto.csv")) 11 | df = DatasetUtil.castColumn(df, "origin", StringType) 12 | 13 | DatasetUtil.storeSchema(df, new File("schema/Auto.json")) 14 | 15 | val cat_cols = Array("cylinders", "model_year", "origin") 16 | val cont_cols = Array("acceleration", "displacement", "horsepower", "weight") 17 | 18 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)) 19 | val ohe = new OneHotEncoder().setHandleInvalid("keep").setDropLast(false).setInputCols(indexer.getOutputCols).setOutputCols(cat_cols.map(cat_col => "ohe_" + cat_col)) 20 | val assembler = new VectorAssembler().setInputCols(ohe.getOutputCols ++ cont_cols).setOutputCol("featureVector") 21 | 22 | val sparse2dense = new SparseToDenseTransformer().setInputCol(assembler.getOutputCol).setOutputCol("denseFeatureVec") 23 | 24 | val regressor = new XGBoostRegressor(Map("objective" -> "reg:squarederror", "num_round" -> 101)).setLabelCol("mpg").setFeaturesCol(sparse2dense.getOutputCol) 25 | 26 | val pipeline = new Pipeline().setStages(Array(indexer, ohe, assembler, sparse2dense, regressor)) 27 | val pipelineModel = pipeline.fit(df) 28 | 29 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/XGBoostAuto.zip")) 30 | 31 | var xgbDf = pipelineModel.transform(df) 32 | xgbDf = xgbDf.selectExpr("prediction as mpg") 33 | xgbDf = DatasetUtil.castColumn(xgbDf, "mpg", FloatType) 34 | 35 | DatasetUtil.storeCsv(xgbDf, new File("csv/XGBoostAuto.csv")) 36 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/XGBoostAutoNA.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.param.ParamMap 7 | import org.apache.spark.ml.util.MLWritable 8 | import org.apache.spark.sql.types.FloatType 9 | import org.jpmml.sparkml.{ArchiveUtil, DatasetUtil, PipelineModelUtil} 10 | import org.jpmml.sparkml.feature.InvalidCategoryTransformer 11 | 12 | var df = DatasetUtil.loadCsv(spark, new File("csv/AutoNA.csv")) 13 | 14 | DatasetUtil.storeSchema(df, new File("schema/AutoNA.json")) 15 | 16 | val cat_cols = Array("cylinders", "model_year", "origin") 17 | val cont_cols = Array("acceleration", "displacement", "horsepower", "weight") 18 | 19 | val indexer = new StringIndexer().setInputCols(cat_cols).setOutputCols(cat_cols.map(cat_col => "idx_" + cat_col)).setHandleInvalid("keep") 20 | val indexTransformer = new InvalidCategoryTransformer().setInputCols(indexer.getOutputCols).setOutputCols(cat_cols.map(cat_col => "idxTransformed_" + cat_col)) 21 | 22 | val assembler = new VectorAssembler().setInputCols(indexTransformer.getOutputCols ++ cont_cols).setOutputCol("featureVector").setHandleInvalid("keep") 23 | 24 | val regressor = new XGBoostRegressor(Map("objective" -> "reg:squarederror", "num_round" -> 101, "num_workers" -> 1, "tree_method" -> "hist")).setLabelCol("mpg").setFeaturesCol(assembler.getOutputCol).setFeatureTypes(Array("c", "c", "c", "q", "q", "q", "q")) 25 | 26 | val pipeline = new Pipeline().setStages(Array(indexer, indexTransformer, assembler, regressor)) 27 | val pipelineModel = pipeline.fit(df) 28 | 29 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/XGBoostAutoNA.zip")) 30 | 31 | var xgbDf = pipelineModel.transform(df) 32 | xgbDf = xgbDf.selectExpr("prediction as mpg") 33 | xgbDf = DatasetUtil.castColumn(xgbDf, "mpg", FloatType) 34 | 35 | DatasetUtil.storeCsv(xgbDf, new File("csv/XGBoostAutoNA.csv")) 36 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/XGBoostHousing.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.sql.types.FloatType 7 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 8 | 9 | var df = DatasetUtil.loadCsv(spark, new File("csv/Housing.csv")) 10 | 11 | DatasetUtil.storeSchema(df, new File("schema/Housing.json")) 12 | 13 | val cat_cols = Array("CHAS", "RAD", "TAX") 14 | val cont_cols = Array("CRIM", "ZN", "INDUS", "NOX", "RM", "AGE", "DIS", "PTRATIO", "B", "LSTAT") 15 | 16 | val assembler = new VectorAssembler().setInputCols(cat_cols ++ cont_cols).setOutputCol("featureVector") 17 | val indexer = new VectorIndexer().setInputCol(assembler.getOutputCol).setOutputCol("catFeatureVector") 18 | 19 | val regressor = new XGBoostRegressor(Map("objective" -> "reg:squarederror", "num_round" -> 101)).setMissing(-1).setLabelCol("MEDV").setFeaturesCol(indexer.getOutputCol) 20 | 21 | val pipeline = new Pipeline().setStages(Array(assembler, indexer, regressor)) 22 | val pipelineModel = pipeline.fit(df) 23 | 24 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/XGBoostHousing.zip")) 25 | 26 | var xgbDf = pipelineModel.transform(df) 27 | xgbDf = xgbDf.selectExpr("prediction as MEDV") 28 | xgbDf = DatasetUtil.castColumn(xgbDf, "MEDV", FloatType) 29 | 30 | DatasetUtil.storeCsv(xgbDf, new File("csv/XGBoostHousing.csv")) 31 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/XGBoostIris.scala: -------------------------------------------------------------------------------- 1 | import java.io.File 2 | 3 | import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.ml.linalg.Vector 7 | import org.apache.spark.sql.functions.{lit, udf} 8 | import org.apache.spark.sql.types.{DataType, StringType, StructType} 9 | import org.jpmml.sparkml.{DatasetUtil, PipelineModelUtil} 10 | 11 | var df = DatasetUtil.loadCsv(spark, new File("csv/Iris.csv")) 12 | 13 | val schema = df.schema 14 | val floatSchema = DataType.fromJson(schema.json.replaceAll("double", "float")) 15 | 16 | df = DatasetUtil.castColumns(df, floatSchema.asInstanceOf[StructType]) 17 | 18 | DatasetUtil.storeSchema(df, new File("schema/Iris.json")) 19 | 20 | val labelIndexer = new StringIndexer().setInputCol("Species").setOutputCol("idx_Species") 21 | val labelIndexerModel = labelIndexer.fit(df) 22 | 23 | val assembler = new VectorAssembler().setInputCols(Array("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width")).setOutputCol("featureVector") 24 | 25 | val classifier = new XGBoostClassifier(Map("objective" -> "multi:softprob", "num_class" -> 3)).setLabelCol(labelIndexer.getOutputCol).setFeaturesCol(assembler.getOutputCol) 26 | 27 | val pipeline = new Pipeline().setStages(Array(labelIndexer, assembler, classifier)) 28 | val pipelineModel = pipeline.fit(df) 29 | 30 | PipelineModelUtil.storeZip(pipelineModel, new File("pipeline/XGBoostIris.zip")) 31 | 32 | val predLabel = udf{ (value: Float) => labelIndexerModel.labels(value.toInt) } 33 | val vectorToColumn = udf{ (vec: Vector, index: Int) => vec(index).toFloat } 34 | 35 | var xgbDf = pipelineModel.transform(df) 36 | xgbDf = xgbDf.selectExpr("prediction", "probability") 37 | xgbDf = xgbDf.withColumn("Species", predLabel(xgbDf("prediction"))).drop("prediction") 38 | xgbDf = xgbDf.withColumn("probability(setosa)", vectorToColumn(xgbDf("probability"), lit(0))).withColumn("probability(versicolor)", vectorToColumn(xgbDf("probability"), lit(1))).withColumn("probability(virginica)", vectorToColumn(xgbDf("probability"), lit(2))).drop("probability").drop("probability") 39 | 40 | DatasetUtil.storeCsv(xgbDf, new File("csv/XGBoostIris.csv")) 41 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAuditNA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAuditNA.zip -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAutoNA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostAutoNA.zip -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostIris.zip -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/Audit.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"Age","type":"integer","nullable":true,"metadata":{}},{"name":"Employment","type":"string","nullable":true,"metadata":{}},{"name":"Education","type":"string","nullable":true,"metadata":{}},{"name":"Marital","type":"string","nullable":true,"metadata":{}},{"name":"Occupation","type":"string","nullable":true,"metadata":{}},{"name":"Income","type":"double","nullable":true,"metadata":{}},{"name":"Gender","type":"string","nullable":true,"metadata":{}},{"name":"Deductions","type":"integer","nullable":true,"metadata":{}},{"name":"Hours","type":"integer","nullable":true,"metadata":{}},{"name":"Adjusted","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/AuditNA.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"Age","type":"integer","nullable":true,"metadata":{}},{"name":"Employment","type":"string","nullable":true,"metadata":{}},{"name":"Education","type":"string","nullable":true,"metadata":{}},{"name":"Marital","type":"string","nullable":true,"metadata":{}},{"name":"Occupation","type":"string","nullable":true,"metadata":{}},{"name":"Income","type":"double","nullable":true,"metadata":{}},{"name":"Gender","type":"string","nullable":true,"metadata":{}},{"name":"Deductions","type":"integer","nullable":true,"metadata":{}},{"name":"Hours","type":"integer","nullable":true,"metadata":{}},{"name":"Adjusted","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/Auto.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"cylinders","type":"integer","nullable":true,"metadata":{}},{"name":"displacement","type":"double","nullable":true,"metadata":{}},{"name":"horsepower","type":"integer","nullable":true,"metadata":{}},{"name":"weight","type":"integer","nullable":true,"metadata":{}},{"name":"acceleration","type":"double","nullable":true,"metadata":{}},{"name":"model_year","type":"integer","nullable":true,"metadata":{}},{"name":"mpg","type":"double","nullable":true,"metadata":{}},{"name":"origin","type":"string","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/AutoNA.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"cylinders","type":"integer","nullable":true,"metadata":{}},{"name":"displacement","type":"integer","nullable":true,"metadata":{}},{"name":"horsepower","type":"integer","nullable":true,"metadata":{}},{"name":"weight","type":"integer","nullable":true,"metadata":{}},{"name":"acceleration","type":"double","nullable":true,"metadata":{}},{"name":"model_year","type":"integer","nullable":true,"metadata":{}},{"name":"origin","type":"integer","nullable":true,"metadata":{}},{"name":"mpg","type":"double","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/Housing.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"CRIM","type":"double","nullable":true,"metadata":{}},{"name":"ZN","type":"double","nullable":true,"metadata":{}},{"name":"INDUS","type":"double","nullable":true,"metadata":{}},{"name":"CHAS","type":"integer","nullable":true,"metadata":{}},{"name":"NOX","type":"double","nullable":true,"metadata":{}},{"name":"RM","type":"double","nullable":true,"metadata":{}},{"name":"AGE","type":"double","nullable":true,"metadata":{}},{"name":"DIS","type":"double","nullable":true,"metadata":{}},{"name":"RAD","type":"integer","nullable":true,"metadata":{}},{"name":"TAX","type":"double","nullable":true,"metadata":{}},{"name":"PTRATIO","type":"double","nullable":true,"metadata":{}},{"name":"B","type":"double","nullable":true,"metadata":{}},{"name":"LSTAT","type":"double","nullable":true,"metadata":{}},{"name":"MEDV","type":"double","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/Iris.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"Species","type":"string","nullable":true,"metadata":{}},{"name":"Sepal_Length","type":"float","nullable":true,"metadata":{}},{"name":"Sepal_Width","type":"float","nullable":true,"metadata":{}},{"name":"Petal_Length","type":"float","nullable":true,"metadata":{}},{"name":"Petal_Width","type":"float","nullable":true,"metadata":{}}]} -------------------------------------------------------------------------------- /pmml-sparkml/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.1-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml 13 | jar 14 | 15 | JPMML Spark ML converter 16 | JPMML Apache Spark ML to PMML converter 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-converter 30 | 31 | 32 | org.jpmml 33 | pmml-converter-testing 34 | 35 | 36 | 37 | org.apache.spark 38 | spark-core_2.12 39 | provided 40 | 41 | 42 | org.apache.spark 43 | spark-mllib_2.12 44 | provided 45 | 46 | 47 | 48 | org.jpmml 49 | pmml-evaluator-testing 50 | provided 51 | 52 | 53 | 54 | org.junit.jupiter 55 | junit-jupiter-api 56 | 57 | 58 | 59 | org.apache.hadoop 60 | hadoop-client 61 | 62 | 63 | 64 | 65 | 66 | 67 | org.apache.maven.plugins 68 | maven-jar-plugin 69 | 3.4.2 70 | 71 | 72 | 73 | JPMML-SparkML library 74 | ${project.version} 75 | 76 | 77 | 78 | 79 | 80 | org.apache.maven.plugins 81 | maven-javadoc-plugin 82 | 83 | 84 | 88 | **/*Converter.java 89 | 90 | 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/AliasExpression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Objects; 22 | 23 | import org.dmg.pmml.Expression; 24 | import org.dmg.pmml.HasExpression; 25 | import org.dmg.pmml.PMMLObject; 26 | import org.dmg.pmml.Visitor; 27 | import org.dmg.pmml.VisitorAction; 28 | import org.jpmml.model.visitors.ExpressionFilterer; 29 | 30 | public class AliasExpression extends Expression implements HasExpression { 31 | 32 | private String name = null; 33 | 34 | private Expression expression = null; 35 | 36 | 37 | public AliasExpression(String name, Expression expression){ 38 | setName(name); 39 | setExpression(expression); 40 | } 41 | 42 | public String getName(){ 43 | return this.name; 44 | } 45 | 46 | public AliasExpression setName(String name){ 47 | this.name = Objects.requireNonNull(name); 48 | 49 | return this; 50 | } 51 | 52 | @Override 53 | public Expression requireExpression(){ 54 | 55 | if(this.expression == null){ 56 | throw new IllegalStateException(); 57 | } 58 | 59 | return this.expression; 60 | } 61 | 62 | @Override 63 | public Expression getExpression(){ 64 | return this.expression; 65 | } 66 | 67 | @Override 68 | public AliasExpression setExpression(Expression expression){ 69 | this.expression = Objects.requireNonNull(expression); 70 | 71 | return this; 72 | } 73 | 74 | @Override 75 | public VisitorAction accept(Visitor visitor){ 76 | VisitorAction status = visitor.visit(this); 77 | 78 | if(status == VisitorAction.CONTINUE){ 79 | visitor.pushParent(this); 80 | 81 | status = PMMLObject.traverse(visitor, getExpression()); 82 | 83 | visitor.popParent(); 84 | } // End if 85 | 86 | if(status == VisitorAction.TERMINATE){ 87 | return VisitorAction.TERMINATE; 88 | } 89 | 90 | return VisitorAction.CONTINUE; 91 | } 92 | 93 | static 94 | public Expression unwrap(Expression expression){ 95 | expression = unwrapInternal(expression); 96 | 97 | ExpressionFilterer filterer = new ExpressionFilterer(){ 98 | 99 | @Override 100 | public Expression filter(Expression expression){ 101 | return unwrapInternal(expression); 102 | } 103 | }; 104 | filterer.applyTo(expression); 105 | 106 | return expression; 107 | } 108 | 109 | static 110 | private Expression unwrapInternal(Expression expression){ 111 | 112 | while(expression instanceof AliasExpression){ 113 | AliasExpression aliasExpression = (AliasExpression)expression; 114 | 115 | expression = aliasExpression.getExpression(); 116 | } 117 | 118 | return expression; 119 | } 120 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/AssociationRulesModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.ml.Model; 22 | import org.apache.spark.ml.param.shared.HasPredictionCol; 23 | import org.dmg.pmml.MiningFunction; 24 | 25 | abstract 26 | public class AssociationRulesModelConverter & HasPredictionCol> extends ModelConverter { 27 | 28 | public AssociationRulesModelConverter(T model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public MiningFunction getMiningFunction(){ 34 | return MiningFunction.ASSOCIATION_RULES; 35 | } 36 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/BinarizedCategoricalFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.List; 22 | import java.util.Objects; 23 | 24 | import org.jpmml.converter.BinaryFeature; 25 | import org.jpmml.converter.CategoricalFeature; 26 | import org.jpmml.converter.ContinuousFeature; 27 | import org.jpmml.converter.Feature; 28 | import org.jpmml.converter.PMMLEncoder; 29 | import org.jpmml.model.ToStringHelper; 30 | 31 | public class BinarizedCategoricalFeature extends Feature { 32 | 33 | private List binaryFeatures = null; 34 | 35 | 36 | public BinarizedCategoricalFeature(PMMLEncoder encoder, CategoricalFeature categoricalFeature, List binaryFeatures){ 37 | super(encoder, categoricalFeature.getName(), categoricalFeature.getDataType()); 38 | 39 | setBinaryFeatures(binaryFeatures); 40 | } 41 | 42 | @Override 43 | public ContinuousFeature toContinuousFeature(){ 44 | throw new UnsupportedOperationException(); 45 | } 46 | 47 | @Override 48 | public int hashCode(){ 49 | return (31 * super.hashCode()) + Objects.hashCode(this.getBinaryFeatures()); 50 | } 51 | 52 | @Override 53 | public boolean equals(Object object){ 54 | 55 | if(object instanceof BinarizedCategoricalFeature){ 56 | BinarizedCategoricalFeature that = (BinarizedCategoricalFeature)object; 57 | 58 | return super.equals(that) && Objects.equals(this.getBinaryFeatures(), that.getBinaryFeatures()); 59 | } 60 | 61 | return false; 62 | } 63 | 64 | @Override 65 | protected ToStringHelper toStringHelper(){ 66 | return super.toStringHelper() 67 | .add("binaryFeatures", getBinaryFeatures()); 68 | } 69 | 70 | public List getBinaryFeatures(){ 71 | return this.binaryFeatures; 72 | } 73 | 74 | private void setBinaryFeatures(List binaryFeatures){ 75 | 76 | if(binaryFeatures != null && binaryFeatures.isEmpty()){ 77 | throw new IllegalArgumentException(); 78 | } 79 | 80 | this.binaryFeatures = Objects.requireNonNull(binaryFeatures); 81 | } 82 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ClusteringModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.Model; 25 | import org.apache.spark.ml.param.shared.HasFeaturesCol; 26 | import org.apache.spark.ml.param.shared.HasPredictionCol; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.FieldRef; 29 | import org.dmg.pmml.MiningFunction; 30 | import org.dmg.pmml.OpType; 31 | import org.dmg.pmml.OutputField; 32 | import org.dmg.pmml.ResultFeature; 33 | import org.jpmml.converter.DerivedOutputField; 34 | import org.jpmml.converter.Feature; 35 | import org.jpmml.converter.FieldNameUtil; 36 | import org.jpmml.converter.IndexFeature; 37 | import org.jpmml.converter.Label; 38 | import org.jpmml.converter.LabelUtil; 39 | import org.jpmml.converter.ModelUtil; 40 | 41 | abstract 42 | public class ClusteringModelConverter & HasFeaturesCol & HasPredictionCol> extends ModelConverter { 43 | 44 | public ClusteringModelConverter(T model){ 45 | super(model); 46 | } 47 | 48 | abstract 49 | public int getNumberOfClusters(); 50 | 51 | @Override 52 | public MiningFunction getMiningFunction(){ 53 | return MiningFunction.CLUSTERING; 54 | } 55 | 56 | @Override 57 | public List getFeatures(SparkMLEncoder encoder){ 58 | T model = getModel(); 59 | 60 | String featuresCol = model.getFeaturesCol(); 61 | 62 | return encoder.getFeatures(featuresCol); 63 | } 64 | 65 | @Override 66 | public List registerOutputFields(Label label, org.dmg.pmml.Model pmmlModel, SparkMLEncoder encoder){ 67 | T model = getModel(); 68 | 69 | List clusters = LabelUtil.createTargetCategories(getNumberOfClusters()); 70 | 71 | String predictionCol = model.getPredictionCol(); 72 | 73 | OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldNameUtil.create("pmml", predictionCol), OpType.CATEGORICAL, DataType.STRING) 74 | .setFinalResult(false); 75 | 76 | DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, true); 77 | 78 | OutputField predictedOutputField = new OutputField(predictionCol, OpType.CATEGORICAL, DataType.INTEGER) 79 | .setResultFeature(ResultFeature.TRANSFORMED_VALUE) 80 | .setExpression(new FieldRef(pmmlPredictedField)); 81 | 82 | DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, true); 83 | 84 | encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, clusters)); 85 | 86 | return Collections.emptyList(); 87 | } 88 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/HasSparkMLOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.jpmml.converter.HasOptions; 22 | 23 | /** 24 | * @see TransformerConverter#getOption(String, Object) 25 | */ 26 | public interface HasSparkMLOptions extends HasOptions { 27 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ItemSetFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.dmg.pmml.Field; 22 | import org.jpmml.converter.ContinuousFeature; 23 | import org.jpmml.converter.Feature; 24 | 25 | public class ItemSetFeature extends Feature { 26 | 27 | public ItemSetFeature(SparkMLEncoder encoder, Field field){ 28 | super(encoder, field.requireName(), field.requireDataType()); 29 | } 30 | 31 | @Override 32 | public ContinuousFeature toContinuousFeature(){ 33 | throw new UnsupportedOperationException(); 34 | } 35 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/MatrixUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.linalg.Matrix; 25 | 26 | public class MatrixUtil { 27 | 28 | private MatrixUtil(){ 29 | } 30 | 31 | public void checkColumns(int columns, Matrix matrix){ 32 | 33 | if(matrix.numCols() != columns){ 34 | throw new IllegalArgumentException("Expected " + columns + " column(s), got " + matrix.numCols() + " column(s)"); 35 | } 36 | } 37 | 38 | static 39 | public void checkRows(int rows, Matrix matrix){ 40 | 41 | if(matrix.numRows() != rows){ 42 | throw new IllegalArgumentException("Expected " + rows + " row(s), got " + matrix.numRows() + " row(s)"); 43 | } 44 | } 45 | 46 | static 47 | public List getRow(Matrix matrix, int row){ 48 | List result = new ArrayList<>(); 49 | 50 | for(int column = 0; column < matrix.numCols(); column++){ 51 | result.add(matrix.apply(row, column)); 52 | } 53 | 54 | return result; 55 | } 56 | 57 | static 58 | public List getColumn(Matrix matrix, int column){ 59 | List result = new ArrayList<>(); 60 | 61 | for(int row = 0; row < matrix.numRows(); row++){ 62 | result.add(matrix.apply(row, column)); 63 | } 64 | 65 | return result; 66 | } 67 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/MultiFeatureConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.ml.Transformer; 22 | import org.apache.spark.ml.param.shared.HasInputCol; 23 | import org.apache.spark.ml.param.shared.HasInputCols; 24 | import org.apache.spark.ml.param.shared.HasOutputCol; 25 | import org.apache.spark.ml.param.shared.HasOutputCols; 26 | 27 | abstract 28 | public class MultiFeatureConverter extends FeatureConverter { 29 | 30 | public MultiFeatureConverter(T transformer){ 31 | super(transformer); 32 | } 33 | 34 | @Override 35 | protected InOutMode getInputMode(){ 36 | T transformer = getTransformer(); 37 | 38 | return getInputMode(transformer); 39 | } 40 | 41 | @Override 42 | public InOutMode getOutputMode(){ 43 | return getInputMode(); 44 | } 45 | 46 | static 47 | public String formatName(T transformer, int index){ 48 | 49 | if(transformer.isSet(transformer.outputCols())){ 50 | return transformer.getOutputCols()[index]; 51 | } // End if 52 | 53 | if(index != 0){ 54 | throw new IllegalArgumentException(); 55 | } 56 | 57 | return transformer.getOutputCol(); 58 | } 59 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ProbabilisticClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.classification./*Probabilistic*/ClassificationModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.apache.spark.ml.param.shared.HasProbabilityCol; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.Model; 29 | import org.dmg.pmml.OutputField; 30 | import org.jpmml.converter.CategoricalLabel; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.Feature; 33 | import org.jpmml.converter.FieldNameUtil; 34 | import org.jpmml.converter.Label; 35 | import org.jpmml.converter.ModelUtil; 36 | 37 | abstract 38 | public class ProbabilisticClassificationModelConverter & HasProbabilityCol> extends ClassificationModelConverter { 39 | 40 | public ProbabilisticClassificationModelConverter(T model){ 41 | super(model); 42 | } 43 | 44 | @Override 45 | public List registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder){ 46 | T model = getModel(); 47 | 48 | List result = super.registerOutputFields(label, pmmlModel, encoder); 49 | 50 | CategoricalLabel categoricalLabel = (CategoricalLabel)label; 51 | 52 | String probabilityCol = model.getProbabilityCol(); 53 | 54 | result = new ArrayList<>(result); 55 | 56 | List features = new ArrayList<>(); 57 | 58 | for(int i = 0; i < categoricalLabel.size(); i++){ 59 | Object value = categoricalLabel.getValue(i); 60 | 61 | OutputField probabilityField = ModelUtil.createProbabilityField(FieldNameUtil.create(probabilityCol, value), DataType.DOUBLE, value); 62 | 63 | result.add(probabilityField); 64 | 65 | features.add(new ContinuousFeature(encoder, probabilityField)); 66 | } 67 | 68 | // XXX 69 | encoder.putFeatures(probabilityCol, features); 70 | 71 | return result; 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/RegexKey.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Objects; 22 | import java.util.function.Predicate; 23 | import java.util.regex.Matcher; 24 | import java.util.regex.Pattern; 25 | 26 | public class RegexKey implements Predicate { 27 | 28 | private Pattern pattern = null; 29 | 30 | 31 | public RegexKey(Pattern pattern){ 32 | setPattern(pattern); 33 | } 34 | 35 | @Override 36 | public boolean test(String string){ 37 | Pattern pattern = getPattern(); 38 | 39 | Matcher matcher = pattern.matcher(string); 40 | 41 | return matcher.matches(); 42 | } 43 | 44 | @Override 45 | public int hashCode(){ 46 | return hashCode(getPattern()); 47 | } 48 | 49 | @Override 50 | public boolean equals(Object object){ 51 | 52 | if(object instanceof RegexKey){ 53 | RegexKey that = (RegexKey)object; 54 | 55 | return equals(this.getPattern(), that.getPattern()); 56 | } 57 | 58 | return false; 59 | } 60 | 61 | public Pattern getPattern(){ 62 | return this.pattern; 63 | } 64 | 65 | private void setPattern(Pattern pattern){ 66 | this.pattern = pattern; 67 | } 68 | 69 | static 70 | private int hashCode(Pattern pattern){ 71 | return Objects.hash(pattern.pattern(), pattern.flags()); 72 | } 73 | 74 | static 75 | private boolean equals(Pattern left, Pattern right){ 76 | return Objects.equals(left.pattern(), right.pattern()) && Objects.equals(left.flags(), right.flags()); 77 | } 78 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/RegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.linalg.Vector; 25 | import org.apache.spark.ml.regression.RegressionModel; 26 | import org.dmg.pmml.MiningFunction; 27 | import org.dmg.pmml.Model; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.OutputField; 30 | import org.jpmml.converter.ContinuousFeature; 31 | import org.jpmml.converter.DerivedOutputField; 32 | import org.jpmml.converter.Label; 33 | import org.jpmml.converter.ModelUtil; 34 | import org.jpmml.converter.ScalarLabel; 35 | import org.jpmml.sparkml.model.HasPredictionModelOptions; 36 | 37 | abstract 38 | public class RegressionModelConverter> extends PredictionModelConverter { 39 | 40 | public RegressionModelConverter(T model){ 41 | super(model); 42 | } 43 | 44 | @Override 45 | public MiningFunction getMiningFunction(){ 46 | return MiningFunction.REGRESSION; 47 | } 48 | 49 | @Override 50 | public List registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder){ 51 | T model = getModel(); 52 | 53 | ScalarLabel scalarLabel = (ScalarLabel)label; 54 | 55 | String predictionCol = model.getPredictionCol(); 56 | 57 | Boolean keepPredictionCol = (Boolean)getOption(HasPredictionModelOptions.OPTION_KEEP_PREDICTIONCOL, Boolean.TRUE); 58 | 59 | OutputField predictedOutputField = ModelUtil.createPredictedField(predictionCol, OpType.CONTINUOUS, scalarLabel.getDataType()); 60 | 61 | DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, keepPredictionCol); 62 | 63 | encoder.putOnlyFeature(predictionCol, new ContinuousFeature(encoder, predictedField)); 64 | 65 | return Collections.emptyList(); 66 | } 67 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/SparkSessionUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.SparkContext; 22 | import org.apache.spark.sql.SparkSession; 23 | 24 | public class SparkSessionUtil { 25 | 26 | private SparkSessionUtil(){ 27 | } 28 | 29 | static 30 | public SparkSession createSparkSession(){ 31 | return createSparkSession("local"); 32 | } 33 | 34 | static 35 | public SparkSession createSparkSession(String master){ 36 | SparkSession.Builder builder = SparkSession.builder() 37 | .master(master) 38 | .config("spark.ui.enabled", false); 39 | 40 | SparkSession sparkSession = builder.getOrCreate(); 41 | 42 | SparkContext sparkContext = sparkSession.sparkContext(); 43 | sparkContext.setLogLevel("ERROR"); 44 | 45 | return sparkSession; 46 | } 47 | 48 | static 49 | public SparkSession destroySparkSession(SparkSession sparkSession){ 50 | sparkSession.stop(); 51 | 52 | return null; 53 | } 54 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/TermUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | public class TermUtil { 22 | 23 | private TermUtil(){ 24 | } 25 | 26 | static 27 | public boolean hasPunctuation(String string){ 28 | String[] tokens = string.split("\\s+"); 29 | 30 | for(String token : tokens){ 31 | int length = token.length(); 32 | 33 | if(length > 0){ 34 | char first = token.charAt(0); 35 | char last = token.charAt(length - 1); 36 | 37 | if(isPunctuation(first) || isPunctuation(last)){ 38 | return true; 39 | } 40 | } 41 | } 42 | 43 | return false; 44 | } 45 | 46 | static 47 | public boolean isPunctuation(char c){ 48 | int type = Character.getType(c); 49 | 50 | switch(type){ 51 | case Character.DASH_PUNCTUATION: 52 | case Character.END_PUNCTUATION: 53 | case Character.START_PUNCTUATION: 54 | case Character.CONNECTOR_PUNCTUATION: 55 | case Character.OTHER_PUNCTUATION: 56 | case Character.INITIAL_QUOTE_PUNCTUATION: 57 | case Character.FINAL_QUOTE_PUNCTUATION: 58 | return true; 59 | default: 60 | return false; 61 | } 62 | } 63 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/TransformerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Map; 22 | 23 | import org.apache.spark.ml.Transformer; 24 | 25 | abstract 26 | public class TransformerConverter { 27 | 28 | private T object = null; 29 | 30 | private Map options = null; 31 | 32 | 33 | public TransformerConverter(T object){ 34 | setObject(object); 35 | } 36 | 37 | public Object getOption(String key, Object defaultValue){ 38 | Map options = getOptions(); 39 | 40 | if(options != null && options.containsKey(key)){ 41 | return options.get(key); 42 | } 43 | 44 | return defaultValue; 45 | } 46 | 47 | public T getObject(){ 48 | return this.object; 49 | } 50 | 51 | private void setObject(T object){ 52 | this.object = object; 53 | } 54 | 55 | public Map getOptions(){ 56 | return this.options; 57 | } 58 | 59 | public void setOptions(Map options){ 60 | this.options = options; 61 | } 62 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/VectorUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.primitives.Doubles; 24 | import org.apache.spark.ml.linalg.DenseVector; 25 | import org.apache.spark.ml.linalg.Vector; 26 | 27 | public class VectorUtil { 28 | 29 | private VectorUtil(){ 30 | } 31 | 32 | static 33 | public List toList(Vector vector){ 34 | DenseVector denseVector = vector.toDense(); 35 | 36 | double[] values = denseVector.values(); 37 | 38 | return Doubles.asList(values); 39 | } 40 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/WeightedTermFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Objects; 22 | 23 | import org.dmg.pmml.Apply; 24 | import org.dmg.pmml.DefineFunction; 25 | import org.jpmml.converter.ExpressionUtil; 26 | import org.jpmml.converter.Feature; 27 | import org.jpmml.converter.PMMLEncoder; 28 | import org.jpmml.model.ToStringHelper; 29 | 30 | public class WeightedTermFeature extends TermFeature { 31 | 32 | private Number weight = null; 33 | 34 | 35 | public WeightedTermFeature(PMMLEncoder encoder, DefineFunction defineFunction, Feature feature, String value, Number weight){ 36 | super(encoder, defineFunction, feature, value); 37 | 38 | setWeight(weight); 39 | } 40 | 41 | @Override 42 | public Apply createApply(){ 43 | Number weight = getWeight(); 44 | 45 | Apply apply = super.createApply() 46 | .addExpressions(ExpressionUtil.createConstant(weight)); 47 | 48 | return apply; 49 | } 50 | 51 | @Override 52 | public int hashCode(){ 53 | return (31 * super.hashCode()) + Objects.hashCode(this.getWeight()); 54 | } 55 | 56 | @Override 57 | public boolean equals(Object object){ 58 | 59 | if(object instanceof WeightedTermFeature){ 60 | WeightedTermFeature that = (WeightedTermFeature)object; 61 | 62 | return super.equals(object) && Objects.equals(this.getWeight(), that.getWeight()); 63 | } 64 | 65 | return false; 66 | } 67 | 68 | @Override 69 | protected ToStringHelper toStringHelper(){ 70 | return super.toStringHelper() 71 | .add("weight", getWeight()); 72 | } 73 | 74 | public Number getWeight(){ 75 | return this.weight; 76 | } 77 | 78 | private void setWeight(Number weight){ 79 | this.weight = Objects.requireNonNull(weight); 80 | } 81 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/BinarizerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import org.apache.spark.ml.feature.Binarizer; 26 | import org.dmg.pmml.Apply; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.DerivedField; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.PMMLFunctions; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.ExpressionUtil; 33 | import org.jpmml.converter.Feature; 34 | import org.jpmml.converter.IndexFeature; 35 | import org.jpmml.sparkml.MultiFeatureConverter; 36 | import org.jpmml.sparkml.SparkMLEncoder; 37 | 38 | public class BinarizerConverter extends MultiFeatureConverter { 39 | 40 | public BinarizerConverter(Binarizer transformer){ 41 | super(transformer); 42 | } 43 | 44 | @Override 45 | public List encodeFeatures(SparkMLEncoder encoder){ 46 | Binarizer transformer = getTransformer(); 47 | 48 | Double threshold = transformer.getThreshold(); 49 | 50 | InOutMode inputMode = getInputMode(); 51 | 52 | List result = new ArrayList<>(); 53 | 54 | String[] inputCols = inputMode.getInputCols(transformer); 55 | for(int i = 0; i < inputCols.length; i++){ 56 | String inputCol = inputCols[i]; 57 | 58 | Feature feature = encoder.getOnlyFeature(inputCol); 59 | 60 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 61 | 62 | Apply apply = new Apply(PMMLFunctions.IF) 63 | .addExpressions(ExpressionUtil.createApply(PMMLFunctions.LESSOREQUAL, continuousFeature.ref(), ExpressionUtil.createConstant(threshold))) 64 | .addExpressions(ExpressionUtil.createConstant(0d), ExpressionUtil.createConstant(1d)); 65 | 66 | DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i), OpType.CATEGORICAL, DataType.DOUBLE, apply); 67 | 68 | result.add(new IndexFeature(encoder, derivedField, Arrays.asList(0d, 1d))); 69 | } 70 | 71 | return result; 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/ChiSqSelectorModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Arrays; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.ChiSqSelectorModel; 25 | import org.jpmml.converter.Feature; 26 | import org.jpmml.sparkml.FeatureConverter; 27 | import org.jpmml.sparkml.SparkMLEncoder; 28 | 29 | public class ChiSqSelectorModelConverter extends FeatureConverter { 30 | 31 | public ChiSqSelectorModelConverter(ChiSqSelectorModel transformer){ 32 | super(transformer); 33 | } 34 | 35 | @Override 36 | public List encodeFeatures(SparkMLEncoder encoder){ 37 | ChiSqSelectorModel transformer = getTransformer(); 38 | 39 | int[] indices = transformer.selectedFeatures(); 40 | if(indices.length > 0){ 41 | indices = indices.clone(); 42 | 43 | Arrays.sort(indices); 44 | } 45 | 46 | return encoder.getFeatures(transformer.getFeaturesCol(), indices); 47 | } 48 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/ColumnPrunerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import org.apache.spark.ml.feature.ColumnPruner; 22 | import org.jpmml.sparkml.FeatureConverter; 23 | 24 | public class ColumnPrunerConverter extends FeatureConverter { 25 | 26 | public ColumnPrunerConverter(ColumnPruner transformer){ 27 | super(transformer); 28 | } 29 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/IDFModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.IDFModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.jpmml.converter.ContinuousFeature; 27 | import org.jpmml.converter.Feature; 28 | import org.jpmml.converter.ProductFeature; 29 | import org.jpmml.converter.SchemaUtil; 30 | import org.jpmml.sparkml.FeatureConverter; 31 | import org.jpmml.sparkml.SparkMLEncoder; 32 | import org.jpmml.sparkml.TermFeature; 33 | import org.jpmml.sparkml.WeightedTermFeature; 34 | 35 | public class IDFModelConverter extends FeatureConverter { 36 | 37 | public IDFModelConverter(IDFModel transformer){ 38 | super(transformer); 39 | } 40 | 41 | @Override 42 | public List encodeFeatures(SparkMLEncoder encoder){ 43 | IDFModel transformer = getTransformer(); 44 | 45 | Vector idf = transformer.idf(); 46 | 47 | List features = encoder.getFeatures(transformer.getInputCol()); 48 | 49 | SchemaUtil.checkSize(idf.size(), features); 50 | 51 | List result = new ArrayList<>(); 52 | 53 | for(int i = 0; i < features.size(); i++){ 54 | Feature feature = features.get(i); 55 | Double weight = idf.apply(i); 56 | 57 | ProductFeature productFeature = new ProductFeature(encoder, feature, weight){ 58 | 59 | private WeightedTermFeature weightedTermFeature = null; 60 | 61 | 62 | @Override 63 | public ContinuousFeature toContinuousFeature(){ 64 | 65 | if(this.weightedTermFeature == null){ 66 | TermFeature termFeature = (TermFeature)getFeature(); 67 | Number factor = getFactor(); 68 | 69 | this.weightedTermFeature = termFeature.toWeightedTermFeature(factor); 70 | } 71 | 72 | return this.weightedTermFeature.toContinuousFeature(); 73 | } 74 | }; 75 | 76 | result.add(productFeature); 77 | } 78 | 79 | return result; 80 | } 81 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/IndexToStringConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Arrays; 22 | import java.util.Collections; 23 | import java.util.List; 24 | 25 | import org.apache.spark.ml.feature.IndexToString; 26 | import org.dmg.pmml.DataField; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.OpType; 29 | import org.jpmml.converter.CategoricalFeature; 30 | import org.jpmml.converter.Feature; 31 | import org.jpmml.sparkml.FeatureConverter; 32 | import org.jpmml.sparkml.SparkMLEncoder; 33 | 34 | public class IndexToStringConverter extends FeatureConverter { 35 | 36 | public IndexToStringConverter(IndexToString transformer){ 37 | super(transformer); 38 | } 39 | 40 | @Override 41 | public List encodeFeatures(SparkMLEncoder encoder){ 42 | IndexToString transformer = getTransformer(); 43 | 44 | DataField dataField = encoder.createDataField(formatName(transformer), OpType.CATEGORICAL, DataType.STRING, Arrays.asList(transformer.getLabels())); 45 | 46 | return Collections.singletonList(new CategoricalFeature(encoder, dataField)); 47 | } 48 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/InteractionConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import org.apache.spark.ml.feature.Interaction; 26 | import org.dmg.pmml.DataType; 27 | import org.jpmml.converter.CategoricalFeature; 28 | import org.jpmml.converter.Feature; 29 | import org.jpmml.converter.FieldNameUtil; 30 | import org.jpmml.converter.InteractionFeature; 31 | import org.jpmml.sparkml.FeatureConverter; 32 | import org.jpmml.sparkml.SparkMLEncoder; 33 | 34 | public class InteractionConverter extends FeatureConverter { 35 | 36 | public InteractionConverter(Interaction transformer){ 37 | super(transformer); 38 | } 39 | 40 | @Override 41 | public List encodeFeatures(SparkMLEncoder encoder){ 42 | Interaction transformer = getTransformer(); 43 | 44 | StringBuilder sb = new StringBuilder(); 45 | 46 | List result = new ArrayList<>(); 47 | 48 | String[] inputCols = transformer.getInputCols(); 49 | for(int i = 0; i < inputCols.length; i++){ 50 | String inputCol = inputCols[i]; 51 | 52 | List features = encoder.getFeatures(inputCol); 53 | 54 | if(features.size() == 1){ 55 | Feature feature = features.get(0); 56 | 57 | categorical: 58 | if(feature instanceof CategoricalFeature){ 59 | CategoricalFeature categoricalFeature = (CategoricalFeature)feature; 60 | 61 | String name = categoricalFeature.getName(); 62 | 63 | DataType dataType = categoricalFeature.getDataType(); 64 | switch(dataType){ 65 | case INTEGER: 66 | break; 67 | case FLOAT: 68 | case DOUBLE: 69 | break categorical; 70 | default: 71 | break; 72 | } 73 | 74 | // XXX 75 | inputCol = name; 76 | 77 | features = (List)OneHotEncoderModelConverter.encodeFeature(encoder, categoricalFeature, categoricalFeature.getValues()); 78 | } 79 | } // End if 80 | 81 | if(i == 0){ 82 | sb.append(inputCol); 83 | 84 | result = features; 85 | } else 86 | 87 | { 88 | sb.append(':').append(inputCol); 89 | 90 | List interactionFeatures = new ArrayList<>(); 91 | 92 | int index = 0; 93 | 94 | for(Feature left : result){ 95 | 96 | for(Feature right : features){ 97 | interactionFeatures.add(new InteractionFeature(encoder, FieldNameUtil.select(sb.toString(), index), DataType.DOUBLE, Arrays.asList(left, right))); 98 | 99 | index++; 100 | } 101 | } 102 | 103 | result = interactionFeatures; 104 | } 105 | } 106 | 107 | return result; 108 | } 109 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/MaxAbsScalerModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.MaxAbsScalerModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.DerivedField; 28 | import org.dmg.pmml.Expression; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.PMMLFunctions; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.ExpressionUtil; 33 | import org.jpmml.converter.Feature; 34 | import org.jpmml.converter.SchemaUtil; 35 | import org.jpmml.converter.ValueUtil; 36 | import org.jpmml.sparkml.FeatureConverter; 37 | import org.jpmml.sparkml.SparkMLEncoder; 38 | 39 | public class MaxAbsScalerModelConverter extends FeatureConverter { 40 | 41 | public MaxAbsScalerModelConverter(MaxAbsScalerModel transformer){ 42 | super(transformer); 43 | } 44 | 45 | @Override 46 | public List encodeFeatures(SparkMLEncoder encoder){ 47 | MaxAbsScalerModel transformer = getTransformer(); 48 | 49 | Vector maxAbs = transformer.maxAbs(); 50 | 51 | List features = encoder.getFeatures(transformer.getInputCol()); 52 | 53 | SchemaUtil.checkSize(maxAbs.size(), features); 54 | 55 | List result = new ArrayList<>(); 56 | 57 | for(int i = 0, length = features.size(); i < length; i++){ 58 | Feature feature = features.get(i); 59 | 60 | double maxAbsUnzero = maxAbs.apply(i); 61 | if(maxAbsUnzero == 0d){ 62 | maxAbsUnzero = 1d; 63 | } // End if 64 | 65 | if(!ValueUtil.isOne(maxAbsUnzero)){ 66 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 67 | 68 | Expression expression = ExpressionUtil.createApply(PMMLFunctions.DIVIDE, continuousFeature.ref(), ExpressionUtil.createConstant(maxAbsUnzero)); 69 | 70 | DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, expression); 71 | 72 | feature = new ContinuousFeature(encoder, derivedField); 73 | } 74 | 75 | result.add(feature); 76 | } 77 | 78 | return result; 79 | } 80 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/MinMaxScalerModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.MinMaxScalerModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.DerivedField; 28 | import org.dmg.pmml.Expression; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.PMMLFunctions; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.ExpressionUtil; 33 | import org.jpmml.converter.Feature; 34 | import org.jpmml.converter.SchemaUtil; 35 | import org.jpmml.converter.ValueUtil; 36 | import org.jpmml.sparkml.FeatureConverter; 37 | import org.jpmml.sparkml.SparkMLEncoder; 38 | 39 | public class MinMaxScalerModelConverter extends FeatureConverter { 40 | 41 | public MinMaxScalerModelConverter(MinMaxScalerModel transformer){ 42 | super(transformer); 43 | } 44 | 45 | @Override 46 | public List encodeFeatures(SparkMLEncoder encoder){ 47 | MinMaxScalerModel transformer = getTransformer(); 48 | 49 | double rescaleFactor = (transformer.getMax() - transformer.getMin()); 50 | double rescaleConstant = transformer.getMin(); 51 | 52 | Vector originalMin = transformer.originalMin(); 53 | Vector originalMax = transformer.originalMax(); 54 | 55 | List features = encoder.getFeatures(transformer.getInputCol()); 56 | 57 | SchemaUtil.checkSize(Math.max(originalMin.size(), originalMax.size()), features); 58 | 59 | List result = new ArrayList<>(); 60 | 61 | for(int i = 0, length = features.size(); i < length; i++){ 62 | Feature feature = features.get(i); 63 | 64 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 65 | 66 | double min = originalMin.apply(i); 67 | double max = originalMax.apply(i); 68 | 69 | Expression expression = ExpressionUtil.createApply(PMMLFunctions.DIVIDE, ExpressionUtil.createApply(PMMLFunctions.SUBTRACT, continuousFeature.ref(), ExpressionUtil.createConstant(min)), ExpressionUtil.createConstant(max - min)); 70 | 71 | if(!ValueUtil.isOne(rescaleFactor)){ 72 | expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, expression, ExpressionUtil.createConstant(rescaleFactor)); 73 | } // End if 74 | 75 | if(!ValueUtil.isZero(rescaleConstant)){ 76 | expression = ExpressionUtil.createApply(PMMLFunctions.ADD, expression, ExpressionUtil.createConstant(rescaleConstant)); 77 | } 78 | 79 | DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, expression); 80 | 81 | result.add(new ContinuousFeature(encoder, derivedField)); 82 | } 83 | 84 | return result; 85 | } 86 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/NGramConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.NGram; 25 | import org.jpmml.converter.Feature; 26 | import org.jpmml.sparkml.DocumentFeature; 27 | import org.jpmml.sparkml.FeatureConverter; 28 | import org.jpmml.sparkml.SparkMLEncoder; 29 | 30 | public class NGramConverter extends FeatureConverter { 31 | 32 | public NGramConverter(NGram transformer){ 33 | super(transformer); 34 | } 35 | 36 | @Override 37 | public List encodeFeatures(SparkMLEncoder encoder){ 38 | NGram transformer = getTransformer(); 39 | 40 | DocumentFeature documentFeature = (DocumentFeature)encoder.getOnlyFeature(transformer.getInputCol()); 41 | 42 | return Collections.singletonList(documentFeature); 43 | } 44 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/PCAModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.PCAModel; 25 | import org.apache.spark.ml.linalg.DenseMatrix; 26 | import org.dmg.pmml.Apply; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.DerivedField; 29 | import org.dmg.pmml.Expression; 30 | import org.dmg.pmml.OpType; 31 | import org.dmg.pmml.PMMLFunctions; 32 | import org.jpmml.converter.ContinuousFeature; 33 | import org.jpmml.converter.ExpressionUtil; 34 | import org.jpmml.converter.Feature; 35 | import org.jpmml.converter.ValueUtil; 36 | import org.jpmml.sparkml.FeatureConverter; 37 | import org.jpmml.sparkml.MatrixUtil; 38 | import org.jpmml.sparkml.SparkMLEncoder; 39 | 40 | public class PCAModelConverter extends FeatureConverter { 41 | 42 | public PCAModelConverter(PCAModel transformer){ 43 | super(transformer); 44 | } 45 | 46 | @Override 47 | public List encodeFeatures(SparkMLEncoder encoder){ 48 | PCAModel transformer = getTransformer(); 49 | 50 | DenseMatrix pc = transformer.pc(); 51 | 52 | List features = encoder.getFeatures(transformer.getInputCol()); 53 | 54 | MatrixUtil.checkRows(features.size(), pc); 55 | 56 | List result = new ArrayList<>(); 57 | 58 | for(int i = 0, length = transformer.getK(); i < length; i++){ 59 | Apply apply = ExpressionUtil.createApply(PMMLFunctions.SUM); 60 | 61 | for(int j = 0; j < features.size(); j++){ 62 | Feature feature = features.get(j); 63 | 64 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 65 | 66 | Expression expression = continuousFeature.ref(); 67 | 68 | Double coefficient = pc.apply(j, i); 69 | if(!ValueUtil.isOne(coefficient)){ 70 | expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, expression, ExpressionUtil.createConstant(coefficient)); 71 | } 72 | 73 | apply.addExpressions(expression); 74 | } 75 | 76 | DerivedField derivedField = encoder.createDerivedField(formatName(transformer, i, length), OpType.CONTINUOUS, DataType.DOUBLE, apply); 77 | 78 | result.add(new ContinuousFeature(encoder, derivedField)); 79 | } 80 | 81 | return result; 82 | } 83 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/RFormulaModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.PipelineModel; 24 | import org.apache.spark.ml.Transformer; 25 | import org.apache.spark.ml.feature.RFormulaModel; 26 | import org.apache.spark.ml.feature.ResolvedRFormula; 27 | import org.jpmml.converter.Feature; 28 | import org.jpmml.sparkml.ConverterFactory; 29 | import org.jpmml.sparkml.FeatureConverter; 30 | import org.jpmml.sparkml.SparkMLEncoder; 31 | import org.jpmml.sparkml.TransformerConverter; 32 | 33 | public class RFormulaModelConverter extends FeatureConverter { 34 | 35 | public RFormulaModelConverter(RFormulaModel transformer){ 36 | super(transformer); 37 | } 38 | 39 | @Override 40 | public void registerFeatures(SparkMLEncoder encoder){ 41 | RFormulaModel transformer = getTransformer(); 42 | 43 | ResolvedRFormula resolvedFormula = transformer.resolvedFormula(); 44 | 45 | String targetCol = resolvedFormula.label(); 46 | 47 | String labelCol = transformer.getLabelCol(); 48 | if(!(targetCol).equals(labelCol)){ 49 | List features = encoder.getFeatures(targetCol); 50 | 51 | encoder.putFeatures(labelCol, features); 52 | } 53 | 54 | ConverterFactory converterFactory = encoder.getConverterFactory(); 55 | 56 | PipelineModel pipelineModel = transformer.pipelineModel(); 57 | 58 | Transformer[] stages = pipelineModel.stages(); 59 | for(Transformer stage : stages){ 60 | TransformerConverter converter = converterFactory.newConverter(stage); 61 | 62 | if(converter instanceof FeatureConverter){ 63 | FeatureConverter featureConverter = (FeatureConverter)converter; 64 | 65 | featureConverter.registerFeatures(encoder); 66 | } else 67 | 68 | { 69 | throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null)); 70 | } 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/RegexTokenizerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.RegexTokenizer; 25 | import org.dmg.pmml.Apply; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.Field; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.PMMLFunctions; 30 | import org.jpmml.converter.ExpressionUtil; 31 | import org.jpmml.converter.Feature; 32 | import org.jpmml.converter.FieldNameUtil; 33 | import org.jpmml.sparkml.DocumentFeature; 34 | import org.jpmml.sparkml.FeatureConverter; 35 | import org.jpmml.sparkml.SparkMLEncoder; 36 | 37 | public class RegexTokenizerConverter extends FeatureConverter { 38 | 39 | public RegexTokenizerConverter(RegexTokenizer transformer){ 40 | super(transformer); 41 | } 42 | 43 | @Override 44 | public List encodeFeatures(SparkMLEncoder encoder){ 45 | RegexTokenizer transformer = getTransformer(); 46 | 47 | if(!transformer.getGaps()){ 48 | throw new IllegalArgumentException("Expected splitter mode, got token matching mode"); 49 | } // End if 50 | 51 | if(transformer.getMinTokenLength() != 1){ 52 | throw new IllegalArgumentException("Expected 1 as minimum token length, got " + transformer.getMinTokenLength() + " as minimum token length"); 53 | } 54 | 55 | Feature feature = encoder.getOnlyFeature(transformer.getInputCol()); 56 | 57 | Field field = feature.getField(); 58 | 59 | if(transformer.getToLowercase()){ 60 | Apply apply = ExpressionUtil.createApply(PMMLFunctions.LOWERCASE, feature.ref()); 61 | 62 | field = encoder.createDerivedField(FieldNameUtil.create(PMMLFunctions.LOWERCASE, feature), OpType.CATEGORICAL, DataType.STRING, apply); 63 | } 64 | 65 | return Collections.singletonList(new DocumentFeature(encoder, field, transformer.getPattern())); 66 | } 67 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/SparseToDenseTransformerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.jpmml.converter.Feature; 24 | import org.jpmml.sparkml.FeatureConverter; 25 | import org.jpmml.sparkml.SparkMLEncoder; 26 | import org.jpmml.sparkml.feature.SparseToDenseTransformer; 27 | 28 | public class SparseToDenseTransformerConverter extends FeatureConverter { 29 | 30 | public SparseToDenseTransformerConverter(SparseToDenseTransformer transformer){ 31 | super(transformer); 32 | } 33 | 34 | @Override 35 | public List encodeFeatures(SparkMLEncoder encoder){ 36 | SparseToDenseTransformer transformer = getTransformer(); 37 | 38 | List features = encoder.getFeatures(transformer.getInputCol()); 39 | 40 | return features; 41 | } 42 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/StopWordsRemoverConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.regex.Pattern; 24 | 25 | import org.apache.spark.ml.feature.StopWordsRemover; 26 | import org.jpmml.converter.Feature; 27 | import org.jpmml.sparkml.DocumentFeature; 28 | import org.jpmml.sparkml.MultiFeatureConverter; 29 | import org.jpmml.sparkml.SparkMLEncoder; 30 | import org.jpmml.sparkml.TermUtil; 31 | 32 | public class StopWordsRemoverConverter extends MultiFeatureConverter { 33 | 34 | public StopWordsRemoverConverter(StopWordsRemover transformer){ 35 | super(transformer); 36 | } 37 | 38 | @Override 39 | public List encodeFeatures(SparkMLEncoder encoder){ 40 | StopWordsRemover transformer = getTransformer(); 41 | 42 | boolean caseSensitive = transformer.getCaseSensitive(); 43 | String[] stopWords = transformer.getStopWords(); 44 | 45 | InOutMode inputMode = getInputMode(); 46 | 47 | List result = new ArrayList<>(); 48 | 49 | String[] inputCols = inputMode.getInputCols(transformer); 50 | for(String inputCol : inputCols){ 51 | DocumentFeature documentFeature = (DocumentFeature)encoder.getOnlyFeature(inputCol); 52 | 53 | Pattern pattern = Pattern.compile(documentFeature.getWordSeparatorRE()); 54 | 55 | DocumentFeature.StopWordSet stopWordSet = new DocumentFeature.StopWordSet(caseSensitive); 56 | 57 | for(String stopWord : stopWords){ 58 | String[] stopTokens = pattern.split(stopWord); 59 | 60 | // Skip multi-token stopwords. See https://issues.apache.org/jira/browse/SPARK-18374 61 | if(stopTokens.length > 1){ 62 | continue; 63 | } // End if 64 | 65 | if(TermUtil.hasPunctuation(stopWord)){ 66 | throw new IllegalArgumentException("Punctuated stop words (" + stopWord + ") are not supported"); 67 | } 68 | 69 | stopWordSet.add(stopWord); 70 | } 71 | 72 | documentFeature.addStopWordSet(stopWordSet); 73 | 74 | result.add(documentFeature); 75 | } 76 | 77 | return result; 78 | } 79 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/TokenizerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.Tokenizer; 25 | import org.dmg.pmml.Apply; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.DerivedField; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.PMMLFunctions; 30 | import org.jpmml.converter.ExpressionUtil; 31 | import org.jpmml.converter.Feature; 32 | import org.jpmml.converter.FieldNameUtil; 33 | import org.jpmml.sparkml.DocumentFeature; 34 | import org.jpmml.sparkml.FeatureConverter; 35 | import org.jpmml.sparkml.SparkMLEncoder; 36 | 37 | public class TokenizerConverter extends FeatureConverter { 38 | 39 | public TokenizerConverter(Tokenizer transformer){ 40 | super(transformer); 41 | } 42 | 43 | @Override 44 | public List encodeFeatures(SparkMLEncoder encoder){ 45 | Tokenizer transformer = getTransformer(); 46 | 47 | Feature feature = encoder.getOnlyFeature(transformer.getInputCol()); 48 | 49 | Apply apply = ExpressionUtil.createApply(PMMLFunctions.LOWERCASE, feature.ref()); 50 | 51 | DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create(PMMLFunctions.LOWERCASE, feature), OpType.CATEGORICAL, DataType.STRING, apply); 52 | 53 | return Collections.singletonList(new DocumentFeature(encoder, derivedField, "\\s+")); 54 | } 55 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorAssemblerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.VectorAssembler; 25 | import org.jpmml.converter.Feature; 26 | import org.jpmml.sparkml.FeatureConverter; 27 | import org.jpmml.sparkml.SparkMLEncoder; 28 | 29 | public class VectorAssemblerConverter extends FeatureConverter { 30 | 31 | public VectorAssemblerConverter(VectorAssembler transformer){ 32 | super(transformer); 33 | } 34 | 35 | @Override 36 | public List encodeFeatures(SparkMLEncoder encoder){ 37 | VectorAssembler transformer = getTransformer(); 38 | 39 | List result = new ArrayList<>(); 40 | 41 | String[] inputCols = transformer.getInputCols(); 42 | for(String inputCol : inputCols){ 43 | List features = encoder.getFeatures(inputCol); 44 | 45 | result.addAll(features); 46 | } 47 | 48 | return result; 49 | } 50 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorAttributeRewriterConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import org.apache.spark.ml.feature.VectorAttributeRewriter; 22 | import org.jpmml.sparkml.FeatureConverter; 23 | 24 | public class VectorAttributeRewriterConverter extends FeatureConverter { 25 | 26 | public VectorAttributeRewriterConverter(VectorAttributeRewriter transformer){ 27 | super(transformer); 28 | } 29 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorSizeHintConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.feature.VectorSizeHint; 24 | import org.jpmml.converter.Feature; 25 | import org.jpmml.converter.SchemaUtil; 26 | import org.jpmml.sparkml.FeatureConverter; 27 | import org.jpmml.sparkml.SparkMLEncoder; 28 | 29 | public class VectorSizeHintConverter extends FeatureConverter { 30 | 31 | public VectorSizeHintConverter(VectorSizeHint transformer){ 32 | super(transformer); 33 | } 34 | 35 | @Override 36 | public List encodeFeatures(SparkMLEncoder encoder){ 37 | VectorSizeHint transformer = getTransformer(); 38 | 39 | int size = transformer.getSize(); 40 | 41 | List features = encoder.getFeatures(transformer.getInputCol()); 42 | 43 | SchemaUtil.checkSize(size, features); 44 | 45 | return features; 46 | } 47 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorSlicerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.feature.VectorSlicer; 24 | import org.jpmml.converter.Feature; 25 | import org.jpmml.sparkml.FeatureConverter; 26 | import org.jpmml.sparkml.SparkMLEncoder; 27 | 28 | public class VectorSlicerConverter extends FeatureConverter { 29 | 30 | public VectorSlicerConverter(VectorSlicer transformer){ 31 | super(transformer); 32 | } 33 | 34 | @Override 35 | public List encodeFeatures(SparkMLEncoder encoder){ 36 | VectorSlicer transformer = getTransformer(); 37 | 38 | String[] names = transformer.getNames(); 39 | if(names != null && names.length > 0){ 40 | throw new IllegalArgumentException("Expected index mode, got name mode"); 41 | } 42 | 43 | return encoder.getFeatures(transformer.getInputCol(), transformer.getIndices()); 44 | } 45 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/DecisionTreeClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel; 22 | import org.apache.spark.ml.linalg.Vector; 23 | import org.dmg.pmml.tree.TreeModel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 26 | 27 | public class DecisionTreeClassificationModelConverter extends ProbabilisticClassificationModelConverter implements HasFeatureImportances, HasTreeOptions { 28 | 29 | public DecisionTreeClassificationModelConverter(DecisionTreeClassificationModel model){ 30 | super(model); 31 | } 32 | 33 | @Override 34 | public Vector getFeatureImportances(){ 35 | DecisionTreeClassificationModel model = getModel(); 36 | 37 | return model.featureImportances(); 38 | } 39 | 40 | @Override 41 | public TreeModel encodeModel(Schema schema){ 42 | return TreeModelUtil.encodeDecisionTree(this, schema); 43 | } 44 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/DecisionTreeRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.linalg.Vector; 22 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 23 | import org.dmg.pmml.tree.TreeModel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.sparkml.RegressionModelConverter; 26 | 27 | public class DecisionTreeRegressionModelConverter extends RegressionModelConverter implements HasFeatureImportances, HasTreeOptions { 28 | 29 | public DecisionTreeRegressionModelConverter(DecisionTreeRegressionModel model){ 30 | super(model); 31 | } 32 | 33 | @Override 34 | public Vector getFeatureImportances(){ 35 | DecisionTreeRegressionModel model = getModel(); 36 | 37 | return model.featureImportances(); 38 | } 39 | 40 | @Override 41 | public TreeModel encodeModel(Schema schema){ 42 | return TreeModelUtil.encodeDecisionTree(this, schema); 43 | } 44 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/GBTClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.primitives.Doubles; 24 | import org.apache.spark.ml.classification.GBTClassificationModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.MiningFunction; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.mining.MiningModel; 30 | import org.dmg.pmml.mining.Segmentation; 31 | import org.dmg.pmml.regression.RegressionModel; 32 | import org.dmg.pmml.tree.TreeModel; 33 | import org.jpmml.converter.ModelUtil; 34 | import org.jpmml.converter.Schema; 35 | import org.jpmml.converter.mining.MiningModelUtil; 36 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 37 | 38 | public class GBTClassificationModelConverter extends ProbabilisticClassificationModelConverter implements HasFeatureImportances, HasTreeOptions { 39 | 40 | public GBTClassificationModelConverter(GBTClassificationModel model){ 41 | super(model); 42 | } 43 | 44 | @Override 45 | public Vector getFeatureImportances(){ 46 | GBTClassificationModel model = getModel(); 47 | 48 | return model.featureImportances(); 49 | } 50 | 51 | @Override 52 | public MiningModel encodeModel(Schema schema){ 53 | GBTClassificationModel model = getModel(); 54 | 55 | String lossType = model.getLossType(); 56 | switch(lossType){ 57 | case "logistic": 58 | break; 59 | default: 60 | throw new IllegalArgumentException("Loss function " + lossType + " is not supported"); 61 | } 62 | 63 | Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); 64 | 65 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, segmentSchema); 66 | 67 | MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())) 68 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels, Doubles.asList(model.treeWeights()))) 69 | .setOutput(ModelUtil.createPredictedOutput("gbtValue", OpType.CONTINUOUS, DataType.DOUBLE)); 70 | 71 | return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema); 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/GBTRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.primitives.Doubles; 24 | import org.apache.spark.ml.linalg.Vector; 25 | import org.apache.spark.ml.regression.GBTRegressionModel; 26 | import org.dmg.pmml.MiningFunction; 27 | import org.dmg.pmml.mining.MiningModel; 28 | import org.dmg.pmml.mining.Segmentation; 29 | import org.dmg.pmml.tree.TreeModel; 30 | import org.jpmml.converter.ModelUtil; 31 | import org.jpmml.converter.Schema; 32 | import org.jpmml.converter.mining.MiningModelUtil; 33 | import org.jpmml.sparkml.RegressionModelConverter; 34 | 35 | public class GBTRegressionModelConverter extends RegressionModelConverter implements HasFeatureImportances, HasTreeOptions { 36 | 37 | public GBTRegressionModelConverter(GBTRegressionModel model){ 38 | super(model); 39 | } 40 | 41 | @Override 42 | public Vector getFeatureImportances(){ 43 | GBTRegressionModel model = getModel(); 44 | 45 | return model.featureImportances(); 46 | } 47 | 48 | @Override 49 | public MiningModel encodeModel(Schema schema){ 50 | GBTRegressionModel model = getModel(); 51 | 52 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, schema); 53 | 54 | MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())) 55 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels, Doubles.asList(model.treeWeights()))); 56 | 57 | return miningModel; 58 | } 59 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasFeatureImportances.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.linalg.Vector; 22 | 23 | public interface HasFeatureImportances { 24 | 25 | Vector getFeatureImportances(); 26 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasPredictionModelOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.jpmml.sparkml.HasSparkMLOptions; 22 | 23 | public interface HasPredictionModelOptions extends HasSparkMLOptions { 24 | 25 | String OPTION_KEEP_PREDICTIONCOL = "keep_predictionCol"; 26 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasRegressionTableOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.jpmml.sparkml.HasSparkMLOptions; 22 | 23 | public interface HasRegressionTableOptions extends HasSparkMLOptions { 24 | 25 | String OPTION_REPRESENTATION = "representation"; 26 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasTreeOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.LinkedHashMap; 22 | import java.util.Map; 23 | 24 | import org.jpmml.converter.HasNativeConfiguration; 25 | import org.jpmml.sparkml.HasSparkMLOptions; 26 | import org.jpmml.sparkml.visitors.TreeModelCompactor; 27 | 28 | public interface HasTreeOptions extends HasSparkMLOptions, HasNativeConfiguration { 29 | 30 | /** 31 | * @see TreeModelCompactor 32 | */ 33 | String OPTION_COMPACT = "compact"; 34 | 35 | String OPTION_ESTIMATE_FEATURE_IMPORTANCES = "estimate_featureImportances"; 36 | 37 | @Override 38 | default 39 | public Map getNativeConfiguration(){ 40 | Map result = new LinkedHashMap<>(); 41 | result.put(HasTreeOptions.OPTION_COMPACT, Boolean.FALSE); 42 | 43 | return result; 44 | } 45 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/KMeansModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.clustering.KMeansModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.CompareFunction; 27 | import org.dmg.pmml.ComparisonMeasure; 28 | import org.dmg.pmml.MiningFunction; 29 | import org.dmg.pmml.SquaredEuclidean; 30 | import org.dmg.pmml.clustering.Cluster; 31 | import org.dmg.pmml.clustering.ClusteringModel; 32 | import org.jpmml.converter.ModelUtil; 33 | import org.jpmml.converter.PMMLUtil; 34 | import org.jpmml.converter.Schema; 35 | import org.jpmml.converter.clustering.ClusteringModelUtil; 36 | import org.jpmml.sparkml.ClusteringModelConverter; 37 | import org.jpmml.sparkml.VectorUtil; 38 | 39 | public class KMeansModelConverter extends ClusteringModelConverter { 40 | 41 | public KMeansModelConverter(KMeansModel model){ 42 | super(model); 43 | } 44 | 45 | @Override 46 | public int getNumberOfClusters(){ 47 | KMeansModel model = getModel(); 48 | 49 | return model.getK(); 50 | } 51 | 52 | @Override 53 | public ClusteringModel encodeModel(Schema schema){ 54 | KMeansModel model = getModel(); 55 | 56 | List clusters = new ArrayList<>(); 57 | 58 | Vector[] clusterCenters = model.clusterCenters(); 59 | for(int i = 0; i < clusterCenters.length; i++){ 60 | Cluster cluster = new Cluster(PMMLUtil.createRealArray(VectorUtil.toList(clusterCenters[i]))) 61 | .setId(String.valueOf(i)); 62 | 63 | clusters.add(cluster); 64 | } 65 | 66 | ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()) 67 | .setCompareFunction(CompareFunction.ABS_DIFF); 68 | 69 | return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters); 70 | } 71 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/LinearRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.regression.LinearRegressionModel; 22 | import org.dmg.pmml.Model; 23 | import org.jpmml.converter.Schema; 24 | import org.jpmml.sparkml.RegressionModelConverter; 25 | 26 | public class LinearRegressionModelConverter extends RegressionModelConverter implements HasRegressionTableOptions { 27 | 28 | public LinearRegressionModelConverter(LinearRegressionModel model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public Model encodeModel(Schema schema){ 34 | LinearRegressionModel model = getModel(); 35 | 36 | return LinearModelUtil.createRegression(this, model.coefficients(), model.intercept(), schema); 37 | } 38 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/LinearSVCModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.LinearSVCModel; 22 | import org.dmg.pmml.DataType; 23 | import org.dmg.pmml.Expression; 24 | import org.dmg.pmml.FieldRef; 25 | import org.dmg.pmml.Model; 26 | import org.dmg.pmml.OpType; 27 | import org.dmg.pmml.PMMLFunctions; 28 | import org.dmg.pmml.mining.MiningModel; 29 | import org.dmg.pmml.regression.RegressionModel; 30 | import org.jpmml.converter.ExpressionUtil; 31 | import org.jpmml.converter.FieldNameUtil; 32 | import org.jpmml.converter.ModelUtil; 33 | import org.jpmml.converter.Schema; 34 | import org.jpmml.converter.Transformation; 35 | import org.jpmml.converter.mining.MiningModelUtil; 36 | import org.jpmml.converter.transformations.AbstractTransformation; 37 | import org.jpmml.sparkml.ClassificationModelConverter; 38 | 39 | public class LinearSVCModelConverter extends ClassificationModelConverter implements HasRegressionTableOptions { 40 | 41 | public LinearSVCModelConverter(LinearSVCModel model){ 42 | super(model); 43 | } 44 | 45 | @Override 46 | public MiningModel encodeModel(Schema schema){ 47 | LinearSVCModel model = getModel(); 48 | 49 | Transformation transformation = new AbstractTransformation(){ 50 | 51 | @Override 52 | public String getName(String name){ 53 | return FieldNameUtil.create(PMMLFunctions.THRESHOLD, name); 54 | } 55 | 56 | @Override 57 | public Expression createExpression(FieldRef fieldRef){ 58 | return ExpressionUtil.createApply(PMMLFunctions.THRESHOLD) 59 | .addExpressions(fieldRef, ExpressionUtil.createConstant(model.getThreshold())); 60 | } 61 | }; 62 | 63 | Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); 64 | 65 | Model linearModel = LinearModelUtil.createRegression(this, model.coefficients(), model.intercept(), segmentSchema) 66 | .setOutput(ModelUtil.createPredictedOutput("margin", OpType.CONTINUOUS, DataType.DOUBLE, transformation)); 67 | 68 | return MiningModelUtil.createBinaryLogisticClassification(linearModel, 1d, 0d, RegressionModel.NormalizationMethod.NONE, false, schema); 69 | } 70 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/LogisticRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.LogisticRegressionModel; 22 | import org.dmg.pmml.Model; 23 | import org.jpmml.converter.CategoricalLabel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 26 | 27 | public class LogisticRegressionModelConverter extends ProbabilisticClassificationModelConverter implements HasRegressionTableOptions { 28 | 29 | public LogisticRegressionModelConverter(LogisticRegressionModel model){ 30 | super(model); 31 | } 32 | 33 | @Override 34 | public Model encodeModel(Schema schema){ 35 | LogisticRegressionModel model = getModel(); 36 | 37 | CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); 38 | 39 | if(categoricalLabel.size() == 2){ 40 | Model linearModel = LinearModelUtil.createBinaryLogisticClassification(this, model.coefficients(), model.intercept(), schema) 41 | .setOutput(null); 42 | 43 | return linearModel; 44 | } else 45 | 46 | if(categoricalLabel.size() > 2){ 47 | Model linearModel = LinearModelUtil.createSoftmaxClassification(this, model.coefficientMatrix(), model.interceptVector(), schema) 48 | .setOutput(null); 49 | 50 | return linearModel; 51 | } else 52 | 53 | { 54 | throw new IllegalArgumentException(); 55 | } 56 | } 57 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.NaiveBayesModel; 22 | import org.dmg.pmml.Model; 23 | import org.jpmml.converter.Schema; 24 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 25 | 26 | public class NaiveBayesModelConverter extends ProbabilisticClassificationModelConverter implements HasRegressionTableOptions { 27 | 28 | public NaiveBayesModelConverter(NaiveBayesModel model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public Model encodeModel(Schema schema){ 34 | NaiveBayesModel model = getModel(); 35 | 36 | String modelType = model.getModelType(); 37 | switch(modelType){ 38 | case "multinomial": 39 | break; 40 | default: 41 | throw new IllegalArgumentException("Model type " + modelType + " is not supported"); 42 | } 43 | 44 | if(model.isSet(model.thresholds())){ 45 | double[] thresholds = model.getThresholds(); 46 | 47 | for(int i = 0; i < thresholds.length; i++){ 48 | double threshold = thresholds[i]; 49 | 50 | if(threshold != 0d){ 51 | throw new IllegalArgumentException("Non-zero thresholds are not supported"); 52 | } 53 | } 54 | } 55 | 56 | Model linearModel = LinearModelUtil.createSoftmaxClassification(this, model.theta(), model.pi(), schema) 57 | .setOutput(null); 58 | 59 | return linearModel; 60 | } 61 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/RandomForestClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.classification.RandomForestClassificationModel; 24 | import org.apache.spark.ml.linalg.Vector; 25 | import org.dmg.pmml.MiningFunction; 26 | import org.dmg.pmml.mining.MiningModel; 27 | import org.dmg.pmml.mining.Segmentation; 28 | import org.dmg.pmml.tree.TreeModel; 29 | import org.jpmml.converter.ModelUtil; 30 | import org.jpmml.converter.Schema; 31 | import org.jpmml.converter.mining.MiningModelUtil; 32 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 33 | 34 | public class RandomForestClassificationModelConverter extends ProbabilisticClassificationModelConverter implements HasFeatureImportances, HasTreeOptions { 35 | 36 | public RandomForestClassificationModelConverter(RandomForestClassificationModel model){ 37 | super(model); 38 | } 39 | 40 | @Override 41 | public Vector getFeatureImportances(){ 42 | RandomForestClassificationModel model = getModel(); 43 | 44 | return model.featureImportances(); 45 | } 46 | 47 | @Override 48 | public MiningModel encodeModel(Schema schema){ 49 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, schema); 50 | 51 | MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())) 52 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)); 53 | 54 | return miningModel; 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/RandomForestRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.linalg.Vector; 24 | import org.apache.spark.ml.regression.RandomForestRegressionModel; 25 | import org.dmg.pmml.MiningFunction; 26 | import org.dmg.pmml.mining.MiningModel; 27 | import org.dmg.pmml.mining.Segmentation; 28 | import org.dmg.pmml.tree.TreeModel; 29 | import org.jpmml.converter.ModelUtil; 30 | import org.jpmml.converter.Schema; 31 | import org.jpmml.converter.mining.MiningModelUtil; 32 | import org.jpmml.sparkml.RegressionModelConverter; 33 | 34 | public class RandomForestRegressionModelConverter extends RegressionModelConverter implements HasFeatureImportances, HasTreeOptions { 35 | 36 | public RandomForestRegressionModelConverter(RandomForestRegressionModel model){ 37 | super(model); 38 | } 39 | 40 | @Override 41 | public Vector getFeatureImportances(){ 42 | RandomForestRegressionModel model = getModel(); 43 | 44 | return model.featureImportances(); 45 | } 46 | 47 | @Override 48 | public MiningModel encodeModel(Schema schema){ 49 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, schema); 50 | 51 | MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())) 52 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)); 53 | 54 | return miningModel; 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/RegressionTableUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.ListIterator; 24 | 25 | import org.dmg.pmml.DataType; 26 | import org.dmg.pmml.DerivedField; 27 | import org.dmg.pmml.MapValues; 28 | import org.dmg.pmml.OpType; 29 | import org.jpmml.converter.BinaryFeature; 30 | import org.jpmml.converter.ContinuousFeature; 31 | import org.jpmml.converter.ExpressionUtil; 32 | import org.jpmml.converter.Feature; 33 | import org.jpmml.converter.FieldNameUtil; 34 | import org.jpmml.converter.PMMLEncoder; 35 | 36 | public class RegressionTableUtil { 37 | 38 | private RegressionTableUtil(){ 39 | } 40 | 41 | static 42 | private MapValues createMapValues(String name, Object identifier, List features, List coefficients){ 43 | ListIterator featureIt = features.listIterator(); 44 | ListIterator coefficientIt = coefficients.listIterator(); 45 | 46 | PMMLEncoder encoder = null; 47 | 48 | List inputValues = new ArrayList<>(); 49 | List outputValues = new ArrayList<>(); 50 | 51 | while(featureIt.hasNext()){ 52 | Feature feature = featureIt.next(); 53 | Double coefficient = coefficientIt.next(); 54 | 55 | if(!(feature instanceof BinaryFeature)){ 56 | continue; 57 | } 58 | 59 | BinaryFeature binaryFeature = (BinaryFeature)feature; 60 | if(!(name).equals(binaryFeature.getName())){ 61 | continue; 62 | } 63 | 64 | featureIt.remove(); 65 | coefficientIt.remove(); 66 | 67 | if(encoder == null){ 68 | encoder = binaryFeature.getEncoder(); 69 | } 70 | 71 | inputValues.add(binaryFeature.getValue()); 72 | outputValues.add(coefficient); 73 | } 74 | 75 | MapValues mapValues = ExpressionUtil.createMapValues(name, inputValues, outputValues) 76 | .setDefaultValue(0d) 77 | .setDataType(DataType.DOUBLE); 78 | 79 | DerivedField derivedField = encoder.createDerivedField(identifier != null ? FieldNameUtil.create("lookup", name, identifier) : FieldNameUtil.create("lookup", name), OpType.CONTINUOUS, DataType.DOUBLE, mapValues); 80 | 81 | featureIt.add(new ContinuousFeature(encoder, derivedField)); 82 | coefficientIt.add(1d); 83 | 84 | return mapValues; 85 | } 86 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/testing/SparkMLEncoderBatchTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import java.util.function.Predicate; 22 | 23 | import com.google.common.base.Equivalence; 24 | import org.apache.spark.sql.SparkSession; 25 | import org.jpmml.converter.FieldNameUtil; 26 | import org.jpmml.converter.testing.ModelEncoderBatchTest; 27 | import org.jpmml.evaluator.ResultField; 28 | import org.jpmml.evaluator.testing.PMMLEquivalence; 29 | import org.jpmml.sparkml.SparkSessionUtil; 30 | 31 | public class SparkMLEncoderBatchTest extends ModelEncoderBatchTest { 32 | 33 | public SparkMLEncoderBatchTest(){ 34 | this(new PMMLEquivalence(1e-14, 1e-14)); 35 | } 36 | 37 | public SparkMLEncoderBatchTest(Equivalence equivalence){ 38 | super(equivalence); 39 | } 40 | 41 | /** 42 | * @see #createSparkSession() 43 | * @see #destroySparkSession() 44 | */ 45 | public SparkSession getSparkSession(){ 46 | return SparkMLEncoderBatchTest.sparkSession; 47 | } 48 | 49 | @Override 50 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 51 | SparkMLEncoderBatch result = new SparkMLEncoderBatch(algorithm, dataset, columnFilter, equivalence){ 52 | 53 | @Override 54 | public SparkMLEncoderBatchTest getArchiveBatchTest(){ 55 | return SparkMLEncoderBatchTest.this; 56 | } 57 | }; 58 | 59 | return result; 60 | } 61 | 62 | static 63 | public void createSparkSession(){ 64 | SparkMLEncoderBatchTest.sparkSession = SparkSessionUtil.createSparkSession(); 65 | } 66 | 67 | static 68 | public void destroySparkSession(){ 69 | SparkMLEncoderBatchTest.sparkSession = SparkSessionUtil.destroySparkSession(SparkMLEncoderBatchTest.sparkSession); 70 | } 71 | 72 | static 73 | public Predicate excludePredictionFields(){ 74 | return excludePredictionFields("prediction"); 75 | } 76 | 77 | static 78 | public Predicate excludePredictionFields(String predictionCol){ 79 | return excludeFields(predictionCol, FieldNameUtil.create("pmml", predictionCol)); 80 | } 81 | 82 | private static SparkSession sparkSession = null; 83 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/scala/org/jpmml/sparkml/feature/SparseToDenseTransformer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature 20 | 21 | import org.apache.spark.ml.Transformer 22 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector} 23 | import org.apache.spark.ml.linalg.SQLDataTypes.VectorType 24 | import org.apache.spark.ml.param.ParamMap 25 | import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} 26 | import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} 27 | import org.apache.spark.sql.{Dataset, Row} 28 | import org.apache.spark.sql.functions.udf 29 | import org.apache.spark.sql.types.{StructField, StructType} 30 | 31 | class SparseToDenseTransformer(override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { 32 | 33 | def this() = this(Identifiable.randomUID("sparse2dense")) 34 | 35 | def setInputCol(value: String): this.type = set(inputCol, value) 36 | 37 | def setOutputCol(value: String): this.type = set(outputCol, value) 38 | 39 | override 40 | def copy(extra: ParamMap): SparseToDenseTransformer = defaultCopy(extra) 41 | 42 | override 43 | def transformSchema(schema: StructType): StructType = { 44 | val inputColName = $(inputCol) 45 | val outputColName = $(outputCol) 46 | 47 | val inputFields = schema.fields 48 | 49 | require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists") 50 | 51 | val inputField = schema(inputColName) 52 | val outputField = new StructField(outputColName, inputField.dataType, inputField.nullable) 53 | 54 | StructType(inputFields :+ outputField) 55 | } 56 | 57 | override 58 | def transform(dataset: Dataset[_]): Dataset[Row] = { 59 | val inputColName = $(inputCol) 60 | val outputColName = $(outputCol) 61 | 62 | transformSchema(dataset.schema, logging = true) 63 | 64 | val converter = udf { vec: Vector => vec.toDense } 65 | 66 | dataset.withColumn(outputColName, converter(dataset(inputColName))) 67 | } 68 | } 69 | 70 | object SparseToDenseTransformer extends DefaultParamsReadable[SparseToDenseTransformer] { 71 | 72 | override 73 | def load(path: String): SparseToDenseTransformer = super.load(path) 74 | } 75 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/AliasExpressionTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.dmg.pmml.Expression; 22 | import org.dmg.pmml.FieldRef; 23 | import org.dmg.pmml.PMMLFunctions; 24 | import org.jpmml.converter.ExpressionUtil; 25 | import org.jpmml.model.ReflectionUtil; 26 | import org.junit.jupiter.api.Test; 27 | 28 | import static org.junit.jupiter.api.Assertions.assertTrue; 29 | 30 | public class AliasExpressionTest { 31 | 32 | @Test 33 | public void unwrap(){ 34 | FieldRef fieldRef = new FieldRef("x"); 35 | 36 | Expression expression = new AliasExpression("parent", new AliasExpression("child", fieldRef)); 37 | 38 | checkExpression(fieldRef, expression); 39 | 40 | expression = new AliasExpression("parent", ExpressionUtil.createApply(PMMLFunctions.ADD, new AliasExpression("left child", fieldRef), new AliasExpression("right child", fieldRef))); 41 | 42 | checkExpression(ExpressionUtil.createApply(PMMLFunctions.ADD, fieldRef, fieldRef), expression); 43 | } 44 | 45 | static 46 | private void checkExpression(Expression expected, Expression expression){ 47 | expression = AliasExpression.unwrap(expression); 48 | 49 | assertTrue(ReflectionUtil.equals(expected, expression)); 50 | } 51 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/PMMLBuilderTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Arrays; 22 | 23 | import org.apache.spark.ml.Model; 24 | import org.apache.spark.ml.PipelineModel; 25 | import org.apache.spark.ml.classification.LogisticRegressionModel; 26 | import org.apache.spark.ml.linalg.DenseVector; 27 | import org.apache.spark.sql.types.StructType; 28 | import org.junit.jupiter.api.Test; 29 | 30 | import static org.junit.jupiter.api.Assertions.fail; 31 | 32 | public class PMMLBuilderTest { 33 | 34 | @Test 35 | public void construct(){ 36 | StructType schema = new StructType(); 37 | 38 | Model model = new LogisticRegressionModel("lrm", new DenseVector(new double[0]), 0d); 39 | 40 | try { 41 | PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, model); 42 | 43 | fail(); 44 | } catch(IllegalArgumentException iae){ 45 | // Ignored 46 | } 47 | 48 | PipelineModel pipelineModel = new PipelineModel("pm", Arrays.asList(model)); 49 | 50 | try { 51 | PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, pipelineModel); 52 | } catch(IllegalArgumentException iae){ 53 | throw iae; 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/RegexKeyTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.regex.Pattern; 22 | 23 | import org.junit.jupiter.api.Test; 24 | 25 | import static org.junit.jupiter.api.Assertions.assertFalse; 26 | import static org.junit.jupiter.api.Assertions.assertTrue; 27 | 28 | public class RegexKeyTest { 29 | 30 | @Test 31 | public void compile(){ 32 | RegexKey anyKey = new RegexKey(Pattern.compile(".*")); 33 | RegexKey dotAsteriskKey = new RegexKey(Pattern.compile(".*", Pattern.LITERAL)); 34 | 35 | assertTrue(anyKey.test("")); 36 | assertTrue(anyKey.test(".*")); 37 | 38 | assertFalse(dotAsteriskKey.test("")); 39 | assertTrue(dotAsteriskKey.test(".*")); 40 | } 41 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/SparkMLTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.sql.SparkSession; 22 | import org.junit.jupiter.api.AfterAll; 23 | import org.junit.jupiter.api.BeforeAll; 24 | 25 | abstract 26 | public class SparkMLTest { 27 | 28 | @BeforeAll 29 | static 30 | public void createSparkSession(){ 31 | SparkMLTest.sparkSession = SparkSessionUtil.createSparkSession(); 32 | } 33 | 34 | @AfterAll 35 | static 36 | public void destroySparkSession(){ 37 | SparkMLTest.sparkSession = SparkSessionUtil.destroySparkSession(SparkMLTest.sparkSession); 38 | } 39 | 40 | public static SparkSession sparkSession = null; 41 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/TermUtilTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.junit.jupiter.api.Test; 22 | 23 | import static org.junit.jupiter.api.Assertions.assertFalse; 24 | import static org.junit.jupiter.api.Assertions.assertTrue; 25 | 26 | public class TermUtilTest { 27 | 28 | @Test 29 | public void hasPunctuation(){ 30 | assertFalse(TermUtil.hasPunctuation("one")); 31 | assertTrue(TermUtil.hasPunctuation("one?")); 32 | 33 | assertFalse(TermUtil.hasPunctuation("one-half")); 34 | 35 | assertFalse(TermUtil.hasPunctuation("one two")); 36 | assertTrue(TermUtil.hasPunctuation("one, two")); 37 | } 38 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/feature/SparseToDenseTransformerTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Arrays; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.Pipeline; 25 | import org.apache.spark.ml.PipelineModel; 26 | import org.apache.spark.ml.PipelineStage; 27 | import org.apache.spark.ml.Transformer; 28 | import org.apache.spark.ml.linalg.DenseVector; 29 | import org.apache.spark.ml.linalg.SparseVector; 30 | import org.apache.spark.ml.linalg.Vector; 31 | import org.apache.spark.ml.linalg.VectorUDT; 32 | import org.apache.spark.sql.Dataset; 33 | import org.apache.spark.sql.Row; 34 | import org.apache.spark.sql.RowFactory; 35 | import org.apache.spark.sql.types.StructType; 36 | import org.jpmml.sparkml.SparkMLTest; 37 | import org.junit.jupiter.api.Test; 38 | 39 | import static org.junit.jupiter.api.Assertions.assertEquals; 40 | import static org.junit.jupiter.api.Assertions.assertNotNull; 41 | import static org.junit.jupiter.api.Assertions.assertTrue; 42 | 43 | public class SparseToDenseTransformerTest extends SparkMLTest { 44 | 45 | @Test 46 | public void transform(){ 47 | StructType schema = new StructType() 48 | .add("featureVec", new VectorUDT(), false); 49 | 50 | List rows = Arrays.asList( 51 | RowFactory.create(new SparseVector(3, new int[]{1}, new double[]{1.0})), 52 | RowFactory.create(new DenseVector(new double[]{0.0d, 0.0d, 1.0d})), 53 | RowFactory.create(new SparseVector(3, new int[]{0}, new double[]{1.0})) 54 | ); 55 | 56 | Dataset ds = SparkMLTest.sparkSession.createDataFrame(rows, schema); 57 | 58 | Transformer transformer = new SparseToDenseTransformer() 59 | .setInputCol("featureVec") 60 | .setOutputCol("denseFeatureVec"); 61 | 62 | Pipeline pipeline = new Pipeline() 63 | .setStages(new PipelineStage[]{transformer}); 64 | 65 | PipelineModel pipelineModel = pipeline.fit(ds); 66 | 67 | Dataset transformedDs = pipelineModel.transform(ds); 68 | 69 | assertNotNull(transformedDs.col("featureVec")); 70 | assertNotNull(transformedDs.col("denseFeatureVec")); 71 | 72 | List transformedRows = transformedDs 73 | .select("featureVec", "denseFeatureVec") 74 | .collectAsList(); 75 | 76 | for(int i = 0; i < 3; i++){ 77 | Row transformedRow = transformedRows.get(i); 78 | 79 | Vector vector = (Vector)transformedRow.get(0); 80 | Vector denseVector = (Vector)transformedRow.get(1); 81 | 82 | assertEquals(i == 1 ? 3 : 1, vector.numActives()); 83 | assertEquals(1, vector.numNonzeros()); 84 | assertEquals(3, vector.size()); 85 | 86 | assertTrue(denseVector instanceof DenseVector); 87 | 88 | assertEquals(3, denseVector.numActives()); 89 | assertEquals(1, denseVector.numNonzeros()); 90 | assertEquals(3, denseVector.size()); 91 | } 92 | } 93 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/AssociationRulesTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import java.util.function.Predicate; 22 | 23 | import com.google.common.base.Equivalence; 24 | import com.google.common.collect.Iterables; 25 | import org.apache.spark.sql.Dataset; 26 | import org.apache.spark.sql.Row; 27 | import org.dmg.pmml.Model; 28 | import org.dmg.pmml.PMML; 29 | import org.dmg.pmml.association.AssociationModel; 30 | import org.jpmml.converter.testing.Datasets; 31 | import org.jpmml.evaluator.ResultField; 32 | import org.junit.jupiter.api.Test; 33 | 34 | import static org.junit.jupiter.api.Assertions.assertTrue; 35 | 36 | public class AssociationRulesTest extends SimpleSparkMLEncoderBatchTest implements SparkMLAlgorithms, Datasets { 37 | 38 | @Override 39 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 40 | columnFilter = columnFilter.and(excludePredictionFields()); 41 | 42 | SparkMLEncoderBatch result = new SparkMLEncoderBatch(algorithm, dataset, columnFilter, equivalence){ 43 | 44 | @Override 45 | public AssociationRulesTest getArchiveBatchTest(){ 46 | return AssociationRulesTest.this; 47 | } 48 | 49 | @Override 50 | public Dataset getVerificationDataset(Dataset inputDataset){ 51 | return null; 52 | } 53 | }; 54 | 55 | return result; 56 | } 57 | 58 | @Test 59 | public void evaluateFPGrowthShopping() throws Exception { 60 | Predicate predicate = (resultField -> true); 61 | Equivalence equivalence = getEquivalence(); 62 | 63 | try(SparkMLEncoderBatch batch = createBatch(FP_GROWTH, SHOPPING, predicate, equivalence)){ 64 | PMML pmml = batch.getPMML(); 65 | 66 | Model model = Iterables.getOnlyElement(pmml.getModels()); 67 | 68 | assertTrue(model instanceof AssociationModel); 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/ClusteringTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import org.jpmml.converter.FieldNameUtil; 22 | import org.jpmml.converter.testing.Datasets; 23 | import org.junit.jupiter.api.Test; 24 | 25 | public class ClusteringTest extends SimpleSparkMLEncoderBatchTest implements SparkMLAlgorithms, Datasets { 26 | 27 | @Test 28 | public void evaluateKMeansIris() throws Exception { 29 | String[] outputFields = {FieldNameUtil.create("pmml", "cluster")}; 30 | 31 | evaluate(K_MEANS, IRIS, excludeFields(outputFields)); 32 | } 33 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/SimpleSparkMLEncoderBatchTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import java.util.function.Predicate; 22 | 23 | import com.google.common.base.Equivalence; 24 | import org.jpmml.evaluator.ResultField; 25 | import org.junit.jupiter.api.AfterAll; 26 | import org.junit.jupiter.api.BeforeAll; 27 | 28 | abstract 29 | public class SimpleSparkMLEncoderBatchTest extends SparkMLEncoderBatchTest { 30 | 31 | @Override 32 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 33 | columnFilter = columnFilter.and(excludePredictionFields()); 34 | 35 | return super.createBatch(algorithm, dataset, columnFilter, equivalence); 36 | } 37 | 38 | @BeforeAll 39 | static 40 | public void createSparkSession(){ 41 | SparkMLEncoderBatchTest.createSparkSession(); 42 | } 43 | 44 | @AfterAll 45 | static 46 | public void destroySparkSession(){ 47 | SparkMLEncoderBatchTest.destroySparkSession(); 48 | } 49 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/SparkMLAlgorithms.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | interface SparkMLAlgorithms { 22 | 23 | String DECISION_TREE = "DecisionTree"; 24 | String FP_GROWTH = "FPGrowth"; 25 | String GBT = "GBT"; 26 | String GLM = "GLM"; 27 | String K_MEANS = "KMeans"; 28 | String LINEAR_REGRESION = "LinearRegression"; 29 | String LINEAR_SVC = "LinearSVC"; 30 | String LOGISTIC_REGRESSION = "LogisticRegression"; 31 | String MODEL_CHAIN = "ModelChain"; 32 | String NAIVE_BAYES = "NaiveBayes"; 33 | String NEURAL_NETWORK = "NeuralNetwork"; 34 | String RANDOM_FOREST = "RandomForest"; 35 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/csv/KMeansIris.csv: -------------------------------------------------------------------------------- 1 | cluster 2 | 1 3 | 10 4 | 10 5 | 10 6 | 1 7 | 1 8 | 10 9 | 1 10 | 10 11 | 10 12 | 1 13 | 10 14 | 10 15 | 10 16 | 1 17 | 1 18 | 1 19 | 1 20 | 1 21 | 1 22 | 1 23 | 1 24 | 10 25 | 1 26 | 10 27 | 10 28 | 1 29 | 1 30 | 1 31 | 10 32 | 10 33 | 1 34 | 1 35 | 1 36 | 10 37 | 10 38 | 1 39 | 1 40 | 10 41 | 1 42 | 1 43 | 10 44 | 10 45 | 1 46 | 1 47 | 10 48 | 1 49 | 10 50 | 1 51 | 10 52 | 3 53 | 3 54 | 3 55 | 6 56 | 3 57 | 2 58 | 3 59 | 5 60 | 3 61 | 6 62 | 5 63 | 2 64 | 6 65 | 9 66 | 6 67 | 3 68 | 2 69 | 6 70 | 9 71 | 6 72 | 4 73 | 6 74 | 9 75 | 9 76 | 3 77 | 3 78 | 3 79 | 3 80 | 2 81 | 6 82 | 6 83 | 6 84 | 6 85 | 4 86 | 2 87 | 2 88 | 3 89 | 9 90 | 2 91 | 6 92 | 2 93 | 2 94 | 6 95 | 5 96 | 2 97 | 2 98 | 2 99 | 2 100 | 5 101 | 2 102 | 7 103 | 4 104 | 8 105 | 8 106 | 8 107 | 0 108 | 2 109 | 0 110 | 8 111 | 0 112 | 8 113 | 8 114 | 8 115 | 4 116 | 4 117 | 7 118 | 8 119 | 0 120 | 0 121 | 9 122 | 8 123 | 4 124 | 0 125 | 9 126 | 8 127 | 0 128 | 9 129 | 4 130 | 8 131 | 8 132 | 0 133 | 0 134 | 8 135 | 9 136 | 4 137 | 0 138 | 7 139 | 8 140 | 4 141 | 8 142 | 8 143 | 8 144 | 4 145 | 8 146 | 7 147 | 8 148 | 9 149 | 8 150 | 7 151 | 4 152 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/DecisionTreeAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/DecisionTreeAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/DecisionTreeHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/DecisionTreeIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/DecisionTreeSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/FPGrowthShopping.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/FPGrowthShopping.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GBTAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GBTAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GBTAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GBTAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GLMAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GLMAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GLMHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GLMSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMVisit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/GLMVisit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/KMeansIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/KMeansIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearRegressionAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/LinearRegressionAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearRegressionHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/LinearRegressionHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearSVCSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/LinearSVCSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LogisticRegressionAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/LogisticRegressionAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LogisticRegressionIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/LogisticRegressionIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/ModelChainAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/ModelChainAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/ModelChainAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/ModelChainAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/ModelChainIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/ModelChainIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NaiveBayesAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/NaiveBayesAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NaiveBayesIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/NaiveBayesIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NeuralNetworkAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/NeuralNetworkAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NeuralNetworkIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/NeuralNetworkIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/RandomForestAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/RandomForestAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/RandomForestHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/RandomForestIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/e50d923448aaa60a716021609ac803edc5e84054/pmml-sparkml/src/test/resources/pipeline/RandomForestSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Audit.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Age","nullable":true,"type":"integer"},{"metadata":{},"name":"Employment","nullable":true,"type":"string"},{"metadata":{},"name":"Education","nullable":true,"type":"string"},{"metadata":{},"name":"Marital","nullable":true,"type":"string"},{"metadata":{},"name":"Occupation","nullable":true,"type":"string"},{"metadata":{},"name":"Income","nullable":true,"type":"double"},{"metadata":{},"name":"Gender","nullable":true,"type":"string"},{"metadata":{},"name":"Adjusted","nullable":true,"type":"integer"},{"metadata":{},"name":"Deductions","nullable":true,"type":"boolean"},{"metadata":{},"name":"Hours","nullable":true,"type":"double"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Auto.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"displacement","nullable":true,"type":"double"},{"metadata":{},"name":"horsepower","nullable":true,"type":"integer"},{"metadata":{},"name":"weight","nullable":true,"type":"integer"},{"metadata":{},"name":"acceleration","nullable":true,"type":"double"},{"metadata":{},"name":"mpg","nullable":true,"type":"double"},{"metadata":{},"name":"cylinders","nullable":true,"type":"string"},{"metadata":{},"name":"model_year","nullable":true,"type":"string"},{"metadata":{},"name":"origin","nullable":true,"type":"string"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Housing.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"CRIM","nullable":true,"type":"double"},{"metadata":{},"name":"ZN","nullable":true,"type":"double"},{"metadata":{},"name":"INDUS","nullable":true,"type":"double"},{"metadata":{},"name":"CHAS","nullable":true,"type":"integer"},{"metadata":{},"name":"NOX","nullable":true,"type":"double"},{"metadata":{},"name":"RM","nullable":true,"type":"double"},{"metadata":{},"name":"AGE","nullable":true,"type":"double"},{"metadata":{},"name":"DIS","nullable":true,"type":"double"},{"metadata":{},"name":"RAD","nullable":true,"type":"integer"},{"metadata":{},"name":"TAX","nullable":true,"type":"double"},{"metadata":{},"name":"PTRATIO","nullable":true,"type":"double"},{"metadata":{},"name":"B","nullable":true,"type":"double"},{"metadata":{},"name":"LSTAT","nullable":true,"type":"double"},{"metadata":{},"name":"MEDV","nullable":true,"type":"double"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Iris.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Sepal_Length","nullable":true,"type":"double"},{"metadata":{},"name":"Sepal_Width","nullable":true,"type":"double"},{"metadata":{},"name":"Petal_Length","nullable":true,"type":"double"},{"metadata":{},"name":"Petal_Width","nullable":true,"type":"double"},{"metadata":{},"name":"Species","nullable":true,"type":"string"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Sentiment.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Sentence","nullable":true,"type":"string"},{"metadata":{},"name":"Score","nullable":true,"type":"integer"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Shopping.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"transaction","nullable":true,"type":"integer"},{"metadata":{},"name":"items","nullable":false,"type":{"containsNull":false,"elementType":"string","type":"array"}}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Visit.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"edlevel","nullable":true,"type":"string"},{"metadata":{},"name":"age","nullable":true,"type":"integer"},{"metadata":{},"name":"outwork","nullable":true,"type":"integer"},{"metadata":{},"name":"female","nullable":true,"type":"integer"},{"metadata":{},"name":"married","nullable":true,"type":"integer"},{"metadata":{},"name":"kids","nullable":true,"type":"integer"},{"metadata":{},"name":"hhninc","nullable":true,"type":"double"},{"metadata":{},"name":"educ","nullable":true,"type":"double"},{"metadata":{},"name":"self","nullable":true,"type":"integer"},{"metadata":{},"name":"docvis","nullable":true,"type":"integer"}],"type":"struct"} --------------------------------------------------------------------------------