├── .gitattributes ├── .gitignore ├── .travis.yml ├── CHANGES.txt ├── LICENSE.txt ├── MANIFEST.in ├── Makefile ├── README.md ├── appveyor.yml ├── benchmarks ├── fast.py └── pyearth_vs_earth.py ├── conda-recipe ├── bld.bat ├── build.sh ├── meta.yaml ├── move-conda-package.py └── run_test.py ├── description.md ├── doc ├── Makefile ├── README.md ├── conf.py ├── content.rst ├── earth_bibliography.bib ├── generate_figures.py ├── index.rst ├── make.bat └── xkcdify.py ├── examples ├── Profile.prof ├── README.txt ├── plot_classifier_comp.py ├── plot_derivatives.py ├── plot_feature_importance.py ├── plot_linear_function.py ├── plot_missing_data_problem.py ├── plot_multicolumn.py ├── plot_output_weight.py ├── plot_sine_wave.py ├── plot_sine_wave_2d.py ├── plot_v_function.py └── return_sympy.py ├── pyearth ├── __init__.py ├── _basis.c ├── _basis.pxd ├── _basis.pyx ├── _forward.c ├── _forward.pxd ├── _forward.pyx ├── _forward.pyxdep ├── _knot_search.c ├── _knot_search.pxd ├── _knot_search.pyx ├── _knot_search.pyxdep ├── _pruning.c ├── _pruning.pxd ├── _pruning.pyx ├── _pruning.pyxdep ├── _qr.c ├── _qr.pxd ├── _qr.pyx ├── _qr.pyxdep ├── _record.c ├── _record.pxd ├── _record.pyx ├── _record.pyxdep ├── _types.c ├── _types.pxd ├── _types.pyx ├── _util.c ├── _util.pxd ├── _util.pyx ├── _util.pyxdep ├── _version.py ├── earth.py ├── export.py └── test │ ├── __init__.py │ ├── basis │ ├── __init__.py │ ├── base.py │ ├── test_basis.py │ ├── test_constant.py │ ├── test_hinge.py │ ├── test_linear.py │ ├── test_missingness.py │ └── test_smoothed_hinge.py │ ├── earth_linvars_regress.txt │ ├── earth_regress.txt │ ├── earth_regress_missing_data.txt │ ├── earth_regress_smooth.txt │ ├── forward_regress.txt │ ├── pathological_data │ ├── issue_44.csv │ ├── issue_44.txt │ ├── issue_50.csv │ ├── issue_50.txt │ ├── issue_50_weight.csv │ └── readme.txt │ ├── record │ ├── __init__.py │ ├── test_forward_pass.py │ └── test_pruning_pass.py │ ├── test_data.csv │ ├── test_earth.py │ ├── test_export.py │ ├── test_forward.py │ ├── test_knot_search.py │ ├── test_pruning.py │ ├── test_qr.py │ ├── test_util.py │ └── testing_utils.py ├── setup.cfg ├── setup.py └── versioneer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | pyearth/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.exp 2 | *.lib 3 | *.o 4 | *.obj 5 | *.pyc 6 | *.pyd 7 | *.so 8 | .DS_Store 9 | .project 10 | .pydevproject 11 | .settings/org.eclipse.core.resources.prefs 12 | MANIFEST 13 | dist 14 | build 15 | dist/py-earth-0.1.0.tar.gz 16 | doc/Humor-Sans.ttf 17 | doc/_build/* 18 | doc/hinge.png 19 | doc/piecewise_linear.png 20 | doc/simple_earth_example.png 21 | doc/auto_examples 22 | examples/classifier_comp.pdf 23 | examples/comparison.csv 24 | scripts/comparison.csv 25 | py_earth.egg-info 26 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 2.7 4 | - 3.4 5 | # Setup anaconda 6 | before_install: 7 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 8 | - chmod +x miniconda.sh 9 | - ./miniconda.sh -b 10 | - export PATH=/home/travis/miniconda2/bin:$PATH 11 | - conda update --yes conda 12 | # The next couple lines fix a crash with multiprocessing on Travis and are not specific to using Miniconda 13 | - sudo rm -rf /dev/shm 14 | - sudo ln -s /run/shm /dev/shm 15 | # Install packages 16 | install: 17 | - conda install --yes cython numpy scipy matplotlib nose dateutil pandas patsy statsmodels scikit-learn sympy 18 | - python setup.py build_ext --inplace --cythonize 19 | script: nosetests -s -v pyearth 20 | -------------------------------------------------------------------------------- /CHANGES.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/py-earth/b209d1916f051dbea5b142af25425df2de469c5a/CHANGES.txt -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, Jason Rudy 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | Neither the name of Clinicast, Inc., nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyearth/*.c 2 | include pyearth/*.pxd 3 | include pyearth/test/*.txt 4 | include pyearth/test/*.csv 5 | include pyearth/test/pathological_data 6 | include README.md 7 | include LICENSE.txt 8 | include CHANGES.txt 9 | include versioneer.py 10 | include pyearth/_version.py 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTHON ?= python 2 | CYTHON ?= cython 3 | NOSETESTS ?= nosetests 4 | CYTHONSRC=$(wildcard pyearth/*.pyx) 5 | CSRC=$(CYTHONSRC:.pyx=.c) 6 | 7 | inplace: cython 8 | $(PYTHON) setup.py build_ext -i 9 | 10 | all: inplace 11 | 12 | cython: $(CSRC) 13 | 14 | clean: 15 | rm -f pyearth/*.c pyearth/*.so pyearth/*.pyc pyearth/test/*.pyc pyearth/test/basis/*.pyc pyearth/test/record/*.pyc 16 | 17 | %.c: %.pyx 18 | $(CYTHON) $< 19 | 20 | test: inplace 21 | $(NOSETESTS) -s pyearth 22 | 23 | test-coverage: inplace 24 | $(NOSETESTS) -s --with-coverage --cover-html --cover-html-dir=coverage --cover-package=pyearth pyearth 25 | 26 | verbose-test: inplace 27 | $(NOSETESTS) -sv pyearth 28 | 29 | conda: 30 | conda-build conda-recipe 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | py-earth [![Build Status](https://travis-ci.org/scikit-learn-contrib/py-earth.png?branch=master)](https://travis-ci.org/scikit-learn-contrib/py-earth?branch=master) 2 | ======== 3 | 4 | A Python implementation of Jerome Friedman's Multivariate Adaptive Regression Splines algorithm, 5 | in the style of scikit-learn. The py-earth package implements Multivariate Adaptive Regression Splines using Cython and provides an interface that is compatible with scikit-learn's Estimator, Predictor, Transformer, and Model interfaces. For more information about 6 | Multivariate Adaptive Regression Splines, see the references below. 7 | 8 | ## Now With Missing Data Support! 9 | 10 | The py-earth package now supports missingness in its predictors. Just set `allow_missing=True` when constructing an `Earth` object. 11 | 12 | ## Requesting Feedback 13 | 14 | If there are other features or improvements you'd like to see in py-earth, please send me an email or open or comment on an issue. In particular, please let me know if any of the following are important to you: 15 | 16 | 1. Improved speed 17 | 2. Exporting models to additional formats 18 | 3. Support for shared memory multiprocessing during fitting 19 | 4. Support for cyclic predictors (such as time of day) 20 | 5. Better support for categorical predictors 21 | 6. Better support for large data sets 22 | 7. Iterative reweighting during fitting 23 | 24 | ## Installation 25 | 26 | Make sure you have numpy and scikit-learn installed. Then do the following: 27 | 28 | ``` 29 | git clone git://github.com/scikit-learn-contrib/py-earth.git 30 | cd py-earth 31 | sudo python setup.py install 32 | ``` 33 | 34 | ## Usage 35 | ```python 36 | import numpy 37 | from pyearth import Earth 38 | from matplotlib import pyplot 39 | 40 | #Create some fake data 41 | numpy.random.seed(0) 42 | m = 1000 43 | n = 10 44 | X = 80*numpy.random.uniform(size=(m,n)) - 40 45 | y = numpy.abs(X[:,6] - 4.0) + 1*numpy.random.normal(size=m) 46 | 47 | #Fit an Earth model 48 | model = Earth() 49 | model.fit(X,y) 50 | 51 | #Print the model 52 | print(model.trace()) 53 | print(model.summary()) 54 | 55 | #Plot the model 56 | y_hat = model.predict(X) 57 | pyplot.figure() 58 | pyplot.plot(X[:,6],y,'r.') 59 | pyplot.plot(X[:,6],y_hat,'b.') 60 | pyplot.xlabel('x_6') 61 | pyplot.ylabel('y') 62 | pyplot.title('Simple Earth Example') 63 | pyplot.show() 64 | ``` 65 | 66 | ## Other Implementations 67 | 68 | I am aware of the following implementations of Multivariate Adaptive Regression Splines: 69 | 70 | 1. The R package earth (coded in C by Stephen Millborrow): http://cran.r-project.org/web/packages/earth/index.html 71 | 2. The R package mda (coded in Fortran by Trevor Hastie and Robert Tibshirani): http://cran.r-project.org/web/packages/mda/index.html 72 | 3. The Orange data mining library for Python (uses the C code from 1): http://orange.biolab.si/ 73 | 4. The xtal package (uses Fortran code written in 1991 by Jerome Friedman): http://www.ece.umn.edu/users/cherkass/ee4389/xtalpackage.html 74 | 5. MARSplines by StatSoft: http://www.statsoft.com/textbook/multivariate-adaptive-regression-splines/ 75 | 6. MARS by Salford Systems (also uses Friedman's code): http://www.salford-systems.com/products/mars 76 | 7. ARESLab (written in Matlab by Gints Jekabsons): http://www.cs.rtu.lv/jekabsons/regression.html 77 | 78 | The R package earth was most useful to me in understanding the algorithm, particularly because of Stephen Milborrow's 79 | thorough and easy to read vignette (http://www.milbo.org/doc/earth-notes.pdf). 80 | 81 | ## References 82 | 83 | 1. Friedman, J. (1991). Multivariate adaptive regression splines. The annals of statistics, 84 | 19(1), 1–67. http://www.jstor.org/stable/10.2307/2241837 85 | 2. Stephen Milborrow. Derived from mda:mars by Trevor Hastie and Rob Tibshirani. 86 | (2012). earth: Multivariate Adaptive Regression Spline Models. R package 87 | version 3.2-3. http://CRAN.R-project.org/package=earth 88 | 3. Friedman, J. (1993). Fast MARS. Stanford University Department of Statistics, Technical Report No 110. 89 | https://statistics.stanford.edu/sites/default/files/LCS%20110.pdf 90 | 4. Friedman, J. (1991). Estimating functions of mixed ordinal and categorical variables using adaptive splines. 91 | Stanford University Department of Statistics, Technical Report No 108. 92 | http://media.salford-systems.com/library/MARS_V2_JHF_LCS-108.pdf 93 | 5. Stewart, G.W. Matrix Algorithms, Volume 1: Basic Decompositions. (1998). Society for Industrial and Applied 94 | Mathematics. 95 | 6. Bjorck, A. Numerical Methods for Least Squares Problems. (1996). Society for Industrial and Applied 96 | Mathematics. 97 | 7. Hastie, T., Tibshirani, R., & Friedman, J. The Elements of Statistical Learning (2nd Edition). (2009). 98 | Springer Series in Statistics 99 | 8. Golub, G., & Van Loan, C. Matrix Computations (3rd Edition). (1996). Johns Hopkins University Press. 100 | 101 | References 7, 2, 1, 3, and 4 contain discussions likely to be useful to users of py-earth. References 1, 2, 6, 5, 102 | 8, 3, and 4 were useful during the implementation process. 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | # AppVeyor.com is a Continuous Integration service to build and run tests under 2 | # Windows 3 | environment: 4 | global: 5 | # SDK v7.0 MSVC Express 2008's SetEnv.cmd script will fail if the 6 | # /E:ON and /V:ON options are not enabled in the batch script interpreter 7 | # See: http://stackoverflow.com/a/13751649/163740 8 | CMD_IN_ENV: "cmd /E:ON /V:ON /C .\\build_tools\\appveyor\\run_with_env.cmd" 9 | 10 | matrix: 11 | - PYTHON: "C:\\Python26" 12 | PYTHON_VERSION: "2.6.9" 13 | PYTHON_ARCH: "32" 14 | MINICONDA: "C:\\Miniconda" 15 | 16 | - PYTHON: "C:\\Python26-x64" 17 | PYTHON_VERSION: "2.6.9" 18 | PYTHON_ARCH: "64" 19 | MINICONDA: "C:\\Miniconda-x64" 20 | 21 | - PYTHON: "C:\\Python27" 22 | PYTHON_VERSION: "2.7.8" 23 | PYTHON_ARCH: "32" 24 | MINICONDA: "C:\\Miniconda" 25 | 26 | - PYTHON: "C:\\Python27-x64" 27 | PYTHON_VERSION: "2.7.8" 28 | PYTHON_ARCH: "64" 29 | MINICONDA: "C:\\Miniconda-x64" 30 | 31 | - PYTHON: "C:\\Python35" 32 | PYTHON_VERSION: "3.5.0" 33 | PYTHON_ARCH: "32" 34 | MINICONDA: "C:\\Miniconda35" 35 | 36 | - PYTHON: "C:\\Python35-x64" 37 | PYTHON_VERSION: "3.5.0" 38 | PYTHON_ARCH: "64" 39 | MINICONDA: "C:\\Miniconda35-x64" 40 | 41 | - PYTHON: "C:\\Python36" 42 | PYTHON_VERSION: "3.6.3" 43 | PYTHON_ARCH: "32" 44 | MINICONDA: "C:\\Miniconda36" 45 | 46 | - PYTHON: "C:\\Python36-x64" 47 | PYTHON_VERSION: "3.6.3" 48 | PYTHON_ARCH: "64" 49 | MINICONDA: "C:\\Miniconda36-x64" 50 | 51 | 52 | 53 | 54 | install: 55 | # Miniconda is pre-installed in the worker build 56 | - "SET PATH=%MINICONDA%;%MINICONDA%\\Scripts;%PATH%" 57 | - "python -m pip install -U pip" 58 | # Check that we have the expected version and architecture for Python 59 | - "python --version" 60 | - "python -c \"import struct; print(struct.calcsize('P') * 8)\"" 61 | - "pip --version" 62 | 63 | # Remove cygwin because it clashes with conda 64 | # see http://help.appveyor.com/discussions/problems/3712-git-remote-https-seems-to-be-broken 65 | - "rmdir C:\\cygwin /s /q" 66 | 67 | # Install the build and runtime dependencies of the project. 68 | - "conda install --quiet --yes six numpy pandas sympy scipy cython nose scikit-learn wheel conda-build" 69 | - "pip install sphinx-gallery" 70 | - "python setup.py bdist_wheel bdist_wininst" 71 | - "python setup.py build_ext --inplace --cythonize" 72 | 73 | - ps: "ls" 74 | 75 | # Install the generated wheel package to test it 76 | - "pip install --pre --no-index --find-links dist/ sklearn-contrib-py-earth" 77 | 78 | # Not a .NET project, we build scikit-learn in the install step instead 79 | build: false 80 | 81 | test_script: 82 | # Change to a non-source folder to make sure we run the tests on the 83 | # # installed library. 84 | - "mkdir empty_folder" 85 | - "cd empty_folder" 86 | 87 | - "python -c \"import nose; nose.main()\" -s -v pyearth" 88 | 89 | # Move back to the project folder 90 | - "cd .." 91 | 92 | artifacts: 93 | # Archive the generated wheel package in the ci.appveyor.com build report. 94 | - path: dist\* 95 | -------------------------------------------------------------------------------- /benchmarks/fast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyearth import Earth 3 | from timeit import Timer 4 | 5 | # The robot arm example, as defined in: 6 | # Fast MARS, Jerome H.Friedman, Technical Report No.110, May 1993, section 6.2. 7 | 8 | np.random.seed(2) 9 | nb_examples = 400 10 | theta1 = np.random.uniform(0, 2 * np.pi, size=nb_examples) 11 | theta2 = np.random.uniform(0, 2 * np.pi, size=nb_examples) 12 | phi = np.random.uniform(-np.pi/2, np.pi/2, size=nb_examples) 13 | l1 = np.random.uniform(0, 1, size=nb_examples) 14 | l2 = np.random.uniform(0, 1, size=nb_examples) 15 | x = l1 * np.cos(theta1) - l2 * np.cos(theta1 + theta2) * np.cos(phi) 16 | y = l1 * np.sin(theta1) - l2 * np.sin(theta1 + theta2) * np.cos(phi) 17 | z = l2 * np.sin(theta2) * np.sin(phi) 18 | d = np.sqrt(x**2 + y**2 + z**2) 19 | 20 | 21 | inputs = np.concatenate([theta1[:, np.newaxis], 22 | theta2[:, np.newaxis], 23 | phi[:, np.newaxis], 24 | l1[:, np.newaxis], 25 | l2[:, np.newaxis]], axis=1) 26 | outputs = d 27 | 28 | hp = dict( 29 | max_degree=5, 30 | minspan=1, 31 | endspan=1, 32 | max_terms=100, 33 | allow_linear=False, 34 | ) 35 | model_normal = Earth(**hp) 36 | t = Timer(lambda: model_normal.fit(inputs, outputs)) 37 | duration_normal = t.timeit(number=1) 38 | print("Normal : MSE={0:.5f}, duration={1:.2f}s". 39 | format(model_normal.mse_, duration_normal)) 40 | model_fast = Earth(use_fast=True, 41 | fast_K=5, 42 | fast_h=1, 43 | **hp) 44 | 45 | t = Timer(lambda: model_fast.fit(inputs, outputs)) 46 | duration_fast = t.timeit(number=1) 47 | print("Fast: MSE={0:.5f}, duration={1:.2f}s". 48 | format(model_fast.mse_, duration_fast)) 49 | speedup = duration_normal / duration_fast 50 | print("diagnostic : MSE goes from {0:.5f} to {1:.5f} but it " 51 | "is {2:.2f}x faster". 52 | format(model_normal.mse_, model_fast.mse_, speedup)) 53 | -------------------------------------------------------------------------------- /benchmarks/pyearth_vs_earth.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script randomly generates earth-style models, then randomly generates data 3 | from those models and fits earth models to those data using both the python and 4 | R implementations. It records the sample size, m, the number of input 5 | dimensions, n, the number of forward pass iterations, the runtime, and the 6 | r^2 statistic for each fit and writes the result to a CSV file. 7 | ''' 8 | 9 | import numpy 10 | import pandas.rpy.common as com 11 | import rpy2.robjects as robjects 12 | import time 13 | import pandas 14 | from pyearth import Earth 15 | 16 | 17 | class DataGenerator(object): 18 | 19 | def __init__(self): 20 | pass 21 | 22 | def generate(self, m): 23 | pass 24 | 25 | 26 | class NoiseGenerator(DataGenerator): 27 | 28 | def __init__(self, n): 29 | self.n = n 30 | 31 | def generate(self, m): 32 | X = numpy.random.normal(size=(m, self.n)) 33 | y = numpy.random.normal(size=m) 34 | return X, y 35 | 36 | 37 | class LinearGenerator(DataGenerator): 38 | 39 | def __init__(self, n): 40 | self.n = n 41 | 42 | def generate(self, m): 43 | X = numpy.random.normal(size=(m, self.n)) 44 | beta = numpy.random.normal(size=self.n) 45 | y = numpy.dot(X, beta) + numpy.random.normal(m) 46 | return X, y 47 | 48 | 49 | class VFunctionGenerator(DataGenerator): 50 | 51 | def __init__(self, n): 52 | self.n = n 53 | 54 | def generate(self, m): 55 | X = numpy.random.normal(size=(m, self.n)) 56 | var = numpy.random.randint(self.n) 57 | y = 10 * abs(X[:, var]) + numpy.random.normal(m) 58 | return X, y 59 | 60 | 61 | class UFunctionGenerator(DataGenerator): 62 | 63 | def __init__(self, n): 64 | self.n = n 65 | 66 | def generate(self, m): 67 | X = numpy.random.normal(size=(m, self.n)) 68 | var = numpy.random.randint(self.n) 69 | y = 10 * (X[:, var] ** 2) + numpy.random.normal(m) 70 | return X, y 71 | 72 | 73 | class RandomComplexityGenerator(DataGenerator): 74 | 75 | def __init__(self, n, max_terms=10, max_degree=2): 76 | self.n = n 77 | self.max_terms = max_terms 78 | self.max_degree = max_degree 79 | 80 | def generate(self, m): 81 | X = numpy.random.normal(size=(m, self.n)) 82 | # Including the intercept 83 | num_terms = numpy.random.randint(2, self.max_terms) 84 | coef = 10 * numpy.random.normal(size=num_terms) 85 | B = numpy.ones(shape=(m, num_terms)) 86 | B[:, 0] += coef[0] 87 | for i in range(1, num_terms): 88 | degree = numpy.random.randint(1, self.max_degree) 89 | for bf in range(degree): 90 | knot = numpy.random.normal() 91 | dir = 1 - 2 * numpy.random.binomial(1, .5) 92 | var = numpy.random.randint(0, self.n) 93 | B[:, i] *= (dir * (X[:, var] - knot)) * \ 94 | (dir * (X[:, var] - knot) > 0) 95 | y = numpy.dot(B, coef) + numpy.random.normal(size=m) 96 | return X, y 97 | 98 | 99 | def run_earth(X, y, **kwargs): 100 | ''' 101 | Run with the R package earth. 102 | Return prediction value, training time, and number 103 | of forward pass iterations. 104 | ''' 105 | r = robjects.r 106 | m, n = X.shape 107 | data = pandas.DataFrame(X) 108 | data['y'] = y 109 | r_data = com.convert_to_r_dataframe(data) 110 | r('library(earth)') 111 | r_func = ''' 112 | run <- function(data, degree=1, fast.k=0, penalty=3.0){ 113 | time = system.time(model <- earth(y~.,data=data,degree=degree,penalty=penalty))[3] 114 | forward_terms = dim(summary(model)$prune.terms)[1] 115 | y_pred = predict(model,data) 116 | return(list(y_pred, time, forward_terms, model)) 117 | } 118 | ''' 119 | r(r_func) 120 | run = r('run') 121 | r_list = run( 122 | **{'data': r_data, 123 | 'degree': kwargs['max_degree'], 124 | 'fast.k': 0, 125 | 'penalty': kwargs['penalty']}) 126 | y_pred = numpy.array(r_list[0]).reshape(m) 127 | time = r_list[1][0] 128 | forward_terms = r_list[2][0] 129 | return y_pred, time, (forward_terms - 1) / 2 130 | 131 | 132 | def run_pyearth(X, y, **kwargs): 133 | ''' 134 | Run with pyearth. 135 | Return prediction value, training time, and number of 136 | forward pass iterations. 137 | ''' 138 | model = Earth(**kwargs) 139 | t0 = time.time() 140 | model.fit(X, y) 141 | t1 = time.time() 142 | y_pred = model.predict(X) 143 | forward_iterations = len(model.forward_trace()) - 1 144 | return y_pred, t1 - t0, forward_iterations 145 | 146 | 147 | def compare(generator_class, sample_sizes, dimensions, repetitions, **kwargs): 148 | ''' 149 | Return a data table that includes m, n, pyearth or earth, training time, 150 | and number of forward pass iterations. 151 | ''' 152 | header = [ 153 | 'm', 154 | 'n', 155 | 'pyearth', 156 | 'earth', 157 | 'time', 158 | 'forward_iterations', 159 | 'rsq'] 160 | data = [] 161 | for n in dimensions: 162 | generator = generator_class(n=n) 163 | for m in sample_sizes: 164 | for rep in range(repetitions): 165 | print n, m, rep 166 | X, y = generator.generate(m=m) 167 | y_pred_r, time_r, iter_r = run_earth(X, y, **kwargs) 168 | rsq_r = 1 - (numpy.sum((y - y_pred_r) ** 2)) / ( 169 | numpy.sum((y - numpy.mean(y)) ** 2)) 170 | data.append([m, n, 0, 1, time_r, iter_r, rsq_r]) 171 | y_pred_py, time_py, iter_py = run_pyearth(X, y, **kwargs) 172 | rsq_py = 1 - \ 173 | (numpy.sum((y - y_pred_py) ** 2)) / ( 174 | numpy.sum((y - numpy.mean(y)) ** 2)) 175 | data.append([m, n, 1, 0, time_py, iter_py, rsq_py]) 176 | return pandas.DataFrame(data, columns=header) 177 | 178 | if __name__ == '__main__': 179 | sample_sizes = [100, 200, 300, 500] 180 | dimensions = [10, 20, 30] 181 | rep = 5 182 | numpy.random.seed(1) 183 | data = compare( 184 | RandomComplexityGenerator, 185 | sample_sizes, 186 | dimensions, 187 | rep, 188 | max_degree=2, 189 | penalty=3.0) 190 | print data 191 | data.to_csv('comparison.csv') 192 | -------------------------------------------------------------------------------- /conda-recipe/bld.bat: -------------------------------------------------------------------------------- 1 | pip install sphinx_gallery 2 | "%PYTHON%" setup.py install 3 | if errorlevel 1 exit 1 4 | -------------------------------------------------------------------------------- /conda-recipe/build.sh: -------------------------------------------------------------------------------- 1 | pip install sphinx_gallery 2 | $PYTHON setup.py install # Python command to install the script. 3 | -------------------------------------------------------------------------------- /conda-recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: pyearth 3 | version: "0.1" 4 | 5 | source: 6 | git_url: https://github.com/scikit-learn-contrib/py-earth.git 7 | 8 | requirements: 9 | build: 10 | - python 11 | - setuptools 12 | - cython 13 | - numpy 14 | - scikit-learn 15 | - nose 16 | - six 17 | run: 18 | - python 19 | - numpy 20 | - scipy 21 | - scikit-learn 22 | - six 23 | 24 | test: 25 | requires: 26 | - numpy 27 | - scipy 28 | - scikit-learn 29 | - nose 30 | imports: 31 | - pyearth 32 | 33 | about: 34 | home: https://github.com/scikit-learn-contrib/py-earth 35 | license: BSD 36 | -------------------------------------------------------------------------------- /conda-recipe/move-conda-package.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import yaml 4 | import glob 5 | import shutil 6 | from conda_build.config import config 7 | 8 | with open(os.path.join(sys.argv[1], 'meta.yaml')) as f: 9 | name = yaml.load(f)['package']['name'] 10 | 11 | binary_package_glob = os.path.join( 12 | config.bldpkgs_dir, '{0}*.tar.bz2'.format(name)) 13 | binary_package = glob.glob(binary_package_glob)[0] 14 | 15 | shutil.move(binary_package, 'dist') 16 | -------------------------------------------------------------------------------- /conda-recipe/run_test.py: -------------------------------------------------------------------------------- 1 | import pyearth 2 | import nose 3 | import os 4 | 5 | pyearth_dir = os.path.dirname( 6 | os.path.abspath(pyearth.__file__)) 7 | os.chdir(pyearth_dir) 8 | nose.run(module=pyearth) 9 | -------------------------------------------------------------------------------- /description.md: -------------------------------------------------------------------------------- 1 | The py-earth package implements Multivariate Adaptive Regression Splines and provides an interface that is compatible with scikit-learn's Estimator, Predictor, Transformer, and Model interfaces. -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 16 | 17 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 18 | 19 | help: 20 | @echo "Please use \`make ' where is one of" 21 | @echo " html to make standalone HTML files" 22 | @echo " dirhtml to make HTML files named index.html in directories" 23 | @echo " singlehtml to make a single large HTML file" 24 | @echo " pickle to make pickle files" 25 | @echo " json to make JSON files" 26 | @echo " htmlhelp to make HTML files and a HTML help project" 27 | @echo " qthelp to make HTML files and a qthelp project" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 31 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 32 | @echo " text to make text files" 33 | @echo " man to make manual pages" 34 | @echo " texinfo to make Texinfo files" 35 | @echo " info to make Texinfo files and run them through makeinfo" 36 | @echo " gettext to make PO message catalogs" 37 | @echo " changes to make an overview of all changed/added/deprecated items" 38 | @echo " linkcheck to check all external links for integrity" 39 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 40 | 41 | clean: 42 | -rm -rf $(BUILDDIR)/* 43 | rm -rf auto_examples/ 44 | rm -rf modules/generated/* 45 | 46 | html: 47 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 48 | @echo 49 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 50 | 51 | dirhtml: 52 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 53 | @echo 54 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 55 | 56 | singlehtml: 57 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 58 | @echo 59 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 60 | 61 | pickle: 62 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 63 | @echo 64 | @echo "Build finished; now you can process the pickle files." 65 | 66 | json: 67 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 68 | @echo 69 | @echo "Build finished; now you can process the JSON files." 70 | 71 | htmlhelp: 72 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 73 | @echo 74 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 75 | ".hhp project file in $(BUILDDIR)/htmlhelp." 76 | 77 | qthelp: 78 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 79 | @echo 80 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 81 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 82 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/py-earth.qhcp" 83 | @echo "To view the help file:" 84 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/py-earth.qhc" 85 | 86 | devhelp: 87 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 88 | @echo 89 | @echo "Build finished." 90 | @echo "To view the help file:" 91 | @echo "# mkdir -p $$HOME/.local/share/devhelp/py-earth" 92 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/py-earth" 93 | @echo "# devhelp" 94 | 95 | epub: 96 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 97 | @echo 98 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 99 | 100 | latex: 101 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 102 | @echo 103 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 104 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 105 | "(use \`make latexpdf' here to do that automatically)." 106 | 107 | latexpdf: 108 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 109 | @echo "Running LaTeX files through pdflatex..." 110 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 111 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 112 | 113 | text: 114 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 115 | @echo 116 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 117 | 118 | man: 119 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 120 | @echo 121 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 122 | 123 | texinfo: 124 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 125 | @echo 126 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 127 | @echo "Run \`make' in that directory to run these through makeinfo" \ 128 | "(use \`make info' here to do that automatically)." 129 | 130 | info: 131 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 132 | @echo "Running Texinfo files through makeinfo..." 133 | make -C $(BUILDDIR)/texinfo info 134 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 135 | 136 | gettext: 137 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 138 | @echo 139 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 140 | 141 | changes: 142 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 143 | @echo 144 | @echo "The overview file is in $(BUILDDIR)/changes." 145 | 146 | linkcheck: 147 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 148 | @echo 149 | @echo "Link check complete; look for any errors in the above output " \ 150 | "or in $(BUILDDIR)/linkcheck/output.txt." 151 | 152 | doctest: 153 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 154 | @echo "Testing of doctests in the sources finished, look at the " \ 155 | "results in $(BUILDDIR)/doctest/output.txt." 156 | 157 | html-noplot: 158 | $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 159 | @echo 160 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 161 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | Py-earth documentation 2 | ---------------------- 3 | 4 | This folder contains all the necessary files for building py-earth 5 | documentation (based on sphinx ). 6 | Building the documentation requires the following packages : matploblib, numpydoc, sphinx-gallery and sphinxcontrib-bibtex.You can install them with pip : 7 | 8 | ``` 9 | pip install matplotlib 10 | pip install numpydoc 11 | pip install sphinx-gallery 12 | pip install sphinxcontrib-bibtex 13 | ``` 14 | 15 | You can then generate html documentation by running : 16 | 17 | ``` 18 | make html 19 | ``` 20 | 21 | Other formats are supported by Sphinx, you can check the supported formats using : 22 | 23 | 24 | ``` 25 | make help 26 | ``` 27 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # py-earth documentation build configuration file, created by 4 | # sphinx-quickstart on Thu Jul 11 21:44:49 2013. 5 | # 6 | # This file is execfile()d with the current directory set to its containing dir. 7 | # 8 | # Note that not all possible configuration values are present in this 9 | # autogenerated file. 10 | # 11 | # All configuration values have a default; values that are commented out 12 | # serve to show the default. 13 | import sys 14 | import os 15 | 16 | # Automatically generate API documentation 17 | #os.system('sphinx-apidoc -o . ' + os.path.join('..','pyearth')) 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | sys.path.insert(0, os.path.abspath('.')) 23 | sys.path.insert(0, os.path.abspath(os.path.join('..'))) 24 | 25 | # Create the figures 26 | import generate_figures 27 | 28 | import sphinx_gallery 29 | 30 | # -- General configuration ----------------------------------------------- 31 | 32 | # If your documentation needs a minimal Sphinx version, state it here. 33 | #needs_sphinx = '1.0' 34 | 35 | # Add any Sphinx extension module names here, as strings. They can be extensions 36 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 37 | extensions = ['sphinx.ext.doctest', 'sphinx.ext.todo', 'sphinx.ext.coverage', 38 | 'sphinx.ext.mathjax', 'sphinx.ext.ifconfig', 39 | 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'numpydoc', 'sphinx.ext.autosummary', 40 | 'sphinxcontrib.bibtex', 'sphinx_gallery.gen_gallery'] 41 | 42 | autosummary_generate = True 43 | numpydoc_show_class_members = False 44 | 45 | autodoc_default_flags = ['members', 'inherited-members'] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ['_templates'] 49 | 50 | # The suffix of source filenames. 51 | source_suffix = '.rst' 52 | 53 | # The encoding of source files. 54 | #source_encoding = 'utf-8-sig' 55 | 56 | # The master toctree document. 57 | master_doc = 'index' 58 | 59 | # General information about the project. 60 | project = u'py-earth' 61 | copyright = u'2013, Jason Rudy' 62 | 63 | # The version info for the project you're documenting, acts as replacement for 64 | # |version| and |release|, also used in various other places throughout the 65 | # built documents. 66 | # 67 | # The short X.Y version. 68 | import pyearth 69 | version = pyearth.__version__ 70 | # The full version, including alpha/beta/rc tags. 71 | release = version 72 | 73 | # The language for content autogenerated by Sphinx. Refer to documentation 74 | # for a list of supported languages. 75 | #language = None 76 | 77 | # There are two options for replacing |today|: either, you set today to some 78 | # non-false value, then it is used: 79 | #today = '' 80 | # Else, today_fmt is used as the format for a strftime call. 81 | #today_fmt = '%B %d, %Y' 82 | 83 | # List of patterns, relative to source directory, that match files and 84 | # directories to ignore when looking for source files. 85 | exclude_patterns = ['_build'] 86 | 87 | # The reST default role (used for this markup: `text`) to use for all documents. 88 | #default_role = None 89 | 90 | # If true, '()' will be appended to :func: etc. cross-reference text. 91 | #add_function_parentheses = True 92 | 93 | # If true, the current module name will be prepended to all description 94 | # unit titles (such as .. function::). 95 | #add_module_names = True 96 | 97 | # If true, sectionauthor and moduleauthor directives will be shown in the 98 | # output. They are ignored by default. 99 | #show_authors = False 100 | 101 | # The name of the Pygments (syntax highlighting) style to use. 102 | pygments_style = 'sphinx' 103 | 104 | # A list of ignored prefixes for module index sorting. 105 | #modindex_common_prefix = [] 106 | 107 | 108 | # -- Options for HTML output --------------------------------------------- 109 | 110 | # The theme to use for HTML and HTML Help pages. See the documentation for 111 | # a list of builtin themes. 112 | html_theme = 'sphinx_rtd_theme' 113 | 114 | # Theme options are theme-specific and customize the look and feel of a theme 115 | # further. For a list of options available for each theme, see the 116 | # documentation. 117 | #html_theme_options = {} 118 | 119 | # Add any paths that contain custom themes here, relative to this directory. 120 | #html_theme_path = [] 121 | 122 | # The name for this set of Sphinx documents. If None, it defaults to 123 | # " v documentation". 124 | #html_title = None 125 | 126 | # A shorter title for the navigation bar. Default is the same as html_title. 127 | #html_short_title = None 128 | 129 | # The name of an image file (relative to this directory) to place at the top 130 | # of the sidebar. 131 | #html_logo = None 132 | 133 | # The name of an image file (within the static path) to use as favicon of the 134 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 135 | # pixels large. 136 | #html_favicon = None 137 | 138 | # Add any paths that contain custom static files (such as style sheets) here, 139 | # relative to this directory. They are copied after the builtin static files, 140 | # so a file named "default.css" will overwrite the builtin "default.css". 141 | html_static_path = ['_static', sphinx_gallery.glr_path_static()] 142 | 143 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 144 | # using the given strftime format. 145 | #html_last_updated_fmt = '%b %d, %Y' 146 | 147 | # If true, SmartyPants will be used to convert quotes and dashes to 148 | # typographically correct entities. 149 | #html_use_smartypants = True 150 | 151 | # Custom sidebar templates, maps document names to template names. 152 | #html_sidebars = {} 153 | 154 | # Additional templates that should be rendered to pages, maps page names to 155 | # template names. 156 | #html_additional_pages = {} 157 | 158 | # If false, no module index is generated. 159 | #html_domain_indices = True 160 | 161 | # If false, no index is generated. 162 | #html_use_index = True 163 | 164 | # If true, the index is split into individual pages for each letter. 165 | #html_split_index = False 166 | 167 | # If true, links to the reST sources are added to the pages. 168 | #html_show_sourcelink = True 169 | 170 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 171 | #html_show_sphinx = True 172 | 173 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 174 | #html_show_copyright = True 175 | 176 | # If true, an OpenSearch description file will be output, and all pages will 177 | # contain a tag referring to it. The value of this option must be the 178 | # base URL from which the finished HTML is served. 179 | #html_use_opensearch = '' 180 | 181 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 182 | #html_file_suffix = None 183 | 184 | # Output file base name for HTML help builder. 185 | htmlhelp_basename = 'py-earth_doc' 186 | 187 | 188 | # -- Options for LaTeX output -------------------------------------------- 189 | 190 | latex_elements = { 191 | # The paper size ('letterpaper' or 'a4paper'). 192 | #'papersize': 'letterpaper', 193 | 194 | # The font size ('10pt', '11pt' or '12pt'). 195 | #'pointsize': '10pt', 196 | 197 | # Additional stuff for the LaTeX preamble. 198 | #'preamble': '', 199 | } 200 | 201 | # Grouping the document tree into LaTeX files. List of tuples 202 | # (source start file, target name, title, author, documentclass [howto/manual]). 203 | latex_documents = [ 204 | ('index', 'py-earth.tex', u'py-earth Documentation', 205 | u'Jason Rudy', 'manual'), 206 | ] 207 | 208 | # The name of an image file (relative to this directory) to place at the top of 209 | # the title page. 210 | #latex_logo = None 211 | 212 | # For "manual" documents, if this is true, then toplevel headings are parts, 213 | # not chapters. 214 | #latex_use_parts = False 215 | 216 | # If true, show page references after internal links. 217 | #latex_show_pagerefs = False 218 | 219 | # If true, show URL addresses after external links. 220 | #latex_show_urls = False 221 | 222 | # Documents to append as an appendix to all manuals. 223 | #latex_appendices = [] 224 | 225 | # If false, no module index is generated. 226 | #latex_domain_indices = True 227 | 228 | 229 | # -- Options for manual page output -------------------------------------- 230 | 231 | # One entry per manual page. List of tuples 232 | # (source start file, name, description, authors, manual section). 233 | man_pages = [ 234 | ('index', 'py-earth', u'py-earth Documentation', 235 | [u'Jason Rudy'], 1) 236 | ] 237 | 238 | # If true, show URL addresses after external links. 239 | #man_show_urls = False 240 | 241 | 242 | # -- Options for Texinfo output ------------------------------------------ 243 | 244 | # Grouping the document tree into Texinfo files. List of tuples 245 | # (source start file, target name, title, author, 246 | # dir menu entry, description, category) 247 | texinfo_documents = [ 248 | ('index', 'py-earth', u'py-earth Documentation', 249 | u'Jason Rudy', 'py-earth', 'One line description of project.', 250 | 'Miscellaneous'), 251 | ] 252 | 253 | # Documents to append as an appendix to all manuals. 254 | #texinfo_appendices = [] 255 | 256 | # If false, no module index is generated. 257 | #texinfo_domain_indices = True 258 | 259 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 260 | #texinfo_show_urls = 'footnote' 261 | 262 | # Create the .nojekyll file for github pages 263 | open(os.path.join('_build', 'html', '.nojekyll'), 'w').close() 264 | 265 | sphinx_gallery_conf = { 266 | # path to your examples scripts 267 | 'examples_dirs' : '../examples', 268 | # path where to save gallery generated examples 269 | 'gallery_dirs' : 'auto_examples', 270 | 'filename_pattern' : '(/plot_.*|return_.*)' 271 | } 272 | -------------------------------------------------------------------------------- /doc/content.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ------------ 3 | 4 | The py-earth package is a Python implementation of Jerome Friedman's Multivariate Adaptive 5 | Regression Splines algorithm, in the style of scikit-learn. For more information about Multivariate 6 | Adaptive Regression Splines, see below. Py-earth is written in Python and Cython. It 7 | provides an interface that is compatible with scikit-learn's Estimator, Predictor, Transformer, and Model 8 | interfaces. Py-earth accommodates input in the form of numpy arrays, pandas DataFrames, patsy DesignMatrix 9 | objects, or most anything that can be converted into an arrray of floats. Fitted models can be pickled for 10 | later use. 11 | 12 | 13 | Multivariate Adaptive Regression Splines 14 | ---------------------------------------- 15 | 16 | Multivariate adaptive regression splines, implemented by the Earth class, is a flexible 17 | regression method that automatically searches for interactions and non-linear 18 | relationships. Earth models can be thought of as linear models in a higher dimensional 19 | basis space. Each term in an Earth model is a product of so called "hinge functions". 20 | A hinge function is a function that's equal to its argument where that argument is greater 21 | than zero and is zero everywhere else. 22 | 23 | .. math:: 24 | \text{h}\left(x-t\right)=\left[x-t\right]_{+}=\begin{cases} 25 | x-t, & x>t\\ 26 | 0, & x\leq t 27 | \end{cases} 28 | 29 | .. image:: hinge.png 30 | 31 | An Earth model is a linear combination of basis functions, each of which is a product of one 32 | or more of the following: 33 | 34 | 1. A constant 35 | 2. Linear functions of input variables 36 | 3. Hinge functions of input variables 37 | 38 | For example, a simple piecewise linear function in one variable can be expressed 39 | as a linear combination of two hinge functions and a constant (see below). During fitting, the Earth class 40 | automatically determines which variables and basis functions to use. 41 | The algorithm has two stages. First, the 42 | forward pass searches for terms that locally minimize squared error loss on the training set. Next, a pruning pass selects a subset of those 43 | terms that produces a locally minimal generalized cross-validation (GCV) score. The GCV 44 | score is not actually based on cross-validation, but rather is meant to approximate a true 45 | cross-validation score by penalizing model complexity. The final result is a set of basis functions 46 | that is nonlinear in the original feature space, may include interactions, and is likely to 47 | generalize well. 48 | 49 | 50 | .. math:: 51 | y=1-2\text{h}\left(1-x\right)+\frac{1}{2}\text{h}\left(x-1\right) 52 | 53 | 54 | .. image:: piecewise_linear.png 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | A Simple Earth Example 63 | ---------------------- 64 | 65 | 66 | :: 67 | 68 | import numpy 69 | from pyearth import Earth 70 | from matplotlib import pyplot 71 | 72 | #Create some fake data 73 | numpy.random.seed(0) 74 | m = 1000 75 | n = 10 76 | X = 80*numpy.random.uniform(size=(m,n)) - 40 77 | y = numpy.abs(X[:,6] - 4.0) + 1*numpy.random.normal(size=m) 78 | 79 | #Fit an Earth model 80 | model = Earth() 81 | model.fit(X,y) 82 | 83 | #Print the model 84 | print(model.trace()) 85 | print(model.summary()) 86 | 87 | #Plot the model 88 | y_hat = model.predict(X) 89 | pyplot.figure() 90 | pyplot.plot(X[:,6],y,'r.') 91 | pyplot.plot(X[:,6],y_hat,'b.') 92 | pyplot.xlabel('x_6') 93 | pyplot.ylabel('y') 94 | pyplot.title('Simple Earth Example') 95 | pyplot.show() 96 | 97 | .. image:: simple_earth_example.png 98 | 99 | Bibliography 100 | ------------ 101 | .. bibliography:: earth_bibliography.bib 102 | 103 | References :cite:`Hastie2009`, :cite:`Millborrow2012`, :cite:`Friedman1991`, :cite:`Friedman1993`, 104 | and :cite:`Friedman1991a` contain discussions likely to be useful to users of py-earth. 105 | References :cite:`Friedman1991`, :cite:`Millborrow2012`, :cite:`Bjorck1996`, :cite:`Stewart1998`, 106 | :cite:`Golub1996`, :cite:`Friedman1993`, and :cite:`Friedman1991a` were useful during the 107 | implementation process. 108 | 109 | 110 | API 111 | --- 112 | 113 | .. autoclass:: pyearth.Earth 114 | -------------------------------------------------------------------------------- /doc/earth_bibliography.bib: -------------------------------------------------------------------------------- 1 | @book{Bjorck1996, 2 | address = {Philadelphia}, 3 | author = {Bjorck, Ake}, 4 | isbn = {0898713609}, 5 | publisher = {Society for Industrial and Applied Mathematics}, 6 | title = {{Numerical Methods for Least Squares Problems}}, 7 | year = {1996} 8 | } 9 | @techreport{Friedman1993, 10 | author = {Friedman, Jerome H.}, 11 | institution = {Stanford University Department of Statistics}, 12 | title = {{Technical Report No. 110: Fast MARS.}}, 13 | url = {http://scholar.google.com/scholar?hl=en\&btnG=Search\&q=intitle:Fast+MARS\#0}, 14 | year = {1993} 15 | } 16 | @techreport{Friedman1991a, 17 | author = {Friedman, JH}, 18 | institution = {Stanford University Department of Statistics}, 19 | publisher = {Stanford University Department of Statistics}, 20 | title = {{Technical Report No. 108: Estimating functions of mixed ordinal and categorical variables using adaptive splines}}, 21 | url = {http://scholar.google.com/scholar?hl=en\&btnG=Search\&q=intitle:Estimating+functions+of+mixed+ordinal+and+categorical+variables+using+adaptive+splines\#0}, 22 | year = {1991} 23 | } 24 | @article{Friedman1991, 25 | author = {Friedman, JH}, 26 | journal = {The annals of statistics}, 27 | number = {1}, 28 | pages = {1--67}, 29 | title = {{Multivariate adaptive regression splines}}, 30 | url = {http://www.jstor.org/stable/10.2307/2241837}, 31 | volume = {19}, 32 | year = {1991} 33 | } 34 | @book{Golub1996, 35 | author = {Golub, Gene and {Van Loan}, Charles}, 36 | edition = {3}, 37 | publisher = {Johns Hopkins University Press}, 38 | title = {{Matrix Computations}}, 39 | year = {1996} 40 | } 41 | @book{Hastie2009, 42 | address = {New York}, 43 | author = {Hastie, Trevor and Tibshirani, Robert and Friedman, Jerome}, 44 | edition = {2}, 45 | publisher = {Springer Science+Business Media}, 46 | title = {{Elements of Statistical Learning: Data Mining, Inference, and Prediction}}, 47 | year = {2009} 48 | } 49 | @book{Stewart1998, 50 | address = {Philadelphia}, 51 | author = {Stewart, G. W.}, 52 | isbn = {0898714141}, 53 | publisher = {Society for Industrial and Applied Mathematics}, 54 | title = {{Matrix Algorithms Volume 1: Basic Decompositions}}, 55 | year = {1998} 56 | } 57 | @misc{Millborrow2012, 58 | author = {Millborrow, Stephen}, 59 | publisher = {CRAN}, 60 | title = {{earth: Multivariate Adaptive Regression Spline Models}}, 61 | url = {http://cran.r-project.org/web/packages/earth/index.html}, 62 | year = {2012} 63 | } 64 | -------------------------------------------------------------------------------- /doc/generate_figures.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('Agg') 3 | import numpy 4 | from pyearth import Earth 5 | from matplotlib import pyplot 6 | 7 | #========================================================================= 8 | # V-Function Example 9 | #========================================================================= 10 | # Create some fake data 11 | numpy.random.seed(0) 12 | m = 1000 13 | n = 10 14 | X = 80 * numpy.random.uniform(size=(m, n)) - 40 15 | y = numpy.abs(X[:, 6] - 4.0) + 1 * numpy.random.normal(size=m) 16 | 17 | # Fit an Earth model 18 | model = Earth() 19 | model.fit(X, y) 20 | 21 | # Print the model 22 | print(model.trace()) 23 | print(model.summary()) 24 | 25 | # Plot the model 26 | y_hat = model.predict(X) 27 | pyplot.figure() 28 | pyplot.plot(X[:, 6], y, 'r.') 29 | pyplot.plot(X[:, 6], y_hat, 'b.') 30 | pyplot.xlabel('x_6') 31 | pyplot.ylabel('y') 32 | pyplot.title('Simple Earth Example') 33 | pyplot.savefig('simple_earth_example.png') 34 | 35 | #========================================================================= 36 | # Hinge plot 37 | #========================================================================= 38 | from xkcdify import XKCDify 39 | x = numpy.arange(-10, 10, .1) 40 | y = x * (x > 0) 41 | 42 | fig = pyplot.figure(figsize=(10, 5)) 43 | pyplot.plot(x, y) 44 | ax = pyplot.gca() 45 | 46 | pyplot.title('Basic Hinge Function') 47 | pyplot.xlabel('x') 48 | pyplot.ylabel('h(x)') 49 | pyplot.annotate('x=t', (0, 0), xytext=(-30, 30), textcoords='offset points', 50 | arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2")) 51 | XKCDify(ax) 52 | pyplot.setp(ax, frame_on=False) 53 | pyplot.savefig('hinge.png') 54 | 55 | #========================================================================= 56 | # Piecewise Linear Plot 57 | #========================================================================= 58 | m = 1000 59 | x = numpy.arange(-10, 10, .1) 60 | y = 1 - 2 * (1 - x) * (x < 1) + 0.5 * (x - 1) * (x > 1) 61 | 62 | pyplot.figure(figsize=(10, 5)) 63 | pyplot.plot(x, y) 64 | ax = pyplot.gca() 65 | pyplot.xlabel('x') 66 | pyplot.ylabel('y') 67 | pyplot.title('Piecewise Linear Function') 68 | XKCDify(ax) 69 | pyplot.setp(ax, frame_on=False) 70 | pyplot.savefig('piecewise_linear.png') 71 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | Py-earth documentation 2 | ---------------------- 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | content 8 | auto_examples/index -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. linkcheck to check all external links for integrity 37 | echo. doctest to run all doctests embedded in the documentation if enabled 38 | goto end 39 | ) 40 | 41 | if "%1" == "clean" ( 42 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 43 | del /q /s %BUILDDIR%\* 44 | goto end 45 | ) 46 | 47 | if "%1" == "html" ( 48 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 49 | if errorlevel 1 exit /b 1 50 | echo. 51 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 52 | goto end 53 | ) 54 | 55 | if "%1" == "dirhtml" ( 56 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 57 | if errorlevel 1 exit /b 1 58 | echo. 59 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 60 | goto end 61 | ) 62 | 63 | if "%1" == "singlehtml" ( 64 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 68 | goto end 69 | ) 70 | 71 | if "%1" == "pickle" ( 72 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished; now you can process the pickle files. 76 | goto end 77 | ) 78 | 79 | if "%1" == "json" ( 80 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished; now you can process the JSON files. 84 | goto end 85 | ) 86 | 87 | if "%1" == "htmlhelp" ( 88 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can run HTML Help Workshop with the ^ 92 | .hhp project file in %BUILDDIR%/htmlhelp. 93 | goto end 94 | ) 95 | 96 | if "%1" == "qthelp" ( 97 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 98 | if errorlevel 1 exit /b 1 99 | echo. 100 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 101 | .qhcp project file in %BUILDDIR%/qthelp, like this: 102 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\py-earth.qhcp 103 | echo.To view the help file: 104 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\py-earth.ghc 105 | goto end 106 | ) 107 | 108 | if "%1" == "devhelp" ( 109 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 110 | if errorlevel 1 exit /b 1 111 | echo. 112 | echo.Build finished. 113 | goto end 114 | ) 115 | 116 | if "%1" == "epub" ( 117 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 118 | if errorlevel 1 exit /b 1 119 | echo. 120 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 121 | goto end 122 | ) 123 | 124 | if "%1" == "latex" ( 125 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 129 | goto end 130 | ) 131 | 132 | if "%1" == "text" ( 133 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The text files are in %BUILDDIR%/text. 137 | goto end 138 | ) 139 | 140 | if "%1" == "man" ( 141 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 145 | goto end 146 | ) 147 | 148 | if "%1" == "texinfo" ( 149 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 150 | if errorlevel 1 exit /b 1 151 | echo. 152 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 153 | goto end 154 | ) 155 | 156 | if "%1" == "gettext" ( 157 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 158 | if errorlevel 1 exit /b 1 159 | echo. 160 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 161 | goto end 162 | ) 163 | 164 | if "%1" == "changes" ( 165 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 166 | if errorlevel 1 exit /b 1 167 | echo. 168 | echo.The overview file is in %BUILDDIR%/changes. 169 | goto end 170 | ) 171 | 172 | if "%1" == "linkcheck" ( 173 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 174 | if errorlevel 1 exit /b 1 175 | echo. 176 | echo.Link check complete; look for any errors in the above output ^ 177 | or in %BUILDDIR%/linkcheck/output.txt. 178 | goto end 179 | ) 180 | 181 | if "%1" == "doctest" ( 182 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 183 | if errorlevel 1 exit /b 1 184 | echo. 185 | echo.Testing of doctests in the sources finished, look at the ^ 186 | results in %BUILDDIR%/doctest/output.txt. 187 | goto end 188 | ) 189 | 190 | :end 191 | -------------------------------------------------------------------------------- /doc/xkcdify.py: -------------------------------------------------------------------------------- 1 | """ 2 | XKCD plot generator 3 | ------------------- 4 | Author: Jake Vanderplas 5 | 6 | This is a script that will take any matplotlib line diagram, and convert it 7 | to an XKCD-style plot. It will work for plots with line & text elements, 8 | including axes labels and titles (but not axes tick labels). 9 | 10 | The idea for this comes from work by Damon McDougall 11 | http://www.mail-archive.com/matplotlib-users@lists.sourceforge.net/msg25499.html 12 | 13 | Copied from Jake's blog: 14 | http://jakevdp.github.com/blog/2012/10/07/xkcd-style-plots-in-matplotlib/ 15 | on 2013-03-07 16 | 17 | """ 18 | import numpy as np 19 | import pylab as pl 20 | from scipy import interpolate, signal 21 | import matplotlib.font_manager as fm 22 | 23 | 24 | # We need a special font for the code below. It can be downloaded this way: 25 | import os 26 | import urllib 27 | if not os.path.exists('Humor-Sans.ttf'): 28 | fhandle = urllib.urlopen('https://github.com/shreyankg/xkcd-desktop/raw/master/Humor-Sans.ttf') 29 | open('Humor-Sans.ttf', 'wb').write(fhandle.read()) 30 | 31 | 32 | def xkcd_line(x, y, xlim=None, ylim=None, 33 | mag=1.0, f1=30, f2=0.05, f3=15): 34 | """ 35 | Mimic a hand-drawn line from (x, y) data 36 | 37 | Parameters 38 | ---------- 39 | x, y : array_like 40 | arrays to be modified 41 | xlim, ylim : data range 42 | the assumed plot range for the modification. If not specified, 43 | they will be guessed from the data 44 | mag : float 45 | magnitude of distortions 46 | f1, f2, f3 : int, float, int 47 | filtering parameters. f1 gives the size of the window, f2 gives 48 | the high-frequency cutoff, f3 gives the size of the filter 49 | 50 | Returns 51 | ------- 52 | x, y : ndarrays 53 | The modified lines 54 | """ 55 | x = np.asarray(x) 56 | y = np.asarray(y) 57 | 58 | # get limits for rescaling 59 | if xlim is None: 60 | xlim = (x.min(), x.max()) 61 | if ylim is None: 62 | ylim = (y.min(), y.max()) 63 | 64 | if xlim[1] == xlim[0]: 65 | xlim = ylim 66 | 67 | if ylim[1] == ylim[0]: 68 | ylim = xlim 69 | 70 | # scale the data 71 | x_scaled = (x - xlim[0]) * 1. / (xlim[1] - xlim[0]) 72 | y_scaled = (y - ylim[0]) * 1. / (ylim[1] - ylim[0]) 73 | 74 | # compute the total distance along the path 75 | dx = x_scaled[1:] - x_scaled[:-1] 76 | dy = y_scaled[1:] - y_scaled[:-1] 77 | dist_tot = np.sum(np.sqrt(dx * dx + dy * dy)) 78 | 79 | # number of interpolated points is proportional to the distance 80 | Nu = int(200 * dist_tot) 81 | u = np.arange(-1, Nu + 1) * 1. / (Nu - 1) 82 | 83 | # interpolate curve at sampled points 84 | k = min(3, len(x) - 1) 85 | res = interpolate.splprep([x_scaled, y_scaled], s=0, k=k) 86 | x_int, y_int = interpolate.splev(u, res[0]) 87 | 88 | # we'll perturb perpendicular to the drawn line 89 | dx = x_int[2:] - x_int[:-2] 90 | dy = y_int[2:] - y_int[:-2] 91 | dist = np.sqrt(dx * dx + dy * dy) 92 | 93 | # create a filtered perturbation 94 | coeffs = mag * np.random.normal(0, 0.01, len(x_int) - 2) 95 | b = signal.firwin(f1, f2 * dist_tot, window=('kaiser', f3)) 96 | response = signal.lfilter(b, 1, coeffs) 97 | 98 | x_int[1:-1] += response * dy / dist 99 | y_int[1:-1] += response * dx / dist 100 | 101 | # un-scale data 102 | x_int = x_int[1:-1] * (xlim[1] - xlim[0]) + xlim[0] 103 | y_int = y_int[1:-1] * (ylim[1] - ylim[0]) + ylim[0] 104 | 105 | return x_int, y_int 106 | 107 | 108 | def XKCDify(ax, mag=1.0, 109 | f1=50, f2=0.01, f3=15, 110 | bgcolor='w', 111 | xaxis_loc=None, 112 | yaxis_loc=None, 113 | xaxis_arrow='+', 114 | yaxis_arrow='+', 115 | ax_extend=0.1, 116 | expand_axes=False): 117 | """Make axis look hand-drawn 118 | 119 | This adjusts all lines, text, legends, and axes in the figure to look 120 | like xkcd plots. Other plot elements are not modified. 121 | 122 | Parameters 123 | ---------- 124 | ax : Axes instance 125 | the axes to be modified. 126 | mag : float 127 | the magnitude of the distortion 128 | f1, f2, f3 : int, float, int 129 | filtering parameters. f1 gives the size of the window, f2 gives 130 | the high-frequency cutoff, f3 gives the size of the filter 131 | xaxis_loc, yaxis_log : float 132 | The locations to draw the x and y axes. If not specified, they 133 | will be drawn from the bottom left of the plot 134 | xaxis_arrow, yaxis_arrow : str 135 | where to draw arrows on the x/y axes. Options are '+', '-', '+-', or '' 136 | ax_extend : float 137 | How far (fractionally) to extend the drawn axes beyond the original 138 | axes limits 139 | expand_axes : bool 140 | if True, then expand axes to fill the figure (useful if there is only 141 | a single axes in the figure) 142 | """ 143 | # Get axes aspect 144 | ext = ax.get_window_extent().extents 145 | aspect = (ext[3] - ext[1]) / (ext[2] - ext[0]) 146 | 147 | xlim = ax.get_xlim() 148 | ylim = ax.get_ylim() 149 | 150 | xspan = xlim[1] - xlim[0] 151 | yspan = ylim[1] - xlim[0] 152 | 153 | xax_lim = (xlim[0] - ax_extend * xspan, 154 | xlim[1] + ax_extend * xspan) 155 | yax_lim = (ylim[0] - ax_extend * yspan, 156 | ylim[1] + ax_extend * yspan) 157 | 158 | if xaxis_loc is None: 159 | xaxis_loc = ylim[0] 160 | 161 | if yaxis_loc is None: 162 | yaxis_loc = xlim[0] 163 | 164 | # Draw axes 165 | xaxis = pl.Line2D([xax_lim[0], xax_lim[1]], [xaxis_loc, xaxis_loc], 166 | linestyle='-', color='k') 167 | yaxis = pl.Line2D([yaxis_loc, yaxis_loc], [yax_lim[0], yax_lim[1]], 168 | linestyle='-', color='k') 169 | 170 | # Label axes3, 0.5, 'hello', fontsize=14) 171 | ax.text(xax_lim[1], xaxis_loc - 0.02 * yspan, ax.get_xlabel(), 172 | fontsize=14, ha='right', va='top', rotation=12) 173 | ax.text(yaxis_loc - 0.02 * xspan, yax_lim[1], ax.get_ylabel(), 174 | fontsize=14, ha='right', va='top', rotation=78) 175 | ax.set_xlabel('') 176 | ax.set_ylabel('') 177 | 178 | # Add title 179 | ax.text(0.5 * (xax_lim[1] + xax_lim[0]), yax_lim[1], 180 | ax.get_title(), 181 | ha='center', va='bottom', fontsize=16) 182 | ax.set_title('') 183 | 184 | Nlines = len(ax.lines) 185 | lines = [xaxis, yaxis] + [ax.lines.pop(0) for i in range(Nlines)] 186 | 187 | for line in lines: 188 | x, y = line.get_data() 189 | 190 | x_int, y_int = xkcd_line(x, y, xlim, ylim, 191 | mag, f1, f2, f3) 192 | 193 | # create foreground and background line 194 | lw = line.get_linewidth() 195 | line.set_linewidth(2 * lw) 196 | line.set_data(x_int, y_int) 197 | 198 | # don't add background line for axes 199 | if (line is not xaxis) and (line is not yaxis): 200 | line_bg = pl.Line2D(x_int, y_int, color=bgcolor, 201 | linewidth=8 * lw) 202 | 203 | ax.add_line(line_bg) 204 | ax.add_line(line) 205 | 206 | # Draw arrow-heads at the end of axes lines 207 | arr1 = 0.03 * np.array([-1, 0, -1]) 208 | arr2 = 0.02 * np.array([-1, 0, 1]) 209 | 210 | arr1[::2] += np.random.normal(0, 0.005, 2) 211 | arr2[::2] += np.random.normal(0, 0.005, 2) 212 | 213 | x, y = xaxis.get_data() 214 | if '+' in str(xaxis_arrow): 215 | ax.plot(x[-1] + arr1 * xspan * aspect, 216 | y[-1] + arr2 * yspan, 217 | color='k', lw=2) 218 | if '-' in str(xaxis_arrow): 219 | ax.plot(x[0] - arr1 * xspan * aspect, 220 | y[0] - arr2 * yspan, 221 | color='k', lw=2) 222 | 223 | x, y = yaxis.get_data() 224 | if '+' in str(yaxis_arrow): 225 | ax.plot(x[-1] + arr2 * xspan * aspect, 226 | y[-1] + arr1 * yspan, 227 | color='k', lw=2) 228 | if '-' in str(yaxis_arrow): 229 | ax.plot(x[0] - arr2 * xspan * aspect, 230 | y[0] - arr1 * yspan, 231 | color='k', lw=2) 232 | 233 | # Change all the fonts to humor-sans. 234 | prop = fm.FontProperties(fname='Humor-Sans.ttf', size=16) 235 | for text in ax.texts: 236 | text.set_fontproperties(prop) 237 | 238 | # modify legend 239 | leg = ax.get_legend() 240 | if leg is not None: 241 | leg.set_frame_on(False) 242 | 243 | for child in leg.get_children(): 244 | if isinstance(child, pl.Line2D): 245 | x, y = child.get_data() 246 | child.set_data(xkcd_line(x, y, mag=10, f1=100, f2=0.001)) 247 | child.set_linewidth(2 * child.get_linewidth()) 248 | if isinstance(child, pl.Text): 249 | child.set_fontproperties(prop) 250 | 251 | # Set the axis limits 252 | ax.set_xlim(xax_lim[0] - 0.1 * xspan, 253 | xax_lim[1] + 0.1 * xspan) 254 | ax.set_ylim(yax_lim[0] - 0.1 * yspan, 255 | yax_lim[1] + 0.1 * yspan) 256 | 257 | # adjust the axes 258 | ax.set_xticks([]) 259 | ax.set_yticks([]) 260 | 261 | if expand_axes: 262 | ax.figure.set_facecolor(bgcolor) 263 | ax.set_axis_off() 264 | ax.set_position([0, 0, 1, 1]) 265 | 266 | return ax 267 | -------------------------------------------------------------------------------- /examples/Profile.prof: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/py-earth/b209d1916f051dbea5b142af25425df2de469c5a/examples/Profile.prof -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | Gallery 2 | ------- 3 | 4 | This is the gallery of examples. 5 | -------------------------------------------------------------------------------- /examples/plot_classifier_comp.py: -------------------------------------------------------------------------------- 1 | """ 2 | ====================================================== 3 | Plotting sckit-learn classifiers comparison with Earth 4 | ====================================================== 5 | 6 | This script recreates the scikit-learn classifier comparison example found at 7 | http://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html. 8 | It has been modified to include an Earth based classifier. 9 | """ 10 | 11 | # Code source: Gael Varoqueux 12 | # Andreas Mueller 13 | # Modified for Documentation merge by Jaques Grobler 14 | # License: BSD 3 clause 15 | # Modified to include pyearth by Jason Rudy 16 | 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | from matplotlib.colors import ListedColormap 20 | from sklearn.cross_validation import train_test_split 21 | from sklearn.preprocessing import StandardScaler 22 | from sklearn.datasets import make_moons, make_circles, make_classification 23 | from sklearn.neighbors import KNeighborsClassifier 24 | from sklearn.svm import SVC 25 | from sklearn.tree import DecisionTreeClassifier 26 | from sklearn.ensemble import RandomForestClassifier 27 | from sklearn.naive_bayes import GaussianNB 28 | from sklearn.lda import LDA 29 | from sklearn.qda import QDA 30 | 31 | from sklearn.linear_model.logistic import LogisticRegression 32 | from sklearn.pipeline import Pipeline 33 | from pyearth.earth import Earth 34 | 35 | print(__doc__) 36 | 37 | h = .02 # step size in the mesh 38 | 39 | np.random.seed(1) 40 | 41 | # Combine Earth with LogisticRegression in a pipeline to do classification 42 | earth_classifier = Pipeline([('earth', Earth(max_degree=3, penalty=1.5)), 43 | ('logistic', LogisticRegression())]) 44 | 45 | names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Decision Tree", 46 | "Random Forest", "Naive Bayes", "LDA", "QDA", "Earth"] 47 | classifiers = [ 48 | KNeighborsClassifier(3), 49 | SVC(kernel="linear", C=0.025, probability=True), 50 | SVC(gamma=2, C=1, probability=True), 51 | DecisionTreeClassifier(max_depth=5), 52 | RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), 53 | GaussianNB(), 54 | LDA(), 55 | QDA(), 56 | earth_classifier] 57 | 58 | X, y = make_classification(n_features=2, n_redundant=0, n_informative=2, 59 | random_state=1, n_clusters_per_class=1) 60 | rng = np.random.RandomState(2) 61 | X += 2 * rng.uniform(size=X.shape) 62 | linearly_separable = (X, y) 63 | 64 | datasets = [make_moons(noise=0.3, random_state=0), 65 | make_circles(noise=0.2, factor=0.5, random_state=1), 66 | linearly_separable 67 | ] 68 | 69 | figure = plt.figure(figsize=(27, 9)) 70 | i = 1 71 | # iterate over datasets 72 | for ds in datasets: 73 | # preprocess dataset, split into training and test part 74 | X, y = ds 75 | X = StandardScaler().fit_transform(X) 76 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4) 77 | 78 | x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 79 | y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 80 | xx, yy = np.meshgrid(np.arange(x_min, x_max, h), 81 | np.arange(y_min, y_max, h)) 82 | 83 | # just plot the dataset first 84 | cm = plt.cm.RdBu 85 | cm_bright = ListedColormap(['#FF0000', '#0000FF']) 86 | ax = plt.subplot(len(datasets), len(classifiers) + 1, i) 87 | # Plot the training points 88 | ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright) 89 | # and testing points 90 | ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6) 91 | ax.set_xlim(xx.min(), xx.max()) 92 | ax.set_ylim(yy.min(), yy.max()) 93 | ax.set_xticks(()) 94 | ax.set_yticks(()) 95 | i += 1 96 | 97 | # iterate over classifiers 98 | for name, clf in zip(names, classifiers): 99 | ax = plt.subplot(len(datasets), len(classifiers) + 1, i) 100 | clf.fit(X_train, y_train) 101 | score = clf.score(X_test, y_test) 102 | 103 | # Plot the decision boundary. For that, we will assign a color to each 104 | # point in the mesh [x_min, m_max]x[y_min, y_max]. 105 | try: 106 | Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] 107 | except NotImplementedError: 108 | Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) 109 | 110 | # Put the result into a color plot 111 | Z = Z.reshape(xx.shape) 112 | ax.contourf(xx, yy, Z, cmap=cm, alpha=.8) 113 | 114 | # Plot also the training points 115 | ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright) 116 | # and testing points 117 | ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, 118 | alpha=0.6) 119 | 120 | ax.set_xlim(xx.min(), xx.max()) 121 | ax.set_ylim(yy.min(), yy.max()) 122 | ax.set_xticks(()) 123 | ax.set_yticks(()) 124 | ax.set_title(name) 125 | ax.text(xx.max() - .3, yy.min() + .3, ('%.2f' % score).lstrip('0'), 126 | size=15, horizontalalignment='right') 127 | i += 1 128 | 129 | figure.subplots_adjust(left=.02, right=.98) 130 | plt.savefig('classifier_comp.pdf', transparent=True) 131 | plt.show() 132 | -------------------------------------------------------------------------------- /examples/plot_derivatives.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================ 3 | Plotting derivatives of simple sine function 4 | ============================================ 5 | 6 | A simple example plotting a fit of the sine function and 7 | the derivatives computed by Earth. 8 | """ 9 | import numpy 10 | import matplotlib.pyplot as plt 11 | 12 | from pyearth import Earth 13 | 14 | # Create some fake data 15 | numpy.random.seed(2) 16 | m = 10000 17 | n = 10 18 | X = 20 * numpy.random.uniform(size=(m, n)) - 10 19 | y = 10*numpy.sin(X[:, 6]) + 0.25*numpy.random.normal(size=m) 20 | 21 | # Compute the known true derivative with respect to the predictive variable 22 | y_prime = 10*numpy.cos(X[:, 6]) 23 | 24 | # Fit an Earth model 25 | model = Earth(max_degree=2, minspan_alpha=.5, smooth=True) 26 | model.fit(X, y) 27 | 28 | # Print the model 29 | print(model.trace()) 30 | print(model.summary()) 31 | 32 | # Get the predicted values and derivatives 33 | y_hat = model.predict(X) 34 | y_prime_hat = model.predict_deriv(X, 'x6') 35 | 36 | # Plot true and predicted function values and derivatives 37 | # for the predictive variable 38 | plt.subplot(211) 39 | plt.plot(X[:, 6], y, 'r.') 40 | plt.plot(X[:, 6], y_hat, 'b.') 41 | plt.ylabel('function') 42 | plt.subplot(212) 43 | plt.plot(X[:, 6], y_prime, 'r.') 44 | plt.plot(X[:, 6], y_prime_hat[:, 0], 'b.') 45 | plt.ylabel('derivative') 46 | plt.show() 47 | -------------------------------------------------------------------------------- /examples/plot_feature_importance.py: -------------------------------------------------------------------------------- 1 | """ 2 | =========================== 3 | Plotting feature importance 4 | =========================== 5 | 6 | A simple example showing how to compute and display 7 | feature importances, it is also compared with the 8 | feature importances obtained using random forests. 9 | 10 | Feature importance is a measure of the effect of the features 11 | on the outputs. For each feature, the values go from 12 | 0 to 1 where a higher the value means that the feature will have 13 | a higher effect on the outputs. 14 | 15 | Currently three criteria are supported : 'gcv', 'rss' and 'nb_subsets'. 16 | See [1], section 12.3 for more information about the criteria. 17 | 18 | .. [1] http://www.milbo.org/doc/earth-notes.pdf 19 | 20 | """ 21 | import numpy 22 | import matplotlib.pyplot as plt 23 | 24 | from sklearn.ensemble import RandomForestRegressor 25 | from pyearth import Earth 26 | 27 | # Create some fake data 28 | numpy.random.seed(2) 29 | m = 10000 30 | n = 10 31 | 32 | X = numpy.random.uniform(size=(m, n)) 33 | y = (10 * numpy.sin(numpy.pi * X[:, 0] * X[:, 1]) + 34 | 20 * (X[:, 2] - 0.5) ** 2 + 35 | 10 * X[:, 3] + 36 | 5 * X[:, 4] + numpy.random.uniform(size=m)) 37 | # Fit an Earth model 38 | criteria = ('rss', 'gcv', 'nb_subsets') 39 | model = Earth(max_degree=3, 40 | max_terms=10, 41 | minspan_alpha=.5, 42 | feature_importance_type=criteria, 43 | verbose=True) 44 | model.fit(X, y) 45 | rf = RandomForestRegressor() 46 | rf.fit(X, y) 47 | # Print the model 48 | print(model.trace()) 49 | print(model.summary()) 50 | print(model.summary_feature_importances(sort_by='gcv')) 51 | 52 | # Plot the feature importances 53 | importances = model.feature_importances_ 54 | importances['random_forest'] = rf.feature_importances_ 55 | criteria = criteria + ('random_forest',) 56 | idx = 1 57 | 58 | fig = plt.figure(figsize=(20, 10)) 59 | labels = ['$x_{}$'.format(i) for i in range(n)] 60 | for crit in criteria: 61 | plt.subplot(2, 2, idx) 62 | plt.bar(numpy.arange(len(labels)), 63 | importances[crit], 64 | align='center', 65 | color='red') 66 | plt.xticks(numpy.arange(len(labels)), labels) 67 | plt.title(crit) 68 | plt.ylabel('importances') 69 | idx += 1 70 | title = '$x_0,...x_9 \sim \mathcal{N}(0, 1)$\n$y= 10sin(\pi x_{0}x_{1}) + 20(x_2 - 0.5)^2 + 10x_3 + 5x_4 + Unif(0, 1)$' 71 | fig.suptitle(title, fontsize="x-large") 72 | plt.show() 73 | -------------------------------------------------------------------------------- /examples/plot_linear_function.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ====================================================== 3 | Plotting a linear function with a categorical variable 4 | ====================================================== 5 | 6 | Fitting a pyearth model to a linear function shows that pyearth 7 | will automatically choose a linear basis function in some cases. 8 | ''' 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from pyearth import Earth 12 | 13 | np.random.seed(1) 14 | m = 1000 15 | n = 5 16 | 17 | X = np.random.normal(size=(m, n)) 18 | 19 | # Make X[:,1] binary 20 | X[:, 1] = np.random.binomial(1, .5, size=m) 21 | 22 | # The response is a linear function of the inputs 23 | y = 2 * X[:, 0] + 3 * X[:, 1] + np.random.normal(size=m) 24 | 25 | # Fit the earth model 26 | model = Earth().fit(X, y) 27 | 28 | # Print the model summary, showing linear terms 29 | print(model.summary()) 30 | 31 | # Plot for both values of X[:,1] 32 | y_hat = model.predict(X) 33 | plt.figure() 34 | plt.plot(X[:, 0], y, 'k.') 35 | plt.plot(X[X[:, 1] == 0, 0], y_hat[X[:, 1] == 0], 'r.', label='$x_1 = 0$') 36 | plt.plot(X[X[:, 1] == 1, 0], y_hat[X[:, 1] == 1], 'b.', label='$x_1 = 1$') 37 | plt.legend(loc='best') 38 | plt.xlabel('$x_0$') 39 | plt.show() 40 | -------------------------------------------------------------------------------- /examples/plot_missing_data_problem.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================================ 3 | Plotting sine function with redundant predictors an missing data 4 | ================================================================ 5 | 6 | An example plotting a fit of the sine function. There are two 7 | redundant predictors, each of which has independent and random 8 | missingness. 9 | """ 10 | import numpy 11 | import matplotlib.pyplot as plt 12 | 13 | from pyearth import Earth 14 | 15 | # Create some fake data 16 | numpy.random.seed(2) 17 | m = 10000 18 | n = 10 19 | X = 80 * numpy.random.uniform(size=(m, n)) - 40 20 | X[:, 5] = X[:, 6] + numpy.random.normal(0, .1, m) 21 | y = 100 * \ 22 | (numpy.sin((X[:, 5] + X[:, 6]) / 20) - 4.0) + \ 23 | 10 * numpy.random.normal(size=m) 24 | missing = numpy.random.binomial(1, .2, (m, n)).astype(bool) 25 | X_full = X.copy() 26 | X[missing] = None 27 | idx5 = (1 - missing[:, 5]).astype(bool) 28 | idx6 = (1 - missing[:, 6]).astype(bool) 29 | 30 | # Fit an Earth model 31 | model = Earth(max_degree=5, minspan_alpha=.5, allow_missing=True, 32 | enable_pruning=True, thresh=.001, smooth=True, verbose=2) 33 | model.fit(X, y) 34 | # Print the model 35 | print(model.summary()) 36 | 37 | # Plot the model 38 | y_hat = model.predict(X) 39 | fig = plt.figure() 40 | 41 | ax1 = fig.add_subplot(3, 2, 1) 42 | ax1.plot(X_full[idx5, 5], y[idx5], 'b.') 43 | ax1.plot(X_full[idx5, 5], y_hat[idx5], 'r.') 44 | ax1.set_xlim(-40, 40) 45 | ax1.set_title('x5 present') 46 | ax1.set_xlabel('x5') 47 | 48 | ax2 = fig.add_subplot(3, 2, 2) 49 | ax2.plot(X_full[idx6, 6], y[idx6], 'b.') 50 | ax2.plot(X_full[idx6, 6], y_hat[idx6], 'r.') 51 | ax2.set_xlim(-40, 40) 52 | ax2.set_title('x6 present') 53 | ax2.set_xlabel('x6') 54 | 55 | ax3 = fig.add_subplot(3, 2, 3, sharex=ax1) 56 | ax3.plot(X_full[~idx6, 5], y[~idx6], 'b.') 57 | ax3.plot(X_full[~idx6, 5], y_hat[~idx6], 'r.') 58 | ax3.set_title('x6 missing') 59 | ax3.set_xlabel('x5') 60 | 61 | ax4 = fig.add_subplot(3, 2, 4, sharex=ax2) 62 | ax4.plot(X_full[~idx5, 6], y[~idx5], 'b.') 63 | ax4.plot(X_full[~idx5, 6], y_hat[~idx5], 'r.') 64 | ax4.set_title('x5 missing') 65 | ax4.set_xlabel('x6') 66 | 67 | ax5 = fig.add_subplot(3, 2, 5, sharex=ax1) 68 | ax5.plot(X_full[(~idx6) & (~idx5), 5], y[(~idx6) & (~idx5)], 'b.') 69 | ax5.plot(X_full[(~idx6) & (~idx5), 5], y_hat[(~idx6) & (~idx5)], 'r.') 70 | ax5.set_title('both missing') 71 | ax5.set_xlabel('x5') 72 | 73 | ax6 = fig.add_subplot(3, 2, 6, sharex=ax2) 74 | ax6.plot(X_full[(~idx6) & (~idx5), 6], y[(~idx6) & (~idx5)], 'b.') 75 | ax6.plot(X_full[(~idx6) & (~idx5), 6], y_hat[(~idx6) & (~idx5)], 'r.') 76 | ax6.set_title('both missing') 77 | ax6.set_xlabel('x6') 78 | 79 | fig.tight_layout() 80 | plt.show() 81 | -------------------------------------------------------------------------------- /examples/plot_multicolumn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | =================================================================== 3 | Plotting a multicolumn regression problem that includes missingness 4 | =================================================================== 5 | 6 | An example plotting a simultaneous fit of the sine and cosine functions. 7 | There are two redundant predictors, each of which has independent and random 8 | missingness. 9 | ''' 10 | 11 | import numpy 12 | import matplotlib.pyplot as plt 13 | 14 | from pyearth import Earth 15 | 16 | # Create some fake data 17 | numpy.random.seed(2) 18 | m = 10000 19 | n = 10 20 | X = 80 * numpy.random.uniform(size=(m, n)) - 40 21 | X[:, 5] = X[:, 6] + numpy.random.normal(0, .1, m) 22 | y1 = 100 * \ 23 | (numpy.sin((X[:, 5] + X[:, 6]) / 20) - 4.0) + \ 24 | 10 * numpy.random.normal(size=m) 25 | y2 = 100 * \ 26 | (numpy.cos((X[:, 5] + X[:, 6]) / 20) - 4.0) + \ 27 | 10 * numpy.random.normal(size=m) 28 | y = numpy.concatenate([y1[:, None], y2[:, None]], axis=1) 29 | missing = numpy.random.binomial(1, .2, (m, n)).astype(bool) 30 | X_full = X.copy() 31 | X[missing] = None 32 | idx5 = (1 - missing[:, 5]).astype(bool) 33 | idx6 = (1 - missing[:, 6]).astype(bool) 34 | 35 | # Fit an Earth model 36 | model = Earth(max_degree=5, minspan_alpha=.5, allow_missing=True, 37 | enable_pruning=True, thresh=.001, smooth=True, 38 | verbose=True) 39 | model.fit(X, y) 40 | 41 | # Print the model 42 | print(model.trace()) 43 | print(model.summary()) 44 | 45 | # Plot the model 46 | y_hat = model.predict(X) 47 | fig = plt.figure() 48 | 49 | for j in [0, 1]: 50 | ax1 = fig.add_subplot(3, 4, 1 + 2*j) 51 | ax1.plot(X_full[idx5, 5], y[idx5, j], 'b.') 52 | ax1.plot(X_full[idx5, 5], y_hat[idx5, j], 'r.') 53 | ax1.set_xlim(-40, 40) 54 | ax1.set_title('x5 present') 55 | ax1.set_xlabel('x5') 56 | ax1.set_ylabel('sin' if j == 0 else 'cos') 57 | 58 | ax2 = fig.add_subplot(3, 4, 2 + 2*j) 59 | ax2.plot(X_full[idx6, 6], y[idx6, j], 'b.') 60 | ax2.plot(X_full[idx6, 6], y_hat[idx6, j], 'r.') 61 | ax2.set_xlim(-40, 40) 62 | ax2.set_title('x6 present') 63 | ax2.set_xlabel('x6') 64 | ax2.set_ylabel('sin' if j == 0 else 'cos') 65 | 66 | ax3 = fig.add_subplot(3, 4, 5 + 2*j, sharex=ax1) 67 | ax3.plot(X_full[~idx6, 5], y[~idx6, j], 'b.') 68 | ax3.plot(X_full[~idx6, 5], y_hat[~idx6, j], 'r.') 69 | ax3.set_title('x6 missing') 70 | ax3.set_xlabel('x5') 71 | ax3.set_ylabel('sin' if j == 0 else 'cos') 72 | 73 | ax4 = fig.add_subplot(3, 4, 6 + 2*j, sharex=ax2) 74 | ax4.plot(X_full[~idx5, 6], y[~idx5, j], 'b.') 75 | ax4.plot(X_full[~idx5, 6], y_hat[~idx5, j], 'r.') 76 | ax4.set_title('x5 missing') 77 | ax4.set_xlabel('x6') 78 | ax4.set_ylabel('sin' if j == 0 else 'cos') 79 | 80 | ax5 = fig.add_subplot(3, 4, 9 + 2*j, sharex=ax1) 81 | ax5.plot(X_full[(~idx6) & (~idx5), 5], y[(~idx6) & (~idx5), j], 'b.') 82 | ax5.plot(X_full[(~idx6) & (~idx5), 5], y_hat[(~idx6) & (~idx5), j], 'r.') 83 | ax5.set_title('both missing') 84 | ax5.set_xlabel('x5') 85 | ax5.set_ylabel('sin' if j == 0 else 'cos') 86 | 87 | ax6 = fig.add_subplot(3, 4, 10 + 2*j, sharex=ax2) 88 | ax6.plot(X_full[(~idx6) & (~idx5), 6], y[(~idx6) & (~idx5), j], 'b.') 89 | ax6.plot(X_full[(~idx6) & (~idx5), 6], y_hat[(~idx6) & (~idx5), j], 'r.') 90 | ax6.set_title('both missing') 91 | ax6.set_xlabel('x6') 92 | ax6.set_ylabel('sin' if j == 0 else 'cos') 93 | 94 | fig.tight_layout() 95 | plt.show() 96 | -------------------------------------------------------------------------------- /examples/plot_output_weight.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================================= 3 | Demonstrating a use of weights in outputs with two sine functions 4 | ================================================================= 5 | 6 | Each row in the grid is a run of an earth model. 7 | Each column is an output. 8 | In each run, different weights are given to 9 | the outputs. 10 | 11 | """ 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from pyearth import Earth 15 | 16 | # Create some fake data 17 | np.random.seed(2) 18 | m = 10000 19 | n = 10 20 | X = 80 * np.random.uniform(size=(m, n)) - 40 21 | y1 = 120 * np.abs(np.sin((X[:, 6]) / 6) - 1.0) + 15 * np.random.normal(size=m) 22 | y2 = 120 * np.abs(np.sin((X[:, 5]) / 6) - 1.0) + 15 * np.random.normal(size=m) 23 | 24 | y1 = (y1 - y1.mean()) / y1.std() 25 | y2 = (y2 - y2.mean()) / y2.std() 26 | y_mix = np.concatenate((y1[:, np.newaxis], y2[:, np.newaxis]), axis=1) 27 | 28 | alphas = [0.9, 0.8, 0.6, 0.4, 0.2, 0.1] 29 | n_plots = len(alphas) 30 | k = 1 31 | fig = plt.figure(figsize=(10, 15)) 32 | for i, alpha in enumerate(alphas): 33 | # Fit an Earth model 34 | model = Earth(max_degree=5, 35 | minspan_alpha=.05, 36 | endspan_alpha=.05, 37 | max_terms=10, 38 | check_every=1, 39 | thresh=0.) 40 | output_weight = np.array([alpha, 1 - alpha]) 41 | model.fit(X, y_mix, output_weight=output_weight) 42 | print(model.summary()) 43 | 44 | # Plot the model 45 | y_hat = model.predict(X) 46 | 47 | mse = ((y_hat - y_mix) ** 2).mean(axis=0) 48 | ax = plt.subplot(n_plots, 2, k) 49 | ax.set_ylabel("Run {0}".format(i + 1), rotation=0, labelpad=20) 50 | plt.plot(X[:, 6], y_mix[:, 0], 'r.') 51 | plt.plot(X[:, 6], model.predict(X)[:, 0], 'b.') 52 | plt.title("MSE: {0:.3f}, Weight : {1:.1f}".format(mse[0], alpha)) 53 | plt.subplot(n_plots, 2, k + 1) 54 | plt.plot(X[:, 5], y_mix[:, 1], 'r.') 55 | plt.plot(X[:, 5], model.predict(X)[:, 1], 'b.') 56 | plt.title("MSE: {0:.3f}, Weight : {1:.1f}".format(mse[1], 1 - alpha)) 57 | k += 2 58 | plt.tight_layout() 59 | plt.show() 60 | -------------------------------------------------------------------------------- /examples/plot_sine_wave.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================= 3 | Plotting simple sine function 4 | ============================= 5 | 6 | A simple example plotting a fit of the sine function. 7 | """ 8 | import numpy 9 | import matplotlib.pyplot as plt 10 | 11 | from pyearth import Earth 12 | 13 | # Create some fake data 14 | numpy.random.seed(2) 15 | m = 10000 16 | n = 10 17 | X = 80 * numpy.random.uniform(size=(m, n)) - 40 18 | y = 100 * \ 19 | (numpy.sin((X[:, 6])) - 4.0) + \ 20 | 10 * numpy.random.normal(size=m) 21 | 22 | # Fit an Earth model 23 | model = Earth(max_degree=3, minspan_alpha=.5, verbose=True) 24 | model.fit(X, y) 25 | 26 | # Print the model 27 | print(model.trace()) 28 | print(model.summary()) 29 | 30 | # Plot the model 31 | y_hat = model.predict(X) 32 | plt.plot(X[:, 6], y, 'r.') 33 | plt.plot(X[:, 6], y_hat, 'b.') 34 | plt.show() 35 | -------------------------------------------------------------------------------- /examples/plot_sine_wave_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================== 3 | Plotting two simple sine functions 4 | ================================== 5 | 6 | A simple example plotting a fit of two sine functions. 7 | """ 8 | import numpy 9 | import matplotlib.pyplot as plt 10 | 11 | from pyearth import Earth 12 | 13 | # Create some fake data 14 | numpy.random.seed(2) 15 | m = 10000 16 | n = 10 17 | X = 80 * numpy.random.uniform(size=(m, n)) - 40 18 | y1 = 100 * \ 19 | numpy.abs(numpy.sin((X[:, 6]) / 10) - 4.0) + \ 20 | 10 * numpy.random.normal(size=m) 21 | 22 | y2 = 100 * \ 23 | numpy.abs(numpy.sin((X[:, 6]) / 2) - 8.0) + \ 24 | 5 * numpy.random.normal(size=m) 25 | 26 | # Fit an Earth model 27 | model = Earth(max_degree=3, minspan_alpha=.5) 28 | y_mix = numpy.concatenate((y1[:, numpy.newaxis], y2[:, numpy.newaxis]), axis=1) 29 | model.fit(X, y_mix) 30 | 31 | # Print the model 32 | print(model.trace()) 33 | print(model.summary()) 34 | 35 | # Plot the model 36 | y_hat = model.predict(X) 37 | 38 | fig = plt.figure() 39 | 40 | ax = fig.add_subplot(1, 2, 1) 41 | ax.plot(X[:, 6], y_mix[:, 0], 'r.') 42 | ax.plot(X[:, 6], model.predict(X)[:, 0], 'b.') 43 | 44 | ax = fig.add_subplot(1, 2, 2) 45 | ax.plot(X[:, 6], y_mix[:, 1], 'r.') 46 | ax.plot(X[:, 6], model.predict(X)[:, 1], 'b.') 47 | plt.show() 48 | -------------------------------------------------------------------------------- /examples/plot_v_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | ==================================== 3 | Plotting the absolute value function 4 | ==================================== 5 | 6 | A simple example plotting a fit of the absolute value function. 7 | 8 | """ 9 | 10 | import numpy 11 | import matplotlib.pyplot as plt 12 | from pyearth import Earth 13 | 14 | # Create some fake data 15 | numpy.random.seed(2) 16 | m = 1000 17 | n = 10 18 | X = 80 * numpy.random.uniform(size=(m, n)) - 40 19 | y = numpy.abs(X[:, 6] - 4.0) + 1 * numpy.random.normal(size=m) 20 | 21 | # Fit an Earth model 22 | model = Earth(max_degree=1, verbose=True) 23 | model.fit(X, y) 24 | 25 | # Print the model 26 | print(model.trace()) 27 | print(model.summary()) 28 | 29 | # Plot the model 30 | y_hat = model.predict(X) 31 | plt.figure() 32 | plt.plot(X[:, 6], y, 'r.') 33 | plt.plot(X[:, 6], y_hat, 'b.') 34 | plt.show() 35 | -------------------------------------------------------------------------------- /examples/return_sympy.py: -------------------------------------------------------------------------------- 1 | """ 2 | ===================================================== 3 | Exporting a fitted Earth models as a sympy expression 4 | ===================================================== 5 | 6 | A simple example returning a sympy expression describing the fit of a sine function computed by Earth. 7 | 8 | """ 9 | 10 | import numpy 11 | from pyearth import Earth 12 | from pyearth import export 13 | 14 | # Create some fake data 15 | numpy.random.seed(2) 16 | m = 1000 17 | n = 10 18 | X = 10 * numpy.random.uniform(size=(m, n)) - 40 19 | y = 100 * \ 20 | (numpy.sin((X[:, 6])) - 4.0) + \ 21 | 10 * numpy.random.normal(size=m) 22 | 23 | # Fit an Earth model 24 | model = Earth(max_degree=2, minspan_alpha=.5, verbose=False) 25 | model.fit(X, y) 26 | 27 | print(model.summary()) 28 | 29 | #return sympy expression 30 | print('Resulting sympy expression:') 31 | print(export.export_sympy(model)) 32 | 33 | -------------------------------------------------------------------------------- /pyearth/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Feb 16, 2013 3 | 4 | @author: jasonrudy 5 | ''' 6 | from .earth import Earth 7 | 8 | from ._version import get_versions 9 | __version__ = get_versions()['version'] 10 | del get_versions 11 | -------------------------------------------------------------------------------- /pyearth/_basis.pxd: -------------------------------------------------------------------------------- 1 | from cpython cimport bool 2 | cimport numpy as cnp 3 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 4 | 5 | cdef class BasisFunction: 6 | '''Abstract. Subclasses must implement the apply and __init__ methods.''' 7 | 8 | cdef BasisFunction parent 9 | cdef dict child_map 10 | cdef list children 11 | cdef bint pruned 12 | cdef bint prunable 13 | cdef bint splittable 14 | 15 | cpdef smooth(BasisFunction self, dict knot_dict, dict translation) 16 | 17 | cpdef bint has_knot(BasisFunction self) 18 | 19 | cpdef bint is_prunable(BasisFunction self) 20 | 21 | cpdef bint is_pruned(BasisFunction self) 22 | 23 | cpdef bint is_splittable(BasisFunction self) 24 | 25 | cpdef bint make_splittable(BasisFunction self) 26 | 27 | cpdef bint make_unsplittable(BasisFunction self) 28 | 29 | cpdef list get_children(BasisFunction self) 30 | 31 | cpdef BasisFunction get_coverage(BasisFunction self, INDEX_t variable) 32 | 33 | cpdef bool has_linear(BasisFunction self, INDEX_t variable) 34 | 35 | cpdef bool linear_in(BasisFunction self, INDEX_t variable) 36 | 37 | cpdef _set_parent(BasisFunction self, BasisFunction parent) 38 | 39 | cpdef _add_child(BasisFunction self, BasisFunction child) 40 | 41 | cpdef BasisFunction get_parent(BasisFunction self) 42 | 43 | cpdef prune(BasisFunction self) 44 | 45 | cpdef unprune(BasisFunction self) 46 | 47 | cpdef knots(BasisFunction self, INDEX_t variable) 48 | 49 | cpdef INDEX_t effective_degree(BasisFunction self) 50 | 51 | cpdef apply(BasisFunction self, cnp.ndarray[FLOAT_t, ndim=2] X, 52 | cnp.ndarray[BOOL_t, ndim=2] missing, 53 | cnp.ndarray[FLOAT_t, ndim=1] b, bint recurse= ?) 54 | 55 | cpdef cnp.ndarray[INT_t, ndim = 1] valid_knots(BasisFunction self, 56 | cnp.ndarray[FLOAT_t, ndim=1] values, 57 | cnp.ndarray[FLOAT_t, ndim=1] variable, 58 | int variable_idx, INDEX_t check_every, 59 | int endspan, int minspan, 60 | FLOAT_t minspan_alpha, INDEX_t n, 61 | cnp.ndarray[INT_t, ndim=1] workspace) 62 | 63 | cdef class RootBasisFunction(BasisFunction): 64 | cpdef bint covered(RootBasisFunction self, INDEX_t variable) 65 | 66 | cpdef bint eligible(RootBasisFunction self, INDEX_t variable) 67 | 68 | cpdef set variables(RootBasisFunction self) 69 | 70 | cpdef _smoothed_version(RootBasisFunction self, BasisFunction parent, 71 | dict knot_dict, dict translation) 72 | 73 | cpdef INDEX_t degree(RootBasisFunction self) 74 | 75 | cpdef _effective_degree(RootBasisFunction self, dict data_dict, dict missing_dict) 76 | 77 | cpdef _set_parent(RootBasisFunction self, BasisFunction parent) 78 | 79 | cpdef BasisFunction get_parent(RootBasisFunction self) 80 | 81 | cpdef apply(RootBasisFunction self, cnp.ndarray[FLOAT_t, ndim=2] X, 82 | cnp.ndarray[BOOL_t, ndim=2] missing, 83 | cnp.ndarray[FLOAT_t, ndim=1] b, bint recurse=?) 84 | 85 | cpdef apply_deriv(RootBasisFunction self, cnp.ndarray[FLOAT_t, ndim=2] X, 86 | cnp.ndarray[BOOL_t, ndim=2] missing, 87 | cnp.ndarray[FLOAT_t, ndim=1] b, 88 | cnp.ndarray[FLOAT_t, ndim=1] j, INDEX_t var) 89 | 90 | cdef class ConstantBasisFunction(RootBasisFunction): 91 | 92 | cpdef inline FLOAT_t eval(ConstantBasisFunction self) 93 | 94 | cpdef inline FLOAT_t eval_deriv(ConstantBasisFunction self) 95 | 96 | cdef class VariableBasisFunction(BasisFunction): 97 | cdef INDEX_t variable 98 | cdef readonly label 99 | 100 | cpdef INDEX_t degree(VariableBasisFunction self) 101 | 102 | cpdef set variables(VariableBasisFunction self) 103 | 104 | cpdef INDEX_t get_variable(VariableBasisFunction self) 105 | 106 | cdef class DataVariableBasisFunction(VariableBasisFunction): 107 | cpdef _effective_degree(DataVariableBasisFunction self, dict data_dict, dict missing_dict) 108 | 109 | cpdef bint covered(DataVariableBasisFunction self, INDEX_t variable) 110 | 111 | cpdef bint eligible(DataVariableBasisFunction self, INDEX_t variable) 112 | 113 | cpdef apply(DataVariableBasisFunction self, cnp.ndarray[FLOAT_t, ndim=2] X, 114 | cnp.ndarray[BOOL_t, ndim=2] missing, 115 | cnp.ndarray[FLOAT_t, ndim=1] b, bint recurse=?) 116 | 117 | cpdef apply_deriv(DataVariableBasisFunction self, 118 | cnp.ndarray[FLOAT_t, ndim=2] X, 119 | cnp.ndarray[BOOL_t, ndim=2] missing, 120 | cnp.ndarray[FLOAT_t, ndim=1] b, 121 | cnp.ndarray[FLOAT_t, ndim=1] j, INDEX_t var) 122 | 123 | cdef class MissingnessBasisFunction(VariableBasisFunction): 124 | cdef readonly bint complement 125 | 126 | cpdef _effective_degree(MissingnessBasisFunction self, dict data_dict, dict missing_dict) 127 | 128 | cpdef bint covered(MissingnessBasisFunction self, INDEX_t variable) 129 | 130 | cpdef bint eligible(MissingnessBasisFunction self, INDEX_t variable) 131 | 132 | cpdef bint covered(MissingnessBasisFunction self, INDEX_t variable) 133 | 134 | cpdef bint eligible(MissingnessBasisFunction self, INDEX_t variable) 135 | 136 | cpdef apply(MissingnessBasisFunction self, cnp.ndarray[FLOAT_t, ndim=2] X, 137 | cnp.ndarray[BOOL_t, ndim=2] missing, 138 | cnp.ndarray[FLOAT_t, ndim=1] b, bint recurse=?) 139 | 140 | cpdef apply_deriv(MissingnessBasisFunction self, 141 | cnp.ndarray[FLOAT_t, ndim=2] X, 142 | cnp.ndarray[BOOL_t, ndim=2] missing, 143 | cnp.ndarray[FLOAT_t, ndim=1] b, 144 | cnp.ndarray[FLOAT_t, ndim=1] j, INDEX_t var) 145 | 146 | cpdef _smoothed_version(MissingnessBasisFunction self, BasisFunction parent, 147 | dict knot_dict, dict translation) 148 | 149 | cdef class HingeBasisFunctionBase(DataVariableBasisFunction): 150 | cdef FLOAT_t knot 151 | cdef INDEX_t knot_idx 152 | cdef bint reverse 153 | 154 | cpdef bint has_knot(HingeBasisFunctionBase self) 155 | 156 | cpdef INDEX_t get_variable(HingeBasisFunctionBase self) 157 | 158 | cpdef FLOAT_t get_knot(HingeBasisFunctionBase self) 159 | 160 | cpdef bint get_reverse(HingeBasisFunctionBase self) 161 | 162 | cpdef INDEX_t get_knot_idx(HingeBasisFunctionBase self) 163 | 164 | cdef class SmoothedHingeBasisFunction(HingeBasisFunctionBase): 165 | cdef FLOAT_t p 166 | cdef FLOAT_t r 167 | cdef FLOAT_t knot_minus 168 | cdef FLOAT_t knot_plus 169 | 170 | cpdef _smoothed_version(SmoothedHingeBasisFunction self, 171 | BasisFunction parent, dict knot_dict, 172 | dict translation) 173 | 174 | cpdef get_knot_minus(SmoothedHingeBasisFunction self) 175 | 176 | cpdef get_knot_plus(SmoothedHingeBasisFunction self) 177 | 178 | cpdef _init_p_r(SmoothedHingeBasisFunction self) 179 | 180 | cpdef get_p(SmoothedHingeBasisFunction self) 181 | 182 | cpdef get_r(SmoothedHingeBasisFunction self) 183 | 184 | cdef class HingeBasisFunction(HingeBasisFunctionBase): 185 | 186 | cpdef _smoothed_version(HingeBasisFunction self, 187 | BasisFunction parent, 188 | dict knot_dict, dict translation) 189 | 190 | cdef class LinearBasisFunction(DataVariableBasisFunction): 191 | cpdef bool linear_in(LinearBasisFunction self, INDEX_t variable) 192 | 193 | cpdef _smoothed_version(LinearBasisFunction self, BasisFunction parent, 194 | dict knot_dict, dict translation) 195 | 196 | cdef class Basis: 197 | '''A wrapper that provides functionality related to a set of BasisFunctions 198 | with a common RootBasisFunction ancestor. Retains the order in which 199 | BasisFunctions are added.''' 200 | 201 | cdef list order 202 | cdef readonly INDEX_t num_variables 203 | # cdef dict coverage 204 | 205 | # cpdef add_coverage(Basis self, int variable, MissingnessBasisFunction b1, \ 206 | # MissingnessBasisFunction b2) 207 | # 208 | # cpdef get_coverage(Basis self, int variable) 209 | # 210 | # cpdef bint has_coverage(Basis self, int variable) 211 | 212 | cpdef int get_num_variables(Basis self) 213 | 214 | cpdef dict anova_decomp(Basis self) 215 | 216 | cpdef smooth(Basis self, cnp.ndarray[FLOAT_t, ndim=2] X) 217 | 218 | cpdef append(Basis self, BasisFunction basis_function) 219 | 220 | cpdef INDEX_t plen(Basis self) 221 | 222 | cpdef BasisFunction get(Basis self, INDEX_t i) 223 | 224 | cpdef transform(Basis self, cnp.ndarray[FLOAT_t, ndim=2] X, 225 | cnp.ndarray[BOOL_t, ndim=2] missing, 226 | cnp.ndarray[FLOAT_t, ndim=2] B) 227 | 228 | cpdef weighted_transform(Basis self, cnp.ndarray[FLOAT_t, ndim=2] X, 229 | cnp.ndarray[BOOL_t, ndim=2] missing, 230 | cnp.ndarray[FLOAT_t, ndim=2] B, 231 | cnp.ndarray[FLOAT_t, ndim=1] weights) 232 | 233 | cpdef transform_deriv(Basis self, cnp.ndarray[FLOAT_t, ndim=2] X, 234 | cnp.ndarray[BOOL_t, ndim=2] missing, 235 | cnp.ndarray[FLOAT_t, ndim=1] b, 236 | cnp.ndarray[FLOAT_t, ndim=1] j, 237 | cnp.ndarray[FLOAT_t, ndim=2] coef, 238 | cnp.ndarray[FLOAT_t, ndim=3] J, 239 | list variables_of_interest, bool prezeroed_j=?) 240 | -------------------------------------------------------------------------------- /pyearth/_forward.pxd: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | import numpy as np 3 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 4 | from _basis cimport Basis 5 | from _record cimport ForwardPassRecord 6 | from _knot_search cimport MultipleOutcomeDependentData 7 | 8 | # cdef dict stopping_conditions 9 | 10 | cdef class ForwardPasser: 11 | 12 | # User selected parameters 13 | cdef int endspan 14 | cdef int minspan 15 | cdef FLOAT_t endspan_alpha 16 | cdef FLOAT_t minspan_alpha 17 | cdef int max_terms 18 | cdef bint allow_linear 19 | cdef int max_degree 20 | cdef FLOAT_t thresh 21 | cdef FLOAT_t penalty 22 | cdef int check_every 23 | cdef int min_search_points 24 | cdef list xlabels 25 | cdef FLOAT_t zero_tol 26 | cdef list fast_heap 27 | cdef int use_fast 28 | cdef long fast_K 29 | cdef long fast_h 30 | cdef bint allow_missing 31 | cdef int verbose 32 | 33 | # Input data 34 | cdef cnp.ndarray X 35 | cdef cnp.ndarray missing 36 | cdef cnp.ndarray y 37 | cdef cnp.ndarray y_col_sum 38 | cdef cnp.ndarray y_row_sum 39 | cdef cnp.ndarray sample_weight 40 | cdef cnp.ndarray output_weight 41 | cdef INDEX_t m 42 | cdef INDEX_t n 43 | cdef FLOAT_t sst 44 | cdef FLOAT_t y_squared 45 | cdef FLOAT_t total_weight 46 | 47 | # Knot search data 48 | cdef MultipleOutcomeDependentData outcome 49 | cdef list predictors 50 | cdef list workings 51 | cdef INDEX_t n_outcomes 52 | 53 | # Working floating point data 54 | cdef cnp.ndarray B # Data matrix in basis space 55 | cdef cnp.ndarray B_orth # Orthogonalized version of B 56 | cdef cnp.ndarray c 57 | cdef cnp.ndarray c_sqr 58 | cdef cnp.ndarray norms 59 | cdef cnp.ndarray u 60 | cdef cnp.ndarray B_orth_times_parent_cum 61 | cdef FLOAT_t c_squared 62 | 63 | # Working integer data 64 | cdef cnp.ndarray sort_tracker 65 | cdef cnp.ndarray sorting 66 | cdef cnp.ndarray mwork 67 | cdef cnp.ndarray linear_variables 68 | cdef int iteration_number 69 | cdef cnp.ndarray has_missing 70 | 71 | # Object construction 72 | cdef ForwardPassRecord record 73 | cdef Basis basis 74 | 75 | cpdef Basis get_basis(ForwardPasser self) 76 | 77 | cpdef init_linear_variables(ForwardPasser self) 78 | 79 | cpdef run(ForwardPasser self) 80 | 81 | cdef stop_check(ForwardPasser self) 82 | 83 | cpdef orthonormal_update(ForwardPasser self, b) 84 | 85 | cpdef orthonormal_downdate(ForwardPasser self) 86 | 87 | cdef next_pair(ForwardPasser self) 88 | 89 | # cdef best_knot(ForwardPasser self, INDEX_t parent, cnp.ndarray[FLOAT_t, ndim=1] x, 90 | # INDEX_t k, cnp.ndarray[INT_t, ndim=1] candidates, 91 | # cnp.ndarray[INT_t, ndim=1] order, 92 | # FLOAT_t * mse, FLOAT_t * knot, 93 | # INDEX_t * knot_idx) 94 | -------------------------------------------------------------------------------- /pyearth/_forward.pyxdep: -------------------------------------------------------------------------------- 1 | _basis 2 | _util 3 | _record 4 | _types -------------------------------------------------------------------------------- /pyearth/_knot_search.pxd: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 3 | from _basis cimport BasisFunction 4 | from _qr cimport UpdatingQT 5 | 6 | @cython.final 7 | cdef class SingleWeightDependentData: 8 | cdef readonly UpdatingQT updating_qt 9 | cdef readonly FLOAT_t[:] w 10 | cdef readonly INDEX_t m 11 | cdef readonly INDEX_t k 12 | cdef readonly INDEX_t max_terms 13 | cdef readonly FLOAT_t[:, :] Q_t 14 | cdef readonly FLOAT_t total_weight 15 | # cpdef int update_from_basis_function(SingleWeightDependentData self, BasisFunction bf, FLOAT_t[:,:] X, 16 | # BOOL_t[:,:] missing) except * 17 | cpdef int update_from_array(SingleWeightDependentData self, FLOAT_t[:] b) except * 18 | # cpdef int _update(SingleWeightDependentData self, FLOAT_t zero_tol) 19 | cpdef downdate(SingleWeightDependentData self) 20 | cpdef reweight(SingleWeightDependentData self, FLOAT_t[:] w, FLOAT_t[:,:] B, INDEX_t k) 21 | 22 | @cython.final 23 | cdef class MultipleOutcomeDependentData: 24 | cdef list outcomes 25 | cdef list weights 26 | cpdef update_from_array(MultipleOutcomeDependentData self, FLOAT_t[:] b) 27 | cpdef downdate(MultipleOutcomeDependentData self) 28 | cpdef list sse(MultipleOutcomeDependentData self) 29 | cpdef FLOAT_t mse(MultipleOutcomeDependentData self) 30 | 31 | @cython.final 32 | cdef class SingleOutcomeDependentData: 33 | cdef readonly FLOAT_t[:] y 34 | cdef readonly SingleWeightDependentData weight 35 | cdef readonly FLOAT_t[:] theta 36 | cdef public FLOAT_t omega 37 | cdef public FLOAT_t sse_ 38 | cdef public INDEX_t m 39 | cdef public INDEX_t k 40 | cdef public INDEX_t max_terms 41 | cdef public object householder 42 | cpdef FLOAT_t sse(SingleOutcomeDependentData self) 43 | cpdef int synchronize(SingleOutcomeDependentData self) except * 44 | cpdef int update(SingleOutcomeDependentData self) except * 45 | cpdef downdate(SingleOutcomeDependentData self) 46 | 47 | 48 | @cython.final 49 | cdef class PredictorDependentData: 50 | cdef readonly FLOAT_t[:] p 51 | cdef readonly FLOAT_t[:] x 52 | cdef readonly FLOAT_t[:] candidates 53 | cdef readonly INDEX_t[:] order 54 | 55 | @cython.final 56 | cdef class KnotSearchReadOnlyData: 57 | cdef readonly PredictorDependentData predictor 58 | cdef readonly MultipleOutcomeDependentData outcome 59 | 60 | @cython.final 61 | cdef class KnotSearchState: 62 | cdef public FLOAT_t alpha 63 | cdef public FLOAT_t beta 64 | cdef public FLOAT_t lambda_ 65 | cdef public FLOAT_t mu 66 | cdef public FLOAT_t upsilon 67 | cdef public FLOAT_t phi 68 | cdef public FLOAT_t phi_next 69 | cdef public INDEX_t ord_idx 70 | cdef public INDEX_t idx 71 | cdef public FLOAT_t zeta_squared 72 | 73 | @cython.final 74 | cdef class KnotSearchWorkingData: 75 | cdef readonly FLOAT_t[:] gamma 76 | cdef readonly FLOAT_t[:] kappa 77 | cdef readonly FLOAT_t[:] delta_kappa 78 | cdef readonly FLOAT_t[:] chi 79 | cdef readonly FLOAT_t[:] psi 80 | cdef KnotSearchState state 81 | 82 | @cython.final 83 | cdef class KnotSearchData: 84 | cdef readonly KnotSearchReadOnlyData constant 85 | cdef readonly list workings 86 | cdef public INDEX_t q 87 | 88 | cdef dot(FLOAT_t[:] x1, FLOAT_t[:] x2, INDEX_t q) 89 | cdef w2dot(FLOAT_t[:] w, FLOAT_t[:] x1, FLOAT_t[:] x2, INDEX_t q) 90 | cdef wdot(FLOAT_t[:] w, FLOAT_t[:] x1, FLOAT_t[:] x2, INDEX_t q) 91 | cdef inline void fast_update(PredictorDependentData predictor, SingleOutcomeDependentData outcome, 92 | KnotSearchWorkingData working, FLOAT_t[:] p, INDEX_t q, INDEX_t m ,INDEX_t r) except * 93 | cpdef tuple knot_search(KnotSearchData data, FLOAT_t[:] candidates, FLOAT_t[:] p, INDEX_t q, INDEX_t m, INDEX_t r, INDEX_t n_outcomes, int verbose) 94 | 95 | -------------------------------------------------------------------------------- /pyearth/_knot_search.pyxdep: -------------------------------------------------------------------------------- 1 | _types 2 | _util -------------------------------------------------------------------------------- /pyearth/_pruning.pxd: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 3 | from _basis cimport Basis 4 | from _record cimport PruningPassRecord 5 | 6 | cdef class PruningPasser: 7 | cdef cnp.ndarray X 8 | cdef cnp.ndarray missing 9 | cdef cnp.ndarray B 10 | cdef cnp.ndarray y 11 | cdef cnp.ndarray sample_weight 12 | cdef int verbose 13 | cdef cnp.ndarray output_weight 14 | cdef public dict feature_importance 15 | 16 | cdef INDEX_t m 17 | cdef INDEX_t n 18 | cdef Basis basis 19 | cdef FLOAT_t penalty 20 | cdef FLOAT_t sst 21 | cdef PruningPassRecord record 22 | 23 | cpdef run(PruningPasser self) 24 | 25 | cpdef PruningPassRecord trace(PruningPasser self) 26 | -------------------------------------------------------------------------------- /pyearth/_pruning.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # cython: cdivision = True 3 | # cython: boundscheck = False 4 | # cython: wraparound = False 5 | # cython: profile = False 6 | 7 | from ._record cimport PruningPassIteration 8 | from ._util cimport gcv, apply_weights_2d 9 | import numpy as np 10 | 11 | from collections import defaultdict 12 | 13 | GCV, RSS, NB_SUBSETS = "gcv", "rss", "nb_subsets" 14 | FEAT_IMP_CRITERIA = (GCV, RSS, NB_SUBSETS) 15 | 16 | cdef class PruningPasser: 17 | '''Implements the generic pruning pass as described by Friedman, 1991.''' 18 | def __init__(PruningPasser self, Basis basis, 19 | cnp.ndarray[FLOAT_t, ndim=2] X, 20 | cnp.ndarray[BOOL_t, ndim=2] missing, 21 | cnp.ndarray[FLOAT_t, ndim=2] y, 22 | cnp.ndarray[FLOAT_t, ndim=2] sample_weight, int verbose, 23 | **kwargs): 24 | self.X = X 25 | self.missing = missing 26 | self.m = self.X.shape[0] 27 | self.n = self.X.shape[1] 28 | self.y = y 29 | self.sample_weight = sample_weight 30 | self.verbose = verbose 31 | self.basis = basis 32 | self.B = np.empty(shape=(self.m, len(self.basis) + 1), dtype=np.float) 33 | self.penalty = kwargs.get('penalty', 3.0) 34 | if sample_weight.shape[1] == 1: 35 | y_avg = np.average(self.y, weights=sample_weight[:,0], axis=0) 36 | else: 37 | y_avg = np.average(self.y, weights=sample_weight, axis=0) 38 | 39 | # feature importance 40 | feature_importance_criteria = kwargs.get("feature_importance_type", []) 41 | if isinstance(feature_importance_criteria, basestring): 42 | feature_importance_criteria = [feature_importance_criteria] 43 | self.feature_importance = dict() 44 | for criterion in feature_importance_criteria: 45 | self.feature_importance[criterion] = np.zeros((self.n,)) 46 | 47 | cpdef run(PruningPasser self): 48 | # This is a totally naive implementation and could potentially be made 49 | # faster through the use of updating algorithms. It is not clear that 50 | # such optimization would be worthwhile, as the pruning pass is not the 51 | # slowest part of the algorithm. 52 | cdef INDEX_t i 53 | cdef INDEX_t j 54 | cdef long v 55 | cdef INDEX_t basis_size = len(self.basis) 56 | cdef INDEX_t pruned_basis_size = self.basis.plen() 57 | cdef FLOAT_t gcv_ 58 | cdef INDEX_t best_iteration 59 | cdef INDEX_t best_bf_to_prune 60 | cdef FLOAT_t best_gcv 61 | cdef FLOAT_t best_iteration_gcv 62 | cdef FLOAT_t best_iteration_mse 63 | cdef FLOAT_t mse, mse0, total_weight 64 | 65 | cdef cnp.ndarray[FLOAT_t, ndim = 2] B = ( 66 | self.B) 67 | cdef cnp.ndarray[FLOAT_t, ndim = 2] X = ( 68 | self.X) 69 | cdef cnp.ndarray[BOOL_t, ndim = 2] missing = ( 70 | self.missing) 71 | cdef cnp.ndarray[FLOAT_t, ndim = 2] y = ( 72 | self.y) 73 | cdef cnp.ndarray[FLOAT_t, ndim = 2] sample_weight = ( 74 | self.sample_weight) 75 | cdef cnp.ndarray[FLOAT_t, ndim = 1] weighted_y 76 | 77 | if self.verbose >= 1: 78 | print('Beginning pruning pass') 79 | 80 | # Initial solution 81 | mse = 0. 82 | mse0 = 0. 83 | total_weight = 0. 84 | for p in range(y.shape[1]): 85 | if sample_weight.shape[1] == 1: 86 | weighted_y = y[:,p] * np.sqrt(sample_weight[:,0]) 87 | self.basis.weighted_transform(X, missing, B, sample_weight[:, 0]) 88 | total_weight += np.sum(sample_weight[:,0]) 89 | mse0 += np.sum(sample_weight[:,0] * (y[:,p] - np.average(y[:,p], weights=sample_weight[:,0])) ** 2) 90 | else: 91 | weighted_y = y[:,p] * np.sqrt(sample_weight[:,p]) 92 | self.basis.weighted_transform(X, missing, B, sample_weight[:, p]) 93 | total_weight += np.sum(sample_weight[:,p]) 94 | mse0 += np.sum(sample_weight[:,p] * (y[:,p] - np.average(y[:,p], weights=sample_weight[:,p])) ** 2) 95 | if sample_weight.shape[1] == 1: 96 | self.basis.weighted_transform(X, missing, B, sample_weight[:, 0]) 97 | else: 98 | self.basis.weighted_transform(X, missing, B, sample_weight[:, p]) 99 | beta, mse_ = np.linalg.lstsq(B[:, 0:(basis_size)], weighted_y)[0:2] 100 | if mse_: 101 | pass 102 | else: 103 | mse_ = np.sum( 104 | (np.dot(B[:, 0:basis_size], beta) - weighted_y) ** 2) 105 | mse += mse_ 106 | 107 | # Create the record object 108 | self.record = PruningPassRecord( 109 | self.m, self.n, self.penalty, mse0 / total_weight, pruned_basis_size, mse / total_weight) 110 | gcv_ = self.record.gcv(0) 111 | best_gcv = gcv_ 112 | best_iteration = 0 113 | 114 | if self.verbose >= 1: 115 | print(self.record.partial_str(slice(-1, None, None), print_footer=False)) 116 | 117 | # init feature importance 118 | prev_best_iteration_gcv = None 119 | prev_best_iteration_mse = None 120 | 121 | # Prune basis functions sequentially 122 | for i in range(1, pruned_basis_size): 123 | first = True 124 | pruned_basis_size -= 1 125 | 126 | # Find the best basis function to prune 127 | for j in range(basis_size): 128 | bf = self.basis[j] 129 | if bf.is_pruned(): 130 | continue 131 | if not bf.is_prunable(): 132 | continue 133 | bf.prune() 134 | 135 | 136 | mse = 0. 137 | for p in range(y.shape[1]): 138 | if sample_weight.shape[1] == 1: 139 | weighted_y = y[:,p] * np.sqrt(sample_weight[:,0]) 140 | self.basis.weighted_transform(X, missing, B, sample_weight[:, 0]) 141 | else: 142 | weighted_y = y[:,p] * np.sqrt(sample_weight[:,p]) 143 | self.basis.weighted_transform(X, missing, B, sample_weight[:, p]) 144 | beta, mse_ = np.linalg.lstsq( 145 | B[:, 0:pruned_basis_size], weighted_y)[0:2] 146 | if mse_: 147 | pass 148 | # mse_ /= np.sum(self.sample_weight) 149 | else: 150 | mse_ = np.sum((np.dot(B[:, 0:pruned_basis_size], beta) - 151 | weighted_y) ** 2) #/ np.sum(sample_weight) 152 | mse += mse_# * output_weight[p] 153 | gcv_ = gcv(mse / np.sum(sample_weight), pruned_basis_size, self.m, self.penalty) 154 | 155 | if gcv_ <= best_iteration_gcv or first: 156 | best_iteration_gcv = gcv_ 157 | best_iteration_mse = mse 158 | best_bf_to_prune = j 159 | first = False 160 | bf.unprune() 161 | 162 | # Feature importance 163 | if i > 1: 164 | # having selected the best basis to prune, we compute how much 165 | # that basis decreased the mse and gcv relative to the previous mse and gcv 166 | # respectively. 167 | mse_decrease = (best_iteration_mse - prev_best_iteration_mse) 168 | gcv_decrease = (best_iteration_gcv - prev_best_iteration_gcv) 169 | variables = set() 170 | bf = self.basis[best_bf_to_prune] 171 | for v in bf.variables(): 172 | variables.add(v) 173 | for v in variables: 174 | if RSS in self.feature_importance: 175 | self.feature_importance[RSS][v] += mse_decrease 176 | if GCV in self.feature_importance: 177 | self.feature_importance[GCV][v] += gcv_decrease 178 | if NB_SUBSETS in self.feature_importance: 179 | self.feature_importance[NB_SUBSETS][v] += 1 180 | # The inner loop found the best basis function to remove for this 181 | # iteration. Now check whether this iteration is better than all 182 | # the previous ones. 183 | if best_iteration_gcv <= best_gcv: 184 | best_gcv = best_iteration_gcv 185 | best_iteration = i 186 | 187 | prev_best_iteration_gcv = best_iteration_gcv 188 | prev_best_iteration_mse = best_iteration_mse 189 | # Update the record and prune the selected basis function 190 | self.record.append(PruningPassIteration( 191 | best_bf_to_prune, pruned_basis_size, best_iteration_mse / total_weight)) 192 | self.basis[best_bf_to_prune].prune() 193 | 194 | if self.verbose >= 1: 195 | print(self.record.partial_str(slice(-1, None, None), print_header=False, print_footer=(pruned_basis_size == 1))) 196 | 197 | # Unprune the basis functions pruned after the best iteration 198 | self.record.set_selected(best_iteration) 199 | self.record.roll_back(self.basis) 200 | if self.verbose >= 1: 201 | print(self.record.final_str()) 202 | 203 | # normalize feature importance values 204 | for name, val in self.feature_importance.items(): 205 | if name == 'gcv': 206 | val[val < 0] = 0 # gcv can have negative feature importance correponding to an increase of gcv, set them to zero 207 | if val.sum() > 0: 208 | val /= val.sum() 209 | self.feature_importance[name] = val 210 | 211 | cpdef PruningPassRecord trace(PruningPasser self): 212 | return self.record 213 | 214 | -------------------------------------------------------------------------------- /pyearth/_pruning.pyxdep: -------------------------------------------------------------------------------- 1 | _types -------------------------------------------------------------------------------- /pyearth/_qr.pxd: -------------------------------------------------------------------------------- 1 | from cython cimport view 2 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 3 | 4 | cdef class UpdatingQT: 5 | cdef readonly int m 6 | cdef readonly int max_n 7 | cdef readonly Householder householder 8 | cdef readonly int k 9 | cdef readonly FLOAT_t[::1, :] Q_t 10 | cdef readonly FLOAT_t zero_tol 11 | cdef readonly BOOL_t[::1] dependent_cols 12 | cpdef void update_qt(UpdatingQT self, bint dependent) 13 | cpdef void update(UpdatingQT self, FLOAT_t[:] x) 14 | cpdef void downdate(UpdatingQT self) 15 | cpdef void reset(UpdatingQT self) 16 | 17 | cdef class Householder: 18 | cdef readonly int k 19 | cdef readonly int m 20 | cdef readonly int max_n 21 | cdef readonly FLOAT_t[::1, :] V 22 | cdef readonly FLOAT_t[::1, :] T 23 | cdef readonly FLOAT_t[::1] tau 24 | cdef readonly FLOAT_t[::1] beta 25 | cdef readonly FLOAT_t[::1, :] work 26 | cdef readonly FLOAT_t zero_tol 27 | cpdef void downdate(Householder self) 28 | cpdef void reset(Householder self) 29 | cpdef bint update_from_column(Householder self, FLOAT_t[:] c) 30 | cpdef bint update_v_t(Householder self) 31 | cpdef void left_apply(Householder self, FLOAT_t[::1, :] C) 32 | cpdef void left_apply_transpose(Householder self, FLOAT_t[::1, :] C) 33 | cpdef void right_apply(Householder self, FLOAT_t[::1, :] C) 34 | cpdef void right_apply_transpose(Householder self, FLOAT_t[::1, :] C) 35 | -------------------------------------------------------------------------------- /pyearth/_qr.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # cython: cdivision = True 3 | # cython: boundscheck = False 4 | # cython: wraparound = False 5 | # cython: profile = False 6 | import numpy as np 7 | from scipy.linalg.cython_lapack cimport dlarfg, dlarft, dlarfb 8 | from scipy.linalg.cython_blas cimport dcopy 9 | from libc.math cimport abs 10 | from _types import BOOL, FLOAT 11 | 12 | cdef class UpdatingQT: 13 | def __init__(UpdatingQT self, int m, int max_n, Householder householder, 14 | int k, FLOAT_t[::1, :] Q_t, FLOAT_t zero_tol, BOOL_t[::1] dependent_cols): 15 | self.m = m 16 | self.max_n = max_n 17 | self.householder = householder 18 | self.k = k 19 | self.Q_t = Q_t 20 | self.zero_tol = zero_tol 21 | self.dependent_cols = dependent_cols 22 | 23 | @classmethod 24 | def alloc(cls, int m, int max_n, FLOAT_t zero_tol): 25 | cdef Householder householder = Householder.alloc(m, max_n, zero_tol) 26 | cdef FLOAT_t[::1, :] Q_t = np.empty(shape=(max_n, m), dtype=FLOAT, order='F') 27 | cdef BOOL_t[::1] dependent_cols = np.empty(shape=max_n, dtype=BOOL, order='F') 28 | return cls(m, max_n, householder, 0, Q_t, zero_tol, dependent_cols) 29 | 30 | cpdef void update_qt(UpdatingQT self, bint dependent): 31 | # Assume that householder has already been updated and now Q_t needs to be updated 32 | # accordingly 33 | 34 | # Zero out the new row of Q_t 35 | cdef FLOAT_t zero = 0. 36 | cdef int zero_int = 0 37 | cdef int N = self.m 38 | cdef FLOAT_t * y = &(self.Q_t[self.k, 0]) 39 | cdef int incy = self.max_n 40 | dcopy(&N, &zero, &zero_int, y, &incy) 41 | 42 | if not dependent: 43 | 44 | # Place a one in the right place 45 | # In general self.householder.k <= self.k + 1. 46 | # They are not necessarily equal. 47 | self.Q_t[self.k, self.householder.k - 1] = 1. 48 | 49 | # Apply the householder transformation 50 | self.householder.right_apply_transpose(self.Q_t[self.k:self.k+1, :]) 51 | 52 | self.k += 1 53 | 54 | 55 | cpdef void update(UpdatingQT self, FLOAT_t[:] x): 56 | # Updates householder, then calls 57 | # update_qt 58 | 59 | # The Householder will detect if the new vector is linearly dependent on the previous 60 | # ones (within numerical precision specified by zero_tol). 61 | cdef bint dependent 62 | dependent = self.householder.update_from_column(x) 63 | 64 | # Mark the column as independent or dependent. This information will be needed if the 65 | # column is ever downdated, since we then need to not downdate householder 66 | self.dependent_cols[self.k] = dependent 67 | 68 | # If linear dependence was detected, the householder will have failed to update 69 | # (as it should). In that case, we want a row of zeros in our Q_t matrix because 70 | # the row space of Q_t should be the same as the span of all the x vectors passed to update. 71 | # A row of zeros makes this possible while still having self.k match the relevant dimension 72 | # of Q_t. The update_qt method takes care of adding the zeros if dependent. Note this means 73 | # that in general self.householder.k <= self.k. They are not necessarily equal. 74 | self.update_qt(dependent) 75 | 76 | 77 | 78 | cpdef void downdate(UpdatingQT self): 79 | self.k -= 1 80 | if not self.dependent_cols[self.k]: 81 | self.householder.downdate() 82 | 83 | cpdef void reset(UpdatingQT self): 84 | self.householder.reset() 85 | self.k = 0 86 | 87 | cdef class Householder: 88 | 89 | def __init__(Householder self, int k, int m, int max_n, 90 | FLOAT_t[::1, :] V, FLOAT_t[::1, :] T, FLOAT_t[::1] tau, 91 | FLOAT_t[::1] beta, FLOAT_t[::1, :] work, FLOAT_t zero_tol): 92 | self.k = k 93 | self.m = m 94 | self.max_n = max_n 95 | self.V = V 96 | self.T = T 97 | self.tau = tau 98 | self.beta = beta 99 | self.work = work 100 | self.zero_tol = zero_tol 101 | 102 | @classmethod 103 | def alloc(cls, int m, int max_n, FLOAT_t zero_tol): 104 | cdef int k = 0 105 | cdef FLOAT_t[::1, :] V = np.empty(shape=(m, max_n), dtype=FLOAT, order='F') 106 | cdef FLOAT_t[::1, :] T = np.empty(shape=(max_n, max_n), dtype=FLOAT, order='F') 107 | cdef FLOAT_t[::1] tau = np.empty(shape=max_n, dtype=FLOAT, order='F') 108 | cdef FLOAT_t[::1] beta = np.empty(shape=max_n, dtype=FLOAT, order='F') 109 | cdef FLOAT_t[::1, :] work = np.empty(shape=(m, max_n), dtype=FLOAT, order='F') 110 | return cls(k, m, max_n, V, T, tau, beta, work, zero_tol) 111 | 112 | cpdef void downdate(Householder self): 113 | self.k -= 1 114 | 115 | cpdef void reset(Householder self): 116 | self.k = 0 117 | 118 | cpdef bint update_from_column(Householder self, FLOAT_t[:] c): 119 | # Copies c, applies self, then updates V and T 120 | 121 | # Copy c into V 122 | cdef int N = self.m 123 | cdef FLOAT_t * x = &(c[0]) 124 | cdef int incx = c.strides[0] / c.itemsize 125 | cdef FLOAT_t * y = &(self.V[0, self.k]) 126 | cdef int incy = 1 127 | dcopy(&N, x, &incx, y, &incy) 128 | 129 | # Apply self to new column in V 130 | self.left_apply_transpose(self.V[:, self.k:self.k+1]) 131 | 132 | # Update V and T (increments k) 133 | return self.update_v_t() 134 | 135 | 136 | cpdef bint update_v_t(Householder self): 137 | # Assume relevant data has been copied into self.V correctly, as by 138 | # update_from_column. Update V and T appropriately. 139 | cdef int n = self.m - self.k 140 | cdef FLOAT_t alpha = self.V[self.k, self.k] 141 | cdef FLOAT_t* x = &(self.V[(self.k + 1), self.k]) 142 | cdef int incx = self.V.strides[0] // self.V.itemsize 143 | cdef FLOAT_t tau 144 | cdef FLOAT_t beta 145 | cdef bint dependent 146 | 147 | # Compute the householder reflection 148 | dlarfg(&n, &alpha, x, &incx, &tau) 149 | beta = alpha 150 | 151 | # If beta is very close to zero, the new column was linearly 152 | # dependent on the previous columns. In that case, it's best if 153 | # we just pretend this never happened. Note that this means k 154 | # will not be incremented. UpdatingQT knows how to handle this 155 | # case, and will be informed by the return value. 156 | dependent = abs(beta) < self.zero_tol 157 | if dependent: 158 | return dependent 159 | 160 | # Add the new householder reflection to the 161 | # block reflector 162 | # TODO: Currently requires recalculating all of T 163 | # Could be updated to use BLAS instead to calculate 164 | # just the new column of T. I'm not sure how to 165 | # do this or whether it would be faster. 166 | self.V[self.k, self.k] = 1. 167 | self.V[:self.k, self.k] = 0. 168 | self.tau[self.k] = tau 169 | self.beta[self.k] = alpha 170 | cdef char direct = 'F' 171 | cdef char storev = 'C' 172 | n = self.m 173 | cdef int k = self.k + 1 174 | cdef FLOAT_t * V = &(self.V[0,0]) 175 | cdef int ldv = self.m 176 | cdef FLOAT_t * T = &(self.T[0,0]) 177 | cdef FLOAT_t * tau_arg = &(self.tau[0]) 178 | cdef int ldt = self.max_n 179 | dlarft(&direct, &storev, &n, &k, V, &ldv, tau_arg, T, &ldt) 180 | 181 | self.k += 1 182 | # Return beta in case the caller wants to diagnose linear dependence. 183 | return dependent 184 | 185 | cpdef void left_apply(Householder self, FLOAT_t[::1, :] C): 186 | cdef char side = 'L' 187 | cdef char trans = 'N' 188 | cdef char direct = 'F' 189 | cdef char storev = 'C' 190 | cdef int M = C.shape[0] 191 | cdef int N = C.shape[1] 192 | cdef int K = self.k 193 | cdef FLOAT_t * V = &(self.V[0, 0]) 194 | cdef int ldv = self.m 195 | cdef FLOAT_t * T = &(self.T[0, 0]) 196 | cdef int ldt = self.max_n 197 | cdef FLOAT_t * C_arg = &(C[0, 0]) 198 | cdef int ldc = C.strides[1] // C.itemsize 199 | cdef FLOAT_t * work = &(self.work[0,0]) 200 | cdef int ldwork = self.m 201 | print C.shape 202 | dlarfb(&side, &trans, &direct, &storev, &M, &N, &K, 203 | V, &ldv, T, &ldt, C_arg, &ldc, work, &ldwork) 204 | 205 | cpdef void left_apply_transpose(Householder self, FLOAT_t[::1, :] C): 206 | cdef char side = 'L' 207 | cdef char trans = 'T' 208 | cdef char direct = 'F' 209 | cdef char storev = 'C' 210 | cdef int M = C.shape[0] 211 | cdef int N = C.shape[1] 212 | cdef int K = self.k 213 | cdef FLOAT_t * V = &(self.V[0, 0]) 214 | cdef int ldv = self.m 215 | cdef FLOAT_t * T = &(self.T[0, 0]) 216 | cdef int ldt = self.max_n 217 | cdef FLOAT_t * C_arg = &(C[0, 0]) 218 | cdef int ldc = C.strides[1] // C.itemsize 219 | cdef FLOAT_t * work = &(self.work[0,0]) 220 | cdef int ldwork = self.m 221 | 222 | dlarfb(&side, &trans, &direct, &storev, &M, &N, &K, 223 | V, &ldv, T, &ldt, C_arg, &ldc, work, &ldwork) 224 | 225 | cpdef void right_apply(Householder self, FLOAT_t[::1, :] C): 226 | cdef char side = 'R' 227 | cdef char trans = 'N' 228 | cdef char direct = 'F' 229 | cdef char storev = 'C' 230 | cdef int M = C.shape[0] 231 | cdef int N = C.shape[1] 232 | cdef int K = self.k 233 | cdef FLOAT_t * V = &(self.V[0, 0]) 234 | cdef int ldv = self.m 235 | cdef FLOAT_t * T = &(self.T[0, 0]) 236 | cdef int ldt = self.max_n 237 | cdef FLOAT_t * C_arg = &(C[0, 0]) 238 | cdef int ldc = C.strides[1] // C.itemsize 239 | cdef FLOAT_t * work = &(self.work[0,0]) 240 | cdef int ldwork = self.m 241 | 242 | dlarfb(&side, &trans, &direct, &storev, &M, &N, &K, 243 | V, &ldv, T, &ldt, C_arg, &ldc, work, &ldwork) 244 | 245 | cpdef void right_apply_transpose(Householder self, FLOAT_t[::1, :] C): 246 | cdef char side = 'R' 247 | cdef char trans = 'T' 248 | cdef char direct = 'F' 249 | cdef char storev = 'C' 250 | cdef int M = C.shape[0] 251 | cdef int N = C.shape[1] 252 | cdef int K = self.k 253 | cdef FLOAT_t * V = &(self.V[0, 0]) 254 | cdef int ldv = self.m 255 | cdef FLOAT_t * T = &(self.T[0, 0]) 256 | cdef int ldt = self.max_n 257 | cdef FLOAT_t * C_arg = &(C[0, 0]) 258 | cdef int ldc = C.strides[1] // C.itemsize 259 | cdef FLOAT_t * work = &(self.work[0,0]) 260 | cdef int ldwork = self.m 261 | 262 | dlarfb(&side, &trans, &direct, &storev, &M, &N, &K, 263 | V, &ldv, T, &ldt, C_arg, &ldc, work, &ldwork) 264 | # 265 | -------------------------------------------------------------------------------- /pyearth/_qr.pyxdep: -------------------------------------------------------------------------------- 1 | _types -------------------------------------------------------------------------------- /pyearth/_record.pxd: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 3 | from _basis cimport Basis 4 | 5 | cdef class Record: 6 | cdef list iterations 7 | cdef int num_samples 8 | cdef int num_variables 9 | cdef FLOAT_t penalty 10 | cdef FLOAT_t sst # Sum of squares total 11 | 12 | cpdef append(Record self, Iteration iteration) 13 | 14 | cpdef FLOAT_t mse(Record self, INDEX_t iteration) 15 | 16 | cpdef FLOAT_t rsq(Record self, INDEX_t iteration) 17 | 18 | cpdef FLOAT_t gcv(Record self, INDEX_t iteration) 19 | 20 | cpdef FLOAT_t grsq(Record self, INDEX_t iteration) 21 | 22 | cdef class PruningPassRecord(Record): 23 | cdef readonly INDEX_t selected 24 | 25 | cpdef set_selected(PruningPassRecord self, INDEX_t selected) 26 | 27 | cpdef INDEX_t get_selected(PruningPassRecord self) 28 | 29 | cpdef roll_back(PruningPassRecord self, Basis basis) 30 | 31 | cdef class ForwardPassRecord(Record): 32 | cdef readonly int stopping_condition 33 | 34 | cdef list xlabels 35 | 36 | cpdef set_stopping_condition(ForwardPassRecord self, int stopping_condition) 37 | 38 | cdef class Iteration: 39 | cdef FLOAT_t mse 40 | cdef INDEX_t size 41 | 42 | cpdef FLOAT_t get_mse(Iteration self) 43 | 44 | cpdef INDEX_t get_size(Iteration self) 45 | 46 | cdef class PruningPassIteration(Iteration): 47 | cdef INDEX_t pruned 48 | 49 | cpdef INDEX_t get_pruned(PruningPassIteration self) 50 | 51 | cdef class FirstPruningPassIteration(PruningPassIteration): 52 | pass 53 | 54 | cdef class ForwardPassIteration(Iteration): 55 | cdef INDEX_t parent 56 | cdef INDEX_t variable 57 | cdef FLOAT_t knot 58 | cdef int code 59 | cdef bint no_candidates 60 | 61 | cpdef set_no_candidates(ForwardPassIteration self, bint value) 62 | 63 | cpdef no_further_candidates(ForwardPassIteration self) 64 | 65 | cdef class FirstForwardPassIteration(ForwardPassIteration): 66 | cpdef INDEX_t get_size(FirstForwardPassIteration self) 67 | -------------------------------------------------------------------------------- /pyearth/_record.pyxdep: -------------------------------------------------------------------------------- 1 | _types -------------------------------------------------------------------------------- /pyearth/_types.pxd: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | ctypedef cnp.float64_t FLOAT_t 3 | ctypedef cnp.int_t INT_t 4 | ctypedef cnp.intp_t INDEX_t 5 | ctypedef cnp.uint8_t BOOL_t 6 | -------------------------------------------------------------------------------- /pyearth/_types.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | FLOAT = np.float64 3 | INT = np.int 4 | INDEX = np.intp 5 | BOOL = np.uint8 6 | -------------------------------------------------------------------------------- /pyearth/_util.pxd: -------------------------------------------------------------------------------- 1 | cimport numpy as cnp 2 | from _types cimport FLOAT_t, INT_t, INDEX_t, BOOL_t 3 | 4 | cdef FLOAT_t log2(FLOAT_t x) 5 | 6 | cpdef apply_weights_2d(cnp.ndarray[FLOAT_t, ndim=2] B, 7 | cnp.ndarray[FLOAT_t, ndim=1] weights) 8 | 9 | cpdef apply_weights_slice(cnp.ndarray[FLOAT_t, ndim=2] B, 10 | cnp.ndarray[FLOAT_t, ndim=1] weights, INDEX_t column) 11 | 12 | cpdef apply_weights_1d(cnp.ndarray[FLOAT_t, ndim=1] y, 13 | cnp.ndarray[FLOAT_t, ndim=1] weights) 14 | 15 | cpdef FLOAT_t gcv(FLOAT_t mse, 16 | FLOAT_t basis_size, FLOAT_t data_size, 17 | FLOAT_t penalty) 18 | 19 | cpdef FLOAT_t gcv_adjust(FLOAT_t basis_size, FLOAT_t data_size, FLOAT_t penalty) 20 | 21 | cpdef str_pad(string, length) 22 | 23 | cpdef ascii_table(header, data, print_header=?, print_footer=?) 24 | -------------------------------------------------------------------------------- /pyearth/_util.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # cython: cdivision = True 3 | # cython: boundscheck = False 4 | # cython: wraparound = False 5 | # cython: profile = False 6 | 7 | import numpy as np 8 | from libc.math cimport sqrt, log 9 | 10 | cdef FLOAT_t log2(FLOAT_t x): 11 | return log(x) / log(2.0) 12 | 13 | cpdef apply_weights_2d(cnp.ndarray[FLOAT_t, ndim=2] B, 14 | cnp.ndarray[FLOAT_t, ndim=1] weights): 15 | cdef INDEX_t i 16 | cdef INDEX_t j 17 | cdef INDEX_t m = B.shape[0] 18 | cdef INDEX_t n = B.shape[1] 19 | for i in range(m): 20 | for j in range(n): 21 | B[i, j] *= sqrt(weights[i]) 22 | 23 | cpdef apply_weights_slice(cnp.ndarray[FLOAT_t, ndim=2] B, 24 | cnp.ndarray[FLOAT_t, ndim=1] weights, 25 | INDEX_t column): 26 | cdef INDEX_t i 27 | cdef INDEX_t j 28 | cdef INDEX_t m = B.shape[0] 29 | cdef INDEX_t n = B.shape[1] 30 | for i in range(m): 31 | B[i, column] *= sqrt(weights[i]) 32 | 33 | cpdef apply_weights_1d(cnp.ndarray[FLOAT_t, ndim=1] y, 34 | cnp.ndarray[FLOAT_t, ndim=1] weights): 35 | cdef INDEX_t i 36 | cdef INDEX_t m = y.shape[0] 37 | for i in range(m): 38 | y[i] *= sqrt(weights[i]) 39 | 40 | cpdef FLOAT_t gcv(FLOAT_t mse, FLOAT_t basis_size, FLOAT_t data_size, 41 | FLOAT_t penalty): 42 | return mse * gcv_adjust(basis_size, data_size, penalty) 43 | 44 | cpdef FLOAT_t gcv_adjust(FLOAT_t basis_size, FLOAT_t data_size, 45 | FLOAT_t penalty): 46 | cdef FLOAT_t effective_parameters 47 | effective_parameters = basis_size + penalty * (basis_size - 1) / 2.0 48 | return 1.0 / ( ( (1.0 - (effective_parameters / data_size)) ** 2 ) ) 49 | 50 | cpdef str_pad(string, length): 51 | if len(string) >= length: 52 | return string[0:length] 53 | pad = length - len(string) 54 | return string + ' ' * pad 55 | 56 | cpdef ascii_table(header, data, print_header=True, print_footer=True): 57 | ''' 58 | header - list of strings representing the header row 59 | data - list of lists of strings representing data rows 60 | ''' 61 | m = len(data) 62 | n = len(header) 63 | column_widths = [len(head) for head in header] 64 | for i, row in enumerate(data): 65 | for j, col in enumerate(row): 66 | if len(col) > column_widths[j]: 67 | column_widths[j] = len(col) 68 | 69 | for j in range(n): 70 | column_widths[j] += 1 71 | 72 | result = '' 73 | if print_header: 74 | for j, col_width in enumerate(column_widths): 75 | result += '-' * col_width + '-' 76 | result += '\n' 77 | for j, head in enumerate(header): 78 | result += str_pad(head, column_widths[j]) + ' ' 79 | result += '\n' 80 | for j, col_width in enumerate(column_widths): 81 | result += '-' * col_width + '-' 82 | # result += '\n' 83 | result += '\n' 84 | for i, row in enumerate(data): 85 | if i > 0: 86 | result += '\n' 87 | for j, item in enumerate(row): 88 | result += str_pad(item, column_widths[j]) + ' ' 89 | 90 | if print_footer: 91 | result += '\n' 92 | for j, col_width in enumerate(column_widths): 93 | result += '-' * col_width + '-' 94 | return result 95 | -------------------------------------------------------------------------------- /pyearth/_util.pyxdep: -------------------------------------------------------------------------------- 1 | _types -------------------------------------------------------------------------------- /pyearth/export.py: -------------------------------------------------------------------------------- 1 | def export_python_function(earth_model): 2 | """ 3 | Exports model as a pure python function, with no numpy/scipy/sklearn dependencies. 4 | :param earth_model: Trained pyearth model 5 | :return: A function that accepts an iterator over examples, and returns an iterator over transformed examples 6 | """ 7 | i = 0 8 | accessors = [] 9 | for bf in earth_model.basis_: 10 | if not bf.is_pruned(): 11 | accessors.append(bf.func_factory(earth_model.coef_[0, i])) 12 | i += 1 13 | 14 | def func(example_iterator): 15 | return [sum(accessor(row) for accessor in accessors) for row in example_iterator] 16 | return func 17 | 18 | 19 | def export_python_string(earth_model, function_name="model"): 20 | """ 21 | Exports model as a string that evaluates as python code, with no numpy/scipy/sklearn dependencies. 22 | :param earth_model: Trained pyearth model 23 | :param function_name: string, optional, will be the name of the function in the returned string 24 | :return: string, when executed (either by writing to a file, or using `exec`, will define a python 25 | function that accepts an iterator over examples, and returns an iterator over transformed examples 26 | """ 27 | i = 0 28 | accessors = [] 29 | for bf in earth_model.basis_: 30 | if not bf.is_pruned(): 31 | accessors.append(bf.func_string_factory(earth_model.coef_[0, i])) 32 | i += 1 33 | 34 | return """def {:s}(example_iterator): 35 | accessors = [{:s}] 36 | for x in example_iterator: 37 | yield sum(accessor(x) for accessor in accessors) 38 | """.format(function_name, ",\n\t\t".join(accessors)) 39 | 40 | def export_sympy_term_expressions(earth_model): 41 | """ 42 | Construct a list of sympy expressions for all non-pruned terms in the model. 43 | 44 | :param earth_model: Trained pyearth model 45 | :return: a list of sympy expressions representing terms in the model. These 46 | expressions are the symbolic equivalent of the Earth.transform method. 47 | 48 | """ 49 | from sympy import Symbol, Add, Mul, Max, RealNumber, Piecewise, Pow, And, nan, Function, Not 50 | from ._basis import LinearBasisFunction, HingeBasisFunction, SmoothedHingeBasisFunction, \ 51 | MissingnessBasisFunction, ConstantBasisFunction, VariableBasisFunction 52 | 53 | Missing = Function('Missing') 54 | NaNProtect = Function('NaNProtect') 55 | 56 | def linear_bf_to_factor(bf, bf_var): 57 | return bf_var 58 | 59 | def smoothed_hinge_bf_to_factor(bf, bf_var): 60 | knot = RealNumber(bf.get_knot()) 61 | knot_minus = RealNumber(bf.get_knot_minus()) 62 | knot_plus = RealNumber(bf.get_knot_plus()) 63 | r = RealNumber(bf.get_r()) 64 | p = RealNumber(bf.get_p()) 65 | if bf.get_reverse(): 66 | lower_p = (-(bf_var - knot)), (bf_var <= knot_minus) 67 | upper_p = (0, bf_var >= knot_plus) 68 | left_exp = Mul(p, Pow((bf_var - knot_plus), 2)) 69 | right_exp = Mul(r, Pow((bf_var - knot_plus), 3)) 70 | middle_b = And(knot_minus < bf_var, bf_var < knot_plus) 71 | middle_exp = (Add(left_exp, right_exp), middle_b) 72 | piecewise = Piecewise(lower_p, upper_p, middle_exp) 73 | factor = piecewise 74 | else: 75 | lower_p = (0, bf_var <= knot_minus) 76 | upper_p = (bf_var - knot, bf_var >= knot_plus) 77 | left_exp = Mul(p, Pow((bf_var - knot_minus), 2)) 78 | right_exp = Mul(r, Pow((bf_var - knot_minus), 3)) 79 | middle_b = And(knot_minus < bf_var, bf_var < knot_plus) 80 | middle_exp = (Add(left_exp, right_exp), middle_b) 81 | piecewise = Piecewise(lower_p, upper_p, middle_exp) 82 | factor = piecewise 83 | return factor 84 | 85 | def hinge_bf_to_factor(bf, bf_var): 86 | knot = bf.get_knot() 87 | if bf.get_reverse(): 88 | factor = Max(0, RealNumber(knot) - bf_var) 89 | else: 90 | factor = Max(0, bf_var - RealNumber(knot)) 91 | return factor 92 | 93 | def missingness_bf_to_factor(bf, bf_var): 94 | # This is the error that should be raised when a user attempts to use functionality 95 | # that has not yet been implemented. 96 | if bf.complement: 97 | return Not(Missing(bf_var)) 98 | else: 99 | return Missing(bf_var) 100 | 101 | def constant_bf_to_factor(bf, bf_var): 102 | return RealNumber(1) 103 | 104 | def protect_from_nan(label, missables): 105 | return NaNProtect(Symbol(label)) if label in missables else Symbol(label) 106 | 107 | def dont_protect_from_nan(label, missables): 108 | return Symbol(label) 109 | 110 | bf_to_factor_dispatcher = {LinearBasisFunction: linear_bf_to_factor, 111 | SmoothedHingeBasisFunction: smoothed_hinge_bf_to_factor, 112 | HingeBasisFunction: hinge_bf_to_factor, 113 | MissingnessBasisFunction: missingness_bf_to_factor, 114 | ConstantBasisFunction: constant_bf_to_factor} 115 | 116 | nan_protect_dispatch = {LinearBasisFunction: protect_from_nan, 117 | SmoothedHingeBasisFunction: protect_from_nan, 118 | HingeBasisFunction: protect_from_nan, 119 | MissingnessBasisFunction: dont_protect_from_nan, 120 | ConstantBasisFunction: protect_from_nan} 121 | 122 | def bf_to_factor(bf, missables): 123 | ''' 124 | Convert a BasisFunction to a factor of a term. 125 | ''' 126 | if isinstance(bf, VariableBasisFunction): 127 | bf_var = nan_protect_dispatch[bf.__class__](bf.label, missables) 128 | 129 | else: 130 | bf_var = None 131 | return bf_to_factor_dispatcher[bf.__class__](bf, bf_var) 132 | 133 | def missingness_bf_get_missables(bf): 134 | bf_var = bf.label 135 | return set([bf_var]) 136 | 137 | def non_missable(bf): 138 | return set() 139 | 140 | bf_get_missables_dispatcher = {LinearBasisFunction: non_missable, 141 | SmoothedHingeBasisFunction: non_missable, 142 | HingeBasisFunction: non_missable, 143 | MissingnessBasisFunction: missingness_bf_get_missables, 144 | ConstantBasisFunction: non_missable} 145 | 146 | def get_missables(bf): 147 | missables = bf_get_missables_dispatcher[bf.__class__](bf) 148 | parent = bf.get_parent() 149 | if parent is None: 150 | return missables 151 | else: 152 | missables.update(get_missables(parent)) 153 | 154 | return missables 155 | 156 | def bf_to_term(bf, missables): 157 | ''' 158 | Convert a BasisFunction to a term (without coefficient). 159 | ''' 160 | term = bf_to_factor(bf, missables) 161 | parent = bf.get_parent() 162 | if parent is None: 163 | return term 164 | else: 165 | return Mul(term, bf_to_term(parent, missables)) 166 | 167 | return [bf_to_term(bf, get_missables(bf)) for bf in earth_model.basis_.piter()] 168 | 169 | 170 | def export_sympy(earth_model, columns=None): 171 | """ 172 | Constructs a sympy expression or list of sympy expressions from of a trained earth model. 173 | 174 | :param earth_model: Trained pyearth model 175 | :param columns: The index or indices of the output columns for which expressions are to 176 | be constructed. If an integer is used, a sympy expression is returned. If indices 177 | are given then a list of sympy expressions is returned. If columns is None, it is treated 178 | as if columns=0 for models with only one output column or as columns=slice(None) for more than 179 | one output column. 180 | :return: a sympy expression or list of sympy expressions equivalent to the Earth.predict method for 181 | the selected output columns. 182 | 183 | """ 184 | # Set a sane default for columns 185 | if columns is None: 186 | if earth_model.coef_.shape[0] == 1: 187 | columns = 0 188 | else: 189 | columns = slice(None) 190 | 191 | # Get basis function terms 192 | terms = export_sympy_term_expressions(earth_model) 193 | 194 | # Handle column choice 195 | coefs = earth_model.coef_[columns] 196 | if len(coefs.shape) == 1: 197 | unwrap = True 198 | coefs = [coefs] 199 | n_cols = 1 200 | else: 201 | unwrap = False 202 | n_cols = coefs.shape[0] 203 | 204 | # Combine coefficients with terms for each output column 205 | result = [sum([coefs[i][j] * term for j, term in enumerate(terms)]) for i in range(n_cols)] 206 | 207 | if unwrap: 208 | # Result should be an expression rather than a list of expressions. 209 | result = result[0] 210 | return result 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /pyearth/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/py-earth/b209d1916f051dbea5b142af25425df2de469c5a/pyearth/test/__init__.py -------------------------------------------------------------------------------- /pyearth/test/basis/__init__.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from nose.tools import assert_true, assert_false, assert_equal 4 | 5 | import numpy 6 | 7 | from pyearth._basis import Basis, ConstantBasisFunction, HingeBasisFunction, \ 8 | LinearBasisFunction, SmoothedHingeBasisFunction 9 | 10 | numpy.random.seed(0) 11 | -------------------------------------------------------------------------------- /pyearth/test/basis/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy 3 | 4 | 5 | class BaseContainer(object): 6 | filename = os.path.join(os.path.dirname(__file__), '../test_data.csv') 7 | data = numpy.genfromtxt(filename, delimiter=',', skip_header=1) 8 | X = numpy.array(data[:, 0:5]) 9 | y = numpy.array(data[:, 5]) 10 | -------------------------------------------------------------------------------- /pyearth/test/basis/test_basis.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | from nose.tools import assert_equal, assert_true 5 | 6 | from .base import BaseContainer 7 | from pyearth._basis import (HingeBasisFunction, SmoothedHingeBasisFunction, 8 | ConstantBasisFunction, LinearBasisFunction, Basis) 9 | 10 | 11 | class Container(BaseContainer): 12 | 13 | def __init__(self): 14 | super(Container, self).__init__() 15 | self.basis = Basis(self.X.shape[1]) 16 | self.parent = ConstantBasisFunction() 17 | self.bf1 = HingeBasisFunction(self.parent, 1.0, 10, 1, False) 18 | self.bf2 = HingeBasisFunction(self.parent, 1.0, 4, 2, True) 19 | self.bf3 = HingeBasisFunction(self.bf2, 1.0, 4, 3, True) 20 | self.bf4 = LinearBasisFunction(self.parent, 2) 21 | self.bf5 = HingeBasisFunction(self.parent, 1.5, 8, 2, True) 22 | self.basis.append(self.parent) 23 | self.basis.append(self.bf1) 24 | self.basis.append(self.bf2) 25 | self.basis.append(self.bf3) 26 | self.basis.append(self.bf4) 27 | self.basis.append(self.bf5) 28 | 29 | 30 | def test_anova_decomp(): 31 | cnt = Container() 32 | anova = cnt.basis.anova_decomp() 33 | assert_equal(set(anova[frozenset([1])]), set([cnt.bf1])) 34 | assert_equal(set(anova[frozenset([2])]), set([cnt.bf2, cnt.bf4, 35 | cnt.bf5])) 36 | assert_equal(set(anova[frozenset([2, 3])]), set([cnt.bf3])) 37 | assert_equal(set(anova[frozenset()]), set([cnt.parent])) 38 | assert_equal(len(anova), 4) 39 | 40 | 41 | def test_smooth_knots(): 42 | cnt = Container() 43 | mins = [0.0, -1.0, 0.1, 0.2] 44 | maxes = [2.5, 3.5, 3.0, 2.0] 45 | knots = cnt.basis.smooth_knots(mins, maxes) 46 | assert_equal(knots[cnt.bf1], (0.0, 2.25)) 47 | assert_equal(knots[cnt.bf2], (0.55, 1.25)) 48 | assert_equal(knots[cnt.bf3], (0.6, 1.5)) 49 | assert_true(cnt.bf4 not in knots) 50 | assert_equal(knots[cnt.bf5], (1.25, 2.25)) 51 | 52 | 53 | def test_smooth(): 54 | cnt = Container() 55 | X = numpy.random.uniform(-2.0, 4.0, size=(20, 4)) 56 | smooth_basis = cnt.basis.smooth(X) 57 | for bf, smooth_bf in zip(cnt.basis, smooth_basis): 58 | if type(bf) is HingeBasisFunction: 59 | assert_true(type(smooth_bf) is SmoothedHingeBasisFunction) 60 | elif type(bf) is ConstantBasisFunction: 61 | assert_true(type(smooth_bf) is ConstantBasisFunction) 62 | elif type(bf) is LinearBasisFunction: 63 | assert_true(type(smooth_bf) is LinearBasisFunction) 64 | else: 65 | raise AssertionError('Basis function is of an unexpected type.') 66 | assert_true(type(smooth_bf) in {SmoothedHingeBasisFunction, 67 | ConstantBasisFunction, 68 | LinearBasisFunction}) 69 | if bf.has_knot(): 70 | assert_equal(bf.get_knot(), smooth_bf.get_knot()) 71 | 72 | 73 | def test_add(): 74 | cnt = Container() 75 | assert_equal(len(cnt.basis), 6) 76 | 77 | 78 | def test_pickle_compat(): 79 | cnt = Container() 80 | basis_copy = pickle.loads(pickle.dumps(cnt.basis)) 81 | assert_true(cnt.basis == basis_copy) 82 | -------------------------------------------------------------------------------- /pyearth/test/basis/test_constant.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | from nose.tools import assert_true, assert_false 5 | 6 | from .base import BaseContainer 7 | from pyearth._types import BOOL 8 | from pyearth._basis import ConstantBasisFunction 9 | 10 | 11 | class Container(BaseContainer): 12 | 13 | def __init__(self): 14 | super(Container, self).__init__() 15 | self.bf = ConstantBasisFunction() 16 | 17 | 18 | def test_apply(): 19 | cnt = Container() 20 | m, _ = cnt.X.shape 21 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 22 | B = numpy.empty(shape=(m, 10)) 23 | assert_false(numpy.all(B[:, 0] == 1)) 24 | cnt.bf.apply(cnt.X, missing, B[:, 0]) 25 | assert_true(numpy.all(B[:, 0] == 1)) 26 | 27 | 28 | def test_deriv(): 29 | cnt = Container() 30 | m, _ = cnt.X.shape 31 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 32 | b = numpy.empty(shape=m) 33 | j = numpy.empty(shape=m) 34 | cnt.bf.apply_deriv(cnt.X, missing, b, j, 1) 35 | assert_true(numpy.all(b == 1)) 36 | assert_true(numpy.all(j == 0)) 37 | 38 | 39 | def test_pickle_compatibility(): 40 | cnt = Container() 41 | bf_copy = pickle.loads(pickle.dumps(cnt.bf)) 42 | assert_true(cnt.bf == bf_copy) 43 | 44 | 45 | def test_smoothed_version(): 46 | cnt = Container() 47 | smoothed = cnt.bf._smoothed_version(None, {}, {}) 48 | assert_true(type(smoothed) is ConstantBasisFunction) 49 | -------------------------------------------------------------------------------- /pyearth/test/basis/test_hinge.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | from nose.tools import assert_equal, assert_true 5 | 6 | from .base import BaseContainer 7 | from pyearth._types import BOOL 8 | from pyearth._basis import (HingeBasisFunction, SmoothedHingeBasisFunction, 9 | ConstantBasisFunction) 10 | 11 | 12 | class Container(BaseContainer): 13 | 14 | def __init__(self): 15 | super(Container, self).__init__() 16 | self.parent = ConstantBasisFunction() 17 | self.bf = HingeBasisFunction(self.parent, 1.0, 10, 1, False) 18 | 19 | 20 | def test_getters(): 21 | cnt = Container() 22 | assert not cnt.bf.get_reverse() 23 | assert cnt.bf.get_knot() == 1.0 24 | assert cnt.bf.get_variable() == 1 25 | assert cnt.bf.get_knot_idx() == 10 26 | assert cnt.bf.get_parent() == cnt.parent 27 | 28 | 29 | def test_apply(): 30 | cnt = Container() 31 | m, _ = cnt.X.shape 32 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 33 | B = numpy.ones(shape=(m, 10)) 34 | cnt.bf.apply(cnt.X, missing, B[:, 0]) 35 | numpy.testing.assert_almost_equal( 36 | B[:, 0], 37 | (cnt.X[:, 1] - 1.0) * (cnt.X[:, 1] > 1.0) 38 | ) 39 | 40 | 41 | def test_apply_deriv(): 42 | cnt = Container() 43 | m, _ = cnt.X.shape 44 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 45 | b = numpy.empty(shape=m) 46 | j = numpy.empty(shape=m) 47 | cnt.bf.apply_deriv(cnt.X, missing, b, j, 1) 48 | numpy.testing.assert_almost_equal( 49 | (cnt.X[:, 1] - 1.0) * (cnt.X[:, 1] > 1.0), 50 | b 51 | ) 52 | numpy.testing.assert_almost_equal(1.0 * (cnt.X[:, 1] > 1.0), j) 53 | 54 | 55 | def test_degree(): 56 | cnt = Container() 57 | assert_equal(cnt.bf.degree(), 1) 58 | 59 | 60 | def test_pickle_compatibility(): 61 | cnt = Container() 62 | bf_copy = pickle.loads(pickle.dumps(cnt.bf)) 63 | assert_true(cnt.bf == bf_copy) 64 | 65 | 66 | def test_smoothed_version(): 67 | cnt = Container() 68 | knot_dict = {cnt.bf: (.5, 1.5)} 69 | translation = {cnt.parent: cnt.parent._smoothed_version(None, {}, {})} 70 | smoothed = cnt.bf._smoothed_version(cnt.parent, knot_dict, 71 | translation) 72 | 73 | assert_true(type(smoothed) is SmoothedHingeBasisFunction) 74 | assert_true(translation[cnt.parent] is smoothed.get_parent()) 75 | assert_equal(smoothed.get_knot_minus(), 0.5) 76 | assert_equal(smoothed.get_knot_plus(), 1.5) 77 | assert_equal(smoothed.get_knot(), cnt.bf.get_knot()) 78 | assert_equal(smoothed.get_variable(), cnt.bf.get_variable()) 79 | -------------------------------------------------------------------------------- /pyearth/test/basis/test_linear.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | from nose.tools import assert_equal, assert_true 5 | 6 | from .base import BaseContainer 7 | from pyearth._types import BOOL 8 | from pyearth._basis import LinearBasisFunction, ConstantBasisFunction 9 | 10 | 11 | class Container(BaseContainer): 12 | 13 | def __init__(self): 14 | super(Container, self).__init__() 15 | self.parent = ConstantBasisFunction() 16 | self.bf = LinearBasisFunction(self.parent, 1) 17 | 18 | 19 | def test_apply(): 20 | cnt = Container() 21 | m, n = cnt.X.shape 22 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 23 | B = numpy.ones(shape=(m, 10)) 24 | cnt.bf.apply(cnt.X, missing, B[:, 0]) 25 | assert_true(numpy.all(B[:, 0] == cnt.X[:, 1])) 26 | 27 | 28 | def test_apply_deriv(): 29 | cnt = Container() 30 | m, _ = cnt.X.shape 31 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 32 | b = numpy.empty(shape=m) 33 | j = numpy.empty(shape=m) 34 | cnt.bf.apply_deriv(cnt.X, missing, b, j, 1) 35 | numpy.testing.assert_almost_equal(b, cnt.X[:, 1]) 36 | numpy.testing.assert_almost_equal(j, 1.0) 37 | 38 | 39 | def test_degree(): 40 | cnt = Container() 41 | assert_equal(cnt.bf.degree(), 1) 42 | 43 | 44 | def test_pickle_compatibility(): 45 | cnt = Container() 46 | bf_copy = pickle.loads(pickle.dumps(cnt.bf)) 47 | assert_true(cnt.bf == bf_copy) 48 | 49 | 50 | def test_smoothed_version(): 51 | cnt = Container() 52 | translation = {cnt.parent: cnt.parent._smoothed_version(None, {}, {})} 53 | smoothed = cnt.bf._smoothed_version(cnt.parent, {}, translation) 54 | assert_true(isinstance(smoothed, LinearBasisFunction)) 55 | assert_equal(smoothed.get_variable(), cnt.bf.get_variable()) 56 | -------------------------------------------------------------------------------- /pyearth/test/basis/test_missingness.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | from nose.tools import assert_equal, assert_true 5 | 6 | from .base import BaseContainer 7 | from pyearth._types import BOOL 8 | from pyearth._basis import ( 9 | HingeBasisFunction, ConstantBasisFunction, MissingnessBasisFunction) 10 | 11 | 12 | class Container(BaseContainer): 13 | 14 | def __init__(self): 15 | super(Container, self).__init__() 16 | self.parent = ConstantBasisFunction() 17 | self.bf = MissingnessBasisFunction(self.parent, 1, True) 18 | self.child = HingeBasisFunction(self.bf, 1.0, 10, 1, False) 19 | 20 | # 21 | # def test_getters(): 22 | # cnt = Container() 23 | # assert not cnt.bf.get_reverse() 24 | # assert cnt.bf.get_knot() == 1.0 25 | # assert cnt.bf.get_variable() == 1 26 | # assert cnt.bf.get_knot_idx() == 10 27 | # assert cnt.bf.get_parent() == cnt.parent 28 | 29 | 30 | def test_apply(): 31 | cnt = Container() 32 | m, _ = cnt.X.shape 33 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 34 | missing[1, 1] = True 35 | B = numpy.ones(shape=(m, 10)) 36 | X = cnt.X.copy() 37 | X[1, 1] = None 38 | cnt.bf.apply(cnt.X, missing, B[:, 0]) 39 | numpy.testing.assert_almost_equal( 40 | B[:, 0], 41 | 1 - missing[:, 1] 42 | ) 43 | cnt.child.apply(cnt.X, missing, B[:, 0]) 44 | expected = (cnt.X[:, 1] - 1.0) * (cnt.X[:, 1] > 1.0) 45 | expected[1] = 0.0 46 | numpy.testing.assert_almost_equal( 47 | B[:, 0], 48 | expected 49 | ) 50 | 51 | 52 | # def test_apply_deriv(): 53 | # cnt = Container() 54 | # m, _ = cnt.X.shape 55 | # missing = numpy.zeros_like(cnt.X, dtype=BOOL) 56 | # b = numpy.empty(shape=m) 57 | # j = numpy.empty(shape=m) 58 | # cnt.bf.apply_deriv(cnt.X, missing, b, j, 1) 59 | # numpy.testing.assert_almost_equal( 60 | # (cnt.X[:, 1] - 1.0) * (cnt.X[:, 1] > 1.0), 61 | # b 62 | # ) 63 | # numpy.testing.assert_almost_equal(1.0 * (cnt.X[:, 1] > 1.0), j) 64 | 65 | 66 | def test_degree(): 67 | cnt = Container() 68 | assert_equal(cnt.bf.degree(), 1) 69 | 70 | 71 | def test_pickle_compatibility(): 72 | cnt = Container() 73 | bf_copy = pickle.loads(pickle.dumps(cnt.bf)) 74 | assert_true(cnt.bf == bf_copy) 75 | 76 | # 77 | # def test_smoothed_version(): 78 | # cnt = Container() 79 | # knot_dict = {cnt.bf: (.5, 1.5)} 80 | # translation = {cnt.parent: cnt.parent._smoothed_version(None, {}, {})} 81 | # smoothed = cnt.bf._smoothed_version(cnt.parent, knot_dict, 82 | # translation) 83 | # 84 | # assert_true(type(smoothed) is SmoothedHingeBasisFunction) 85 | # assert_true(translation[cnt.parent] is smoothed.get_parent()) 86 | # assert_equal(smoothed.get_knot_minus(), 0.5) 87 | # assert_equal(smoothed.get_knot_plus(), 1.5) 88 | # assert_equal(smoothed.get_knot(), cnt.bf.get_knot()) 89 | # assert_equal(smoothed.get_variable(), cnt.bf.get_variable()) 90 | -------------------------------------------------------------------------------- /pyearth/test/basis/test_smoothed_hinge.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | 4 | from nose.tools import assert_equal 5 | 6 | from .base import BaseContainer 7 | from pyearth._types import BOOL 8 | from pyearth._basis import SmoothedHingeBasisFunction, ConstantBasisFunction 9 | 10 | 11 | class Container(BaseContainer): 12 | 13 | def __init__(self): 14 | super(Container, self).__init__() 15 | self.parent = ConstantBasisFunction() 16 | self.bf1 = SmoothedHingeBasisFunction(self.parent, 17 | 1.0, 0.0, 3.0, 10, 1, 18 | False) 19 | self.bf2 = SmoothedHingeBasisFunction(self.parent, 20 | 1.0, 0.0, 3.0, 10, 1, 21 | True) 22 | 23 | 24 | def test_getters(): 25 | cnt = Container() 26 | assert not cnt.bf1.get_reverse() 27 | assert cnt.bf2.get_reverse() 28 | assert cnt.bf1.get_knot() == 1.0 29 | assert cnt.bf1.get_variable() == 1 30 | assert cnt.bf1.get_knot_idx() == 10 31 | assert cnt.bf1.get_parent() == cnt.parent 32 | assert cnt.bf1.get_knot_minus() == 0.0 33 | assert cnt.bf1.get_knot_plus() == 3.0 34 | 35 | 36 | def test_pickle_compatibility(): 37 | cnt = Container() 38 | bf_copy = pickle.loads(pickle.dumps(cnt.bf1)) 39 | assert_equal(cnt.bf1, bf_copy) 40 | 41 | 42 | def test_smoothed_version(): 43 | cnt = Container() 44 | translation = {cnt.parent: cnt.parent._smoothed_version(None, {}, {})} 45 | smoothed = cnt.bf1._smoothed_version(cnt.parent, {}, translation) 46 | assert_equal(cnt.bf1, smoothed) 47 | 48 | 49 | def test_degree(): 50 | cnt = Container() 51 | assert_equal(cnt.bf1.degree(), 1) 52 | assert_equal(cnt.bf2.degree(), 1) 53 | 54 | 55 | def test_p_r(): 56 | cnt = Container() 57 | pplus = (2 * 3.0 + 0.0 - 3 * 1.0) / ((3.0 - 0.0)**2) 58 | rplus = (2 * 1.0 - 3.0 - 0.0) / ((3.0 - 0.0)**3) 59 | pminus = (3 * 1.0 - 2 * 0.0 - 3.0) / ((0.0 - 3.0)**2) 60 | rminus = (0.0 + 3.0 - 2 * 1.0) / ((0.0 - 3.0)**3) 61 | assert_equal(cnt.bf1.get_p(), pplus) 62 | assert_equal(cnt.bf1.get_r(), rplus) 63 | assert_equal(cnt.bf2.get_p(), pminus) 64 | assert_equal(cnt.bf2.get_r(), rminus) 65 | 66 | 67 | def test_apply(): 68 | cnt = Container() 69 | m, _ = cnt.X.shape 70 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 71 | B = numpy.ones(shape=(m, 10)) 72 | cnt.bf1.apply(cnt.X, missing, B[:, 0]) 73 | cnt.bf2.apply(cnt.X, missing, B[:, 1]) 74 | pplus = (2 * 3.0 + 0.0 - 3 * 1.0) / ((3.0 - 0.0)**2) 75 | rplus = (2 * 1.0 - 3.0 - 0.0) / ((3.0 - 0.0)**3) 76 | pminus = (3 * 1.0 - 2 * 0.0 - 3.0) / ((0.0 - 3.0)**2) 77 | rminus = (0.0 + 3.0 - 2 * 1.0) / ((0.0 - 3.0)**3) 78 | c1 = numpy.ones(m) 79 | c1[cnt.X[:, 1] <= 0.0] = 0.0 80 | c1[(cnt.X[:, 1] > 0.0) & (cnt.X[:, 1] < 3.0)] = ( 81 | pplus * ((cnt.X[(cnt.X[:, 1] > 0.0) & ( 82 | cnt.X[:, 1] < 3.0), 1] - 0.0)**2) + 83 | rplus * ((cnt.X[(cnt.X[:, 1] > 0.0) & ( 84 | cnt.X[:, 1] < 3.0), 1] - 0.0)**3)) 85 | c1[cnt.X[:, 1] >= 3.0] = cnt.X[cnt.X[:, 1] >= 3.0, 1] - 1.0 86 | c2 = numpy.ones(m) 87 | c2[cnt.X[:, 1] >= 3.0] = 0.0 88 | c2.flat[cnt.X[:, 1] <= 0.0] = -1 * (cnt.X[cnt.X[:, 1] <= 0.0] - 1.0) 89 | c2[(cnt.X[:, 1] > 0.0) & (cnt.X[:, 1] < 3.0)] = ( 90 | pminus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 91 | (cnt.X[:, 1] < 3.0), 1] - 3.0)**2) + 92 | rminus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 93 | (cnt.X[:, 1] < 3.0), 1] - 3.0)**3) 94 | ) 95 | numpy.testing.assert_almost_equal(B[:, 0], c1) 96 | numpy.testing.assert_almost_equal(B[:, 1], c2) 97 | 98 | 99 | def test_apply_deriv(): 100 | cnt = Container() 101 | m, _ = cnt.X.shape 102 | missing = numpy.zeros_like(cnt.X, dtype=BOOL) 103 | pplus = (2 * 3.0 + 0.0 - 3 * 1.0) / ((3.0 - 0.0)**2) 104 | rplus = (2 * 1.0 - 3.0 - 0.0) / ((3.0 - 0.0)**3) 105 | pminus = (3 * 1.0 - 2 * 0.0 - 3.0) / ((0.0 - 3.0)**2) 106 | rminus = (0.0 + 3.0 - 2 * 1.0) / ((0.0 - 3.0)**3) 107 | c1 = numpy.ones(m) 108 | c1[cnt.X[:, 1] <= 0.0] = 0.0 109 | c1[(cnt.X[:, 1] > 0.0) & (cnt.X[:, 1] < 3.0)] = ( 110 | pplus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 111 | (cnt.X[:, 1] < 3.0), 1] - 0.0)**2) + 112 | rplus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 113 | (cnt.X[:, 1] < 3.0), 1] - 0.0)**3)) 114 | c1[cnt.X[:, 1] >= 3.0] = cnt.X[cnt.X[:, 1] >= 3.0, 1] - 1.0 115 | c2 = numpy.ones(m) 116 | c2[cnt.X[:, 1] >= 3.0] = 0.0 117 | c2.flat[cnt.X[:, 1] <= 0.0] = -1 * (cnt.X[cnt.X[:, 1] <= 0.0] - 1.0) 118 | c2[(cnt.X[:, 1] > 0.0) & (cnt.X[:, 1] < 3.0)] = ( 119 | pminus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 120 | (cnt.X[:, 1] < 3.0), 1] - 3.0)**2) + 121 | rminus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 122 | (cnt.X[:, 1] < 3.0), 1] - 3.0)**3) 123 | ) 124 | b1 = numpy.empty(shape=m) 125 | j1 = numpy.empty(shape=m) 126 | b2 = numpy.empty(shape=m) 127 | j2 = numpy.empty(shape=m) 128 | cp1 = numpy.ones(m) 129 | cp1[cnt.X[:, 1] <= 0.0] = 0.0 130 | cp1[(cnt.X[:, 1] > 0.0) & (cnt.X[:, 1] < 3.0)] = ( 131 | 2.0 * pplus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 132 | (cnt.X[:, 1] < 3.0), 1] - 0.0)) + 133 | 3.0 * rplus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 134 | (cnt.X[:, 1] < 3.0), 1] - 0.0)**2) 135 | ) 136 | cp1[cnt.X[:, 1] >= 3.0] = 1.0 137 | cp2 = numpy.ones(m) 138 | cp2[cnt.X[:, 1] >= 3.0] = 0.0 139 | cp2[cnt.X[:, 1] <= 0.0] = -1.0 140 | cp2[(cnt.X[:, 1] > 0.0) & (cnt.X[:, 1] < 3.0)] = ( 141 | 2.0 * pminus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 142 | (cnt.X[:, 1] < 3.0), 1] - 3.0)) + 143 | 3.0 * rminus * ((cnt.X[(cnt.X[:, 1] > 0.0) & 144 | (cnt.X[:, 1] < 3.0), 1] - 3.0)**2)) 145 | cnt.bf1.apply_deriv(cnt.X, missing, b1, j1, 1) 146 | cnt.bf2.apply_deriv(cnt.X, missing, b2, j2, 1) 147 | numpy.testing.assert_almost_equal(b1, c1) 148 | numpy.testing.assert_almost_equal(b2, c2) 149 | numpy.testing.assert_almost_equal(j1, cp1) 150 | numpy.testing.assert_almost_equal(j2, cp2) 151 | -------------------------------------------------------------------------------- /pyearth/test/earth_linvars_regress.txt: -------------------------------------------------------------------------------- 1 | 0.151785846961 -------------------------------------------------------------------------------- /pyearth/test/earth_regress.txt: -------------------------------------------------------------------------------- 1 | 0.219175992389 -------------------------------------------------------------------------------- /pyearth/test/earth_regress_missing_data.txt: -------------------------------------------------------------------------------- 1 | 0.178314473548 -------------------------------------------------------------------------------- /pyearth/test/earth_regress_smooth.txt: -------------------------------------------------------------------------------- 1 | 0.209074299412 -------------------------------------------------------------------------------- /pyearth/test/forward_regress.txt: -------------------------------------------------------------------------------- 1 | (Intercept) 2 | h(x1-0.906045) 3 | h(0.906045-x1) 4 | h(x2-1.03441) 5 | h(1.03441-x2) 6 | h(x8-1.38526) 7 | h(1.38526-x8) 8 | h(x2+0.837678) 9 | h(-0.837678-x2) 10 | x5 11 | x6 12 | h(x6+0.470033) 13 | h(-0.470033-x6) 14 | h(x9-0.527004) 15 | h(0.527004-x9) 16 | h(x6-0.86352) 17 | h(0.86352-x6) 18 | h(x6-0.67057) 19 | h(0.67057-x6) 20 | h(x6-0.950088) 21 | h(0.950088-x6) 22 | x7 23 | h(x5-1.09075) 24 | h(1.09075-x5) 25 | h(x5-0.821406) 26 | h(0.821406-x5) 27 | h(x5+1.10291) 28 | h(-1.10291-x5) 29 | h(x5+0.649338) 30 | h(-0.649338-x5) 31 | h(x1+0.144567) 32 | h(-0.144567-x1) 33 | h(x5+0.88778) 34 | h(-0.88778-x5) 35 | h(x5+0.46072) 36 | h(-0.46072-x5) 37 | h(x6-0.49869) 38 | h(0.49869-x6) 39 | h(x5-0.314817) 40 | h(0.314817-x5) 41 | h(x5-0.449712) 42 | h(0.449712-x5) 43 | h(x5-0.160928) 44 | h(0.160928-x5) 45 | h(x5-0.403265) 46 | h(0.403265-x5) 47 | h(x5-0.681595) 48 | h(0.681595-x5) 49 | h(x6+0.198399) 50 | h(-0.198399-x6) 51 | h(x7+1.47184) 52 | h(-1.47184-x7) 53 | h(x7-1.04797) 54 | h(1.04797-x7) 55 | x3 56 | h(x7+0.187184) 57 | h(-0.187184-x7) 58 | 59 | ------------------------------------------------------------------ 60 | iter parent var knot mse terms gcv rsq grsq 61 | ------------------------------------------------------------------ 62 | 0 - - - 1.373069 1 1.401 0.000 0.000 63 | 1 0 1 11 0.815066 3 0.884 0.406 0.369 64 | 2 0 2 64 0.714179 5 0.826 0.480 0.411 65 | 3 0 8 98 0.684218 7 0.845 0.502 0.397 66 | 4 0 2 82 0.645618 9 0.853 0.530 0.391 67 | 5 0 5 -1 0.639630 10 0.875 0.534 0.375 68 | 6 0 6 -1 0.632280 11 0.896 0.540 0.360 69 | 7 0 6 59 0.616184 13 0.939 0.551 0.330 70 | 8 0 9 56 0.591386 15 0.972 0.569 0.306 71 | 9 0 6 95 0.585666 17 1.041 0.573 0.257 72 | 10 0 6 75 0.573308 19 1.106 0.582 0.211 73 | 11 0 6 0 0.561442 21 1.179 0.591 0.158 74 | 12 0 7 -1 0.556672 22 1.222 0.595 0.128 75 | 13 0 5 69 0.553707 24 1.331 0.597 0.050 76 | 14 0 5 75 0.532222 26 1.407 0.612 -0.004 77 | 15 0 5 76 0.529160 28 1.546 0.615 -0.104 78 | 16 0 5 66 0.520594 30 1.690 0.621 -0.206 79 | 17 0 1 98 0.489226 32 1.775 0.644 -0.267 80 | 18 0 5 88 0.469607 34 1.917 0.658 -0.368 81 | 19 0 5 30 0.464368 36 2.148 0.662 -0.533 82 | 20 0 6 51 0.461568 38 2.439 0.664 -0.741 83 | 21 0 5 43 0.458250 40 2.794 0.666 -0.994 84 | 22 0 5 65 0.454438 42 3.232 0.669 -1.307 85 | 23 0 5 36 0.440007 44 3.697 0.680 -1.639 86 | 24 0 5 63 0.435169 46 4.386 0.683 -2.131 87 | 25 0 5 17 0.428080 48 5.270 0.688 -2.762 88 | 26 0 6 99 0.421638 50 6.484 0.693 -3.628 89 | 27 0 7 38 0.417944 52 8.256 0.696 -4.893 90 | 28 0 7 51 0.414224 54 10.893 0.698 -6.776 91 | 29 0 3 -1 0.411539 55 12.702 0.700 -8.067 92 | 30 0 7 2 0.408069 57 18.136 0.703 -11.946 93 | ------------------------------------------------------------------ -------------------------------------------------------------------------------- /pyearth/test/pathological_data/issue_44.csv: -------------------------------------------------------------------------------- 1 | x,y 2 | -3.7373737373737375,1.7991218849952073 3 | -0.50505050505050519,-0.91104074841082883 4 | 2.3232323232323235,1.7489750942002336 5 | -9.3939393939393945,0.93746398352941285 6 | -1.9191919191919189,0.37906368380023697 7 | -4.5454545454545459,2.0671407475807846 8 | -2.3232323232323235,2.8903231496524264 9 | -2.1212121212121207,0.95421401743510981 10 | 5.1515151515151523,-0.28835274505737712 11 | 6.5656565656565631,2.4349616731309003 12 | 8.3838383838383841,2.0015189624872649 13 | -4.3434343434343443,2.0839886262250777 14 | -5.1515151515151523,4.0035426047149203 15 | 4.9494949494949498,1.0301181262384298 16 | 10.0,2.0990912639096395 17 | 5.9595959595959584,1.4837216429432181 18 | -4.1414141414141419,3.1115350771447017 19 | 5.5555555555555554,-0.49432035860384549 20 | 9.7979797979797993,2.2570162633565838 21 | 2.9292929292929286,0.19922300098316129 22 | -4.7474747474747474,2.2411649234982853 23 | 8.1818181818181817,3.8319133058008834 24 | 0.10101010101010033,-1.1655342778993465 25 | -7.3737373737373728,1.5936860907276955 26 | 0.50505050505050519,2.7824955819904238 27 | 4.7474747474747474,0.37375342310356208 28 | -7.7777777777777777,1.3319391790079955 29 | 0.70707070707070763,2.7447709715653392 30 | 0.90909090909090795,1.2593630603714561 31 | -2.9292929292929286,1.4730041828646854 32 | 1.5151515151515156,1.9517355390216784 33 | 7.5757575757575744,2.5204906997403671 34 | 3.737373737373737,0.82513781953866294 35 | 1.7171717171717162,1.5521287384776057 36 | 3.9393939393939386,0.41909788678935839 37 | 8.787878787878789,0.83357749049159235 38 | 2.5252525252525242,0.97158776816516124 39 | 3.333333333333333,2.796727214906336 40 | 1.9191919191919189,1.5770991785524595 41 | 7.3737373737373728,2.812545924932186 42 | 8.9898989898989896,2.3399323489969777 43 | -8.1818181818181817,0.22251629284196264 44 | -6.3636363636363633,0.65508003936260206 45 | -2.5252525252525246,1.1781202228906005 46 | 7.7777777777777777,1.1361770487748726 47 | -6.9696969696969697,0.49658600871033542 48 | 7.171717171717173,1.2870810571346567 49 | -1.1111111111111107,-0.92633728764534862 50 | -1.5151515151515156,-0.30285092335205799 51 | -6.1616161616161618,2.2600448186850697 52 | 9.1919191919191903,2.0528078501883256 53 | 9.5959595959595987,0.36719227705824703 54 | -5.5555555555555554,1.5554954537126158 55 | 4.3434343434343443,-1.2471684340825608 56 | -0.10101010101010033,1.6055254760143964 57 | 6.9696969696969688,1.9399379619451884 58 | 4.545454545454545,0.86004045773468585 59 | -9.7979797979797958,2.8053907657933528 60 | -6.7676767676767682,-0.2256450948986943 61 | -3.9393939393939386,3.489343493164955 62 | -5.9595959595959593,2.5238295345305537 63 | -1.3131313131313131,0.23652412699589609 64 | -1.7171717171717178,-1.6686836205834381 65 | -9.1919191919191938,-0.55397857306368992 66 | -4.9494949494949498,1.0493231114754551 67 | -0.30303030303030282,-1.0878104111406399 68 | -9.5959595959595987,-0.55743666312264351 69 | -0.90909090909090995,1.8446458659377187 70 | 5.3535353535353529,0.59548081377866657 71 | -7.1717171717171713,-0.46557755831965858 72 | 0.30303030303030282,0.95943699775297564 73 | -5.7575757575757578,1.8185702673284407 74 | 1.1111111111111107,2.1419245861795222 75 | 2.7272727272727262,-1.3093855134880783 76 | -5.3535353535353538,1.6946451925769861 77 | 3.1313131313131315,0.32016804097713369 78 | -3.333333333333333,2.2084032577385755 79 | 9.3939393939393945,2.5342019954452182 80 | -7.9797979797979792,-0.48297953030589919 81 | 8.5858585858585847,1.7297261905989889 82 | -2.7272727272727275,-0.9646357159778064 83 | -8.3838383838383841,-0.38537725554562052 84 | 6.7676767676767682,3.3939003736405828 85 | 6.1616161616161627,-0.077379083453575012 86 | 6.3636363636363633,2.3902476491581792 87 | -8.5858585858585847,-0.19404572628702388 88 | 2.1212121212121207,1.3602855988564588 89 | -3.5353535353535355,0.93983489736420145 90 | -10.0,2.3799875617530448 91 | -8.7878787878787872,-0.52460655837426085 92 | -7.5757575757575761,-0.28301040599130989 93 | 5.7575757575757578,2.206172182722522 94 | 4.1414141414141419,1.401220881354313 95 | 7.9797979797979792,1.5959670945663329 96 | -6.5656565656565657,-0.45038856153592982 97 | -3.1313131313131315,2.1499146748068823 98 | -0.70707070707070763,-0.86787179551703342 99 | 1.3131313131313131,0.1139238155075194 100 | -8.9898989898989896,0.19646334195105306 101 | 3.5353535353535346,1.1596893484229858 102 | -------------------------------------------------------------------------------- /pyearth/test/pathological_data/issue_44.txt: -------------------------------------------------------------------------------- 1 | Earth Model 2 | ------------------------------------- 3 | Basis Function Pruned Coefficient 4 | ------------------------------------- 5 | (Intercept) No 1.13135 6 | x Yes None 7 | ------------------------------------- 8 | MSE: 1.6432, GCV: 1.6765, RSQ: -0.0000, GRSQ: -0.0000 -------------------------------------------------------------------------------- /pyearth/test/pathological_data/issue_50.csv: -------------------------------------------------------------------------------- 1 | x,y 2 | 24.040128,-125.66847999999999 3 | 40.17048,-131.51711999999998 4 | 41.10344,-133.78176 5 | 41.972224,-133.5096 6 | 44.720591999999996,-133.60592 7 | 45.825248,-138.48239999999998 8 | 50.725696,-142.98031999999998 9 | -------------------------------------------------------------------------------- /pyearth/test/pathological_data/issue_50.txt: -------------------------------------------------------------------------------- 1 | Earth Model 2 | ------------------------------------- 3 | Basis Function Pruned Coefficient 4 | ------------------------------------- 5 | (Intercept) No -134.994 6 | h(x-44.7206) No -1.42112 7 | h(44.7206-x) No 0.461161 8 | h(x-40.1705) Yes None 9 | h(40.1705-x) Yes None 10 | ------------------------------------- 11 | MSE: 1191.2154, GCV: 6485.5061, RSQ: 0.9507, GRSQ: 0.8028 -------------------------------------------------------------------------------- /pyearth/test/pathological_data/issue_50_weight.csv: -------------------------------------------------------------------------------- 1 | sample_weight 2 | 1062.711799884992 3 | 871.4544949444627 4 | 839.5063272968255 5 | 954.8189050502953 6 | 827.1006200543013 7 | 1244.6183968904068 8 | 802.2168136248937 -------------------------------------------------------------------------------- /pyearth/test/pathological_data/readme.txt: -------------------------------------------------------------------------------- 1 | Pathological Data Sets 2 | 3 | The data sets contained in this folder have revealed bugs in the past. They now serve as 4 | regression tests to prevent the same bugs from arising again in the future. 5 | 6 | issue_44: 7 | This data set caused a segfault during fitting due to the use of a negative index in Cython 8 | with wraparound = False. 9 | 10 | issue_50: 11 | This data set exposed a bug that occurred when using the sample_weight parameter. The problem 12 | was that the apply methods of the BasisFunctions were not applying the weights, and the 13 | next_pair method of the ForwardPasser class assumed they were. Now next_pair applies the 14 | weights after calling apply. The same data set exposed issue 51, in which user-specified 15 | endspans were not used. This test case covers both issues. -------------------------------------------------------------------------------- /pyearth/test/record/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/py-earth/b209d1916f051dbea5b142af25425df2de469c5a/pyearth/test/record/__init__.py -------------------------------------------------------------------------------- /pyearth/test/record/test_forward_pass.py: -------------------------------------------------------------------------------- 1 | from pyearth._record import ForwardPassRecord, ForwardPassIteration 2 | from pyearth._util import gcv 3 | from ..testing_utils import assert_list_almost_equal 4 | 5 | 6 | num_samples = 1000 7 | num_variables = 10 8 | penalty = 3.0 9 | sst = 100.0 10 | varnames = ['x' + str(i) for i in range(num_variables)] 11 | record = ForwardPassRecord(num_samples, num_variables, 12 | penalty, sst, varnames) 13 | record.append(ForwardPassIteration(0, 3, 3, 63.0, 3)) 14 | record.append(ForwardPassIteration(0, 3, 14, 34.0, 5)) 15 | record.append(ForwardPassIteration(3, 6, 12, 18.0, 7)) 16 | mses = [sst, 63.0, 34.0, 18.0] 17 | sizes = [1, 3, 5, 7] 18 | 19 | 20 | def test_statistics(): 21 | mses = [record.mse(i) for i in range(len(record))] 22 | mses_ = [mses[i] for i in range(len(record))] 23 | gcvs = [record.gcv(i) for i in range(len(record))] 24 | gcvs_ = [gcv(mses[i], sizes[i], num_samples, penalty) 25 | for i in range(len(record))] 26 | rsqs = [record.rsq(i) for i in range(len(record))] 27 | rsqs_ = [1 - (mses[i] / sst) 28 | for i in range(len(record))] 29 | grsqs = [record.grsq(i) for i in range(len(record))] 30 | grsqs_ = [1 - (record.gcv(i) / gcv(sst, 1, num_samples, penalty)) 31 | for i in range(len(record))] 32 | assert_list_almost_equal(mses, mses_) 33 | assert_list_almost_equal(gcvs, gcvs_) 34 | assert_list_almost_equal(rsqs, rsqs_) 35 | assert_list_almost_equal(grsqs, grsqs_) 36 | -------------------------------------------------------------------------------- /pyearth/test/record/test_pruning_pass.py: -------------------------------------------------------------------------------- 1 | from pyearth._record import PruningPassRecord, PruningPassIteration 2 | from pyearth._util import gcv 3 | from ..testing_utils import assert_list_almost_equal 4 | 5 | 6 | num_samples = 1000 7 | num_variables = 10 8 | penalty = 3.0 9 | sst = 100.0 10 | record = PruningPassRecord(num_samples, num_variables, 11 | penalty, sst, 7, 18.0) 12 | record.append(PruningPassIteration(2, 6, 25.0)) 13 | record.append(PruningPassIteration(1, 5, 34.0)) 14 | record.append(PruningPassIteration(3, 4, 87.0)) 15 | mses = [18.0, 25.0, 34.0, 87.0] 16 | sizes = [7, 6, 5, 4] 17 | 18 | 19 | def test_statistics(): 20 | mses = [record.mse(i) for i in range(len(record))] 21 | mses_ = [mses[i] for i in range(len(record))] 22 | gcvs = [record.gcv(i) for i in range(len(record))] 23 | gcvs_ = [gcv(mses[i], sizes[i], num_samples, penalty) 24 | for i in range(len(record))] 25 | rsqs = [record.rsq(i) for i in range(len(record))] 26 | rsqs_ = [1 - (mses[i] / sst) 27 | for i in range(len(record))] 28 | grsqs = [record.grsq(i) for i in range(len(record))] 29 | grsqs_ = [1 - (record.gcv(i) / gcv(sst, 1, num_samples, penalty)) 30 | for i in range(len(record))] 31 | assert_list_almost_equal(mses, mses_) 32 | assert_list_almost_equal(gcvs, gcvs_) 33 | assert_list_almost_equal(rsqs, rsqs_) 34 | assert_list_almost_equal(grsqs, grsqs_) 35 | -------------------------------------------------------------------------------- /pyearth/test/test_data.csv: -------------------------------------------------------------------------------- 1 | ,x1,x2,x3,x4,y 2 | 0,43.599490214200372,168.01752905539175,0.54966247787870914,5.3532239261827685,1.1297150506376397 3 | 1,42.036780208748901,665.30797852674027,0.2046486340378425,7.192709663506637,1.2713369428242232 4 | 2,29.965467367452312,561.56026190867692,0.6211338327692949,6.2914209427703902,1.4850975083257041 5 | 3,13.457994534493356,964.65939760630852,0.18443986564691528,8.8533514781667346,1.4953000345937923 6 | 4,85.397529263948883,933.06293121890087,0.846561485357468,1.7964547700906099,1.4631020455675685 7 | 5,50.524609012170394,232.31757947498625,0.42812232759738944,1.9653091566061256,1.1007702782292474 8 | 6,12.715997170127746,1100.523659094282,0.22601200060423587,2.0694568430998297,1.5197173280758225 9 | 7,22.03062070705597,697.14978338305832,0.46778748458230024,3.0174322626496535,1.5033445601444178 10 | 8,64.040672521491487,914.82020234450215,0.50523672001854913,4.8689265111855935,1.4331168743198677 11 | 9,79.363745444157701,1073.1748771988184,0.16229859850231387,8.0075234660715626,1.1432487611454192 12 | 10,96.455108008925521,942.49145512022062,0.88952006394614491,4.4161365267110968,1.4562484984016864 13 | 11,56.714412762770927,824.11483997236769,0.436747263026799,8.7655918499710026,1.414510544111284 14 | 12,53.560417349765629,1683.72388440954,0.54420816014810214,1.820949222750248,1.5124095787831815 15 | 13,36.634240167502043,1515.6370663391237,0.40627504304795081,1.2720236589484959,1.5113725053014444 16 | 14,24.717723899735333,235.35264229889736,0.99385201142127289,10.705803133771735,1.4655130013848647 17 | 15,80.02583511325868,1108.8091147978175,0.76495986045168152,2.6922544658417795,1.4767263034461056 18 | 16,29.302323181943979,981.79412176895062,0.3566242811227589,1.4567896524540005,1.4873012780322354 19 | 17,98.315344535721266,846.67353972446074,0.50400043937915262,4.2354131753469879,1.3443519921907898 20 | 18,25.974475274713704,757.69792530898212,0.83201689963450087,8.3674705628711301,1.5296175850152098 21 | 19,37.921056694163745,146.92919419462237,0.79740493903259357,3.6938879759630119,1.2577693248791448 22 | 20,58.268488852903452,167.40444477138499,0.66220201926832578,4.8752342587328146,1.0868547000197621 23 | 21,49.707379873451259,803.46557407006787,0.35087190130164048,6.509779053244193,1.3962687049291784 24 | 22,97.29106898748249,309.89810937521673,0.31325852842971358,1.4179770981531006,0.78429139704278905 25 | 23,73.839975862090029,1199.7944712348387,0.21463574639995364,5.1675344019533709,1.2915519005124272 26 | 24,64.384193414819507,1206.278241921543,0.17047713348377069,9.8165223574854856,1.2673762639361186 27 | 25,77.800815980846821,344.49507542459531,0.86891662643207879,8.4877787827685438,1.3165122722672742 28 | 26,79.858565409908095,1013.287786221621,0.22083791506146189,10.184586465896006,1.2280106041472256 29 | 27,59.208457766637636,691.28771196765126,0.26377852928001977,10.13915477165132,1.2568334102951819 30 | 28,41.973546130682074,1008.1357886623166,0.60844215780504962,9.2624982843415786,1.5024742715099091 31 | 29,62.356318456213074,414.34567303306642,0.59125735265354584,5.8926616699485912,1.3215571295292234 32 | 30,54.790778009037666,1268.4203030312453,0.24581116390413071,2.8662714558648021,1.3968435575605214 33 | 31,11.058314779616108,573.3746240396913,0.010250039399590904,7.2935972282958215,0.48848512011765227 34 | 32,29.517230521974248,431.62028327847486,0.095288052260803724,3.8375580746525464,0.948291052430141 35 | 33,21.492438413069525,592.2340141355902,0.47140985746967901,6.4949682108840801,1.493964996060547 36 | 34,84.511311557470165,1740.5210094161948,0.048868086229351348,3.3211825143616394,0.78860895398958175 37 | 35,64.331142724823522,389.4400333896524,0.8701458933690015,3.1740242630502591,1.3831880046821949 38 | 36,74.175503888344892,1192.4564185889287,0.79888551150462961,1.312475622652616,1.4930896449101998 39 | 37,22.957402874434518,1276.7630450157669,0.087562509001122835,1.3058948319431931,1.368261016889893 40 | 38,35.713493369956076,1089.1481873897621,0.052222727384160228,1.6566366620637754,1.0101147923717795 41 | 39,4.3501186059715646,771.19312047180301,0.66842704601605774,2.9802711181616961,1.5623576569691324 42 | 40,87.626645170415912,832.02671712936331,0.61964146370996465,3.9042674293498836,1.4024404901819352 43 | 41,61.525457346427629,1683.5894947019638,0.44801471652408964,3.0704983825331427,1.4894072650971031 44 | 42,42.536729190538644,851.9327531051822,0.50761748667196138,6.2573324981095793,1.4727508306732644 45 | 43,4.2429191882546125,394.26335861902976,0.45022717961737324,8.0797117061892081,1.5468981531588779 46 | 44,77.758790830553366,1395.1610647361344,0.50277875250922144,10.567475933072476,1.460393868002404 47 | 45,33.187371253987585,912.1873275789045,0.74837399934103632,9.5783951843951396,1.5222195816705415 48 | 46,41.446005513291148,1512.0986444671671,0.44457094175070921,8.1574708428635425,1.5092202298531809 49 | 47,0.84483635788197287,166.83346202488673,0.94045049354208854,2.0213794823119349,1.5654116764743224 50 | 48,66.196779043970764,588.40317224341504,0.20084234377544541,4.8833418211475914,1.0601903754929445 51 | 49,92.590159650849927,1057.2797254435752,0.91671461195788251,8.0226424081972496,1.4755551186558762 52 | 50,50.121645168708696,952.42141882934163,0.21882078632937274,1.0466352273504989,1.3347814583500877 53 | 51,41.783624400271577,1046.0107958571264,0.87656186842332706,7.7758484022790944,1.5252569521812418 54 | 52,84.233304610347119,1638.6668835657422,0.94111587363362226,9.1601915192439396,1.5162307488602251 55 | 53,13.139460547944381,691.9282666193011,0.2085024945255225,9.6738153715652349,1.4799705457454786 56 | 54,79.233021930625796,695.80262118445432,0.56914053205563542,9.0513154025570159,1.373325112134129 57 | 55,83.316124473397622,461.85405604761195,0.7306802005261156,9.0341416415944895,1.3287501970473858 58 | 56,4.3129036353362737,300.76807022306298,0.43508364985882764,3.208812218607441,1.537849657136791 59 | 57,88.062206616568076,1301.3655484690632,0.713007293546981,8.6985965699270196,1.4761731652796777 60 | 58,33.074702616940478,499.12785574456325,0.62634395920102637,5.1447924952538919,1.4653918027366943 61 | 59,93.678107223418593,1175.192081886714,0.3869003430123682,9.5511966080419057,1.3676095804869888 62 | 60,38.079258371620803,416.9556655484522,0.78165942566608138,5.7220858741580143,1.4544863401972903 63 | 61,25.948040440427711,1256.7866259336176,0.98048507899877846,3.4625358896778984,1.5497421662642135 64 | 62,79.024380259404523,1371.0360436135713,0.12003979217741401,9.3890259455900065,1.1231452260421266 65 | 63,46.173890475653167,331.14371718490827,0.53634347309606856,3.9651275079462902,1.3164476345244571 66 | 64,17.57241932567306,255.10941472150532,0.19889909182949916,5.2958861905641381,1.2374029218521734 67 | 65,64.181176362883164,268.40879494821309,0.71573058777096521,2.0174481096338206,1.2483635075582606 68 | 66,17.559907090882188,1290.9659113904681,0.88526367441048914,10.831383610430724,1.5554324559039934 69 | 67,65.290695070121458,906.10727496154846,0.087787554884994456,6.5895079469727671,0.88349792606034916 70 | 68,18.115022940864645,966.00540791086814,0.60893014893439157,9.3321249012260417,1.540010225121146 71 | 69,32.375058225614652,1741.0966877926312,0.99495025996779318,9.2615081943979671,1.5521094914523137 72 | 70,69.41525454251834,1324.2073630319194,0.8752688098046626,10.515101483457324,1.5109773403955418 73 | 71,85.226292659755117,1172.7191296517713,0.073498384714436593,7.2338471429583864,0.79103662938713226 74 | 72,35.321389178551854,394.85070848243845,0.89316618656704339,9.2489048120667938,1.4709741466986235 75 | 73,52.663842385283331,882.76747093432482,0.48930128812712559,3.60720019719452,1.4494708771682183 76 | 74,40.263792313314738,839.01992696466721,0.015056134248370046,9.9290588307871577,0.30401221390383193 77 | 75,2.9149317847511735,734.83202953690954,0.09587331469277216,2.969643203500385,1.5294442092384655 78 | 76,94.242061239714317,732.18275455239154,0.72900080182147697,2.2616010862902147,1.3960352705744583 79 | 77,25.128292170720933,1565.274520614033,0.90329615048901546,3.87635258215888,1.5530259502567871 80 | 78,96.711313151229632,1217.0896077711188,0.50300677353432655,4.7800240864855548,1.414118748214515 81 | 79,33.73682942964583,1680.8952215019413,0.69333382858922288,6.3410170007333724,1.5418562287011579 82 | 80,4.3489697936277434,530.06979251124858,0.94317072266720825,6.568601662432533,1.5620976675187246 83 | 81,35.751823235523204,141.1021036455474,0.25354826722492296,3.5422669015053678,0.78571091574943486 84 | 82,17.068552603402608,689.23268976152656,0.20264905134772382,1.6290154018084104,1.4491942594415783 85 | 83,9.7960395886258596,937.13175163872779,0.069660590412445433,10.351535633514883,1.4218482165198603 86 | 84,20.609763517058365,884.0140824788291,0.41466825710663691,3.9133771843298231,1.5146325379560737 87 | 85,32.352851482004418,471.33548120389071,0.28026828777019874,2.1710503569658961,1.3306112118673223 88 | 86,68.925187477754406,1178.5798517678015,0.31152940334298462,9.802989561574968,1.3852319082884261 89 | 87,38.637572509456085,1123.9564075595154,0.52316932617519518,4.2505680361012992,1.5051826484433741 90 | 88,92.979042306516263,1511.3094300825874,0.84112294801990795,10.542886365563277,1.4977834382425299 91 | 89,0.24824262680964715,929.60333613098521,0.93791311906460373,2.3442675982313901,1.5705116078990755 92 | 90,24.499413841416974,950.74709217600264,0.67842150982188265,4.3275995761694919,1.5328314042906843 93 | 91,30.636467159000723,1393.0290357968324,0.032853938228447621,8.8579786252172905,0.98089726746506534 94 | 92,42.85402814602093,1153.2312807768033,0.70550302490679973,1.9846862942141765,1.5181733452638007 95 | 93,90.384616487488657,1469.0165275332865,0.37538143725374984,2.9479432849585923,1.4083347621930655 96 | 94,6.280543805583938,629.8159444446452,0.71319196060205603,9.1318171617989297,1.5568149793762627 97 | 95,32.658830121723945,485.53547264576639,0.32739418364654083,10.643668363057484,1.3681645174548087 98 | 96,9.6090700177030914,390.61115652266483,0.69423911895276047,2.3976349904776564,1.5353763989585334 99 | 97,26.662589262979463,1437.7544447801224,0.30061178385836629,6.9701654887125475,1.5091848395920116 100 | 98,57.281228776535677,559.29964147614896,0.24873428889263027,3.8967549066201626,1.1802018969155548 101 | 99,87.594007242403123,156.30754331038852,0.09163641211230289,4.4120995915425212,0.16207038454185166 102 | -------------------------------------------------------------------------------- /pyearth/test/test_export.py: -------------------------------------------------------------------------------- 1 | from pyearth._basis import (Basis, ConstantBasisFunction, HingeBasisFunction, 2 | LinearBasisFunction) 3 | from pyearth.export import export_python_function, export_python_string,\ 4 | export_sympy 5 | from nose.tools import assert_almost_equal 6 | import numpy 7 | import six 8 | from pyearth import Earth 9 | from pyearth._types import BOOL 10 | from pyearth.test.testing_utils import if_pandas,\ 11 | if_sympy 12 | from itertools import product 13 | from numpy.testing.utils import assert_array_almost_equal 14 | 15 | numpy.random.seed(0) 16 | 17 | basis = Basis(10) 18 | constant = ConstantBasisFunction() 19 | basis.append(constant) 20 | bf1 = HingeBasisFunction(constant, 0.1, 10, 1, False, 'x1') 21 | bf2 = HingeBasisFunction(constant, 0.1, 10, 1, True, 'x1') 22 | bf3 = LinearBasisFunction(bf1, 2, 'x2') 23 | basis.append(bf1) 24 | basis.append(bf2) 25 | basis.append(bf3) 26 | X = numpy.random.normal(size=(100, 10)) 27 | missing = numpy.zeros_like(X, dtype=BOOL) 28 | B = numpy.empty(shape=(100, 4), dtype=numpy.float64) 29 | basis.transform(X, missing, B) 30 | beta = numpy.random.normal(size=4) 31 | y = numpy.empty(shape=100, dtype=numpy.float64) 32 | y[:] = numpy.dot(B, beta) + numpy.random.normal(size=100) 33 | beta2 = numpy.random.normal(size=4) 34 | y2 = numpy.empty(shape=100, dtype=numpy.float64) 35 | y2[:] = numpy.dot(B, beta2) + numpy.random.normal(size=100) 36 | Y = numpy.concatenate([y[:, None], y2[:, None]], axis=1) 37 | default_params = {"penalty": 1} 38 | 39 | 40 | def test_export_python_function(): 41 | for smooth in (True, False): 42 | model = Earth(penalty=1, smooth=smooth, max_degree=2).fit(X, y) 43 | export_model = export_python_function(model) 44 | for exp_pred, model_pred in zip(model.predict(X), export_model(X)): 45 | assert_almost_equal(exp_pred, model_pred) 46 | 47 | 48 | def test_export_python_string(): 49 | for smooth in (True, False): 50 | model = Earth(penalty=1, smooth=smooth, max_degree=2).fit(X, y) 51 | export_model = export_python_string(model, 'my_test_model') 52 | six.exec_(export_model, globals()) 53 | for exp_pred, model_pred in zip(model.predict(X), my_test_model(X)): 54 | assert_almost_equal(exp_pred, model_pred) 55 | 56 | @if_pandas 57 | @if_sympy 58 | def test_export_sympy(): 59 | import pandas as pd 60 | from sympy.utilities.lambdify import lambdify 61 | from sympy.printing.lambdarepr import NumPyPrinter 62 | 63 | class PyEarthNumpyPrinter(NumPyPrinter): 64 | def _print_Max(self, expr): 65 | return 'maximum(' + ','.join(self._print(i) for i in expr.args) + ')' 66 | 67 | def _print_NaNProtect(self, expr): 68 | return 'where(isnan(' + ','.join(self._print(a) for a in expr.args) + '), 0, ' \ 69 | + ','.join(self._print(a) for a in expr.args) + ')' 70 | 71 | def _print_Missing(self, expr): 72 | return 'isnan(' + ','.join(self._print(a) for a in expr.args) + ').astype(float)' 73 | 74 | for smooth, n_cols, allow_missing in product((True, False), (1, 2), (True, False)): 75 | X_df = pd.DataFrame(X.copy(), columns=['x_%d' % i for i in range(X.shape[1])]) 76 | y_df = pd.DataFrame(Y[:, :n_cols]) 77 | if allow_missing: 78 | # Randomly remove some values so that the fitted model contains MissingnessBasisFunctions 79 | X_df['x_1'][numpy.random.binomial(n=1, p=.1, size=X_df.shape[0]).astype(bool)] = numpy.nan 80 | 81 | model = Earth(allow_missing=allow_missing, smooth=smooth, max_degree=2).fit(X_df, y_df) 82 | expressions = export_sympy(model) if n_cols > 1 else [export_sympy(model)] 83 | module_dict = {'select': numpy.select, 'less_equal': numpy.less_equal, 'isnan': numpy.isnan, 84 | 'greater_equal':numpy.greater_equal, 'logical_and': numpy.logical_and, 'less': numpy.less, 85 | 'logical_not':numpy.logical_not, "greater": numpy.greater, 'maximum':numpy.maximum, 86 | 'Missing': lambda x: numpy.isnan(x).astype(float), 87 | 'NaNProtect': lambda x: numpy.where(numpy.isnan(x), 0, x), 'nan': numpy.nan, 88 | 'float': float, 'where': numpy.where 89 | } 90 | 91 | for i, expression in enumerate(expressions): 92 | # The lambdified functions for smoothed basis functions only work with modules='numpy' and 93 | # for regular basis functions with modules={'Max':numpy.maximum}. This is a confusing situation 94 | func = lambdify(X_df.columns, expression, printer=PyEarthNumpyPrinter, modules=module_dict) 95 | y_pred_sympy = func(*[X_df.loc[:,var] for var in X_df.columns]) 96 | 97 | y_pred = model.predict(X_df)[:,i] if n_cols > 1 else model.predict(X_df) 98 | assert_array_almost_equal(y_pred, y_pred_sympy) 99 | -------------------------------------------------------------------------------- /pyearth/test/test_forward.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Feb 16, 2013 3 | 4 | @author: jasonrudy 5 | ''' 6 | 7 | import os 8 | import numpy 9 | 10 | from nose.tools import assert_equal 11 | 12 | from pyearth._forward import ForwardPasser 13 | from pyearth._basis import (Basis, ConstantBasisFunction, 14 | HingeBasisFunction, LinearBasisFunction) 15 | from pyearth._types import BOOL 16 | 17 | numpy.random.seed(0) 18 | basis = Basis(10) 19 | constant = ConstantBasisFunction() 20 | basis.append(constant) 21 | bf1 = HingeBasisFunction(constant, 0.1, 10, 1, False, 'x1') 22 | bf2 = HingeBasisFunction(constant, 0.1, 10, 1, True, 'x1') 23 | bf3 = LinearBasisFunction(bf1, 2, 'x2') 24 | basis.append(bf1) 25 | basis.append(bf2) 26 | basis.append(bf3) 27 | X = numpy.random.normal(size=(100, 10)) 28 | missing = numpy.zeros_like(X).astype(BOOL) 29 | B = numpy.empty(shape=(100, 4), dtype=numpy.float64) 30 | basis.transform(X, missing, B) 31 | beta = numpy.random.normal(size=4) 32 | y = numpy.empty(shape=100, dtype=numpy.float64) 33 | y[:] = numpy.dot(B, beta) + numpy.random.normal(size=100) 34 | sample_weight = numpy.ones((X.shape[0], 1)) 35 | 36 | 37 | def test_run(): 38 | 39 | forwardPasser = ForwardPasser(X, missing, y[:, numpy.newaxis], 40 | sample_weight, 41 | max_terms=1000, penalty=1) 42 | 43 | forwardPasser.run() 44 | res = str(forwardPasser.get_basis()) + \ 45 | '\n' + str(forwardPasser.trace()) 46 | filename = os.path.join(os.path.dirname(__file__), 47 | 'forward_regress.txt') 48 | # with open(filename, 'w') as fl: 49 | # fl.write(res) 50 | with open(filename, 'r') as fl: 51 | prev = fl.read() 52 | assert_equal(res, prev) 53 | -------------------------------------------------------------------------------- /pyearth/test/test_knot_search.py: -------------------------------------------------------------------------------- 1 | from pyearth._knot_search import (MultipleOutcomeDependentData, 2 | KnotSearchWorkingData, 3 | PredictorDependentData, 4 | KnotSearchReadOnlyData, 5 | KnotSearchData, 6 | knot_search, 7 | SingleWeightDependentData, 8 | SingleOutcomeDependentData) 9 | from nose.tools import assert_equal 10 | import numpy as np 11 | from numpy.testing.utils import assert_almost_equal, assert_array_equal 12 | from scipy.linalg import qr 13 | 14 | 15 | def test_outcome_dependent_data(): 16 | np.random.seed(10) 17 | m = 1000 18 | max_terms = 100 19 | y = np.random.normal(size=m) 20 | w = np.random.normal(size=m) ** 2 21 | weight = SingleWeightDependentData.alloc(w, m, max_terms, 1e-16) 22 | data = SingleOutcomeDependentData.alloc(y, weight, m, max_terms) 23 | 24 | # Test updating 25 | B = np.empty(shape=(m, max_terms)) 26 | for k in range(max_terms): 27 | b = np.random.normal(size=m) 28 | B[:, k] = b 29 | code = weight.update_from_array(b) 30 | if k >= 99: 31 | 1 + 1 32 | data.update() 33 | assert_equal(code, 0) 34 | assert_almost_equal( 35 | np.dot(weight.Q_t[:k + 1, :], np.transpose(weight.Q_t[:k + 1, :])), 36 | np.eye(k + 1)) 37 | assert_equal(weight.update_from_array(b), -1) 38 | # data.update(1e-16) 39 | 40 | # Test downdating 41 | q = np.array(weight.Q_t).copy() 42 | theta = np.array(data.theta[:max_terms]).copy() 43 | weight.downdate() 44 | data.downdate() 45 | weight.update_from_array(b) 46 | data.update() 47 | assert_almost_equal(q, np.array(weight.Q_t)) 48 | assert_almost_equal(theta, np.array(data.theta[:max_terms])) 49 | assert_almost_equal( 50 | np.array(data.theta[:max_terms]), np.dot(weight.Q_t, w * y)) 51 | wB = B * w[:, None] 52 | Q, _ = qr(wB, pivoting=False, mode='economic') 53 | assert_almost_equal(np.abs(np.dot(weight.Q_t, Q)), np.eye(max_terms)) 54 | 55 | # Test that reweighting works 56 | assert_equal(data.k, max_terms) 57 | w2 = np.random.normal(size=m) ** 2 58 | weight.reweight(w2, B, max_terms) 59 | data.synchronize() 60 | assert_equal(data.k, max_terms) 61 | w2B = B * w2[:, None] 62 | Q2, _ = qr(w2B, pivoting=False, mode='economic') 63 | assert_almost_equal(np.abs(np.dot(weight.Q_t, Q2)), np.eye(max_terms)) 64 | assert_almost_equal( 65 | np.array(data.theta[:max_terms]), np.dot(weight.Q_t, w2 * y)) 66 | 67 | 68 | def test_knot_candidates(): 69 | np.random.seed(10) 70 | m = 1000 71 | x = np.random.normal(size=m) 72 | p = np.random.normal(size=m) 73 | p[np.random.binomial(p=.1, n=1, size=m) == 1] = 0. 74 | x[np.random.binomial(p=.1, n=1, size=m) == 1] = 0. 75 | predictor = PredictorDependentData.alloc(x) 76 | candidates, candidates_idx = predictor.knot_candidates( 77 | p, 5, 10, 0, 0, set()) 78 | assert_array_equal(candidates, x[candidates_idx]) 79 | assert_equal(len(candidates), len(set(candidates))) 80 | # print candidates, np.sum(x==0) 81 | # print candidates_idx 82 | 83 | 84 | def slow_knot_search(p, x, B, candidates, outcomes): 85 | # Brute force, utterly un-optimized knot search with no fast update. 86 | # Use only for testing the actual knot search function. 87 | # This version allows #for multiple outcome columns. 88 | best_e = float('inf') 89 | best_k = 0 90 | best_knot = float('inf') 91 | for k, knot in enumerate(candidates): 92 | # Formulate the linear system for this candidate 93 | X = np.concatenate( 94 | [B, (p * np.maximum(x - knot, 0.0))[:, None]], axis=1) 95 | 96 | # Solve the system for each y and w 97 | e_squared = 0.0 98 | for y, w in outcomes: 99 | # Solve the system 100 | beta = np.linalg.lstsq(w[:, None] * X, w * y)[0] 101 | 102 | # Compute the error 103 | r = w * (y - np.dot(X, beta)) 104 | e_squared += np.dot(r, r) 105 | # Compute loss 106 | e = e_squared # / np.sum(w ** 2) 107 | 108 | # Compare to the best error 109 | if e < best_e: 110 | best_e = e 111 | best_k = k 112 | best_knot = knot 113 | return best_knot, best_k, best_e 114 | 115 | 116 | def generate_problem(m, q, r, n_outcomes, shared_weight): 117 | # Generate some problem data 118 | x = np.random.normal(size=m) 119 | B = np.random.normal(size=(m, q)) 120 | p = B[:, 1] 121 | knot = x[int(m / 2)] 122 | candidates = np.array(sorted( 123 | [knot] + 124 | list(x[np.random.randint(low=0, high=m, size=r - 1)])))[::-1] 125 | 126 | # These data need to be generated for each outcome 127 | outcomes = [] 128 | if shared_weight: 129 | w = np.random.normal(size=m) ** 2 130 | # w = w * 0. + 1. 131 | for _ in range(n_outcomes): 132 | beta = np.random.normal(size=q + 1) 133 | y = (np.dot( 134 | np.concatenate([B, (p * np.maximum(x - knot, 0.0))[:, None]], 135 | axis=1), 136 | beta) + 0.01 * np.random.normal(size=m)) 137 | if not shared_weight: 138 | w = np.random.normal(size=m) ** 2 139 | # w = w * 0. + 1. 140 | outcomes.append((y, w)) 141 | 142 | return x, B, p, knot, candidates, outcomes 143 | 144 | 145 | def form_inputs(x, B, p, knot, candidates, y, w): 146 | # Formulate the inputs for the fast version 147 | m, q = B.shape 148 | max_terms = q + 2 149 | workings = [] 150 | n_outcomes = w.shape[1] 151 | for _ in range(n_outcomes): 152 | working = KnotSearchWorkingData.alloc(max_terms) 153 | workings.append(working) 154 | outcome = MultipleOutcomeDependentData.alloc( 155 | y, w, m, n_outcomes, max_terms, 1e-16) 156 | for j in range(B.shape[1]): 157 | outcome.update_from_array(B[:, j]) 158 | predictor = PredictorDependentData.alloc(x) 159 | constant = KnotSearchReadOnlyData(predictor, outcome) 160 | return KnotSearchData(constant, workings, q) 161 | 162 | 163 | def test_knot_search(): 164 | seed = 10 165 | np.random.seed(seed) 166 | m = 100 167 | q = 5 168 | r = 10 169 | n_outcomes = 3 170 | 171 | # Generate some problem data 172 | x, B, p, knot, candidates, outcomes = generate_problem( 173 | m, q, r, n_outcomes, False) 174 | y = np.concatenate([y_[:, None] for y_, _ in outcomes], axis=1) 175 | w = np.concatenate([w_[:, None] for _, w_ in outcomes], axis=1) 176 | 177 | # Formulate the inputs for the fast version 178 | data = form_inputs(x, B, p, knot, candidates, y, w) 179 | 180 | # Get the answer using the slow version 181 | best_knot, best_k, best_e = slow_knot_search(p, x, B, candidates, outcomes) 182 | 183 | # Test the test 184 | assert_almost_equal(best_knot, knot) 185 | assert_equal(r, len(candidates)) 186 | assert_equal(m, B.shape[0]) 187 | assert_equal(q, B.shape[1]) 188 | assert_equal(len(outcomes), n_outcomes) 189 | 190 | # Run fast knot search and compare results to slow knot search 191 | fast_best_knot, fast_best_k, fast_best_e = knot_search(data, candidates, 192 | p, q, m, r, 193 | len(outcomes), 0) 194 | assert_almost_equal(fast_best_knot, best_knot) 195 | assert_equal(candidates[fast_best_k], candidates[best_k]) 196 | assert_almost_equal(fast_best_e, best_e) 197 | -------------------------------------------------------------------------------- /pyearth/test/test_pruning.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class Test(object): 4 | 5 | def __init__(self): 6 | pass 7 | 8 | def test(self): 9 | pass 10 | 11 | if __name__ == '__main__': 12 | import nose 13 | nose.run(argv=[__file__, '-s', '-v']) 14 | -------------------------------------------------------------------------------- /pyearth/test/test_qr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on Jan 28, 2016 3 | 4 | @author: jason 5 | ''' 6 | import numpy as np 7 | from pyearth._qr import UpdatingQT 8 | 9 | 10 | def test_updating_qt(): 11 | np.random.seed(0) 12 | m = 10 13 | n = 3 14 | 15 | X = np.random.normal(size=(n, m)).T 16 | u = UpdatingQT.alloc(m, n, 1e-14) 17 | 18 | # u2 = UpdatingQT(m, n) 19 | Q = np.linalg.qr(X, mode='reduced')[0] 20 | u.update(X[:, 0]) 21 | 22 | # u2.update(X[:,0]) 23 | u.update(X[:, 1]) 24 | # u2.update(X[:,1]) 25 | u.update(X[:, 2]) 26 | # u2.update(X[:,2]) 27 | 28 | # assert np.max(np.abs(np.abs(u2.Q_t) - np.abs(Q.T))) < .0000000000001 29 | assert np.max(np.abs(np.abs(u.Q_t) - np.abs(Q.T))) < .0000000000001 30 | 31 | X2 = X.copy() 32 | X2[:, 2] = np.random.normal(size=m) 33 | 34 | u.downdate() 35 | u.update(X2[:, 2]) 36 | 37 | Q2 = np.linalg.qr(X2, mode='reduced')[0] 38 | assert np.max(np.abs(np.abs(u.Q_t) - np.abs(Q2.T))) < .0000000000001 39 | 40 | 41 | def test_updating_qr_with_linear_dependence(): 42 | np.random.seed(0) 43 | m = 10 44 | n = 5 45 | assert n >= 3 46 | X = np.random.normal(size=(n, m)).T 47 | X[:, 2] = X[:, 0] + 3 * X[:, 1] 48 | 49 | Q = np.linalg.qr(X, mode='reduced')[0] 50 | u = UpdatingQT.alloc(m, n, 1e-14) 51 | u2 = UpdatingQT.alloc(m, n, 1e-14) 52 | 53 | u.update(X[:, 0]) 54 | u2.update(X[:, 0]) 55 | # u2.update(X[:,0]) 56 | u.update(X[:, 1]) 57 | u2.update(X[:, 1]) 58 | # u2.update(X[:,1]) 59 | u.update(X[:, 2]) 60 | 61 | assert np.max( 62 | np.abs(np.abs(u.Q_t[:2, :]) - np.abs(Q[:, :2].T))) < .0000000000001 63 | assert np.max(np.abs(u.Q_t[2, :])) == 0. 64 | 65 | # Make sure you can downdate a dependent column safely 66 | u.downdate() 67 | u.update(X[:, 2]) 68 | assert np.max( 69 | np.abs(np.abs(u.Q_t[:2, :]) - np.abs(Q[:, :2].T))) < .0000000000001 70 | assert np.max(np.abs(u.Q_t[2, :])) == 0. 71 | 72 | for j in range(3, n): 73 | u.update(X[:, j]) 74 | u2.update(X[:, j]) 75 | 76 | # Q_t is orthonormal except for its zero column 77 | Q_nonzero = np.concatenate([u.Q_t[:2, :].T, u.Q_t[3:, :].T], axis=1) 78 | np.testing.assert_array_almost_equal( 79 | np.dot(Q_nonzero.T, Q_nonzero), np.eye(n - 1)) 80 | 81 | # Q_t.T is in the column space of X 82 | b = np.linalg.lstsq(X, u.Q_t.T)[0] 83 | Q_hat = np.dot(X, b) 84 | np.testing.assert_array_almost_equal(Q_hat, u.Q_t.T) 85 | 86 | # X is in the column space of Q_t.T 87 | a = np.linalg.lstsq(u.Q_t.T, X)[0] 88 | X_hat = np.dot(u.Q_t.T, a) 89 | np.testing.assert_array_almost_equal(X_hat, X) 90 | 91 | # u and u2 should have the same householder 92 | np.testing.assert_array_almost_equal( 93 | u.householder.V[:, :u.householder.k], 94 | u2.householder.V[:, :u2.householder.k]) 95 | np.testing.assert_array_almost_equal( 96 | u.householder.T[:u.householder.k, :u.householder.k], 97 | u2.householder.T[:u2.householder.k, :u2.householder.k]) 98 | 99 | # u should have one more column than u2 100 | assert u.k == u2.k + 1 101 | -------------------------------------------------------------------------------- /pyearth/test/test_util.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class TestUtil(object): 4 | 5 | def __init__(self): 6 | pass 7 | 8 | def test(self): 9 | pass 10 | 11 | if __name__ == '__main__': 12 | import nose 13 | nose.run(argv=[__file__, '-s', '-v']) 14 | -------------------------------------------------------------------------------- /pyearth/test/testing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import wraps 3 | from nose import SkipTest 4 | from nose.tools import assert_almost_equal 5 | from distutils.version import LooseVersion 6 | import sys 7 | 8 | def if_environ_has(var_name): 9 | # Test decorator that skips test if environment variable is not defined 10 | def if_environ(func): 11 | @wraps(func) 12 | def run_test(*args, **kwargs): 13 | if var_name in os.environ: 14 | return func(*args, **kwargs) 15 | else: 16 | raise SkipTest('Only run if %s environment variable is ' 17 | 'defined.' % var_name) 18 | return run_test 19 | return if_environ 20 | 21 | def if_platform_not_win_32(func): 22 | @wraps(func) 23 | def run_test(*args, **kwargs): 24 | if sys.platform == 'win32': 25 | raise SkipTest('Skip for 32 bit Windows platforms.') 26 | else: 27 | return func(*args, **kwargs) 28 | return run_test 29 | 30 | def if_sklearn_version_greater_than_or_equal_to(min_version): 31 | ''' 32 | Test decorator that skips test unless sklearn version is greater than or 33 | equal to min_version. 34 | ''' 35 | def _if_sklearn_version(func): 36 | @wraps(func) 37 | def run_test(*args, **kwargs): 38 | import sklearn 39 | if LooseVersion(sklearn.__version__) < LooseVersion(min_version): 40 | raise SkipTest('sklearn version less than %s' % 41 | str(min_version)) 42 | else: 43 | return func(*args, **kwargs) 44 | return run_test 45 | return _if_sklearn_version 46 | 47 | 48 | def if_statsmodels(func): 49 | """Test decorator that skips test if statsmodels not installed. """ 50 | 51 | @wraps(func) 52 | def run_test(*args, **kwargs): 53 | try: 54 | import statsmodels 55 | except ImportError: 56 | raise SkipTest('statsmodels not available.') 57 | else: 58 | return func(*args, **kwargs) 59 | return run_test 60 | 61 | 62 | def if_pandas(func): 63 | """Test decorator that skips test if pandas not installed. """ 64 | 65 | @wraps(func) 66 | def run_test(*args, **kwargs): 67 | try: 68 | import pandas 69 | except ImportError: 70 | raise SkipTest('pandas not available.') 71 | else: 72 | return func(*args, **kwargs) 73 | return run_test 74 | 75 | def if_sympy(func): 76 | """ Test decorator that skips test if sympy not installed """ 77 | 78 | @wraps(func) 79 | def run_test(*args, **kwargs): 80 | try: 81 | from sympy import Symbol, Add, Mul, Max, RealNumber, Piecewise, sympify, Pow, And, lambdify 82 | except ImportError: 83 | raise SkipTest('sympy not available.') 84 | else: 85 | return func(*args, **kwargs) 86 | return run_test 87 | 88 | 89 | 90 | def if_patsy(func): 91 | """Test decorator that skips test if patsy not installed. """ 92 | 93 | @wraps(func) 94 | def run_test(*args, **kwargs): 95 | try: 96 | import patsy 97 | except ImportError: 98 | raise SkipTest('patsy not available.') 99 | else: 100 | return func(*args, **kwargs) 101 | return run_test 102 | 103 | 104 | def assert_list_almost_equal(list1, list2): 105 | for el1, el2 in zip(list1, list2): 106 | assert_almost_equal(el1, el2) 107 | 108 | 109 | def assert_list_almost_equal_value(list, value): 110 | for el in list: 111 | assert_almost_equal(el, value) 112 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [versioneer] 2 | VCS = git 3 | style = pep440 4 | versionfile_source = pyearth/_version.py 5 | versionfile_build = pyearth/_version.py 6 | tag_prefix= 7 | parentdir_prefix = py-earth 8 | 9 | [metadata] 10 | description-file = description.md 11 | license_file = LICENSE.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension, find_packages 2 | import sys 3 | import codecs 4 | import versioneer 5 | 6 | # Determine whether to use Cython 7 | if '--cythonize' in sys.argv: 8 | cythonize_switch = True 9 | del sys.argv[sys.argv.index('--cythonize')] 10 | else: 11 | cythonize_switch = False 12 | 13 | def get_ext_modules(): 14 | import numpy 15 | # Find all includes 16 | local_inc = 'pyearth' 17 | numpy_inc = numpy.get_include() 18 | 19 | # Set up the ext_modules for Cython or not, depending 20 | if cythonize_switch: 21 | from Cython.Build import cythonize 22 | ext_modules = cythonize( 23 | [Extension( 24 | "pyearth._util", ["pyearth/_util.pyx"], include_dirs=[numpy_inc]), 25 | Extension( 26 | "pyearth._basis", 27 | ["pyearth/_basis.pyx"], 28 | include_dirs=[numpy_inc]), 29 | Extension( 30 | "pyearth._record", 31 | ["pyearth/_record.pyx"], 32 | include_dirs=[numpy_inc]), 33 | Extension( 34 | "pyearth._pruning", 35 | ["pyearth/_pruning.pyx"], 36 | include_dirs=[local_inc, 37 | numpy_inc]), 38 | Extension( 39 | "pyearth._forward", 40 | ["pyearth/_forward.pyx"], 41 | include_dirs=[local_inc, 42 | numpy_inc]), 43 | Extension( 44 | "pyearth._knot_search", 45 | ["pyearth/_knot_search.pyx"], 46 | include_dirs=[local_inc, 47 | numpy_inc]), 48 | Extension( 49 | "pyearth._qr", 50 | ["pyearth/_qr.pyx"], 51 | include_dirs=[local_inc, 52 | numpy_inc]), 53 | Extension( 54 | "pyearth._types", 55 | ["pyearth/_types.pyx"], 56 | include_dirs=[local_inc, 57 | numpy_inc]) 58 | ]) 59 | else: 60 | ext_modules = [Extension( 61 | "pyearth._util", ["pyearth/_util.c"], include_dirs=[numpy_inc]), 62 | Extension( 63 | "pyearth._basis", 64 | ["pyearth/_basis.c"], 65 | include_dirs=[numpy_inc]), 66 | Extension( 67 | "pyearth._record", 68 | ["pyearth/_record.c"], 69 | include_dirs=[numpy_inc]), 70 | Extension( 71 | "pyearth._pruning", 72 | ["pyearth/_pruning.c"], 73 | include_dirs=[local_inc, 74 | numpy_inc]), 75 | Extension( 76 | "pyearth._forward", 77 | ["pyearth/_forward.c"], 78 | include_dirs=[local_inc, 79 | numpy_inc]), 80 | Extension( 81 | "pyearth._knot_search", 82 | ["pyearth/_knot_search.c"], 83 | include_dirs=[local_inc, 84 | numpy_inc]), 85 | Extension( 86 | "pyearth._qr", 87 | ["pyearth/_qr.c"], 88 | include_dirs=[local_inc, 89 | numpy_inc]), 90 | Extension( 91 | "pyearth._types", 92 | ["pyearth/_types.c"], 93 | include_dirs=[local_inc, 94 | numpy_inc]) 95 | ] 96 | return ext_modules 97 | 98 | def setup_package(): 99 | # Create a dictionary of arguments for setup 100 | setup_args = { 101 | 'name': 'sklearn-contrib-py-earth', 102 | 'version': versioneer.get_version(), 103 | 'author': 'Jason Rudy', 104 | 'author_email': 'jcrudy@gmail.com', 105 | 'packages': find_packages(), 106 | 'license': 'LICENSE.txt', 107 | 'download_url': 'https://github.com/scikit-learn-contrib/py-earth/archive/0.1.tar.gz', 108 | 'description': 109 | 'A Python implementation of Jerome Friedman\'s Multivariate Adaptive Regression Splines.', 110 | 'long_description': codecs.open('README.md', mode='r', encoding='utf-8').read(), 111 | 'classifiers': ['Intended Audience :: Developers', 112 | 'Intended Audience :: Science/Research', 113 | 'License :: OSI Approved :: BSD License', 114 | 'Development Status :: 3 - Alpha', 115 | 'Operating System :: MacOS', 116 | 'Operating System :: Microsoft :: Windows', 117 | 'Operating System :: POSIX', 118 | 'Operating System :: Unix', 119 | 'Programming Language :: Cython', 120 | 'Programming Language :: Python', 121 | 'Programming Language :: Python :: 2', 122 | 'Programming Language :: Python :: 2.6', 123 | 'Programming Language :: Python :: 2.7', 124 | 'Programming Language :: Python :: 3', 125 | 'Programming Language :: Python :: 3.4', 126 | 'Programming Language :: Python :: 3.5', 127 | 'Programming Language :: Python :: 3.6', 128 | 'Topic :: Scientific/Engineering', 129 | 'Topic :: Software Development'], 130 | 'install_requires': [ 131 | 'scipy >= 0.16', 132 | 'scikit-learn >= 0.16', 133 | 'six' 134 | ], 135 | 'extras_require': {'docs': ['sphinx_gallery'], 136 | 'dev': ['cython'], 137 | 'export': ['sympy'], 138 | 'all_tests': ['pandas', 'statsmodels', 'patsy', 'sympy']}, 139 | 'setup_requires': ['numpy'], 140 | 'include_package_data': True 141 | } 142 | 143 | # Add the build_ext command only if cythonizing 144 | if cythonize_switch: 145 | from Cython.Distutils import build_ext 146 | setup_args['cmdclass'] = versioneer.get_cmdclass({'build_ext': build_ext}) 147 | else: 148 | setup_args['cmdclass'] = versioneer.get_cmdclass() 149 | 150 | def is_special_command(): 151 | special_list = ('--help-commands', 152 | 'egg_info', 153 | '--version', 154 | 'clean') 155 | return ('--help' in sys.argv[1:] or 156 | sys.argv[1] in special_list) 157 | 158 | if len(sys.argv) >= 2 and is_special_command(): 159 | setup(**setup_args) 160 | else: 161 | setup_args['ext_modules'] = get_ext_modules() 162 | setup(**setup_args) 163 | 164 | if __name__ == "__main__": 165 | setup_package() 166 | --------------------------------------------------------------------------------