├── sklearn_pmml
├── convert
│ ├── test
│ │ ├── __init__.py
│ │ ├── jpmml-csv-evaluator
│ │ │ ├── README.md
│ │ │ ├── pom.xml
│ │ │ └── src
│ │ │ │ └── main
│ │ │ │ └── java
│ │ │ │ └── sklearn
│ │ │ │ └── pmml
│ │ │ │ └── jpmml
│ │ │ │ └── JPMMLCSVEvaluator.java
│ │ ├── test_randomForestConverter.py
│ │ ├── test_derived_fields.py
│ │ ├── test_gradientBoostingConverter.py
│ │ ├── test_decisionTreeClassifierConverter.py
│ │ └── jpmml_test.py
│ ├── __init__.py
│ ├── random_forest.py
│ ├── features.py
│ ├── utils.py
│ ├── tree.py
│ ├── gbrt.py
│ └── model.py
├── __init__.py
└── test
│ ├── data
│ └── gradient_boosting_classifier
│ │ ├── context.pkl
│ │ ├── document.pmml
│ │ └── estimator.pkl
│ └── __init__.py
├── MANIFEST.in
├── .travis.yml
├── .gitignore
├── setup.py
├── LICENSE
├── README.md
└── examples
└── pmml
├── DecisionTreeClassifier.pmml
├── GradientBoostingClassifier.pmml
└── RandomForestClassifier.pmml
/sklearn_pmml/convert/test/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/sklearn_pmml/__init__.py:
--------------------------------------------------------------------------------
1 | from sklearn_pmml.convert import *
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.md
2 | include LICENSE
3 | include requirements.txt
--------------------------------------------------------------------------------
/sklearn_pmml/convert/__init__.py:
--------------------------------------------------------------------------------
1 | from sklearn_pmml import pmml
2 | from sklearn_pmml.convert.features import Feature, NumericFeature, CategoricalFeature, RealNumericFeature
3 | from sklearn_pmml.convert.gbrt import *
4 | from sklearn_pmml.convert.tree import *
5 | from sklearn_pmml.convert.random_forest import *
6 | from sklearn_pmml.convert.model import *
7 | from sklearn_pmml.convert.utils import *
8 |
9 |
10 | __all__ = ['TransformationContext', 'EstimatorConverter', 'find_converter', 'GradientBoostingConverter', 'LogOddsEstimatorConverter', 'DecisionTreeConverter', 'features']
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/sklearn_pmml/convert/test/jpmml-csv-evaluator/README.md:
--------------------------------------------------------------------------------
1 | # About
2 | This is a simple [JPMML](http://github.com/jpmml)-based CLI evaluator for PMML models.
3 |
4 | # Notes
5 | This submodule relies on AGPL library [jpmml-evaluator](http://github.com/jpmml/jpmml-evaluator),
6 | but it's only used for testing and it's not a part of sklearn-pmml distribution.
7 | Since users will not interact with AGPL-licensed library, I think it's OK to use it in tests.
8 |
9 | # Usage
10 | 1. Build the JAR file (make sure you have JDK8 installed):
11 | ```
12 | mvn clean package
13 | ```
14 | 2. Run with maven:
15 | ```
16 | mvn exec:java -e -q \
17 | -Dexec.mainClass=sklearn.pmml.jpmml.JPMMLCSVEvaluator \
18 | -Dexec.args=/path/to/pmml /path/to/input.csv /path/to/output.csv
19 | ```
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "2.7"
4 | - "3.4"
5 | # - "nightly"
6 | # command to install dependencies
7 | before_install:
8 | - sudo add-apt-repository ppa:webupd8team/java -y
9 | - sudo apt-get update -qq
10 | - sudo apt-get install oracle-java8-installer
11 | - sudo apt-get install maven
12 | - export PATH=/usr/bin:$PATH
13 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh
14 | - chmod +x miniconda.sh
15 | - ./miniconda.sh -b
16 | - export PATH=/home/travis/miniconda/bin:$PATH
17 | - conda update --yes conda
18 | # install the heaviest dependencies with conda to save some time
19 | - travis_retry conda install --yes python=$TRAVIS_PYTHON_VERSION pip numpy scipy scikit-learn pandas lxml
20 |
21 | install:
22 | - travis_retry pip install .
23 | # command to run tests
24 | script: python setup.py test
25 | cache: apt
26 |
--------------------------------------------------------------------------------
/sklearn_pmml/test/data/gradient_boosting_classifier/context.pkl:
--------------------------------------------------------------------------------
1 | ccopy_reg
2 | _reconstructor
3 | p1
4 | (csklearn_pmml.convert
5 | TransformationContext
6 | p2
7 | c__builtin__
8 | object
9 | p3
10 | NtRp4
11 | (dp5
12 | S'schemas'
13 | p6
14 | (dp7
15 | S'output'
16 | p8
17 | (lp9
18 | g1
19 | (csklearn_pmml.convert.features
20 | RealNumericFeature
21 | p10
22 | g3
23 | NtRp11
24 | (dp12
25 | S'_namespace'
26 | p13
27 | S''
28 | sS'_invalid_value_treatment'
29 | p14
30 | S'asIs'
31 | p15
32 | sS'_name'
33 | p16
34 | g8
35 | sbasS'input'
36 | p17
37 | (lp18
38 | g1
39 | (csklearn_pmml.convert.features
40 | IntegerNumericFeature
41 | p19
42 | g3
43 | NtRp20
44 | (dp21
45 | g13
46 | S''
47 | sg14
48 | g15
49 | sg16
50 | S'x1'
51 | p22
52 | sbag1
53 | (csklearn_pmml.convert.features
54 | StringCategoricalFeature
55 | p23
56 | g3
57 | NtRp24
58 | (dp25
59 | S'value_list'
60 | p26
61 | (lp27
62 | S'zero'
63 | p28
64 | aS'one'
65 | p29
66 | asg13
67 | S''
68 | sg14
69 | g15
70 | sg16
71 | S'x2'
72 | p30
73 | sbassb.
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | env/
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 | *,cover
45 |
46 | # Translations
47 | *.mo
48 | *.pot
49 |
50 | # Django stuff:
51 | *.log
52 |
53 | # Sphinx documentation
54 | docs/_build/
55 |
56 | # PyBuilder
57 | target/
58 |
59 | #java/intellij stuff
60 | *.iml
61 | *.class
62 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, Command
2 |
3 |
4 | class PyTest(Command):
5 | user_options = []
6 |
7 | def initialize_options(self):
8 | pass
9 |
10 | def finalize_options(self):
11 | pass
12 |
13 | def run(self):
14 | import subprocess
15 | import sys
16 | errno = subprocess.call([sys.executable, 'runtests.py'])
17 | raise SystemExit(errno)
18 |
19 | setup(
20 | name='sklearn-pmml',
21 | version='0.1.2',
22 | packages=['sklearn_pmml', 'sklearn_pmml.convert'],
23 | install_requires=[
24 | "pyxb",
25 | "scikit-learn",
26 | "pandas",
27 | "scipy",
28 | "pytest",
29 | "lxml",
30 | "enum34",
31 | ],
32 | cmdclass={'test': PyTest},
33 | url='https://github.com/alex-pirozhenko/sklearn-pmml',
34 | license='MIT',
35 | author='Alex Pirozhenko',
36 | author_email='apirozhenko@pulsepoint.com',
37 | description='A library that allows serialization of SciKit-Learn estimators into PMML'
38 | )
39 |
--------------------------------------------------------------------------------
/sklearn_pmml/convert/test/test_randomForestConverter.py:
--------------------------------------------------------------------------------
1 | from sklearn_pmml.convert import IntegerCategoricalFeature
2 | from sklearn_pmml.convert.test.jpmml_test import JPMMLClassificationTest, JPMMLTest, TARGET_NAME
3 | from unittest import TestCase
4 | from sklearn.ensemble import RandomForestClassifier
5 |
6 | __author__ = 'evancox'
7 |
8 |
9 | from sklearn_pmml.convert.random_forest import RandomForestClassifierConverter
10 |
11 |
12 | class TestRandomForestClassifierParity(TestCase, JPMMLClassificationTest):
13 |
14 | @classmethod
15 | def setUpClass(cls):
16 | if JPMMLTest.can_run():
17 | JPMMLTest.init_jpmml()
18 |
19 | def setUp(self):
20 | self.model = RandomForestClassifier(
21 | n_estimators=3,
22 | max_depth=3
23 | )
24 | self.init_data()
25 | self.converter = RandomForestClassifierConverter(
26 | estimator=self.model,
27 | context=self.ctx
28 | )
29 |
30 | @property
31 | def output(self):
32 | return IntegerCategoricalFeature(name=TARGET_NAME, value_list=[0, 1, 2])
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2015 alex-pirozhenko
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/alex-pirozhenko/sklearn-pmml)
2 | [](https://gitter.im/alex-pirozhenko/sklearn-pmml?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
3 |
4 | # sklearn-pmml
5 |
6 | A library that allows serialization of SciKit-Learn estimators into PMML
7 |
8 | # Installation
9 | The easiest way is to use pip:
10 | ```
11 | pip install sklearn-pmml
12 | ```
13 |
14 | # Supported models
15 | - DecisionTreeClassifier
16 | - DecisionTreeRegressor
17 | - GradientBoostingClassifier
18 | - RandomForestClassifier
19 |
20 | # PMML output
21 |
22 | ## Classification
23 | Classifier converters can only operate with categorical outputs, and for each categorical output variable ```varname```
24 | the PMML output contains the following outputs:
25 | - categorical ```varname``` for the predicted label of the instance
26 | - double ```varname.label``` for the probability for a given label
27 |
28 | ## Regression
29 | Regression model PMML outputs the numeric response variable named as the output variable
30 |
--------------------------------------------------------------------------------
/sklearn_pmml/convert/test/jpmml-csv-evaluator/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | sklearn.pmml.jpmml
8 | jpmml-csv-evaluator
9 | 1.0-SNAPSHOT
10 |
11 |
12 |
13 |
14 | org.jpmml
15 | pmml-evaluator
16 |
17 | 1.2.5
18 |
19 |
20 |
21 | org.jpmml
22 | pmml-model
23 |
24 | 1.2.6
25 |
26 |
27 |
28 | net.sf.supercsv
29 | super-csv
30 | 2.0.1
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | org.apache.maven.plugins
42 | maven-compiler-plugin
43 | 3.0
44 |
45 | 1.7
46 | 1.7
47 |
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/sklearn_pmml/convert/random_forest.py:
--------------------------------------------------------------------------------
1 | from sklearn_pmml.convert import CategoricalFeature
2 |
3 | __author__ = 'evancox'
4 |
5 |
6 | from sklearn.ensemble import RandomForestClassifier
7 | from sklearn_pmml.convert.model import Schema, ModelMode, ClassifierConverter
8 | from sklearn_pmml.convert.tree import DecisionTreeConverter
9 | from sklearn_pmml.convert.utils import estimator_to_converter
10 |
11 | import sklearn_pmml.pmml as pmml
12 |
13 |
14 | class RandomForestClassifierConverter(ClassifierConverter):
15 | def __init__(self, estimator, context):
16 | super(RandomForestClassifierConverter, self).__init__(estimator, context)
17 | assert isinstance(estimator, RandomForestClassifier), \
18 | 'This converter can only process RandomForestClassifier instances'
19 | assert len(context.schemas[Schema.OUTPUT]) == 1, 'Only one-label classification is supported'
20 |
21 | def model(self, verification_data=None):
22 | mining_model = pmml.MiningModel(functionName=ModelMode.CLASSIFICATION.value)
23 | mining_model.append(self.mining_schema())
24 | mining_model.append(self.output())
25 | mining_model.append(self.segmentation())
26 | if verification_data is not None:
27 | mining_model.append(self.model_verification(verification_data))
28 | return mining_model
29 |
30 | def segmentation(self):
31 | """
32 | Build a segmentation (sequence of estimators)
33 | :return: Segmentation element
34 | """
35 | segmentation = pmml.Segmentation(multipleModelMethod="weightedAverage")
36 |
37 | for index, est in enumerate(self.estimator.estimators_):
38 | s = pmml.Segment(id=index)
39 | s.append(pmml.True_())
40 | s.append(DecisionTreeConverter(est, self.context, ModelMode.CLASSIFICATION)._model())
41 | segmentation.append(s)
42 |
43 | return segmentation
44 |
45 |
46 | estimator_to_converter[RandomForestClassifier] = RandomForestClassifierConverter
--------------------------------------------------------------------------------
/sklearn_pmml/convert/test/test_derived_fields.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from sklearn.tree import DecisionTreeClassifier
3 | from sklearn_pmml import EstimatorConverter, TransformationContext, pmml
4 | from sklearn_pmml.convert import Schema, ModelMode
5 | from sklearn_pmml.convert.features import *
6 | import numpy as np
7 |
8 | test_cases = [
9 | (
10 | [
11 | RealNumericFeature(name='f1'),
12 | ],
13 | [
14 | DerivedFeature(
15 | feature=RealNumericFeature(name='f2'),
16 | transformation=pmml.Discretize(mapMissingTo=0, defaultValue=1, field='f1'),
17 | function=np.vectorize(lambda f1: 0 if f1 is None else 1)
18 | )
19 | ],
20 | [RealNumericFeature(name='f3')],
21 |
22 | ''
23 | ''
24 | ''
25 | ''
26 | '',
27 |
28 | ''
29 | ''
30 | ''
31 | ''
32 | ''
33 | ''
34 | )
35 | ]
36 |
37 | @pytest.mark.parametrize("input_fields,derived_fields,output_fields,expected_data_dictionary,expected_transformation_dictionary", test_cases)
38 | def test_transformation_dictionary(input_fields, derived_fields, output_fields, expected_data_dictionary, expected_transformation_dictionary):
39 | converter = EstimatorConverter(
40 | DecisionTreeClassifier(),
41 | context=TransformationContext({
42 | Schema.INPUT: input_fields,
43 | Schema.DERIVED: derived_fields,
44 | Schema.MODEL: input_fields + derived_fields,
45 | Schema.OUTPUT: output_fields
46 | }),
47 | mode=ModelMode.CLASSIFICATION
48 | )
49 |
50 | assert converter.data_dictionary().toxml() == expected_data_dictionary, 'Error in data dictionary generation'
51 | assert converter.transformation_dictionary().toxml() == expected_transformation_dictionary,\
52 | 'Error in transformation dictionary generation'
--------------------------------------------------------------------------------
/sklearn_pmml/convert/test/test_gradientBoostingConverter.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | from sklearn.ensemble import GradientBoostingClassifier
4 | import numpy as np
5 |
6 | from sklearn_pmml.convert.test.jpmml_test import JPMMLClassificationTest, JPMMLTest, TARGET_NAME
7 | from sklearn_pmml.convert import TransformationContext, Schema
8 | from sklearn_pmml.convert.features import *
9 | from sklearn_pmml.convert.gbrt import GradientBoostingConverter
10 |
11 |
12 | class TestGradientBoostingClassifierConverter(TestCase):
13 | def setUp(self):
14 | np.random.seed(1)
15 | self.est = GradientBoostingClassifier(max_depth=2, n_estimators=10)
16 | self.est.fit([
17 | [0, 0],
18 | [0, 1],
19 | [1, 0],
20 | [1, 1],
21 | ], [0, 1, 1, 1])
22 | self.ctx = TransformationContext({
23 | Schema.INPUT: [
24 | IntegerNumericFeature('x1'),
25 | StringCategoricalFeature('x2', ['zero', 'one'])
26 | ],
27 | Schema.MODEL: [
28 | IntegerNumericFeature('x1'),
29 | StringCategoricalFeature('x2', ['zero', 'one'])
30 | ],
31 | Schema.DERIVED: [],
32 | Schema.OUTPUT: [
33 | IntegerCategoricalFeature('output', [0, 1])
34 | ]
35 | })
36 | self.converter = GradientBoostingConverter(
37 | estimator=self.est,
38 | context=self.ctx
39 | )
40 |
41 | def test_transform(self):
42 | p = self.converter.pmml()
43 | mm = p.MiningModel[0]
44 | assert mm.MiningSchema is not None, 'Missing mining schema'
45 | assert len(mm.MiningSchema.MiningField) == 2, 'Wrong number of mining fields'
46 | assert mm.Segmentation is not None, 'Missing segmentation root'
47 |
48 | def test_transform_with_verification(self):
49 | p = self.converter.pmml([
50 | {'x1': 0, 'x2': 'zero', 'output#1': self.est.predict_proba([[0, 0]])[0, 1], 'output#0': self.est.predict_proba([[0, 0]])[0, 0], 'output': self.est.predict([[0, 0]])},
51 | {'x1': 0, 'x2': 'one', 'output#1': self.est.predict_proba([[0, 1]])[0, 1], 'output#0': self.est.predict_proba([[0, 1]])[0, 0], 'output': self.est.predict([[0, 1]])},
52 | {'x1': 1, 'x2': 'zero', 'output#1': self.est.predict_proba([[1, 0]])[0, 1], 'output#0': self.est.predict_proba([[1, 0]])[0, 0], 'output': self.est.predict([[1, 0]])},
53 | {'x1': 1, 'x2': 'one', 'output#1': self.est.predict_proba([[1, 1]])[0, 1], 'output#0': self.est.predict_proba([[1, 1]])[0, 0], 'output': self.est.predict([[1, 1]])},
54 | ])
55 | mm = p.MiningModel[0]
56 | assert mm.MiningSchema is not None, 'Missing mining schema'
57 | assert len(mm.MiningSchema.MiningField) == 2, 'Wrong number of mining fields'
58 | assert mm.Segmentation is not None, 'Missing segmentation root'
59 |
60 |
61 | class TestGradientBoostingClassifierParity(TestCase, JPMMLClassificationTest):
62 |
63 | @classmethod
64 | def setUpClass(cls):
65 | if JPMMLTest.can_run():
66 | JPMMLTest.init_jpmml()
67 |
68 | def setUp(self):
69 | self.model = GradientBoostingClassifier(n_estimators=2, max_depth=2)
70 | self.init_data_one_label()
71 | self.converter = GradientBoostingConverter(
72 | estimator=self.model,
73 | context=self.ctx
74 | )
75 |
--------------------------------------------------------------------------------
/sklearn_pmml/test/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from unittest import TestCase
3 | from sklearn.base import BaseEstimator
4 |
5 | try:
6 | import cPickle as pickle
7 | except:
8 | import pickle
9 | from sklearn_pmml.convert import *
10 | from sklearn_pmml import pmml
11 |
12 |
13 | class TestSerializationMeta(type):
14 | TEST_DIR = os.path.dirname(__file__)
15 | DATA_DIR = os.path.join(TEST_DIR, 'data')
16 | ESTIMATOR_FILE_NAME = 'estimator.pkl'
17 | PMML_FILE_NAME = 'document.pmml'
18 | CONTEXT_FILE_NAME = 'context.pkl'
19 |
20 | def __new__(mcs, name, bases, d):
21 | """
22 | This method overrides default behaviour for creation of new instances. For every directory abc in data it
23 | creates a method called test_abc, with the body of load_and_compare function.
24 | """
25 | def gen_test(suite_name):
26 | def load_and_compare(self):
27 | # load the context.pkl, document.pmml and estimator.pkl
28 | suite_path = os.path.join(mcs.DATA_DIR, suite_name)
29 | content = os.listdir(suite_path)
30 | assert len(content) == 3, 'There should be exactly two files in the suite directory'
31 | assert mcs.ESTIMATOR_FILE_NAME in content, 'Estimator should be stored in {} file'.format(mcs.ESTIMATOR_FILE_NAME)
32 | assert mcs.PMML_FILE_NAME in content, 'PMML should be stored in {} file'.format(mcs.PMML_FILE_NAME)
33 | assert mcs.CONTEXT_FILE_NAME in content, 'Context should be stored in {} file'.format(mcs.CONTEXT_FILE_NAME)
34 | with open(os.path.join(suite_path, mcs.ESTIMATOR_FILE_NAME), 'r') as est_file:
35 | est = pickle.load(est_file)
36 | assert isinstance(est, BaseEstimator), '{} should be a trained estimator'.format(mcs.ESTIMATOR_FILE_NAME)
37 | with open(os.path.join(suite_path, mcs.CONTEXT_FILE_NAME), 'r') as ctx_file:
38 | ctx = pickle.load(ctx_file)
39 | assert isinstance(ctx, TransformationContext), '{} should be a transformation context'.format(mcs.CONTEXT_FILE_NAME)
40 | converter = find_converter(est)
41 | assert converter is not None, 'Can not find converter for {}'.format(est)
42 | transformed_pmml = converter(est, ctx).pmml()
43 | with open(os.path.join(suite_path, mcs.PMML_FILE_NAME), 'r') as pmml_file:
44 | loaded_pmml = pmml.CreateFromDocument('\n'.join(pmml_file.readlines()))
45 | self.maxDiff = None
46 | # make sure that the expected PMML matches the produced one
47 | self.assertEquals(loaded_pmml.toDOM().toprettyxml(), transformed_pmml.toDOM().toprettyxml())
48 |
49 | return load_and_compare
50 |
51 | # for every batch in the data dir create a corresponding test method
52 | for case in os.listdir(TestSerializationMeta.DATA_DIR):
53 | test_name = 'test_{}'.format(case)
54 | d[test_name] = gen_test(case)
55 | return type.__new__(mcs, name, bases, d)
56 |
57 |
58 | class TestSerialization(TestCase):
59 | """
60 | This is an automated tester for serializers. It uses a custom metaclass to define the test cases based on the
61 | content of the data directory. For the logic behind every check see load_and_compare method above.
62 | """
63 | __metaclass__ = TestSerializationMeta
64 |
65 |
66 |
--------------------------------------------------------------------------------
/examples/pmml/DecisionTreeClassifier.pmml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/sklearn_pmml/convert/test/jpmml-csv-evaluator/src/main/java/sklearn/pmml/jpmml/JPMMLCSVEvaluator.java:
--------------------------------------------------------------------------------
1 | package sklearn.pmml.jpmml;
2 |
3 | import com.google.common.collect.Lists;
4 | import com.google.common.collect.Maps;
5 |
6 | import com.google.common.collect.Sets;
7 | import org.dmg.pmml.PMML;
8 | import org.dmg.pmml.FieldName;
9 | import org.jpmml.model.JAXBUtil;
10 | import org.jpmml.model.ImportFilter;
11 | import org.jpmml.evaluator.FieldValue;
12 | import org.jpmml.evaluator.Evaluator;
13 | import org.jpmml.evaluator.ModelEvaluator;
14 | import org.jpmml.evaluator.ModelEvaluatorFactory;
15 | import org.supercsv.io.CsvMapReader;
16 | import org.supercsv.io.CsvMapWriter;
17 | import org.supercsv.prefs.CsvPreference;
18 | import org.xml.sax.SAXException;
19 | import org.xml.sax.InputSource;
20 |
21 | import javax.xml.bind.JAXBException;
22 |
23 | import java.io.FileInputStream;
24 | import java.io.FileReader;
25 | import java.io.FileWriter;
26 | import java.io.IOException;
27 | import java.io.InputStream;
28 | import java.util.Arrays;
29 | import java.util.HashMap;
30 | import java.util.List;
31 | import java.util.Map;
32 | import java.util.Set;
33 | import java.util.logging.Level;
34 | import java.util.logging.Logger;
35 |
36 | /**
37 | * Created by evancox on 7/23/15.
38 | */
39 | public class JPMMLCSVEvaluator
40 | {
41 | private static final Logger logger = Logger.getLogger(JPMMLCSVEvaluator.class.getCanonicalName());
42 |
43 | static PMML pmmlFromXml(final InputStream is)
44 | {
45 | try
46 | {
47 | return JAXBUtil.unmarshalPMML(ImportFilter.apply(new InputSource(is)));
48 | }
49 | catch (SAXException | JAXBException e)
50 | {
51 | throw new RuntimeException("Error reading PMML.", e);
52 | }
53 | }
54 |
55 | static Evaluator evaluatorFromPmml(final PMML pmml)
56 | {
57 | ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
58 |
59 | ModelEvaluator> modelEvaluator = modelEvaluatorFactory.newModelManager(pmml);
60 |
61 | return modelEvaluator;
62 | }
63 |
64 | static Evaluator evaluatorFromXml(final InputStream is)
65 | {
66 | // Adapted from:
67 | // * https://github.com/jpmml/jpmml/blob/master/README.md
68 | // * https://github.com/jpmml/jpmml-example/blob/master/src/main/java/org/jpmml/example/CsvEvaluationExample.java
69 | return evaluatorFromPmml(pmmlFromXml(is));
70 | }
71 |
72 | static List