├── pmml-sparkml-lightgbm ├── src │ ├── test │ │ ├── resources │ │ │ ├── pipeline │ │ │ │ └── .gitkeep │ │ │ ├── schema │ │ │ │ └── .gitkeep │ │ │ ├── README.md │ │ │ └── main.scala │ │ └── java │ │ │ └── org │ │ │ └── jpmml │ │ │ └── sparkml │ │ │ └── lightgbm │ │ │ └── testing │ │ │ └── LightGBMTest.java │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── sparkml2pmml.properties │ │ └── java │ │ └── org │ │ └── jpmml │ │ └── sparkml │ │ └── lightgbm │ │ ├── LightGBMRegressionModelConverter.java │ │ ├── LightGBMClassificationModelConverter.java │ │ └── BoosterUtil.java └── pom.xml ├── pmml-sparkml-xgboost ├── src │ ├── test │ │ └── resources │ │ │ ├── pipeline │ │ │ └── .gitkeep │ │ │ ├── schema │ │ │ └── .gitkeep │ │ │ └── README.md │ └── main │ │ ├── resources │ │ └── META-INF │ │ │ └── sparkml2pmml.properties │ │ └── java │ │ └── org │ │ └── jpmml │ │ └── sparkml │ │ └── xgboost │ │ ├── XGBoostClassificationModelConverter.java │ │ └── XGBoostRegressionModelConverter.java └── pom.xml ├── pmml-sparkml ├── src │ ├── test │ │ ├── resources │ │ │ ├── requirements.txt │ │ │ ├── pipeline │ │ │ │ ├── GBTAudit.zip │ │ │ │ ├── GBTAuto.zip │ │ │ │ ├── GLMAudit.zip │ │ │ │ ├── GLMAuto.zip │ │ │ │ ├── GLMVisit.zip │ │ │ │ ├── GLMHousing.zip │ │ │ │ ├── KMeansIris.zip │ │ │ │ ├── GLMSentiment.zip │ │ │ │ ├── ModelChainAuto.zip │ │ │ │ ├── ModelChainIris.zip │ │ │ │ ├── NaiveBayesIris.zip │ │ │ │ ├── DecisionTreeAuto.zip │ │ │ │ ├── DecisionTreeIris.zip │ │ │ │ ├── FPGrowthShopping.zip │ │ │ │ ├── ModelChainAudit.zip │ │ │ │ ├── NaiveBayesAudit.zip │ │ │ │ ├── RandomForestAuto.zip │ │ │ │ ├── RandomForestIris.zip │ │ │ │ ├── DecisionTreeAudit.zip │ │ │ │ ├── DecisionTreeHousing.zip │ │ │ │ ├── LinearSVCSentiment.zip │ │ │ │ ├── NeuralNetworkAudit.zip │ │ │ │ ├── NeuralNetworkIris.zip │ │ │ │ ├── RandomForestAudit.zip │ │ │ │ ├── RandomForestHousing.zip │ │ │ │ ├── DecisionTreeSentiment.zip │ │ │ │ ├── LinearRegressionAuto.zip │ │ │ │ ├── RandomForestSentiment.zip │ │ │ │ ├── LinearRegressionAutoNA.zip │ │ │ │ ├── LinearRegressionHousing.zip │ │ │ │ ├── LogisticRegressionAudit.zip │ │ │ │ ├── LogisticRegressionIris.zip │ │ │ │ ├── IsotonicRegressionDecrAuto.zip │ │ │ │ ├── IsotonicRegressionIncrAuto.zip │ │ │ │ ├── LinearRegressionHousingVec.zip │ │ │ │ ├── LogisticRegressionAuditNA.zip │ │ │ │ └── LogisticRegressionIrisVec.zip │ │ │ ├── schema │ │ │ │ ├── Sentiment.json │ │ │ │ ├── Shopping.json │ │ │ │ ├── Iris.json │ │ │ │ ├── Auto.json │ │ │ │ ├── AutoNA.json │ │ │ │ ├── Visit.json │ │ │ │ ├── AuditNA.json │ │ │ │ ├── Audit.json │ │ │ │ ├── IrisVec.json │ │ │ │ ├── HousingVec.json │ │ │ │ └── Housing.json │ │ │ ├── README.md │ │ │ ├── csv │ │ │ │ └── KMeansIris.csv │ │ │ ├── data.py │ │ │ └── main.scala │ │ └── java │ │ │ └── org │ │ │ └── jpmml │ │ │ └── sparkml │ │ │ ├── testing │ │ │ ├── SparkMLDatasets.java │ │ │ ├── ClusteringTest.java │ │ │ ├── SparkMLAlgorithms.java │ │ │ ├── SimpleSparkMLEncoderBatchTest.java │ │ │ └── AssociationRulesTest.java │ │ │ ├── TermUtilTest.java │ │ │ ├── SparkMLTest.java │ │ │ ├── RegexKeyTest.java │ │ │ ├── PipelineModelUtilTest.java │ │ │ ├── PMMLBuilderTest.java │ │ │ ├── AliasExpressionTest.java │ │ │ ├── SparkMLEncoderTest.java │ │ │ └── feature │ │ │ ├── VectorDisassemblerTest.java │ │ │ └── DomainTest.java │ └── main │ │ ├── java │ │ └── org │ │ │ └── jpmml │ │ │ └── sparkml │ │ │ ├── model │ │ │ ├── HasFeatureImportances.java │ │ │ ├── HasRegressionTableOptions.java │ │ │ ├── HasPredictionModelOptions.java │ │ │ ├── LinearRegressionModelConverter.java │ │ │ ├── HasTreeOptions.java │ │ │ ├── DecisionTreeRegressionModelConverter.java │ │ │ ├── DecisionTreeClassificationModelConverter.java │ │ │ ├── LogisticRegressionModelConverter.java │ │ │ ├── NaiveBayesModelConverter.java │ │ │ ├── RandomForestRegressionModelConverter.java │ │ │ ├── RandomForestClassificationModelConverter.java │ │ │ ├── GBTRegressionModelConverter.java │ │ │ ├── KMeansModelConverter.java │ │ │ ├── LinearSVCModelConverter.java │ │ │ └── GBTClassificationModelConverter.java │ │ │ ├── HasSparkMLOptions.java │ │ │ ├── feature │ │ │ ├── ColumnPrunerConverter.java │ │ │ ├── VectorAttributeRewriterConverter.java │ │ │ ├── VectorDisassemblerConverter.java │ │ │ ├── SparseToDenseTransformerConverter.java │ │ │ ├── NGramConverter.java │ │ │ ├── VectorSizeHintConverter.java │ │ │ ├── VectorSlicerConverter.java │ │ │ ├── ChiSqSelectorModelConverter.java │ │ │ ├── VectorAssemblerConverter.java │ │ │ ├── IndexToStringConverter.java │ │ │ ├── TokenizerConverter.java │ │ │ ├── RegexTokenizerConverter.java │ │ │ ├── IDFModelConverter.java │ │ │ ├── RFormulaModelConverter.java │ │ │ ├── BinarizerConverter.java │ │ │ ├── StopWordsRemoverConverter.java │ │ │ ├── DomainUtil.java │ │ │ ├── MaxAbsScalerModelConverter.java │ │ │ ├── PCAModelConverter.java │ │ │ ├── InteractionConverter.java │ │ │ ├── ContinuousDomainModelConverter.java │ │ │ ├── DomainModelConverter.java │ │ │ └── MinMaxScalerModelConverter.java │ │ │ ├── ScalaUtil.java │ │ │ ├── ItemSetFeature.java │ │ │ ├── VectorUtil.java │ │ │ ├── AssociationRulesModelConverter.java │ │ │ ├── SparkSessionUtil.java │ │ │ ├── TransformerConverter.java │ │ │ ├── TermUtil.java │ │ │ ├── PredictionModelConverter.java │ │ │ ├── MatrixUtil.java │ │ │ ├── RegexKey.java │ │ │ ├── MultiFeatureConverter.java │ │ │ ├── WeightedTermFeature.java │ │ │ ├── BinarizedCategoricalFeature.java │ │ │ ├── testing │ │ │ └── SparkMLEncoderBatchTest.java │ │ │ ├── ProbabilisticClassificationModelConverter.java │ │ │ ├── ClusteringModelConverter.java │ │ │ ├── AliasExpression.java │ │ │ └── DocumentFeature.java │ │ └── scala │ │ └── org │ │ └── jpmml │ │ └── sparkml │ │ └── feature │ │ ├── package.scala │ │ └── SparseToDenseTransformer.scala └── pom.xml ├── .github └── workflows │ └── maven.yml ├── pmml-sparkml-example └── src │ └── main │ └── java │ └── org │ └── jpmml │ └── sparkml │ └── example │ └── NullSplitter.java └── pmml-sparkml-evaluator ├── pom.xml └── src └── main └── java └── org └── jpmml └── sparkml └── evaluator └── SparkMLFunctionRegistry.java /pmml-sparkml-lightgbm/src/test/resources/pipeline/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/schema/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/pipeline/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/schema/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/requirements.txt: -------------------------------------------------------------------------------- 1 | pyspark2pmml==0.8.0 2 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GBTAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GBTAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GBTAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GBTAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GLMAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GLMAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMVisit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GLMVisit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GLMHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/KMeansIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/KMeansIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/GLMSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/GLMSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/ModelChainAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/ModelChainAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/ModelChainIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/ModelChainIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NaiveBayesIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/NaiveBayesIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/DecisionTreeAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/DecisionTreeIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/FPGrowthShopping.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/FPGrowthShopping.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/ModelChainAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/ModelChainAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NaiveBayesAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/NaiveBayesAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/RandomForestAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/RandomForestIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/DecisionTreeAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/DecisionTreeHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearSVCSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LinearSVCSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NeuralNetworkAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/NeuralNetworkAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/NeuralNetworkIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/NeuralNetworkIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/RandomForestAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/RandomForestHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/DecisionTreeSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/DecisionTreeSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearRegressionAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LinearRegressionAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/RandomForestSentiment.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/RandomForestSentiment.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearRegressionAutoNA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LinearRegressionAutoNA.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearRegressionHousing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LinearRegressionHousing.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LogisticRegressionAudit.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LogisticRegressionAudit.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LogisticRegressionIris.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LogisticRegressionIris.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/IsotonicRegressionDecrAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/IsotonicRegressionDecrAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/IsotonicRegressionIncrAuto.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/IsotonicRegressionIncrAuto.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LinearRegressionHousingVec.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LinearRegressionHousingVec.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LogisticRegressionAuditNA.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LogisticRegressionAuditNA.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/pipeline/LogisticRegressionIrisVec.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jpmml/jpmml-sparkml/HEAD/pmml-sparkml/src/test/resources/pipeline/LogisticRegressionIrisVec.zip -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Sentiment.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Sentence","nullable":true,"type":"string"},{"metadata":{},"name":"Score","nullable":true,"type":"integer"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Shopping.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"transaction","nullable":true,"type":"integer"},{"metadata":{},"name":"items","nullable":false,"type":{"containsNull":false,"elementType":"string","type":"array"}}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/resources/META-INF/sparkml2pmml.properties: -------------------------------------------------------------------------------- 1 | ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel = org.jpmml.sparkml.xgboost.XGBoostClassificationModelConverter 2 | ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel = org.jpmml.sparkml.xgboost.XGBoostRegressionModelConverter 3 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/resources/META-INF/sparkml2pmml.properties: -------------------------------------------------------------------------------- 1 | com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassificationModel = org.jpmml.sparkml.lightgbm.LightGBMClassificationModelConverter 2 | com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel = org.jpmml.sparkml.lightgbm.LightGBMRegressionModelConverter 3 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Iris.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Sepal_Length","nullable":true,"type":"double"},{"metadata":{},"name":"Sepal_Width","nullable":true,"type":"double"},{"metadata":{},"name":"Petal_Length","nullable":true,"type":"double"},{"metadata":{},"name":"Petal_Width","nullable":true,"type":"double"},{"metadata":{},"name":"Species","nullable":true,"type":"string"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/test/resources/README.md: -------------------------------------------------------------------------------- 1 | Launch `spark-shell`: 2 | 3 | ```bash 4 | $SPARK_HOME/bin/spark-shell --jars ../../../../pmml-sparkml-example/target/pmml-sparkml-example-executable-3.2-SNAPSHOT.jar --packages ml.dmlc:xgboost4j-spark_2.12:${xgboost4j-spark.version} 5 | ``` 6 | 7 | Load scripts: 8 | 9 | ```spark-shell 10 | :load ../../../../pmml-sparkml/src/test/resources/common.scala 11 | :load main.scala 12 | ``` 13 | -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/README.md: -------------------------------------------------------------------------------- 1 | Launch `spark-shell`: 2 | 3 | ```bash 4 | $SPARK_HOME/bin/spark-shell --jars "../../../../pmml-sparkml-example/target/pmml-sparkml-example-executable-3.2-SNAPSHOT.jar,scala-library-2.12.20.jar" --packages com.microsoft.azure:synapseml-lightgbm_2.12:${synapseml-lightgbm.version} 5 | ``` 6 | 7 | Load scripts: 8 | 9 | ```spark-shell 10 | :load ../../../../pmml-sparkml/src/test/resources/common.scala 11 | :load main.scala 12 | ``` 13 | -------------------------------------------------------------------------------- /.github/workflows/maven.yml: -------------------------------------------------------------------------------- 1 | name: maven 2 | 3 | on: 4 | push: 5 | branches: [ '3.0.X', '3.1.X', master ] 6 | 7 | jobs: 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | java: [ 17, 21 ] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - uses: actions/setup-java@v4 18 | with: 19 | distribution: 'zulu' 20 | java-version: ${{ matrix.java }} 21 | cache: 'maven' 22 | - run: mvn -B package --file pom.xml 23 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/README.md: -------------------------------------------------------------------------------- 1 | Run `spark-submit`: 2 | 3 | ```bash 4 | $SPARK_HOME/bin/spark-submit --jars ../../../../pmml-sparkml-example/target/pmml-sparkml-example-executable-3.2-SNAPSHOT.jar main.py 5 | ``` 6 | 7 | Launch `spark-shell`: 8 | 9 | ```bash 10 | $SPARK_HOME/bin/spark-shell --jars ../../../../pmml-sparkml-example/target/pmml-sparkml-example-executable-3.2-SNAPSHOT.jar 11 | ``` 12 | 13 | Load scripts: 14 | 15 | ```spark-shell 16 | :load common.scala 17 | :load main.scala 18 | ``` 19 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Auto.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"displacement","nullable":true,"type":"double"},{"metadata":{},"name":"horsepower","nullable":true,"type":"integer"},{"metadata":{},"name":"weight","nullable":true,"type":"integer"},{"metadata":{},"name":"acceleration","nullable":true,"type":"double"},{"metadata":{},"name":"mpg","nullable":true,"type":"double"},{"metadata":{},"name":"cylinders","nullable":true,"type":"string"},{"metadata":{},"name":"model_year","nullable":true,"type":"string"},{"metadata":{},"name":"origin","nullable":true,"type":"string"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/AutoNA.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"mpg","nullable":true,"type":"double"},{"metadata":{},"name":"cylinders","nullable":true,"type":"string"},{"metadata":{},"name":"model_year","nullable":true,"type":"string"},{"metadata":{},"name":"origin","nullable":true,"type":"string"},{"metadata":{},"name":"acceleration","nullable":true,"type":"double"},{"metadata":{},"name":"displacement","nullable":true,"type":"double"},{"metadata":{},"name":"horsepower","nullable":true,"type":"double"},{"metadata":{},"name":"weight","nullable":true,"type":"double"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Visit.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"edlevel","nullable":true,"type":"string"},{"metadata":{},"name":"age","nullable":true,"type":"integer"},{"metadata":{},"name":"outwork","nullable":true,"type":"integer"},{"metadata":{},"name":"female","nullable":true,"type":"integer"},{"metadata":{},"name":"married","nullable":true,"type":"integer"},{"metadata":{},"name":"kids","nullable":true,"type":"integer"},{"metadata":{},"name":"hhninc","nullable":true,"type":"double"},{"metadata":{},"name":"educ","nullable":true,"type":"double"},{"metadata":{},"name":"self","nullable":true,"type":"integer"},{"metadata":{},"name":"docvis","nullable":true,"type":"integer"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/AuditNA.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Employment","nullable":true,"type":"string"},{"metadata":{},"name":"Education","nullable":true,"type":"string"},{"metadata":{},"name":"Marital","nullable":true,"type":"string"},{"metadata":{},"name":"Occupation","nullable":true,"type":"string"},{"metadata":{},"name":"Gender","nullable":true,"type":"string"},{"metadata":{},"name":"Deductions","nullable":true,"type":"string"},{"metadata":{},"name":"Adjusted","nullable":true,"type":"string"},{"metadata":{},"name":"Age","nullable":true,"type":"double"},{"metadata":{},"name":"Income","nullable":true,"type":"double"},{"metadata":{},"name":"Hours","nullable":true,"type":"double"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Audit.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"Age","nullable":true,"type":"integer"},{"metadata":{},"name":"Employment","nullable":true,"type":"string"},{"metadata":{},"name":"Education","nullable":true,"type":"string"},{"metadata":{},"name":"Marital","nullable":true,"type":"string"},{"metadata":{},"name":"Occupation","nullable":true,"type":"string"},{"metadata":{},"name":"Income","nullable":true,"type":"double"},{"metadata":{},"name":"Gender","nullable":true,"type":"string"},{"metadata":{},"name":"Adjusted","nullable":true,"type":"integer"},{"metadata":{},"name":"Deductions","nullable":true,"type":"boolean"},{"metadata":{},"name":"Hours","nullable":true,"type":"double"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/IrisVec.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"label","type":"integer","nullable":true,"metadata":{}},{"name":"features","type":{"type":"udt","class":"org.apache.spark.ml.linalg.VectorUDT","pyClass":"pyspark.ml.linalg.VectorUDT","sqlType":{"type":"struct","fields":[{"name":"type","type":"byte","nullable":false,"metadata":{}},{"name":"size","type":"integer","nullable":true,"metadata":{}},{"name":"indices","type":{"type":"array","elementType":"integer","containsNull":false},"nullable":true,"metadata":{}},{"name":"values","type":{"type":"array","elementType":"double","containsNull":false},"nullable":true,"metadata":{}}]}},"nullable":true,"metadata":{"numFeatures":4,"ml_attr":{"num_attrs":4}}}]} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/HousingVec.json: -------------------------------------------------------------------------------- 1 | {"type":"struct","fields":[{"name":"label","type":"double","nullable":true,"metadata":{}},{"name":"features","type":{"type":"udt","class":"org.apache.spark.ml.linalg.VectorUDT","pyClass":"pyspark.ml.linalg.VectorUDT","sqlType":{"type":"struct","fields":[{"name":"type","type":"byte","nullable":false,"metadata":{}},{"name":"size","type":"integer","nullable":true,"metadata":{}},{"name":"indices","type":{"type":"array","elementType":"integer","containsNull":false},"nullable":true,"metadata":{}},{"name":"values","type":{"type":"array","elementType":"double","containsNull":false},"nullable":true,"metadata":{}}]}},"nullable":true,"metadata":{"numFeatures":13,"ml_attr":{"num_attrs":13}}}]} -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/schema/Housing.json: -------------------------------------------------------------------------------- 1 | {"fields":[{"metadata":{},"name":"CRIM","nullable":true,"type":"double"},{"metadata":{},"name":"ZN","nullable":true,"type":"double"},{"metadata":{},"name":"INDUS","nullable":true,"type":"double"},{"metadata":{},"name":"CHAS","nullable":true,"type":"integer"},{"metadata":{},"name":"NOX","nullable":true,"type":"double"},{"metadata":{},"name":"RM","nullable":true,"type":"double"},{"metadata":{},"name":"AGE","nullable":true,"type":"double"},{"metadata":{},"name":"DIS","nullable":true,"type":"double"},{"metadata":{},"name":"RAD","nullable":true,"type":"integer"},{"metadata":{},"name":"TAX","nullable":true,"type":"double"},{"metadata":{},"name":"PTRATIO","nullable":true,"type":"double"},{"metadata":{},"name":"B","nullable":true,"type":"double"},{"metadata":{},"name":"LSTAT","nullable":true,"type":"double"},{"metadata":{},"name":"MEDV","nullable":true,"type":"double"}],"type":"struct"} -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasFeatureImportances.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.linalg.Vector; 22 | 23 | public interface HasFeatureImportances { 24 | 25 | Vector getFeatureImportances(); 26 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/HasSparkMLOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.jpmml.converter.HasOptions; 22 | 23 | /** 24 | * @see TransformerConverter#getOption(String, Object) 25 | */ 26 | public interface HasSparkMLOptions extends HasOptions { 27 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/SparkMLDatasets.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import org.jpmml.converter.testing.Datasets; 22 | 23 | interface SparkMLDatasets extends Datasets { 24 | 25 | String HOUSING_VEC = HOUSING + "Vec"; 26 | String IRIS_VEC = IRIS + "Vec"; 27 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasRegressionTableOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.jpmml.sparkml.HasSparkMLOptions; 22 | 23 | public interface HasRegressionTableOptions extends HasSparkMLOptions { 24 | 25 | String OPTION_REPRESENTATION = "representation"; 26 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasPredictionModelOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.jpmml.sparkml.HasSparkMLOptions; 22 | 23 | public interface HasPredictionModelOptions extends HasSparkMLOptions { 24 | 25 | String OPTION_KEEP_PREDICTIONCOL = "keep_predictionCol"; 26 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/ColumnPrunerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import org.apache.spark.ml.feature.ColumnPruner; 22 | import org.jpmml.sparkml.FeatureConverter; 23 | 24 | public class ColumnPrunerConverter extends FeatureConverter { 25 | 26 | public ColumnPrunerConverter(ColumnPruner transformer){ 27 | super(transformer); 28 | } 29 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/csv/KMeansIris.csv: -------------------------------------------------------------------------------- 1 | cluster 2 | 1 3 | 10 4 | 10 5 | 10 6 | 1 7 | 1 8 | 10 9 | 1 10 | 10 11 | 10 12 | 1 13 | 10 14 | 10 15 | 10 16 | 1 17 | 1 18 | 1 19 | 1 20 | 1 21 | 1 22 | 1 23 | 1 24 | 10 25 | 1 26 | 10 27 | 10 28 | 1 29 | 1 30 | 1 31 | 10 32 | 10 33 | 1 34 | 1 35 | 1 36 | 10 37 | 10 38 | 1 39 | 1 40 | 10 41 | 1 42 | 1 43 | 10 44 | 10 45 | 1 46 | 1 47 | 10 48 | 1 49 | 10 50 | 1 51 | 10 52 | 3 53 | 3 54 | 3 55 | 6 56 | 3 57 | 2 58 | 3 59 | 5 60 | 3 61 | 6 62 | 5 63 | 2 64 | 6 65 | 9 66 | 6 67 | 3 68 | 2 69 | 6 70 | 9 71 | 6 72 | 4 73 | 6 74 | 9 75 | 9 76 | 3 77 | 3 78 | 3 79 | 3 80 | 2 81 | 6 82 | 6 83 | 6 84 | 6 85 | 4 86 | 2 87 | 2 88 | 3 89 | 9 90 | 2 91 | 6 92 | 2 93 | 2 94 | 6 95 | 5 96 | 2 97 | 2 98 | 2 99 | 2 100 | 5 101 | 2 102 | 7 103 | 4 104 | 8 105 | 8 106 | 8 107 | 0 108 | 2 109 | 0 110 | 8 111 | 0 112 | 8 113 | 8 114 | 8 115 | 4 116 | 4 117 | 7 118 | 8 119 | 0 120 | 0 121 | 9 122 | 8 123 | 4 124 | 0 125 | 9 126 | 8 127 | 0 128 | 9 129 | 4 130 | 8 131 | 8 132 | 0 133 | 0 134 | 8 135 | 9 136 | 4 137 | 0 138 | 7 139 | 8 140 | 4 141 | 8 142 | 8 143 | 8 144 | 4 145 | 8 146 | 7 147 | 8 148 | 9 149 | 8 150 | 7 151 | 4 152 | -------------------------------------------------------------------------------- /pmml-sparkml-example/src/main/java/org/jpmml/sparkml/example/NullSplitter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.example; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import com.beust.jcommander.converters.IParameterSplitter; 25 | 26 | public class NullSplitter implements IParameterSplitter { 27 | 28 | @Override 29 | public List split(String value){ 30 | return Collections.singletonList(value); 31 | } 32 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/scala/org/jpmml/sparkml/feature/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml 20 | 21 | /** 22 | * @groupname param Parameters 23 | * @groupprio param -3 24 | * 25 | * @groupname setParam Parameter setters 26 | * @groupprio setParam -2 27 | * 28 | * @groupname getParam Parameter getters 29 | * @groupprio getParam -1 30 | * 31 | * @groupname Ungrouped Members 32 | * @groupprio Ungrouped 0 33 | */ 34 | package object feature -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorAttributeRewriterConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import org.apache.spark.ml.feature.VectorAttributeRewriter; 22 | import org.jpmml.sparkml.FeatureConverter; 23 | 24 | public class VectorAttributeRewriterConverter extends FeatureConverter { 25 | 26 | public VectorAttributeRewriterConverter(VectorAttributeRewriter transformer){ 27 | super(transformer); 28 | } 29 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/data.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import LabelEncoder 2 | 3 | import math 4 | import pandas 5 | 6 | def read_csv(name): 7 | df = pandas.read_csv("csv/" + name + ".csv", na_values = ["", "NA", "N/A"]) 8 | 9 | X = df.iloc[:, :-1].values 10 | y = df.iloc[:, -1].values 11 | 12 | return (X, y) 13 | 14 | def write_libsvm(X, y, name): 15 | n_rows, n_cols = X.shape 16 | 17 | with open("libsvm/" + name + ".libsvm", "w") as file: 18 | for row in range(n_rows): 19 | cells = [] 20 | label = y[row] 21 | cells.append("{:g}".format(label)) 22 | for col in range(n_cols): 23 | value = X[row, col] 24 | if not math.isnan(value): 25 | cells.append("{}:{:g}".format(col + 1, value)) 26 | file.write(" ".join(cells) + "\n") 27 | 28 | # 29 | # Auto 30 | # 31 | 32 | auto_X, auto_y = read_csv("Auto") 33 | 34 | write_libsvm(auto_X, auto_y, "Auto") 35 | 36 | auto_X, auto_y = read_csv("AutoNA") 37 | 38 | write_libsvm(auto_X, auto_y, "AutoNA") 39 | 40 | # 41 | # Housing 42 | # 43 | 44 | housing_X, housing_y = read_csv("Housing") 45 | 46 | write_libsvm(housing_X, housing_y, "Housing") 47 | 48 | # 49 | # Iris 50 | # 51 | 52 | iris_X, iris_y = read_csv("Iris") 53 | 54 | iris_le = LabelEncoder() 55 | iris_y = iris_le.fit_transform(iris_y) 56 | 57 | write_libsvm(iris_X, iris_y, "Iris") 58 | -------------------------------------------------------------------------------- /pmml-sparkml-evaluator/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.2-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml-evaluator 13 | jar 14 | 15 | JPMML Spark ML JPMML-Evaluator integration 16 | JPMML Apache Spark ML JPMML-Evaluator integration 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-evaluator 30 | provided 31 | 32 | 33 | org.jpmml 34 | pmml-evaluator-testing 35 | provided 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/ClusteringTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import org.jpmml.converter.FieldNameUtil; 22 | import org.junit.jupiter.api.Test; 23 | 24 | public class ClusteringTest extends SimpleSparkMLEncoderBatchTest implements SparkMLAlgorithms, SparkMLDatasets { 25 | 26 | @Test 27 | public void evaluateKMeansIris() throws Exception { 28 | String[] outputFields = {FieldNameUtil.create("pmml", "cluster")}; 29 | 30 | evaluate(K_MEANS, IRIS, excludeFields(outputFields)); 31 | } 32 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ScalaUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import scala.collection.Seq; 25 | 26 | public class ScalaUtil { 27 | 28 | private ScalaUtil(){ 29 | } 30 | 31 | static 32 | public List seqAsJavaList(Seq seq){ 33 | List result = new ArrayList<>(); 34 | 35 | for(int i = 0, max = seq.length(); i < max; i++){ 36 | E element = seq.apply(i); 37 | 38 | result.add(element); 39 | } 40 | 41 | return result; 42 | } 43 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ItemSetFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.dmg.pmml.Field; 22 | import org.jpmml.converter.ContinuousFeature; 23 | import org.jpmml.converter.Feature; 24 | 25 | public class ItemSetFeature extends Feature { 26 | 27 | public ItemSetFeature(SparkMLEncoder encoder, Field field){ 28 | super(encoder, field.requireName(), field.requireDataType()); 29 | } 30 | 31 | @Override 32 | public ContinuousFeature toContinuousFeature(){ 33 | throw new UnsupportedOperationException(); 34 | } 35 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/VectorUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.primitives.Doubles; 24 | import org.apache.spark.ml.linalg.DenseVector; 25 | import org.apache.spark.ml.linalg.Vector; 26 | 27 | public class VectorUtil { 28 | 29 | private VectorUtil(){ 30 | } 31 | 32 | static 33 | public List toList(Vector vector){ 34 | DenseVector denseVector = vector.toDense(); 35 | 36 | double[] values = denseVector.values(); 37 | 38 | return Doubles.asList(values); 39 | } 40 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/AssociationRulesModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.ml.Model; 22 | import org.apache.spark.ml.param.shared.HasPredictionCol; 23 | import org.dmg.pmml.MiningFunction; 24 | 25 | abstract 26 | public class AssociationRulesModelConverter & HasPredictionCol> extends ModelConverter { 27 | 28 | public AssociationRulesModelConverter(T model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public MiningFunction getMiningFunction(){ 34 | return MiningFunction.ASSOCIATION_RULES; 35 | } 36 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/TermUtilTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.junit.jupiter.api.Test; 22 | 23 | import static org.junit.jupiter.api.Assertions.assertFalse; 24 | import static org.junit.jupiter.api.Assertions.assertTrue; 25 | 26 | public class TermUtilTest { 27 | 28 | @Test 29 | public void hasPunctuation(){ 30 | assertFalse(TermUtil.hasPunctuation("one")); 31 | assertTrue(TermUtil.hasPunctuation("one?")); 32 | 33 | assertFalse(TermUtil.hasPunctuation("one-half")); 34 | 35 | assertFalse(TermUtil.hasPunctuation("one two")); 36 | assertTrue(TermUtil.hasPunctuation("one, two")); 37 | } 38 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/SparkMLTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.sql.SparkSession; 22 | import org.junit.jupiter.api.AfterAll; 23 | import org.junit.jupiter.api.BeforeAll; 24 | 25 | abstract 26 | public class SparkMLTest { 27 | 28 | @BeforeAll 29 | static 30 | public void createSparkSession(){ 31 | SparkMLTest.sparkSession = SparkSessionUtil.createSparkSession(); 32 | } 33 | 34 | @AfterAll 35 | static 36 | public void destroySparkSession(){ 37 | SparkMLTest.sparkSession = SparkSessionUtil.destroySparkSession(SparkMLTest.sparkSession); 38 | } 39 | 40 | public static SparkSession sparkSession = null; 41 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/SparkMLAlgorithms.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | interface SparkMLAlgorithms { 22 | 23 | String DECISION_TREE = "DecisionTree"; 24 | String FP_GROWTH = "FPGrowth"; 25 | String GBT = "GBT"; 26 | String GLM = "GLM"; 27 | String ISOTONIC_REGRESSION = "IsotonicRegression"; 28 | String K_MEANS = "KMeans"; 29 | String LINEAR_REGRESION = "LinearRegression"; 30 | String LINEAR_SVC = "LinearSVC"; 31 | String LOGISTIC_REGRESSION = "LogisticRegression"; 32 | String MODEL_CHAIN = "ModelChain"; 33 | String NAIVE_BAYES = "NaiveBayes"; 34 | String NEURAL_NETWORK = "NeuralNetwork"; 35 | String RANDOM_FOREST = "RandomForest"; 36 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/RegexKeyTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.regex.Pattern; 22 | 23 | import org.junit.jupiter.api.Test; 24 | 25 | import static org.junit.jupiter.api.Assertions.assertFalse; 26 | import static org.junit.jupiter.api.Assertions.assertTrue; 27 | 28 | public class RegexKeyTest { 29 | 30 | @Test 31 | public void compile(){ 32 | RegexKey anyKey = new RegexKey(Pattern.compile(".*")); 33 | RegexKey dotAsteriskKey = new RegexKey(Pattern.compile(".*", Pattern.LITERAL)); 34 | 35 | assertTrue(anyKey.test("")); 36 | assertTrue(anyKey.test(".*")); 37 | 38 | assertFalse(dotAsteriskKey.test("")); 39 | assertTrue(dotAsteriskKey.test(".*")); 40 | } 41 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorDisassemblerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.jpmml.converter.Feature; 24 | import org.jpmml.sparkml.FeatureConverter; 25 | import org.jpmml.sparkml.SparkMLEncoder; 26 | 27 | public class VectorDisassemblerConverter extends FeatureConverter { 28 | 29 | public VectorDisassemblerConverter(VectorDisassembler transformer){ 30 | super(transformer); 31 | } 32 | 33 | @Override 34 | public List encodeFeatures(SparkMLEncoder encoder){ 35 | VectorDisassembler transformer = getTransformer(); 36 | 37 | String inputCol = transformer.getInputCol(); 38 | 39 | return encoder.getFeatures(inputCol); 40 | } 41 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/java/org/jpmml/sparkml/lightgbm/LightGBMRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm; 20 | 21 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMRegressionModel; 22 | import org.dmg.pmml.mining.MiningModel; 23 | import org.jpmml.converter.Schema; 24 | import org.jpmml.sparkml.RegressionModelConverter; 25 | 26 | public class LightGBMRegressionModelConverter extends RegressionModelConverter { 27 | 28 | public LightGBMRegressionModelConverter(LightGBMRegressionModel model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public MiningModel encodeModel(Schema schema){ 34 | LightGBMRegressionModel model = getModel(); 35 | 36 | return BoosterUtil.encodeModel(this, schema); 37 | } 38 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/LinearRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.regression.LinearRegressionModel; 22 | import org.dmg.pmml.Model; 23 | import org.jpmml.converter.Schema; 24 | import org.jpmml.sparkml.RegressionModelConverter; 25 | 26 | public class LinearRegressionModelConverter extends RegressionModelConverter implements HasRegressionTableOptions { 27 | 28 | public LinearRegressionModelConverter(LinearRegressionModel model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public Model encodeModel(Schema schema){ 34 | LinearRegressionModel model = getModel(); 35 | 36 | return LinearModelUtil.createRegression(this, model.coefficients(), model.intercept(), schema); 37 | } 38 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/SparseToDenseTransformerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.jpmml.converter.Feature; 24 | import org.jpmml.sparkml.FeatureConverter; 25 | import org.jpmml.sparkml.SparkMLEncoder; 26 | 27 | public class SparseToDenseTransformerConverter extends FeatureConverter { 28 | 29 | public SparseToDenseTransformerConverter(SparseToDenseTransformer transformer){ 30 | super(transformer); 31 | } 32 | 33 | @Override 34 | public List encodeFeatures(SparkMLEncoder encoder){ 35 | SparseToDenseTransformer transformer = getTransformer(); 36 | 37 | List features = encoder.getFeatures(transformer.getInputCol()); 38 | 39 | return features; 40 | } 41 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/PipelineModelUtilTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.ml.PipelineModel; 22 | import org.apache.spark.ml.Transformer; 23 | import org.apache.spark.ml.feature.SQLTransformer; 24 | 25 | import static org.junit.jupiter.api.Assertions.assertArrayEquals; 26 | import static org.junit.jupiter.api.Assertions.assertEquals; 27 | 28 | public class PipelineModelUtilTest { 29 | 30 | public void create(){ 31 | Transformer identityTransformer = new SQLTransformer() 32 | .setStatement("SELECT * FROM __THIS__"); 33 | 34 | PipelineModel pipelineModel = PipelineModelUtil.create("test", new Transformer[]{identityTransformer}); 35 | 36 | assertEquals("test", pipelineModel.uid()); 37 | assertArrayEquals(new Transformer[]{identityTransformer}, pipelineModel.stages()); 38 | } 39 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/HasTreeOptions.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.LinkedHashMap; 22 | import java.util.Map; 23 | 24 | import org.jpmml.converter.HasNativeConfiguration; 25 | import org.jpmml.sparkml.HasSparkMLOptions; 26 | import org.jpmml.sparkml.visitors.TreeModelCompactor; 27 | 28 | public interface HasTreeOptions extends HasSparkMLOptions, HasNativeConfiguration { 29 | 30 | /** 31 | * @see TreeModelCompactor 32 | */ 33 | String OPTION_COMPACT = "compact"; 34 | 35 | String OPTION_ESTIMATE_FEATURE_IMPORTANCES = "estimate_featureImportances"; 36 | 37 | @Override 38 | default 39 | public Map getNativeConfiguration(){ 40 | Map result = new LinkedHashMap<>(); 41 | result.put(HasTreeOptions.OPTION_COMPACT, Boolean.FALSE); 42 | 43 | return result; 44 | } 45 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/NGramConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.NGram; 25 | import org.jpmml.converter.Feature; 26 | import org.jpmml.sparkml.DocumentFeature; 27 | import org.jpmml.sparkml.FeatureConverter; 28 | import org.jpmml.sparkml.SparkMLEncoder; 29 | 30 | public class NGramConverter extends FeatureConverter { 31 | 32 | public NGramConverter(NGram transformer){ 33 | super(transformer); 34 | } 35 | 36 | @Override 37 | public List encodeFeatures(SparkMLEncoder encoder){ 38 | NGram transformer = getTransformer(); 39 | 40 | DocumentFeature documentFeature = (DocumentFeature)encoder.getOnlyFeature(transformer.getInputCol()); 41 | 42 | return Collections.singletonList(documentFeature); 43 | } 44 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorSizeHintConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.feature.VectorSizeHint; 24 | import org.jpmml.converter.Feature; 25 | import org.jpmml.converter.SchemaUtil; 26 | import org.jpmml.sparkml.FeatureConverter; 27 | import org.jpmml.sparkml.SparkMLEncoder; 28 | 29 | public class VectorSizeHintConverter extends FeatureConverter { 30 | 31 | public VectorSizeHintConverter(VectorSizeHint transformer){ 32 | super(transformer); 33 | } 34 | 35 | @Override 36 | public List encodeFeatures(SparkMLEncoder encoder){ 37 | VectorSizeHint transformer = getTransformer(); 38 | 39 | int size = transformer.getSize(); 40 | 41 | List features = encoder.getFeatures(transformer.getInputCol()); 42 | 43 | SchemaUtil.checkSize(size, features); 44 | 45 | return features; 46 | } 47 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorSlicerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.feature.VectorSlicer; 24 | import org.jpmml.converter.Feature; 25 | import org.jpmml.sparkml.FeatureConverter; 26 | import org.jpmml.sparkml.SparkMLEncoder; 27 | 28 | public class VectorSlicerConverter extends FeatureConverter { 29 | 30 | public VectorSlicerConverter(VectorSlicer transformer){ 31 | super(transformer); 32 | } 33 | 34 | @Override 35 | public List encodeFeatures(SparkMLEncoder encoder){ 36 | VectorSlicer transformer = getTransformer(); 37 | 38 | String[] names = transformer.getNames(); 39 | if(names != null && names.length > 0){ 40 | throw new IllegalArgumentException("Expected index mode, got name mode"); 41 | } 42 | 43 | return encoder.getFeatures(transformer.getInputCol(), transformer.getIndices()); 44 | } 45 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/SparkSessionUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.apache.spark.SparkContext; 22 | import org.apache.spark.sql.SparkSession; 23 | 24 | public class SparkSessionUtil { 25 | 26 | private SparkSessionUtil(){ 27 | } 28 | 29 | static 30 | public SparkSession createSparkSession(){ 31 | return createSparkSession("local"); 32 | } 33 | 34 | static 35 | public SparkSession createSparkSession(String master){ 36 | SparkSession.Builder builder = SparkSession.builder() 37 | .master(master) 38 | .config("spark.ui.enabled", false); 39 | 40 | SparkSession sparkSession = builder.getOrCreate(); 41 | 42 | SparkContext sparkContext = sparkSession.sparkContext(); 43 | sparkContext.setLogLevel("ERROR"); 44 | 45 | return sparkSession; 46 | } 47 | 48 | static 49 | public SparkSession destroySparkSession(SparkSession sparkSession){ 50 | sparkSession.stop(); 51 | 52 | return null; 53 | } 54 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/DecisionTreeRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.linalg.Vector; 22 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel; 23 | import org.dmg.pmml.tree.TreeModel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.sparkml.RegressionModelConverter; 26 | 27 | public class DecisionTreeRegressionModelConverter extends RegressionModelConverter implements HasFeatureImportances, HasTreeOptions { 28 | 29 | public DecisionTreeRegressionModelConverter(DecisionTreeRegressionModel model){ 30 | super(model); 31 | } 32 | 33 | @Override 34 | public Vector getFeatureImportances(){ 35 | DecisionTreeRegressionModel model = getModel(); 36 | 37 | return model.featureImportances(); 38 | } 39 | 40 | @Override 41 | public TreeModel encodeModel(Schema schema){ 42 | return TreeModelUtil.encodeDecisionTree(this, schema); 43 | } 44 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/ChiSqSelectorModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Arrays; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.ChiSqSelectorModel; 25 | import org.jpmml.converter.Feature; 26 | import org.jpmml.sparkml.FeatureConverter; 27 | import org.jpmml.sparkml.SparkMLEncoder; 28 | 29 | public class ChiSqSelectorModelConverter extends FeatureConverter { 30 | 31 | public ChiSqSelectorModelConverter(ChiSqSelectorModel transformer){ 32 | super(transformer); 33 | } 34 | 35 | @Override 36 | public List encodeFeatures(SparkMLEncoder encoder){ 37 | ChiSqSelectorModel transformer = getTransformer(); 38 | 39 | int[] indices = transformer.selectedFeatures(); 40 | if(indices.length > 0){ 41 | indices = indices.clone(); 42 | 43 | Arrays.sort(indices); 44 | } 45 | 46 | return encoder.getFeatures(transformer.getFeaturesCol(), indices); 47 | } 48 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/VectorAssemblerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.VectorAssembler; 25 | import org.jpmml.converter.Feature; 26 | import org.jpmml.sparkml.FeatureConverter; 27 | import org.jpmml.sparkml.SparkMLEncoder; 28 | 29 | public class VectorAssemblerConverter extends FeatureConverter { 30 | 31 | public VectorAssemblerConverter(VectorAssembler transformer){ 32 | super(transformer); 33 | } 34 | 35 | @Override 36 | public List encodeFeatures(SparkMLEncoder encoder){ 37 | VectorAssembler transformer = getTransformer(); 38 | 39 | List result = new ArrayList<>(); 40 | 41 | String[] inputCols = transformer.getInputCols(); 42 | for(String inputCol : inputCols){ 43 | List features = encoder.getFeatures(inputCol); 44 | 45 | result.addAll(features); 46 | } 47 | 48 | return result; 49 | } 50 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/TransformerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Map; 22 | 23 | import org.apache.spark.ml.Transformer; 24 | 25 | abstract 26 | public class TransformerConverter { 27 | 28 | private T object = null; 29 | 30 | private Map options = null; 31 | 32 | 33 | public TransformerConverter(T object){ 34 | setObject(object); 35 | } 36 | 37 | public Object getOption(String key, Object defaultValue){ 38 | Map options = getOptions(); 39 | 40 | if(options != null && options.containsKey(key)){ 41 | return options.get(key); 42 | } 43 | 44 | return defaultValue; 45 | } 46 | 47 | public T getObject(){ 48 | return this.object; 49 | } 50 | 51 | private void setObject(T object){ 52 | this.object = object; 53 | } 54 | 55 | public Map getOptions(){ 56 | return this.options; 57 | } 58 | 59 | public void setOptions(Map options){ 60 | this.options = options; 61 | } 62 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/SimpleSparkMLEncoderBatchTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import java.util.function.Predicate; 22 | 23 | import com.google.common.base.Equivalence; 24 | import org.jpmml.evaluator.ResultField; 25 | import org.junit.jupiter.api.AfterAll; 26 | import org.junit.jupiter.api.BeforeAll; 27 | 28 | abstract 29 | public class SimpleSparkMLEncoderBatchTest extends SparkMLEncoderBatchTest { 30 | 31 | @Override 32 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 33 | columnFilter = columnFilter.and(excludePredictionFields()); 34 | 35 | return super.createBatch(algorithm, dataset, columnFilter, equivalence); 36 | } 37 | 38 | @BeforeAll 39 | static 40 | public void createSparkSession(){ 41 | SparkMLEncoderBatchTest.createSparkSession(); 42 | } 43 | 44 | @AfterAll 45 | static 46 | public void destroySparkSession(){ 47 | SparkMLEncoderBatchTest.destroySparkSession(); 48 | } 49 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/DecisionTreeClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel; 22 | import org.apache.spark.ml.linalg.Vector; 23 | import org.dmg.pmml.tree.TreeModel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 26 | 27 | public class DecisionTreeClassificationModelConverter extends ProbabilisticClassificationModelConverter implements HasFeatureImportances, HasTreeOptions { 28 | 29 | public DecisionTreeClassificationModelConverter(DecisionTreeClassificationModel model){ 30 | super(model); 31 | } 32 | 33 | @Override 34 | public Vector getFeatureImportances(){ 35 | DecisionTreeClassificationModel model = getModel(); 36 | 37 | return model.featureImportances(); 38 | } 39 | 40 | @Override 41 | public TreeModel encodeModel(Schema schema){ 42 | return TreeModelUtil.encodeDecisionTree(this, schema); 43 | } 44 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/TermUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | public class TermUtil { 22 | 23 | private TermUtil(){ 24 | } 25 | 26 | static 27 | public boolean hasPunctuation(String string){ 28 | String[] tokens = string.split("\\s+"); 29 | 30 | for(String token : tokens){ 31 | int length = token.length(); 32 | 33 | if(length > 0){ 34 | char first = token.charAt(0); 35 | char last = token.charAt(length - 1); 36 | 37 | if(isPunctuation(first) || isPunctuation(last)){ 38 | return true; 39 | } 40 | } 41 | } 42 | 43 | return false; 44 | } 45 | 46 | static 47 | public boolean isPunctuation(char c){ 48 | int type = Character.getType(c); 49 | 50 | switch(type){ 51 | case Character.DASH_PUNCTUATION: 52 | case Character.END_PUNCTUATION: 53 | case Character.START_PUNCTUATION: 54 | case Character.CONNECTOR_PUNCTUATION: 55 | case Character.OTHER_PUNCTUATION: 56 | case Character.INITIAL_QUOTE_PUNCTUATION: 57 | case Character.FINAL_QUOTE_PUNCTUATION: 58 | return true; 59 | default: 60 | return false; 61 | } 62 | } 63 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/IndexToStringConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Arrays; 22 | import java.util.Collections; 23 | import java.util.List; 24 | 25 | import org.apache.spark.ml.feature.IndexToString; 26 | import org.dmg.pmml.DataField; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.OpType; 29 | import org.jpmml.converter.CategoricalFeature; 30 | import org.jpmml.converter.Feature; 31 | import org.jpmml.sparkml.FeatureConverter; 32 | import org.jpmml.sparkml.SparkMLEncoder; 33 | 34 | public class IndexToStringConverter extends FeatureConverter { 35 | 36 | public IndexToStringConverter(IndexToString transformer){ 37 | super(transformer); 38 | } 39 | 40 | @Override 41 | public List encodeFeatures(SparkMLEncoder encoder){ 42 | IndexToString transformer = getTransformer(); 43 | 44 | DataField dataField = encoder.createDataField(formatName(encoder), OpType.CATEGORICAL, DataType.STRING, Arrays.asList(transformer.getLabels())); 45 | 46 | return Collections.singletonList(new CategoricalFeature(encoder, dataField)); 47 | } 48 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/PMMLBuilderTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Arrays; 22 | 23 | import org.apache.spark.ml.Model; 24 | import org.apache.spark.ml.PipelineModel; 25 | import org.apache.spark.ml.classification.LogisticRegressionModel; 26 | import org.apache.spark.ml.linalg.DenseVector; 27 | import org.apache.spark.sql.types.StructType; 28 | import org.junit.jupiter.api.Test; 29 | 30 | import static org.junit.jupiter.api.Assertions.fail; 31 | 32 | public class PMMLBuilderTest { 33 | 34 | @Test 35 | public void construct(){ 36 | StructType schema = new StructType(); 37 | 38 | Model model = new LogisticRegressionModel("lrm", new DenseVector(new double[0]), 0d); 39 | 40 | try { 41 | PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, model); 42 | 43 | fail(); 44 | } catch(IllegalArgumentException iae){ 45 | // Ignored 46 | } 47 | 48 | PipelineModel pipelineModel = new PipelineModel("pm", Arrays.asList(model)); 49 | 50 | try { 51 | PMMLBuilder pmmlBuilder = new PMMLBuilder(schema, pipelineModel); 52 | } catch(IllegalArgumentException iae){ 53 | throw iae; 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.2-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml-xgboost 13 | jar 14 | 15 | JPMML Spark ML XGBoost converter 16 | JPMML Apache Spark ML XGBoost to PMML converter 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-sparkml 30 | 31 | 32 | 33 | org.jpmml 34 | pmml-evaluator-testing 35 | provided 36 | 37 | 38 | 39 | org.jpmml 40 | pmml-xgboost 41 | 42 | 43 | 44 | ml.dmlc 45 | xgboost4j-spark_${scala.binary.version} 46 | provided 47 | 48 | 49 | 50 | org.apache.spark 51 | spark-core_${scala.binary.version} 52 | provided 53 | 54 | 55 | org.apache.spark 56 | spark-mllib_${scala.binary.version} 57 | provided 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/AliasExpressionTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import org.dmg.pmml.Expression; 22 | import org.dmg.pmml.FieldRef; 23 | import org.dmg.pmml.PMMLFunctions; 24 | import org.jpmml.converter.ExpressionUtil; 25 | import org.jpmml.model.ReflectionUtil; 26 | import org.junit.jupiter.api.Test; 27 | 28 | import static org.junit.jupiter.api.Assertions.assertTrue; 29 | 30 | public class AliasExpressionTest { 31 | 32 | @Test 33 | public void unwrap(){ 34 | FieldRef fieldRef = new FieldRef("x"); 35 | 36 | Expression expression = new AliasExpression("parent", new AliasExpression("child", fieldRef)); 37 | 38 | checkExpression(fieldRef, expression); 39 | 40 | expression = new AliasExpression("parent", ExpressionUtil.createApply(PMMLFunctions.ADD, new AliasExpression("left child", fieldRef), new AliasExpression("right child", fieldRef))); 41 | 42 | checkExpression(ExpressionUtil.createApply(PMMLFunctions.ADD, fieldRef, fieldRef), expression); 43 | } 44 | 45 | static 46 | private void checkExpression(Expression expected, Expression expression){ 47 | expression = AliasExpression.unwrap(expression); 48 | 49 | assertTrue(ReflectionUtil.equals(expected, expression)); 50 | } 51 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/PredictionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.PredictionModel; 24 | import org.apache.spark.ml.linalg.Vector; 25 | import org.apache.spark.ml.param.shared.HasFeaturesCol; 26 | import org.apache.spark.ml.param.shared.HasLabelCol; 27 | import org.apache.spark.ml.param.shared.HasPredictionCol; 28 | import org.jpmml.converter.Feature; 29 | import org.jpmml.converter.SchemaUtil; 30 | import org.jpmml.sparkml.model.HasPredictionModelOptions; 31 | 32 | abstract 33 | public class PredictionModelConverter & HasLabelCol & HasFeaturesCol & HasPredictionCol> extends ModelConverter implements HasPredictionModelOptions { 34 | 35 | public PredictionModelConverter(T model){ 36 | super(model); 37 | } 38 | 39 | @Override 40 | public List getFeatures(SparkMLEncoder encoder){ 41 | T model = getModel(); 42 | 43 | String featuresCol = model.getFeaturesCol(); 44 | 45 | List features = encoder.getFeatures(featuresCol); 46 | 47 | int numFeatures = model.numFeatures(); 48 | if(numFeatures != -1){ 49 | SchemaUtil.checkSize(numFeatures, features); 50 | } 51 | 52 | return features; 53 | } 54 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/MatrixUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.linalg.Matrix; 25 | 26 | public class MatrixUtil { 27 | 28 | private MatrixUtil(){ 29 | } 30 | 31 | public void checkColumns(int columns, Matrix matrix){ 32 | 33 | if(matrix.numCols() != columns){ 34 | throw new IllegalArgumentException("Expected " + columns + " column(s), got " + matrix.numCols() + " column(s)"); 35 | } 36 | } 37 | 38 | static 39 | public void checkRows(int rows, Matrix matrix){ 40 | 41 | if(matrix.numRows() != rows){ 42 | throw new IllegalArgumentException("Expected " + rows + " row(s), got " + matrix.numRows() + " row(s)"); 43 | } 44 | } 45 | 46 | static 47 | public List getRow(Matrix matrix, int row){ 48 | List result = new ArrayList<>(); 49 | 50 | for(int column = 0; column < matrix.numCols(); column++){ 51 | result.add(matrix.apply(row, column)); 52 | } 53 | 54 | return result; 55 | } 56 | 57 | static 58 | public List getColumn(Matrix matrix, int column){ 59 | List result = new ArrayList<>(); 60 | 61 | for(int row = 0; row < matrix.numRows(); row++){ 62 | result.add(matrix.apply(row, column)); 63 | } 64 | 65 | return result; 66 | } 67 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/java/org/jpmml/sparkml/lightgbm/LightGBMClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm; 20 | 21 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassificationModel; 22 | import org.dmg.pmml.mining.MiningModel; 23 | import org.dmg.pmml.regression.RegressionModel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.converter.mining.MiningModelUtil; 26 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 27 | 28 | public class LightGBMClassificationModelConverter extends ProbabilisticClassificationModelConverter { 29 | 30 | public LightGBMClassificationModelConverter(LightGBMClassificationModel model){ 31 | super(model); 32 | } 33 | 34 | @Override 35 | public int getNumberOfClasses(){ 36 | int numberOfClasses = super.getNumberOfClasses(); 37 | 38 | if(numberOfClasses == 1){ 39 | return 2; 40 | } 41 | 42 | return numberOfClasses; 43 | } 44 | 45 | @Override 46 | public MiningModel encodeModel(Schema schema){ 47 | LightGBMClassificationModel model = getModel(); 48 | 49 | MiningModel miningModel = BoosterUtil.encodeModel(this, schema); 50 | 51 | RegressionModel regressionModel = (RegressionModel)MiningModelUtil.getFinalModel(miningModel); 52 | regressionModel.setOutput(null); 53 | 54 | return miningModel; 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.2-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml-lightgbm 13 | jar 14 | 15 | JPMML Spark ML LightGBM converter 16 | JPMML Apache Spark ML LightGBM to PMML converter 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-sparkml 30 | 31 | 32 | 33 | org.jpmml 34 | pmml-evaluator-testing 35 | provided 36 | 37 | 38 | 39 | org.jpmml 40 | pmml-lightgbm 41 | 42 | 43 | 44 | com.microsoft.azure 45 | synapseml-lightgbm_2.12 46 | provided 47 | 48 | 49 | 50 | 51 | org.scala-lang 52 | scala-library 53 | 2.12.20 54 | provided 55 | 56 | 57 | 58 | org.apache.spark 59 | spark-core_${scala.binary.version} 60 | provided 61 | 62 | 63 | org.apache.spark 64 | spark-mllib_${scala.binary.version} 65 | provided 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /pmml-sparkml-evaluator/src/main/java/org/jpmml/sparkml/evaluator/SparkMLFunctionRegistry.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.evaluator; 20 | 21 | import java.util.Collections; 22 | import java.util.Map; 23 | import java.util.Objects; 24 | import java.util.function.Predicate; 25 | 26 | import org.jpmml.evaluator.Function; 27 | import org.jpmml.evaluator.FunctionRegistry; 28 | 29 | /** 30 | * @see FunctionRegistry 31 | */ 32 | public class SparkMLFunctionRegistry { 33 | 34 | private SparkMLFunctionRegistry(){ 35 | } 36 | 37 | static 38 | public void publish(String name){ 39 | publish(key -> Objects.equals(name, key)); 40 | } 41 | 42 | static 43 | public void publishAll(){ 44 | publish(key -> true); 45 | } 46 | 47 | static 48 | private void publish(Predicate predicate){ 49 | (SparkMLFunctionRegistry.functions.entrySet()).stream() 50 | .filter(entry -> predicate.test(entry.getKey())) 51 | .forEach(entry -> FunctionRegistry.putFunction(entry.getKey(), entry.getValue())); 52 | 53 | (SparkMLFunctionRegistry.functionClazzes.entrySet()).stream() 54 | .filter(entry -> predicate.test(entry.getKey())) 55 | .forEach(entry -> FunctionRegistry.putFunction(entry.getKey(), entry.getValue())); 56 | } 57 | 58 | private static final Map functions = Collections.emptyMap(); 59 | private static final Map> functionClazzes = Collections.emptyMap(); 60 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/TokenizerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.Tokenizer; 25 | import org.dmg.pmml.Apply; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.DerivedField; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.PMMLFunctions; 30 | import org.jpmml.converter.ExpressionUtil; 31 | import org.jpmml.converter.Feature; 32 | import org.jpmml.converter.FieldNameUtil; 33 | import org.jpmml.sparkml.DocumentFeature; 34 | import org.jpmml.sparkml.FeatureConverter; 35 | import org.jpmml.sparkml.SparkMLEncoder; 36 | 37 | public class TokenizerConverter extends FeatureConverter { 38 | 39 | public TokenizerConverter(Tokenizer transformer){ 40 | super(transformer); 41 | } 42 | 43 | @Override 44 | public List encodeFeatures(SparkMLEncoder encoder){ 45 | Tokenizer transformer = getTransformer(); 46 | 47 | Feature feature = encoder.getOnlyFeature(transformer.getInputCol()); 48 | 49 | Apply apply = ExpressionUtil.createApply(PMMLFunctions.LOWERCASE, feature.ref()); 50 | 51 | DerivedField derivedField = encoder.createDerivedField(FieldNameUtil.create(PMMLFunctions.LOWERCASE, feature), OpType.CATEGORICAL, DataType.STRING, apply); 52 | 53 | return Collections.singletonList(new DocumentFeature(encoder, derivedField, "\\s+")); 54 | } 55 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/LogisticRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.LogisticRegressionModel; 22 | import org.dmg.pmml.Model; 23 | import org.jpmml.converter.CategoricalLabel; 24 | import org.jpmml.converter.Schema; 25 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 26 | 27 | public class LogisticRegressionModelConverter extends ProbabilisticClassificationModelConverter implements HasRegressionTableOptions { 28 | 29 | public LogisticRegressionModelConverter(LogisticRegressionModel model){ 30 | super(model); 31 | } 32 | 33 | @Override 34 | public Model encodeModel(Schema schema){ 35 | LogisticRegressionModel model = getModel(); 36 | 37 | CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); 38 | 39 | if(categoricalLabel.size() == 2){ 40 | Model linearModel = LinearModelUtil.createBinaryLogisticClassification(this, model.coefficients(), model.intercept(), schema) 41 | .setOutput(null); 42 | 43 | return linearModel; 44 | } else 45 | 46 | if(categoricalLabel.size() > 2){ 47 | Model linearModel = LinearModelUtil.createSoftmaxClassification(this, model.coefficientMatrix(), model.interceptVector(), schema) 48 | .setOutput(null); 49 | 50 | return linearModel; 51 | } else 52 | 53 | { 54 | throw new IllegalArgumentException(); 55 | } 56 | } 57 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/NaiveBayesModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.NaiveBayesModel; 22 | import org.dmg.pmml.Model; 23 | import org.jpmml.converter.Schema; 24 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 25 | 26 | public class NaiveBayesModelConverter extends ProbabilisticClassificationModelConverter implements HasRegressionTableOptions { 27 | 28 | public NaiveBayesModelConverter(NaiveBayesModel model){ 29 | super(model); 30 | } 31 | 32 | @Override 33 | public Model encodeModel(Schema schema){ 34 | NaiveBayesModel model = getModel(); 35 | 36 | String modelType = model.getModelType(); 37 | switch(modelType){ 38 | case "multinomial": 39 | break; 40 | default: 41 | throw new IllegalArgumentException("Model type " + modelType + " is not supported"); 42 | } 43 | 44 | if(model.isSet(model.thresholds())){ 45 | double[] thresholds = model.getThresholds(); 46 | 47 | for(int i = 0; i < thresholds.length; i++){ 48 | double threshold = thresholds[i]; 49 | 50 | if(threshold != 0d){ 51 | throw new IllegalArgumentException("Non-zero thresholds are not supported"); 52 | } 53 | } 54 | } 55 | 56 | Model linearModel = LinearModelUtil.createSoftmaxClassification(this, model.theta(), model.pi(), schema) 57 | .setOutput(null); 58 | 59 | return linearModel; 60 | } 61 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/RegexKey.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Objects; 22 | import java.util.function.Predicate; 23 | import java.util.regex.Matcher; 24 | import java.util.regex.Pattern; 25 | 26 | public class RegexKey implements Predicate { 27 | 28 | private Pattern pattern = null; 29 | 30 | 31 | public RegexKey(Pattern pattern){ 32 | setPattern(pattern); 33 | } 34 | 35 | @Override 36 | public boolean test(String string){ 37 | Pattern pattern = getPattern(); 38 | 39 | Matcher matcher = pattern.matcher(string); 40 | 41 | return matcher.matches(); 42 | } 43 | 44 | @Override 45 | public int hashCode(){ 46 | return hashCode(getPattern()); 47 | } 48 | 49 | @Override 50 | public boolean equals(Object object){ 51 | 52 | if(object instanceof RegexKey){ 53 | RegexKey that = (RegexKey)object; 54 | 55 | return equals(this.getPattern(), that.getPattern()); 56 | } 57 | 58 | return false; 59 | } 60 | 61 | public Pattern getPattern(){ 62 | return this.pattern; 63 | } 64 | 65 | private void setPattern(Pattern pattern){ 66 | this.pattern = pattern; 67 | } 68 | 69 | static 70 | private int hashCode(Pattern pattern){ 71 | return Objects.hash(pattern.pattern(), pattern.flags()); 72 | } 73 | 74 | static 75 | private boolean equals(Pattern left, Pattern right){ 76 | return Objects.equals(left.pattern(), right.pattern()) && Objects.equals(left.flags(), right.flags()); 77 | } 78 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/MultiFeatureConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.Transformer; 24 | import org.apache.spark.ml.param.shared.HasInputCol; 25 | import org.apache.spark.ml.param.shared.HasInputCols; 26 | import org.apache.spark.ml.param.shared.HasOutputCol; 27 | import org.apache.spark.ml.param.shared.HasOutputCols; 28 | 29 | abstract 30 | public class MultiFeatureConverter extends FeatureConverter { 31 | 32 | public MultiFeatureConverter(T transformer){ 33 | super(transformer); 34 | } 35 | 36 | protected String formatMultiName(int index, int length, SparkMLEncoder encoder){ 37 | T transformer = getTransformer(); 38 | 39 | if(transformer.isSet(transformer.outputCols())){ 40 | String[] outputCols = transformer.getOutputCols(); 41 | 42 | String outputCol = outputCols[index]; 43 | 44 | return encoder.mapOnlyFieldName(outputCol); 45 | } 46 | 47 | if(index != 0){ 48 | throw new IllegalArgumentException(); 49 | } 50 | 51 | List names = formatNames(length, encoder); 52 | 53 | return names.get(index); 54 | } 55 | 56 | @Override 57 | protected InOutMode getInputMode(){ 58 | T transformer = getTransformer(); 59 | 60 | return getInputMode(transformer); 61 | } 62 | 63 | @Override 64 | public InOutMode getOutputMode(){ 65 | return getInputMode(); 66 | } 67 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/RandomForestRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.linalg.Vector; 24 | import org.apache.spark.ml.regression.RandomForestRegressionModel; 25 | import org.dmg.pmml.MiningFunction; 26 | import org.dmg.pmml.mining.MiningModel; 27 | import org.dmg.pmml.mining.Segmentation; 28 | import org.dmg.pmml.tree.TreeModel; 29 | import org.jpmml.converter.ModelUtil; 30 | import org.jpmml.converter.Schema; 31 | import org.jpmml.converter.mining.MiningModelUtil; 32 | import org.jpmml.sparkml.RegressionModelConverter; 33 | 34 | public class RandomForestRegressionModelConverter extends RegressionModelConverter implements HasFeatureImportances, HasTreeOptions { 35 | 36 | public RandomForestRegressionModelConverter(RandomForestRegressionModel model){ 37 | super(model); 38 | } 39 | 40 | @Override 41 | public Vector getFeatureImportances(){ 42 | RandomForestRegressionModel model = getModel(); 43 | 44 | return model.featureImportances(); 45 | } 46 | 47 | @Override 48 | public MiningModel encodeModel(Schema schema){ 49 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, schema); 50 | 51 | MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())) 52 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)); 53 | 54 | return miningModel; 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/java/org/jpmml/sparkml/xgboost/XGBoostClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.xgboost; 20 | 21 | import ml.dmlc.xgboost4j.scala.Booster; 22 | import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel; 23 | import org.dmg.pmml.DataType; 24 | import org.dmg.pmml.mining.MiningModel; 25 | import org.dmg.pmml.regression.RegressionModel; 26 | import org.jpmml.converter.Schema; 27 | import org.jpmml.converter.mining.MiningModelUtil; 28 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 29 | 30 | public class XGBoostClassificationModelConverter extends ProbabilisticClassificationModelConverter { 31 | 32 | public XGBoostClassificationModelConverter(XGBoostClassificationModel model){ 33 | super(model); 34 | } 35 | 36 | @Override 37 | public int getNumberOfClasses(){ 38 | XGBoostClassificationModel model = getModel(); 39 | 40 | int numClass = model.getNumClass(); 41 | if(numClass != 0){ 42 | return numClass; 43 | } 44 | 45 | return super.getNumberOfClasses(); 46 | } 47 | 48 | @Override 49 | public DataType getDataType(){ 50 | return DataType.FLOAT; 51 | } 52 | 53 | @Override 54 | public MiningModel encodeModel(Schema schema){ 55 | XGBoostClassificationModel model = getModel(); 56 | 57 | Booster booster = model.nativeBooster(); 58 | 59 | MiningModel miningModel = BoosterUtil.encodeBooster(this, booster, schema); 60 | 61 | RegressionModel regressionModel = (RegressionModel)MiningModelUtil.getFinalModel(miningModel); 62 | regressionModel.setOutput(null); 63 | 64 | return miningModel; 65 | } 66 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/RandomForestClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.classification.RandomForestClassificationModel; 24 | import org.apache.spark.ml.linalg.Vector; 25 | import org.dmg.pmml.MiningFunction; 26 | import org.dmg.pmml.mining.MiningModel; 27 | import org.dmg.pmml.mining.Segmentation; 28 | import org.dmg.pmml.tree.TreeModel; 29 | import org.jpmml.converter.ModelUtil; 30 | import org.jpmml.converter.Schema; 31 | import org.jpmml.converter.mining.MiningModelUtil; 32 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 33 | 34 | public class RandomForestClassificationModelConverter extends ProbabilisticClassificationModelConverter implements HasFeatureImportances, HasTreeOptions { 35 | 36 | public RandomForestClassificationModelConverter(RandomForestClassificationModel model){ 37 | super(model); 38 | } 39 | 40 | @Override 41 | public Vector getFeatureImportances(){ 42 | RandomForestClassificationModel model = getModel(); 43 | 44 | return model.featureImportances(); 45 | } 46 | 47 | @Override 48 | public MiningModel encodeModel(Schema schema){ 49 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, schema); 50 | 51 | MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel())) 52 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)); 53 | 54 | return miningModel; 55 | } 56 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/GBTRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.primitives.Doubles; 24 | import org.apache.spark.ml.linalg.Vector; 25 | import org.apache.spark.ml.regression.GBTRegressionModel; 26 | import org.dmg.pmml.MiningFunction; 27 | import org.dmg.pmml.mining.MiningModel; 28 | import org.dmg.pmml.mining.Segmentation; 29 | import org.dmg.pmml.tree.TreeModel; 30 | import org.jpmml.converter.ModelUtil; 31 | import org.jpmml.converter.Schema; 32 | import org.jpmml.converter.mining.MiningModelUtil; 33 | import org.jpmml.sparkml.RegressionModelConverter; 34 | 35 | public class GBTRegressionModelConverter extends RegressionModelConverter implements HasFeatureImportances, HasTreeOptions { 36 | 37 | public GBTRegressionModelConverter(GBTRegressionModel model){ 38 | super(model); 39 | } 40 | 41 | @Override 42 | public Vector getFeatureImportances(){ 43 | GBTRegressionModel model = getModel(); 44 | 45 | return model.featureImportances(); 46 | } 47 | 48 | @Override 49 | public MiningModel encodeModel(Schema schema){ 50 | GBTRegressionModel model = getModel(); 51 | 52 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, schema); 53 | 54 | MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())) 55 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels, Doubles.asList(model.treeWeights()))); 56 | 57 | return miningModel; 58 | } 59 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/WeightedTermFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Objects; 22 | 23 | import org.dmg.pmml.Apply; 24 | import org.dmg.pmml.DefineFunction; 25 | import org.jpmml.converter.ExpressionUtil; 26 | import org.jpmml.converter.Feature; 27 | import org.jpmml.converter.PMMLEncoder; 28 | import org.jpmml.model.ToStringHelper; 29 | 30 | public class WeightedTermFeature extends TermFeature { 31 | 32 | private Number weight = null; 33 | 34 | 35 | public WeightedTermFeature(PMMLEncoder encoder, DefineFunction defineFunction, Feature feature, String value, Number weight){ 36 | super(encoder, defineFunction, feature, value); 37 | 38 | setWeight(weight); 39 | } 40 | 41 | @Override 42 | public Apply createApply(){ 43 | Number weight = getWeight(); 44 | 45 | Apply apply = super.createApply() 46 | .addExpressions(ExpressionUtil.createConstant(weight)); 47 | 48 | return apply; 49 | } 50 | 51 | @Override 52 | public int hashCode(){ 53 | return (31 * super.hashCode()) + Objects.hashCode(this.getWeight()); 54 | } 55 | 56 | @Override 57 | public boolean equals(Object object){ 58 | 59 | if(object instanceof WeightedTermFeature){ 60 | WeightedTermFeature that = (WeightedTermFeature)object; 61 | 62 | return super.equals(object) && Objects.equals(this.getWeight(), that.getWeight()); 63 | } 64 | 65 | return false; 66 | } 67 | 68 | @Override 69 | protected ToStringHelper toStringHelper(){ 70 | return super.toStringHelper() 71 | .add("weight", getWeight()); 72 | } 73 | 74 | public Number getWeight(){ 75 | return this.weight; 76 | } 77 | 78 | private void setWeight(Number weight){ 79 | this.weight = Objects.requireNonNull(weight); 80 | } 81 | } -------------------------------------------------------------------------------- /pmml-sparkml-xgboost/src/main/java/org/jpmml/sparkml/xgboost/XGBoostRegressionModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.xgboost; 20 | 21 | import java.util.List; 22 | 23 | import ml.dmlc.xgboost4j.scala.Booster; 24 | import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel; 25 | import org.dmg.pmml.DataType; 26 | import org.dmg.pmml.MiningFunction; 27 | import org.dmg.pmml.Model; 28 | import org.dmg.pmml.OutputField; 29 | import org.dmg.pmml.mining.MiningModel; 30 | import org.jpmml.converter.ContinuousLabel; 31 | import org.jpmml.converter.Label; 32 | import org.jpmml.converter.Schema; 33 | import org.jpmml.sparkml.PredictionModelConverter; 34 | import org.jpmml.sparkml.RegressionModelConverter; 35 | import org.jpmml.sparkml.SparkMLEncoder; 36 | 37 | public class XGBoostRegressionModelConverter extends PredictionModelConverter { 38 | 39 | public XGBoostRegressionModelConverter(XGBoostRegressionModel model){ 40 | super(model); 41 | } 42 | 43 | @Override 44 | public MiningFunction getMiningFunction(){ 45 | return MiningFunction.REGRESSION; 46 | } 47 | 48 | @Override 49 | public DataType getDataType(){ 50 | return DataType.FLOAT; 51 | } 52 | 53 | @Override 54 | public ContinuousLabel getLabel(SparkMLEncoder encoder){ 55 | return RegressionModelConverter.getLabel(this, encoder); 56 | } 57 | 58 | @Override 59 | public MiningModel encodeModel(Schema schema){ 60 | XGBoostRegressionModel model = getModel(); 61 | 62 | Booster booster = model.nativeBooster(); 63 | 64 | return BoosterUtil.encodeBooster(this, booster, schema); 65 | } 66 | 67 | @Override 68 | public List registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder){ 69 | return RegressionModelConverter.registerPredictionOutputField(this, label, pmmlModel, encoder); 70 | } 71 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/RegexTokenizerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.RegexTokenizer; 25 | import org.dmg.pmml.Apply; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.Field; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.PMMLFunctions; 30 | import org.jpmml.converter.ExpressionUtil; 31 | import org.jpmml.converter.Feature; 32 | import org.jpmml.converter.FieldNameUtil; 33 | import org.jpmml.sparkml.DocumentFeature; 34 | import org.jpmml.sparkml.FeatureConverter; 35 | import org.jpmml.sparkml.SparkMLEncoder; 36 | 37 | public class RegexTokenizerConverter extends FeatureConverter { 38 | 39 | public RegexTokenizerConverter(RegexTokenizer transformer){ 40 | super(transformer); 41 | } 42 | 43 | @Override 44 | public List encodeFeatures(SparkMLEncoder encoder){ 45 | RegexTokenizer transformer = getTransformer(); 46 | 47 | if(!transformer.getGaps()){ 48 | throw new IllegalArgumentException("Expected splitter mode, got token matching mode"); 49 | } // End if 50 | 51 | if(transformer.getMinTokenLength() != 1){ 52 | throw new IllegalArgumentException("Expected 1 as minimum token length, got " + transformer.getMinTokenLength() + " as minimum token length"); 53 | } 54 | 55 | Feature feature = encoder.getOnlyFeature(transformer.getInputCol()); 56 | 57 | Field field = feature.getField(); 58 | 59 | if(transformer.getToLowercase()){ 60 | Apply apply = ExpressionUtil.createApply(PMMLFunctions.LOWERCASE, feature.ref()); 61 | 62 | field = encoder.createDerivedField(FieldNameUtil.create(PMMLFunctions.LOWERCASE, feature), OpType.CATEGORICAL, DataType.STRING, apply); 63 | } 64 | 65 | return Collections.singletonList(new DocumentFeature(encoder, field, transformer.getPattern())); 66 | } 67 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/IDFModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.IDFModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.jpmml.converter.ContinuousFeature; 27 | import org.jpmml.converter.Feature; 28 | import org.jpmml.converter.ProductFeature; 29 | import org.jpmml.converter.SchemaUtil; 30 | import org.jpmml.sparkml.FeatureConverter; 31 | import org.jpmml.sparkml.SparkMLEncoder; 32 | import org.jpmml.sparkml.TermFeature; 33 | import org.jpmml.sparkml.WeightedTermFeature; 34 | 35 | public class IDFModelConverter extends FeatureConverter { 36 | 37 | public IDFModelConverter(IDFModel transformer){ 38 | super(transformer); 39 | } 40 | 41 | @Override 42 | public List encodeFeatures(SparkMLEncoder encoder){ 43 | IDFModel transformer = getTransformer(); 44 | 45 | Vector idf = transformer.idf(); 46 | 47 | List features = encoder.getFeatures(transformer.getInputCol()); 48 | 49 | SchemaUtil.checkSize(idf.size(), features); 50 | 51 | List result = new ArrayList<>(); 52 | 53 | for(int i = 0; i < features.size(); i++){ 54 | Feature feature = features.get(i); 55 | Double weight = idf.apply(i); 56 | 57 | ProductFeature productFeature = new ProductFeature(encoder, feature, weight){ 58 | 59 | private WeightedTermFeature weightedTermFeature = null; 60 | 61 | 62 | @Override 63 | public ContinuousFeature toContinuousFeature(){ 64 | 65 | if(this.weightedTermFeature == null){ 66 | TermFeature termFeature = (TermFeature)getFeature(); 67 | Number factor = getFactor(); 68 | 69 | this.weightedTermFeature = termFeature.toWeightedTermFeature(factor); 70 | } 71 | 72 | return this.weightedTermFeature.toContinuousFeature(); 73 | } 74 | }; 75 | 76 | result.add(productFeature); 77 | } 78 | 79 | return result; 80 | } 81 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/testing/AssociationRulesTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import java.io.File; 22 | import java.util.List; 23 | import java.util.function.Predicate; 24 | 25 | import com.google.common.base.Equivalence; 26 | import com.google.common.collect.Iterables; 27 | import org.apache.spark.sql.Dataset; 28 | import org.apache.spark.sql.Row; 29 | import org.apache.spark.sql.SparkSession; 30 | import org.dmg.pmml.Model; 31 | import org.dmg.pmml.PMML; 32 | import org.dmg.pmml.association.AssociationModel; 33 | import org.jpmml.evaluator.ResultField; 34 | import org.junit.jupiter.api.Test; 35 | 36 | import static org.junit.jupiter.api.Assertions.assertInstanceOf; 37 | 38 | public class AssociationRulesTest extends SimpleSparkMLEncoderBatchTest implements SparkMLAlgorithms, SparkMLDatasets { 39 | 40 | @Override 41 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 42 | columnFilter = columnFilter.and(excludePredictionFields()); 43 | 44 | SparkMLEncoderBatch result = new SparkMLEncoderBatch(algorithm, dataset, columnFilter, equivalence){ 45 | 46 | @Override 47 | public AssociationRulesTest getArchiveBatchTest(){ 48 | return AssociationRulesTest.this; 49 | } 50 | 51 | @Override 52 | protected Dataset loadVerificationDataset(SparkSession sparkSession, List tmpResources){ 53 | return null; 54 | } 55 | }; 56 | 57 | return result; 58 | } 59 | 60 | @Test 61 | public void evaluateFPGrowthShopping() throws Exception { 62 | Predicate predicate = (resultField -> true); 63 | Equivalence equivalence = getEquivalence(); 64 | 65 | try(SparkMLEncoderBatch batch = createBatch(FP_GROWTH, SHOPPING, predicate, equivalence)){ 66 | PMML pmml = batch.getPMML(); 67 | 68 | Model model = Iterables.getOnlyElement(pmml.getModels()); 69 | 70 | assertInstanceOf(AssociationModel.class, model); 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/RFormulaModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.List; 22 | 23 | import org.apache.spark.ml.PipelineModel; 24 | import org.apache.spark.ml.Transformer; 25 | import org.apache.spark.ml.feature.RFormulaModel; 26 | import org.apache.spark.ml.feature.ResolvedRFormula; 27 | import org.jpmml.converter.Feature; 28 | import org.jpmml.sparkml.ConverterFactory; 29 | import org.jpmml.sparkml.FeatureConverter; 30 | import org.jpmml.sparkml.SparkMLEncoder; 31 | import org.jpmml.sparkml.TransformerConverter; 32 | 33 | public class RFormulaModelConverter extends FeatureConverter { 34 | 35 | public RFormulaModelConverter(RFormulaModel transformer){ 36 | super(transformer); 37 | } 38 | 39 | @Override 40 | public void registerFeatures(SparkMLEncoder encoder){ 41 | RFormulaModel transformer = getTransformer(); 42 | 43 | ResolvedRFormula resolvedFormula = transformer.resolvedFormula(); 44 | 45 | String targetCol = resolvedFormula.label(); 46 | 47 | String labelCol = transformer.getLabelCol(); 48 | if(!(targetCol).equals(labelCol)){ 49 | List features = encoder.getFeatures(targetCol); 50 | 51 | encoder.putFeatures(labelCol, features); 52 | } 53 | 54 | ConverterFactory converterFactory = encoder.getConverterFactory(); 55 | 56 | PipelineModel pipelineModel = transformer.pipelineModel(); 57 | 58 | Transformer[] stages = pipelineModel.stages(); 59 | for(Transformer stage : stages){ 60 | TransformerConverter converter = converterFactory.newConverter(stage); 61 | 62 | if(converter instanceof FeatureConverter){ 63 | FeatureConverter featureConverter = (FeatureConverter)converter; 64 | 65 | featureConverter.registerFeatures(encoder); 66 | } else 67 | 68 | { 69 | throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null)); 70 | } 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/BinarizedCategoricalFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.List; 22 | import java.util.Objects; 23 | 24 | import org.jpmml.converter.BinaryFeature; 25 | import org.jpmml.converter.CategoricalFeature; 26 | import org.jpmml.converter.ContinuousFeature; 27 | import org.jpmml.converter.Feature; 28 | import org.jpmml.converter.PMMLEncoder; 29 | import org.jpmml.model.ToStringHelper; 30 | 31 | public class BinarizedCategoricalFeature extends Feature { 32 | 33 | private List binaryFeatures = null; 34 | 35 | 36 | public BinarizedCategoricalFeature(PMMLEncoder encoder, CategoricalFeature categoricalFeature, List binaryFeatures){ 37 | super(encoder, categoricalFeature.getName(), categoricalFeature.getDataType()); 38 | 39 | setBinaryFeatures(binaryFeatures); 40 | } 41 | 42 | @Override 43 | public ContinuousFeature toContinuousFeature(){ 44 | throw new UnsupportedOperationException(); 45 | } 46 | 47 | @Override 48 | public int hashCode(){ 49 | return (31 * super.hashCode()) + Objects.hashCode(this.getBinaryFeatures()); 50 | } 51 | 52 | @Override 53 | public boolean equals(Object object){ 54 | 55 | if(object instanceof BinarizedCategoricalFeature){ 56 | BinarizedCategoricalFeature that = (BinarizedCategoricalFeature)object; 57 | 58 | return super.equals(that) && Objects.equals(this.getBinaryFeatures(), that.getBinaryFeatures()); 59 | } 60 | 61 | return false; 62 | } 63 | 64 | @Override 65 | protected ToStringHelper toStringHelper(){ 66 | return super.toStringHelper() 67 | .add("binaryFeatures", getBinaryFeatures()); 68 | } 69 | 70 | public List getBinaryFeatures(){ 71 | return this.binaryFeatures; 72 | } 73 | 74 | private void setBinaryFeatures(List binaryFeatures){ 75 | 76 | if(binaryFeatures != null && binaryFeatures.isEmpty()){ 77 | throw new IllegalArgumentException(); 78 | } 79 | 80 | this.binaryFeatures = Objects.requireNonNull(binaryFeatures); 81 | } 82 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/KMeansModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.clustering.KMeansModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.CompareFunction; 27 | import org.dmg.pmml.ComparisonMeasure; 28 | import org.dmg.pmml.MiningFunction; 29 | import org.dmg.pmml.SquaredEuclidean; 30 | import org.dmg.pmml.clustering.Cluster; 31 | import org.dmg.pmml.clustering.ClusteringModel; 32 | import org.jpmml.converter.ModelUtil; 33 | import org.jpmml.converter.PMMLUtil; 34 | import org.jpmml.converter.Schema; 35 | import org.jpmml.converter.clustering.ClusteringModelUtil; 36 | import org.jpmml.sparkml.ClusteringModelConverter; 37 | import org.jpmml.sparkml.VectorUtil; 38 | 39 | public class KMeansModelConverter extends ClusteringModelConverter { 40 | 41 | public KMeansModelConverter(KMeansModel model){ 42 | super(model); 43 | } 44 | 45 | @Override 46 | public int getNumberOfClusters(){ 47 | KMeansModel model = getModel(); 48 | 49 | return model.getK(); 50 | } 51 | 52 | @Override 53 | public ClusteringModel encodeModel(Schema schema){ 54 | KMeansModel model = getModel(); 55 | 56 | List clusters = new ArrayList<>(); 57 | 58 | Vector[] clusterCenters = model.clusterCenters(); 59 | for(int i = 0; i < clusterCenters.length; i++){ 60 | Cluster cluster = new Cluster(PMMLUtil.createRealArray(VectorUtil.toList(clusterCenters[i]))) 61 | .setId(String.valueOf(i)); 62 | 63 | clusters.add(cluster); 64 | } 65 | 66 | ComparisonMeasure comparisonMeasure = new ComparisonMeasure(ComparisonMeasure.Kind.DISTANCE, new SquaredEuclidean()) 67 | .setCompareFunction(CompareFunction.ABS_DIFF); 68 | 69 | return new ClusteringModel(MiningFunction.CLUSTERING, ClusteringModel.ModelClass.CENTER_BASED, clusters.size(), ModelUtil.createMiningSchema(schema.getLabel()), comparisonMeasure, ClusteringModelUtil.createClusteringFields(schema.getFeatures()), clusters); 70 | } 71 | } -------------------------------------------------------------------------------- /pmml-sparkml/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | 5 | 6 | org.jpmml 7 | jpmml-sparkml 8 | 3.2-SNAPSHOT 9 | 10 | 11 | org.jpmml 12 | pmml-sparkml 13 | jar 14 | 15 | JPMML Spark ML converter 16 | JPMML Apache Spark ML to PMML converter 17 | 18 | 19 | 20 | GNU Affero General Public License (AGPL) version 3.0 21 | http://www.gnu.org/licenses/agpl-3.0.html 22 | repo 23 | 24 | 25 | 26 | 27 | 28 | org.jpmml 29 | pmml-converter 30 | 31 | 32 | org.jpmml 33 | pmml-converter-testing 34 | 35 | 36 | 37 | org.apache.spark 38 | spark-core_${scala.binary.version} 39 | provided 40 | 41 | 42 | org.apache.spark 43 | spark-mllib_${scala.binary.version} 44 | provided 45 | 46 | 47 | 48 | org.jpmml 49 | pmml-evaluator-testing 50 | provided 51 | 52 | 53 | 54 | org.apache.hadoop 55 | hadoop-client 56 | 57 | 58 | 59 | 60 | 61 | 62 | org.apache.maven.plugins 63 | maven-jar-plugin 64 | 65 | 66 | 67 | JPMML-SparkML library 68 | ${project.version} 69 | 70 | 71 | 72 | 73 | 74 | org.apache.maven.plugins 75 | maven-javadoc-plugin 76 | 77 | 78 | 79 | ${project.groupId} 80 | ${project.artifactId} 81 | ${project.version} 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/BinarizerConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import org.apache.spark.ml.feature.Binarizer; 26 | import org.dmg.pmml.Apply; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.DerivedField; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.PMMLFunctions; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.ExpressionUtil; 33 | import org.jpmml.converter.Feature; 34 | import org.jpmml.converter.IndexFeature; 35 | import org.jpmml.sparkml.MultiFeatureConverter; 36 | import org.jpmml.sparkml.SparkMLEncoder; 37 | 38 | public class BinarizerConverter extends MultiFeatureConverter { 39 | 40 | public BinarizerConverter(Binarizer transformer){ 41 | super(transformer); 42 | } 43 | 44 | @Override 45 | public List encodeFeatures(SparkMLEncoder encoder){ 46 | Binarizer transformer = getTransformer(); 47 | 48 | Double threshold = transformer.getThreshold(); 49 | 50 | InOutMode inputMode = getInputMode(); 51 | 52 | List result = new ArrayList<>(); 53 | 54 | String[] inputCols = inputMode.getInputCols(transformer); 55 | for(int i = 0, length = inputCols.length; i < length; i++){ 56 | String inputCol = inputCols[i]; 57 | 58 | Feature feature = encoder.getOnlyFeature(inputCol); 59 | 60 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 61 | 62 | Apply apply = new Apply(PMMLFunctions.IF) 63 | .addExpressions(ExpressionUtil.createApply(PMMLFunctions.LESSOREQUAL, continuousFeature.ref(), ExpressionUtil.createConstant(threshold))) 64 | .addExpressions(ExpressionUtil.createConstant(0d), ExpressionUtil.createConstant(1d)); 65 | 66 | DerivedField derivedField = encoder.createDerivedField(formatMultiName(i, length, encoder), OpType.CATEGORICAL, DataType.DOUBLE, apply); 67 | 68 | result.add(new IndexFeature(encoder, derivedField, Arrays.asList(0d, 1d))); 69 | } 70 | 71 | return result; 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/StopWordsRemoverConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.regex.Pattern; 24 | 25 | import org.apache.spark.ml.feature.StopWordsRemover; 26 | import org.jpmml.converter.Feature; 27 | import org.jpmml.sparkml.DocumentFeature; 28 | import org.jpmml.sparkml.MultiFeatureConverter; 29 | import org.jpmml.sparkml.SparkMLEncoder; 30 | import org.jpmml.sparkml.TermUtil; 31 | 32 | public class StopWordsRemoverConverter extends MultiFeatureConverter { 33 | 34 | public StopWordsRemoverConverter(StopWordsRemover transformer){ 35 | super(transformer); 36 | } 37 | 38 | @Override 39 | public List encodeFeatures(SparkMLEncoder encoder){ 40 | StopWordsRemover transformer = getTransformer(); 41 | 42 | boolean caseSensitive = transformer.getCaseSensitive(); 43 | String[] stopWords = transformer.getStopWords(); 44 | 45 | InOutMode inputMode = getInputMode(); 46 | 47 | List result = new ArrayList<>(); 48 | 49 | String[] inputCols = inputMode.getInputCols(transformer); 50 | for(String inputCol : inputCols){ 51 | DocumentFeature documentFeature = (DocumentFeature)encoder.getOnlyFeature(inputCol); 52 | 53 | Pattern pattern = Pattern.compile(documentFeature.getWordSeparatorRE()); 54 | 55 | DocumentFeature.StopWordSet stopWordSet = new DocumentFeature.StopWordSet(caseSensitive); 56 | 57 | for(String stopWord : stopWords){ 58 | String[] stopTokens = pattern.split(stopWord); 59 | 60 | // Skip multi-token stopwords. See https://issues.apache.org/jira/browse/SPARK-18374 61 | if(stopTokens.length > 1){ 62 | continue; 63 | } // End if 64 | 65 | if(TermUtil.hasPunctuation(stopWord)){ 66 | throw new IllegalArgumentException("Punctuated stop words (" + stopWord + ") are not supported"); 67 | } 68 | 69 | stopWordSet.add(stopWord); 70 | } 71 | 72 | documentFeature.addStopWordSet(stopWordSet); 73 | 74 | result.add(documentFeature); 75 | } 76 | 77 | return result; 78 | } 79 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/SparkMLEncoderTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Arrays; 22 | import java.util.Collections; 23 | import java.util.List; 24 | 25 | import org.apache.spark.sql.types.DataTypes; 26 | import org.apache.spark.sql.types.StructType; 27 | import org.dmg.pmml.DataField; 28 | import org.dmg.pmml.DataType; 29 | import org.dmg.pmml.OpType; 30 | import org.jpmml.converter.ContinuousFeature; 31 | import org.jpmml.converter.Feature; 32 | import org.jpmml.converter.FieldUtil; 33 | import org.jpmml.converter.ObjectFeature; 34 | import org.junit.jupiter.api.Test; 35 | 36 | import static org.junit.jupiter.api.Assertions.assertEquals; 37 | import static org.junit.jupiter.api.Assertions.assertInstanceOf; 38 | 39 | public class SparkMLEncoderTest { 40 | 41 | @Test 42 | public void toCategorical(){ 43 | StructType schema = new StructType() 44 | .add("x", DataTypes.IntegerType, false); 45 | 46 | ConverterFactory converterFactory = new ConverterFactory(Collections.emptyMap()); 47 | 48 | SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory); 49 | 50 | Feature feature = encoder.getOnlyFeature("x"); 51 | 52 | assertInstanceOf(ContinuousFeature.class, feature); 53 | 54 | DataField dataField = checkField(feature, OpType.CONTINUOUS, DataType.INTEGER, Collections.emptyList()); 55 | 56 | encoder.toCategorical(feature, Arrays.asList(1, 2, 3)); 57 | 58 | // Clear feature cache 59 | encoder.removeFeatures("x"); 60 | 61 | feature = encoder.getOnlyFeature("x"); 62 | 63 | assertInstanceOf(ObjectFeature.class, feature); 64 | 65 | dataField = checkField(feature, OpType.CATEGORICAL, DataType.INTEGER, Arrays.asList(1, 2, 3)); 66 | 67 | encoder.toCategorical(feature, Arrays.asList("1.0", "2.0", "3.0")); 68 | } 69 | 70 | static 71 | private DataField checkField(Feature feature, OpType opType, DataType dataType, List values){ 72 | DataField field = (DataField)feature.getField(); 73 | 74 | assertEquals(opType, field.requireOpType()); 75 | assertEquals(dataType, field.requireDataType()); 76 | assertEquals(values, FieldUtil.getValues(field)); 77 | 78 | return field; 79 | } 80 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/LinearSVCModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import org.apache.spark.ml.classification.LinearSVCModel; 22 | import org.dmg.pmml.DataType; 23 | import org.dmg.pmml.Expression; 24 | import org.dmg.pmml.FieldRef; 25 | import org.dmg.pmml.Model; 26 | import org.dmg.pmml.OpType; 27 | import org.dmg.pmml.PMMLFunctions; 28 | import org.dmg.pmml.mining.MiningModel; 29 | import org.dmg.pmml.regression.RegressionModel; 30 | import org.jpmml.converter.ExpressionUtil; 31 | import org.jpmml.converter.FieldNameUtil; 32 | import org.jpmml.converter.ModelUtil; 33 | import org.jpmml.converter.Schema; 34 | import org.jpmml.converter.Transformation; 35 | import org.jpmml.converter.mining.MiningModelUtil; 36 | import org.jpmml.converter.transformations.AbstractTransformation; 37 | import org.jpmml.sparkml.ClassificationModelConverter; 38 | 39 | public class LinearSVCModelConverter extends ClassificationModelConverter implements HasRegressionTableOptions { 40 | 41 | public LinearSVCModelConverter(LinearSVCModel model){ 42 | super(model); 43 | } 44 | 45 | @Override 46 | public MiningModel encodeModel(Schema schema){ 47 | LinearSVCModel model = getModel(); 48 | 49 | Transformation transformation = new AbstractTransformation(){ 50 | 51 | @Override 52 | public String getName(String name){ 53 | return FieldNameUtil.create(PMMLFunctions.THRESHOLD, name); 54 | } 55 | 56 | @Override 57 | public Expression createExpression(FieldRef fieldRef){ 58 | return ExpressionUtil.createApply(PMMLFunctions.THRESHOLD) 59 | .addExpressions(fieldRef, ExpressionUtil.createConstant(model.getThreshold())); 60 | } 61 | }; 62 | 63 | Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); 64 | 65 | Model linearModel = LinearModelUtil.createRegression(this, model.coefficients(), model.intercept(), segmentSchema) 66 | .setOutput(ModelUtil.createPredictedOutput("margin", OpType.CONTINUOUS, DataType.DOUBLE, transformation)); 67 | 68 | return MiningModelUtil.createBinaryLogisticClassification(linearModel, 1d, 0d, RegressionModel.NormalizationMethod.NONE, false, schema); 69 | } 70 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/testing/SparkMLEncoderBatchTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.testing; 20 | 21 | import java.util.function.Predicate; 22 | 23 | import com.google.common.base.Equivalence; 24 | import org.apache.spark.sql.SparkSession; 25 | import org.jpmml.converter.FieldNameUtil; 26 | import org.jpmml.converter.testing.ModelEncoderBatchTest; 27 | import org.jpmml.evaluator.ResultField; 28 | import org.jpmml.evaluator.testing.PMMLEquivalence; 29 | import org.jpmml.sparkml.SparkSessionUtil; 30 | 31 | public class SparkMLEncoderBatchTest extends ModelEncoderBatchTest { 32 | 33 | public SparkMLEncoderBatchTest(){ 34 | this(new PMMLEquivalence(1e-14, 1e-14)); 35 | } 36 | 37 | public SparkMLEncoderBatchTest(Equivalence equivalence){ 38 | super(equivalence); 39 | } 40 | 41 | /** 42 | * @see #createSparkSession() 43 | * @see #destroySparkSession() 44 | */ 45 | public SparkSession getSparkSession(){ 46 | return SparkMLEncoderBatchTest.sparkSession; 47 | } 48 | 49 | @Override 50 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 51 | SparkMLEncoderBatch result = new SparkMLEncoderBatch(algorithm, dataset, columnFilter, equivalence){ 52 | 53 | @Override 54 | public SparkMLEncoderBatchTest getArchiveBatchTest(){ 55 | return SparkMLEncoderBatchTest.this; 56 | } 57 | }; 58 | 59 | return result; 60 | } 61 | 62 | static 63 | public void createSparkSession(){ 64 | SparkMLEncoderBatchTest.sparkSession = SparkSessionUtil.createSparkSession(); 65 | } 66 | 67 | static 68 | public void destroySparkSession(){ 69 | SparkMLEncoderBatchTest.sparkSession = SparkSessionUtil.destroySparkSession(SparkMLEncoderBatchTest.sparkSession); 70 | } 71 | 72 | static 73 | public Predicate excludePredictionFields(){ 74 | return excludePredictionFields("prediction"); 75 | } 76 | 77 | static 78 | public Predicate excludePredictionFields(String predictionCol){ 79 | return excludeFields(predictionCol, FieldNameUtil.create("pmml", predictionCol)); 80 | } 81 | 82 | private static SparkSession sparkSession = null; 83 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/DomainUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.lang.reflect.Array; 22 | import java.util.Arrays; 23 | import java.util.Collection; 24 | import java.util.LinkedHashMap; 25 | import java.util.List; 26 | import java.util.Map; 27 | import java.util.stream.Collectors; 28 | 29 | import scala.jdk.javaapi.CollectionConverters; 30 | 31 | public class DomainUtil { 32 | 33 | private DomainUtil(){ 34 | } 35 | 36 | static 37 | public E[] toArray(List values, Class clazz){ 38 | 39 | if(values == null){ 40 | return null; 41 | } 42 | 43 | @SuppressWarnings("unchecked") 44 | E[] result = (E[])Array.newInstance(clazz, values.size()); 45 | 46 | return values.toArray(result); 47 | } 48 | 49 | static 50 | public Map toArrayMap(Map> map, Class clazz){ 51 | Collection>> entries = map.entrySet(); 52 | 53 | return entries.stream() 54 | .collect(Collectors.toMap(entry -> entry.getKey(), entry -> toArray(entry.getValue(), clazz), (left, right) -> left, LinkedHashMap::new)); 55 | } 56 | 57 | static 58 | public Map toObjectArrayMap(Map> map){ 59 | return toArrayMap(map, Object.class); 60 | } 61 | 62 | static 63 | public Map toNumberArrayMap(Map> map){ 64 | return toArrayMap(map, Number.class); 65 | } 66 | 67 | static 68 | public Map> toListMap(Map map){ 69 | Collection> entries = map.entrySet(); 70 | 71 | return entries.stream() 72 | .collect(Collectors.toMap(entry -> entry.getKey(), entry -> Arrays.asList(entry.getValue()), (left, right) -> left, LinkedHashMap::new)); 73 | } 74 | 75 | static 76 | public Map toJavaMap(scala.collection.immutable.Map scalaMap){ 77 | Map javaMap = (Map)CollectionConverters.asJava(scalaMap); 78 | 79 | return javaMap; 80 | } 81 | 82 | static 83 | public scala.collection.immutable.Map toScalaMap(Map javaMap){ 84 | scala.collection.mutable.Map scalaMap = CollectionConverters.asScala(javaMap); 85 | 86 | return scala.collection.immutable.Map$.MODULE$.from(scalaMap); 87 | } 88 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/MaxAbsScalerModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.MaxAbsScalerModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.DerivedField; 28 | import org.dmg.pmml.Expression; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.PMMLFunctions; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.ExpressionUtil; 33 | import org.jpmml.converter.Feature; 34 | import org.jpmml.converter.SchemaUtil; 35 | import org.jpmml.converter.ValueUtil; 36 | import org.jpmml.sparkml.FeatureConverter; 37 | import org.jpmml.sparkml.SparkMLEncoder; 38 | 39 | public class MaxAbsScalerModelConverter extends FeatureConverter { 40 | 41 | public MaxAbsScalerModelConverter(MaxAbsScalerModel transformer){ 42 | super(transformer); 43 | } 44 | 45 | @Override 46 | public List encodeFeatures(SparkMLEncoder encoder){ 47 | MaxAbsScalerModel transformer = getTransformer(); 48 | 49 | Vector maxAbs = transformer.maxAbs(); 50 | 51 | List features = encoder.getFeatures(transformer.getInputCol()); 52 | 53 | SchemaUtil.checkSize(maxAbs.size(), features); 54 | 55 | List names = formatNames(features.size(), encoder); 56 | 57 | List result = new ArrayList<>(); 58 | 59 | for(int i = 0, length = features.size(); i < length; i++){ 60 | Feature feature = features.get(i); 61 | 62 | double maxAbsUnzero = maxAbs.apply(i); 63 | if(maxAbsUnzero == 0d){ 64 | maxAbsUnzero = 1d; 65 | } // End if 66 | 67 | if(!ValueUtil.isOne(maxAbsUnzero)){ 68 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 69 | 70 | Expression expression = ExpressionUtil.createApply(PMMLFunctions.DIVIDE, continuousFeature.ref(), ExpressionUtil.createConstant(maxAbsUnzero)); 71 | 72 | DerivedField derivedField = encoder.createDerivedField(names.get(i), OpType.CONTINUOUS, DataType.DOUBLE, expression); 73 | 74 | feature = new ContinuousFeature(encoder, derivedField); 75 | } 76 | 77 | result.add(feature); 78 | } 79 | 80 | return result; 81 | } 82 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/PCAModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.PCAModel; 25 | import org.apache.spark.ml.linalg.DenseMatrix; 26 | import org.dmg.pmml.Apply; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.DerivedField; 29 | import org.dmg.pmml.Expression; 30 | import org.dmg.pmml.OpType; 31 | import org.dmg.pmml.PMMLFunctions; 32 | import org.jpmml.converter.ContinuousFeature; 33 | import org.jpmml.converter.ExpressionUtil; 34 | import org.jpmml.converter.Feature; 35 | import org.jpmml.converter.ValueUtil; 36 | import org.jpmml.sparkml.FeatureConverter; 37 | import org.jpmml.sparkml.MatrixUtil; 38 | import org.jpmml.sparkml.SparkMLEncoder; 39 | 40 | public class PCAModelConverter extends FeatureConverter { 41 | 42 | public PCAModelConverter(PCAModel transformer){ 43 | super(transformer); 44 | } 45 | 46 | @Override 47 | public List encodeFeatures(SparkMLEncoder encoder){ 48 | PCAModel transformer = getTransformer(); 49 | 50 | DenseMatrix pc = transformer.pc(); 51 | 52 | List features = encoder.getFeatures(transformer.getInputCol()); 53 | 54 | MatrixUtil.checkRows(features.size(), pc); 55 | 56 | List names = formatNames(features.size(), encoder); 57 | 58 | List result = new ArrayList<>(); 59 | 60 | for(int i = 0, length = transformer.getK(); i < length; i++){ 61 | Apply apply = ExpressionUtil.createApply(PMMLFunctions.SUM); 62 | 63 | for(int j = 0; j < features.size(); j++){ 64 | Feature feature = features.get(j); 65 | 66 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 67 | 68 | Expression expression = continuousFeature.ref(); 69 | 70 | Double coefficient = pc.apply(j, i); 71 | if(!ValueUtil.isOne(coefficient)){ 72 | expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, expression, ExpressionUtil.createConstant(coefficient)); 73 | } 74 | 75 | apply.addExpressions(expression); 76 | } 77 | 78 | DerivedField derivedField = encoder.createDerivedField(names.get(i), OpType.CONTINUOUS, DataType.DOUBLE, apply); 79 | 80 | result.add(new ContinuousFeature(encoder, derivedField)); 81 | } 82 | 83 | return result; 84 | } 85 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ProbabilisticClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | import java.util.function.IntFunction; 24 | 25 | import org.apache.spark.ml.classification./*Probabilistic*/ClassificationModel; 26 | import org.apache.spark.ml.linalg.Vector; 27 | import org.apache.spark.ml.param.shared.HasProbabilityCol; 28 | import org.dmg.pmml.Model; 29 | import org.dmg.pmml.OutputField; 30 | import org.jpmml.converter.CategoricalLabel; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.Feature; 33 | import org.jpmml.converter.FieldNameUtil; 34 | import org.jpmml.converter.Label; 35 | import org.jpmml.converter.ModelUtil; 36 | 37 | abstract 38 | public class ProbabilisticClassificationModelConverter & HasProbabilityCol> extends ClassificationModelConverter { 39 | 40 | public ProbabilisticClassificationModelConverter(T model){ 41 | super(model); 42 | } 43 | 44 | @Override 45 | public List registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder){ 46 | T model = getModel(); 47 | 48 | List result = super.registerOutputFields(label, pmmlModel, encoder); 49 | 50 | CategoricalLabel categoricalLabel = (CategoricalLabel)label; 51 | 52 | String probabilityCol = model.getProbabilityCol(); 53 | 54 | IntFunction formatter = new IntFunction<>(){ 55 | 56 | @Override 57 | public String apply(int index){ 58 | Object value = categoricalLabel.getValue(index); 59 | 60 | return FieldNameUtil.create(probabilityCol, value); 61 | } 62 | }; 63 | 64 | List names = encoder.mapFieldNames(probabilityCol, categoricalLabel.size(), formatter); 65 | 66 | result = new ArrayList<>(result); 67 | 68 | List features = new ArrayList<>(); 69 | 70 | for(int i = 0; i < categoricalLabel.size(); i++){ 71 | Object value = categoricalLabel.getValue(i); 72 | 73 | OutputField probabilityField = ModelUtil.createProbabilityField(names.get(i), getDataType(), value); 74 | 75 | result.add(probabilityField); 76 | 77 | features.add(new ContinuousFeature(encoder, probabilityField)); 78 | } 79 | 80 | // XXX 81 | encoder.putFeatures(probabilityCol, features); 82 | 83 | return result; 84 | } 85 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/model/GBTClassificationModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.model; 20 | 21 | import java.util.List; 22 | 23 | import com.google.common.primitives.Doubles; 24 | import org.apache.spark.ml.classification.GBTClassificationModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.MiningFunction; 28 | import org.dmg.pmml.OpType; 29 | import org.dmg.pmml.mining.MiningModel; 30 | import org.dmg.pmml.mining.Segmentation; 31 | import org.dmg.pmml.regression.RegressionModel; 32 | import org.dmg.pmml.tree.TreeModel; 33 | import org.jpmml.converter.ModelUtil; 34 | import org.jpmml.converter.Schema; 35 | import org.jpmml.converter.mining.MiningModelUtil; 36 | import org.jpmml.sparkml.ProbabilisticClassificationModelConverter; 37 | 38 | public class GBTClassificationModelConverter extends ProbabilisticClassificationModelConverter implements HasFeatureImportances, HasTreeOptions { 39 | 40 | public GBTClassificationModelConverter(GBTClassificationModel model){ 41 | super(model); 42 | } 43 | 44 | @Override 45 | public Vector getFeatureImportances(){ 46 | GBTClassificationModel model = getModel(); 47 | 48 | return model.featureImportances(); 49 | } 50 | 51 | @Override 52 | public MiningModel encodeModel(Schema schema){ 53 | GBTClassificationModel model = getModel(); 54 | 55 | String lossType = model.getLossType(); 56 | switch(lossType){ 57 | case "logistic": 58 | break; 59 | default: 60 | throw new IllegalArgumentException("Loss function " + lossType + " is not supported"); 61 | } 62 | 63 | Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); 64 | 65 | List treeModels = TreeModelUtil.encodeDecisionTreeEnsemble(this, segmentSchema); 66 | 67 | MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(segmentSchema.getLabel())) 68 | .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.WEIGHTED_SUM, Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels, Doubles.asList(model.treeWeights()))) 69 | .setOutput(ModelUtil.createPredictedOutput("gbtValue", OpType.CONTINUOUS, DataType.DOUBLE)); 70 | 71 | return MiningModelUtil.createBinaryLogisticClassification(miningModel, 2d, 0d, RegressionModel.NormalizationMethod.LOGIT, false, schema); 72 | } 73 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/feature/VectorDisassemblerTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Arrays; 22 | import java.util.List; 23 | import java.util.stream.Collectors; 24 | 25 | import org.apache.spark.ml.Pipeline; 26 | import org.apache.spark.ml.PipelineModel; 27 | import org.apache.spark.ml.PipelineStage; 28 | import org.apache.spark.ml.Transformer; 29 | import org.apache.spark.ml.linalg.DenseVector; 30 | import org.apache.spark.ml.linalg.SparseVector; 31 | import org.apache.spark.ml.linalg.VectorUDT; 32 | import org.apache.spark.sql.Dataset; 33 | import org.apache.spark.sql.Row; 34 | import org.apache.spark.sql.RowFactory; 35 | import org.apache.spark.sql.types.StructType; 36 | import org.jpmml.sparkml.SparkMLTest; 37 | import org.junit.jupiter.api.Test; 38 | 39 | import static org.junit.jupiter.api.Assertions.assertEquals; 40 | 41 | public class VectorDisassemblerTest extends SparkMLTest { 42 | 43 | @Test 44 | public void transform(){ 45 | StructType schema = new StructType() 46 | .add("featureVec", new VectorUDT(), false); 47 | 48 | List rows = Arrays.asList( 49 | RowFactory.create(new SparseVector(3, new int[]{1}, new double[]{1.0})), 50 | RowFactory.create(new DenseVector(new double[]{0.0d, 0.0d, 1.0d})), 51 | RowFactory.create(new SparseVector(3, new int[]{0}, new double[]{1.0})) 52 | ); 53 | 54 | Dataset ds = SparkMLTest.sparkSession.createDataFrame(rows, schema); 55 | 56 | Transformer transformer = new VectorDisassembler() 57 | .setInputCol("featureVec") 58 | .setOutputCols(new String[]{"first", "second", "third"}); 59 | 60 | Pipeline pipeline = new Pipeline() 61 | .setStages(new PipelineStage[]{transformer}); 62 | 63 | PipelineModel pipelineModel = pipeline.fit(ds); 64 | 65 | Dataset transformedDs = pipelineModel.transform(ds); 66 | 67 | checkColumn(Arrays.asList(null, 0.0d, 1.0d), transformedDs.select("first")); 68 | checkColumn(Arrays.asList(1.0d, 0.0d, null), transformedDs.select("second")); 69 | checkColumn(Arrays.asList(null, 1.0d, null), transformedDs.select("third")); 70 | } 71 | 72 | static 73 | private void checkColumn(List expectedValues, Dataset dataset){ 74 | List rows = dataset.collectAsList(); 75 | 76 | List actualValues = rows.stream() 77 | .map(row -> (Double)row.get(0)) 78 | .collect(Collectors.toList()); 79 | 80 | assertEquals(expectedValues, actualValues); 81 | } 82 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/main/java/org/jpmml/sparkml/lightgbm/BoosterUtil.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm; 20 | 21 | import java.io.IOException; 22 | import java.io.StringReader; 23 | import java.util.LinkedHashMap; 24 | import java.util.List; 25 | import java.util.Map; 26 | 27 | import com.google.common.io.CharStreams; 28 | import com.microsoft.azure.synapse.ml.lightgbm.LightGBMModelMethods; 29 | import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster; 30 | import org.apache.spark.ml.Model; 31 | import org.apache.spark.ml.param.shared.HasPredictionCol; 32 | import org.dmg.pmml.mining.MiningModel; 33 | import org.jpmml.converter.Schema; 34 | import org.jpmml.lightgbm.GBDT; 35 | import org.jpmml.lightgbm.HasLightGBMOptions; 36 | import org.jpmml.lightgbm.LightGBMUtil; 37 | import org.jpmml.sparkml.ModelConverter; 38 | import scala.Option; 39 | 40 | public class BoosterUtil { 41 | 42 | private BoosterUtil(){ 43 | } 44 | 45 | static 46 | public , M extends Model & HasPredictionCol & LightGBMModelMethods> MiningModel encodeModel(C converter, Schema schema){ 47 | M model = converter.getModel(); 48 | 49 | GBDT gbdt = BoosterUtil.getGBDT(model); 50 | 51 | Integer bestIteration = model.getBoosterBestIteration(); 52 | if(bestIteration < 0){ 53 | bestIteration = null; 54 | } 55 | 56 | Map options = new LinkedHashMap<>(); 57 | options.put(HasLightGBMOptions.OPTION_COMPACT, converter.getOption(HasLightGBMOptions.OPTION_COMPACT, Boolean.TRUE)); 58 | options.put(HasLightGBMOptions.OPTION_NUM_ITERATION, converter.getOption(HasLightGBMOptions.OPTION_NUM_ITERATION, bestIteration)); 59 | 60 | Schema lgbmSchema = gbdt.toLightGBMSchema(schema); 61 | 62 | MiningModel miningModel = gbdt.encodeModel(options, lgbmSchema); 63 | 64 | return miningModel; 65 | } 66 | 67 | static 68 | private & LightGBMModelMethods> GBDT getGBDT(M model){ 69 | LightGBMBooster booster = model.getLightGBMBooster(); 70 | 71 | Option modelStr = booster.modelStr(); 72 | if(modelStr.isEmpty()){ 73 | throw new IllegalArgumentException(); 74 | } 75 | 76 | String string = modelStr.get(); 77 | 78 | try(StringReader reader = new StringReader(string)){ 79 | List lines = CharStreams.readLines(reader); 80 | 81 | return LightGBMUtil.loadGBDT(lines.iterator()); 82 | } catch(IOException ioe){ 83 | throw new RuntimeException(ioe); 84 | } 85 | } 86 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/resources/main.scala: -------------------------------------------------------------------------------- 1 | import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMClassifier, LightGBMRegressor} 2 | import org.apache.spark.ml.Pipeline 3 | import org.jpmml.sparkml.DatasetUtil 4 | 5 | class LightGBMTest extends SparkMLTest { 6 | 7 | override 8 | def build_classification_pipeline(label_col: String, cat_cols: Array[String], cont_cols: Array[String], cat_encoding: CategoryEncoding): Pipeline = { 9 | val labelIndexer = new StringIndexer() 10 | .setInputCol(label_col) 11 | .setOutputCol("idx_" + label_col) 12 | 13 | val features = build_features(cat_cols, cont_cols, cat_encoding, withDomain = true) 14 | 15 | var classifier = new LightGBMClassifier() 16 | .setLabelCol(labelIndexer.getOutputCol) 17 | .setFeaturesCol(features.last.asInstanceOf[HasOutputCol].getOutputCol) 18 | 19 | label_col match { 20 | case "Adjusted" => 21 | classifier = classifier 22 | .setObjective("binary") 23 | .setNumIterations(101) 24 | case "Species" => 25 | classifier = classifier 26 | .setObjective("multiclass") 27 | .setNumIterations(17) 28 | case _ => 29 | throw new IllegalArgumentException() 30 | } 31 | 32 | new Pipeline() 33 | .setStages(labelIndexer +: features :+ classifier) 34 | } 35 | 36 | override 37 | def build_regression_pipeline(label_col: String, cat_cols: Array[String], cont_cols: Array[String], cat_encoding: CategoryEncoding): Pipeline = { 38 | val features = build_features(cat_cols, cont_cols, cat_encoding, withDomain = true) 39 | 40 | val regressor = new LightGBMRegressor() 41 | .setLabelCol(label_col) 42 | .setFeaturesCol(features.last.asInstanceOf[HasOutputCol].getOutputCol) 43 | .setNumIterations(101) 44 | 45 | new Pipeline() 46 | .setStages(features :+ regressor) 47 | } 48 | 49 | def run_audit(): Unit = { 50 | val label_col = "Adjusted" 51 | val cat_cols = Array("Education", "Employment", "Gender", "Marital", "Occupation") 52 | val cont_cols = Array("Age", "Hours", "Income") 53 | 54 | var df = load_audit("Audit") 55 | 56 | run_classification(df, label_col, cat_cols, cont_cols, CategoryEncoding.LEGACY_DIRECT_MIXED, "LightGBM", "Audit") 57 | 58 | df = load_audit("AuditNA") 59 | 60 | run_classification(df, label_col, cat_cols, cont_cols, CategoryEncoding.MODERN_DIRECT, "LightGBM", "AuditNA") 61 | } 62 | 63 | def run_auto(): Unit = { 64 | val label_col = "mpg" 65 | val cat_cols = Array("cylinders", "model_year", "origin") 66 | val cont_cols = Array("acceleration", "displacement", "horsepower", "weight") 67 | 68 | var df = load_auto("Auto") 69 | 70 | run_regression(df, label_col, cat_cols, cont_cols, CategoryEncoding.LEGACY_DIRECT_MIXED, "LightGBM", "Auto") 71 | 72 | df = load_auto("AutoNA") 73 | 74 | run_regression(df, label_col, cat_cols, cont_cols, CategoryEncoding.MODERN_DIRECT, "LightGBM", "AutoNA") 75 | } 76 | 77 | def run_iris(): Unit = { 78 | val label_col = "Species" 79 | val cat_cols = Array[String]() 80 | val cont_cols = Array("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width") 81 | 82 | val df = load_iris("Iris") 83 | 84 | run_classification(df, label_col, cat_cols, cont_cols, null, "LightGBM", "Iris") 85 | } 86 | } 87 | 88 | val test = new LightGBMTest() 89 | test.run_audit() 90 | test.run_auto() 91 | test.run_iris() 92 | -------------------------------------------------------------------------------- /pmml-sparkml/src/test/resources/main.scala: -------------------------------------------------------------------------------- 1 | import org.apache.spark.ml.Pipeline 2 | import org.apache.spark.ml.classification.LogisticRegression 3 | import org.apache.spark.ml.feature.{StandardScaler, SQLTransformer} 4 | import org.apache.spark.ml.regression.LinearRegression 5 | import org.apache.spark.sql.types.DataTypes 6 | import org.jpmml.sparkml.DatasetUtil 7 | import org.jpmml.sparkml.feature.VectorDisassembler 8 | 9 | class LibSVMTest extends SparkMLTest { 10 | 11 | override 12 | def load_housing(name: String): DataFrame = { 13 | val df = spark.read 14 | .format("libsvm") 15 | .load("libsvm/Housing.libsvm") 16 | 17 | df 18 | } 19 | 20 | override 21 | def load_iris(name: String): DataFrame = { 22 | var df = spark.read 23 | .format("libsvm") 24 | .load("libsvm/Iris.libsvm") 25 | 26 | df.withColumn("label", col("label").cast(DataTypes.IntegerType)) 27 | } 28 | 29 | override 30 | def build_regression_pipeline(label_col: String, cat_cols: Array[String], cont_cols: Array[String], cat_encoding: CategoryEncoding): Pipeline = { 31 | val stdScaler = new StandardScaler() 32 | .setInputCol("features") 33 | .setOutputCol("scaledFeatures") 34 | .setWithMean(true) 35 | .setWithStd(true) 36 | 37 | val regressor = new LinearRegression() 38 | .setLabelCol("label") 39 | .setFeaturesCol(stdScaler.getOutputCol) 40 | 41 | new Pipeline() 42 | .setStages(Array(stdScaler, regressor)) 43 | } 44 | 45 | override 46 | def build_classification_pipeline(label_col: String, cat_cols: Array[String], cont_cols: Array[String], cat_encoding: CategoryEncoding): Pipeline = { 47 | val labelIndexer = new StringIndexer() 48 | .setInputCol(label_col) 49 | .setOutputCol("idx_" + label_col) 50 | 51 | val stdScaler = new StandardScaler() 52 | .setInputCol("features") 53 | .setOutputCol("scaledFeatures") 54 | .setWithMean(true) 55 | .setWithStd(true) 56 | 57 | val classifier = new LogisticRegression() 58 | .setLabelCol(labelIndexer.getOutputCol) 59 | .setFeaturesCol(stdScaler.getOutputCol) 60 | 61 | val vecDisassembler = new VectorDisassembler() 62 | .setInputCol(classifier.getProbabilityCol) 63 | .setOutputCols(Array("probSetosa", "probVersicolor", "probVirginica")) 64 | 65 | val sqlTransformer = new SQLTransformer() 66 | .setStatement(""" 67 | SELECT 68 | prediction, 69 | probability, 70 | probSetosa, probVersicolor, probVirginica, 71 | CASE 72 | WHEN probSetosa >= 0.5 THEN 'Setosa' 73 | WHEN probVersicolor >= 0.5 THEN 'Versicolor' 74 | WHEN probVirginica >= 0.5 THEN 'Virginica' 75 | ELSE '(mixed)' 76 | END AS SpeciesDecision 77 | FROM __THIS__ 78 | """) 79 | 80 | new Pipeline() 81 | .setStages(Array(labelIndexer, stdScaler, classifier, vecDisassembler, sqlTransformer)) 82 | } 83 | 84 | def run_housing(): Unit = { 85 | val label_col = "label" 86 | 87 | val df = load_housing("Housing") 88 | 89 | run_regression(df, label_col, null, null, null, "LinearRegression", "HousingVec") 90 | } 91 | 92 | def run_iris(): Unit = { 93 | val label_col = "label" 94 | 95 | val df = load_iris("Iris") 96 | 97 | run_classification(df, label_col, null, null, null, "LogisticRegression", "IrisVec") 98 | } 99 | } 100 | 101 | val test = new LibSVMTest() 102 | test.run_housing() 103 | test.run_iris() -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/ClusteringModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.Model; 25 | import org.apache.spark.ml.param.shared.HasFeaturesCol; 26 | import org.apache.spark.ml.param.shared.HasPredictionCol; 27 | import org.dmg.pmml.DataType; 28 | import org.dmg.pmml.FieldRef; 29 | import org.dmg.pmml.MiningFunction; 30 | import org.dmg.pmml.OpType; 31 | import org.dmg.pmml.OutputField; 32 | import org.dmg.pmml.ResultFeature; 33 | import org.jpmml.converter.DerivedOutputField; 34 | import org.jpmml.converter.Feature; 35 | import org.jpmml.converter.FieldNameUtil; 36 | import org.jpmml.converter.IndexFeature; 37 | import org.jpmml.converter.Label; 38 | import org.jpmml.converter.LabelUtil; 39 | import org.jpmml.converter.ModelUtil; 40 | 41 | abstract 42 | public class ClusteringModelConverter & HasFeaturesCol & HasPredictionCol> extends ModelConverter { 43 | 44 | public ClusteringModelConverter(T model){ 45 | super(model); 46 | } 47 | 48 | abstract 49 | public int getNumberOfClusters(); 50 | 51 | @Override 52 | public MiningFunction getMiningFunction(){ 53 | return MiningFunction.CLUSTERING; 54 | } 55 | 56 | @Override 57 | public List getFeatures(SparkMLEncoder encoder){ 58 | T model = getModel(); 59 | 60 | String featuresCol = model.getFeaturesCol(); 61 | 62 | return encoder.getFeatures(featuresCol); 63 | } 64 | 65 | @Override 66 | public List registerOutputFields(Label label, org.dmg.pmml.Model pmmlModel, SparkMLEncoder encoder){ 67 | T model = getModel(); 68 | 69 | List clusters = LabelUtil.createTargetCategories(getNumberOfClusters()); 70 | 71 | String predictionCol = model.getPredictionCol(); 72 | 73 | OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldNameUtil.create("pmml", predictionCol), OpType.CATEGORICAL, DataType.STRING) 74 | .setFinalResult(false); 75 | 76 | DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, true); 77 | 78 | OutputField predictedOutputField = new OutputField(encoder.mapOnlyFieldName(predictionCol), OpType.CATEGORICAL, DataType.INTEGER) 79 | .setResultFeature(ResultFeature.TRANSFORMED_VALUE) 80 | .setExpression(new FieldRef(pmmlPredictedField)); 81 | 82 | DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, true); 83 | 84 | encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, clusters)); 85 | 86 | return Collections.emptyList(); 87 | } 88 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/scala/org/jpmml/sparkml/feature/SparseToDenseTransformer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature 20 | 21 | import org.apache.spark.ml.Transformer 22 | import org.apache.spark.ml.linalg.{DenseVector, Vector} 23 | import org.apache.spark.ml.param.ParamMap 24 | import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} 25 | import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} 26 | import org.apache.spark.sql.{Dataset, Row} 27 | import org.apache.spark.sql.functions.udf 28 | import org.apache.spark.sql.types.{StructField, StructType} 29 | 30 | class SparseToDenseTransformer(override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { 31 | private 32 | val sparseToDenseUDF = udf(SparseToDenseTransformer.sparseToDense _) 33 | 34 | /** 35 | * @group setParam 36 | */ 37 | def setInputCol(value: String): this.type = set(inputCol, value) 38 | 39 | /** 40 | * @group setParam 41 | */ 42 | def setOutputCol(value: String): this.type = set(outputCol, value) 43 | 44 | 45 | def this() = this(Identifiable.randomUID("sparse2dense")) 46 | 47 | override 48 | def copy(extra: ParamMap): SparseToDenseTransformer = defaultCopy(extra) 49 | 50 | protected 51 | def validateParams(): Unit = { 52 | require(isDefined(inputCol) && isDefined(outputCol), "inputCol and outputCol must be defined") 53 | } 54 | 55 | override 56 | def transformSchema(schema: StructType): StructType = { 57 | validateParams() 58 | 59 | val inputColName = getInputCol 60 | val outputColName = getOutputCol 61 | 62 | val inputFields = schema.fields 63 | 64 | require(inputFields.exists(_.name == inputColName), s"Input column $inputColName not found") 65 | require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists") 66 | 67 | val inputField = schema(inputColName) 68 | val outputField = new StructField(outputColName, inputField.dataType, inputField.nullable) 69 | 70 | StructType(inputFields :+ outputField) 71 | } 72 | 73 | override 74 | def transform(dataset: Dataset[_]): Dataset[Row] = { 75 | val inputColName = getInputCol 76 | val outputColName = getOutputCol 77 | 78 | dataset.withColumn(outputColName, sparseToDenseUDF(dataset(inputColName))) 79 | } 80 | } 81 | 82 | object SparseToDenseTransformer extends DefaultParamsReadable[SparseToDenseTransformer] { 83 | 84 | def sparseToDense(vec: Vector): DenseVector = { 85 | if(vec != null){ 86 | vec match { 87 | case denseVec: DenseVector => denseVec 88 | case _ => vec.toDense 89 | } 90 | } else 91 | 92 | { 93 | null 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/AliasExpression.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.Objects; 22 | 23 | import org.dmg.pmml.Expression; 24 | import org.dmg.pmml.HasExpression; 25 | import org.dmg.pmml.PMMLObject; 26 | import org.dmg.pmml.Visitor; 27 | import org.dmg.pmml.VisitorAction; 28 | import org.jpmml.model.visitors.ExpressionFilterer; 29 | 30 | public class AliasExpression extends Expression implements HasExpression { 31 | 32 | private String name = null; 33 | 34 | private Expression expression = null; 35 | 36 | 37 | public AliasExpression(String name, Expression expression){ 38 | setName(name); 39 | setExpression(expression); 40 | } 41 | 42 | public String getName(){ 43 | return this.name; 44 | } 45 | 46 | public AliasExpression setName(String name){ 47 | this.name = Objects.requireNonNull(name); 48 | 49 | return this; 50 | } 51 | 52 | @Override 53 | public Expression requireExpression(){ 54 | 55 | if(this.expression == null){ 56 | throw new IllegalStateException(); 57 | } 58 | 59 | return this.expression; 60 | } 61 | 62 | @Override 63 | public Expression getExpression(){ 64 | return this.expression; 65 | } 66 | 67 | @Override 68 | public AliasExpression setExpression(Expression expression){ 69 | this.expression = Objects.requireNonNull(expression); 70 | 71 | return this; 72 | } 73 | 74 | @Override 75 | public VisitorAction accept(Visitor visitor){ 76 | VisitorAction status = visitor.visit(this); 77 | 78 | if(status == VisitorAction.CONTINUE){ 79 | visitor.pushParent(this); 80 | 81 | status = PMMLObject.traverse(visitor, getExpression()); 82 | 83 | visitor.popParent(); 84 | } // End if 85 | 86 | if(status == VisitorAction.TERMINATE){ 87 | return VisitorAction.TERMINATE; 88 | } 89 | 90 | return VisitorAction.CONTINUE; 91 | } 92 | 93 | static 94 | public Expression unwrap(Expression expression){ 95 | expression = unwrapInternal(expression); 96 | 97 | ExpressionFilterer filterer = new ExpressionFilterer(){ 98 | 99 | @Override 100 | public Expression filter(Expression expression){ 101 | return unwrapInternal(expression); 102 | } 103 | }; 104 | filterer.applyTo(expression); 105 | 106 | return expression; 107 | } 108 | 109 | static 110 | private Expression unwrapInternal(Expression expression){ 111 | 112 | while(expression instanceof AliasExpression){ 113 | AliasExpression aliasExpression = (AliasExpression)expression; 114 | 115 | expression = aliasExpression.getExpression(); 116 | } 117 | 118 | return expression; 119 | } 120 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/test/java/org/jpmml/sparkml/feature/DomainTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.io.File; 22 | import java.io.IOException; 23 | import java.lang.reflect.Method; 24 | import java.util.List; 25 | import java.util.Map; 26 | import java.util.Set; 27 | import java.util.stream.Collectors; 28 | 29 | import com.google.common.io.MoreFiles; 30 | import com.google.common.io.RecursiveDeleteOption; 31 | import org.apache.spark.ml.PipelineStage; 32 | import org.apache.spark.ml.util.MLReader; 33 | import org.apache.spark.ml.util.MLWritable; 34 | import org.apache.spark.ml.util.MLWriter; 35 | import org.apache.spark.sql.Dataset; 36 | import org.apache.spark.sql.Row; 37 | import org.jpmml.sparkml.SparkMLTest; 38 | 39 | import static org.junit.jupiter.api.Assertions.assertEquals; 40 | 41 | abstract 42 | public class DomainTest extends SparkMLTest { 43 | 44 | static 45 | protected void checkDataset(Map> expectedColumns, Dataset actualDs){ 46 | Set keys = expectedColumns.keySet(); 47 | 48 | for(String key : keys){ 49 | List expectedColumn = expectedColumns.get(key); 50 | 51 | List actualColumnRows = actualDs 52 | .select(key) 53 | .collectAsList(); 54 | 55 | List actualColumn = actualColumnRows.stream() 56 | .map(row -> row.get(0)) 57 | .collect(Collectors.toList()); 58 | 59 | assertEquals(expectedColumn, actualColumn); 60 | } 61 | } 62 | 63 | static 64 | protected S sparkClone(S stage) throws IOException { 65 | File tmpDir = createTempDir(stage); 66 | 67 | try { 68 | String path = tmpDir.getAbsolutePath(); 69 | 70 | MLWriter writer = stage.write(); 71 | 72 | writer 73 | .overwrite() 74 | .save(path); 75 | 76 | Class stageClazz = stage.getClass(); 77 | 78 | // The read method of the companion object 79 | Method readMethod = stageClazz.getDeclaredMethod("read"); 80 | 81 | @SuppressWarnings("unchecked") 82 | MLReader reader = (MLReader)readMethod.invoke(null); 83 | 84 | return reader.load(path); 85 | } catch(ReflectiveOperationException roe){ 86 | throw new RuntimeException(roe); 87 | } finally { 88 | MoreFiles.deleteRecursively(tmpDir.toPath(), RecursiveDeleteOption.ALLOW_INSECURE); 89 | } 90 | } 91 | 92 | static 93 | private File createTempDir(PipelineStage stage) throws IOException { 94 | File tmpFile = File.createTempFile("jpmml-sparkml-" + stage.uid(), ""); 95 | 96 | if(!tmpFile.delete() || !tmpFile.mkdirs()){ 97 | throw new IOException(); 98 | } 99 | 100 | return tmpFile; 101 | } 102 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/InteractionConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import org.apache.spark.ml.feature.Interaction; 26 | import org.dmg.pmml.DataType; 27 | import org.jpmml.converter.CategoricalFeature; 28 | import org.jpmml.converter.Feature; 29 | import org.jpmml.converter.FieldNameUtil; 30 | import org.jpmml.converter.InteractionFeature; 31 | import org.jpmml.sparkml.FeatureConverter; 32 | import org.jpmml.sparkml.SparkMLEncoder; 33 | 34 | public class InteractionConverter extends FeatureConverter { 35 | 36 | public InteractionConverter(Interaction transformer){ 37 | super(transformer); 38 | } 39 | 40 | @Override 41 | public List encodeFeatures(SparkMLEncoder encoder){ 42 | Interaction transformer = getTransformer(); 43 | 44 | StringBuilder sb = new StringBuilder(); 45 | 46 | List result = new ArrayList<>(); 47 | 48 | String[] inputCols = transformer.getInputCols(); 49 | for(int i = 0; i < inputCols.length; i++){ 50 | String inputCol = inputCols[i]; 51 | 52 | List features = encoder.getFeatures(inputCol); 53 | 54 | if(features.size() == 1){ 55 | Feature feature = features.get(0); 56 | 57 | categorical: 58 | if(feature instanceof CategoricalFeature){ 59 | CategoricalFeature categoricalFeature = (CategoricalFeature)feature; 60 | 61 | String name = categoricalFeature.getName(); 62 | 63 | DataType dataType = categoricalFeature.getDataType(); 64 | switch(dataType){ 65 | case INTEGER: 66 | break; 67 | case FLOAT: 68 | case DOUBLE: 69 | break categorical; 70 | default: 71 | break; 72 | } 73 | 74 | // XXX 75 | inputCol = name; 76 | 77 | features = (List)OneHotEncoderModelConverter.encodeFeature(encoder, categoricalFeature, categoricalFeature.getValues()); 78 | } 79 | } // End if 80 | 81 | if(i == 0){ 82 | sb.append(inputCol); 83 | 84 | result = features; 85 | } else 86 | 87 | { 88 | sb.append(':').append(inputCol); 89 | 90 | List interactionFeatures = new ArrayList<>(); 91 | 92 | int index = 0; 93 | 94 | for(Feature left : result){ 95 | 96 | for(Feature right : features){ 97 | interactionFeatures.add(new InteractionFeature(encoder, FieldNameUtil.select(sb.toString(), index), DataType.DOUBLE, Arrays.asList(left, right))); 98 | 99 | index++; 100 | } 101 | } 102 | 103 | result = interactionFeatures; 104 | } 105 | } 106 | 107 | return result; 108 | } 109 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/ContinuousDomainModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.Collections; 22 | import java.util.List; 23 | import java.util.Map; 24 | 25 | import org.dmg.pmml.DataField; 26 | import org.dmg.pmml.Interval; 27 | import org.dmg.pmml.OutlierTreatmentMethod; 28 | import org.jpmml.converter.ContinuousFeature; 29 | import org.jpmml.converter.Feature; 30 | import org.jpmml.converter.OutlierDecorator; 31 | import org.jpmml.sparkml.SparkMLEncoder; 32 | 33 | public class ContinuousDomainModelConverter extends DomainModelConverter { 34 | 35 | public ContinuousDomainModelConverter(ContinuousDomainModel transformer){ 36 | super(transformer); 37 | } 38 | 39 | @Override 40 | public List encodeFeatures(SparkMLEncoder encoder){ 41 | ContinuousDomainModel transformer = getTransformer(); 42 | 43 | boolean withData = transformer.getWithData(); 44 | 45 | Map dataRanges; 46 | 47 | if(withData){ 48 | dataRanges = DomainUtil.toJavaMap(transformer.getDataRanges()); 49 | } else 50 | 51 | { 52 | dataRanges = Collections.emptyMap(); 53 | } 54 | 55 | OutlierTreatmentMethod outlierTreatment = parseOutlierTreatment(transformer.getOutlierTreatment()); 56 | Number lowValue; 57 | Number highValue; 58 | 59 | switch(outlierTreatment){ 60 | case AS_MISSING_VALUES: 61 | case AS_EXTREME_VALUES: 62 | lowValue = transformer.getLowValue(); 63 | highValue = transformer.getHighValue(); 64 | break; 65 | default: 66 | lowValue = null; 67 | highValue = null; 68 | break; 69 | } 70 | 71 | DomainManager domainManager = new DomainManager(){ 72 | 73 | @Override 74 | public DataField toDataField(Feature feature){ 75 | Number[] range = dataRanges.get(feature.getName()); 76 | 77 | DataField dataField = (DataField)encoder.toContinuous(feature); 78 | 79 | if(range != null){ 80 | Interval interval = new Interval(Interval.Closure.CLOSED_CLOSED, range[0], range[1]); 81 | 82 | dataField.addIntervals(interval); 83 | } 84 | 85 | encoder.addDecorator(dataField, new OutlierDecorator(outlierTreatment, lowValue, highValue)); 86 | 87 | return dataField; 88 | } 89 | 90 | @Override 91 | public ContinuousFeature toFeature(DataField dataField){ 92 | return new ContinuousFeature(encoder, dataField); 93 | } 94 | }; 95 | 96 | return super.encodeFeatures(domainManager, encoder); 97 | } 98 | 99 | static 100 | private OutlierTreatmentMethod parseOutlierTreatment(String outlierTreatment){ 101 | return OutlierTreatmentMethod.fromValue(outlierTreatment); 102 | } 103 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/DomainModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2025 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.Arrays; 23 | import java.util.List; 24 | 25 | import org.dmg.pmml.DataField; 26 | import org.dmg.pmml.InvalidValueTreatmentMethod; 27 | import org.dmg.pmml.MissingValueTreatmentMethod; 28 | import org.dmg.pmml.Value; 29 | import org.jpmml.converter.Feature; 30 | import org.jpmml.converter.FieldUtil; 31 | import org.jpmml.converter.InvalidValueDecorator; 32 | import org.jpmml.converter.MissingValueDecorator; 33 | import org.jpmml.sparkml.FeatureConverter; 34 | import org.jpmml.sparkml.SparkMLEncoder; 35 | 36 | abstract 37 | public class DomainModelConverter> extends FeatureConverter { 38 | 39 | public DomainModelConverter(T transformer){ 40 | super(transformer); 41 | } 42 | 43 | protected List encodeFeatures(DomainManager domainManager, SparkMLEncoder encoder){ 44 | T transformer = getTransformer(); 45 | 46 | Object[] missingValues = transformer.getMissingValues(); 47 | 48 | MissingValueTreatmentMethod missingValueTreatment = parseMissingValueTreatment(transformer.getMissingValueTreatment()); 49 | Object missingValueReplacement = transformer.getMissingValueReplacement(); 50 | InvalidValueTreatmentMethod invalidValueTreatment = parseInvalidValueTreatment(transformer.getInvalidValueTreatment()); 51 | Object invalidValueReplacement = transformer.getInvalidValueReplacement(); 52 | 53 | List result = new ArrayList<>(); 54 | 55 | String[] inputCols = transformer.getInputCols(); 56 | for(String inputCol : inputCols){ 57 | Feature feature = encoder.getOnlyFeature(inputCol); 58 | 59 | DataField dataField = domainManager.toDataField(feature); 60 | 61 | FieldUtil.addValues(dataField, Value.Property.MISSING, Arrays.asList(missingValues)); 62 | 63 | encoder.addDecorator(dataField, new MissingValueDecorator(missingValueTreatment, missingValueReplacement)); 64 | encoder.addDecorator(dataField, new InvalidValueDecorator(invalidValueTreatment, invalidValueReplacement)); 65 | 66 | feature = domainManager.toFeature(dataField); 67 | 68 | result.add(feature); 69 | } 70 | 71 | return result; 72 | } 73 | 74 | static 75 | protected MissingValueTreatmentMethod parseMissingValueTreatment(String missingValueTreatment){ 76 | return MissingValueTreatmentMethod.fromValue(missingValueTreatment); 77 | } 78 | 79 | static 80 | protected InvalidValueTreatmentMethod parseInvalidValueTreatment(String invalidValueTreatment){ 81 | return InvalidValueTreatmentMethod.fromValue(invalidValueTreatment); 82 | } 83 | 84 | static 85 | protected interface DomainManager { 86 | 87 | DataField toDataField(Feature feature); 88 | 89 | Feature toFeature(DataField dataField); 90 | } 91 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/feature/MinMaxScalerModelConverter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2016 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.feature; 20 | 21 | import java.util.ArrayList; 22 | import java.util.List; 23 | 24 | import org.apache.spark.ml.feature.MinMaxScalerModel; 25 | import org.apache.spark.ml.linalg.Vector; 26 | import org.dmg.pmml.DataType; 27 | import org.dmg.pmml.DerivedField; 28 | import org.dmg.pmml.Expression; 29 | import org.dmg.pmml.OpType; 30 | import org.dmg.pmml.PMMLFunctions; 31 | import org.jpmml.converter.ContinuousFeature; 32 | import org.jpmml.converter.ExpressionUtil; 33 | import org.jpmml.converter.Feature; 34 | import org.jpmml.converter.SchemaUtil; 35 | import org.jpmml.converter.ValueUtil; 36 | import org.jpmml.sparkml.FeatureConverter; 37 | import org.jpmml.sparkml.SparkMLEncoder; 38 | 39 | public class MinMaxScalerModelConverter extends FeatureConverter { 40 | 41 | public MinMaxScalerModelConverter(MinMaxScalerModel transformer){ 42 | super(transformer); 43 | } 44 | 45 | @Override 46 | public List encodeFeatures(SparkMLEncoder encoder){ 47 | MinMaxScalerModel transformer = getTransformer(); 48 | 49 | double rescaleFactor = (transformer.getMax() - transformer.getMin()); 50 | double rescaleConstant = transformer.getMin(); 51 | 52 | Vector originalMin = transformer.originalMin(); 53 | Vector originalMax = transformer.originalMax(); 54 | 55 | List features = encoder.getFeatures(transformer.getInputCol()); 56 | 57 | SchemaUtil.checkSize(Math.max(originalMin.size(), originalMax.size()), features); 58 | 59 | List names = formatNames(features.size(), encoder); 60 | 61 | List result = new ArrayList<>(); 62 | 63 | for(int i = 0, length = features.size(); i < length; i++){ 64 | Feature feature = features.get(i); 65 | 66 | ContinuousFeature continuousFeature = feature.toContinuousFeature(); 67 | 68 | double min = originalMin.apply(i); 69 | double max = originalMax.apply(i); 70 | 71 | Expression expression = ExpressionUtil.createApply(PMMLFunctions.DIVIDE, ExpressionUtil.createApply(PMMLFunctions.SUBTRACT, continuousFeature.ref(), ExpressionUtil.createConstant(min)), ExpressionUtil.createConstant(max - min)); 72 | 73 | if(!ValueUtil.isOne(rescaleFactor)){ 74 | expression = ExpressionUtil.createApply(PMMLFunctions.MULTIPLY, expression, ExpressionUtil.createConstant(rescaleFactor)); 75 | } // End if 76 | 77 | if(!ValueUtil.isZero(rescaleConstant)){ 78 | expression = ExpressionUtil.createApply(PMMLFunctions.ADD, expression, ExpressionUtil.createConstant(rescaleConstant)); 79 | } 80 | 81 | DerivedField derivedField = encoder.createDerivedField(names.get(i), OpType.CONTINUOUS, DataType.DOUBLE, expression); 82 | 83 | result.add(new ContinuousFeature(encoder, derivedField)); 84 | } 85 | 86 | return result; 87 | } 88 | } -------------------------------------------------------------------------------- /pmml-sparkml-lightgbm/src/test/java/org/jpmml/sparkml/lightgbm/testing/LightGBMTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml.lightgbm.testing; 20 | 21 | import java.util.LinkedHashMap; 22 | import java.util.List; 23 | import java.util.Map; 24 | import java.util.function.Predicate; 25 | 26 | import com.google.common.base.Equivalence; 27 | import org.jpmml.converter.testing.Datasets; 28 | import org.jpmml.converter.testing.OptionsUtil; 29 | import org.jpmml.evaluator.ResultField; 30 | import org.jpmml.evaluator.testing.PMMLEquivalence; 31 | import org.jpmml.lightgbm.HasLightGBMOptions; 32 | import org.jpmml.sparkml.testing.SparkMLEncoderBatch; 33 | import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest; 34 | import org.junit.jupiter.api.AfterAll; 35 | import org.junit.jupiter.api.BeforeAll; 36 | import org.junit.jupiter.api.Disabled; 37 | import org.junit.jupiter.api.Test; 38 | 39 | @Disabled 40 | public class LightGBMTest extends SparkMLEncoderBatchTest implements Datasets { 41 | 42 | public LightGBMTest(){ 43 | super(new PMMLEquivalence(1e-14, 1e-14)); 44 | } 45 | 46 | @Override 47 | public SparkMLEncoderBatch createBatch(String algorithm, String dataset, Predicate columnFilter, Equivalence equivalence){ 48 | columnFilter = columnFilter.and(SparkMLEncoderBatchTest.excludePredictionFields()); 49 | 50 | SparkMLEncoderBatch result = new SparkMLEncoderBatch(algorithm, dataset, columnFilter, equivalence){ 51 | 52 | @Override 53 | public LightGBMTest getArchiveBatchTest(){ 54 | return LightGBMTest.this; 55 | } 56 | 57 | @Override 58 | public List> getOptionsMatrix(){ 59 | Map options = new LinkedHashMap<>(); 60 | 61 | options.put(HasLightGBMOptions.OPTION_COMPACT, new Boolean[]{false, true}); 62 | 63 | return OptionsUtil.generateOptionsMatrix(options); 64 | } 65 | }; 66 | 67 | return result; 68 | } 69 | 70 | @Test 71 | public void evaluateLightGBMAudit() throws Exception { 72 | evaluate("LightGBM", AUDIT); 73 | } 74 | 75 | @Test 76 | public void evaluateLightGBMAuditNA() throws Exception { 77 | evaluate("LightGBM", AUDIT_NA); 78 | } 79 | 80 | @Test 81 | public void evaluateLightGBMAuto() throws Exception { 82 | evaluate("LightGBM", AUTO); 83 | } 84 | 85 | @Test 86 | public void evaluateLightGBMAutoNA() throws Exception { 87 | evaluate("LightGBM", AUTO_NA); 88 | } 89 | 90 | @Test 91 | public void evaluateLightGBMIris() throws Exception { 92 | evaluate("LightGBM", IRIS); 93 | } 94 | 95 | @BeforeAll 96 | static 97 | public void createSparkSession(){ 98 | SparkMLEncoderBatchTest.createSparkSession(); 99 | } 100 | 101 | @AfterAll 102 | static 103 | public void destroySparkSession(){ 104 | SparkMLEncoderBatchTest.destroySparkSession(); 105 | } 106 | } -------------------------------------------------------------------------------- /pmml-sparkml/src/main/java/org/jpmml/sparkml/DocumentFeature.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2017 Villu Ruusmann 3 | * 4 | * This file is part of JPMML-SparkML 5 | * 6 | * JPMML-SparkML is free software: you can redistribute it and/or modify 7 | * it under the terms of the GNU Affero General Public License as published by 8 | * the Free Software Foundation, either version 3 of the License, or 9 | * (at your option) any later version. 10 | * 11 | * JPMML-SparkML is distributed in the hope that it will be useful, 12 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | * GNU Affero General Public License for more details. 15 | * 16 | * You should have received a copy of the GNU Affero General Public License 17 | * along with JPMML-SparkML. If not, see . 18 | */ 19 | package org.jpmml.sparkml; 20 | 21 | import java.util.LinkedHashSet; 22 | import java.util.Objects; 23 | import java.util.Set; 24 | 25 | import org.dmg.pmml.Field; 26 | import org.jpmml.converter.ObjectFeature; 27 | import org.jpmml.model.ToStringHelper; 28 | 29 | public class DocumentFeature extends ObjectFeature { 30 | 31 | private String wordSeparatorRE = null; 32 | 33 | private Set stopWordSets = new LinkedHashSet<>(); 34 | 35 | 36 | public DocumentFeature(SparkMLEncoder encoder, Field field, String wordSeparatorRE){ 37 | super(encoder, field.requireName(), field.requireDataType()); 38 | 39 | setWordSeparatorRE(wordSeparatorRE); 40 | } 41 | 42 | @Override 43 | public int hashCode(){ 44 | return (31 * super.hashCode()) + Objects.hashCode(this.getWordSeparatorRE()); 45 | } 46 | 47 | @Override 48 | public boolean equals(Object object){ 49 | 50 | if(object instanceof DocumentFeature){ 51 | DocumentFeature that = (DocumentFeature)object; 52 | 53 | return super.equals(object) && Objects.equals(this.getWordSeparatorRE(), that.getWordSeparatorRE()); 54 | } 55 | 56 | return false; 57 | } 58 | 59 | @Override 60 | protected ToStringHelper toStringHelper(){ 61 | return super.toStringHelper() 62 | .add("wordSeparatorRE", getWordSeparatorRE()); 63 | } 64 | 65 | public String getWordSeparatorRE(){ 66 | return this.wordSeparatorRE; 67 | } 68 | 69 | private void setWordSeparatorRE(String wordSeparatorRE){ 70 | this.wordSeparatorRE = Objects.requireNonNull(wordSeparatorRE); 71 | } 72 | 73 | public void addStopWordSet(StopWordSet stopWordSet){ 74 | Set stopWordSets = getStopWordSets(); 75 | 76 | stopWordSets.add(stopWordSet); 77 | } 78 | 79 | public Set getStopWordSets(){ 80 | return this.stopWordSets; 81 | } 82 | 83 | static 84 | public class StopWordSet extends LinkedHashSet { 85 | 86 | private boolean caseSensitive = false; 87 | 88 | 89 | public StopWordSet(boolean caseSensitive){ 90 | setCaseSensitive(caseSensitive); 91 | } 92 | 93 | @Override 94 | public int hashCode(){ 95 | return (31 * super.hashCode()) + Objects.hashCode(this.isCaseSensitive()); 96 | } 97 | 98 | @Override 99 | public boolean equals(Object object){ 100 | 101 | if(object instanceof StopWordSet){ 102 | StopWordSet that = (StopWordSet)object; 103 | 104 | return super.equals(object) && Objects.equals(this.isCaseSensitive(), that.isCaseSensitive()); 105 | } 106 | 107 | return false; 108 | } 109 | 110 | public boolean isCaseSensitive(){ 111 | return this.caseSensitive; 112 | } 113 | 114 | private void setCaseSensitive(boolean caseSensitive){ 115 | this.caseSensitive = caseSensitive; 116 | } 117 | } 118 | } --------------------------------------------------------------------------------