├── tests ├── __init__.py ├── test_multistage.py ├── test_sqlalchemy.py ├── test_pipeline.py ├── test_ast.py ├── test_logistic.py ├── verification.py ├── test_skompile.py ├── conftest.py ├── evaluators.py └── test_sklearn.py ├── MANIFEST.in ├── skompiler ├── toskast │ ├── sklearn │ │ ├── tree │ │ │ ├── __init__.py │ │ │ └── base.py │ │ ├── ensemble │ │ │ ├── __init__.py │ │ │ ├── forest.py │ │ │ ├── gradient_boosting.py │ │ │ └── weight_boosting.py │ │ ├── cluster │ │ │ ├── __init__.py │ │ │ └── k_means.py │ │ ├── linear_model │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── logistic.py │ │ ├── decomposition │ │ │ ├── __init__.py │ │ │ └── pca.py │ │ ├── neural_network │ │ │ ├── __init__.py │ │ │ └── multilayer_perceptron.py │ │ ├── preprocessing │ │ │ ├── __init__.py │ │ │ └── data.py │ │ ├── common.py │ │ └── __init__.py │ ├── __init__.py │ ├── _common.py │ ├── string.py │ └── python.py ├── fromskast │ ├── __init__.py │ ├── pfa.py │ ├── _common.py │ ├── python.py │ ├── sympy.py │ ├── excel.py │ └── sqlalchemy.py ├── __init__.py ├── dsl.py ├── api.py └── ast.py ├── .vscode ├── pypi-upload.cmd ├── settings.json └── tasks.json ├── .gitignore ├── .travis.yml ├── setup.cfg ├── LICENSE ├── CHANGELOG.txt ├── setup.py ├── README.md └── pylintrc /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE README.rst CHANGELOG.txt -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/tree/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decision trees 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ensemble models 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Clustering methods. 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/linear_model/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Linear models 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/fromskast/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKompiler: Code generation from SK-AST. 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decomposition methods 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converters from other representations TO SKAST. 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/neural_network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multilayer perceptron 3 | """ 4 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing methods. 3 | """ 4 | -------------------------------------------------------------------------------- /.vscode/pypi-upload.cmd: -------------------------------------------------------------------------------- 1 | del /Q dist\*.tar.gz 2 | python setup.py sdist && twine upload dist\*.tar.gz 3 | -------------------------------------------------------------------------------- /tests/test_multistage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-stage logic tests. 3 | """ 4 | 5 | 6 | def test_excel(): 7 | pass 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /build/ 2 | /dist/ 3 | /*.egg 4 | /*.egg-info/ 5 | /**/__pycache__/ 6 | *.pyc 7 | /.cache 8 | /venv* 9 | /.pytest_cache/ 10 | /.ipynb_checkpoints 11 | Untitled.ipynb 12 | /_* -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | - "3.7" 6 | - "3.8" 7 | - "3.9" 8 | dist: focal 9 | install: 10 | - pip install .[test] 11 | script: pytest 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [egg_info] 2 | tag_build = 3 | tag_svn_revision = false 4 | 5 | [tool:pytest] 6 | addopts = --ignore=setup.py --ignore=build --ignore=dist --doctest-modules 7 | norecursedirs=*.egg 8 | filterwarnings = 9 | ignore::UserWarning 10 | ignore::FutureWarning -------------------------------------------------------------------------------- /skompiler/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | SKompiler: Library for converting trained SKLearn models into abstract expressions suitable 3 | for further compilation into executable code in various languages. 4 | 5 | Author: Konstantin Tretyakov 6 | License: MIT 7 | ''' 8 | from .api import skompile 9 | 10 | __version__ = '0.7' 11 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/decomposition/pca.py: -------------------------------------------------------------------------------- 1 | """ 2 | PCA implementation 3 | """ 4 | import numpy as np 5 | from skompiler.dsl import const 6 | 7 | def pca(model, inputs): 8 | matrix = np.array(model.components_) 9 | if model.whiten: 10 | matrix /= np.sqrt(model.explained_variance_)[:, np.newaxis] 11 | if model.mean_ is not None: 12 | inputs = inputs - const(model.mean_) 13 | return const(matrix) @ inputs 14 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/cluster/k_means.py: -------------------------------------------------------------------------------- 1 | """ 2 | K-means implementation. 3 | """ 4 | from skompiler.dsl import const, func, let, defn, ref, vector 5 | 6 | 7 | def k_means(cluster_centers, inputs, method): 8 | res = [] 9 | for c in cluster_centers: 10 | dx = inputs - const(c) 11 | res.append(let(defn(dx=dx), ref('dx', dx) @ ref('dx', dx))) 12 | 13 | sq_dists = vector(res) 14 | if method == 'transform': 15 | return func.Sqrt(sq_dists) 16 | elif method == 'predict': 17 | return func.ArgMax(sq_dists * -1) 18 | else: 19 | raise ValueError("Unsupported methods for KMeans: {0}".format(method)) 20 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.exclude": { 3 | "**/.git": true, // this is a default value 4 | "**/.DS_Store": true, // this is a default value 5 | "build": true, 6 | "dist": true, 7 | "*.egg-info": true, 8 | "venv": true, 9 | "**/__pycache__": true, 10 | "Untitled.ipynb": true, 11 | ".ipynb_checkpoints": true, 12 | ".pytest_cache": true 13 | }, 14 | "python.pythonPath": "${workspaceFolder}\\venv\\Scripts\\python.exe", 15 | "python.unitTest.pyTestEnabled": true, 16 | "python.linting.pylintEnabled": true, 17 | "python.linting.enabled": true, 18 | "python.linting.pylintUseMinimalCheckers": false, 19 | "python.linting.pylintPath": "${workspaceFolder}\\venv\\Scripts\\pylint.exe", 20 | // https://github.com/Microsoft/vscode-python/issues/435 21 | } -------------------------------------------------------------------------------- /tests/test_sqlalchemy.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | from skompiler.dsl import * 3 | 4 | def equal_queries(query1, query2): 5 | assert ''.join(query1.strip().lower().split()) == ''.join(query2.strip().lower().split()) 6 | 7 | def test_from_obj(): 8 | expr = ident('a')*ident('b') 9 | 10 | result = expr.to('sqlalchemy/sqlite', key_column='_key_', from_obj='_table_') 11 | equal_queries(result, 'select a * b as y from _table_') 12 | 13 | result = expr.to('sqlalchemy/sqlite', key_column='_key_', from_obj=sa.table('_table_', sa.column('_key_'))) 14 | equal_queries(result, 'select a * b as y from _table_') 15 | 16 | cte = sa.select([sa.column('_key_'), (sa.column('x')*2).label('a')], from_obj=sa.text('_table_')).cte('_cte_') 17 | result = expr.to('sqlalchemy/sqlite', key_column='_key_', from_obj=cte) 18 | equal_queries(result, 'with _cte_ as (select _key_, x*2 as a from _table_) select a * b as y from _cte_') 19 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.datasets import load_breast_cancer 4 | from sklearn.preprocessing import StandardScaler 5 | from sklearn.decomposition import PCA 6 | from sklearn.cluster import KMeans 7 | from sklearn.pipeline import Pipeline 8 | from sklearn.neural_network import MLPClassifier 9 | 10 | from skompiler import skompile 11 | 12 | 13 | def test_random_pipeline(): 14 | m = Pipeline([('scale', StandardScaler()), 15 | ('dim_reduce', PCA(6)), 16 | ('cluster', KMeans(10)), 17 | ('classify', MLPClassifier([5, 4], 'tanh'))]) 18 | 19 | X, y = load_breast_cancer(return_X_y=True) 20 | m.fit(X, y) 21 | 22 | expr = skompile(m, 'predict_proba') 23 | 24 | pred_Y = np.asarray([expr.evaluate(x=X[i]) for i in range(len(X))]).ravel() 25 | true_Y = m.predict_proba(X)[:, 1] 26 | 27 | assert np.abs(true_Y - pred_Y.ravel()).max() < 1e-10 28 | -------------------------------------------------------------------------------- /skompiler/toskast/_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Commonly useful functions. 3 | """ 4 | from skompiler.dsl import ident, vector 5 | 6 | 7 | def is_(x): 8 | return lambda self, node, **kw: x 9 | 10 | def prepare_inputs(inputs, n_features=None): 11 | if hasattr(inputs, '__next__'): 12 | # Unroll iterators 13 | inputs = [next(inputs) for i in range(n_features)] 14 | if isinstance(inputs, str): 15 | if not n_features: 16 | raise ValueError("Impossible to determine number of input variables") 17 | return ident(inputs, size=n_features) 18 | elif isinstance(inputs, list): 19 | if n_features is not None and len(inputs) != n_features: 20 | raise ValueError("The number of inputs must match the number of features in the tree") 21 | if isinstance(inputs[0], str): 22 | inputs = [ident(el) for el in inputs] 23 | return vector(inputs) 24 | else: 25 | return inputs 26 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/ensemble/forest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decision trees to SKAST 3 | """ 4 | from skompiler.dsl import sum_ 5 | from ..common import classifier 6 | from ..tree.base import decision_tree 7 | 8 | 9 | def random_forest_classifier(model, inputs, method="predict_proba"): 10 | """ 11 | Creates a SKAST expression corresponding to a given random forest classifier 12 | """ 13 | trees = [decision_tree(estimator.tree_, inputs, method="predict_proba", value_transform=lambda v: v/len(model.estimators_)) 14 | for estimator in model.estimators_] 15 | return classifier(sum_(trees), method) 16 | 17 | 18 | def random_forest_regressor(model, inputs): 19 | """ 20 | Creates a SKAST expression corresponding to a given random forest regressor 21 | """ 22 | 23 | return sum_([decision_tree(estimator.tree_, inputs=inputs, method="predict", value_transform=lambda v: v/len(model.estimators_)) 24 | for estimator in model.estimators_]) 25 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/neural_network/multilayer_perceptron.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multilayer perceptron 3 | """ 4 | from skompiler.dsl import func, const 5 | from ..common import classifier 6 | 7 | _activations = { 8 | 'identity': lambda x: x, 9 | 'tanh': lambda x: func.Sigmoid(x*2)*2 - 1, 10 | 'logistic': func.Sigmoid, 11 | 'relu': lambda x: func.Max(x, const([0] * len(x))), 12 | 'softmax': func.Softmax 13 | } 14 | 15 | def mlp(model, inputs): 16 | actns = [model.activation]*(len(model.coefs_)-1) + [model.out_activation_] 17 | outs = inputs 18 | for M, b, a in zip(model.coefs_, model.intercepts_, actns): 19 | outs = _activations[a](const(M.T) @ outs + const(b)) 20 | return outs 21 | 22 | def mlp_classifier(model, inputs, method): 23 | out = mlp(model, inputs) 24 | if model.n_outputs_ == 1 and method == 'predict': 25 | # Binary classifier 26 | return func.Step(out - 0.5) 27 | else: 28 | return classifier(out, method) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Konstantin Tretyakov 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common useful functionality. 3 | """ 4 | from skompiler.dsl import func, let, defn, ref, repeat 5 | 6 | 7 | def classifier(probs, method): 8 | """Given a probability expression and a method name, returns the classifier output""" 9 | 10 | if method == 'predict_proba': 11 | return probs 12 | elif method == 'predict': 13 | return func.ArgMax(probs) 14 | elif method == 'predict_log_proba': 15 | return func.Log(probs) 16 | else: 17 | raise ValueError("Invalid method: {0}".format(method)) 18 | 19 | def vecsumnormalize(node, vector_dim): 20 | x = node 21 | s = func.VecSum(ref('x', x)) 22 | return let(defn(x=x), 23 | defn(s=s), 24 | ref('x', x) / repeat(ref('s', s), vector_dim)) 25 | 26 | def sklearn_softmax(node, vector_dim): 27 | x = node 28 | xmax = func.VecMax(ref('x', x)) 29 | xfix = ref('x', x) - repeat(ref('xmax', xmax), vector_dim) 30 | return let(defn(x=x), 31 | defn(xmax=xmax), 32 | defn(xfix=xfix), 33 | func.Softmax(ref('xfix', xfix))) 34 | -------------------------------------------------------------------------------- /skompiler/toskast/string.py: -------------------------------------------------------------------------------- 1 | """ 2 | String to SKAST translator. 3 | """ 4 | import ast 5 | from .python import translate as from_python 6 | 7 | 8 | def translate(expr): 9 | """Convert a given (restricted) Python code string to a SK-AST. 10 | 11 | So far we only need this functionality for debugging purposes, 12 | hence instead of implementing a full-fledged parser, we rely on 13 | Python's ast.parse. 14 | 15 | This means that if the expression contains non-Python code 16 | or Python code which we cannot translate, you will get cryptic errors. 17 | 18 | >>> expr = translate("12.4 * (X1[25.3] + Y)") 19 | >>> print(str(expr)) 20 | (12.4 * (X1[25.3] + Y)) 21 | >>> expr = translate("a=X; b=a+2; 12.4 * (a[25.3] + b + y)") 22 | Traceback (most recent call last): 23 | ... 24 | ValueError: Subscripting named references is not supported 25 | >>> expr = translate("a=X; b=a+2; 12.4 * (a + b + y)") 26 | >>> print(str(expr)) 27 | { 28 | $a = X; 29 | $b = ($a + 2); 30 | (12.4 * (($a + $b) + y)) 31 | } 32 | """ 33 | return from_python(ast.parse(expr)) 34 | -------------------------------------------------------------------------------- /tests/test_ast.py: -------------------------------------------------------------------------------- 1 | #pylint: disable=wildcard-import,unused-wildcard-import,no-member 2 | import numpy as np 3 | from skompiler.ast import * 4 | from skompiler.dsl import * 5 | 6 | 7 | def test_dsl(): 8 | assert isinstance(ident('x'), Identifier) 9 | assert isinstance(ident('x', 1), VectorIdentifier) 10 | assert isinstance(const(1), NumberConstant) 11 | assert isinstance(const([1]), VectorConstant) 12 | assert isinstance(const([[1]]), MatrixConstant) 13 | assert isinstance(const(np.array([1], dtype='int')[0]), NumberConstant) 14 | assert isinstance(const(np.array(1)), NumberConstant) 15 | mtx = const(np.array([[1, 2]])) 16 | assert isinstance(mtx, MatrixConstant) 17 | assert len(mtx) == 1 18 | v = mtx[0] 19 | assert isinstance(v, VectorConstant) 20 | assert len(v) == 2 21 | n = v[1] 22 | assert isinstance(n, NumberConstant) 23 | assert n.value == 2 24 | ids = vector(map(ident, 'abc')) 25 | assert isinstance(ids, MakeVector) 26 | assert len(ids.elems) == 3 27 | assert isinstance(ids.elems[0], Identifier) 28 | 29 | def test_singleton(): 30 | assert Add() is Add() 31 | assert func.Add is Add() 32 | assert func.Mul is Mul() 33 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/linear_model/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKLearn linear model to SKAST. 3 | """ 4 | from skompiler.dsl import const 5 | 6 | def linear_model(coef, intercept, inputs): 7 | """ 8 | Linear regression. 9 | Depending on the shape of the coef and intercept, produces either a single-valued 10 | linear model (w @ x + b) or a multi-valued one (M @ x + b_vec) 11 | 12 | Args: 13 | 14 | coef (np.array): A vector (1D array, for single-valued model) or a matrix (2D array, for multi-valued one) for the model. 15 | intercept: a number (for single-valued) or a 1D array (for multi-valued regression). 16 | inputs: a list of AST nodes to be used as the input vector to the model or a single node, corresponding to a vector. 17 | """ 18 | 19 | single_valued = (coef.ndim == 1) 20 | if single_valued and hasattr(intercept, '__iter__'): 21 | raise ValueError("Single-valued linear model must have a single value for the intercept") 22 | elif not single_valued and (coef.ndim != 2 or intercept.ndim != 1): 23 | raise ValueError("Multi-valued linear model must have a 2D coefficient matrix and a 1D intercept vector") 24 | 25 | return const(coef) @ inputs + const(intercept) 26 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/preprocessing/data.py: -------------------------------------------------------------------------------- 1 | from skompiler.ast import decompose 2 | from skompiler.dsl import const, vector, iif, let, defn, func, ref 3 | 4 | def binarize(threshold, inputs): 5 | if not isinstance(inputs, list): 6 | inputs = decompose(inputs) 7 | return vector([iif(inp <= const(threshold), const(0), const(1)) for inp in inputs]) 8 | 9 | def scale(scale_, min_, inputs): 10 | return inputs * const(scale_) + const(min_) 11 | 12 | def unscale(scale_, inputs): 13 | return inputs / const(scale_) 14 | 15 | def standard_scaler(model, inputs): 16 | if model.with_mean: 17 | inputs = inputs - const(model.mean_) 18 | if model.with_std: 19 | inputs = inputs / const(model.scale_) 20 | return inputs 21 | 22 | def normalizer(norm, inputs): 23 | if norm == 'l2': 24 | norm = func.Sqrt(func.VecSum(inputs * inputs)) 25 | elif norm == 'l1': 26 | norm = func.VecSum(func.Abs(inputs)) 27 | elif norm == 'max': 28 | norm = func.VecMax(inputs) 29 | else: 30 | raise ValueError("Unknown norm {0}".format(norm)) 31 | norm_fix = iif(ref('norm', norm) == const(0), const(1), ref('norm', norm)) 32 | return let(defn(norm=norm), 33 | defn(norm_fix=norm_fix), 34 | inputs / vector([ref('norm_fix', norm_fix)]*len(inputs))) 35 | -------------------------------------------------------------------------------- /tests/test_logistic.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import LogisticRegression 2 | from skompiler.ast import VectorIdentifier 3 | from skompiler.toskast.sklearn.linear_model.logistic import logreg_binary, logreg_multiclass 4 | from .verification import X, y, y_bin, verify 5 | 6 | 7 | _inputs = VectorIdentifier('x', 4) # Iris table has four input columns 8 | 9 | def test_logreg_binary(): 10 | m = LogisticRegression(solver='lbfgs') 11 | m.fit(X, y_bin) 12 | 13 | for method in ['decision_function', 'predict_proba', 'predict_log_proba', 'predict']: 14 | expr = logreg_binary(m.coef_.ravel(), m.intercept_[0], _inputs, method=method) 15 | verify(m, method, expr, True) 16 | 17 | def test_logreg_multiclass_ovr(): 18 | m = LogisticRegression(solver='lbfgs', multi_class='ovr') 19 | m.fit(X, y) 20 | 21 | for method in ['decision_function', 'predict_proba', 'predict_log_proba', 'predict']: 22 | expr = logreg_multiclass(m.coef_, m.intercept_, method=method, inputs=_inputs, multi_class='ovr') 23 | verify(m, method, expr) 24 | 25 | def test_logreg_multiclass_multinomial(): 26 | m = LogisticRegression(solver='lbfgs', multi_class='multinomial') 27 | m.fit(X, y) 28 | 29 | for method in ['decision_function', 'predict_proba', 'predict_log_proba', 'predict']: 30 | expr = logreg_multiclass(m.coef_, m.intercept_, method=method, inputs=_inputs, multi_class='multinomial') 31 | verify(m, method, expr) 32 | -------------------------------------------------------------------------------- /CHANGELOG.txt: -------------------------------------------------------------------------------- 1 | Version 0.7 2 | ----------- 3 | 4 | - Fixed warnings from SKlearn 1.0 (PR#11, PR#14). 5 | - The internal TreeWalker class now tolerates excess input variables. Also from PR#14. 6 | 7 | Version 0.6 8 | ----------- 9 | 10 | - Removed Keras translation (to avoid dependency on Tensorflow) 11 | - Modernized dependencies. SKLearn >= 0.22 required as of this moment. 12 | 13 | Version 0.5-0.5.5 14 | ----------------- 15 | 16 | - Changed skompile call signature 17 | - Fixed Excel code generation 18 | - Changed SQLAlchemy translator 19 | - multistage=True is the default now for SQL and Excel 20 | - Improved code organization 21 | - Nicer DSL for generating SKAST, basic symbolic processing and type checking 22 | - New algorithms: AdaBoostClassifier, KMeans, PCA, MLPClassifier, MLPRegressor, 23 | Normalizer, StandardScaler, MinMaxScaler, MaxAbsScaler, Binarizer 24 | - (0.5.1) Rudimentary support for Keras MLP models. 25 | - (0.5.2) Fixes issue #1 26 | - (0.5.3) Fixes for Python 3.5, removed dependency on staticdispatch. 27 | - (0.5.4) Portable Format for Analytics (PFA) added as a target 28 | - (0.5.5) Translation to SQLAlchemy supports CTE as from_obj 29 | 30 | Version 0.4 31 | ------------- 32 | 33 | - More logical set of SK AST nodes, proper handling of Let nodes 34 | - Support for Pipeline objects 35 | 36 | Version 0.3-0.3.1 37 | ------------- 38 | 39 | - Multi-stage Excel code generation 40 | 41 | Version 0.2 42 | ------------- 43 | 44 | - Multi-stage SQL code generation 45 | 46 | Version 0.1 47 | ------------- 48 | 49 | - First prototype. Supports: 50 | * Inputs: linear models, trees, random forest and gradient boosting (partial) 51 | * Outputs: string, python, excel, sqlalchemy, sympy 52 | 53 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "type": "shell", 6 | "isBackground": true, 7 | "options": { 8 | "env": {"PATH": "./venv/Scripts"} 9 | }, 10 | "presentation": { 11 | "panel": "dedicated" 12 | }, 13 | "tasks": [ 14 | { 15 | "label": "jupyter", 16 | "type": "shell", 17 | "command": "jupyter notebook", 18 | "problemMatcher": [] 19 | }, 20 | { 21 | "label": "shell", 22 | "type": "shell", 23 | "problemMatcher": [], 24 | "command": "ipython", 25 | "presentation": { 26 | "reveal": "always", 27 | "panel": "new", 28 | "focus": true 29 | } 30 | }, 31 | { 32 | "label": "py.test: all tests", 33 | "command": "py.test", 34 | "problemMatcher": [], 35 | "presentation": { 36 | "reveal": "always", 37 | "panel": "dedicated" 38 | }, 39 | "group": "test" 40 | }, 41 | { 42 | "label": "py.test: current file", 43 | "command": "py.test", 44 | "args": [ 45 | "${file}" 46 | ], 47 | "problemMatcher": [], 48 | "presentation": { 49 | "reveal": "always", 50 | "panel": "dedicated" 51 | }, 52 | "group": { 53 | "kind": "test", 54 | "isDefault": true 55 | } 56 | }, 57 | { 58 | "label": "Upload to PyPI", 59 | "command": ".\\.vscode\\pypi-upload.cmd", 60 | "problemMatcher": [] 61 | } 62 | ] 63 | } -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/ensemble/gradient_boosting.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decision trees to SKAST 3 | """ 4 | from skompiler.dsl import vector, const, sum_ 5 | from ..tree.base import decision_tree 6 | 7 | 8 | def gradient_boosting_classifier(model, inputs, method="decision_function"): 9 | """ 10 | Creates a SKAST expression corresponding to a given gradient boosting classifier 11 | 12 | At the moment we only support model's decision_function method. 13 | FYI: Conversion to probabilities and a prediction depends on the loss and by default 14 | is done as np.exp(score - (logsumexp(score, axis=1)[:, np.newaxis]))) 15 | """ 16 | 17 | if method != "decision_function": 18 | raise NotImplementedError("Only decision_function is implemented for gradient boosting models so far") 19 | 20 | tree_exprs = [vector([decision_tree(estimator.tree_, inputs, method="predict", value_transform=lambda v: v * model.learning_rate) 21 | for estimator in iteration]) 22 | for iteration in model.estimators_] 23 | # Here we rely on the fact that DummyClassifier.predict() does not really read the input vectors. 24 | # Consequently model.loss_.get_init_raw_predictions([], model.) kind-of-works. 25 | return sum_(tree_exprs + [const(model.loss_.get_init_raw_predictions([[]], model.init_)[0])]) 26 | 27 | def gradient_boosting_regressor(model, inputs, method="decision_function"): 28 | """ 29 | Creates a SKAST expression corresponding to a given GB regressor. 30 | 31 | The logic is mostly the same as for the classifier, except we work with scalars rather than vectors. 32 | """ 33 | 34 | if method != "decision_function": 35 | raise NotImplementedError("Only decision_function is implemented for gradient boosting models so far") 36 | 37 | tree_exprs = [decision_tree(iteration[0].tree_, inputs, method="predict", value_transform=lambda v: v * model.learning_rate) 38 | for iteration in model.estimators_] 39 | # See remark above about the hack used here. 40 | return sum_(tree_exprs + [const(model.loss_.get_init_raw_predictions([[]], model.init_)[0,0])]) 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | SKompiler: Library for converting trained SKLearn models into abstract expressions suitable 3 | for further compilation into executable code in various languages. 4 | 5 | Author: Konstantin Tretyakov 6 | License: MIT 7 | ''' 8 | 9 | from setuptools import setup, find_packages 10 | 11 | setup(name='SKompiler', 12 | version=[ln for ln in open("skompiler/__init__.py") if ln.startswith("__version__")][0].split("'")[1], 13 | description="Library for compiling trained SKLearn models into abstract expressions " 14 | "suitable for further compilation into executable code in various languages.", 15 | long_description=open("README.md", encoding='utf-8').read(), 16 | long_description_content_type="text/markdown", 17 | classifiers=[ # Get strings from http://pypi.python.org/pypi?%3Aaction=list_classifiers 18 | 'Development Status :: 3 - Alpha', 19 | 'Programming Language :: Python :: 3.5', 20 | 'Programming Language :: Python :: 3.6', 21 | 'Programming Language :: Python :: 3.7', 22 | 'Programming Language :: Python :: 3.8', 23 | 'Programming Language :: Python :: 3.9', 24 | 'Topic :: Software Development :: Code Generators', 25 | 'License :: OSI Approved :: MIT License', 26 | 'Intended Audience :: Developers', 27 | 'Intended Audience :: Science/Research', 28 | ], 29 | keywords='sklearn datascience modelling deployment', 30 | author='Konstantin Tretyakov', 31 | author_email='konstantin.tretjakov@gmail.com', 32 | url='https://github.com/konstantint/SKompiler', 33 | license='MIT', 34 | packages=find_packages(exclude=["examples", "tests"]), 35 | include_package_data=True, 36 | zip_safe=True, 37 | install_requires=["scikit-learn >= 0.22"], 38 | extras_require={ 39 | "full": ["sympy", "sqlalchemy", "astor >= 0.6"], 40 | "test": ["sympy", "sqlalchemy", "astor >= 0.6", "pytest", "pandas", "titus2"], 41 | "dev": ["sympy", "sqlalchemy", "astor >= 0.6", "pytest", "pandas", 42 | "pylint", "jupyter", "twine", "pyyaml", "titus2"], 43 | } 44 | ) 45 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/ensemble/weight_boosting.py: -------------------------------------------------------------------------------- 1 | """ 2 | AdaBoost 3 | """ 4 | import numpy as np 5 | from skompiler.dsl import const, func, sum_ 6 | from ..common import classifier, sklearn_softmax 7 | from ..tree.base import decision_tree 8 | 9 | # NB: AdaboostRegressor is annoying to implement as it requires 10 | # finding a weighted median among the all estimator predictions, 11 | # which is not meaningfully implementable without adding extra special functions 12 | # Thus we only support AdaboostClassifier for now 13 | def adaboost_classifier(model, inputs, method="predict_proba"): 14 | """ 15 | Creates a SKAST expression corresponding to a given adaboost classifier. 16 | """ 17 | divisor = model.estimator_weights_.sum() 18 | if method == 'decision_function': 19 | divisor /= (model.n_classes_ - 1) 20 | tree_exprs = [decision_tree(e.tree_, 21 | method='predict_proba' if model.algorithm == 'SAMME.R' else 'predict', 22 | inputs=inputs, 23 | value_transform=adaboost_values(model, w/divisor, method)) 24 | for e, w in zip(model.estimators_, model.estimator_weights_)] 25 | decision = sum_(tree_exprs) 26 | 27 | if method == 'decision_function': 28 | if model.n_classes_ == 2: 29 | decision = decision @ const([-1, 1]) 30 | return decision 31 | elif method == 'predict': 32 | return func.ArgMax(decision) 33 | else: 34 | return classifier(sklearn_softmax(decision, model.n_classes_), method) 35 | 36 | 37 | def adaboost_values(m, weight=1.0, method='predict_proba'): 38 | def _samme(proba): 39 | proba = np.array(proba) 40 | proba[proba < np.finfo(proba.dtype).eps] = np.finfo(proba.dtype).eps 41 | log_proba = np.log(proba) 42 | return (log_proba - (1. / m.n_classes_) * log_proba.sum(axis=1)[:, np.newaxis])*weight 43 | 44 | def _ada_predict(preds): 45 | probs = np.zeros((len(preds), m.n_classes_)) 46 | probs[np.arange(len(preds)), preds] = 1 47 | probs *= weight/(m.n_classes_-1) 48 | return probs 49 | 50 | if m.algorithm == 'SAMME.R': 51 | return _samme 52 | else: 53 | return _ada_predict 54 | -------------------------------------------------------------------------------- /skompiler/dsl.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of convenience wrapper functions for AST creation. 3 | """ 4 | #pylint: disable=protected-access 5 | import numpy as np 6 | from . import ast 7 | 8 | def const(value): 9 | ## Convenience function for creating Number, Vector or MatrixConstants. 10 | if np.isscalar(value): 11 | return ast.NumberConstant(value) 12 | elif hasattr(value, 'ndim'): 13 | if value.ndim == 0: 14 | return ast.NumberConstant(value.item()) 15 | if value.ndim == 1: 16 | return ast.VectorConstant(value) 17 | elif value.ndim == 2: 18 | return ast.MatrixConstant(value) 19 | else: 20 | raise ValueError("Only one or two-dimensional vectors are supported") 21 | elif hasattr(value, '__iter__') or hasattr(value, '__next__'): 22 | return const(np.asarray(value)) 23 | else: 24 | raise ValueError("Invalid constant: {0}".format(value)) 25 | 26 | class _FuncCreator: 27 | ## An object similar to sqlalchemy.func 28 | def __getattr__(self, attrname): 29 | if attrname != '__wrapped__': # Otherwise test discovery crashes 30 | return getattr(ast, attrname, None)() 31 | 32 | func = _FuncCreator() 33 | 34 | def vector(elems): 35 | if hasattr(elems, '__next__'): 36 | elems = list(elems) 37 | return ast.MakeVector(elems) 38 | 39 | def ident(name, size=None): 40 | return ast.VectorIdentifier(name, size) if size else ast.Identifier(name) 41 | 42 | def ref(name, to_obj=None): 43 | if to_obj is not None: 44 | try: 45 | size = len(to_obj) 46 | except ast.UnableToDecompose: 47 | size = None 48 | return ast.TypedReference(name, to_obj._dtype, size) 49 | else: 50 | return ast.Reference(name) 51 | 52 | def defn(**kw): 53 | for name, value in kw.items(): 54 | return ast.Definition(name, value) 55 | 56 | def let(*steps): 57 | return ast.Let(list(steps[:-1]), steps[-1]) 58 | 59 | def iif(test, iftrue, iffalse): 60 | return ast.IfThenElse(test, iftrue, iffalse) 61 | 62 | def repeat(node, n_times): 63 | return ast.MakeVector([node]*n_times) 64 | 65 | def mean(elems, vector_dim=None): 66 | divisor = len(elems) 67 | if vector_dim is not None: 68 | divisor = [divisor] * vector_dim 69 | return sum_(elems) / const(divisor) 70 | 71 | def sum_(elems): 72 | return ast.LFold(func.Add, elems) 73 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/linear_model/logistic.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKLearn logistic regression to SKAST. 3 | """ 4 | from skompiler.dsl import func 5 | from .base import linear_model 6 | from ..common import classifier, vecsumnormalize, sklearn_softmax 7 | 8 | 9 | def logreg_binary(coef, intercept, inputs, method="predict_proba"): 10 | """ 11 | Binary logistic regression. 12 | 13 | Args: 14 | 15 | inputs: a list of AST nodes to be used as inputs to the model. 16 | 17 | Kwargs: 18 | 19 | method (string): The sklearn method's output to emulate. 20 | 'decision_function' - The logistic regression decision function only. 21 | 'predict_proba' - The output will be the probability of class 1 22 | (note that it is NOT two probabilities, as in case of SKLearn's 23 | actual predict_proba output) 24 | 'predict_log_proba' - The log probability of class 1 25 | 'predict' - The output will be an integer 0/1, predicting the class. 26 | """ 27 | decision = linear_model(coef, intercept, inputs) 28 | 29 | if method == "decision_function": 30 | return decision 31 | if method == "predict": 32 | return func.Step(decision) 33 | 34 | return classifier(func.Sigmoid(decision), method) 35 | 36 | 37 | def logreg_multiclass(coef_matrix, intercept_vector, inputs='x', method="predict_proba", multi_class='ovr'): 38 | """ 39 | Multiclass logistic regression. 40 | 41 | Kwargs: 42 | 43 | output (string): The sklearn method's output to emulate. 44 | 'decision_function' - The logistic regression decision function only. 45 | 'predict_proba' - The output will be the probability of class 1 46 | (note that it is NOT two probabilities, as in case of SKLearn's 47 | actual predict_proba output) 48 | 'predict_log_proba' - The log probability of class 1 49 | 'predict' - The output will be an integer 0/1, predicting the class. 50 | 51 | inputs (string or ASTNode): The name of the inputs variable to use in the formula. 52 | Any SKAST expression could be used instead, assuming it encodes an 53 | input vector. 54 | 55 | multi_class (string): The value of the "multi_class" setting used in the model. Either 'ovr' or 'multinomial' 56 | """ 57 | decision = linear_model(coef_matrix, intercept_vector, inputs) 58 | 59 | if method == "decision_function": 60 | return decision 61 | if method == "predict": 62 | return func.ArgMax(decision) 63 | 64 | if multi_class == 'ovr': 65 | probs = vecsumnormalize(func.Sigmoid(decision), coef_matrix.shape[0]) 66 | elif multi_class == 'multinomial': 67 | probs = sklearn_softmax(decision, coef_matrix.shape[0]) 68 | else: 69 | raise ValueError("Invalid value of the multi_class argument: " + multi_class) 70 | 71 | return classifier(probs, method) 72 | -------------------------------------------------------------------------------- /tests/verification.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility for basic testing of models. 3 | """ 4 | import warnings 5 | import numpy as np 6 | from sklearn.datasets import load_iris 7 | from .evaluators import PythonEval, SympyEval, SQLiteEval, ExcelEval, PFAEval 8 | 9 | # Set up evaluators 10 | X, y = load_iris(return_X_y=True) 11 | y_bin = np.array(y) 12 | y_bin[y_bin == 2] = 0 13 | 14 | _evaluators = { 15 | 'python': PythonEval(X), 16 | 'sympy': SympyEval(X, true_argmax=False), 17 | 'sympy2': SympyEval(X, true_argmax=True), 18 | 'sqlite': SQLiteEval(X, False), 19 | 'sqlite2': SQLiteEval(X, True), 20 | 'excel': ExcelEval(X), 21 | 'pfa': PFAEval(X), 22 | } 23 | 24 | def verify_one(model, method, evaluator, expr, binary_fix=False, inf_fix=False, data_preprocessing=None): 25 | X_inputs = X 26 | if data_preprocessing: 27 | X_inputs = data_preprocessing(X_inputs) 28 | true_Y = getattr(model, method)(X_inputs) 29 | if binary_fix and true_Y.ndim > 1 and true_Y.shape[1] > 1: 30 | true_Y = true_Y[:, 1] 31 | if true_Y.ndim > 1 and true_Y.shape[1] == 1: 32 | true_Y = true_Y[:, 0] 33 | pred_Y = _evaluators[evaluator](expr) 34 | if inf_fix: 35 | # Our custom SQL log function returns -FLOAT_MIN instead of -inf for log(0) 36 | pred_Y[pred_Y == np.finfo('float64').min] = -float('inf') 37 | assert (np.isinf(true_Y) == np.isinf(pred_Y)).all() 38 | assert np.abs(true_Y[~np.isinf(true_Y)] - pred_Y[~np.isinf(pred_Y)]).max() < 1e-10 39 | else: 40 | assert np.abs(pred_Y - true_Y).max() < 1e-10 41 | 42 | def verify(model, method, expr, binary_fix=False, inf_fix=False): 43 | with warnings.catch_warnings(): 44 | if method == 'predict_log_proba': # Ignore divide by zeroes encountered in log(0) 45 | warnings.simplefilter('ignore', RuntimeWarning) 46 | 47 | verify_one(model, method, 'excel', expr, binary_fix, inf_fix) 48 | verify_one(model, method, 'python', expr, binary_fix, inf_fix) 49 | verify_one(model, method, 'sympy', expr, binary_fix, inf_fix) 50 | if not binary_fix and method == 'predict' and hasattr(model, 'decision_function'): 51 | verify_one(model, method, 'sympy2', expr, binary_fix, inf_fix) # See that Sympy supports true_argmax correctly 52 | verify_one(model, method, 'sqlite', expr, binary_fix, inf_fix) 53 | verify_one(model, method, 'sqlite2', expr, binary_fix, inf_fix) 54 | 55 | warnings.simplefilter('ignore', PendingDeprecationWarning) # Those two come from the PFA evaluator 56 | warnings.simplefilter('ignore', DeprecationWarning) 57 | try: 58 | verify_one(model, method, 'pfa', expr, binary_fix, inf_fix) 59 | except NameError as e: 60 | if str(e) == "name 'inf' is not defined": 61 | # This happens because I do not know how to properly encode inf/-inf in 62 | # the PFA output. Ignore this so far. 63 | pass 64 | else: 65 | raise 66 | -------------------------------------------------------------------------------- /tests/test_skompile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Smoke test for all supported models and translation types. 3 | """ 4 | import warnings 5 | from sklearn.tree import DecisionTreeRegressor 6 | from skompiler.ast import BinOp, Mul, NumberConstant, IndexedIdentifier 7 | from skompiler.dsl import ident 8 | from skompiler.toskast.sklearn import _supported_methods 9 | from skompiler import skompile 10 | from .verification import verify_one 11 | 12 | # Sympy targets not included because these take a long time to evaluate (and fail on some models, e.g. 'sympy/c', 'sympy/js') 13 | _translate_targets = ['string', 'python/code', 'pfa/json'] 14 | _eval_targets = ['python', 'excel', 'sqlite2', 'sympy', 'pfa'] 15 | 16 | def list_supported_methods(model): 17 | if isinstance(model, DecisionTreeRegressor): 18 | return ['predict'] 19 | for cls, methods in _supported_methods.items(): 20 | if isinstance(model, cls): 21 | return methods 22 | raise ValueError("Unsupported model: {0}".format(model)) 23 | 24 | #pylint: disable=unsupported-membership-test 25 | _limit = None 26 | 27 | def test_skompile(models): 28 | # TODO: If we use NumberConstant(2), we get a failed test for RandomForestRegressor. 29 | # Could it be due to float precision issues? 30 | transformed_features = [BinOp(Mul(), IndexedIdentifier('x', i, 4), NumberConstant(2.1)) for i in range(4)] 31 | 32 | with warnings.catch_warnings(): 33 | warnings.simplefilter('ignore', RuntimeWarning) # Ignore divide by zero warning for log(0) 34 | warnings.simplefilter('ignore', PendingDeprecationWarning) # Those two come from the PFA evaluator 35 | warnings.simplefilter('ignore', DeprecationWarning) 36 | for name, model in models.items(): 37 | if _limit and name not in _limit: 38 | continue 39 | methods = list_supported_methods(model) 40 | for method in methods: 41 | if name in ['bin', 'n1', 'n2', 'n3']: # Binarizer and Normalizer want to know number of features 42 | expr = skompile(getattr(model, method), inputs=ident('x', 4)) 43 | else: 44 | expr = skompile(getattr(model, method)) 45 | 46 | print(name, model, method) 47 | for evaluator in _eval_targets: 48 | print(evaluator) 49 | try: 50 | verify_one(model, method, evaluator, expr, 51 | binary_fix=name.endswith('_bin'), inf_fix=(method == 'predict_log_proba')) 52 | except NameError as e: 53 | if evaluator == 'pfa' and str(e) == "name 'inf' is not defined": 54 | # This happens because I do not know how to properly encode inf/-inf in 55 | # the PFA output. Ignore this so far. 56 | pass 57 | else: 58 | raise 59 | for target in _translate_targets: 60 | expr.to(target) 61 | 62 | # Check that everything will work if we provide expressions instead of raw features 63 | expr = skompile(getattr(model, method), transformed_features) 64 | verify_one(model, method, 'python', expr, 65 | binary_fix=name.endswith('_bin'), inf_fix=(method == 'predict_log_proba'), 66 | data_preprocessing=lambda X: X*2.1) 67 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test fixtures. 3 | """ 4 | #pylint: disable=possibly-unused-variable,redefined-outer-name 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | from pytest import fixture 9 | from sklearn.datasets import load_iris 10 | from sklearn.linear_model import LogisticRegression, LinearRegression 11 | from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor 12 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 13 | from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor 14 | from sklearn.ensemble import AdaBoostClassifier 15 | from sklearn.cluster import KMeans 16 | from sklearn.decomposition import PCA 17 | from sklearn.neural_network import MLPRegressor, MLPClassifier 18 | from sklearn.svm import SVR, SVC 19 | from sklearn.preprocessing import Binarizer, MinMaxScaler, MaxAbsScaler, StandardScaler, Normalizer 20 | 21 | 22 | @fixture(scope='session') 23 | def iris(): 24 | return load_iris(return_X_y=True) 25 | 26 | @fixture(scope='session') 27 | def X(iris): 28 | return iris[0] 29 | 30 | @fixture(scope='session') 31 | def y(iris): 32 | return iris[1] 33 | 34 | @fixture(scope='session') 35 | def y_bin(y): 36 | y_bin = np.array(y) 37 | y_bin[y_bin == 2] = 0 38 | return y_bin 39 | 40 | @fixture(scope='session') 41 | def y_ohe(y): 42 | return pd.get_dummies(y) 43 | 44 | def make_models(X, y, y_bin): 45 | return dict( 46 | ols=LinearRegression().fit(X, y), 47 | lr_bin=LogisticRegression().fit(X, y_bin), 48 | lr_ovr=LogisticRegression(multi_class='ovr').fit(X, y), 49 | lr_mn=LogisticRegression(solver='lbfgs', multi_class='multinomial').fit(X, y), 50 | svc=SVC(kernel='linear').fit(X, y_bin), 51 | svr=SVR(kernel='linear').fit(X, y), 52 | dtc=DecisionTreeClassifier(max_depth=4).fit(X, y), 53 | dtr=DecisionTreeRegressor(max_depth=4).fit(X, y), 54 | rfc=RandomForestClassifier(n_estimators=3, max_depth=3, random_state=1).fit(X, y), 55 | rfr=RandomForestRegressor(n_estimators=3, max_depth=3, random_state=1).fit(X, y), 56 | gbc=GradientBoostingClassifier(n_estimators=3, max_depth=3, random_state=1).fit(X, y), 57 | gbr=GradientBoostingRegressor(n_estimators=3, max_depth=3, random_state=1).fit(X, y), 58 | abc=AdaBoostClassifier(algorithm='SAMME', n_estimators=3, random_state=1).fit(X, y), 59 | abc2=AdaBoostClassifier(algorithm='SAMME.R', n_estimators=3, random_state=1).fit(X, y), 60 | abc3=AdaBoostClassifier(algorithm='SAMME', n_estimators=3, random_state=1).fit(X, y_bin), 61 | abc4=AdaBoostClassifier(algorithm='SAMME.R', n_estimators=3, random_state=1).fit(X, y_bin), 62 | km=KMeans(1).fit(X), 63 | km2=KMeans(5).fit(X), 64 | pc1=PCA(1).fit(X), 65 | pc2=PCA(2).fit(X), 66 | pc3=PCA(2, whiten=True).fit(X), 67 | mlr1=MLPRegressor([2], 'relu').fit(X, y), 68 | mlr2=MLPRegressor([2, 1], 'tanh').fit(X, y), 69 | mlr3=MLPRegressor([2, 2, 2], 'identity').fit(X, y), 70 | mlc=MLPClassifier([2, 2], 'tanh').fit(X, y), 71 | mlc_bin=MLPClassifier([2, 2], 'identity').fit(X, y_bin), 72 | bin=Binarizer(threshold=0.5), 73 | mms=MinMaxScaler().fit(X), 74 | mas=MaxAbsScaler().fit(X), 75 | ss1=StandardScaler().fit(X), 76 | ss2=StandardScaler(with_mean=False).fit(X), 77 | ss3=StandardScaler(with_std=False).fit(X), 78 | n1=Normalizer(norm='l1'), 79 | n2=Normalizer(norm='l2'), 80 | n3=Normalizer(norm='max') 81 | ) 82 | 83 | 84 | @fixture(scope='session') 85 | def models(X, y, y_bin): 86 | return make_models(X, y, y_bin) 87 | -------------------------------------------------------------------------------- /tests/evaluators.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | import pandas as pd 4 | import numpy as np 5 | import sqlalchemy as sa 6 | 7 | from skompiler.fromskast.sqlalchemy import translate as to_sql 8 | import skompiler.fromskast.sympy as to_sympy 9 | 10 | def _sql_log(x): 11 | if x <= 0: 12 | return np.finfo('float64').min 13 | else: 14 | return np.log(x) 15 | 16 | class SQLiteEval: 17 | def __init__(self, X, multistage): 18 | self.engine = sa.create_engine("sqlite://") 19 | self.conn = self.engine.connect() 20 | self.conn.connection.create_function('log', 1, _sql_log) 21 | self.conn.connection.create_function('exp', 1, math.exp) 22 | self.conn.connection.create_function('sqrt', 1, math.sqrt) 23 | df = pd.DataFrame(X, columns=['x{0}'.format(i+1) for i in range(X.shape[1])]).reset_index() 24 | df.to_sql('data', self.conn) 25 | self.multistage = multistage 26 | 27 | def __call__(self, expr): 28 | query = to_sql(expr, 'sqlite', multistage=self.multistage, key_column='index') 29 | with warnings.catch_warnings(): 30 | warnings.simplefilter('ignore', RuntimeWarning) # divide by zero encountered in log 31 | result = pd.read_sql(query, self.conn).values 32 | if result.shape[1] == 1: 33 | result = result[:, 0] 34 | return result 35 | 36 | def __del__(self): 37 | #self.conn.close() # <-- This raises an exception somewhy 38 | pass 39 | 40 | class PythonEval: 41 | def __init__(self, X): 42 | self.X = X 43 | 44 | def __call__(self, expr): 45 | fn = expr.lambdify() 46 | result = np.asarray([fn(x=x) for x in self.X]) 47 | if result.ndim > 1 and result.shape[-1] == 1: 48 | result = result[..., 0] 49 | return result 50 | 51 | class SympyEval: 52 | def __init__(self, X, true_argmax=False): 53 | self.X = X 54 | self.true_argmax = true_argmax 55 | 56 | def __call__(self, expr): 57 | fn = to_sympy.lambdify('x', to_sympy.translate(expr, true_argmax=self.true_argmax)) 58 | pred_Y = np.asarray([np.array([fn(x.reshape(-1, 1))]).ravel() for x in self.X]) 59 | if pred_Y.shape[-1] == 1: 60 | pred_Y = pred_Y[..., 0] 61 | return pred_Y 62 | 63 | class ExcelEval: 64 | def __init__(self, X): 65 | self.X = X 66 | 67 | def _eval(self, code, n_outputs, x): 68 | inputs = {'x{0}'.format(i+1): x_i for i, x_i in enumerate(x)} 69 | result = code.evaluate(**inputs) 70 | keys = list(result.keys())[-n_outputs:] 71 | return np.asarray([result[k] for k in keys]) 72 | 73 | def __call__(self, expr): 74 | code = expr.to('excel', multistage=True, _max_subexpression_length=500) 75 | # We don't know how many outputs should the expression produce just from the 76 | # excel's result, so we use a hackish way to determine it via a separate evaluator 77 | res = expr.lambdify()(x=self.X[0]) 78 | shape = getattr(res, 'shape', None) 79 | n_outputs = 1 if not shape else shape[0] 80 | result = np.asarray([self._eval(code, n_outputs, x) for x in self.X]) 81 | if result.shape[-1] == 1: 82 | result = result[..., 0] 83 | return result 84 | 85 | class PFAEval: 86 | def __init__(self, X): 87 | self.X = X 88 | 89 | def __call__(self, expr): 90 | from titus.prettypfa import PFAEngine 91 | engine, = PFAEngine.fromJson(expr.to('pfa/json')) 92 | result = np.asarray([engine.action({'x': x}) for x in self.X]) 93 | if result.ndim > 1 and result.shape[-1] == 1: 94 | result = result[..., 0] 95 | return result 96 | -------------------------------------------------------------------------------- /skompiler/api.py: -------------------------------------------------------------------------------- 1 | """ 2 | A convenience interface to SKompiler's functionality. 3 | Wraps around the intricacies of the various toskast/fromskast pieces. 4 | """ 5 | 6 | 7 | def skompile(*args, inputs=None): 8 | """ 9 | Creates a SKAST expression from a given bound method of a fitted SKLearn model. 10 | A shorthand notation for SKompiledModel(method, inputs) 11 | 12 | Args: 13 | 14 | args: Either a bound method of a trained model (e.g. skompile(model.predict_proba)), 15 | OR two arguments - a model and a method name (e.g. skompile(model, 'predict_proba') 16 | (which may be necessary for some models where the first option cannot be used due to metaclasses) 17 | 18 | inputs: A string or a list of strings, or a SKAST node or a list of SKAST nodes, 19 | denoting the input variable(s) to your model. 20 | A single string corresponds to a vector variable (which will be indexed to access 21 | the components). A list of strings corresponds to a vector with separately named components. 22 | You may pass the inputs as a non-keyword argument as well (the last one in *args) 23 | If not specified, the default value of 'x' is used. 24 | 25 | Returns: 26 | An instance of SKompiledModel. 27 | 28 | Examples: 29 | >>> from sklearn.datasets import load_iris 30 | >>> from sklearn.linear_model import LogisticRegression 31 | >>> X, y = load_iris(return_X_y=True) 32 | >>> m = LogisticRegression().fit(X, y) 33 | >>> print(skompile(m.predict)) 34 | argmax((([[...]] m@v x) + [ ...])) 35 | >>> print(skompile(m, 'predict')) 36 | argmax((([[...]] m@v x) + [ ...])) 37 | >>> print(skompile(m.predict, 'y')) 38 | argmax((([[...]] m@v y) + [ ...])) 39 | >>> print(skompile(m.predict, ['x','y','z','w'])) 40 | argmax((([[...]] m@v [x, y, z, w]) + [ ...])) 41 | >>> print(skompile(m, 'predict', 'y')) 42 | argmax((([[...]] m@v y) + [ ...])) 43 | >>> print(skompile(m, 'predict', ['x','y','z','w'])) 44 | argmax((([[...]] m@v [x, y, z, w]) + [ ...])) 45 | >>> from skompiler.ast import VectorIdentifier, Identifier 46 | >>> print(skompile(m, 'predict', VectorIdentifier('y', 4))) 47 | argmax((([[...]] m@v y) + [ ...])) 48 | >>> print(skompile(m, 'predict', map(Identifier, ['x','y','z','w']))) 49 | argmax((([[...]] m@v [x, y, z, w]) + [ ...])) 50 | >>> from sklearn.pipeline import Pipeline 51 | >>> p = Pipeline([('1', m)]) 52 | >>> skompile(p.predict) 53 | Traceback (most recent call last): 54 | ... 55 | ValueError: The bound method ... Please, use the skompile(m, 'predict') syntax instead. 56 | """ 57 | 58 | if len(args) > 3: 59 | raise ValueError("Too many arguments") 60 | elif not args: 61 | raise ValueError("Invalid arguments") 62 | elif len(args) == 3: 63 | if inputs is not None: 64 | raise ValueError("Too many arguments") 65 | model, method, inputs = args 66 | elif len(args) == 2: 67 | if hasattr(args[0], '__call__'): 68 | model, method = _get_model_and_method(args[0]) 69 | inputs = args[1] 70 | else: 71 | model, method = args 72 | else: 73 | model, method = _get_model_and_method(args[0]) 74 | if not inputs: 75 | inputs = 'x' 76 | return _translate(model, inputs, method) 77 | 78 | def _translate(model, inputs, method): 79 | if model.__class__.__module__.startswith('keras.'): 80 | if method != 'predict': 81 | raise ValueError("Only the 'predict' method is supported for Keras models") 82 | # Import here, this way we do not force everyone to install everything 83 | from .toskast.keras import translate as from_keras 84 | return from_keras(model, inputs) 85 | else: 86 | from .toskast.sklearn import translate as from_sklearn 87 | return from_sklearn(model, inputs=inputs, method=method) 88 | 89 | def _get_model_and_method(obj): 90 | if not hasattr(obj, '__call__'): 91 | raise ValueError("Please, provide a method to compile.") 92 | if not hasattr(obj, '__self__'): 93 | raise ValueError("The bound method object was probably mangled by " 94 | "SKLearn's metaclasses and cannot be passed to skompile as skompile(m.predict). " 95 | "Please, use the skompile(m, 'predict') syntax instead.") 96 | return obj.__self__, obj.__name__ 97 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/tree/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decision trees to SKAST 3 | """ 4 | import numpy as np 5 | from skompiler.ast import decompose 6 | from skompiler.dsl import const, iif 7 | 8 | 9 | def decision_tree(tree, inputs, method="predict", value_transform=None): 10 | """ 11 | Creates a SKAST expression corresponding to a given SKLearn Tree object. 12 | 13 | Kwargs: 14 | 15 | inputs: a list of AST nodes to be used as inputs to the model. 16 | method: 'predict' (for classifier and regressor models), 17 | 'predict_proba' or 'predict_log_proba' (for classifier models) 18 | 19 | value_transform: If not None, the tree values are processed using the given operator. 20 | This way we may propagate constant operations into trees 21 | (e.g. instead of const * decision_tree(...) you may just have a 22 | decision tree which outputs const * value) 23 | 24 | >>> from sklearn.datasets import load_iris 25 | >>> from sklearn.tree import DecisionTreeClassifier 26 | >>> from skompiler.dsl import ident, vector 27 | >>> m = DecisionTreeClassifier(max_depth=2, random_state=1).fit(*load_iris(return_X_y=True)) 28 | >>> print(decision_tree(m.tree_, ident('x', m.n_features_in_))) 29 | (if (x[3] <= 0.80...) then 0 else (if (x[3] <= 1.75) then 1 else 2)) 30 | 31 | >>> inputs = vector(map(ident, 'abcd')) 32 | >>> print(decision_tree(m.tree_, inputs, method='predict_proba')) 33 | (if (d <= 0.80...) then [1. 0. 0.] else (if (d <= 1.75) then [0... 0.90... 0.09...] else [0... 0.02... 0.97...])) 34 | """ 35 | v = tree.value[:, 0, :] 36 | if v.shape[1] == 1: 37 | # Regression model 38 | if method != 'predict': 39 | raise ValueError("Only predict method is supported for regression trees") 40 | v = v[:, 0] 41 | else: 42 | # Classifier 43 | if method == "predict": 44 | v = np.argmax(v, axis=1) 45 | else: 46 | v = v / v.sum(axis=1)[:, np.newaxis] 47 | if method == "predict_log_proba": 48 | v = np.log(v) 49 | elif method != "predict_proba": 50 | raise ValueError("Invalid method: {0}".format(method)) 51 | 52 | if value_transform: 53 | v = value_transform(v) 54 | return TreeWalker(tree, inputs, v).walk() 55 | 56 | 57 | class TreeWalker: 58 | """Converts a SKLearn Tree object to a SKAST expression. 59 | 60 | >>> from sklearn.datasets import load_iris 61 | >>> from sklearn.tree import DecisionTreeRegressor 62 | >>> from skompiler.dsl import ident 63 | >>> m = DecisionTreeRegressor(max_depth=2, random_state=1).fit(*load_iris(return_X_y=True)) 64 | >>> tr = TreeWalker(m.tree_, ident('x', 4)) 65 | >>> print(tr.walk()) 66 | (if (x[3] <= 0.80...) then 0.0 else (if (x[3] <= 1.75) then 1.09... else 1.97...)) 67 | >>> tr = TreeWalker(m.tree_, ident('x', 4), np.arange(m.tree_.node_count)) 68 | >>> print(tr.walk()) 69 | (if (x[3] <= 0.80...) then 1 else (if (x[3] <= 1.75) then 3 else 4)) 70 | """ 71 | 72 | def __init__(self, tree, features, node_values=None): 73 | """ 74 | Kwargs: 75 | node_values (list/array): A way to override the tree.value array. 76 | Must be a 1D or 2D array of values. 77 | """ 78 | if not isinstance(features, list): 79 | features = decompose(features) 80 | 81 | self.tree = tree 82 | self.features = features 83 | if len(self.features) < tree.n_features: 84 | raise ValueError(f"Incorrect number of features provided. Expect {tree.n_features} but have {self.features}") 85 | if node_values is None: 86 | self.values = tree.value[:, 0] 87 | if self.values.shape[1] == 1: 88 | self.values = self.values[:, 0] 89 | else: 90 | self.values = node_values 91 | 92 | def walk(self, node_id=0): 93 | if node_id >= self.tree.node_count or node_id < 0: 94 | raise ValueError("Invalid node id") 95 | if self.tree.children_left[node_id] == -1: 96 | if self.tree.children_right[node_id] != -1: 97 | raise ValueError("Invalid tree structure. Children must either be both present or absent") 98 | 99 | if self.values.ndim == 1: 100 | return const(self.values[node_id].item()) 101 | else: 102 | return const(self.values[node_id]) 103 | else: 104 | ft = self.tree.feature[node_id] 105 | if ft < 0 or ft >= len(self.features): 106 | raise ValueError("Invalid feature value for node {0}".format(node_id)) 107 | return iif(self.features[ft] <= const(self.tree.threshold[node_id].item()), 108 | self.walk(self.tree.children_left[node_id]), 109 | self.walk(self.tree.children_right[node_id])) 110 | -------------------------------------------------------------------------------- /tests/test_sklearn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test generic sklearn translate function on supported classifier models. 3 | """ 4 | 5 | import warnings 6 | import numpy as np 7 | from sklearn.linear_model import LogisticRegression, LogisticRegressionCV, \ 8 | LinearRegression, Ridge, Lars, LarsCV, ElasticNet, ElasticNetCV, \ 9 | Lasso, LassoCV, LassoLars, LassoLarsCV, LassoLarsIC 10 | from sklearn.svm import SVC, SVR 11 | from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor 12 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, \ 13 | GradientBoostingClassifier, GradientBoostingRegressor 14 | from sklearn.preprocessing import Binarizer 15 | from sklearn.pipeline import Pipeline 16 | from skompiler.toskast.sklearn import translate 17 | import skompiler.fromskast.sympy as to_sympy 18 | from skompiler import skompile 19 | from skompiler.ast import VectorIdentifier 20 | from .verification import X, y, y_bin, verify as _verify 21 | from .evaluators import SQLiteEval 22 | 23 | 24 | def verify(model, methods=None, binary_fix=False, inputs='x'): 25 | methods = methods or ['decision_function', 'predict_proba', 'predict_log_proba', 'predict'] 26 | for method in methods: 27 | expr = translate(model, inputs=inputs, method=method) 28 | _verify(model, method, expr, binary_fix=binary_fix, inf_fix=method == 'predict_log_proba') 29 | 30 | def test_logreg(): 31 | m = LogisticRegression(solver='lbfgs', multi_class='ovr') 32 | verify(m.fit(X, y_bin), binary_fix=True) 33 | verify(m.fit(X, y)) 34 | verify(m.set_params(multi_class='multinomial').fit(X, y)) 35 | 36 | def test_logregcv(): 37 | m = LogisticRegressionCV(solver='lbfgs', multi_class='ovr') 38 | verify(m.fit(X, y_bin), binary_fix=True) 39 | verify(m.fit(X, y)) 40 | verify(m.set_params(multi_class='multinomial').fit(X, y)) 41 | 42 | def test_linearsvc(): 43 | m = SVC(kernel='linear') 44 | verify(m.fit(X, y_bin), ['decision_function', 'predict'], True) 45 | # Non-binary SVM not implemented so far 46 | # verify(m.fit(X, y), methods=['decision_function', 'predict']) 47 | 48 | def test_linreg(): 49 | for m in [LinearRegression(), Ridge(), Lars(), LarsCV(), ElasticNet(), ElasticNetCV(), \ 50 | Lasso(), LassoCV(), LassoLars(), LassoLarsCV(), LassoLarsIC(), SVR(kernel='linear')]: 51 | verify(m.fit(X, y), ['predict']) 52 | if not isinstance(m, SVR): 53 | verify(m.set_params(fit_intercept=False).fit(X, y), ['predict']) 54 | 55 | def test_tree(): 56 | for m_class in [DecisionTreeClassifier, DecisionTreeRegressor]: 57 | m = m_class(max_depth=3).fit(X, y) 58 | verify(m, ['predict']) 59 | with warnings.catch_warnings(): 60 | warnings.simplefilter('ignore', RuntimeWarning) 61 | if m_class == DecisionTreeClassifier: 62 | verify(m, ['predict_proba', 'predict_log_proba']) 63 | 64 | def test_rf(): 65 | for m_class in [RandomForestClassifier, RandomForestRegressor]: 66 | m = m_class(random_state=1, max_depth=3, n_estimators=3).fit(X, y) 67 | verify(m, ['predict']) 68 | if m_class == RandomForestClassifier: 69 | verify(m, ['predict_proba', 'predict_log_proba']) 70 | 71 | def test_gb(): 72 | for m_class in [GradientBoostingClassifier, GradientBoostingRegressor]: 73 | m = m_class(random_state=1, max_depth=3, n_estimators=3).fit(X, y) 74 | if m_class == GradientBoostingRegressor: 75 | verify(m, ['predict']) 76 | else: 77 | verify(m, ['decision_function']) 78 | 79 | def test_columnlist(): 80 | m = LinearRegression() 81 | m.fit(X, y) 82 | true_Y = m.predict(X) 83 | inputs = ['x{0}'.format(i+1) for i in range(X.shape[1])] 84 | expr = translate(m, inputs=inputs) 85 | ev = SQLiteEval(X, True) 86 | assert np.abs(ev(expr) - true_Y).max() < 1e-10 87 | 88 | fn = to_sympy.lambdify(' '.join(inputs), to_sympy.translate(expr)) 89 | pred_Y = np.asarray([fn(*x) for x in X]) 90 | assert np.abs(pred_Y - true_Y).max() < 1e-10 91 | 92 | def test_binarizer(): 93 | b = Binarizer(threshold=np.mean(X)) 94 | inputs = ['x{0}'.format(i+1) for i in range(X.shape[1])] 95 | expr = skompile(b.transform, inputs) 96 | assert np.all(b.transform(X) == np.asarray([expr.evaluate(x1=x[0], x2=x[1], x3=x[2], x4=x[3]) for x in X])) 97 | 98 | 99 | def make_pipeline(*args): 100 | return Pipeline([(str(i), a) for i, a in enumerate(args)]).fit(X, y) 101 | 102 | def test_pipeline(): 103 | b1 = Binarizer(threshold=np.mean(X)) 104 | b2 = Binarizer(threshold=0.5) 105 | m = RandomForestClassifier(10, max_depth=7, random_state=1) 106 | inp = VectorIdentifier('x', 4) 107 | verify(make_pipeline(b1, b2, m), ['predict', 'predict_proba', 'predict_log_proba'], inputs=inp) 108 | verify(make_pipeline(b1, m), ['predict', 'predict_proba', 'predict_log_proba'], inputs=inp) 109 | verify(make_pipeline(m), ['predict', 'predict_proba', 'predict_log_proba'], inputs=inp) 110 | verify(make_pipeline(b1), ['transform'], inputs=inp) 111 | verify(make_pipeline(b1, b2), ['transform'], inputs=inp) 112 | -------------------------------------------------------------------------------- /skompiler/toskast/python.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Python AST to SKAST translator. 3 | 4 | May only convert nodes of Python AST which have a mapping in SK-AST. 5 | Useful for debugging testing and simplistic parsing mostly: 6 | 7 | 8 | >>> expr = ast.parse("12.4 * (X1[25.3] + Y)") 9 | >>> print(str(translate(expr))) 10 | (12.4 * (X1[25.3] + Y)) 11 | >>> expr = ast.parse("a=12+b; b=2*a; 12.4 * (X[25.3] + Y + 2*a*b)") 12 | >>> print(str(translate(expr))) 13 | { 14 | $a = (12 + b); 15 | $b = (2 * $a); 16 | (12.4 * ((X[25.3] + Y) + ((2 * $a) * $b))) 17 | } 18 | ''' 19 | #pylint: disable=wildcard-import,unused-wildcard-import,unused-argument 20 | import ast 21 | from skompiler.ast import * 22 | from ._common import is_ 23 | 24 | def translate(node): 25 | return PythonASTProcessor()(node) 26 | 27 | _funcmap = { 28 | 'log': Log(), 29 | 'exp': Exp(), 30 | 'step': Step(), 31 | 'sqrt': Sqrt(), 32 | 'abs': Abs(), 33 | } 34 | 35 | class PythonASTProcessor: 36 | 37 | def __call__(self, node, **kw): 38 | cls = node.__class__.__name__ 39 | if not hasattr(self, cls): 40 | raise NotImplementedError("No translation logic implemented for node type {0}".format(cls)) 41 | return getattr(self, cls)(node, **kw) 42 | 43 | def Module(self, module, local_varnames=None): 44 | # Module may have more than one statement. We only allow a sequence of Assign statements, 45 | # followed by an expression 46 | definitions = [] 47 | local_varnames = local_varnames or set() 48 | for assn in module.body[:-1]: 49 | if not isinstance(assn, ast.Assign): 50 | raise NotImplementedError("Only a sequence of assignments followed by an expression is allowed") 51 | if len(assn.targets) != 1: 52 | raise NotImplementedError("Assignment to a single variable allowed only") 53 | if not isinstance(assn.targets[0], ast.Name): 54 | raise NotImplementedError("Assignment may only be done to a named variable") 55 | varname = assn.targets[0].id 56 | definitions.append(Definition(varname, self(assn.value, local_varnames=local_varnames))) 57 | local_varnames.add(varname) 58 | 59 | body = self(module.body[-1], local_varnames=local_varnames) 60 | if local_varnames: 61 | return Let(definitions, body) 62 | return body 63 | 64 | def Expr(self, expr, **kw): 65 | return self(expr.value, **kw) 66 | 67 | def Expression(self, expr, **kw): 68 | return self(expr.body, **kw) 69 | 70 | def Name(self, name, local_varnames=None): 71 | if local_varnames and name.id in local_varnames: 72 | return Reference(name.id) 73 | else: 74 | return Identifier(name.id) 75 | 76 | def Subscript(self, sub, local_varnames=None): 77 | if not isinstance(sub.value, ast.Name): 78 | raise NotImplementedError("Unsupported form of subscript") 79 | if local_varnames and sub.value.id in local_varnames: 80 | raise ValueError("Subscripting named references is not supported") 81 | if isinstance(sub.slice, ast.Index) and isinstance(sub.slice.value, ast.Num): 82 | return IndexedIdentifier(id=sub.value.id, 83 | index=sub.slice.value.n, 84 | size=None) # This makes Sympy sad 85 | elif isinstance(sub.slice, ast.Constant): 86 | return IndexedIdentifier(id=sub.value.id, 87 | index=sub.slice.value, 88 | size=None) 89 | else: 90 | raise NotImplementedError("Unsupported form of subscript") 91 | 92 | def Constant(self, const, **kw): 93 | # Starting from Py3.8 AST for constants is just Constant, rather than Num/Str/NameConstant 94 | if type(const.value) not in [int, float]: 95 | raise ValueError("Only numeric constants are supported") 96 | return NumberConstant(const.value) 97 | 98 | def Num(self, num, **kw): 99 | return NumberConstant(num.n) 100 | 101 | def UnaryOp(self, op, **kw): 102 | return UnaryFunc(op=self(op.op, **kw), 103 | arg=self(op.operand, **kw)) 104 | 105 | def BinOp(self, op, **kw): 106 | return BinOp(op=self(op.op, **kw), 107 | left=self(op.left, **kw), 108 | right=self(op.right, **kw)) 109 | 110 | def Call(self, call, **kw): 111 | if not isinstance(call.func, ast.Name) or call.keywords or len(call.args) != 1: 112 | raise ValueError("Only one-argument functions are supported") 113 | if call.func.id not in _funcmap: 114 | raise ValueError("Unsupported unary function: " + call.func.id) 115 | return UnaryFunc(op=_funcmap[call.func.id], arg=self(call.args[0], **kw)) 116 | 117 | def IfExp(self, ifexp, **kw): 118 | return IfThenElse(self(ifexp.test, **kw), self(ifexp.body, **kw), self(ifexp.orelse, **kw)) 119 | 120 | def Compare(self, cmp, **kw): 121 | if len(cmp.comparators) != 1 or len(cmp.ops) != 1: 122 | raise ValueError("Only one-element comparison expressions are supported") 123 | return BinOp(self(cmp.ops[0], **kw), self(cmp.left, **kw), self(cmp.comparators[0], **kw)) 124 | 125 | def List(self, lst, **kw): 126 | return MakeVector([self(el, **kw) for el in lst.elts]) 127 | 128 | Mult = is_(Mul()) 129 | Add = is_(Add()) 130 | Sub = is_(Sub()) 131 | USub = is_(USub()) 132 | LtE = is_(LtEq()) 133 | Div = is_(Div()) 134 | -------------------------------------------------------------------------------- /skompiler/fromskast/pfa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converter from SKAST to Portable Format for Analytics 3 | (http://dmg.org/pfa/) 4 | """ 5 | import json 6 | from functools import reduce 7 | import numpy as np 8 | from skompiler import ast 9 | from ._common import ASTProcessor, tolist, not_implemented, denumpyfy 10 | 11 | 12 | def translate(expr, dialect=None): 13 | """ 14 | Translates a given expression to PFA. 15 | 16 | Returns: 17 | - a Dict (if dialect is None) 18 | - a JSON string (if dialect is 'json') 19 | - a YAML string (if dialect is 'yaml'). This will invoke `import yaml`, so make sure PyYaml is installed. 20 | """ 21 | 22 | ic = InputCollector() 23 | ast.map_tree(expr, ic) 24 | input_def = { 25 | "type": "record", 26 | "name": "Input", 27 | "fields": [{"name": k, "type": ({"type": "array", "items": "double"} if v else "double")} for k, v in ic.inputs.items()] 28 | } 29 | 30 | if expr._dtype == ast.DTYPE_VECTOR: 31 | output_def = {"type": "array", "items": "double"} 32 | else: 33 | output_def = {"type": "double"} 34 | 35 | writer = PFAWriter() 36 | result = { 37 | "input": input_def, 38 | "output": output_def, 39 | "action": writer(expr) 40 | } 41 | fcn_defs = {fn: ufuncs[fn] for fn in writer.fcns} 42 | if fcn_defs: 43 | result["fcns"] = fcn_defs 44 | 45 | if dialect is None: 46 | return result 47 | elif dialect == 'json': 48 | return json.dumps(result) 49 | elif dialect == 'yaml': 50 | import yaml 51 | return yaml.dump(result) 52 | else: 53 | raise ValueError("Unknown dialect: {0}".format(dialect)) 54 | 55 | 56 | ufuncs = { 57 | "usub": {"params": [{"x": "double"}], 58 | "ret": "double", 59 | "do": {"u-": "x"}}, 60 | "mul": {"params": [{"x": "double"}, {"y": "double"}], 61 | "ret": "double", 62 | "do": {"*": ["x", "y"]}}, 63 | "div": {"params": [{"x": "double"}, {"y": "double"}], 64 | "ret": "double", 65 | "do": {"/": ["x", "y"]}}, 66 | "max": {"params": [{"x": "double"}, {"y": "double"}], 67 | "ret": "double", 68 | "do": {"max": ["x", "y"]}}, 69 | "sigmoid": {"params": [{"x": "double"}], 70 | "ret": "double", 71 | "do": {"/": [1.0, {"+": [1, {"m.exp": {"u-": "x"}}]}]}}, 72 | "step": {"params": [{"x": "double"}], 73 | "ret": "double", 74 | "do": {"if": {"<=": ["x", 0.0]}, "then": 0.0, "else": 1.0}}, 75 | "vdot": {"params": [{"x": {"type": "array", "items": "double"}}, 76 | {"y": {"type": "array", "items": "double"}}], 77 | "ret": "double", 78 | "do": {"attr": 79 | {"la.dot": [{"type": {"type": "array", "items": {"type": "array", "items": "double"}}, "new": ["x"]}, "y"]}, 80 | "path": [0]} 81 | }, 82 | # In theory this should not be needed, however Python's PFA implementation somewhy crashes if I use m.abs without this wrapper. 83 | "abs": {"params": [{"x": "double"}], 84 | "ret": "double", 85 | "do": {"m.abs": "x"}}, 86 | } 87 | 88 | 89 | def is_fn(name, elemwise_name=None, scalar_name=None): 90 | def _fn(self, _, is_elemwise=False): 91 | fname = name 92 | if elemwise_name is not None and is_elemwise: 93 | fname = elemwise_name 94 | elif scalar_name is not None and not is_elemwise: 95 | fname = scalar_name 96 | if fname.startswith("u."): 97 | self.fcns.add(fname[2:]) 98 | if is_elemwise and elemwise_name is None: 99 | def result(*args): 100 | fcn = [{"fcn": fname}] 101 | args = list(args) 102 | if len(args) > 1: 103 | return {"a.zipmap": args + fcn} 104 | else: 105 | return {"a.map": args + fcn} 106 | return result 107 | else: 108 | return lambda *args: {fname: list(args) if len(args) > 1 else args[0]} 109 | return _fn 110 | 111 | 112 | class InputCollector: 113 | """Collects all input variable names from the expression""" 114 | 115 | def __init__(self): 116 | self.inputs = {} 117 | 118 | def __call__(self, node, _): 119 | if isinstance(node, ast.IsInput): 120 | self.inputs[node.id] = isinstance(node, ast.VectorIdentifier) or isinstance(node, ast.IndexedIdentifier) 121 | return node 122 | 123 | class PFAWriter(ASTProcessor): 124 | def __init__(self): 125 | self.fcns = set() 126 | 127 | def Identifier(self, id): 128 | return "input.{0}".format(id.id) 129 | 130 | def IndexedIdentifier(self, sub): 131 | return "input.{0}.{1}".format(sub.id, sub.index) 132 | 133 | def _number_constant(self, value): 134 | # Infinities have to be handled separately 135 | if np.isinf(value): 136 | value = float('inf') if value > 0 else -float('inf') 137 | else: 138 | value = denumpyfy(value) 139 | return value 140 | 141 | def NumberConstant(self, num): 142 | return self._number_constant(num.value) 143 | 144 | def VectorIdentifier(self, id): 145 | return "input.{0}".format(id.id) 146 | 147 | def VectorConstant(self, vec): 148 | return {'type': {'type': 'array', 'items': 'double'}, 'value': [self._number_constant(v) for v in tolist(vec.value)]} 149 | 150 | def MakeVector(self, vec): 151 | return {'type': {'type': 'array', 'items': 'double'}, 'new': [self(el) for el in vec.elems]} 152 | 153 | def MatrixConstant(self, mtx): 154 | return {'type': {'type': 'array', 'items': {'type': 'array', 'items': 'double'}}, 155 | 'value': [[self._number_constant(v) for v in tolist(row)] for row in mtx.value]} 156 | 157 | def BinOp(self, node): 158 | left = self(node.left) 159 | right = self(node.right) 160 | is_elemwise = (isinstance(node.op, ast.IsElemwise) and 161 | node.left._dtype == ast.DTYPE_VECTOR and node.right._dtype == ast.DTYPE_VECTOR) 162 | op = self(node.op, is_elemwise=is_elemwise) 163 | return op(left, right) 164 | 165 | def UnaryFunc(self, node): 166 | arg = self(node.arg) 167 | is_elemwise = isinstance(node.op, ast.IsElemwise) and node.arg._dtype == ast.DTYPE_VECTOR 168 | op = self(node.op, is_elemwise=is_elemwise) 169 | return op(arg) 170 | 171 | def IfThenElse(self, node): 172 | test, iftrue, iffalse = self(node.test), self(node.iftrue), self(node.iffalse) 173 | return {'if': test, 'then': iftrue, 'else': iffalse} 174 | 175 | VecMax = is_fn("a.max") 176 | ArgMax = is_fn("a.argmax") 177 | VecSum = is_fn("a.sum") 178 | Softmax = is_fn("m.link.softmax") 179 | 180 | MatVecProduct = is_fn("la.dot") 181 | DotProduct = is_fn("u.vdot") 182 | Exp = is_fn("m.exp") 183 | Log = is_fn("m.ln") 184 | Sqrt = is_fn("m.sqrt") 185 | Abs = is_fn("u.abs") 186 | Max = is_fn("u.max") 187 | 188 | Sigmoid = is_fn("u.sigmoid") 189 | Step = is_fn("u.step") 190 | Mul = is_fn("u.mul", scalar_name="*") 191 | Div = is_fn("u.div", scalar_name="/") 192 | Add = is_fn("+", elemwise_name="la.add") 193 | Sub = is_fn("-", elemwise_name="la.sub") 194 | USub = is_fn("u.usub", scalar_name="u-") 195 | LtEq = is_fn("<=") 196 | Eq = is_fn("==") 197 | 198 | def Let(self, node): 199 | result = [{'let': {defn.name: self(defn.body)}} for defn in node.defs] 200 | result.append(self(node.body)) 201 | return {"do": result} 202 | 203 | def Reference(self, node): 204 | return node.name 205 | 206 | def LFold(self, node, **kw): 207 | # Standard implementation simply expands LFold into a sequence of BinOps and then calls itself 208 | if not node.elems: 209 | raise ValueError("LFold expects at least one element") 210 | return self(reduce(lambda x, y: ast.BinOp(node.op, x, y), node.elems), **kw) 211 | 212 | def TypedReference(self, node, **_): 213 | return self(ast.Reference(node.name)) 214 | 215 | Definition = not_implemented 216 | -------------------------------------------------------------------------------- /skompiler/fromskast/_common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for AST processors and functions, useful within multiple implementations 3 | """ 4 | #pylint: disable=not-callable 5 | from itertools import count 6 | from functools import reduce 7 | from ..ast import AST_NODES, inline_definitions, IndexedIdentifier, NumberConstant, Exp, BinOp, IsElemwise, Reference 8 | 9 | 10 | class ASTProcessorMeta(type): 11 | """A metaclass, which checks that the class defines methods for all known AST nodes. 12 | This is useful to verify SKAST processor implementations for completeness.""" 13 | 14 | def __new__(mcs, name, bases, dct): 15 | if name != 'ASTProcessor': 16 | # This way the verification applies to all subclasses of ASTProcessor 17 | unimplemented = AST_NODES.difference(dct.keys()) 18 | # Maybe the methods are implemented in one of the base classes? 19 | for base_cls in bases: 20 | unimplemented.difference_update(dir(base_cls)) 21 | if unimplemented: 22 | raise ValueError(("Class {0} does not implement all the required ASTParser methods. " 23 | "Unimplemented methods: {1}").format(name, ', '.join(unimplemented))) 24 | return super().__new__(mcs, name, bases, dct) 25 | 26 | 27 | class ASTProcessor(object, metaclass=ASTProcessorMeta): 28 | """ 29 | The class hides the need to specify ASTProcessorMeta metaclass 30 | """ 31 | 32 | def __call__(self, node, **kw): 33 | return getattr(self, node.__class__.__name__)(node, **kw) 34 | 35 | 36 | def is_(val): 37 | return lambda self, node: val 38 | 39 | def tolist(x): 40 | if hasattr(x, 'tolist'): 41 | return x.tolist() 42 | else: 43 | return list(x) 44 | 45 | def not_implemented(self, node, *args, **kw): 46 | raise NotImplementedError("Processing of node {0} is not implemented.".format(node.__class__.__name__)) 47 | 48 | def _apply_bin_op(op_node, op, left, right): 49 | if (not isinstance(left, list) and not isinstance(right, list)) or not isinstance(op_node, IsElemwise): 50 | return op(left, right) 51 | if not isinstance(left, list) or not isinstance(right, list): 52 | raise ValueError("Elementwise operations requires both operands to be lists") 53 | if len(left) != len(right): 54 | raise ValueError("Sizes of the arguments do not match") 55 | return [op(l, r) for l, r in zip(left, right)] 56 | 57 | 58 | class StandardOps: 59 | """Common implementation for BinOp, UnaryFunc, LFold, Let and ArgMin""" 60 | 61 | def BinOp(self, node, **kw): 62 | """Most common implementation for BinOp, 63 | If the arguments are lists and the op is elemwise, applies 64 | the operation elementwise and returns a list.""" 65 | 66 | left = self(node.left, **kw) 67 | right = self(node.right, **kw) 68 | op = self(node.op, **kw) 69 | return _apply_bin_op(node.op, op, left, right) 70 | 71 | def UnaryFunc(self, node, **kw): 72 | op, arg = self(node.op, **kw), self(node.arg, **kw) 73 | if not isinstance(node.op, IsElemwise) or not isinstance(arg, list): 74 | return op(arg) 75 | else: 76 | return [op(a) for a in arg] 77 | 78 | def LFold(self, node, **kw): 79 | # Standard implementation simply expands LFold into a sequence of BinOps and then calls itself 80 | if not node.elems: 81 | raise ValueError("LFold expects at least one element") 82 | return self(reduce(lambda x, y: BinOp(node.op, x, y), node.elems), **kw) 83 | 84 | def Let(self, node, **kw): 85 | "Lazy implementation of the 'Let' node. Simply substitutes variables and proceeds as normal." 86 | return self(inline_definitions(node), **kw) 87 | 88 | Reference = Definition = not_implemented 89 | 90 | def TypedReference(self, node, **_): 91 | return self(Reference(node.name)) 92 | 93 | 94 | class VectorsAsLists: 95 | """A partial implementation of an AST processor, 96 | which assumes that: 97 | - all vectors are implemented as lists, 98 | - all element-wise operations operate element-wise on the lists, 99 | - all binary and unary operations are interpreted to lambda functions. 100 | 101 | The impementation of IfThenElse requires the class to have an 102 | _iif function, which corresponds to a unary IfThenElse. 103 | """ 104 | 105 | _iif = not_implemented # Must implement this method in subclasses 106 | 107 | def VectorIdentifier(self, id): 108 | return [self(IndexedIdentifier(id.id, i, id.size)) for i in range(id.size)] 109 | 110 | def VectorConstant(self, vec): 111 | return [self(NumberConstant(v)) for v in tolist(vec.value)] 112 | 113 | def MatrixConstant(self, mtx): 114 | return [[self(NumberConstant(v)) for v in tolist(row)] for row in mtx.value] 115 | 116 | def MakeVector(self, vec): 117 | return [self(el) for el in vec.elems] 118 | 119 | def IfThenElse(self, node): 120 | """Implementation for IfThenElse for 'listwise' translators. 121 | Relies on the existence of self._iif function.""" 122 | 123 | test, iftrue, iffalse = self(node.test), self(node.iftrue), self(node.iffalse) 124 | if isinstance(iftrue, list): 125 | if not isinstance(iffalse, list) or len(iftrue) != len(iffalse): 126 | raise ValueError("Mixed types in IfThenElse expressions are not supported") 127 | return [self._iif(test, ift, iff) for ift, iff in zip(iftrue, iffalse)] 128 | else: 129 | if isinstance(iffalse, list): 130 | raise ValueError("Mixed types in IfThenElse expressions are not supported") 131 | return self._iif(test, iftrue, iffalse) 132 | 133 | def LFold(self, node, **kw): 134 | "If we know vectors are lists, we can improve LFold to avoid deep recursions" 135 | 136 | if not node.elems: 137 | raise ValueError("LFold expects at least one element") 138 | op = self(node.op, **kw) 139 | return reduce(lambda x, y: _apply_bin_op(node.op, op, x, y), 140 | [self(el, **kw) for el in node.elems]) 141 | 142 | VecSum = is_(lambda vec: reduce(lambda x, y: x + y, vec)) 143 | 144 | 145 | class StandardArithmetics: 146 | """A partial implementation of an AST processor, 147 | which assimes that: 148 | - all binary and unary operations are interpreted as lambda functions. 149 | - basic arithmetics and comparisons map to Python's basic arithmetics. 150 | - sigmoid can be expressed in terms of Exp as usual. 151 | """ 152 | 153 | Mul = is_(lambda x, y: x * y) 154 | Div = is_(lambda x, y: x / y) 155 | Add = is_(lambda x, y: x + y) 156 | Sub = is_(lambda x, y: x - y) 157 | USub = is_(lambda x: -x) 158 | LtEq = is_(lambda x, y: x <= y) 159 | Eq = is_(lambda x, y: x == y) 160 | 161 | def Sigmoid(self, _): 162 | return lambda x: 1/(1 + self(Exp())(-x)) 163 | 164 | 165 | def prepare_assign_to(assign_to, n_actual_targets): 166 | """Converts the value of the assign_to parameter to a list of strings, as needed. 167 | 168 | >>> prepare_assign_to('x', 1) 169 | ['x'] 170 | >>> prepare_assign_to('x', 2) 171 | ['x1', 'x2'] 172 | >>> prepare_assign_to(['x'], 1) 173 | ['x'] 174 | >>> prepare_assign_to(['a','b'], 2) 175 | ['a', 'b'] 176 | >>> prepare_assign_to(None, 3) 177 | >>> prepare_assign_to(['a'], 2) 178 | Traceback (most recent call last): 179 | ... 180 | ValueError: The number of outputs (2) does not match the number of assign_to values (1) 181 | """ 182 | 183 | if assign_to is None: 184 | return None 185 | 186 | if isinstance(assign_to, str): 187 | if n_actual_targets == 1: 188 | return [assign_to] 189 | else: 190 | return ['{0}{1}'.format(assign_to, i+1) for i in range(n_actual_targets)] 191 | 192 | if len(assign_to) != n_actual_targets: 193 | raise ValueError(("The number of outputs ({0}) does not match the number" 194 | " of assign_to values ({1})").format(n_actual_targets, len(assign_to))) 195 | 196 | return assign_to 197 | 198 | 199 | def id_generator(template='_tmp{0}', start=1): 200 | return map(template.format, count(start)) 201 | 202 | 203 | def denumpyfy(value): 204 | if hasattr(value, 'dtype'): 205 | return value.item() 206 | else: 207 | return value 208 | -------------------------------------------------------------------------------- /skompiler/fromskast/python.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKompiler: Generate Python AST expressions from SKAST. 3 | Useful for testing and evaluating expressions. 4 | 5 | Note that the expressions with a Let condition are compiled differently to 6 | those without. 7 | 8 | >>> from ..toskast import string 9 | 10 | Bare expressions without Let compile to bare Python expressions. 11 | You need to wrap them in ast.Expression to use and may eval to get a value: 12 | 13 | >>> expr = string.translate("12.4 * (b[1] + Y[0])") 14 | >>> pyast = translate(expr) 15 | >>> code = compile(ast.Expression(body=pyast), "__main__", "eval") 16 | >>> eval(code, {'b': [10, 20, 30], 'Y': [1.2]}) 17 | 262.88 18 | 19 | Let-expressions compile to ast.Module with multiple expressions, the last of which 20 | writes the result value to the __result__ global. You may only exec this code: 21 | 22 | >>> expr = string.translate("a=X[1]; b=a+1; 12.4 * (b + Y[0])") 23 | >>> pyast = translate(expr) 24 | >>> code = compile(pyast, "__main__", "exec") 25 | >>> vars = {} 26 | >>> eval(code, {'X': [10, 20, 30], 'Y': [1.2]}, vars) 27 | >>> vars['__result__'] 28 | 275.28 29 | 30 | In general, use utils.evaluate to evaluate SKAST expressions. 31 | """ 32 | import ast 33 | import sys 34 | import numpy as np 35 | from ..ast import USub, Identifier, NumberConstant, IsBoolean, merge_let_scopes, Max 36 | from ._common import ASTProcessor, StandardOps, denumpyfy 37 | 38 | 39 | _linearg = dict(lineno=1, col_offset=0) # Most Python AST nodes require these 40 | 41 | def translate(node, dialect=None): 42 | """ 43 | When dialect is None, translates the given SK AST expression to a Python AST tree. 44 | Otherwise, further converts the tree depending on the value of dialect, to: 45 | 46 | 'code': Python source code (via the astor package) 47 | 'lambda': Executable function (via expr.lambdify) 48 | 49 | >>> from skompiler.toskast.string import translate as skast 50 | >>> expr = skast('[2*x[0], 1] if x[1] <= 3 else [12.0, 45.5]') 51 | >>> print(translate(expr, 'code')) 52 | (np.array([2 * x[0], 1]) if x[1] <= 3 else np.array([12.0, 45.5])) 53 | 54 | >>> fn = translate(expr, 'lambda') 55 | >>> fn(x=[1, 2]) 56 | array([2, 1]) 57 | """ 58 | pyast = PythonASTWriter()(merge_let_scopes(node)) 59 | if dialect is None: 60 | return pyast 61 | elif dialect == 'lambda': 62 | return lambdify(pyast) 63 | elif dialect == 'code': 64 | import astor 65 | code = astor.to_source(pyast) 66 | # Replace some internal identifiers with matching functions 67 | code = code.replace('__np__', 'np').replace('__exp__', 'np.exp').replace('__log__', 'np.log') 68 | code = code.replace('__argmax__', 'np.argmax').replace('__sum__', 'np.sum') 69 | return code 70 | else: 71 | raise ValueError("Unknown dialect: {0}".format(dialect)) 72 | 73 | def _ident(name): 74 | "Shorthand for defining methods in PythonASTWriter (see code below)" 75 | return lambda self, x: self(Identifier(name)) 76 | 77 | def _is(node): 78 | "Shorthand for defining methods in PythonASTWriter (see code below)" 79 | return lambda self, x: node 80 | 81 | class PythonASTWriter(ASTProcessor, StandardOps): 82 | """ 83 | An AST processor, which translates a given SKAST node to a Python AST. 84 | 85 | >>> import ast 86 | >>> topy = PythonASTWriter() 87 | >>> print(ast.dump(topy(Identifier('x')))) 88 | Name(id='x', ctx=Load()) 89 | """ 90 | 91 | def Identifier(self, name): 92 | return ast.Name(id=name.id, ctx=ast.Load(), **_linearg) 93 | VectorIdentifier = Identifier 94 | 95 | def IndexedIdentifier(self, sub): 96 | return ast.Subscript(value=self(Identifier(sub.id)), 97 | slice=ast.Index(value=self(NumberConstant(sub.index))), 98 | ctx=ast.Load(), **_linearg) 99 | 100 | def NumberConstant(self, num): 101 | return ast.Num(n=denumpyfy(num.value), **_linearg) 102 | 103 | def VectorConstant(self, vec): 104 | result = ast.parse('__np__.array()', mode='eval').body 105 | result.args = [ast.List(elts=[ast.Num(n=denumpyfy(el), **_linearg) for el in vec.value], 106 | ctx=ast.Load(), **_linearg)] 107 | return result 108 | 109 | def MakeVector(self, mv): 110 | result = ast.parse('__np__.array()', mode='eval').body 111 | result.args = [ast.List(elts=[self(el) for el in mv.elems], 112 | ctx=ast.Load(), **_linearg)] 113 | return result 114 | 115 | def MatrixConstant(self, mat): 116 | result = ast.parse('__np__.array()', mode='eval').body 117 | result.args = [ast.List(elts=[ast.List(elts=[ast.Num(n=denumpyfy(el), **_linearg) for el in row], 118 | ctx=ast.Load(), **_linearg) for row in mat.value], ctx=ast.Load(), **_linearg)] 119 | return result 120 | 121 | def UnaryFunc(self, node, **kw): 122 | if isinstance(node.op, USub): 123 | return ast.UnaryOp(op=self(node.op), operand=self(node.arg), **_linearg) 124 | else: 125 | return ast.Call(func=self(node.op), args=[self(node.arg)], keywords=[], **_linearg) 126 | 127 | def BinOp(self, node, **kw): 128 | op, left, right = self(node.op), self(node.left), self(node.right) 129 | if isinstance(node.op, IsBoolean): 130 | return ast.Compare(left=left, ops=[op], comparators=[right], **_linearg) 131 | elif isinstance(node.op, Max): 132 | return ast.Call(func=self(node.op), args=[self(node.left), self(node.right)], keywords=[], **_linearg) 133 | else: 134 | return ast.BinOp(op=op, left=left, right=right, **_linearg) 135 | 136 | def IfThenElse(self, node): 137 | return ast.IfExp(test=self(node.test), body=self(node.iftrue), orelse=self(node.iffalse), **_linearg) 138 | 139 | def Let(self, node, **kw): 140 | code = [ast.Assign(targets=[ast.Name(id='_def_' + defn.name, ctx=ast.Store(), **_linearg)], 141 | value=self(defn.body), **_linearg) for defn in node.defs] 142 | # Evaluate the expression body into a "__result__" variable 143 | code.append( 144 | ast.Assign(targets=[ast.Name(id='__result__', ctx=ast.Store(), **_linearg)], 145 | value=self(node.body), **_linearg)) 146 | return ast.Module(body=code, 147 | **({} if sys.version < '3.8' else {"type_ignores": []})) 148 | 149 | def Reference(self, ref): 150 | return ast.Name(id='_def_' + ref.name, ctx=ast.Load(), **_linearg) 151 | TypedReference = Reference 152 | 153 | # Functions 154 | Exp = _ident('__exp__') 155 | Sqrt = _ident('__sqrt__') 156 | Log = _ident('__log__') 157 | Step = _ident('__step__') 158 | VecSum = _ident('__sum__') 159 | ArgMax = _ident('__argmax__') 160 | Sigmoid = _ident('__sigmoid__') 161 | Softmax = _ident('__softmax__') 162 | VecMax = _ident('__vecmax__') 163 | Max = _ident('__max__') 164 | Abs = _ident('__abs__') 165 | 166 | # Operators 167 | Mul = _is(ast.Mult()) 168 | Div = _is(ast.Div()) 169 | Add = _is(ast.Add()) 170 | Sub = _is(ast.Sub()) 171 | USub = _is(ast.USub()) 172 | DotProduct = _is(ast.MatMult()) 173 | MatVecProduct = DotProduct 174 | 175 | # Predicates 176 | LtEq = _is(ast.LtE()) 177 | Eq = _is(ast.Eq()) 178 | 179 | # ------------- Evaluation of Python AST-s --------------- # 180 | 181 | def _softmax(X): 182 | X = np.exp(X) 183 | sum_prob = np.sum(X, axis=1).reshape((-1, 1)) 184 | X /= sum_prob 185 | return X 186 | 187 | _eval_vars = { 188 | '__np__': np, 189 | '__exp__': np.exp, 190 | '__sqrt__': np.sqrt, 191 | '__log__': np.log, 192 | '__sum__': np.sum, 193 | '__argmax__': np.argmax, 194 | '__vecmax__': np.max, 195 | '__max__': np.maximum, 196 | '__abs__': np.abs, 197 | '__sigmoid__': lambda z: 1.0/(1.0 + np.exp(-z)), 198 | '__sum_normalize__': lambda x: x / np.sum(x), 199 | '__softmax__': lambda x: _softmax([x])[0, :], 200 | '__step__': lambda x: 1 if x > 0 else 0 # This is how step is implemented in LogisticRegression 201 | } 202 | 203 | def lambdify(pyast): 204 | """ 205 | Converts a given Python AST, produced by PythonASTWriter to an executable Python function. 206 | 207 | >>> from ..ast import NumberConstant, BinOp, Mul, Identifier 208 | >>> pyast = translate(BinOp(Mul(), NumberConstant(2), Identifier('x'))) 209 | >>> fn = lambdify(pyast) 210 | >>> fn(x=3.14) 211 | 6.28 212 | """ 213 | if isinstance(pyast, ast.Module): 214 | # Exec the code: 215 | code = compile(pyast, "__main__", "exec") 216 | def result(**inputs): 217 | globals_ = {} 218 | globals_.update(_eval_vars) 219 | eval(code, inputs, globals_) # pylint: disable=eval-used 220 | return globals_['__result__'] 221 | else: 222 | # Eval the code: 223 | code = compile(ast.Expression(body=pyast), "__main__", "eval") 224 | def result(**inputs): 225 | globals_ = {} 226 | globals_.update(_eval_vars) 227 | return eval(code, inputs, globals_) # pylint: disable=eval-used 228 | return result 229 | -------------------------------------------------------------------------------- /skompiler/toskast/sklearn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKLearn model transformation to SKompiler's AST. 3 | """ 4 | #pylint: disable=unused-argument 5 | from functools import singledispatch 6 | from sklearn.linear_model import LogisticRegression 7 | from sklearn.linear_model._base import LinearModel 8 | from sklearn.svm import SVC, SVR 9 | from sklearn.tree._classes import BaseDecisionTree, DecisionTreeRegressor 10 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor,\ 11 | GradientBoostingClassifier, GradientBoostingRegressor,\ 12 | AdaBoostClassifier 13 | from sklearn.cluster import KMeans 14 | from sklearn.decomposition._pca import _BasePCA 15 | from sklearn.neural_network import MLPClassifier, MLPRegressor 16 | 17 | from sklearn.preprocessing import Binarizer, MinMaxScaler, MaxAbsScaler, StandardScaler,\ 18 | Normalizer 19 | from sklearn.pipeline import Pipeline 20 | 21 | from skompiler.dsl import ident, vector 22 | 23 | from .._common import prepare_inputs 24 | from .linear_model.logistic import logreg_binary, logreg_multiclass 25 | from .linear_model.base import linear_model 26 | from .tree.base import decision_tree 27 | from .ensemble.forest import random_forest_classifier, random_forest_regressor 28 | from .ensemble.gradient_boosting import gradient_boosting_classifier, gradient_boosting_regressor 29 | from .ensemble.weight_boosting import adaboost_classifier 30 | from .cluster.k_means import k_means 31 | from .decomposition.pca import pca 32 | from .neural_network.multilayer_perceptron import mlp, mlp_classifier 33 | from .preprocessing.data import binarize, scale, unscale, standard_scaler, normalizer 34 | 35 | 36 | @singledispatch 37 | def translate(model, inputs='x', method='predict'): 38 | """ 39 | Translate a given SKLearn model to a SK AST expression. 40 | 41 | Kwargs: 42 | 43 | inputs (string or ASTNode): 44 | The name of the variable that will be used to represent 45 | the input in the resulting expression. It can be given either as a 46 | single string (in this case the variable denotes an input vector), 47 | a list of strings (in this case each element denotes the name of one input component of a vector), 48 | or an ASTNode (e.g. VectorIdentifier, or vector([ident('x'), ident('y')])) 49 | 50 | method (string): Method to be expressed. Possible options: 51 | 'predict', for all supported models. 52 | 'decision_function', for all supported models. 53 | 'predict_proba', 'predict_log_proba': for classifiers. 54 | """ 55 | raise NotImplementedError("Conversion not implemented for {0}".format(model.__class__.__name__)) 56 | 57 | _supported_methods = {} 58 | 59 | # An improved version of @translate.register decorator 60 | def register(cls, methods): 61 | def decorator(fn): 62 | @translate.register(cls) 63 | def new_fn(model, inputs='x', method='predict'): 64 | if method not in methods: 65 | raise ValueError("Method {0} is not supported (or not implemented yet) for {1}".format(method, cls.__name__)) 66 | return fn(model, inputs, method) 67 | return new_fn 68 | _supported_methods[cls] = methods 69 | return decorator 70 | 71 | 72 | @register(LogisticRegression, ['decision_function', 'predict', 'predict_proba', 'predict_log_proba']) 73 | def _(model, inputs, method): 74 | ovr = (model.multi_class in ["ovr", "warn"] or 75 | (model.multi_class == 'auto' and (model.classes_.size <= 2 or 76 | model.solver == 'liblinear'))) 77 | if model.coef_.shape[0] == 1: # Binary logreg 78 | if not ovr: 79 | raise NotImplementedError("Logistic regression with binary outcomes and multinomial outputs is not implemented") 80 | # ... It's not too hard, actually, just need to find the 15 minutes needed to implement it. 81 | return logreg_binary(model.coef_.ravel(), model.intercept_[0], inputs=prepare_inputs(inputs, model.coef_.shape[-1]), method=method) 82 | else: # Multiclass logreg 83 | return logreg_multiclass(model.coef_, model.intercept_, method=method, 84 | inputs=prepare_inputs(inputs, model.coef_.shape[-1]), multi_class='ovr' if ovr else 'multinomial') 85 | 86 | @register(SVC, ['decision_function', 'predict']) 87 | def _(model, inputs, method): 88 | # For linear SVC the predict and decision function logic is the same as for logreg 89 | if model.kernel != 'linear': 90 | raise NotImplementedError("Translation for nonlinear SVC not implemented") 91 | if model.decision_function_shape != 'ovr': 92 | raise NotImplementedError("Translation not implemented for one-vs-one SVC") 93 | if len(model.classes_) > 2 and method != 'predict': 94 | raise NotImplementedError("Translation not implemented for non-binary SVC") # See sklearn.utils.multiclass._ovr_decision_function 95 | if model.coef_.shape[0] == 1: # Binary 96 | return logreg_binary(model.coef_.ravel(), model.intercept_[0], inputs=prepare_inputs(inputs, model.coef_.shape[-1]), method=method) 97 | else: # Multiclass logreg 98 | return logreg_multiclass(model.coef_, model.intercept_, method=method, 99 | inputs=prepare_inputs(inputs, model.coef_.shape[-1]), multi_class='ovr') 100 | 101 | @register(SVR, ['predict']) 102 | def _(model, inputs, method): 103 | if isinstance(model, SVR) and model.kernel != 'linear': 104 | raise NotImplementedError("Nonlinear SVR not implemented") 105 | return linear_model(model.coef_.ravel(), model.intercept_.item(), prepare_inputs(inputs, model.coef_.shape[-1])) 106 | 107 | @register(LinearModel, ['predict']) 108 | def _(model, inputs, method): 109 | return linear_model(model.coef_.ravel(), model.intercept_, prepare_inputs(inputs, model.coef_.shape[-1])) 110 | 111 | @register(BaseDecisionTree, ['predict', 'predict_proba', 'predict_log_proba']) 112 | def _(model, inputs, method): 113 | if isinstance(model, DecisionTreeRegressor) and method != 'predict': 114 | raise ValueError("Method {0} is not supported for DecisionTreeRegressor".format(method)) 115 | return decision_tree(model.tree_, prepare_inputs(inputs, model.n_features_in_), method) 116 | 117 | @register(RandomForestClassifier, ['predict', 'predict_proba', 'predict_log_proba']) 118 | def _(model, inputs, method): 119 | return random_forest_classifier(model, prepare_inputs(inputs, model.n_features_in_), method) 120 | 121 | @register(RandomForestRegressor, ['predict']) 122 | def _(model, inputs, method): 123 | return random_forest_regressor(model, prepare_inputs(inputs, model.n_features_in_)) 124 | 125 | @register(GradientBoostingClassifier, ['decision_function']) 126 | def _(model, inputs, method): 127 | return gradient_boosting_classifier(model, prepare_inputs(inputs, model.n_features_in_)) 128 | 129 | @register(GradientBoostingRegressor, ['predict']) 130 | def _(model, inputs, method): 131 | return gradient_boosting_regressor(model, prepare_inputs(inputs, model.n_features_in_)) 132 | 133 | @register(AdaBoostClassifier, ['decision_function', 'predict', 'predict_proba', 'predict_log_proba']) 134 | def _(model, inputs, method): 135 | return adaboost_classifier(model, prepare_inputs(inputs, model.estimators_[0].n_features_in_), method) 136 | 137 | @register(KMeans, ['transform', 'predict']) 138 | def _(model, inputs, method): 139 | return k_means(model.cluster_centers_, prepare_inputs(inputs, model.cluster_centers_.shape[1]), method) 140 | 141 | @register(_BasePCA, ['transform']) 142 | def _(model, inputs, method): 143 | return pca(model, prepare_inputs(inputs, model.components_.shape[1])) 144 | 145 | @register(Pipeline, ['predict', 'predict_proba', 'decision_function', 'predict_log_proba', 'transform']) 146 | def _(model, inputs, method): 147 | if not model.steps: 148 | raise ValueError("Empty pipeline provided") 149 | # The first step in the pipeline is responsible for preparing the inputs 150 | first_method = 'transform' if len(model.steps) > 1 else method 151 | expr = translate(model.steps[0][1], inputs, first_method) 152 | for i in range(1, len(model.steps)-1): 153 | expr = translate(model.steps[i][1], expr, 'transform') 154 | if len(model.steps) > 1: 155 | expr = translate(model.steps[-1][1], expr, method) 156 | return expr 157 | 158 | @register(MLPRegressor, ['predict']) 159 | def _(model, inputs, method): 160 | return mlp(model, prepare_inputs(inputs, len(model.coefs_[0]))) 161 | 162 | @register(MLPClassifier, ['predict', 'predict_proba', 'predict_log_proba']) 163 | def _(model, inputs, method): 164 | return mlp_classifier(model, prepare_inputs(inputs, len(model.coefs_[0])), method) 165 | 166 | @register(Binarizer, ['transform']) 167 | def _(model, inputs, method): 168 | return binarize(model.threshold, prepare_inputs(inputs)) 169 | 170 | @register(MinMaxScaler, ['transform']) 171 | def _(model, inputs, method): 172 | return scale(model.scale_, model.min_, prepare_inputs(inputs, len(model.scale_))) 173 | 174 | @register(MaxAbsScaler, ['transform']) 175 | def _(model, inputs, method): 176 | return unscale(model.scale_, prepare_inputs(inputs, len(model.scale_))) 177 | 178 | @register(StandardScaler, ['transform']) 179 | def _(model, inputs, method): 180 | n = None 181 | if model.with_mean: 182 | n = len(model.mean_) 183 | elif model.with_std: 184 | n = len(model.scale_) 185 | return standard_scaler(model, prepare_inputs(inputs, n)) 186 | 187 | @register(Normalizer, ['transform']) 188 | def _(model, inputs, method): 189 | return normalizer(model.norm, prepare_inputs(inputs)) 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | SKompiler: Translate trained SKLearn models to executable code in other languages 2 | ================================================================================ 3 | 4 | [![Build Status](https://travis-ci.org/konstantint/SKompiler.svg?branch=master)](https://travis-ci.org/konstantint/SKompiler) 5 | 6 | The package provides a tool for transforming trained SKLearn models into other forms, such as SQL queries, Excel formulas, Portable Format for Analytics (PFA) files or Sympy expressions (which, in turn, can be translated to code in a variety of languages, such as C, Javascript, Rust, Julia, etc). 7 | 8 | Requirements 9 | ------------ 10 | 11 | - Python 3.5 or later 12 | 13 | Installation 14 | ------------ 15 | 16 | The simplest way to install the package is via `pip`: 17 | 18 | $ pip install SKompiler[full] 19 | 20 | 21 | Note that the `[full]` option includes the installations of `sympy`, `sqlalchemy` and `astor`, which are necessary if you plan to convert `SKompiler`'s expressions to `sympy` expressions (which, in turn, can be compiled to many other languages) or to SQLAlchemy expressions (which can be further translated to different SQL dialects) or to Python source code. If you do not need this functionality (say, you only need the raw `SKompiler` expressions or perhaps only the SQL conversions without the `sympy` ones), you may avoid the forced installation of all optional dependencies by simply writing 22 | 23 | $ pip install SKompiler 24 | 25 | (you are free to install any of the required extra dependencies, via separate calls to `pip install`, of course) 26 | 27 | Usage 28 | ----- 29 | 30 | ### Introductory example 31 | 32 | Let us start by walking through an introductory example. We begin by training a model on a small dataset: 33 | 34 | from sklearn.datasets import load_iris 35 | from sklearn.ensemble import RandomForestClassifier 36 | X, y = load_iris(return_X_y=True) 37 | m = RandomForestClassifier(n_estimators=3, max_depth=3).fit(X, y) 38 | 39 | Suppose we need to express the logic of `m.predict` in SQLite. Here is how we can achieve that: 40 | 41 | from skompiler import skompile 42 | expr = skompile(m.predict) 43 | sql = expr.to('sqlalchemy/sqlite') 44 | 45 | Voila, the value of the `sql` variable is a query, which would compute the value of `m.predict` in pure SQL: 46 | 47 | WITH _tmp1 AS 48 | (SELECT .... FROM data) 49 | _tmp2 AS 50 | ( ... ) 51 | SELECT ... from _tmp2 ... 52 | 53 | Let us import the data into an in-memory SQLite database to test the generated query: 54 | 55 | import sqlalchemy as sa 56 | import pandas as pd 57 | conn = sa.create_engine('sqlite://').connect() 58 | df = pd.DataFrame(X, columns=['x1', 'x2', 'x3', 'x4']).reset_index() 59 | df.to_sql('data', conn) 60 | 61 | Our database now contains the table named `data` with the primary key `index`. We need to provide this information to SKompiler to have it generate the correct query: 62 | 63 | sql = expr.to('sqlalchemy/sqlite', key_column='index', from_obj='data') 64 | 65 | We can now query the data: 66 | 67 | results = pd.read_sql(sql, conn) 68 | 69 | and verify that the results match: 70 | 71 | assert (results.values.ravel() == m.predict(X).ravel()).all() 72 | 73 | Note that the generated SQL expression uses names `x1`, `x2`, `x3` and `x4` to refer to the input variables. 74 | We could have chosen different input variable names by writing: 75 | 76 | expr = skompile(m.predict, ['a', 'b', 'c', 'd']) 77 | 78 | ### Single-shot computation 79 | 80 | Note that the generated SQL code splits the computation into sequential steps using `with` expressions. In some cases you might want to have the whole computation "inlined" into a single expression. You can achieve this by specifying 81 | `multistage=False`: 82 | 83 | sql = expr.to('sqlalchemy/sqlite', multistage=False) 84 | 85 | Note that in this case the resulting expression would typically be several times longer than the multistage version: 86 | 87 | len(expr.to('sqlalchemy/sqlite')) 88 | > 2262 89 | len(expr.to('sqlalchemy/sqlite', multistage=False)) 90 | > 12973 91 | 92 | Why so? Because, for a typical classifier (including the one used in this example) 93 | 94 | predict(x) = argmax(predict_proba(x)) 95 | 96 | There is, however, no single `argmax` function in SQL, hence it has to be faked using the following logic: 97 | 98 | predict(x) = if predict_proba(x)[0] == max(predict_proba(x)) then 0 99 | else if predict_proba(x)[1] == max(predict_proba(x)) then 1 100 | else 2 101 | 102 | If SKompiler is not alowed to use a separate step to store the intermediate `predict_proba` outputs, it is forced to inline the same computation verbatim multiple times. To summarize, you should probably avoid the use of `multistage=False` in most cases. 103 | 104 | ### Other formats 105 | 106 | By changing the first parameter of the `.to()` call you may produce output in a variety of other formats besides SQLite: 107 | 108 | * `sqlalchemy`: raw SQLAlchemy expression (which is a dialect-independent way of representing SQL). Jokes aside, SQL is sometimes a totally valid choice for deploying models into production. 109 | 110 | Note that generated SQL may (depending on the chosen model and method) include functions `exp`, `log` and `sqrt`, which are not supported out of the box in SQLite. If you work with SQLite, you will need to [add them separately](https://stackoverflow.com/a/2108921/318964) via `create_function`. You can find an example of how this can be done in `tests/evaluators.py` in the SKompiler's source code. 111 | * `sqlalchemy/`: SQL string in any of the SQLAlchemy-supported dialects (`firebird`, `mssql`, `mysql`, `oracle`, `postgresql`, `sqlite`, `sybase`). This is a convenience feature for those who are lazy to figure out how to compile raw SQLAlchemy to actual SQL. 112 | * `excel`: Excel formula. Ever tried dragging a random forest equation down along the table? Fun! Check out [this short screencast](https://www.youtube.com/watch?v=7vUfa7W0NpY) to see how it can be done. 113 | 114 | _NB: The screencast was recorded using a previous version, where `multistage=False` was the default option_. 115 | * `pfa`: A dict with [PFA](http://dmg.org/pfa/) code. 116 | * `pfa/json` or `pfa/yaml`: PFA code as a JSON or YAML string for those who are lazy to write `json.dumps` or `yaml.dump`. PyYAML should be installed in the latter case, of course. 117 | * `sympy`: A SymPy expression. Ever wanted to take a derivative of your model symbolically? 118 | * `sympy/`: Code in the language ``, generated via SymPy. Supported values for `` are `c`, `cxx`, `rust`, `fortran`, `js`, `r`, `julia`, `mathematica`, `octave`. Note that the quality of the generated code varies depending on the model, language and the value of the `assign_to` parameter. Again, this is just a convenience feature, you will get more control by dealing with `sympy` code printers [manually](https://www.sympy.org/scipy-2017-codegen-tutorial/). 119 | 120 | _NB: Sympy translation does not support multistage mode at the moment, hence the resulting code will have repeated subexpressions (which can be extracted by means of Sympy itself, however)._ 121 | 122 | * `python`: Python syntax tree (the same you'd get via `ast.parse`). This (and the following three options) are mostly useful for debugging and testing. 123 | * `python/code`: Python source code. The generated code will contain references to custom functions, such as `__argmax__`, `__sigmoid__`, etc. To execute the code you will need to provide these in the `locals` dictionary. See `skompiler.fromskast.python._eval_vars`. 124 | * `python/lambda`: Python callable function (primarily useful for debugging and testing). Equivalent to calling `expr.lambdify()`. 125 | * `string`: string, equivalent to `str(expr)`. 126 | 127 | ### Other models 128 | 129 | So far this has been a fun two-weekends-long project, hence translation is implemented for a limited number of models. The most basic ones (linear models, decision trees, forests, gradient boosting, PCA, KMeans, MLP, Pipeline and a couple of preprocessors) are covered, however, and this is already sufficient to compile nontrivial constructions. For example: 130 | 131 | m = Pipeline([('scale', StandardScaler()), 132 | ('dim_reduce', PCA(6)), 133 | ('cluster', KMeans(10)), 134 | ('classify', MLPClassifier([5, 4], 'tanh'))]) 135 | 136 | Even though this particular example probably does not make much sense from a machine learning perspective, it would happily compile both to Excel and SQL forms none the less. 137 | 138 | ### How it works 139 | 140 | The `skompile` procedure translates a given method into an intermediate syntactic representation (called SKompiler AST or SKAST). This representation uses a limited number of operations so it is reasonably simple to translate it into other forms. 141 | 142 | In principle, SKAST's utility is not limited to `sklearn` models. Anything you translate into SKAST becomes automatically compileable to whatever output backends are implemented in `SKompiler`. Generating raw SKAST is quite straightforward: 143 | 144 | from skompiler.dsl import ident, const 145 | expr = const([[1,2],[3,4]]) @ ident('x', 2) + 12 146 | expr.to('sqlalchemy/sqlite', 'result') 147 | > SELECT 1 * x1 + 2 * x2 + 12 AS result1, 3 * x1 + 4 * x2 + 12 AS result2 148 | > FROM data 149 | 150 | You can use `repr(expr)` on any SKAST expression to dump its unformatted internal representation for examination or `str(expr)` to get a somewhat-formatted view of it. 151 | 152 | It is important to note, that for larger models (say, a random forest or a gradient boosted model with 500+ trees) the resulting SKAST expression tree may become deeper than Python's default recursion limit of 1000. As a result some translators may produce a `RecursionError` when processing such expressions. This can be solved by raising the system recursion limit to sufficiently high value: 153 | 154 | import sys 155 | sys.setrecursionlimit(10000) 156 | 157 | Development 158 | ----------- 159 | 160 | If you plan to develop or debug the package, consider installing it by running: 161 | 162 | $ pip install -e .[dev] 163 | 164 | from within the source distribution. This will install the package in "development mode" and include extra dependencies, useful for development. 165 | 166 | You can then run the tests by typing 167 | 168 | $ py.test 169 | 170 | at the root of the source distribution. 171 | 172 | Contributing 173 | ------------ 174 | 175 | Feel free to contribute or report issues via Github: 176 | 177 | * https://github.com/konstantint/SKompiler 178 | 179 | 180 | Copyright & License 181 | ------------------- 182 | 183 | Copyright: 2018, Konstantin Tretyakov. 184 | License: MIT 185 | -------------------------------------------------------------------------------- /skompiler/fromskast/sympy.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKompiler: Generate Sympy expressions from SKAST. 3 | """ 4 | import numpy as np 5 | import sympy as sp 6 | from ..ast import IsElemwise, Mul 7 | from ._common import ASTProcessor, is_, StandardOps, StandardArithmetics 8 | 9 | 10 | def translate(node, dialect=None, true_argmax=True, assign_to='y', component=None, lambdify_inputs_str='x', **kw): 11 | """Translates SKAST to a Sympy expression and optionally generates code from it. 12 | 13 | KwArgs: 14 | dialect (string): If None, returns the Sympy expression. Otherwise translates it further to one of the supported languages. 15 | Supported values: 16 | 'c', 'cxx', 'rust', 'fortran', 'js', 'r', 'julia', 'mathematica', 'octave', 'lambda'. 17 | If dialect == 'lambda', the expression is lambdified (the call is then equivalent to 18 | lambdify(lambdify_inputs_str, translate(node))) 19 | 20 | true_argmax (bool): When True (default), the generated expression will include a "Sympy-executable" definition of argmax(vector) 21 | (in the form of a lengthy expression "if v[0] == max(v) then 0, else if v[1] == max(v) then 1, ...") 22 | When False, the expression will contain just the name "argmax". This may be sufficient 23 | if you only need Sympy as an intermediate representation before compiling into a different language. 24 | 25 | assign_to: This value is passed further to sympy code printers when dialect is not None. 26 | When assign_to is not None, it specifies that the code printer should generate assignment statements, 27 | setting values of the given variable. Otherwise, a pure expression is generated (which is not always 28 | possible if the expression is a matrix) 29 | 30 | component (int): If the result is a vector, return only the specified component of it. 31 | 32 | lambdify_inputs_str (str): The string specifying lambdified function inputs (when dialect == 'lambda') 33 | 34 | **kw: Other arguments passed to the code generator (when dialect is not None) 35 | 36 | >>> from skompiler.toskast.string import translate as skast 37 | >>> expr = skast('[2*x[0], 1] if x[1] <= 3 else [12.0, 45.5]') 38 | >>> print(translate(expr, 'js')) 39 | if (x_1 <= 3) { 40 | y[0] = 2*x_0; 41 | } 42 | else { 43 | y[0] = 12.0; 44 | } 45 | if (x_1 <= 3) { 46 | y[1] = 1; 47 | } 48 | else { 49 | y[1] = 45.5; 50 | } 51 | >>> print(translate(expr, 'js', assign_to=None, component=0)) 52 | ((x_1 <= 3) ? ( 53 | 2*x_0 54 | ) 55 | : ( 56 | 12.0 57 | )) 58 | """ 59 | syexpr = SympyWriter(true_argmax=true_argmax)(node) 60 | if component is not None: 61 | syexpr = syexpr[component] 62 | if dialect is None: 63 | return syexpr 64 | elif dialect == 'lambda': 65 | return lambdify(lambdify_inputs_str, syexpr) 66 | else: 67 | return to_code(syexpr, dialect, assign_to=assign_to, **kw) 68 | 69 | def _argmax(val): 70 | "A sympy implementation of argmax" 71 | 72 | maxval = sp.Max(*val) 73 | pieces = [(i, sp.Eq(val[i], maxval)) for i in range(len(val)-1)] 74 | pieces.append((len(val)-1, True)) 75 | return sp.Piecewise(*pieces) 76 | 77 | def _max(val): 78 | return sp.Max(*val) 79 | 80 | def _softmax(vec): 81 | "A sympy implementation of sklearn's softmax" 82 | sexp = [sp.exp(vec[i]) for i in range(len(vec))] 83 | return sp.ImmutableMatrix([sexp[i] / sum(sexp) for i in range(len(sexp))]) 84 | 85 | 86 | class SympyWriter(ASTProcessor, StandardOps, StandardArithmetics): 87 | """A SK AST processor, producing a Sympy expression""" 88 | 89 | def __init__(self, true_argmax=False): 90 | self.true_argmax = true_argmax 91 | 92 | def Identifier(self, id): 93 | return sp.symbols(id.id) 94 | 95 | def VectorIdentifier(self, node): 96 | # This is not the best option, because this way all our vector inputs 97 | # must be 2D matrices (for purposes of lambdify as well as printingc:w 98 | x = sp.MatrixSymbol(node.id, node.size, 1) 99 | return sp.ImmutableMatrix([x[i] for i in range(node.size)]) 100 | 101 | # NB: This alone won't work, because Dot operator wants to have an actual matrix as input 102 | # return sp.MatrixSymbol(node.id, node.size, 1) 103 | # 104 | # This (and the version with x = sp.IndexedBase('x')) is not good because it breaks code printers 105 | # (apparently they expect all indices to have lower and upper bounds which does not happen if you provide 106 | # numeric indices, or smth like that, probably a bug somewhere in sympy) 107 | # x = sp.MatrixSymbol(node.id, node.size, 1) 108 | # return sp.ImmutableMatrix([x[i] for i in range(node.size)]) 109 | 110 | def IndexedIdentifier(self, sub): 111 | if sub.size is None: # Hack to handle Python-written expressions. We do not represent them as "true" indexed values in Sympy 112 | return sp.symbols('{0}_{1}'.format(sub.id, sub.index)) 113 | else: 114 | return sp.IndexedBase(sub.id, shape=(sub.size,))[sub.index] 115 | 116 | def NumberConstant(self, num): 117 | return sp.sympify(num.value) 118 | 119 | def VectorConstant(self, vec): 120 | return sp.ImmutableMatrix(vec.value) 121 | 122 | MatrixConstant = VectorConstant 123 | 124 | def MakeVector(self, vec): 125 | return sp.ImmutableMatrix([self(el) for el in vec.elems]) 126 | 127 | def UnaryFunc(self, node, **kw): 128 | arg = self(node.arg) 129 | op = self(node.op) 130 | if isinstance(arg, sp.MatrixBase) and isinstance(node.op, IsElemwise): 131 | if arg.shape[1] != 1: 132 | raise NotImplementedError("Elementwise operations are only supported for vectors (column matrices)") 133 | return sp.ImmutableMatrix([op(arg[i]) for i in range(len(arg))]) 134 | else: 135 | return op(arg) 136 | 137 | def BinOp(self, node, **kw): 138 | left = self(node.left) 139 | right = self(node.right) 140 | op = self(node.op) 141 | if isinstance(left, sp.MatrixBase) and isinstance(right, sp.MatrixBase) and isinstance(node.op, IsElemwise): 142 | if left.shape != right.shape: 143 | raise ValueError("Shapes of the arguments do not match") 144 | if left.shape[1] != 1: 145 | raise NotImplementedError("Elementwise operations are only supported for vectors (column matrices)") 146 | return sp.ImmutableMatrix([op(left[i], right[i]) for i in range(len(left))]) 147 | else: 148 | return op(left, right) 149 | 150 | def IfThenElse(self, node): 151 | # Piecewise function with matrix output is not a Matrix itself, which breaks some of the logic 152 | # Hence this won't work in general: 153 | test, iftrue, iffalse = self(node.test), self(node.iftrue), self(node.iffalse) 154 | if hasattr(iftrue, 'shape'): 155 | if iftrue.shape != iffalse.shape: 156 | raise ValueError("Shapes of the IF branches must match") 157 | if iftrue.shape[1] != 1: 158 | raise NotImplementedError("Elementwise operations are only supported for vectors (column matrices)") 159 | return sp.ImmutableMatrix([sp.Piecewise((ift, test), (iff, True)) for ift, iff in zip(iftrue, iffalse)]) 160 | else: 161 | return sp.Piecewise((iftrue, test), (iffalse, True)) 162 | 163 | def MatVecProduct(self, _): 164 | return self(Mul()) 165 | 166 | DotProduct = is_(lambda x, y: x.dot(y)) 167 | Exp = is_(sp.exp) 168 | Log = is_(sp.log) 169 | Step = is_(sp.Heaviside) 170 | Sqrt = is_(sp.sqrt) 171 | Abs = is_(sp.Abs) 172 | VecSum = is_(sum) # Yes, we return a Python summation here 173 | Softmax = is_(_softmax) 174 | VecMax = is_(_max) 175 | Max = is_(lambda x, y: _max([x, y])) 176 | Eq = is_(sp.Eq) 177 | 178 | def ArgMax(self, _): 179 | return _argmax if self.true_argmax else sp.Function('argmax') 180 | 181 | # Utility function 182 | 183 | def _softmax(X): 184 | X = np.exp(X) 185 | sum_prob = np.sum(X, axis=1).reshape((-1, 1)) 186 | X /= sum_prob 187 | return X 188 | 189 | _lambdify_modules = ["numpy", 190 | {"Heaviside": lambda x: int(x > 0), 191 | "softmax": lambda x: _softmax([x])[0, :] 192 | }] 193 | 194 | def lambdify(sympy_inputs_str, sympy_expr): 195 | return sp.lambdify(sp.symbols(sympy_inputs_str), sympy_expr, modules=_lambdify_modules) 196 | 197 | 198 | 199 | _ufns = {'argmax': 'argmax'} 200 | 201 | #pylint: disable=unnecessary-lambda 202 | _code_printers = { 203 | 'c': lambda expr, **kw: sp.ccode(expr, standard='c99', user_functions=_ufns, **kw), 204 | 'cxx': lambda expr, **kw: sp.cxxcode(expr, user_functions=_ufns, **kw), 205 | 'rust': lambda expr, **kw: sp.rust_code(expr, user_functions=_ufns, **kw), 206 | 'fortran': lambda expr, **kw: sp.fcode(expr, standard=95, user_functions=_ufns, **kw), 207 | 'js': lambda expr, **kw: sp.jscode(expr, user_functions=_ufns, **kw), 208 | 'r': lambda expr, **kw: sp.rcode(expr, user_functions=_ufns, **kw), 209 | 'julia': lambda expr, **kw: sp.julia_code(expr, user_functions=_ufns, **kw), 210 | 'mathematica': lambda expr, assign_to=None, **kw: sp.mathematica_code(expr, user_functions=_ufns, **kw), 211 | 'octave': lambda expr, **kw: sp.octave_code(expr, user_functions=_ufns, **kw), 212 | } 213 | 214 | def to_code(syexpr, dialect, assign_to='y', **kw): 215 | ''' 216 | Shorthand for converting the resulting Sympy expression to code in various languages. 217 | 218 | 219 | Args: 220 | 221 | dialect (str): The target language. Can be one of: 222 | 'c', 'cxx', 'rust', 'fortran', 'js', 'r', 'julia', 'mathematica', 'octave'. 223 | 224 | Kwargs: 225 | 226 | assign_to (str or sympy expression): 227 | Passed to the sympy code printers. 228 | If this is not None, the generated code is a set of assignments to the target variable 229 | (or, of the result is an array, to the components of the target array). 230 | Note that if assign_to is None, and the resulting expression is a vector (e.g. predict_proba vector 231 | of probabilities), sympy may be unable to generate meaningful code for many languages, 232 | because those lack the "Matrix" datatype (hence it cannot be the result of a one-liner expression) 233 | ''' 234 | if dialect not in _code_printers: 235 | raise ValueError("Unknown dialect: ") 236 | return _code_printers[dialect](syexpr, assign_to=assign_to, **kw) 237 | -------------------------------------------------------------------------------- /skompiler/fromskast/excel.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKompiler: Generate Sympy expressions from SKAST. 3 | """ 4 | #pylint: disable=protected-access 5 | import warnings 6 | from collections import OrderedDict 7 | from itertools import product, chain, takewhile, count 8 | import re 9 | import numpy as np 10 | from ._common import ASTProcessor, is_, StandardOps, VectorsAsLists, id_generator 11 | 12 | 13 | def translate(node, component=None, multistage=True, assign_to=None, 14 | multistage_subexpression_min_length=3, _max_subexpression_length=8100): 15 | """Translates SKAST to an Excel formula (or a list of those, if the output should be a vector). 16 | 17 | Kwargs: 18 | component (int or None): 19 | If the result is a vector and only one component is required, pass its index here. 20 | 21 | multistage (bool): 22 | When False, generates a single string, describing the model as one long expression. 23 | For complex models this string will be too long to be used in Excel. 24 | When True (default), returns an ExcelCode object, which is an OrderedDict, 25 | mapping cell names to expressions which they should correspond to. 26 | 27 | assign_to: A list or a generator expression, producing Excel cell names 28 | which can be filled in a multi-stage computation. 29 | When None, a default sequence ['A1', 'B1', ...] will be used. 30 | if you would like such sequence, but beginning at, say 'G3', 31 | pass excel_range('G3:*3') 32 | 33 | multistage_subexpression_min_length (int): 34 | Allows to reduce the number of stages in the computation, by preventing the creation 35 | of an intermediate step whenever the corresponding expression is shorter than the given length. 36 | I.e. suppose you would like to avoid having a short separate subexpression 37 | G1 = MAX(A1,B1,C1) 38 | and would rather have it inlined. 39 | In this case specify multistage_subexpression_min_length=14 and the expression shorter than 14 characters won't be 40 | assigned to a separate variable. 41 | Specifying a very large value is nearly equivalent to setting multistage=False 42 | (the only difference is that the returned value is still an OrderedDict with a single assignment) 43 | 44 | _max_subexpression_length (int): Max length of a single subexpression in multistage mode. 45 | You should not change it (used for testing internally) 46 | 47 | >>> from skompiler.toskast.string import translate as skast 48 | >>> expr = skast('[2*x[0]/5, 1] if x[1] <= 3 else [12.0+y, -45.5]') 49 | >>> print(translate(expr, multistage=False)) 50 | ['IF((x2<=3),((2*x1)/5),(12.0+y))', 'IF((x2<=3),1,(-45.5))'] 51 | >>> print(translate(expr)) 52 | A1=IF((x2<=3),((2*x1)/5),(12.0+y)) 53 | B1=IF((x2<=3),1,(-45.5)) 54 | >>> expr = skast('a=1+x; a+a+a') 55 | >>> print(translate(expr)) 56 | A1=(1+x) 57 | B1=((A1+A1)+A1) 58 | >>> print(translate(expr, multistage=False)) 59 | (((1+x)+(1+x))+(1+x)) 60 | """ 61 | writer = ExcelWriter(multistage=multistage, assign_to=assign_to, 62 | multistage_subexpression_min_length=multistage_subexpression_min_length, 63 | _max_subexpression_length=_max_subexpression_length) 64 | result = writer(node) 65 | if component is not None: 66 | result = result[component] 67 | if multistage: 68 | writer.add_named_subexpression(result) 69 | return writer.code 70 | else: 71 | return result 72 | 73 | def _sum(iterable): 74 | return "({0})".format("+".join(iterable)) 75 | 76 | def _iif(cond, iftrue, iffalse): 77 | if iftrue == iffalse: 78 | return iftrue 79 | else: 80 | return 'IF({0},{1},{2})'.format(cond, iftrue, iffalse) 81 | 82 | def _dotproduct(xs, ys): 83 | return _sum('{0}*{1}'.format(x, y) for x, y in zip(xs, ys)) 84 | 85 | def _step(x): 86 | return _iif('{0}>0'.format(x), 1, 0) 87 | 88 | def _max(xs): 89 | if len(xs) == 1: 90 | return xs[0] 91 | else: 92 | return 'MAX({0})'.format(','.join(xs)) 93 | 94 | def _argmax(xs, maxval=None): 95 | if not maxval: 96 | maxval = _max(xs) 97 | expr = str(len(xs)-1) 98 | n = len(xs)-1 99 | while n > 0: 100 | n -= 1 101 | expr = _iif('{0}={1}'.format(xs[n], maxval), str(n), expr) 102 | return expr 103 | 104 | def is_fmt(template): 105 | # Auto-compacting binary operator 106 | def auto_compacting_operator(self, _): 107 | def fn(x, y=''): 108 | result = template.format(x, y) 109 | if self.multistage and len(result) > self.max_subexpression_length: 110 | if len(x) > len(y): 111 | x, y = self.add_named_subexpression(x), y 112 | else: 113 | x, y = x, self.add_named_subexpression(y) 114 | result = template.format(x, y) 115 | return result 116 | return fn 117 | return auto_compacting_operator 118 | 119 | def _takeuntil(value, iterable): 120 | if value is None: 121 | return iterable 122 | else: 123 | return chain(takewhile(lambda x: x != value, iterable), [value]) 124 | 125 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 126 | 127 | def excel_column_names(start_column='A', end_column=None): 128 | """ 129 | Generates excel column names starting from a given one. 130 | end_column is inclusive. 131 | 132 | >>> gen = excel_column_names('C') 133 | >>> [next(gen) for _ in range(4)] 134 | ['C', 'D', 'E', 'F'] 135 | >>> list(excel_column_names('Y', 'AC')) 136 | ['Y', 'Z', 'AA', 'AB', 'AC'] 137 | >>> list(excel_column_names('AY', 'BC')) 138 | ['AY', 'AZ', 'BA', 'BB', 'BC'] 139 | >>> list(excel_column_names('ZY', 'AAC')) 140 | ['ZY', 'ZZ', 'AAA', 'AAB', 'AAC'] 141 | """ 142 | 143 | ns = _letters 144 | all_cols = map(''.join, chain(ns, product(ns, ns), product(ns, ns, ns))) 145 | while next(all_cols) != start_column: 146 | pass 147 | return _takeuntil(end_column, chain([start_column], all_cols)) 148 | 149 | _range_re = re.compile(r'^([\*A-Z]+)(\d+):([\*A-Z]+)(\d+)$') 150 | 151 | def excel_range(range_): 152 | """ 153 | A convenience method for generating lists of excel cells in a row or column. 154 | The range_ argument is an expression of the form A3:G3 155 | The second endpoint of the range may contain * instead of the 156 | row or column coordinate (as in A3:*3 or A3:A*), in this case 157 | the return value is a generator, enumerating cells in the given 158 | row or column indefinitely. 159 | """ 160 | matches = _range_re.match(range_.upper()) 161 | if not matches: 162 | raise ValueError("Range must be of the form A1:B1") 163 | lcol, lrow, rcol, rrow = matches.groups() 164 | if lcol == '*' or lrow == '*': 165 | raise ValueError("Only the right side of the interval may include wildcard") 166 | if lcol == rcol: # Fixed column 167 | if rrow == '*': 168 | rows = count(int(lrow)) 169 | else: 170 | rows = range(int(lrow), int(rrow)+1) 171 | yield from ('{0}{1}'.format(lcol, i) for i in rows) 172 | elif lrow == rrow: # Fixed row 173 | cols = excel_column_names(lcol, None if rcol == '*' else rcol) 174 | yield from ('{0}{1}'.format(col, lrow) for col in cols) 175 | else: 176 | raise ValueError("Only single-column or single-row ranges are supported") 177 | 178 | def _compact_string(s, max_len=70): 179 | if len(s) > max_len: 180 | part = (max_len - 30) // 2 181 | return s[:part] + ' ...{0} chars skipped... '.format(len(s)-2*part) + s[-part:] 182 | else: 183 | return s 184 | 185 | _builtins = { 186 | 'IF': lambda t, a, b: a if t else b, 187 | 'MAX': max, 188 | 'EXP': np.exp, 189 | 'LOG': np.log, 190 | 'SQRT': np.sqrt, 191 | 'ABS': np.abs, 192 | } 193 | _single_comparison = re.compile(r'(? 10: 201 | lines = lines[:4] + [' ... {0} lines skipped ...'.format(len(lines)-8)] +\ 202 | lines[-4:] 203 | return '\n'.join(lines) 204 | 205 | def to_dataframe(self): 206 | """Converts code to a pandas dataframe, 207 | suitable for pasting into Excel. 208 | 209 | The main usecase for this method is: 210 | 211 | code.to_dataframe().to_clipboard() 212 | 213 | """ 214 | import pandas as pd 215 | return pd.DataFrame([['={0}'.format(v) for v in self.values()]], 216 | columns=self.keys()) 217 | 218 | def evaluate(self, **kwargs): 219 | """Evaluates the excel code using Python's eval. 220 | Will probably fail with MemoryError for longer strings 221 | (because Python's ast.parse can't handle them).""" 222 | env = OrderedDict() 223 | env.update(_builtins) 224 | for k, v in self.items(): 225 | expand = _single_comparison.sub('==', v) # Excel uses '=' for comparisons 226 | env[k] = eval(expand, kwargs, env) #pylint: disable=eval-used 227 | for k in _builtins: 228 | del env[k] 229 | return env 230 | 231 | 232 | class ExcelWriter(ASTProcessor, StandardOps, VectorsAsLists): 233 | """A SK AST processor, producing an Excel expression (or a list of those)""" 234 | 235 | def __init__(self, multistage=False, assign_to=None, positive_infinity=float(np.finfo('float64').max), 236 | negative_infinity=float(np.finfo('float64').min), 237 | multistage_subexpression_min_length=3, 238 | _max_subexpression_length=8100): 239 | self.positive_infinity = positive_infinity 240 | self.negative_infinity = negative_infinity 241 | self.multistage = multistage 242 | self.multistage_subexpression_min_length = multistage_subexpression_min_length 243 | self.max_subexpression_length = _max_subexpression_length # In multistage mode we attempt to keep subexpressions shorter than this 244 | # (because Excel does not allow cell values longer than 8196 chars) 245 | # NB: this is rather ad-hoc and may not always work. 246 | 247 | if self.multistage: 248 | if assign_to is None: 249 | warnings.warn("Value of the assign_to parameter is not provided. Will use default ['A1', 'B1', ...']", UserWarning) 250 | assign_to = excel_range('A1:*1') 251 | self.assign_to = assign_to if hasattr(assign_to, '__next__') else iter(assign_to) 252 | self.code = ExcelCode() 253 | self.references = [{}] 254 | self.temp_ids = id_generator() 255 | 256 | def Identifier(self, id): 257 | return id.id 258 | 259 | def IndexedIdentifier(self, sub): 260 | warnings.warn("Excel does not support vector types natively. " 261 | "Numbers will be appended to the given feature name, " 262 | "it may not be what you intend.", UserWarning) 263 | return "{0}{1}".format(sub.id, sub.index+1) 264 | 265 | def NumberConstant(self, num): 266 | # Infinities have to be handled separately 267 | if np.isinf(num.value): 268 | val = self.positive_infinity if num.value > 0 else self.negative_infinity 269 | else: 270 | val = num.value 271 | return str(val) 272 | 273 | def _iif(self, test, ift, iff): 274 | # Auto-compacting IIF 275 | result = _iif(test, ift, iff) 276 | if self.multistage and len(result) > self.max_subexpression_length: 277 | if len(ift) > len(iff): 278 | ift, iff = self.add_named_subexpression(ift), iff 279 | else: 280 | ift, iff = ift, self.add_named_subexpression(iff) 281 | result = _iif(test, ift, iff) 282 | return result 283 | 284 | # Implement binary and unary operations 285 | Mul = is_fmt('({0}*{1})') 286 | Div = is_fmt('({0}/{1})') 287 | Add = is_fmt('({0}+{1})') 288 | Sub = is_fmt('({0}-{1})') 289 | LtEq = is_fmt('({0}<={1})') 290 | Eq = is_fmt('({0}={1})') 291 | USub = is_fmt('(-{0})') 292 | Exp = is_fmt('EXP({0})') 293 | Sqrt = is_fmt('SQRT({0})') 294 | Log = is_fmt('LOG({0})') 295 | Max = is_fmt('MAX({0},{1})') 296 | Abs = is_fmt('ABS({0})') 297 | Step = is_(_step) 298 | Sigmoid = is_fmt('(1/(1+EXP(-{0})))') 299 | DotProduct = is_(_dotproduct) 300 | VecSum = is_(_sum) 301 | VecMax = is_(_max) 302 | MatVecProduct = lambda self, _: self._matvecproduct 303 | ArgMax = lambda self, _: self._argmax 304 | Softmax = lambda self, _: self._softmax 305 | 306 | def _argmax(self, xs): 307 | xs = [self.possibly_add_named_subexpression(x) for x in xs] 308 | max_var = self.possibly_add_named_subexpression(_max(xs)) 309 | return _argmax(xs, max_var) 310 | 311 | def _matvecproduct(self, M, xs): 312 | xs = [self.possibly_add_named_subexpression(x) for x in xs] 313 | return [_sum('{0}*{1}'.format(m_i[j], xs[j]) for j in range(len(xs))) for m_i in M] 314 | 315 | def _vecsumnormalize(self, xs): 316 | xs = [self.possibly_add_named_subexpression(x) for x in xs] 317 | sum_var = self.possibly_add_named_subexpression(_sum(xs)) 318 | return ['({0}/{1})'.format(x, sum_var) for x in xs] 319 | 320 | def _softmax(self, xs): 321 | return self._vecsumnormalize(['EXP({0})'.format(x) for x in xs]) 322 | 323 | def Let(self, node, **kw): 324 | if not self.multistage: 325 | return StandardOps.Let(self, node) 326 | else: 327 | self.references.append({}) 328 | for defn in node.defs: 329 | self.add_named_subexpression(self(defn.body), defn.name) 330 | result = self(node.body) 331 | self.references.pop() 332 | return result 333 | 334 | def Reference(self, node): 335 | if not self.multistage: 336 | raise ValueError("Reference nodes are only supported in multi-stage code generation") 337 | if node.name not in self.references[-1]: 338 | raise ValueError("Undefined reference: {0}".format(node.name)) 339 | return self.references[-1][node.name] 340 | 341 | def possibly_add_named_subexpression(self, value): 342 | if self.multistage and len(value) >= self.multistage_subexpression_min_length: 343 | return self.add_named_subexpression(value) 344 | else: 345 | return value 346 | 347 | def add_named_subexpression(self, value, name=None): 348 | is_list = isinstance(value, list) 349 | if not isinstance(value, list): 350 | value = [value] 351 | if name is None: 352 | name = next(self.temp_ids) 353 | try: 354 | ref = [] 355 | for v in value: 356 | next_output = next(self.assign_to) 357 | if next_output in self.code: 358 | raise ValueError("Repeated names are not supported in the assign_to parameter") 359 | self.code[next_output] = v 360 | ref.append(next_output) 361 | if len(ref) == 1 and not is_list: 362 | ref = ref[0] 363 | self.references[-1][name] = ref 364 | return ref 365 | except StopIteration as ex: 366 | raise ValueError("The number of fields provided in the assign_to parameter" 367 | " is not sufficient to complete the computation.") from ex 368 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code 6 | extension-pkg-whitelist= 7 | 8 | # Add files or directories to the blacklist. They should be base names, not 9 | # paths. 10 | ignore=CVS 11 | 12 | # Add files or directories matching the regex patterns to the blacklist. The 13 | # regex matches against base names, not paths. 14 | ignore-patterns= 15 | 16 | # Python code to execute, usually for sys.path manipulation such as 17 | # pygtk.require(). 18 | #init-hook= 19 | 20 | # Use multiple processes to speed up Pylint. 21 | jobs=1 22 | 23 | # List of plugins (as comma separated values of python modules names) to load, 24 | # usually to register additional checkers. 25 | load-plugins= 26 | 27 | # Pickle collected data for later comparisons. 28 | persistent=yes 29 | 30 | # Specify a configuration file. 31 | #rcfile= 32 | 33 | # When enabled, pylint would attempt to guess common misconfiguration and emit 34 | # user-friendly hints instead of false-positive error messages 35 | suggestion-mode=yes 36 | 37 | # Allow loading of arbitrary C extensions. Extensions are imported into the 38 | # active Python interpreter and may run arbitrary code. 39 | unsafe-load-any-extension=no 40 | 41 | 42 | [MESSAGES CONTROL] 43 | 44 | # Only show warnings with the listed confidence levels. Leave empty to show 45 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 46 | confidence= 47 | 48 | # Disable the message, report, category or checker with the given id(s). You 49 | # can either give multiple identifiers separated by comma (,) or put this 50 | # option multiple times (only on the command line, not in the configuration 51 | # file where it should appear only once).You can also use "--disable=all" to 52 | # disable everything first and then reenable specific checks. For example, if 53 | # you want to run only the similarities checker, you can use "--disable=all 54 | # --enable=similarities". If you want to run only the classes checker, but have 55 | # no Warning level messages displayed, use"--disable=all --enable=classes 56 | # --disable=W" 57 | disable=fixme, 58 | missing-docstring, 59 | invalid-name, 60 | redefined-builtin, 61 | abstract-method 62 | 63 | # Enable the message, report, category or checker with the given id(s). You can 64 | # either give multiple identifier separated by comma (,) or put this option 65 | # multiple time (only on the command line, not in the configuration file where 66 | # it should appear only once). See also the "--disable" option for examples. 67 | enable=c-extension-no-member 68 | 69 | [REPORTS] 70 | 71 | # Python expression which should return a note less than 10 (10 is the highest 72 | # note). You have access to the variables errors warning, statement which 73 | # respectively contain the number of errors / warnings messages and the total 74 | # number of statements analyzed. This is used by the global evaluation report 75 | # (RP0004). 76 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 77 | 78 | # Template used to display messages. This is a python new-style format string 79 | # used to format the message information. See doc for all details 80 | #msg-template= 81 | 82 | # Set the output format. Available formats are text, parseable, colorized, json 83 | # and msvs (visual studio).You can also give a reporter class, eg 84 | # mypackage.mymodule.MyReporterClass. 85 | output-format=text 86 | 87 | # Tells whether to display a full report or only the messages 88 | reports=no 89 | 90 | # Activate the evaluation score. 91 | score=yes 92 | 93 | 94 | [REFACTORING] 95 | 96 | # Maximum number of nested blocks for function / method body 97 | max-nested-blocks=5 98 | 99 | 100 | [BASIC] 101 | 102 | # Naming style matching correct argument names 103 | argument-naming-style=snake_case 104 | 105 | # Regular expression matching correct argument names. Overrides argument- 106 | # naming-style 107 | #argument-rgx= 108 | 109 | # Naming style matching correct attribute names 110 | attr-naming-style=snake_case 111 | 112 | # Regular expression matching correct attribute names. Overrides attr-naming- 113 | # style 114 | #attr-rgx= 115 | 116 | # Bad variable names which should always be refused, separated by a comma 117 | bad-names=foo, 118 | bar, 119 | baz, 120 | toto, 121 | tutu, 122 | tata 123 | 124 | # Naming style matching correct class attribute names 125 | class-attribute-naming-style=any 126 | 127 | # Regular expression matching correct class attribute names. Overrides class- 128 | # attribute-naming-style 129 | #class-attribute-rgx= 130 | 131 | # Naming style matching correct class names 132 | class-naming-style=PascalCase 133 | 134 | # Regular expression matching correct class names. Overrides class-naming-style 135 | #class-rgx= 136 | 137 | # Naming style matching correct constant names 138 | const-naming-style=UPPER_CASE 139 | 140 | # Regular expression matching correct constant names. Overrides const-naming- 141 | # style 142 | #const-rgx= 143 | 144 | # Minimum line length for functions/classes that require docstrings, shorter 145 | # ones are exempt. 146 | docstring-min-length=-1 147 | 148 | # Naming style matching correct function names 149 | function-naming-style=snake_case 150 | 151 | # Regular expression matching correct function names. Overrides function- 152 | # naming-style 153 | #function-rgx= 154 | 155 | # Good variable names which should always be accepted, separated by a comma 156 | good-names=i,j,k,ex,df,fn,id,db,log,app,X,X_train,X_test,y 157 | 158 | # Include a hint for the correct naming format with invalid-name 159 | include-naming-hint=no 160 | 161 | # Naming style matching correct inline iteration names 162 | inlinevar-naming-style=any 163 | 164 | # Regular expression matching correct inline iteration names. Overrides 165 | # inlinevar-naming-style 166 | #inlinevar-rgx= 167 | 168 | # Naming style matching correct method names 169 | method-naming-style=snake_case 170 | 171 | # Regular expression matching correct method names. Overrides method-naming- 172 | # style 173 | #method-rgx= 174 | 175 | # Naming style matching correct module names 176 | module-naming-style=snake_case 177 | 178 | # Regular expression matching correct module names. Overrides module-naming- 179 | # style 180 | #module-rgx= 181 | 182 | # Colon-delimited sets of names that determine each other's naming style when 183 | # the name regexes allow several styles. 184 | name-group= 185 | 186 | # Regular expression which should only match function or class names that do 187 | # not require a docstring. 188 | no-docstring-rgx=^_ 189 | 190 | # List of decorators that produce properties, such as abc.abstractproperty. Add 191 | # to this list to register other decorators that produce valid properties. 192 | property-classes=abc.abstractproperty 193 | 194 | # Naming style matching correct variable names 195 | variable-naming-style=snake_case 196 | 197 | # Regular expression matching correct variable names. Overrides variable- 198 | # naming-style 199 | #variable-rgx= 200 | 201 | 202 | [FORMAT] 203 | 204 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 205 | expected-line-ending-format= 206 | 207 | # Regexp for a line that is allowed to be longer than the limit. 208 | ignore-long-lines=^\s*(# )??$ 209 | 210 | # Number of spaces of indent required inside a hanging or continued line. 211 | indent-after-paren=4 212 | 213 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 214 | # tab). 215 | indent-string=' ' 216 | 217 | # Maximum number of characters on a single line. 218 | max-line-length=150 219 | 220 | # Maximum number of lines in a module 221 | max-module-lines=1000 222 | 223 | # List of optional constructs for which whitespace checking is disabled. `dict- 224 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 225 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 226 | # `empty-line` allows space-only lines. 227 | no-space-check=trailing-comma, 228 | dict-separator, 229 | empty-line 230 | 231 | # Allow the body of a class to be on the same line as the declaration if body 232 | # contains single statement. 233 | single-line-class-stmt=no 234 | 235 | # Allow the body of an if to be on the same line as the test if there is no 236 | # else. 237 | single-line-if-stmt=no 238 | 239 | 240 | [LOGGING] 241 | 242 | # Logging modules to check that the string format arguments are in logging 243 | # function parameter format 244 | logging-modules=logging 245 | 246 | 247 | [MISCELLANEOUS] 248 | 249 | # List of note tags to take in consideration, separated by a comma. 250 | notes=FIXME, 251 | XXX, 252 | TODO 253 | 254 | 255 | [SIMILARITIES] 256 | 257 | # Ignore comments when computing similarities. 258 | ignore-comments=yes 259 | 260 | # Ignore docstrings when computing similarities. 261 | ignore-docstrings=yes 262 | 263 | # Ignore imports when computing similarities. 264 | ignore-imports=no 265 | 266 | # Minimum lines number of a similarity. 267 | min-similarity-lines=4 268 | 269 | 270 | [SPELLING] 271 | 272 | # Limits count of emitted suggestions for spelling mistakes 273 | max-spelling-suggestions=4 274 | 275 | # Spelling dictionary name. Available dictionaries: none. To make it working 276 | # install python-enchant package. 277 | spelling-dict= 278 | 279 | # List of comma separated words that should not be checked. 280 | spelling-ignore-words= 281 | 282 | # A path to a file that contains private dictionary; one word per line. 283 | spelling-private-dict-file= 284 | 285 | # Tells whether to store unknown words to indicated private dictionary in 286 | # --spelling-private-dict-file option instead of raising a message. 287 | spelling-store-unknown-words=no 288 | 289 | 290 | [TYPECHECK] 291 | 292 | # List of decorators that produce context managers, such as 293 | # contextlib.contextmanager. Add to this list to register other decorators that 294 | # produce valid context managers. 295 | contextmanager-decorators=contextlib.contextmanager 296 | 297 | # List of members which are set dynamically and missed by pylint inference 298 | # system, and so shouldn't trigger E1101 when accessed. Python regular 299 | # expressions are accepted. 300 | generated-members= 301 | 302 | # Tells whether missing members accessed in mixin class should be ignored. A 303 | # mixin class is detected if its name ends with "mixin" (case insensitive). 304 | ignore-mixin-members=yes 305 | 306 | # This flag controls whether pylint should warn about no-member and similar 307 | # checks whenever an opaque object is returned when inferring. The inference 308 | # can return multiple potential results while evaluating a Python object, but 309 | # some branches might not be evaluated, which results in partial inference. In 310 | # that case, it might be useful to still emit no-member and other checks for 311 | # the rest of the inferred objects. 312 | ignore-on-opaque-inference=yes 313 | 314 | # List of class names for which member attributes should not be checked (useful 315 | # for classes with dynamically set attributes). This supports the use of 316 | # qualified names. 317 | ignored-classes=optparse.Values,thread._local,_thread._local,scoped_session,ImmutableColumnCollection,Table 318 | 319 | # List of module names for which member attributes should not be checked 320 | # (useful for modules/projects where namespaces are manipulated during runtime 321 | # and thus existing member attributes cannot be deduced by static analysis. It 322 | # supports qualified module names, as well as Unix pattern matching. 323 | ignored-modules= 324 | 325 | # Show a hint with possible names when a member name was not found. The aspect 326 | # of finding the hint is based on edit distance. 327 | missing-member-hint=yes 328 | 329 | # The minimum edit distance a name should have in order to be considered a 330 | # similar match for a missing member name. 331 | missing-member-hint-distance=1 332 | 333 | # The total number of similar names that should be taken in consideration when 334 | # showing a hint for a missing member. 335 | missing-member-max-choices=1 336 | 337 | 338 | [VARIABLES] 339 | 340 | # List of additional names supposed to be defined in builtins. Remember that 341 | # you should avoid to define new builtins when possible. 342 | additional-builtins= 343 | 344 | # Tells whether unused global variables should be treated as a violation. 345 | allow-global-unused-variables=yes 346 | 347 | # List of strings which can identify a callback function by name. A callback 348 | # name must start or end with one of those strings. 349 | callbacks=cb_, 350 | _cb 351 | 352 | # A regular expression matching the name of dummy variables (i.e. expectedly 353 | # not used). 354 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 355 | 356 | # Argument names that match this expression will be ignored. Default to name 357 | # with leading underscore 358 | ignored-argument-names=_.*|^ignored_|^unused_ 359 | 360 | # Tells whether we should check for unused import in __init__ files. 361 | init-import=no 362 | 363 | # List of qualified module names which can have objects that can redefine 364 | # builtins. 365 | redefining-builtins-modules=six.moves,past.builtins,future.builtins 366 | 367 | 368 | [CLASSES] 369 | 370 | # List of method names used to declare (i.e. assign) instance attributes. 371 | defining-attr-methods=__init__, 372 | __new__, 373 | setUp 374 | 375 | # List of member names, which should be excluded from the protected access 376 | # warning. 377 | exclude-protected=_asdict, 378 | _fields, 379 | _replace, 380 | _source, 381 | _make 382 | 383 | # List of valid names for the first argument in a class method. 384 | valid-classmethod-first-arg=cls 385 | 386 | # List of valid names for the first argument in a metaclass class method. 387 | valid-metaclass-classmethod-first-arg=mcs 388 | 389 | 390 | [DESIGN] 391 | 392 | # Maximum number of arguments for function / method 393 | max-args=5 394 | 395 | # Maximum number of attributes for a class (see R0902). 396 | max-attributes=7 397 | 398 | # Maximum number of boolean expressions in a if statement 399 | max-bool-expr=5 400 | 401 | # Maximum number of branch for function / method body 402 | max-branches=12 403 | 404 | # Maximum number of locals for function / method body 405 | max-locals=15 406 | 407 | # Maximum number of parents for a class (see R0901). 408 | max-parents=7 409 | 410 | # Maximum number of public methods for a class (see R0904). 411 | max-public-methods=20 412 | 413 | # Maximum number of return / yield for function / method body 414 | max-returns=6 415 | 416 | # Maximum number of statements in function / method body 417 | max-statements=50 418 | 419 | # Minimum number of public methods for a class (see R0903). 420 | min-public-methods=2 421 | 422 | 423 | [IMPORTS] 424 | 425 | # Allow wildcard imports from modules that define __all__. 426 | allow-wildcard-with-all=no 427 | 428 | # Analyse import fallback blocks. This can be used to support both Python 2 and 429 | # 3 compatible code, which means that the block might have code that exists 430 | # only in one or another interpreter, leading to false positives when analysed. 431 | analyse-fallback-blocks=no 432 | 433 | # Deprecated modules which should not be used, separated by a comma 434 | deprecated-modules=optparse,tkinter.tix 435 | 436 | # Create a graph of external dependencies in the given file (report RP0402 must 437 | # not be disabled) 438 | ext-import-graph= 439 | 440 | # Create a graph of every (i.e. internal and external) dependencies in the 441 | # given file (report RP0402 must not be disabled) 442 | import-graph= 443 | 444 | # Create a graph of internal dependencies in the given file (report RP0402 must 445 | # not be disabled) 446 | int-import-graph= 447 | 448 | # Force import order to recognize a module as part of the standard 449 | # compatibility libraries. 450 | known-standard-library= 451 | 452 | # Force import order to recognize a module as part of a third party library. 453 | known-third-party=enchant 454 | 455 | 456 | [EXCEPTIONS] 457 | 458 | # Exceptions that will emit a warning when being caught. Defaults to 459 | # "Exception" 460 | overgeneral-exceptions=Exception 461 | -------------------------------------------------------------------------------- /skompiler/fromskast/sqlalchemy.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKompiler: Generate SQLAlchemy expressions from SKAST. 3 | """ 4 | from functools import reduce 5 | from collections import namedtuple 6 | import numpy as np 7 | import sqlalchemy as sa 8 | from sqlalchemy.sql.selectable import Join, FromGrouping 9 | from ..ast import ArgMax, VecMax, Softmax, IsElemwise, VecSum, Max, IsAtom 10 | from ._common import ASTProcessor, StandardOps, StandardArithmetics, is_, tolist,\ 11 | not_implemented, prepare_assign_to, id_generator, denumpyfy 12 | 13 | #pylint: disable=trailing-whitespace 14 | def translate(node, dialect=None, assign_to='y', component=None, 15 | multistage=True, key_column='id', from_obj='data'): 16 | """Translates SKAST to an SQLAlchemy expression (or a list of those, if the output should be a vector). 17 | 18 | If dialect is not None, further compiles the expression(s) to a given dialect via to_sql. 19 | 20 | Kwargs: 21 | assign_to (None/string/list of str): See to_sql 22 | 23 | component (int): If the result is a vector and you only need one component of it, specify its index (0-based) here. 24 | 25 | multistage (bool): When multistage=False, the returned value is a single expression which can be selected directly from the 26 | source data table. This, however, may make the resulting query rather long, as some functions (e.g. argmax) 27 | require repeated computaion of the same parts over and over. 28 | The problem is solved by splitting the computation in a sequence of CTE subqueries - the "multistage" mode. 29 | The resulting query may then look like 30 | 31 | with _tmp1 as (select [probability computations] from data), 32 | _tmp2 as (select [argmax computation] from _tmp1), 33 | ... 34 | select [final values] from _tmpX 35 | 36 | Default - True 37 | 38 | from_obj: A string or a SQLAlchemy selectable object - the source table for the data. 39 | In non-multistage mode this may be None - in this case the returned value is 40 | simply 'SELECT cols'. 41 | 42 | key_column: A string or a sa.column object, naming the key column in the source table. 43 | Compulsory for multistage mode. 44 | 45 | 46 | >>> from skompiler.toskast.string import translate as skast 47 | >>> expr = skast('[2*x[0], 1] if x[1] <= 3 else [12.0, 45.5]') 48 | >>> print(translate(expr, 'sqlite', multistage=False, from_obj=None)) 49 | SELECT CASE WHEN (x2 <= 3) THEN 2 * x1 ELSE 12.0 END AS y1, CASE WHEN (x2 <= 3) THEN 1 ELSE 45.5 END AS y2 50 | 51 | 52 | >>> expr = skast('x=1; y=2; x+y') 53 | >>> print(translate(expr, 'sqlite', multistage=True)) 54 | WITH _tmp1 AS 55 | (SELECT data.id AS __id__, 1 AS f1 56 | FROM data), 57 | _tmp2 AS 58 | (SELECT data.id AS __id__, 2 AS f1 59 | FROM data) 60 | SELECT _tmp1.f1 + _tmp2.f1 AS y 61 | FROM _tmp1 JOIN _tmp2 ON _tmp1.__id__ = _tmp2.__id__ 62 | >>> expr = skast('x+y') 63 | >>> stbl = sa.select([sa.column('id'), sa.column('x'), sa.column('y')], from_obj=sa.table('test')).cte('_data') 64 | >>> print(translate(expr, 'sqlite', multistage=False, from_obj=stbl)) 65 | WITH _data AS 66 | (SELECT id, x, y 67 | FROM test) 68 | SELECT x + y AS y 69 | FROM _data 70 | """ 71 | if multistage and from_obj is None: 72 | raise ValueError("from_obj must be specified in multistage mode") 73 | result = SQLAlchemyWriter(from_obj=from_obj, key_column=key_column, multistage=multistage)(node) 74 | 75 | if component is not None: 76 | result = result._replace(cols=[result.cols[component]]) 77 | 78 | assign_to = prepare_assign_to(assign_to, len(result.cols)) 79 | if assign_to is not None: 80 | result = result._replace(cols=[col.label(lbl) for col, lbl in zip(result.cols, assign_to)]) 81 | 82 | result = sa.select(result.cols, from_obj=result.from_obj) 83 | 84 | if dialect is not None: 85 | result = to_sql(result, dialect) 86 | return result 87 | 88 | 89 | def _max(xs): 90 | if len(xs) == 1: 91 | return xs[0] 92 | return reduce(greatest, xs) 93 | 94 | def _sum(iterable): 95 | "The built-in 'sum' does not work for us as we need." 96 | return reduce(lambda x, y: x+y, iterable) 97 | 98 | def _iif(cond, iftrue, iffalse): 99 | # Optimize if (...) then X else X for literal X 100 | # A lot of these occur when compiling trees 101 | if isinstance(iftrue, sa.sql.elements.BindParameter) and \ 102 | isinstance(iffalse, sa.sql.elements.BindParameter) and \ 103 | iftrue.value == iffalse.value: 104 | return iftrue 105 | return sa.case([(cond, iftrue)], else_=iffalse) 106 | 107 | def _matvecproduct(M, x): 108 | return [_sum(m_i[j] * x[j] for j in range(len(x))) for m_i in M] 109 | 110 | def _dotproduct(xs, ys): 111 | return [_sum(x * y for x, y in zip(xs, ys))] 112 | 113 | def _step(x): 114 | return _iif(x > 0, 1, 0) 115 | 116 | def extract_tables(from_obj): 117 | if isinstance(from_obj, FromGrouping): 118 | return extract_tables(from_obj.element) 119 | elif isinstance(from_obj, Join): 120 | return extract_tables(from_obj.left) + extract_tables(from_obj.right) 121 | else: 122 | return [from_obj] 123 | 124 | def _merge(tbl1, tbl2): 125 | if tbl1 is None: 126 | return tbl2 127 | elif tbl2 is None: 128 | return tbl1 129 | if tbl1 is tbl2: 130 | return tbl1 131 | # Either of the arguments may be a join clause and these 132 | # may include repeated elements. If so, we have to extract them and recombine. 133 | all_tables = list(sorted(set(extract_tables(tbl1) + extract_tables(tbl2)), key=lambda x: x.name)) 134 | tbl1 = all_tables[0] 135 | joined = tbl1 136 | for tbl_next in all_tables[1:]: 137 | joined = joined.join(tbl_next, onclause=tbl1.key_ == tbl_next.key_) 138 | joined.key_ = tbl1.key_ 139 | return joined 140 | 141 | Result = namedtuple('Result', 'cols from_obj') 142 | 143 | class SQLAlchemyWriter(ASTProcessor, StandardOps, StandardArithmetics): 144 | """A SK AST processor, producing a SQLAlchemy "multistage" expression. 145 | The interpretation of each node is a tuple, containing a list of column expressions and a from_obj, 146 | where these columns must be queried from.""" 147 | 148 | def __init__(self, from_obj='data', key_column='id', 149 | positive_infinity=float(np.finfo('float64').max), 150 | negative_infinity=float(np.finfo('float64').min), 151 | multistage=True): 152 | self.positive_infinity = positive_infinity 153 | self.negative_infinity = negative_infinity 154 | if multistage: 155 | if isinstance(from_obj, str): 156 | from_obj = sa.table(from_obj, sa.column(key_column)) 157 | # This is a bit hackish, but quite convenient. 158 | # This way we do not have to carry around an extra "key" field in our results all the time 159 | from_obj.key_ = from_obj.columns[key_column] 160 | else: 161 | if key_column not in from_obj.columns: 162 | raise ValueError("The provided selectable does not contain the key column {0}".format(key_column)) 163 | from_obj.key_ = from_obj.columns[key_column] 164 | elif isinstance(from_obj, str): 165 | from_obj = sa.table(from_obj) 166 | self.from_obj = from_obj 167 | self.temp_ids = id_generator() 168 | self.references = [{}] 169 | self.multistage = multistage 170 | 171 | def Identifier(self, id): 172 | return Result([sa.column(id.id)], self.from_obj) 173 | 174 | def _indexed_identifier(self, id, idx): 175 | return sa.column("{0}{1}".format(id, idx+1)) 176 | 177 | def IndexedIdentifier(self, sub): 178 | return Result([self._indexed_identifier(sub.id, sub.index)], self.from_obj) 179 | 180 | def _number_constant(self, value): 181 | # Infinities have to be handled separately 182 | if np.isinf(value): 183 | value = self.positive_infinity if value > 0 else self.negative_infinity 184 | else: 185 | value = denumpyfy(value) 186 | return sa.literal(value) 187 | 188 | def NumberConstant(self, num): 189 | return Result([self._number_constant(num.value)], self.from_obj) 190 | 191 | def VectorIdentifier(self, id): 192 | return Result([self._indexed_identifier(id.id, i) for i in range(id.size)], self.from_obj) 193 | 194 | def VectorConstant(self, vec): 195 | return Result([self._number_constant(v) for v in tolist(vec.value)], self.from_obj) 196 | 197 | def MatrixConstant(self, mtx): 198 | return Result([[self._number_constant(v) for v in tolist(row)] for row in mtx.value], self.from_obj) 199 | 200 | def UnaryFunc(self, node, **kw): 201 | arg = self(node.arg) 202 | if isinstance(node.op, ArgMax): 203 | return self._argmax(arg) 204 | elif isinstance(node.op, VecMax): 205 | return self._vecmax(arg) 206 | elif isinstance(node.op, VecSum): 207 | return self._vecsum(arg) 208 | elif isinstance(node.op, Softmax): 209 | return self._softmax(arg) 210 | else: 211 | op = self(node.op) 212 | return Result([op(el) for el in arg.cols], arg.from_obj) 213 | 214 | ArgMax = VecSumNormalize = VecSum = VecMax = Softmax = not_implemented 215 | 216 | def BinOp(self, node, **kw): 217 | left, right, op = self(node.left), self(node.right), self(node.op) 218 | if not isinstance(node.op, IsElemwise): 219 | # MatVecProduct requires atomizing the argument, otherwise it will be repeated multiple times in the output 220 | if not isinstance(node.right, IsAtom): 221 | right = self._make_cte(right) 222 | return Result(op(left.cols, right.cols), _merge(left.from_obj, right.from_obj)) 223 | elif len(left.cols) != len(right.cols): 224 | raise ValueError("Mismatching operand dimensions in {0}".format(repr(node.op))) 225 | elif isinstance(node.op, Max): 226 | # Max is implemented as (if x > y then x else y), hence to avoid double-computation, 227 | # we save x and y in separate CTE's 228 | if not isinstance(node.left, IsAtom): 229 | left = self._make_cte(left) 230 | if not isinstance(node.right, IsAtom): 231 | right = self._make_cte(right) 232 | return Result([op(lc, rc) for lc, rc in zip(left.cols, right.cols)], _merge(left.from_obj, right.from_obj)) 233 | else: 234 | return Result([op(lc, rc) for lc, rc in zip(left.cols, right.cols)], _merge(left.from_obj, right.from_obj)) 235 | 236 | def MakeVector(self, vec): 237 | result = [] 238 | tbls = set() 239 | for el in vec.elems: 240 | el = self(el) 241 | tbls.add(el.from_obj) 242 | if len(el.cols) != 1: 243 | raise ValueError("MakeVector expects a list of scalars") 244 | result.append(el.cols[0]) 245 | tbls = list(tbls) 246 | target_table = tbls[0] 247 | for tbl in tbls[1:]: 248 | new_joined = target_table.join(tbl, onclause=target_table.key_ == tbl.key_) 249 | new_joined.key_ = target_table.key_ 250 | target_table = new_joined 251 | return Result(result, target_table) 252 | 253 | def IfThenElse(self, node): 254 | test, iftrue, iffalse = self(node.test), self(node.iftrue), self(node.iffalse) 255 | 256 | return Result([_iif(test.cols[0], ift, iff) for ift, iff in zip(iftrue.cols, iffalse.cols)], 257 | reduce(_merge, [test.from_obj, iftrue.from_obj, iffalse.from_obj])) 258 | 259 | MatVecProduct = is_(_matvecproduct) 260 | DotProduct = is_(_dotproduct) 261 | Exp = is_(sa.func.exp) 262 | Log = is_(sa.func.log) 263 | Sqrt = is_(sa.func.sqrt) 264 | Abs = is_(sa.func.abs) 265 | Step = is_(_step) 266 | Max = is_(lambda x, y: _max([x, y])) 267 | 268 | # ------ The actual "multi-stage" logic ----- 269 | def Let(self, node, **kw): 270 | if not self.multistage: 271 | return StandardOps.Let(self, node, **kw) 272 | self.references.append({}) 273 | for defn in node.defs: 274 | self.references[-1][defn.name] = self._make_cte(self(defn.body)) 275 | result = self(node.body) 276 | self.references.pop() 277 | return result 278 | 279 | def Reference(self, node): 280 | if not self.multistage: 281 | raise ValueError("References are not supported in non-multistage mode") 282 | if node.name not in self.references[-1]: 283 | raise ValueError("Undefined reference: {0}".format(node.name)) 284 | return self.references[-1][node.name] 285 | 286 | def _make_cte(self, result, col_names=None, key_label='__id__'): 287 | if not self.multistage: 288 | return result 289 | if col_names is None: 290 | col_names = ['f{0}'.format(i+1) for i in range(len(result.cols))] 291 | labeled_cols = [c.label(n) for c, n in zip(result.cols, col_names)] 292 | new_tbl = sa.select([result.from_obj.key_.label(key_label)] + labeled_cols, from_obj=result.from_obj).cte(next(self.temp_ids)) 293 | new_tbl.key_ = new_tbl.columns[key_label] 294 | new_cols = [new_tbl.columns[n] for n in col_names] 295 | return Result(new_cols, new_tbl) 296 | 297 | def _argmax(self, result): 298 | if len(result.cols) == 1: 299 | return Result([sa.literal(0)], self.from_obj) 300 | features = self._make_cte(result) 301 | max_val = Result([_max(features.cols)], features.from_obj) 302 | max_val = self._make_cte(max_val, ['_max']) 303 | 304 | argmax = sa.case([(col == max_val.cols[0], i) 305 | for i, col in enumerate(features.cols[:-1])], 306 | else_=len(features.cols)-1) 307 | return Result([argmax], _merge(features.from_obj, max_val.from_obj)) 308 | 309 | def _vecmax(self, result): 310 | return Result([_max(result.cols)], result.from_obj) 311 | 312 | def _softmax(self, result): 313 | return self._vecsumnormalize(Result([sa.func.exp(col) for col in result.cols], result.from_obj)) 314 | 315 | def _vecsumnormalize(self, result): 316 | features = self._make_cte(result) 317 | sum_val = Result([_sum(features.cols)], features.from_obj) 318 | sum_val = self._make_cte(sum_val, ['_sum']) 319 | return Result([col/sum_val.cols[0] for col in features.cols], 320 | _merge(features.from_obj, sum_val.from_obj)) 321 | 322 | def _vecsum(self, result): 323 | return Result([_sum(result.cols)], result.from_obj) 324 | 325 | # ------- SQLAlchemy "greatest" function 326 | # See https://docs.sqlalchemy.org/en/latest/core/compiler.html 327 | #pylint: disable=wrong-import-position,wrong-import-order 328 | from sqlalchemy.sql import expression 329 | from sqlalchemy.ext.compiler import compiles 330 | from sqlalchemy.types import Numeric 331 | 332 | class greatest(expression.FunctionElement): 333 | type = Numeric() 334 | name = 'greatest' 335 | 336 | @compiles(greatest) 337 | def default_greatest(element, compiler, **kw): 338 | res = compiler.visit_function(element, **kw) 339 | return res 340 | 341 | @compiles(greatest, 'sqlite') 342 | @compiles(greatest, 'mssql') 343 | @compiles(greatest, 'oracle') 344 | def case_greatest(element, compiler, **kw): 345 | arg1, arg2 = list(element.clauses) 346 | return compiler.process(sa.case([(arg1 > arg2, arg1)], else_=arg2), **kw) 347 | 348 | 349 | # Utilities ---------------------------------- 350 | import sqlalchemy.dialects 351 | #pylint: disable=wildcard-import,unused-wildcard-import 352 | from sqlalchemy.dialects import * # Must do it in order to getattr(sqlalchemy.dialects, ...) 353 | def to_sql(sa_expr, dialect_name='sqlite'): 354 | """ 355 | Helper function. Given a SQLAlchemy expression, returns the corresponding 356 | SQL string in a given dialect. 357 | """ 358 | 359 | dialect_module = getattr(sqlalchemy.dialects, dialect_name) 360 | return str(sa_expr.compile(dialect=dialect_module.dialect(), 361 | compile_kwargs={'literal_binds': True})) 362 | -------------------------------------------------------------------------------- /skompiler/ast.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKompiler: AST nodes. 3 | 4 | The classes here describe the AST nodes of the expressions produced 5 | by SKompiler. 6 | 7 | Notes: 8 | - We might have relied on Python's ast.* classes, but this 9 | would introduce unnecesary complexity and limit possibilities for 10 | adding custom nodes for special cases. 11 | - @dataclass would be a nice tech to use here, but it would prevent 12 | compatibility with Python older than 3.6 13 | 14 | >>> expr = BinOp(Mul(), Identifier('x'), IndexedIdentifier('y', -5, 10)) 15 | >>> expr = BinOp(Add(), NumberConstant(12.2), expr) 16 | >>> print(str(expr)) 17 | (12.2 + (x * y[-5])) 18 | >>> expr = Let([Definition('z', expr)], BinOp(Add(), Reference('z'), NumberConstant(2))) 19 | >>> print(str(expr)) 20 | { 21 | $z = (12.2 + (x * y[-5])); 22 | ($z + 2) 23 | } 24 | """ 25 | #pylint: disable=protected-access,multiple-statements,too-few-public-methods,no-member 26 | from itertools import count 27 | from importlib import import_module 28 | import numpy as np 29 | 30 | 31 | # Each ASTNode registers its class name in this set 32 | AST_NODES = set([]) 33 | 34 | # This is the set of conversions supported in the node.to(...) method 35 | TRANSLATORS = { 36 | 'excel': 'skompiler.fromskast.excel:translate', 37 | 'python': 'skompiler.fromskast.python:translate', 38 | 'sqlalchemy': 'skompiler.fromskast.sqlalchemy:translate', 39 | 'sympy': 'skompiler.fromskast.sympy:translate', 40 | 'pfa': 'skompiler.fromskast.pfa:translate', 41 | 'string': str 42 | } 43 | 44 | 45 | #region ASTNode base ------------------------------------------------- 46 | 47 | # Basic type inference (NB: not an enum to keep compatibility with Python 3.3. Maybe it is lost already, though) 48 | DTYPE_SCALAR = 1 49 | DTYPE_VECTOR = 2 50 | DTYPE_MATRIX = 3 51 | DTYPE_OTHER = 42 52 | DTYPE_UNKNOWN = None 53 | class ASTTypeError(ValueError): pass 54 | class UnableToDecompose(ValueError): pass 55 | 56 | 57 | class ASTNodeCreator(type): 58 | """A metaclass, which allows us to implement our AST nodes like mutable namedtuples 59 | with a nicer declaration syntax.""" 60 | def __new__(mcs, name, bases, dct, fields='', repr=None, dtype=DTYPE_OTHER): 61 | if fields is None: 62 | return super().__new__(mcs, name, bases, dct) 63 | else: 64 | cls = super().__new__(mcs, name, bases, dct) 65 | cls._fields = fields.split() 66 | cls._template = repr 67 | cls._default_dtype = dtype 68 | AST_NODES.add(name) 69 | return cls 70 | 71 | # For Python 3.5, see https://stackoverflow.com/a/25191150/318964 72 | def __init__(cls, name, bases, dct, **_): 73 | super().__init__(name, bases, dct) 74 | 75 | 76 | _singletons = {} 77 | 78 | class ASTNode(object, metaclass=ASTNodeCreator, fields=None): 79 | """Base class for all AST nodes. You may not instantiate it.""" 80 | 81 | def __new__(cls, *_args, **_kw): 82 | # Save some memory on singletons 83 | if hasattr(cls, '_fields') and not cls._fields: 84 | # Singleton node 85 | name = cls.__name__ 86 | if _singletons.get(name, None) is None: 87 | _singletons[name] = super().__new__(cls) 88 | return _singletons[name] 89 | return super().__new__(cls) 90 | 91 | def __init__(self, *args, **kw): 92 | """Make sure the constructor arguments correspond to the _fields proprty""" 93 | 94 | if len(args) > len(self._fields): raise Exception("Too many arguments") 95 | args = dict(zip(self._fields, args)) 96 | for k in args: 97 | if k in kw: 98 | raise Exception("Argument %s defined multiple times" % k) 99 | args.update(kw) 100 | if len(args) != len(self._fields): 101 | raise Exception("Not enough arguments ({0}) given to {1}".format(len(args), self.__class__.__name__)) 102 | for k in self._fields: 103 | if k not in args: 104 | raise Exception("Argument %s not provided" % k) 105 | self.__dict__[k] = args[k] 106 | self.__dict__['_dtype'] = self._compute_dtype() 107 | 108 | def __str__(self): 109 | dct = {k: str(v) for k, v in vars(self).items()} 110 | return self._template.format(**dct) 111 | 112 | def __repr__(self): 113 | return self.__class__.__name__ + '(' + ', '.join(k + '=' + repr(getattr(self, k)) for k in self._fields) + ')' 114 | 115 | def __setattr__(self, field, value): 116 | raise Exception("AST nodes are immutable") 117 | 118 | def __iter__(self): 119 | "Iteration over node fields. Note that the node works like a dict.items(), not like a tuple" 120 | for k in self._fields: 121 | yield k, getattr(self, k) 122 | 123 | def __bool__(self): 124 | return True 125 | 126 | # Convenience routines for combining AST nodes 127 | def __len__(self): 128 | raise UnableToDecompose() 129 | 130 | def __getitem__(self, idx): 131 | raise UnableToDecompose() 132 | 133 | def __add__(self, other): 134 | other = self._align_scalar(other) 135 | return BinOp(Add(), self, other) 136 | 137 | def __mul__(self, other): 138 | other = self._align_scalar(other) 139 | return BinOp(Mul(), self, other) 140 | 141 | def __truediv__(self, other): 142 | other = self._align_scalar(other) 143 | return BinOp(Div(), self, other) 144 | 145 | def __sub__(self, other): 146 | other = self._align_scalar(other) 147 | return BinOp(Sub(), self, other) 148 | 149 | def __matmul__(self, other): 150 | if self._dtype == DTYPE_MATRIX: 151 | return BinOp(MatVecProduct(), self, other) 152 | else: 153 | return BinOp(DotProduct(), self, other) 154 | 155 | def __le__(self, other): 156 | other = self._align_scalar(other) 157 | return BinOp(LtEq(), self, other) 158 | 159 | def __eq__(self, other): 160 | other = self._align_scalar(other) 161 | return BinOp(Eq(), self, other) 162 | 163 | def __call__(self, arg, arg2=None): 164 | if arg2 is not None: 165 | return BinOp(self, arg, arg2) 166 | else: 167 | return UnaryFunc(self, arg) 168 | 169 | def _align_scalar(self, other): 170 | # If a scalar is given in a binary operation, 171 | # we try to align it in length with us 172 | # This is hackish and should be used with caution 173 | if np.isscalar(other): 174 | self_type = self._dtype 175 | if self_type == DTYPE_SCALAR: 176 | return NumberConstant(other) 177 | elif self_type == DTYPE_VECTOR: 178 | return VectorConstant([other]*len(self)) 179 | else: 180 | raise ASTTypeError("Unable to align scalar with a node of type {0}".format(self.__class__.__name__)) 181 | else: 182 | return other 183 | 184 | # Type inference 185 | def _compute_dtype(self): 186 | return self._default_dtype 187 | 188 | # Convenience routines for evaluating AST nodes 189 | def lambdify(self): 190 | """ 191 | Converts the SKAST expression to an executable Python function. 192 | 193 | >>> from .toskast.string import translate as skast 194 | >>> skast("12.4 * (X[1] + Y)").lambdify()(**{'X': [10, 20, 30], 'Y': 1.2}) 195 | 262.88 196 | >>> skast("122.45 + 1").lambdify()() 197 | 123.45 198 | >>> skast("x[0]").lambdify()(x=[[1]]) 199 | [1] 200 | >>> skast("a=1; a").lambdify()() 201 | 1 202 | >>> skast("a=b; b=b+1; c=b+b; a+b+c").lambdify()(b=1.0) 203 | 7.0 204 | """ 205 | from .fromskast import python 206 | 207 | return python.lambdify(python.translate(self)) 208 | 209 | def evaluate(self, **inputs): 210 | "Convenience routine for evaluating expressions" 211 | return self.lambdify()(**inputs) 212 | 213 | # The main convenience function, which converts the node to any of the 214 | # supported formats 215 | def to(self, target, *args, **kw): 216 | """Convenience routine for converting expressions to any 217 | supported output form. 218 | See project documentation for detailed explanation. 219 | 220 | Equivalent to skompiler.fromskast..translate(self, dialect[1], *args, **kw) 221 | (where dialect[0] and dialect[1] denote the parts to the left and right of the '/' in the 222 | dialect parameter) 223 | 224 | Args: 225 | 226 | target (str): The target value. Possible values are: 227 | 228 | - 'sqlalchemy', 229 | - 'sqlalchemy/', where is on of the supported 230 | SQLAlchemy dialects ('firebird', 'mssql', 'mysql', 'oracle', 231 | 'postgresql', 'sqlite', 'sybase') 232 | - 'sympy' 233 | - 'sympy/', where is either of: 234 | 'c', 'cxx', 'rust', 'fortran', 'js', 'r', 'julia', 235 | 'mathematica', 'octave' 236 | - 'excel' 237 | - 'pfa' 238 | - 'pfa/json' 239 | - 'pfa/yaml' 240 | - 'python' 241 | - 'python/code' 242 | - 'python/lambda' 243 | - 'string' 244 | 245 | Kwargs: 246 | 247 | *args, **kw: Extra arguments are passed to the translation function. 248 | 249 | The most important for `sqlalchemy` and `sympy/` dialects are: 250 | 251 | assign_to: When not None, the generated code outputs the result to a variable 252 | (or variables, or columns) with given name(s). 253 | component: When not None and the expression produces a vector, only the 254 | specified component of the vector is output (0-indexed) 255 | 256 | For more info, see documentation for 257 | skompiler.fromskast..translate. 258 | 259 | '""" 260 | translator, *dialect = target.split('/') 261 | if translator not in TRANSLATORS: 262 | raise ValueError("Invalid translator: {0}".format(translator)) 263 | translator = TRANSLATORS[translator] 264 | if not hasattr(translator, '__call__'): 265 | module, callable = translator.split(':') 266 | mod = import_module(module) 267 | translator = getattr(mod, callable) 268 | return translator(self, *dialect, *args, **kw) 269 | #endregion 270 | 271 | 272 | #region AST node types --------------------------------------------- 273 | 274 | # Unary operators and functions 275 | class UnaryFunc(ASTNode, fields='op arg', repr='{op}({arg})'): 276 | def __len__(self): 277 | if isinstance(self.op, IsElemwise): 278 | return len(self.arg) 279 | elif getattr(self.op, '_out_dtype', None) == DTYPE_SCALAR: 280 | return 1 281 | elif isinstance(self.op, Softmax): 282 | return len(self.arg) 283 | else: 284 | raise UnableToDecompose() 285 | 286 | def __getitem__(self, index): 287 | if isinstance(self.op, IsElemwise): 288 | return UnaryFunc(self.op, self.arg[index]) 289 | else: 290 | raise UnableToDecompose() 291 | 292 | def _compute_dtype(self): 293 | if self.arg._dtype not in [DTYPE_UNKNOWN, DTYPE_SCALAR, DTYPE_VECTOR]: 294 | raise ASTTypeError() 295 | if isinstance(self.op, IsElemwise): 296 | return self.arg._dtype 297 | else: 298 | return self.op._out_dtype 299 | 300 | # Some unary functions distribute over vectors. We mark them as such 301 | class IsElemwise: pass 302 | class USub(ASTNode, IsElemwise, repr='-'): pass # Unary minus operator 303 | class Exp(ASTNode, IsElemwise, repr='exp'): pass 304 | class Log(ASTNode, IsElemwise, repr='log'): pass 305 | class Abs(ASTNode, IsElemwise, repr='abs'): pass 306 | class Sqrt(ASTNode, IsElemwise, repr='sqrt'): pass 307 | class Step(ASTNode, IsElemwise, repr='step'): pass # Heaviside step (x > 0) 308 | class Sigmoid(ASTNode, IsElemwise, repr='sigmoid'): pass 309 | 310 | # Some functions take vector arguments but do not distribute elementwise 311 | class VecSum(ASTNode, repr='sum'): 312 | _out_dtype = DTYPE_SCALAR 313 | class VecMax(ASTNode, repr='max'): 314 | _out_dtype = DTYPE_SCALAR 315 | class ArgMax(ASTNode, repr='argmax'): 316 | _out_dtype = DTYPE_SCALAR 317 | class Softmax(ASTNode, repr='softmax'): 318 | _out_dtype = DTYPE_VECTOR 319 | 320 | # Binary operators 321 | def _common_dtype(nodes): 322 | dtypes = {n._dtype for n in nodes if n._dtype is not DTYPE_UNKNOWN} 323 | if not dtypes: 324 | return DTYPE_UNKNOWN 325 | elif len(dtypes) == 2: 326 | raise ASTTypeError("Mismatching operand types") 327 | else: 328 | result = next(iter(dtypes)) 329 | if result not in [DTYPE_SCALAR, DTYPE_VECTOR]: 330 | raise ASTTypeError("Arguments must be scalars or vectors") 331 | return result 332 | 333 | 334 | class BinOp(ASTNode, fields='op left right', repr='({left} {op} {right})'): 335 | def __len__(self): 336 | if isinstance(self.op, IsElemwise) or isinstance(self.op, MatVecProduct): 337 | return len(self.left) 338 | else: 339 | raise UnableToDecompose() 340 | 341 | def __getitem__(self, index): 342 | if isinstance(self.op, IsElemwise): 343 | return BinOp(self.op, self.left[index], self.right[index]) 344 | elif isinstance(self.op, MatVecProduct): 345 | return BinOp(DotProduct(), self.left[index], self.right) 346 | else: 347 | raise UnableToDecompose() 348 | 349 | def _compute_dtype(self): 350 | if isinstance(self.op, IsElemwise): 351 | return _common_dtype([self.left, self.right]) 352 | else: 353 | return self.op._out_dtype 354 | 355 | class Mul(ASTNode, IsElemwise, repr='*'): pass 356 | class Add(ASTNode, IsElemwise, repr='+'): pass 357 | class Sub(ASTNode, IsElemwise, repr='-'): pass 358 | class Div(ASTNode, IsElemwise, repr='/'): pass 359 | class Max(ASTNode, IsElemwise, repr='max'): pass 360 | class DotProduct(ASTNode, repr='v@v'): 361 | _out_dtype = DTYPE_SCALAR 362 | class MatVecProduct(ASTNode, repr='m@v'): 363 | _out_dtype = DTYPE_VECTOR 364 | 365 | class LFold(ASTNode, fields='op elems'): 366 | """ 367 | Left-associative fold of an operator over a list of arguments 368 | E.g. sum(xs) := LFold(Add(), xs) 369 | This could be represented as a sequence of binary ops, but having a 370 | dedicated operator may significantly reduce the depth of AST trees with long sums 371 | which you may find in ensemble classifiers. 372 | Deep trees are bad because you run the risk of hitting system recursion limit when processing 373 | them via recursive parsers (as is the case currently) 374 | """ 375 | 376 | def __str__(self): 377 | return str(self.op).join(str(e) for e in self.elems) 378 | 379 | def __getitem__(self, idx): 380 | return LFold(self.op, [el[idx] for el in self.elems]) 381 | 382 | def __len__(self): 383 | if not self.elems: 384 | raise ASTTypeError("Empty LFolds are not allowed") 385 | return len(self.elems[0]) 386 | 387 | def _compute_dtype(self): 388 | return _common_dtype(self.elems) 389 | 390 | # Boolean binary ops 391 | class IsBoolean: pass 392 | class LtEq(ASTNode, IsElemwise, IsBoolean, repr='<='): pass 393 | class Eq(ASTNode, IsElemwise, IsBoolean, repr='=='): pass 394 | 395 | 396 | # IfThenElse 397 | class IfThenElse(ASTNode, fields='test iftrue iffalse', repr='(if {test} then {iftrue} else {iffalse})'): 398 | def __len__(self): 399 | return len(self.iftrue) 400 | 401 | def __getitem__(self, index): 402 | return IfThenElse(self.test, self.iftrue[index], self.iffalse[index]) 403 | 404 | def _compute_dtype(self): 405 | return _common_dtype([self.iftrue, self.iffalse]) 406 | 407 | # Special function 408 | class MakeVector(ASTNode, fields='elems', dtype=DTYPE_VECTOR): 409 | def __str__(self): 410 | elems = ', '.join(str(e) for e in self.elems) 411 | return '[{0}]'.format(elems) 412 | 413 | def __getitem__(self, idx): 414 | return self.elems[idx] 415 | 416 | def __len__(self): 417 | return len(self.elems) 418 | 419 | # Leaf nodes 420 | class IsAtom: pass 421 | class IsInput: pass 422 | class VectorIdentifier(ASTNode, IsAtom, IsInput, fields='id size', repr='{id}', dtype=DTYPE_VECTOR): 423 | def __getitem__(self, index): 424 | return IndexedIdentifier(self.id, index, self.size) 425 | def __len__(self): 426 | return self.size 427 | 428 | class Identifier(ASTNode, IsAtom, IsInput, fields='id', repr='{id}', dtype=DTYPE_SCALAR): pass 429 | 430 | # Note that IndexedIdentifier is not a generic subscript operator. Its field must contain a string id and an integer index as well as the 431 | # total size of the vector being indexed. 432 | # This lets us "fake" vector input variables in contexts like SQL, where we interpret IndexedIdentifier("x", 1, 10) as a concatenated name "x1" 433 | class IndexedIdentifier(ASTNode, IsAtom, IsInput, fields='id index size', repr='{id}[{index}]', dtype=DTYPE_SCALAR): pass 434 | class NumberConstant(ASTNode, IsAtom, fields='value', repr='{value}', dtype=DTYPE_SCALAR): pass 435 | class VectorConstant(ASTNode, IsAtom, fields='value', repr='{value}', dtype=DTYPE_VECTOR): 436 | def __getitem__(self, index): 437 | return NumberConstant(self.value[index]) 438 | def __len__(self): 439 | return len(self.value) 440 | 441 | class MatrixConstant(ASTNode, IsAtom, fields='value', repr='{value}', dtype=DTYPE_MATRIX): 442 | def __len__(self): 443 | return len(self.value) 444 | def __getitem__(self, index): 445 | return VectorConstant(self.value[index]) 446 | 447 | # Variable definitions 448 | # NB: at the moment let-scopes do not capture the outside variables. 449 | # I.e. all the references inside the definitions and body of a let expressions are assumed to 450 | # refer to variables defined in this let scope 451 | class Let(ASTNode, fields='defs body'): 452 | def __str__(self): 453 | defs = ';\n'.join(str(d) for d in self.defs) 454 | return '{\n' + defs + ';\n' + str(self.body) + '\n}' 455 | def __len__(self): 456 | return len(self.body) 457 | def __getitem__(self, idx): 458 | return self.body[idx] 459 | def _compute_dtype(self): 460 | return self.body._dtype 461 | 462 | class Definition(ASTNode, fields='name body', repr='${name} = {body}'): 463 | def _compute_dtype(self): 464 | return self.body._dtype 465 | 466 | class Reference(ASTNode, fields='name', repr='${name}', dtype=DTYPE_UNKNOWN): pass 467 | 468 | # Sometimes marking the type and dimension of the object referenced to is convenient 469 | class TypedReference(ASTNode, fields='name dtype size', repr='${name}'): 470 | def _compute_dtype(self): 471 | return self.dtype 472 | def __len__(self): 473 | return self.size 474 | 475 | #endregion 476 | 477 | 478 | #region Node processing functions ----------------------- 479 | def map_list(node_list, fn, call_on_enter=False, call_on_exit=True): 480 | new_list = [] 481 | changed = False 482 | for node in node_list: 483 | new_node = map_tree(node, fn, call_on_enter, call_on_exit) 484 | if new_node is not node: 485 | changed = True 486 | new_list.append(new_node) 487 | return new_list if changed else node_list 488 | 489 | 490 | def map_tree(node, fn, call_on_enter=False, call_on_exit=True): 491 | """Applies a function to each node in the tree. 492 | 493 | Args: 494 | fn: function, which must accept (node, is_entering) and return a node. 495 | 496 | Kwargs: 497 | call_on_enter: Invoke the function (potentially replacing the node) every time before entering the node. 498 | call_on_exit: Invoke the function (potentially replacing the node) every time before leavin the node. 499 | """ 500 | if call_on_enter: node = fn(node, True) 501 | 502 | updates = {} 503 | for fname, fval in node: 504 | if isinstance(fval, ASTNode): 505 | new_val = map_tree(fval, fn, call_on_enter, call_on_exit) 506 | elif fname in ['elems', 'defs']: # Special case for MakeVector and Let nodes 507 | new_val = map_list(fval, fn, call_on_enter, call_on_exit) 508 | else: 509 | new_val = fval 510 | if new_val is not fval: 511 | updates[fname] = new_val 512 | node = replace(node, **updates) 513 | 514 | if call_on_exit: node = fn(node, False) 515 | return node 516 | 517 | 518 | def substitute_references(node, definitions): 519 | """Substitutes all references in the given expression with ones from the `definitions` dictionary. 520 | If any substitutions were made, returns a new node. Otherwise returns the same node. 521 | 522 | If references with no matches are found, raises a ValueError. 523 | 524 | Args: 525 | definitions (dict): a dictionary (name -> expr) 526 | 527 | >>> expr = BinOp(Add(), Identifier('x'), NumberConstant(2)) 528 | >>> substitute_references(expr, {}) is expr 529 | True 530 | >>> expr = replace(expr, left = Reference('x')) 531 | >>> substitute_references(expr, {}) 532 | Traceback (most recent call last): 533 | ... 534 | ValueError: Unknown variable reference: x 535 | >>> new_expr = substitute_references(expr, {'x': expr}) 536 | >>> print(new_expr) # Only one level of references is expanded 537 | (($x + 2) + 2) 538 | >>> assert new_expr is not expr 539 | >>> assert new_expr.right is expr.right 540 | """ 541 | def fn(node, _): 542 | if isinstance(node, Reference) or isinstance(node, TypedReference): 543 | if node.name not in definitions: 544 | raise ValueError("Unknown variable reference: " + node.name) 545 | else: 546 | return definitions[node.name] 547 | else: 548 | return node 549 | 550 | return map_tree(node, fn, call_on_enter=False, call_on_exit=True) 551 | 552 | 553 | def inline_definitions(let_expr): 554 | """Given a Let expression, substitutes all the definitions and returns a single 555 | evaluatable non-let expression. 556 | 557 | >>> from .toskast.string import translate as skast 558 | >>> expr = inline_definitions(skast('a=1; a')) 559 | >>> print(str(expr)) 560 | 1 561 | >>> expr = inline_definitions(skast('a=1; b=a+a+2; c=b+b+3; a+b+c+X')) 562 | >>> print(str(expr)) 563 | (((1 + ((1 + 1) + 2)) + ((((1 + 1) + 2) + ((1 + 1) + 2)) + 3)) + X) 564 | >>> expr = inline_definitions(skast('X+Y')) 565 | Traceback (most recent call last): 566 | ... 567 | ValueError: Let expression expected 568 | """ 569 | if not isinstance(let_expr, Let): 570 | raise ValueError("Let expression expected") 571 | 572 | let_expr = merge_let_scopes(let_expr) 573 | 574 | defs = {} 575 | for defn in let_expr.defs: 576 | defs[defn.name] = substitute_references(defn.body, defs) 577 | 578 | return substitute_references(let_expr.body, defs) 579 | 580 | 581 | # Let expression extraction 582 | def _scope_id_gen(): 583 | yield '' # Do not mangle the scope of the first let we find 584 | yield from map('_{0}'.format, count()) 585 | 586 | class LetCollector: 587 | def __init__(self): 588 | self.definitions = [] 589 | self.scopes = [] 590 | self.temp_ids = _scope_id_gen() 591 | 592 | def __call__(self, node, entering): 593 | if isinstance(node, Let): 594 | if entering: 595 | name = next(self.temp_ids) 596 | self.scopes.append(name) 597 | else: 598 | self.scopes.pop() 599 | return node.body 600 | elif isinstance(node, Definition) and not entering: 601 | self.definitions.append(replace(node, name='{0}{1}'.format(self.scopes[-1], node.name))) 602 | elif (isinstance(node, Reference) or isinstance(node, TypedReference)) and not entering: 603 | if not self.scopes: 604 | raise ValueError("Undefined reference: {0}".format(node.name)) 605 | return replace(node, name='{0}{1}'.format(self.scopes[-1], node.name)) 606 | return node 607 | 608 | 609 | def merge_let_scopes(expr): 610 | """ 611 | Given an expression which may potentially contain Let subexpressions inside, 612 | bubbles them all up and returns a single Let expression. 613 | If the original expression contained no Let subexpressions, returns it as-is. 614 | 615 | >>> from skompiler.toskast.string import translate as skast 616 | >>> let1 = skast("x=1; y=2; x+y") # 3 617 | >>> let2 = skast("x=3; y=4; x+2*y") # 11 618 | >>> let3 = skast("x=0; y=0; 3*x+y") # 20 619 | >>> let3.defs[0].__dict__['body']=let1 620 | >>> let3.defs[1].__dict__['body']=let2 621 | >>> merged = merge_let_scopes(let3) 622 | >>> print(merged) 623 | { 624 | $_0x = 1; 625 | $_0y = 2; 626 | $x = ($_0x + $_0y); 627 | $_1x = 3; 628 | $_1y = 4; 629 | $y = ($_1x + (2 * $_1y)); 630 | ((3 * $x) + $y) 631 | } 632 | >>> print(inline_definitions(merged)) 633 | ((3 * (1 + 2)) + (3 + (2 * 4))) 634 | """ 635 | lc = LetCollector() 636 | expr = map_tree(expr, lc, True, True) 637 | if not lc.definitions: 638 | return expr 639 | else: 640 | return Let(lc.definitions, expr) 641 | 642 | 643 | # Node copying and field replacing methods 644 | # NB: We don't implement them as methods to avoid potential conflicts with 645 | # field names ("lambdify", "evaluate" and "to" are the only exceptions as they might be used more often) 646 | 647 | def copy(node): 648 | """ 649 | Copies the node. 650 | 651 | >>> o = BinOp(Add(), NumberConstant(2), NumberConstant(3)) 652 | >>> o2 = copy(o) 653 | >>> o.__dict__['op'] = Sub() 654 | >>> print(o, o2) 655 | (2 - 3) (2 + 3) 656 | """ 657 | return node.__class__(**dict(node)) 658 | 659 | def replace(node, **kw): 660 | """Equivalent to namedtuple's _replace. 661 | When **kw is empty simply returns node itself. 662 | 663 | >>> o = BinOp(Add(), NumberConstant(2), NumberConstant(3)) 664 | >>> o2 = replace(o, op=Sub(), left=NumberConstant(3)) 665 | >>> print(o, o2) 666 | (2 + 3) (3 - 3) 667 | """ 668 | if kw: 669 | vals = dict(node) 670 | vals.update(**kw) 671 | return node.__class__(**vals) 672 | else: 673 | return node 674 | 675 | def decompose(node): 676 | return [node[i] for i in range(len(node))] 677 | 678 | #endregion 679 | --------------------------------------------------------------------------------