├── .github └── workflows │ └── base.yml ├── .gitignore ├── .zenodo.json ├── CITATION.cff ├── LICENSE ├── README.md ├── ci_tools ├── .pylintrc ├── check_python_version.py ├── flake8-requirements.txt ├── github_release.py └── nox_utils.py ├── docs ├── changelog.md ├── examples │ ├── 1_simple_1D_demo.py │ ├── 2_pw_linear_demo.py │ └── Readme.md ├── index.md └── long_description.md ├── mkdocs.yml ├── noxfile-requirements.txt ├── noxfile.py ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src └── m5py │ ├── __init__.py │ ├── export.py │ ├── linreg_utils.py │ ├── main.py │ └── py.typed └── tests ├── __init__.py └── test_main.py /.github/workflows/base.yml: -------------------------------------------------------------------------------- 1 | # .github/workflows/base.yml 2 | name: Build 3 | on: 4 | # this one is to trigger the workflow manually from the interface 5 | workflow_dispatch: 6 | 7 | push: 8 | tags: 9 | - '*' 10 | branches: 11 | - main 12 | pull_request: 13 | branches: 14 | - main 15 | jobs: 16 | # pre-job to read nox tests matrix - see https://stackoverflow.com/q/66747359/7262247 17 | list_nox_test_sessions: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v2 21 | - uses: actions/setup-python@v1 22 | with: 23 | python-version: 3.7 24 | architecture: x64 25 | 26 | - name: Install noxfile requirements 27 | shell: bash -l {0} 28 | run: pip install -r noxfile-requirements.txt 29 | 30 | - name: List 'tests' nox sessions 31 | id: set-matrix 32 | run: echo "::set-output name=matrix::$(nox -s gha_list -- tests)" 33 | outputs: 34 | matrix: ${{ steps.set-matrix.outputs.matrix }} # save nox sessions list to outputs 35 | 36 | run_all_tests: 37 | needs: list_nox_test_sessions 38 | strategy: 39 | fail-fast: false 40 | matrix: 41 | os: [ ubuntu-latest ] # , macos-latest, windows-latest] 42 | # all nox sessions: manually > dynamically from previous job 43 | # nox_session: ["tests-2.7", "tests-3.7"] 44 | nox_session: ${{ fromJson(needs.list_nox_test_sessions.outputs.matrix) }} 45 | 46 | name: ${{ matrix.os }} ${{ matrix.nox_session }} # ${{ matrix.name_suffix }} 47 | runs-on: ${{ matrix.os }} 48 | steps: 49 | - uses: actions/checkout@v2 50 | 51 | # Conda install 52 | - name: Install conda v3.7 53 | uses: conda-incubator/setup-miniconda@v2 54 | with: 55 | # auto-update-conda: true 56 | python-version: 3.7 57 | activate-environment: noxenv 58 | - run: conda info 59 | shell: bash -l {0} # so that conda works 60 | - run: conda list 61 | shell: bash -l {0} # so that conda works 62 | 63 | # Nox install + run 64 | - name: Install noxfile requirements 65 | shell: bash -l {0} # so that conda works 66 | run: pip install -r noxfile-requirements.txt 67 | - run: conda list 68 | shell: bash -l {0} # so that conda works 69 | - run: nox -s "${{ matrix.nox_session }}" 70 | shell: bash -l {0} # so that conda works 71 | 72 | # Share ./docs/reports so that they can be deployed with doc in next job 73 | - name: Share reports with other jobs 74 | # if: matrix.nox_session == '...': not needed, if empty wont be shared 75 | uses: actions/upload-artifact@master 76 | with: 77 | name: reports_dir 78 | path: ./docs/reports 79 | 80 | publish_release: 81 | needs: run_all_tests 82 | runs-on: ubuntu-latest 83 | if: github.event_name == 'push' 84 | steps: 85 | - name: GitHub context to debug conditional steps 86 | env: 87 | GITHUB_CONTEXT: ${{ toJSON(github) }} 88 | run: echo "$GITHUB_CONTEXT" 89 | 90 | - uses: actions/checkout@v2 91 | with: 92 | fetch-depth: 0 # so that gh-deploy works 93 | 94 | # 1) retrieve the reports generated previously 95 | - name: Retrieve reports 96 | uses: actions/download-artifact@master 97 | with: 98 | name: reports_dir 99 | path: ./docs/reports 100 | 101 | # Conda install 102 | - name: Install conda v3.7 103 | uses: conda-incubator/setup-miniconda@v2 104 | with: 105 | # auto-update-conda: true 106 | python-version: 3.7 107 | activate-environment: noxenv 108 | - run: conda info 109 | shell: bash -l {0} # so that conda works 110 | - run: conda list 111 | shell: bash -l {0} # so that conda works 112 | 113 | # Nox install 114 | - name: Install noxfile requirements 115 | shell: bash -l {0} # so that conda works 116 | run: pip install -r noxfile-requirements.txt 117 | - run: conda list 118 | shell: bash -l {0} # so that conda works 119 | 120 | # 5) Run the flake8 report and badge 121 | - name: Run flake8 analysis and generate corresponding badge 122 | shell: bash -l {0} # so that conda works 123 | run: nox -s flake8 124 | 125 | # -------------- only on Ubuntu + MAIN PUSH (no pull request, no tag) ----------- 126 | 127 | # 5) Publish the doc and test reports 128 | - name: \[not on TAG\] Publish documentation, tests and coverage reports 129 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads') # startsWith(matrix.os,'ubuntu') 130 | shell: bash -l {0} # so that conda works 131 | run: nox -s publish 132 | 133 | # 6) Publish coverage report 134 | - name: \[not on TAG\] Create codecov.yaml with correct paths 135 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads') 136 | shell: bash 137 | run: | 138 | cat << EOF > codecov.yml 139 | # codecov.yml 140 | fixes: 141 | - "/home/runner/work/smarie/python-m5p/::" # Correct paths 142 | EOF 143 | - name: \[not on TAG\] Publish coverage report 144 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads') 145 | uses: codecov/codecov-action@v1 146 | with: 147 | files: ./docs/reports/coverage/coverage.xml 148 | 149 | # -------------- only on Ubuntu + TAG PUSH (no pull request) ----------- 150 | 151 | # 7) Create github release and build the wheel 152 | - name: \[TAG only\] Build wheel and create github release 153 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 154 | shell: bash -l {0} # so that conda works 155 | run: nox -s release -- ${{ secrets.GITHUB_TOKEN }} 156 | 157 | # 8) Publish the wheel on PyPi 158 | - name: \[TAG only\] Deploy on PyPi 159 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 160 | uses: pypa/gh-action-pypi-publish@release/v1 161 | with: 162 | user: __token__ 163 | password: ${{ secrets.PYPI_API_TOKEN }} 164 | 165 | delete-artifacts: 166 | needs: publish_release 167 | runs-on: ubuntu-latest 168 | if: github.event_name == 'push' 169 | steps: 170 | - uses: kolpav/purge-artifacts-action@v1 171 | with: 172 | token: ${{ secrets.GITHUB_TOKEN }} 173 | expire-in: 0 # Setting this to 0 will delete all artifacts 174 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | src/*/_version.py 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv*/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # PyCharm development 133 | /.idea 134 | 135 | # OSX 136 | .DS_Store 137 | 138 | # JUnit and coverage reports 139 | docs/reports 140 | 141 | # ODSClient cache 142 | .odsclient 143 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "title": "m5py", 3 | "description": "

An implementation of M5 (Prime) and model trees for scikit-learn.

", 4 | "language": "eng", 5 | "license": { 6 | "id": "bsd-license" 7 | }, 8 | "keywords": [ 9 | "python", 10 | "model", 11 | "tree", 12 | "regression", 13 | "M5", 14 | "prime", 15 | "scikit-learn", 16 | "machine learning" 17 | ], 18 | "creators": [ 19 | { 20 | "orcid": "0000-0002-5929-1047", 21 | "affiliation": "Schneider Electric", 22 | "name": "Sylvain Mari\u00e9" 23 | }, 24 | { 25 | "name": "Various github contributors" 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: m5py 6 | message: 'If you use this software, please cite it as below.' 7 | type: software 8 | authors: 9 | - family-names: Marié 10 | given-names: Sylvain 11 | orcid: 'https://orcid.org/0000-0002-5929-1047' 12 | identifiers: 13 | - type: doi 14 | value: 10.5281/zenodo.10552219 15 | description: Software (Zenodo) 16 | repository-code: 'https://github.com/smarie/python-m5p' 17 | url: 'https://smarie.github.io/python-m5p/' 18 | repository-artifact: 'https://pypi.org/project/m5py/' 19 | abstract: >- 20 | An implementation of M5 (Prime) and model trees for 21 | scikit-learn. 22 | keywords: 23 | - python 24 | - model 25 | - tree 26 | - regression 27 | - M5 28 | - prime 29 | - scikit-learn 30 | - machine learning 31 | license: BSD-3-Clause 32 | doi: 10.5281/zenodo.10552219 33 | preferred-citation: 34 | type: conference 35 | url: 'https://hal.science/hal-03762155/' 36 | authors: 37 | - family-names: Marié 38 | given-names: Sylvain 39 | orcid: 'https://orcid.org/0000-0002-5929-1047' 40 | title: >- 41 | `python-m5p` - M5 Prime regression trees in python, compliant with 42 | scikit-learn 43 | conference: 44 | name: "PyCon.DE & PyData" 45 | city: "Berlin" 46 | country: "Germany" 47 | date-end: "2022-04-13" 48 | date-start: "2022-04-11" 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2016-2022, Sylvain Marié, Schneider Electric Industries 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `m5py` 2 | 3 | *`scikit-learn`-compliant M5 / M5' model trees for python* 4 | 5 | [![Python versions](https://img.shields.io/pypi/pyversions/m5py.svg)](https://pypi.python.org/pypi/m5py/) [![Build Status](https://github.com/smarie/python-m5p/actions/workflows/base.yml/badge.svg)](https://github.com/smarie/python-m5p/actions/workflows/base.yml) [![Tests Status](./reports/junit/junit-badge.svg?dummy=8484744)](./reports/junit/report.html) [![Coverage Status](./reports/coverage/coverage-badge.svg?dummy=8484744)](./reports/coverage/index.html) [![codecov](https://codecov.io/gh/smarie/python-m5p/branch/main/graph/badge.svg)](https://codecov.io/gh/smarie/python-m5p) [![Flake8 Status](./reports/flake8/flake8-badge.svg?dummy=8484744)](./reports/flake8/index.html) 6 | 7 | [![Documentation](https://img.shields.io/badge/doc-latest-blue.svg)](https://smarie.github.io/python-m5p/) [![PyPI](https://img.shields.io/pypi/v/m5py.svg)](https://pypi.python.org/pypi/m5py/) [![Downloads](https://pepy.tech/badge/m5py)](https://pepy.tech/project/m5py) [![Downloads per week](https://pepy.tech/badge/m5py/week)](https://pepy.tech/project/m5py) [![GitHub stars](https://img.shields.io/github/stars/smarie/python-m5p.svg)](https://github.com/smarie/python-m5p/stargazers) 8 | 9 | **This is the readme for developers.** The documentation for users is available here: [https://smarie.github.io/python-m5p/](https://smarie.github.io/python-m5p/) 10 | 11 | ## Want to contribute ? 12 | 13 | Contributions are welcome ! Simply fork this project on github, commit your contributions, and create pull requests. 14 | 15 | Here is a non-exhaustive list of interesting open topics: [https://github.com/smarie/python-m5p/issues](https://github.com/smarie/python-m5p/issues) 16 | 17 | ## `nox` setup 18 | 19 | This project uses `nox` to define all lifecycle tasks. In order to be able to run those tasks, you should create python 3.7 environment and install the requirements: 20 | 21 | ```bash 22 | >>> conda create -n noxenv python="3.7" 23 | >>> activate noxenv 24 | (noxenv) >>> pip install -r noxfile-requirements.txt 25 | ``` 26 | 27 | You should then be able to list all available tasks using: 28 | 29 | ``` 30 | >>> nox --list 31 | Sessions defined in \noxfile.py: 32 | 33 | * tests-2.7 -> Run the test suite, including test reports generation and coverage reports. 34 | * tests-3.5 -> Run the test suite, including test reports generation and coverage reports. 35 | * tests-3.6 -> Run the test suite, including test reports generation and coverage reports. 36 | * tests-3.8 -> Run the test suite, including test reports generation and coverage reports. 37 | * tests-3.7 -> Run the test suite, including test reports generation and coverage reports. 38 | - docs-3.7 -> Generates the doc and serves it on a local http server. Pass '-- build' to build statically instead. 39 | - publish-3.7 -> Deploy the docs+reports on github pages. Note: this rebuilds the docs 40 | - release-3.7 -> Create a release on github corresponding to the latest tag 41 | ``` 42 | 43 | ## Running the tests and generating the reports 44 | 45 | This project uses `pytest` so running `pytest` at the root folder will execute all tests on current environment. However it is a bit cumbersome to manage all requirements by hand ; it is easier to use `nox` to run `pytest` on all supported python environments with the correct package requirements: 46 | 47 | ```bash 48 | nox 49 | ``` 50 | 51 | Tests and coverage reports are automatically generated under `./docs/reports` for one of the sessions (`tests-3.7`). 52 | 53 | If you wish to execute tests on a specific environment, use explicit session names, e.g. `nox -s tests-3.6`. 54 | 55 | 56 | ## Editing the documentation 57 | 58 | This project uses `mkdocs` to generate its documentation page. Therefore building a local copy of the doc page may be done using `mkdocs build -f docs/mkdocs.yml`. However once again things are easier with `nox`. You can easily build and serve locally a version of the documentation site using: 59 | 60 | ```bash 61 | >>> nox -s docs 62 | nox > Running session docs-3.7 63 | nox > Creating conda env in .nox\docs-3-7 with python=3.7 64 | nox > [docs] Installing requirements with pip: ['mkdocs-material', 'mkdocs', 'pymdown-extensions', 'pygments'] 65 | nox > python -m pip install mkdocs-material mkdocs pymdown-extensions pygments 66 | nox > mkdocs serve -f ./docs/mkdocs.yml 67 | INFO - Building documentation... 68 | INFO - Cleaning site directory 69 | INFO - The following pages exist in the docs directory, but are not included in the "nav" configuration: 70 | - long_description.md 71 | INFO - Documentation built in 1.07 seconds 72 | INFO - Serving on http://127.0.0.1:8000 73 | INFO - Start watching changes 74 | ... 75 | ``` 76 | 77 | While this is running, you can edit the files under `./docs/` and browse the automatically refreshed documentation at the local [http://127.0.0.1:8000](http://127.0.0.1:8000) page. 78 | 79 | Once you are done, simply hit `` to stop the session. 80 | 81 | Publishing the documentation (including tests and coverage reports) is done automatically by [the continuous integration engine](https://github.com/smarie/python-m5p/actions), using the `nox -s publish` session, this is not needed for local development. 82 | 83 | ## Packaging 84 | 85 | This project uses `setuptools_scm` to synchronise the version number. Therefore the following command should be used for development snapshots as well as official releases: `python setup.py sdist bdist_wheel`. However this is not generally needed since [the continuous integration engine](https://github.com/smarie/python-m5p/actions) does it automatically for us on git tags. For reference, this is done in the `nox -s release` session. 86 | 87 | ### Merging pull requests with edits - memo 88 | 89 | Ax explained in github ('get commandline instructions'): 90 | 91 | ```bash 92 | git checkout -b - main 93 | git pull https://github.com//python-m5p.git --no-commit --ff-only 94 | ``` 95 | 96 | if the second step does not work, do a normal auto-merge (do not use **rebase**!): 97 | 98 | ```bash 99 | git pull https://github.com//python-m5p.git --no-commit 100 | ``` 101 | 102 | Finally review the changes, possibly perform some modifications, and commit. 103 | -------------------------------------------------------------------------------- /ci_tools/.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | # init-hook="import m5py" 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore= 13 | 14 | # Pickle collected data for later comparisons. 15 | persistent=no 16 | 17 | # List of plugins (as comma separated values of python modules names) to load, 18 | # usually to register additional checkers. 19 | load-plugins= 20 | 21 | # Use multiple processes to speed up Pylint. 22 | # DO NOT CHANGE THIS VALUES >1 HIDE RESULTS!!!!! 23 | jobs=1 24 | 25 | # Allow loading of arbitrary C extensions. Extensions are imported into the 26 | # active Python interpreter and may run arbitrary code. 27 | unsafe-load-any-extension=no 28 | 29 | # A comma-separated list of package or module names from where C extensions may 30 | # be loaded. Extensions are loading into the active Python interpreter and may 31 | # run arbitrary code 32 | extension-pkg-whitelist= 33 | 34 | # Allow optimization of some AST trees. This will activate a peephole AST 35 | # optimizer, which will apply various small optimizations. For instance, it can 36 | # be used to obtain the result of joining multiple strings with the addition 37 | # operator. Joining a lot of strings can lead to a maximum recursion error in 38 | # Pylint and this flag can prevent that. It has one side effect, the resulting 39 | # AST will be different than the one from reality. 40 | optimize-ast=no 41 | 42 | 43 | [MESSAGES CONTROL] 44 | 45 | # Only show warnings with the listed confidence levels. Leave empty to show 46 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 47 | confidence= 48 | 49 | # Enable the message, report, category or checker with the given id(s). You can 50 | # either give multiple identifier separated by comma (,) or put this option 51 | # multiple time. See also the "--disable" option for examples. 52 | disable=all 53 | 54 | enable=import-error, 55 | import-self, 56 | reimported, 57 | wildcard-import, 58 | misplaced-future, 59 | relative-import, 60 | deprecated-module, 61 | unpacking-non-sequence, 62 | invalid-all-object, 63 | undefined-all-variable, 64 | used-before-assignment, 65 | cell-var-from-loop, 66 | global-variable-undefined, 67 | redefined-builtin, 68 | redefine-in-handler, 69 | unused-import, 70 | unused-wildcard-import, 71 | global-variable-not-assigned, 72 | undefined-loop-variable, 73 | global-statement, 74 | global-at-module-level, 75 | bad-open-mode, 76 | redundant-unittest-assert, 77 | boolean-datetime, 78 | # Has common issues with our style due to 79 | # https://github.com/PyCQA/pylint/issues/210 80 | unused-variable 81 | 82 | # Things we'd like to enable someday: 83 | # redefined-outer-name (requires a bunch of work to clean up our code first) 84 | # undefined-variable (re-enable when pylint fixes https://github.com/PyCQA/pylint/issues/760) 85 | # no-name-in-module (giving us spurious warnings https://github.com/PyCQA/pylint/issues/73) 86 | # unused-argument (need to clean up or code a lot, e.g. prefix unused_?) 87 | 88 | # Things we'd like to try. 89 | # Procedure: 90 | # 1. Enable a bunch. 91 | # 2. See if there's spurious ones; if so disable. 92 | # 3. Record above. 93 | # 4. Remove from this list. 94 | # deprecated-method, 95 | # anomalous-unicode-escape-in-string, 96 | # anomalous-backslash-in-string, 97 | # not-in-loop, 98 | # function-redefined, 99 | # continue-in-finally, 100 | # abstract-class-instantiated, 101 | # star-needs-assignment-target, 102 | # duplicate-argument-name, 103 | # return-in-init, 104 | # too-many-star-expressions, 105 | # nonlocal-and-global, 106 | # return-outside-function, 107 | # return-arg-in-generator, 108 | # invalid-star-assignment-target, 109 | # bad-reversed-sequence, 110 | # nonexistent-operator, 111 | # yield-outside-function, 112 | # init-is-generator, 113 | # nonlocal-without-binding, 114 | # lost-exception, 115 | # assert-on-tuple, 116 | # dangerous-default-value, 117 | # duplicate-key, 118 | # useless-else-on-loop, 119 | # expression-not-assigned, 120 | # confusing-with-statement, 121 | # unnecessary-lambda, 122 | # pointless-statement, 123 | # pointless-string-statement, 124 | # unnecessary-pass, 125 | # unreachable, 126 | # eval-used, 127 | # exec-used, 128 | # bad-builtin, 129 | # using-constant-test, 130 | # deprecated-lambda, 131 | # bad-super-call, 132 | # missing-super-argument, 133 | # slots-on-old-class, 134 | # super-on-old-class, 135 | # property-on-old-class, 136 | # not-an-iterable, 137 | # not-a-mapping, 138 | # format-needs-mapping, 139 | # truncated-format-string, 140 | # missing-format-string-key, 141 | # mixed-format-string, 142 | # too-few-format-args, 143 | # bad-str-strip-call, 144 | # too-many-format-args, 145 | # bad-format-character, 146 | # format-combined-specification, 147 | # bad-format-string-key, 148 | # bad-format-string, 149 | # missing-format-attribute, 150 | # missing-format-argument-key, 151 | # unused-format-string-argument, 152 | # unused-format-string-key, 153 | # invalid-format-index, 154 | # bad-indentation, 155 | # mixed-indentation, 156 | # unnecessary-semicolon, 157 | # lowercase-l-suffix, 158 | # fixme, 159 | # invalid-encoded-data, 160 | # unpacking-in-except, 161 | # import-star-module-level, 162 | # parameter-unpacking, 163 | # long-suffix, 164 | # old-octal-literal, 165 | # old-ne-operator, 166 | # backtick, 167 | # old-raise-syntax, 168 | # print-statement, 169 | # metaclass-assignment, 170 | # next-method-called, 171 | # dict-iter-method, 172 | # dict-view-method, 173 | # indexing-exception, 174 | # raising-string, 175 | # standarderror-builtin, 176 | # using-cmp-argument, 177 | # cmp-method, 178 | # coerce-method, 179 | # delslice-method, 180 | # getslice-method, 181 | # hex-method, 182 | # nonzero-method, 183 | # oct-method, 184 | # setslice-method, 185 | # apply-builtin, 186 | # basestring-builtin, 187 | # buffer-builtin, 188 | # cmp-builtin, 189 | # coerce-builtin, 190 | # old-division, 191 | # execfile-builtin, 192 | # file-builtin, 193 | # filter-builtin-not-iterating, 194 | # no-absolute-import, 195 | # input-builtin, 196 | # intern-builtin, 197 | # long-builtin, 198 | # map-builtin-not-iterating, 199 | # range-builtin-not-iterating, 200 | # raw_input-builtin, 201 | # reduce-builtin, 202 | # reload-builtin, 203 | # round-builtin, 204 | # unichr-builtin, 205 | # unicode-builtin, 206 | # xrange-builtin, 207 | # zip-builtin-not-iterating, 208 | # logging-format-truncated, 209 | # logging-too-few-args, 210 | # logging-too-many-args, 211 | # logging-unsupported-format, 212 | # logging-not-lazy, 213 | # logging-format-interpolation, 214 | # invalid-unary-operand-type, 215 | # unsupported-binary-operation, 216 | # no-member, 217 | # not-callable, 218 | # redundant-keyword-arg, 219 | # assignment-from-no-return, 220 | # assignment-from-none, 221 | # not-context-manager, 222 | # repeated-keyword, 223 | # missing-kwoa, 224 | # no-value-for-parameter, 225 | # invalid-sequence-index, 226 | # invalid-slice-index, 227 | # too-many-function-args, 228 | # unexpected-keyword-arg, 229 | # unsupported-membership-test, 230 | # unsubscriptable-object, 231 | # access-member-before-definition, 232 | # method-hidden, 233 | # assigning-non-slot, 234 | # duplicate-bases, 235 | # inconsistent-mro, 236 | # inherit-non-class, 237 | # invalid-slots, 238 | # invalid-slots-object, 239 | # no-method-argument, 240 | # no-self-argument, 241 | # unexpected-special-method-signature, 242 | # non-iterator-returned, 243 | # protected-access, 244 | # arguments-differ, 245 | # attribute-defined-outside-init, 246 | # no-init, 247 | # abstract-method, 248 | # signature-differs, 249 | # bad-staticmethod-argument, 250 | # non-parent-init-called, 251 | # super-init-not-called, 252 | # bad-except-order, 253 | # catching-non-exception, 254 | # bad-exception-context, 255 | # notimplemented-raised, 256 | # raising-bad-type, 257 | # raising-non-exception, 258 | # misplaced-bare-raise, 259 | # duplicate-except, 260 | # broad-except, 261 | # nonstandard-exception, 262 | # binary-op-exception, 263 | # bare-except, 264 | # not-async-context-manager, 265 | # yield-inside-async-function, 266 | 267 | # ... 268 | [REPORTS] 269 | 270 | # Set the output format. Available formats are text, parseable, colorized, msvs 271 | # (visual studio) and html. You can also give a reporter class, eg 272 | # mypackage.mymodule.MyReporterClass. 273 | output-format=parseable 274 | 275 | # Put messages in a separate file for each module / package specified on the 276 | # command line instead of printing them on stdout. Reports (if any) will be 277 | # written in a file name "pylint_global.[txt|html]". 278 | files-output=no 279 | 280 | # Tells whether to display a full report or only the messages 281 | reports=no 282 | 283 | # Python expression which should return a note less than 10 (10 is the highest 284 | # note). You have access to the variables errors warning, statement which 285 | # respectively contain the number of errors / warnings messages and the total 286 | # number of statements analyzed. This is used by the global evaluation report 287 | # (RP0004). 288 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 289 | 290 | # Template used to display messages. This is a python new-style format string 291 | # used to format the message information. See doc for all details 292 | #msg-template= 293 | 294 | 295 | [LOGGING] 296 | 297 | # Logging modules to check that the string format arguments are in logging 298 | # function parameter format 299 | logging-modules=logging 300 | 301 | 302 | [FORMAT] 303 | 304 | # Maximum number of characters on a single line. 305 | max-line-length=100 306 | 307 | # Regexp for a line that is allowed to be longer than the limit. 308 | ignore-long-lines=^\s*(# )??$ 309 | 310 | # Allow the body of an if to be on the same line as the test if there is no 311 | # else. 312 | single-line-if-stmt=no 313 | 314 | # List of optional constructs for which whitespace checking is disabled. `dict- 315 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 316 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 317 | # `empty-line` allows space-only lines. 318 | no-space-check=trailing-comma,dict-separator 319 | 320 | # Maximum number of lines in a module 321 | max-module-lines=1000 322 | 323 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 324 | # tab). 325 | indent-string=' ' 326 | 327 | # Number of spaces of indent required inside a hanging or continued line. 328 | indent-after-paren=4 329 | 330 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 331 | expected-line-ending-format= 332 | 333 | 334 | [TYPECHECK] 335 | 336 | # Tells whether missing members accessed in mixin class should be ignored. A 337 | # mixin class is detected if its name ends with "mixin" (case insensitive). 338 | ignore-mixin-members=yes 339 | 340 | # List of module names for which member attributes should not be checked 341 | # (useful for modules/projects where namespaces are manipulated during runtime 342 | # and thus existing member attributes cannot be deduced by static analysis. It 343 | # supports qualified module names, as well as Unix pattern matching. 344 | ignored-modules= 345 | 346 | # List of classes names for which member attributes should not be checked 347 | # (useful for classes with attributes dynamically set). This supports can work 348 | # with qualified names. 349 | ignored-classes= 350 | 351 | # List of members which are set dynamically and missed by pylint inference 352 | # system, and so shouldn't trigger E1101 when accessed. Python regular 353 | # expressions are accepted. 354 | generated-members= 355 | 356 | 357 | [VARIABLES] 358 | 359 | # Tells whether we should check for unused import in __init__ files. 360 | init-import=no 361 | 362 | # A regular expression matching the name of dummy variables (i.e. expectedly 363 | # not used). 364 | dummy-variables-rgx=^_|^dummy 365 | 366 | # List of additional names supposed to be defined in builtins. Remember that 367 | # you should avoid to define new builtins when possible. 368 | additional-builtins= 369 | 370 | # List of strings which can identify a callback function by name. A callback 371 | # name must start or end with one of those strings. 372 | callbacks=cb_,_cb 373 | 374 | 375 | [SIMILARITIES] 376 | 377 | # Minimum lines number of a similarity. 378 | min-similarity-lines=4 379 | 380 | # Ignore comments when computing similarities. 381 | ignore-comments=yes 382 | 383 | # Ignore docstrings when computing similarities. 384 | ignore-docstrings=yes 385 | 386 | # Ignore imports when computing similarities. 387 | ignore-imports=no 388 | 389 | 390 | [SPELLING] 391 | 392 | # Spelling dictionary name. Available dictionaries: none. To make it working 393 | # install python-enchant package. 394 | spelling-dict= 395 | 396 | # List of comma separated words that should not be checked. 397 | spelling-ignore-words= 398 | 399 | # A path to a file that contains private dictionary; one word per line. 400 | spelling-private-dict-file= 401 | 402 | # Tells whether to store unknown words to indicated private dictionary in 403 | # --spelling-private-dict-file option instead of raising a message. 404 | spelling-store-unknown-words=no 405 | 406 | 407 | [MISCELLANEOUS] 408 | 409 | # List of note tags to take in consideration, separated by a comma. 410 | notes=FIXME,XXX,TODO 411 | 412 | 413 | [BASIC] 414 | 415 | # List of builtins function names that should not be used, separated by a comma 416 | bad-functions=map,filter,input 417 | 418 | # Good variable names which should always be accepted, separated by a comma 419 | good-names=i,j,k,ex,Run,_ 420 | 421 | # Bad variable names which should always be refused, separated by a comma 422 | bad-names=foo,bar,baz,toto,tutu,tata 423 | 424 | # Colon-delimited sets of names that determine each other's naming style when 425 | # the name regexes allow several styles. 426 | name-group= 427 | 428 | # Include a hint for the correct naming format with invalid-name 429 | include-naming-hint=no 430 | 431 | # Regular expression matching correct function names 432 | function-rgx=[a-z_][a-z0-9_]{2,30}$ 433 | 434 | # Naming hint for function names 435 | function-name-hint=[a-z_][a-z0-9_]{2,30}$ 436 | 437 | # Regular expression matching correct variable names 438 | variable-rgx=[a-z_][a-z0-9_]{2,30}$ 439 | 440 | # Naming hint for variable names 441 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$ 442 | 443 | # Regular expression matching correct constant names 444 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 445 | 446 | # Naming hint for constant names 447 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 448 | 449 | # Regular expression matching correct attribute names 450 | attr-rgx=[a-z_][a-z0-9_]{2,30}$ 451 | 452 | # Naming hint for attribute names 453 | attr-name-hint=[a-z_][a-z0-9_]{2,30}$ 454 | 455 | # Regular expression matching correct argument names 456 | argument-rgx=[a-z_][a-z0-9_]{2,30}$ 457 | 458 | # Naming hint for argument names 459 | argument-name-hint=[a-z_][a-z0-9_]{2,30}$ 460 | 461 | # Regular expression matching correct class attribute names 462 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 463 | 464 | # Naming hint for class attribute names 465 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 466 | 467 | # Regular expression matching correct inline iteration names 468 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 469 | 470 | # Naming hint for inline iteration names 471 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 472 | 473 | # Regular expression matching correct class names 474 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 475 | 476 | # Naming hint for class names 477 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 478 | 479 | # Regular expression matching correct module names 480 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 481 | 482 | # Naming hint for module names 483 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 484 | 485 | # Regular expression matching correct method names 486 | method-rgx=[a-z_][a-z0-9_]{2,30}$ 487 | 488 | # Naming hint for method names 489 | method-name-hint=[a-z_][a-z0-9_]{2,30}$ 490 | 491 | # Regular expression which should only match function or class names that do 492 | # not require a docstring. 493 | no-docstring-rgx=^_ 494 | 495 | # Minimum line length for functions/classes that require docstrings, shorter 496 | # ones are exempt. 497 | docstring-min-length=-1 498 | 499 | 500 | [ELIF] 501 | 502 | # Maximum number of nested blocks for function / method body 503 | max-nested-blocks=5 504 | 505 | 506 | [IMPORTS] 507 | 508 | # Deprecated modules which should not be used, separated by a comma 509 | deprecated-modules=regsub,TERMIOS,Bastion,rexec 510 | 511 | # Create a graph of every (i.e. internal and external) dependencies in the 512 | # given file (report RP0402 must not be disabled) 513 | import-graph= 514 | 515 | # Create a graph of external dependencies in the given file (report RP0402 must 516 | # not be disabled) 517 | ext-import-graph= 518 | 519 | # Create a graph of internal dependencies in the given file (report RP0402 must 520 | # not be disabled) 521 | int-import-graph= 522 | 523 | 524 | [DESIGN] 525 | 526 | # Maximum number of arguments for function / method 527 | max-args=5 528 | 529 | # Argument names that match this expression will be ignored. Default to name 530 | # with leading underscore 531 | ignored-argument-names=_.* 532 | 533 | # Maximum number of locals for function / method body 534 | max-locals=15 535 | 536 | # Maximum number of return / yield for function / method body 537 | max-returns=6 538 | 539 | # Maximum number of branch for function / method body 540 | max-branches=12 541 | 542 | # Maximum number of statements in function / method body 543 | max-statements=50 544 | 545 | # Maximum number of parents for a class (see R0901). 546 | max-parents=7 547 | 548 | # Maximum number of attributes for a class (see R0902). 549 | max-attributes=7 550 | 551 | # Minimum number of public methods for a class (see R0903). 552 | min-public-methods=2 553 | 554 | # Maximum number of public methods for a class (see R0904). 555 | max-public-methods=20 556 | 557 | # Maximum number of boolean expressions in a if statement 558 | max-bool-expr=5 559 | 560 | 561 | [CLASSES] 562 | 563 | # List of method names used to declare (i.e. assign) instance attributes. 564 | defining-attr-methods=__init__,__new__,setUp 565 | 566 | # List of valid names for the first argument in a class method. 567 | valid-classmethod-first-arg=cls 568 | 569 | # List of valid names for the first argument in a metaclass class method. 570 | valid-metaclass-classmethod-first-arg=mcs 571 | 572 | # List of member names, which should be excluded from the protected access 573 | # warning. 574 | exclude-protected=_asdict,_fields,_replace,_source,_make 575 | 576 | 577 | [EXCEPTIONS] 578 | 579 | # Exceptions that will emit a warning when being caught. Defaults to 580 | # "Exception" 581 | overgeneral-exceptions=Exception 582 | -------------------------------------------------------------------------------- /ci_tools/check_python_version.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if __name__ == "__main__": 4 | # Execute only if run as a script. 5 | # Check the arguments 6 | nbargs = len(sys.argv[1:]) 7 | if nbargs != 1: 8 | raise ValueError("a mandatory argument is required: ") 9 | 10 | expected_version_str = sys.argv[1] 11 | try: 12 | expected_version = tuple(int(i) for i in expected_version_str.split(".")) 13 | except Exception as e: 14 | raise ValueError("Error while parsing expected version %r: %r" % (expected_version, e)) 15 | 16 | if len(expected_version) < 1: 17 | raise ValueError("At least a major is expected") 18 | 19 | if sys.version_info[0] != expected_version[0]: 20 | raise AssertionError("Major version does not match. Expected %r - Actual %r" % (expected_version_str, sys.version)) 21 | 22 | if len(expected_version) >= 2 and sys.version_info[1] != expected_version[1]: 23 | raise AssertionError("Minor version does not match. Expected %r - Actual %r" % (expected_version_str, sys.version)) 24 | 25 | if len(expected_version) >= 3 and sys.version_info[2] != expected_version[2]: 26 | raise AssertionError("Patch version does not match. Expected %r - Actual %r" % (expected_version_str, sys.version)) 27 | 28 | print("SUCCESS - Actual python version %r matches expected one %r" % (sys.version, expected_version_str)) 29 | -------------------------------------------------------------------------------- /ci_tools/flake8-requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools_scm>=3,<4 2 | chardet 3 | flake8>=3.6,<4 4 | flake8-html>=0.4,<1 5 | flake8-bandit>=2.1.1,<3 6 | bandit<1.7.3 7 | flake8-bugbear>=20.1.0,<21.0.0 8 | flake8-docstrings>=1.5,<2 9 | flake8-print>=3.1.1,<4 10 | flake8-tidy-imports>=4.2.1,<5 11 | flake8-copyright==0.2.2 # Internal forked repo to fix an issue, keep specific version 12 | pydocstyle>=5.1.1,<6 13 | pycodestyle>=2.6.0,<3 14 | mccabe>=0.6.1,<1 15 | naming>=0.5.1,<1 16 | pyflakes>=2.2,<3 17 | genbadge[flake8] 18 | jinja2>=3.0.0,<3.1.0 19 | -------------------------------------------------------------------------------- /ci_tools/github_release.py: -------------------------------------------------------------------------------- 1 | # a clone of the ruby example https://gist.github.com/valeriomazzeo/5491aee76f758f7352e2e6611ce87ec1 2 | import os 3 | from os import path 4 | 5 | import re 6 | 7 | import click 8 | from click import Path 9 | from github import Github, UnknownObjectException 10 | # from valid8 import validate not compliant with python 2.7 11 | 12 | 13 | @click.command() 14 | @click.option('-u', '--user', help='GitHub username') 15 | @click.option('-p', '--pwd', help='GitHub password') 16 | @click.option('-s', '--secret', help='GitHub access token') 17 | @click.option('-r', '--repo-slug', help='Repo slug. i.e.: apple/swift') 18 | @click.option('-cf', '--changelog-file', help='Changelog file path') 19 | @click.option('-d', '--doc-url', help='Documentation url') 20 | @click.option('-df', '--data-file', help='Data file to upload', type=Path(exists=True, file_okay=True, dir_okay=False, 21 | resolve_path=True)) 22 | @click.argument('tag') 23 | def create_or_update_release(user, pwd, secret, repo_slug, changelog_file, doc_url, data_file, tag): 24 | """ 25 | Creates or updates (TODO) 26 | a github release corresponding to git tag . 27 | """ 28 | # 1- AUTHENTICATION 29 | if user is not None and secret is None: 30 | # using username and password 31 | # validate('user', user, instance_of=str) 32 | assert isinstance(user, str) 33 | # validate('pwd', pwd, instance_of=str) 34 | assert isinstance(pwd, str) 35 | g = Github(user, pwd) 36 | elif user is None and secret is not None: 37 | # or using an access token 38 | # validate('secret', secret, instance_of=str) 39 | assert isinstance(secret, str) 40 | g = Github(secret) 41 | else: 42 | raise ValueError("You should either provide username/password OR an access token") 43 | click.echo("Logged in as {user_name}".format(user_name=g.get_user())) 44 | 45 | # 2- CHANGELOG VALIDATION 46 | regex_pattern = "[\s\S]*[\n][#]+[\s]*(?P[\S ]*%s[\S ]*)[\n]+?(?P<body>[\s\S]*?)[\n]*?(\n#|$)" % re.escape(tag) 47 | changelog_section = re.compile(regex_pattern) 48 | if changelog_file is not None: 49 | # validate('changelog_file', changelog_file, custom=os.path.exists, 50 | # help_msg="changelog file should be a valid file path") 51 | assert os.path.exists(changelog_file), "changelog file should be a valid file path" 52 | with open(changelog_file) as f: 53 | contents = f.read() 54 | 55 | match = changelog_section.match(contents).groupdict() 56 | if match is None or len(match) != 2: 57 | raise ValueError("Unable to find changelog section matching regexp pattern in changelog file.") 58 | else: 59 | title = match['title'] 60 | message = match['body'] 61 | else: 62 | title = tag 63 | message = '' 64 | 65 | # append footer if doc url is provided 66 | message += "\n\nSee [documentation page](%s) for details." % doc_url 67 | 68 | # 3- REPOSITORY EXPLORATION 69 | # validate('repo_slug', repo_slug, instance_of=str, min_len=1, help_msg="repo_slug should be a non-empty string") 70 | assert isinstance(repo_slug, str) and len(repo_slug) > 0, "repo_slug should be a non-empty string" 71 | repo = g.get_repo(repo_slug) 72 | 73 | # -- Is there a tag with that name ? 74 | try: 75 | tag_ref = repo.get_git_ref("tags/" + tag) 76 | except UnknownObjectException: 77 | raise ValueError("No tag with name %s exists in repository %s" % (tag, repo.name)) 78 | 79 | # -- Is there already a release with that tag name ? 80 | click.echo("Checking if release %s already exists in repository %s" % (tag, repo.name)) 81 | try: 82 | release = repo.get_release(tag) 83 | if release is not None: 84 | raise ValueError("Release %s already exists in repository %s. Please set overwrite to True if you wish to " 85 | "update the release (Not yet supported)" % (tag, repo.name)) 86 | except UnknownObjectException: 87 | # Release does not exist: we can safely create it. 88 | click.echo("Creating release %s on repo: %s" % (tag, repo.name)) 89 | click.echo("Release title: '%s'" % title) 90 | click.echo("Release message:\n--\n%s\n--\n" % message) 91 | repo.create_git_release(tag=tag, name=title, 92 | message=message, 93 | draft=False, prerelease=False) 94 | 95 | # add the asset file if needed 96 | if data_file is not None: 97 | release = None 98 | while release is None: 99 | release = repo.get_release(tag) 100 | release.upload_asset(path=data_file, label=path.split(data_file)[1], content_type="application/gzip") 101 | 102 | # --- Memo --- 103 | # release.target_commitish # 'master' 104 | # release.tag_name # '0.5.0' 105 | # release.title # 'First public release' 106 | # release.body # markdown body 107 | # release.draft # False 108 | # release.prerelease # False 109 | # # 110 | # release.author 111 | # release.created_at # datetime.datetime(2018, 11, 9, 17, 49, 56) 112 | # release.published_at # datetime.datetime(2018, 11, 9, 20, 11, 10) 113 | # release.last_modified # None 114 | # # 115 | # release.id # 13928525 116 | # release.etag # 'W/"dfab7a13086d1b44fe290d5d04125124"' 117 | # release.url # 'https://api.github.com/repos/smarie/python-m5p/releases/13928525' 118 | # release.html_url # 'https://github.com/smarie/python-m5p/releases/tag/0.5.0' 119 | # release.tarball_url # 'https://api.github.com/repos/smarie/python-m5p/tarball/0.5.0' 120 | # release.zipball_url # 'https://api.github.com/repos/smarie/python-m5p/zipball/0.5.0' 121 | # release.upload_url # 'https://uploads.github.com/repos/smarie/python-m5p/releases/13928525/assets{?name,label}' 122 | 123 | 124 | if __name__ == '__main__': 125 | create_or_update_release() 126 | -------------------------------------------------------------------------------- /ci_tools/nox_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import asyncio 4 | from collections import namedtuple 5 | from inspect import signature, isfunction 6 | import logging 7 | from pathlib import Path 8 | import shutil 9 | import subprocess 10 | import sys 11 | import os 12 | 13 | from typing import Sequence, Dict, Union, Iterable, Mapping, Any, IO, Tuple, Optional, List 14 | 15 | from makefun import wraps, remove_signature_parameters, add_signature_parameters 16 | 17 | import nox 18 | from nox.sessions import Session 19 | 20 | 21 | nox_logger = logging.getLogger("nox") 22 | 23 | 24 | PY27, PY35, PY36, PY37, PY38, PY39, PY310 = "2.7", "3.5", "3.6", "3.7", "3.8", "3.9", "3.10" 25 | DONT_INSTALL = "dont_install" 26 | 27 | 28 | def power_session( 29 | func=None, 30 | envs=None, 31 | grid_param_name="env", 32 | python=None, 33 | py=None, 34 | reuse_venv=None, 35 | name=None, 36 | venv_backend=None, 37 | venv_params=None, 38 | logsdir=None, 39 | **kwargs 40 | ): 41 | """A nox.session on steroids 42 | 43 | :param func: 44 | :param envs: a dictionary {key: dict_of_params} where key is either the python version of a tuple (python version, 45 | grid id) and all keys in the dict_of_params must be the same in all entries. The decorated function should 46 | have one parameter for each of these keys, they will be injected with the value. 47 | :param grid_param_name: when the key in `envs` is a tuple, this name will be the name of the generated parameter to 48 | iterate through the various combinations for each python version. 49 | :param python: 50 | :param py: 51 | :param reuse_venv: 52 | :param name: 53 | :param venv_backend: 54 | :param venv_params: 55 | :param logsdir: 56 | :param kwargs: 57 | :return: 58 | """ 59 | if func is not None: 60 | return power_session()(func) 61 | else: 62 | def combined_decorator(f): 63 | # replace Session with PowerSession 64 | f = with_power_session(f) 65 | 66 | # open a log file for the session, use it to stream the commands stdout and stderrs, 67 | # and possibly inject the log file in the session function 68 | if logsdir is not None: 69 | f = with_logfile(logs_dir=logsdir)(f) 70 | 71 | # decorate with @nox.session and possibly @nox.parametrize to create the grid 72 | return nox_session_with_grid(python=python, py=py, envs=envs, reuse_venv=reuse_venv, name=name, 73 | grid_param_name=grid_param_name, venv_backend=venv_backend, 74 | venv_params=venv_params, **kwargs)(f) 75 | 76 | return combined_decorator 77 | 78 | 79 | def with_power_session(f=None): 80 | """ A decorator to patch the session objects in order to add all methods from Session2""" 81 | 82 | if f is not None: 83 | return with_power_session()(f) 84 | 85 | def _decorator(f): 86 | @wraps(f) 87 | def _f_wrapper(**kwargs): 88 | # patch the session arg 89 | PowerSession.patch(kwargs['session']) 90 | 91 | # finally execute the session 92 | return f(**kwargs) 93 | 94 | return _f_wrapper 95 | 96 | return _decorator 97 | 98 | 99 | class PowerSession(Session): 100 | """ 101 | Our nox session improvements 102 | """ 103 | 104 | # ------------ commandline runners ----------- 105 | 106 | def run2(self, 107 | command: Union[Iterable[str], str], 108 | logfile: Union[bool, str, Path] = True, 109 | **kwargs): 110 | """ 111 | An improvement of session.run that is able to 112 | 113 | - automatically split the provided command if it is a string 114 | - use a log file 115 | 116 | :param command: 117 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path. 118 | :param kwargs: 119 | :return: 120 | """ 121 | if isinstance(command, str): 122 | command = command.split(' ') 123 | 124 | self.run(*command, logfile=logfile, **kwargs) 125 | 126 | def run_multi(self, 127 | cmds: str, 128 | logfile: Union[bool, str, Path] = True, 129 | **kwargs): 130 | """ 131 | An improvement of session.run that is able to 132 | 133 | - support multiline strings 134 | - use a log file 135 | 136 | :param cmds: 137 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path. 138 | :param kwargs: 139 | :return: 140 | """ 141 | for cmdline in (line for line in cmds.splitlines() if line): 142 | self.run2(cmdline, logfile=logfile, **kwargs) 143 | 144 | # ------------ requirements installers ----------- 145 | 146 | def install_reqs( 147 | self, 148 | # pre wired phases 149 | setup=False, 150 | install=False, 151 | tests=False, 152 | extras=(), 153 | # custom phase 154 | phase=None, 155 | phase_reqs=None, 156 | versions_dct=None 157 | ): 158 | """ 159 | A high-level helper to install requirements from the various project files 160 | 161 | - pyproject.toml "[build-system] requires" (if setup=True) 162 | - setup.cfg "[options] setup_requires" (if setup=True) 163 | - setup.cfg "[options] install_requires" (if install=True) 164 | - setup.cfg "[options] test_requires" (if tests=True) 165 | - setup.cfg "[options.extras_require] <...>" (if extras=(a tuple of extras)) 166 | 167 | Two additional mechanisms are provided in order to customize how packages are installed. 168 | 169 | Conda packages 170 | -------------- 171 | If the session runs on a conda environment, you can add a [tool.conda] section to your pyproject.toml. This 172 | section should contain a `conda_packages` entry containing the list of package names that should be installed 173 | using conda instead of pip. 174 | 175 | ``` 176 | [tool.conda] 177 | # Declare that the following packages should be installed with conda instead of pip 178 | # Note: this includes packages declared everywhere, here and in setup.cfg 179 | conda_packages = [ 180 | "setuptools", 181 | "wheel", 182 | "pip" 183 | ] 184 | ``` 185 | 186 | Version constraints 187 | ------------------- 188 | In addition to the version constraints in the pyproject.toml and setup.cfg, you can specify additional temporary 189 | constraints with the `versions_dct` argument , for example if you know that this executes on a specific python 190 | version that requires special care. 191 | For this, simply pass a dictionary of {'pkg_name': 'pkg_constraint'} for example {"pip": ">10"}. 192 | 193 | """ 194 | 195 | # Read requirements from pyproject.toml 196 | toml_setup_reqs, toml_use_conda_for = read_pyproject_toml() 197 | if setup: 198 | self.install_any("pyproject.toml#build-system", toml_setup_reqs, 199 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct) 200 | 201 | # Read test requirements from setup.cfg 202 | setup_cfg = read_setuptools_cfg() 203 | if setup: 204 | self.install_any("setup.cfg#setup_requires", setup_cfg.setup_requires, 205 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct) 206 | if install: 207 | self.install_any("setup.cfg#install_requires", setup_cfg.install_requires, 208 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct) 209 | if tests: 210 | self.install_any("setup.cfg#tests_requires", setup_cfg.tests_requires, 211 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct) 212 | 213 | for extra in extras: 214 | self.install_any("setup.cfg#extras_require#%s" % extra, setup_cfg.extras_require[extra], 215 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct) 216 | 217 | if phase is not None: 218 | self.install_any(phase, phase_reqs, use_conda_for=toml_use_conda_for, versions_dct=versions_dct) 219 | 220 | def install_any(self, 221 | phase_name: str, 222 | pkgs: Sequence[str], 223 | use_conda_for: Sequence[str] = (), 224 | versions_dct: Dict[str, str] = None, 225 | logfile: Union[bool, str, Path] = True, 226 | ): 227 | """Install the `pkgs` provided with `session.install(*pkgs)`, except for those present in `use_conda_for`""" 228 | 229 | nox_logger.debug("\nAbout to install *%s* requirements: %s.\n " 230 | "Conda pkgs are %s" % (phase_name, pkgs, use_conda_for)) 231 | 232 | # use the provided versions dictionary to update the versions 233 | if versions_dct is None: 234 | versions_dct = dict() 235 | pkgs = [pkg + versions_dct.get(pkg, "") for pkg in pkgs if versions_dct.get(pkg, "") != DONT_INSTALL] 236 | 237 | # install on conda... if the session uses conda backend 238 | if not isinstance(self.virtualenv, nox.virtualenv.CondaEnv): 239 | conda_pkgs = [] 240 | else: 241 | conda_pkgs = [pkg_req for pkg_req in pkgs if any(get_req_pkg_name(pkg_req) == c for c in use_conda_for)] 242 | if len(conda_pkgs) > 0: 243 | nox_logger.info("[%s] Installing requirements with conda: %s" % (phase_name, conda_pkgs)) 244 | self.conda_install2(*conda_pkgs, logfile=logfile) 245 | 246 | pip_pkgs = [pkg_req for pkg_req in pkgs if pkg_req not in conda_pkgs] 247 | # safety: make sure that nothing went modified or forgotten 248 | assert set(conda_pkgs).union(set(pip_pkgs)) == set(pkgs) 249 | if len(pip_pkgs) > 0: 250 | nox_logger.info("[%s] Installing requirements with pip: %s" % (phase_name, pip_pkgs)) 251 | self.install2(*pip_pkgs, logfile=logfile) 252 | 253 | def conda_install2(self, 254 | *conda_pkgs, 255 | logfile: Union[bool, str, Path] = True, 256 | **kwargs 257 | ): 258 | """ 259 | Same as session.conda_install() but with support for `logfile`. 260 | 261 | :param conda_pkgs: 262 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path. 263 | :return: 264 | """ 265 | return self.conda_install(*conda_pkgs, logfile=logfile, **kwargs) 266 | 267 | def install2(self, 268 | *pip_pkgs, 269 | logfile: Union[bool, str, Path] = True, 270 | **kwargs 271 | ): 272 | """ 273 | Same as session.install() but with support for `logfile`. 274 | 275 | :param pip_pkgs: 276 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path. 277 | :return: 278 | """ 279 | return self.install(*pip_pkgs, logfile=logfile, **kwargs) 280 | 281 | def get_session_id(self): 282 | """Return the session id""" 283 | return Path(self.bin).name 284 | 285 | @classmethod 286 | def is_power_session(cls, session: Session): 287 | return PowerSession.install2.__name__ in session.__dict__ 288 | 289 | @classmethod 290 | def patch(cls, session: Session): 291 | """ 292 | Add all methods from this class to the provided object. 293 | Note that we could instead have created a proper proxy... but complex for not a lot of benefit. 294 | :param session: 295 | :return: 296 | """ 297 | if not cls.is_power_session(session): 298 | for m_name, m in cls.__dict__.items(): 299 | if not isfunction(m): 300 | continue 301 | if m is cls.patch: 302 | continue 303 | if not hasattr(session, m_name): 304 | setattr(session.__class__, m_name, m) 305 | 306 | return True 307 | 308 | 309 | # ------------- requirements related 310 | 311 | 312 | def read_pyproject_toml(): 313 | """ 314 | Reads the `pyproject.toml` and returns 315 | 316 | - a list of setup requirements from [build-system] requires 317 | - sub-list of these requirements that should be installed with conda, from [tool.my_conda] conda_packages 318 | """ 319 | if os.path.exists("pyproject.toml"): 320 | import toml 321 | nox_logger.debug("\nA `pyproject.toml` file exists. Loading it.") 322 | pyproject = toml.load("pyproject.toml") 323 | requires = pyproject['build-system']['requires'] 324 | conda_pkgs = pyproject['tool']['conda']['conda_packages'] 325 | return requires, conda_pkgs 326 | else: 327 | raise FileNotFoundError("No `pyproject.toml` file exists. No dependency will be installed ...") 328 | 329 | 330 | SetupCfg = namedtuple('SetupCfg', ('setup_requires', 'install_requires', 'tests_requires', 'extras_require')) 331 | 332 | 333 | def read_setuptools_cfg(): 334 | """ 335 | Reads the `setup.cfg` file and extracts the various requirements lists 336 | """ 337 | # see https://stackoverflow.com/a/30679041/7262247 338 | from setuptools import Distribution 339 | dist = Distribution() 340 | dist.parse_config_files() 341 | return SetupCfg(setup_requires=dist.setup_requires, 342 | install_requires=dist.install_requires, 343 | tests_requires=dist.tests_require, 344 | extras_require=dist.extras_require) 345 | 346 | 347 | def get_req_pkg_name(r): 348 | """Return the package name part of a python package requirement. 349 | 350 | For example 351 | "funcsigs;python<'3.5'" will return "funcsigs" 352 | "pytest>=3" will return "pytest" 353 | """ 354 | return r.replace('<', '=').replace('>', '=').replace(';', '=').split("=")[0] 355 | 356 | 357 | # ------------- log related 358 | 359 | 360 | def with_logfile(logs_dir: Path, 361 | logfile_arg: str = "logfile", 362 | logfile_handler_arg: str = "logfilehandler" 363 | ): 364 | """ A decorator to inject a logfile""" 365 | 366 | def _decorator(f): 367 | # check the signature of f 368 | foo_sig = signature(f) 369 | needs_logfile_injection = logfile_arg in foo_sig.parameters 370 | needs_logfilehandler_injection = logfile_handler_arg in foo_sig.parameters 371 | 372 | # modify the exposed signature if needed 373 | new_sig = None 374 | if needs_logfile_injection: 375 | new_sig = remove_signature_parameters(foo_sig, logfile_arg) 376 | if needs_logfilehandler_injection: 377 | new_sig = remove_signature_parameters(foo_sig, logfile_handler_arg) 378 | 379 | @wraps(f, new_sig=new_sig) 380 | def _f_wrapper(**kwargs): 381 | # find the session arg 382 | session = kwargs['session'] # type: Session 383 | 384 | # add file handler to logger 385 | logfile = logs_dir / ("%s.log" % PowerSession.get_session_id(session)) 386 | error_logfile = logfile.with_name("ERROR_%s" % logfile.name) 387 | success_logfile = logfile.with_name("SUCCESS_%s" % logfile.name) 388 | # delete old files if present 389 | for _f in (logfile, error_logfile, success_logfile): 390 | if _f.exists(): 391 | _f.unlink() 392 | 393 | # add a FileHandler to the logger 394 | logfile_handler = log_to_file(logfile) 395 | 396 | # inject the log file / log file handler in the args: 397 | if needs_logfile_injection: 398 | kwargs[logfile_arg] = logfile 399 | if needs_logfilehandler_injection: 400 | kwargs[logfile_handler_arg] = logfile_handler 401 | 402 | # finally execute the session 403 | try: 404 | res = f(**kwargs) 405 | except Exception as e: 406 | # close and detach the file logger and rename as ERROR_....log 407 | remove_file_logger() 408 | logfile.rename(error_logfile) 409 | raise e 410 | else: 411 | # close and detach the file logger and rename as SUCCESS_....log 412 | remove_file_logger() 413 | logfile.rename(success_logfile) 414 | return res 415 | 416 | return _f_wrapper 417 | 418 | return _decorator 419 | 420 | 421 | def log_to_file(file_path: Union[str, Path] 422 | ): 423 | """ 424 | Closes and removes all file handlers from the nox logger, 425 | and add a new one to the provided file path 426 | 427 | :param file_path: 428 | :return: 429 | """ 430 | for h in list(nox_logger.handlers): 431 | if isinstance(h, logging.FileHandler): 432 | h.close() 433 | nox_logger.removeHandler(h) 434 | fh = logging.FileHandler(str(file_path), mode='w') 435 | nox_logger.addHandler(fh) 436 | return fh 437 | 438 | 439 | def get_current_logfile_handler(): 440 | """ 441 | Returns the current unique log file handler (see `log_to_file`) 442 | """ 443 | for h in list(nox_logger.handlers): 444 | if isinstance(h, logging.FileHandler): 445 | return h 446 | return None 447 | 448 | 449 | def get_log_file_stream(): 450 | """ 451 | Returns the output stream for the current log file handler if any (see `log_to_file`) 452 | """ 453 | h = get_current_logfile_handler() 454 | if h is not None: 455 | return h.stream 456 | return None 457 | 458 | 459 | def remove_file_logger(): 460 | """ 461 | Closes and detaches the current logfile handler 462 | :return: 463 | """ 464 | h = get_current_logfile_handler() 465 | if h is not None: 466 | h.close() 467 | nox_logger.removeHandler(h) 468 | 469 | 470 | # ------------ environment grid / parametrization related 471 | 472 | def nox_session_with_grid(python = None, 473 | py = None, 474 | envs: Mapping[str, Mapping[str, Any]] = None, 475 | reuse_venv: Optional[bool] = None, 476 | name: Optional[str] = None, 477 | venv_backend: Any = None, 478 | venv_params: Any = None, 479 | grid_param_name: str = None, 480 | **kwargs 481 | ): 482 | """ 483 | Since nox is not yet capable to define a build matrix with python and parameters mixed in the same parametrize 484 | this implements it with a dirty hack. 485 | To remove when https://github.com/theacodes/nox/pull/404 is complete 486 | 487 | :param envs: 488 | :param env_python_key: 489 | :return: 490 | """ 491 | if envs is None: 492 | # Fast track default to @nox.session 493 | return nox.session(python=python, py=py, reuse_venv=reuse_venv, name=name, venv_backend=venv_backend, 494 | venv_params=venv_params, **kwargs) 495 | else: 496 | # Current limitation : session param names can be 'python' or 'py' only 497 | if py is not None or python is not None: 498 | raise ValueError("`python` session argument can not be provided both directly and through the " 499 | "`env` with `session_param_names`") 500 | 501 | # First examine the env and collect the parameter values for python 502 | all_python = [] 503 | all_params = [] 504 | 505 | env_contents_names = None 506 | has_parameter = None 507 | for env_id, env_params in envs.items(): 508 | # consistency checks for the env_id 509 | if has_parameter is None: 510 | has_parameter = isinstance(env_id, tuple) 511 | else: 512 | if has_parameter != isinstance(env_id, tuple): 513 | raise ValueError("All keys in env should be tuples, or not be tuples. Error for %r" % env_id) 514 | 515 | # retrieve python version and parameter 516 | if not has_parameter: 517 | if env_id not in all_python: 518 | all_python.append(env_id) 519 | else: 520 | if len(env_id) != 2: 521 | raise ValueError("Only a size-2 tuple can be used as env id") 522 | py_id, param_id = env_id 523 | if py_id not in all_python: 524 | all_python.append(py_id) 525 | if param_id not in all_params: 526 | all_params.append(param_id) 527 | 528 | # consistency checks for the dict contents. 529 | if env_contents_names is None: 530 | env_contents_names = set(env_params.keys()) 531 | else: 532 | if env_contents_names != set(env_params.keys()): 533 | raise ValueError("Environment %r parameters %r does not match parameters in the first environment: %r" 534 | % (env_id, env_contents_names, set(env_params.keys()))) 535 | 536 | if has_parameter and not grid_param_name: 537 | raise ValueError("You must provide a grid parameter name when the env keys are tuples.") 538 | 539 | def _decorator(f): 540 | s_name = name if name is not None else f.__name__ 541 | for pyv, _param in product(all_python, all_params): 542 | if (pyv, _param) not in envs: 543 | # create a dummy folder to avoid creating a useless venv ? 544 | env_dir = Path(".nox") / ("%s-%s-%s-%s" % (s_name, pyv.replace('.', '-'), grid_param_name, _param)) 545 | env_dir.mkdir(parents=True, exist_ok=True) 546 | 547 | # check the signature of f 548 | foo_sig = signature(f) 549 | missing = env_contents_names - set(foo_sig.parameters) 550 | if len(missing) > 0: 551 | raise ValueError("Session function %r does not contain environment parameter(s) %r" % (f.__name__, missing)) 552 | 553 | # modify the exposed signature if needed 554 | new_sig = None 555 | if len(env_contents_names) > 0: 556 | new_sig = remove_signature_parameters(foo_sig, *env_contents_names) 557 | 558 | if has_parameter: 559 | if grid_param_name in foo_sig.parameters: 560 | raise ValueError("Internal error, this parameter has a reserved name: %r" % grid_param_name) 561 | else: 562 | new_sig = add_signature_parameters(new_sig, last=(grid_param_name,)) 563 | 564 | @wraps(f, new_sig=new_sig) 565 | def _f_wrapper(**kwargs): 566 | # find the session arg 567 | session = kwargs['session'] # type: Session 568 | 569 | # get the versions to use for this environment 570 | try: 571 | if has_parameter: 572 | grid_param = kwargs.pop(grid_param_name) 573 | params_dct = envs[(session.python, grid_param)] 574 | else: 575 | params_dct = envs[session.python] 576 | except KeyError: 577 | # Skip this session, it is a dummy one 578 | nox_logger.warning( 579 | "Skipping configuration, this is not supported in python version %r" % session.python) 580 | return 581 | 582 | # inject the parameters in the args: 583 | kwargs.update(params_dct) 584 | 585 | # finally execute the session 586 | return f(**kwargs) 587 | 588 | if has_parameter: 589 | _f_wrapper = nox.parametrize(grid_param_name, all_params)(_f_wrapper) 590 | 591 | _f_wrapper = nox.session(python=all_python, reuse_venv=reuse_venv, name=name, 592 | venv_backend=venv_backend, venv_params=venv_params)(_f_wrapper) 593 | return _f_wrapper 594 | 595 | return _decorator 596 | 597 | 598 | # ----------- other goodies 599 | 600 | 601 | def rm_file(folder: Union[str, Path] 602 | ): 603 | """Since on windows Path.unlink throws permission error sometimes, os.remove is preferred.""" 604 | if isinstance(folder, str): 605 | folder = Path(folder) 606 | 607 | if folder.exists(): 608 | os.remove(str(folder)) 609 | # Folders.site.unlink() --> possible PermissionError 610 | 611 | 612 | def rm_folder(folder: Union[str, Path] 613 | ): 614 | """Since on windows Path.unlink throws permission error sometimes, shutil is preferred.""" 615 | if isinstance(folder, str): 616 | folder = Path(folder) 617 | 618 | if folder.exists(): 619 | shutil.rmtree(str(folder)) 620 | # Folders.site.unlink() --> possible PermissionError 621 | 622 | 623 | # --- the patch of popen able to tee to logfile -- 624 | 625 | 626 | import nox.popen as nox_popen_module 627 | orig_nox_popen = nox_popen_module.popen 628 | 629 | 630 | class LogFileStreamCtx: 631 | def __init__(self, logfile_stream): 632 | self.logfile_stream = logfile_stream 633 | 634 | def __enter__(self): 635 | return self.logfile_stream 636 | 637 | def __exit__(self, exc_type, exc_val, exc_tb): 638 | pass 639 | 640 | 641 | def patched_popen( 642 | args: Sequence[str], 643 | env: Mapping[str, str] = None, 644 | silent: bool = False, 645 | stdout: Union[int, IO] = None, 646 | stderr: Union[int, IO] = subprocess.STDOUT, 647 | logfile: Union[bool, str, Path] = None, 648 | **kwargs 649 | ) -> Tuple[int, str]: 650 | """ 651 | Our patch of nox.popen.popen(). 652 | 653 | Current behaviour in `nox` is 654 | 655 | - when `silent=True` (default), process err is redirected to STDOUT and process out is captured in a PIPE and sent 656 | to the logger (that does not displaying it :) ) 657 | 658 | - when `silent=False` (explicitly set, or when nox is run with verbose flag), process out and process err are both 659 | redirected to STDOUT. 660 | 661 | Our implementation allows us to be a little more flexible: 662 | 663 | - if logfile is True or a string/Path, both process err and process out are both TEE-ed to logfile 664 | - at the same time, the above behaviour remains. 665 | 666 | :param args: 667 | :param env: 668 | :param silent: 669 | :param stdout: 670 | :param stderr: 671 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path. 672 | :return: 673 | """ 674 | logfile_stream = get_log_file_stream() 675 | 676 | if logfile in (None, False) or (logfile is True and logfile_stream is None): 677 | # execute popen as usual 678 | return orig_nox_popen(args=args, env=env, silent=silent, stdout=stdout, stderr=stderr, **kwargs) 679 | 680 | else: 681 | # we'll need to tee the popen 682 | if logfile is True: 683 | ctx = LogFileStreamCtx 684 | else: 685 | ctx = lambda _: open(logfile, "a") 686 | 687 | with ctx(logfile_stream) as log_file_stream: 688 | if silent and stdout is not None: 689 | raise ValueError( 690 | "Can not specify silent and stdout; passing a custom stdout always silences the commands output in " 691 | "Nox's log." 692 | ) 693 | 694 | shell = kwargs.get("shell", False) 695 | if shell: 696 | raise ValueError("Using shell=True is not yet supported with async streaming to log files") 697 | 698 | if stdout is not None or stderr is not subprocess.STDOUT: 699 | raise ValueError("Using custom streams is not yet supported with async popen") 700 | 701 | # old way 702 | # proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr) 703 | 704 | # New way: use asyncio to stream correctly 705 | # Note: if keyboard interrupts do not work we should check 706 | # https://mail.python.org/pipermail/async-sig/2017-August/000374.html maybe or the following threads. 707 | 708 | # define the async coroutines 709 | async def async_popen(): 710 | process = await asyncio.create_subprocess_exec(*args, env=env, stdout=asyncio.subprocess.PIPE, 711 | stderr=asyncio.subprocess.PIPE, **kwargs) 712 | 713 | # bind the out and err streams - see https://stackoverflow.com/a/59041913/7262247 714 | # to mimic nox behaviour we only use a single capturing list 715 | outlines = [] 716 | await asyncio.wait([ 717 | # process out is only redirected to STDOUT if not silent 718 | _read_stream(process.stdout, lambda l: tee(l, sinklist=outlines, sinkstream=log_file_stream, 719 | quiet=silent, verbosepipe=sys.stdout)), 720 | # process err is always redirected to STDOUT (quiet=False) with a specific label 721 | _read_stream(process.stderr, lambda l: tee(l, sinklist=outlines, sinkstream=log_file_stream, 722 | quiet=False, verbosepipe=sys.stdout, label="ERR:")) 723 | ]) 724 | return_code = await process.wait() # make sur the process has ended and retrieve its return code 725 | return return_code, outlines 726 | 727 | # run the coroutine in the event loop 728 | loop = asyncio.get_event_loop() 729 | return_code, outlines = loop.run_until_complete(async_popen()) 730 | 731 | # just in case, flush everything 732 | log_file_stream.flush() 733 | sys.stdout.flush() 734 | sys.stderr.flush() 735 | 736 | if silent: 737 | # same behaviour as in nox: this will be passed to the logger, and it will act depending on verbose flag 738 | out = "\n".join(outlines) if len(outlines) > 0 else "" 739 | else: 740 | # already written to stdout, no need to capture 741 | out = "" 742 | 743 | return return_code, out 744 | 745 | 746 | async def _read_stream(stream, callback): 747 | """Helper async coroutine to read from a stream line by line and write them in callback""" 748 | while True: 749 | line = await stream.readline() 750 | if line: 751 | callback(line) 752 | else: 753 | break 754 | 755 | 756 | def tee(linebytes, sinklist, sinkstream, verbosepipe, quiet, label=""): 757 | """ 758 | Helper routine to read a line, decode it, and append it to several sinks: 759 | 760 | - an optional `sinklist` list that will receive the decoded string in its "append" method 761 | - an optional `sinkstream` stream that will receive the decoded string in its "writelines" method 762 | - an optional `verbosepipe` stream that will receive only when quiet=False, the decoded string through a print 763 | 764 | append it to the sink, and if quiet=False, write it to pipe too. 765 | """ 766 | line = linebytes.decode('utf-8').rstrip() 767 | 768 | if sinklist is not None: 769 | sinklist.append(line) 770 | 771 | if sinkstream is not None: 772 | sinkstream.write(line + "\n") 773 | sinkstream.flush() 774 | 775 | if not quiet and verbosepipe is not None: 776 | print(label, line, file=verbosepipe) 777 | verbosepipe.flush() 778 | 779 | 780 | def patch_popen(): 781 | nox_popen_module.popen = patched_popen 782 | 783 | from nox.command import popen 784 | if popen is not patched_popen: 785 | nox.command.popen = patched_popen 786 | 787 | # change event loop on windows 788 | # see https://stackoverflow.com/a/44639711/7262247 789 | # and https://docs.python.org/3/library/asyncio-platforms.html#subprocess-support-on-windows 790 | if 'win32' in sys.platform: 791 | # Windows specific event-loop policy & cmd 792 | asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) 793 | # cmds = [['C:/Windows/system32/HOSTNAME.EXE']] 794 | 795 | # loop = asyncio.ProactorEventLoop() 796 | # asyncio.set_event_loop(loop) 797 | 798 | 799 | patch_popen() 800 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ### 0.3.3 - Added zenodo 4 | 5 | * Added zenodo entry 6 | 7 | ### 0.3.2 - Fixed compliance with sklearn 1.3.0 8 | 9 | * Fixed `AttributeError: 'super' object has no attribute 'fit' `. 10 | PR [#16](https://github.com/smarie/python-m5p/pull/16) by [lccatala](https://github.com/lccatala) 11 | 12 | ### 0.3.1 - Fixed compliance with sklearn 1.1.0 13 | 14 | * Fixed `TypeError: fit() got an unexpected keyword argument 'X_idx_sorted'`. 15 | PR [#11](https://github.com/smarie/python-m5p/pull/11) by [preinaj](https://github.com/preinaj) 16 | 17 | ### 0.3.0 - First public version 18 | 19 | * Initial fork from private repository + [scikit-learn/scikit-learn#13732](https://github.com/scikit-learn/scikit-learn/pull/13732). Fixes [#2](https://github.com/smarie/python-m5p/issues/1) 20 | * Two gallery examples (1D sine and pw-linear) 21 | 22 | ### 0.0.1 - Initial repo setup 23 | 24 | First CI roundtrip. Fixes [#1](https://github.com/smarie/python-m5p/issues/1) 25 | -------------------------------------------------------------------------------- /docs/examples/1_simple_1D_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple 1D example 3 | ================= 4 | 5 | A 1D regression with M5P decision tree. 6 | 7 | The tree is used to fit a sine curve with addition noisy observation. As a 8 | result, it learns local linear regressions approximating the sine curve. 9 | 10 | We can see the role of pruning (Tree 2) and pruning + smoothing (Tree 3). 11 | """ 12 | # %% 13 | # Import the necessary modules and libraries 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | from m5py import M5Prime, export_text_m5 18 | 19 | # %% 20 | # Create a random dataset 21 | rng = np.random.RandomState(1) 22 | X = np.sort(5 * rng.rand(80, 1), axis=0) 23 | y = np.sin(X).ravel() 24 | y[::5] += 0.5 * (0.5 - rng.rand(16)) 25 | 26 | # %% 27 | # Fit regression model 28 | regr_1 = M5Prime(use_smoothing=False, use_pruning=False) 29 | regr_1_label = "Tree 1" 30 | regr_1.fit(X, y) 31 | regr_2 = M5Prime(use_smoothing=False) 32 | regr_2_label = "Tree 2" 33 | regr_2.fit(X, y) 34 | regr_3 = M5Prime(smoothing_constant=5) 35 | regr_3_label = "Tree 3" 36 | regr_3.fit(X, y) 37 | 38 | # %% 39 | # Predict 40 | X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis] 41 | y_1 = regr_1.predict(X_test) 42 | y_2 = regr_2.predict(X_test) 43 | y_3 = regr_3.predict(X_test) 44 | 45 | # %% 46 | # Print the trees 47 | print("\n----- %s" % regr_1_label) 48 | print(regr_1.as_pretty_text()) 49 | 50 | # %% 51 | print("\n----- %s" % regr_2_label) 52 | print(regr_2.as_pretty_text()) 53 | 54 | # %% 55 | print("\n----- %s" % regr_3_label) 56 | print(export_text_m5(regr_3, out_file=None)) # equivalent to as_pretty_text 57 | 58 | # %% 59 | # Plot the results 60 | fig = plt.figure() 61 | plt.scatter(X, y, s=20, edgecolor="black", 62 | c="darkorange", label="data") 63 | plt.plot(X_test, y_1, color="cornflowerblue", label=regr_1_label, linewidth=2) 64 | plt.plot(X_test, y_2, color="yellowgreen", label=regr_2_label, linewidth=2) 65 | plt.plot(X_test, y_3, color="green", label=regr_3_label, linewidth=2) 66 | plt.xlabel("data") 67 | plt.ylabel("target") 68 | plt.title("Decision Tree Regression") 69 | plt.legend() 70 | fig 71 | -------------------------------------------------------------------------------- /docs/examples/2_pw_linear_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Piecewise-Linear artificial example 3 | =================================== 4 | 5 | This demo is reproducing the results from the "Artificial Dataset" in the M5 6 | paper, named `pw-linear` in the M5' paper. 7 | """ 8 | # %% 9 | # Import the necessary modules and libraries 10 | import numpy as np 11 | import pandas as pd 12 | import seaborn as sns 13 | 14 | from sklearn.model_selection import cross_val_score 15 | from sklearn.tree import DecisionTreeRegressor, export_text 16 | from m5py import M5Prime 17 | 18 | # %% 19 | # Create a random dataset 20 | rng = np.random.RandomState(1) 21 | nb_samples = 200 22 | 23 | X1 = rng.randint(0, 2, nb_samples) * 2 - 1 24 | X2 = rng.randint(-1, 2, nb_samples) 25 | X3 = rng.randint(-1, 2, nb_samples) 26 | X4 = rng.randint(-1, 2, nb_samples) 27 | X5 = rng.randint(-1, 2, nb_samples) 28 | X6 = rng.randint(-1, 2, nb_samples) 29 | X7 = rng.randint(-1, 2, nb_samples) 30 | X8 = rng.randint(-1, 2, nb_samples) 31 | X9 = rng.randint(-1, 2, nb_samples) 32 | X10 = rng.randint(-1, 2, nb_samples) 33 | 34 | feature_names = ["X%i" % i for i in range(1, 11)] 35 | X = np.c_[X1, X2, X3, X4, X5, X6, X7, X8, X9, X10] 36 | 37 | y = np.where( 38 | X1 > 0, 39 | 3 + 3 * X2 + 2 * X3 + X4, 40 | -3 + 3 * X5 + 2 * X6 + X7 41 | ) + rng.normal(loc=0., scale=2 ** 0.5, size=nb_samples) 42 | 43 | # %% 44 | # Define regression models and evaluate them on 10-fold CV 45 | regr_0 = DecisionTreeRegressor() 46 | regr_0_label = "Tree 0" 47 | regr_0_scores = cross_val_score(regr_0, X, y, cv=10) 48 | 49 | regr_1 = M5Prime(use_smoothing=False, use_pruning=False) 50 | regr_1_label = "Tree 1" 51 | regr_1_scores = cross_val_score(regr_1, X, y, cv=10) 52 | 53 | regr_2 = M5Prime(use_smoothing=False) 54 | regr_2_label = "Tree 2" 55 | regr_2_scores = cross_val_score(regr_2, X, y, cv=10) 56 | 57 | regr_3 = M5Prime(use_smoothing=True) 58 | regr_3_label = "Tree 3" 59 | regr_3_scores = cross_val_score(regr_3, X, y, cv=10) 60 | 61 | scores = np.c_[regr_0_scores, regr_1_scores, regr_2_scores, regr_3_scores] 62 | avgs = scores.mean(axis=0) 63 | stds = scores.std(axis=0) 64 | labels = [regr_0_label, regr_1_label, regr_2_label, regr_3_label] 65 | 66 | scores_df = pd.DataFrame(data=scores, columns=labels) 67 | sns.violinplot(data=scores_df) 68 | 69 | # %% 70 | # Fit the final models and print the trees: 71 | # 72 | regr_0.fit(X, y) 73 | print("\n----- %s" % regr_0_label) 74 | print(export_text(regr_0, feature_names=feature_names)) 75 | 76 | # %% 77 | regr_1.fit(X, y) 78 | print("\n----- %s" % regr_1_label) 79 | print(regr_1.as_pretty_text(feature_names=feature_names)) 80 | 81 | # %% 82 | regr_2.fit(X, y) 83 | print("\n----- %s" % regr_2_label) 84 | print(regr_2.as_pretty_text(feature_names=feature_names)) 85 | 86 | # %% 87 | regr_3.fit(X, y) 88 | print("\n----- %s" % regr_3_label) 89 | print(regr_3.as_pretty_text(feature_names=feature_names)) 90 | -------------------------------------------------------------------------------- /docs/examples/Readme.md: -------------------------------------------------------------------------------- 1 | # Usage examples 2 | 3 | These examples demonstrate how to use the library. 4 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # `m5py` 2 | 3 | *`scikit-learn`-compliant M5 / M5' model trees for python* 4 | 5 | [![Python versions](https://img.shields.io/pypi/pyversions/m5py.svg)](https://pypi.python.org/pypi/m5py/) [![Build Status](https://github.com/smarie/python-m5p/actions/workflows/base.yml/badge.svg)](https://github.com/smarie/python-m5p/actions/workflows/base.yml) [![Tests Status](./reports/junit/junit-badge.svg?dummy=8484744)](./reports/junit/report.html) [![Coverage Status](./reports/coverage/coverage-badge.svg?dummy=8484744)](./reports/coverage/index.html) [![codecov](https://codecov.io/gh/smarie/python-m5p/branch/main/graph/badge.svg)](https://codecov.io/gh/smarie/python-m5p) [![Flake8 Status](./reports/flake8/flake8-badge.svg?dummy=8484744)](./reports/flake8/index.html) 6 | 7 | [![Documentation](https://img.shields.io/badge/doc-latest-blue.svg)](https://smarie.github.io/python-m5p/) [![PyPI](https://img.shields.io/pypi/v/m5py.svg)](https://pypi.python.org/pypi/m5py/) [![Downloads](https://pepy.tech/badge/m5py)](https://pepy.tech/project/m5py) [![Downloads per week](https://pepy.tech/badge/m5py/week)](https://pepy.tech/project/m5py) [![GitHub stars](https://img.shields.io/github/stars/smarie/python-m5p.svg)](https://github.com/smarie/python-m5p/stargazers)[![DOI](https://zenodo.org/badge/411580595.svg)](https://zenodo.org/doi/10.5281/zenodo.10552218) 8 | 9 | In 1996 R. Quinlan introduced the M5 algorithm, a regression tree algorithm similar to CART (Breiman), with additional pruning so that leaves may contain linear models instead of constant values. The idea was to get smoother and simpler models. 10 | 11 | The algorithm was later enhanced by Wang & Witten under the name M5 Prime (aka M5', or M5P), with an implementation in the Weka toolbox. 12 | 13 | `m5py` is a python implementation leveraging `scikit-learn`'s regression tree engine. 14 | 15 | 16 | ## Citing 17 | 18 | If `m5py` helps you with your research work, don't hesitate to spread the word ! For this 19 | please cite this [conference presentation](https://hal.science/hal-03762155/). 20 | Optionally in addition you can also cite this Zenodo entry 21 | [![DOI](https://zenodo.org/badge/411580595.svg)](https://zenodo.org/doi/10.5281/zenodo.10552218). Thanks! 22 | 23 | ## Installing 24 | 25 | ```bash 26 | > pip install m5py 27 | ``` 28 | 29 | ## Usage 30 | 31 | See the [usage examples gallery](./generated/gallery). 32 | 33 | ## Main features / benefits 34 | 35 | * The classic `M5` algorithm in python, compliant with `scikit-learn`. 36 | 37 | ## See Also 38 | 39 | * Weka's [M5P model](https://weka.sourceforge.io/doc.dev/weka/classifiers/trees/M5P.html) 40 | * LightGBM's [linear_tree](https://lightgbm.readthedocs.io/en/latest/Parameters.html#linear_tree) 41 | * [linear-tree](https://github.com/cerlymarco/linear-tree) that has a similar procedure than lightgbm's (from 42 | [discussion](https://github.com/scikit-learn/scikit-learn/issues/13106#issuecomment-808730062)) 43 | 44 | ### Others 45 | 46 | *Do you like this library ? You might also like [my other python libraries](https://github.com/smarie/OVERVIEW#python)* 47 | 48 | ## Want to contribute ? 49 | 50 | Details on the github page: [https://github.com/smarie/python-m5p](https://github.com/smarie/python-m5p) 51 | -------------------------------------------------------------------------------- /docs/long_description.md: -------------------------------------------------------------------------------- 1 | # `m5py` 2 | 3 | [![Python versions](https://img.shields.io/pypi/pyversions/m5py.svg)](https://pypi.python.org/pypi/m5py/) [![Build Status](https://github.com/smarie/python-m5p/actions/workflows/base.yml/badge.svg)](https://github.com/smarie/python-m5p/actions/workflows/base.yml) [![Tests Status](https://smarie.github.io/python-m5p/reports/junit/junit-badge.svg?dummy=8484744)](https://smarie.github.io/python-m5p/reports/junit/report.html) [![Coverage Status](https://smarie.github.io/python-m5p/reports/coverage/coverage-badge.svg?dummy=8484744)](https://smarie.github.io/python-m5p/reports/coverage/index.html) [![codecov](https://codecov.io/gh/smarie/python-m5p/branch/main/graph/badge.svg)](https://codecov.io/gh/smarie/python-m5p) [![Flake8 Status](https://smarie.github.io/python-m5p/reports/flake8/flake8-badge.svg?dummy=8484744)](https://smarie.github.io/python-m5p/reports/flake8/index.html) 4 | 5 | [![Documentation](https://img.shields.io/badge/doc-latest-blue.svg)](https://smarie.github.io/python-m5p/) [![PyPI](https://img.shields.io/pypi/v/m5py.svg)](https://pypi.python.org/pypi/m5py/) [![Downloads](https://pepy.tech/badge/m5py)](https://pepy.tech/project/m5py) [![Downloads per week](https://pepy.tech/badge/m5py/week)](https://pepy.tech/project/m5py) [![GitHub stars](https://img.shields.io/github/stars/smarie/python-m5p.svg)](https://github.com/smarie/python-m5p/stargazers) 6 | 7 | *`scikit-learn`-compliant M5 / M5' model trees for python* 8 | 9 | The documentation for users is available here: [https://smarie.github.io/python-m5p/](https://smarie.github.io/python-m5p/) 10 | 11 | A readme for developers is available here: [https://github.com/smarie/python-m5p](https://github.com/smarie/python-m5p) 12 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: m5py 2 | # site_description: 'A short description of my project' 3 | repo_url: https://github.com/smarie/python-m5p 4 | # docs_dir: . 5 | # site_dir: ../site 6 | # default branch is main instead of master now on github 7 | edit_uri : ./edit/main/docs 8 | nav: 9 | - Home: index.md 10 | - Usage examples: generated/gallery 11 | # - API reference: api_reference.md 12 | - Changelog: changelog.md 13 | 14 | theme: material # readthedocs mkdocs 15 | 16 | plugins: 17 | - gallery: 18 | examples_dirs: docs/examples # path to your example scripts 19 | gallery_dirs: docs/generated/gallery # where to save generated gallery 20 | filename_pattern: ".*_demo" 21 | within_subsection_order: FileNameSortKey # order according to file name 22 | 23 | - search # make sure the search plugin is still enabled 24 | 25 | markdown_extensions: 26 | - toc: 27 | permalink: true 28 | -------------------------------------------------------------------------------- /noxfile-requirements.txt: -------------------------------------------------------------------------------- 1 | nox 2 | toml 3 | makefun 4 | setuptools_scm # used in 'release' 5 | keyring # used in 'release' 6 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from json import dumps 3 | import logging 4 | 5 | import nox # noqa 6 | from pathlib import Path # noqa 7 | import sys 8 | 9 | # add parent folder to python path so that we can import noxfile_utils.py 10 | # note that you need to "pip install -r noxfile-requiterements.txt" for this file to work. 11 | sys.path.append(str(Path(__file__).parent / "ci_tools")) 12 | from nox_utils import PY37, PY36, PY35, PY38, PY39, PY310, power_session, rm_folder, rm_file, PowerSession # noqa 13 | 14 | 15 | pkg_name = "m5py" 16 | gh_org = "smarie" 17 | gh_repo = "python-m5p" 18 | 19 | ENVS = { 20 | # python 3.10 is not available on conda yet 21 | # PY310: {"coverage": False, "pkg_specs": {"pip": ">19"}}, 22 | PY39: {"coverage": False, "pkg_specs": {"pip": ">19"}}, 23 | # PY27: {"coverage": False, "pkg_specs": {"pip": ">10"}}, 24 | # PY35: {"coverage": False, "pkg_specs": {"pip": ">10"}}, 25 | # PY36: {"coverage": False, "pkg_specs": {"pip": ">19"}}, 26 | # IMPORTANT: this should be last so that the folder docs/reports is not deleted afterwards 27 | PY37: {"coverage": False, "pkg_specs": {"pip": ">19"}}, # , "pytest-html": "1.9.0" 28 | PY38: {"coverage": True, "pkg_specs": {"pip": ">19"}}, 29 | } 30 | 31 | 32 | # set the default activated sessions, minimal for CI 33 | nox.options.sessions = ["tests", "flake8", "docs"] # , "docs", "gh_pages" 34 | nox.options.reuse_existing_virtualenvs = True # this can be done using -r 35 | # if platform.system() == "Windows": >> always use this for better control 36 | nox.options.default_venv_backend = "conda" 37 | # os.environ["NO_COLOR"] = "True" # nox.options.nocolor = True does not work 38 | # nox.options.verbose = True 39 | 40 | nox_logger = logging.getLogger("nox") 41 | # nox_logger.setLevel(logging.INFO) NO !!!! this prevents the "verbose" nox flag to work ! 42 | 43 | 44 | class Folders: 45 | root = Path(__file__).parent 46 | ci_tools = root / "ci_tools" 47 | runlogs = root / Path(nox.options.envdir or ".nox") / "_runlogs" 48 | runlogs.mkdir(parents=True, exist_ok=True) 49 | dist = root / "dist" 50 | site = root / "site" 51 | site_reports = site / "reports" 52 | reports_root = root / "docs" / "reports" 53 | test_reports = reports_root / "junit" 54 | test_xml = test_reports / "junit.xml" 55 | test_html = test_reports / "report.html" 56 | test_badge = test_reports / "junit-badge.svg" 57 | coverage_reports = reports_root / "coverage" 58 | coverage_xml = coverage_reports / "coverage.xml" 59 | coverage_intermediate_file = root / ".coverage" 60 | coverage_badge = coverage_reports / "coverage-badge.svg" 61 | flake8_reports = reports_root / "flake8" 62 | flake8_intermediate_file = root / "flake8stats.txt" 63 | flake8_badge = flake8_reports / "flake8-badge.svg" 64 | 65 | 66 | @power_session(envs=ENVS, logsdir=Folders.runlogs) 67 | def tests(session: PowerSession, coverage, pkg_specs): 68 | """Run the test suite, including test reports generation and coverage reports. """ 69 | 70 | # As soon as this runs, we delete the target site and coverage files to avoid reporting wrong coverage/etc. 71 | rm_folder(Folders.site) 72 | rm_folder(Folders.reports_root) 73 | # delete the .coverage files if any (they are not supposed to be any, but just in case) 74 | rm_file(Folders.coverage_intermediate_file) 75 | rm_file(Folders.root / "coverage.xml") 76 | 77 | # CI-only dependencies 78 | # Did we receive a flag through positional arguments ? (nox -s tests -- <flag>) 79 | # install_ci_deps = False 80 | # if len(session.posargs) == 1: 81 | # assert session.posargs[0] == "keyrings.alt" 82 | # install_ci_deps = True 83 | # elif len(session.posargs) > 1: 84 | # raise ValueError("Only a single positional argument is accepted, received: %r" % session.posargs) 85 | 86 | # uncomment and edit if you wish to uninstall something without deleting the whole env 87 | # session.run2("pip uninstall pytest-asyncio --yes") 88 | 89 | # install all requirements 90 | # session.install_reqs(phase="pip", phase_reqs=("pip",), versions_dct=pkg_specs) 91 | session.install_reqs(setup=True, install=True, tests=True, versions_dct=pkg_specs) 92 | 93 | # install CI-only dependencies 94 | # if install_ci_deps: 95 | # session.install2("keyrings.alt") 96 | 97 | # list all (conda list alone does not work correctly on github actions) 98 | # session.run2("conda list") 99 | conda_prefix = Path(session.bin) 100 | if conda_prefix.name == "bin": 101 | conda_prefix = conda_prefix.parent 102 | session.run2("conda list", env={"CONDA_PREFIX": str(conda_prefix), "CONDA_DEFAULT_ENV": session.get_session_id()}) 103 | 104 | # Fail if the assumed python version is not the actual one 105 | session.run2("python ci_tools/check_python_version.py %s" % session.python) 106 | 107 | # check that it can be imported even from a different folder 108 | # Important: do not surround the command into double quotes as in the shell ! 109 | # session.run('python', '-c', 'import os; os.chdir(\'./docs/\'); import %s' % pkg_name) 110 | 111 | # finally run all tests 112 | if not coverage: 113 | # install self so that it is recognized by pytest 114 | session.run2("pip install . --no-deps") 115 | # session.install(".", "--no-deps") 116 | 117 | # simple: pytest only 118 | session.run2("python -m pytest --cache-clear -v tests/") 119 | else: 120 | # install self in "develop" mode so that coverage can be measured 121 | session.run2("pip install -e . --no-deps") 122 | 123 | # coverage + junit html reports + badge generation 124 | session.install_reqs(phase="coverage", 125 | phase_reqs=["coverage", "pytest-html", "genbadge[tests,coverage]"], 126 | versions_dct=pkg_specs) 127 | 128 | # --coverage + junit html reports 129 | session.run2("coverage run --source src/{pkg_name} " 130 | "-m pytest --cache-clear --junitxml={test_xml} --html={test_html} -v tests/" 131 | "".format(pkg_name=pkg_name, test_xml=Folders.test_xml, test_html=Folders.test_html)) 132 | session.run2("coverage report") 133 | session.run2("coverage xml -o {covxml}".format(covxml=Folders.coverage_xml)) 134 | session.run2("coverage html -d {dst}".format(dst=Folders.coverage_reports)) 135 | # delete this intermediate file, it is not needed anymore 136 | rm_file(Folders.coverage_intermediate_file) 137 | 138 | # --generates the badge for the test results and fail build if less than x% tests pass 139 | nox_logger.info("Generating badge for tests coverage") 140 | # Use our own package to generate the badge 141 | session.run2("genbadge tests -i %s -o %s -t 100" % (Folders.test_xml, Folders.test_badge)) 142 | session.run2("genbadge coverage -i %s -o %s" % (Folders.coverage_xml, Folders.coverage_badge)) 143 | 144 | 145 | @power_session(python=PY38, logsdir=Folders.runlogs) 146 | def flake8(session: PowerSession): 147 | """Launch flake8 qualimetry.""" 148 | 149 | session.install("-r", str(Folders.ci_tools / "flake8-requirements.txt")) 150 | session.install_reqs(phase="flake8", phase_reqs=["numpy", "pandas", "scikit-learn"]) 151 | session.run2("pip install . --no-deps") 152 | 153 | rm_folder(Folders.flake8_reports) 154 | Folders.flake8_reports.mkdir(parents=True, exist_ok=True) 155 | rm_file(Folders.flake8_intermediate_file) 156 | 157 | session.cd("src") 158 | 159 | # Options are set in `setup.cfg` file 160 | session.run("flake8", pkg_name, "--exit-zero", "--format=html", "--htmldir", str(Folders.flake8_reports), 161 | "--statistics", "--tee", "--output-file", str(Folders.flake8_intermediate_file)) 162 | # generate our badge 163 | session.run2("genbadge flake8 -i %s -o %s" % (Folders.flake8_intermediate_file, Folders.flake8_badge)) 164 | rm_file(Folders.flake8_intermediate_file) 165 | 166 | 167 | @power_session(python=[PY38]) 168 | def docs(session: PowerSession): 169 | """Generates the doc and serves it on a local http server. Pass '-- build' to build statically instead.""" 170 | 171 | # we need to install self for the doc gallery examples to work 172 | session.install_reqs(phase="docs", phase_reqs=["numpy", "pandas", "scikit-learn", "matplotlib", "seaborn"]) 173 | session.run2("pip install . --no-deps") 174 | session.install_reqs(phase="docs", phase_reqs=[ 175 | "mkdocs-material", "mkdocs", "pymdown-extensions", "pygments", "mkdocs-gallery", "pillow", "matplotlib", 176 | ]) 177 | 178 | if session.posargs: 179 | # use posargs instead of "serve" 180 | session.run2("mkdocs %s" % " ".join(session.posargs)) 181 | else: 182 | session.run2("mkdocs serve") 183 | 184 | 185 | @power_session(python=[PY38]) 186 | def publish(session: PowerSession): 187 | """Deploy the docs+reports on github pages. Note: this rebuilds the docs""" 188 | 189 | # we need to install self for the doc gallery examples to work 190 | session.install_reqs(phase="publish", phase_reqs=["numpy", "pandas", "scikit-learn", "matplotlib", "seaborn"]) 191 | session.run2("pip install . --no-deps") 192 | session.install_reqs(phase="publish", phase_reqs=[ 193 | "mkdocs-material", "mkdocs", "pymdown-extensions", "pygments", "mkdocs-gallery", "pillow" 194 | ]) 195 | 196 | # possibly rebuild the docs in a static way (mkdocs serve does not build locally) 197 | session.run2("mkdocs build") 198 | 199 | # check that the doc has been generated with coverage 200 | if not Folders.site_reports.exists(): 201 | raise ValueError("Test reports have not been built yet. Please run 'nox -s tests-3.7' first") 202 | 203 | # publish the docs 204 | session.run2("mkdocs gh-deploy") 205 | 206 | # publish the coverage - now in github actions only 207 | # session.install_reqs(phase="codecov", phase_reqs=["codecov", "keyring"]) 208 | # # keyring set https://app.codecov.io/gh/<org>/<repo> token 209 | # import keyring # (note that this import is not from the session env but the main nox env) 210 | # codecov_token = keyring.get_password("https://app.codecov.io/gh/<org>/<repo>>", "token") 211 | # # note: do not use --root nor -f ! otherwise "There was an error processing coverage reports" 212 | # session.run2('codecov -t %s -f %s' % (codecov_token, Folders.coverage_xml)) 213 | 214 | 215 | @power_session(python=[PY38]) 216 | def release(session: PowerSession): 217 | """Create a release on github corresponding to the latest tag""" 218 | 219 | # Get current tag using setuptools_scm and make sure this is not a dirty/dev one 220 | from setuptools_scm import get_version # (note that this import is not from the session env but the main nox env) 221 | from setuptools_scm.version import guess_next_dev_version 222 | version = [] 223 | 224 | def my_scheme(version_): 225 | version.append(version_) 226 | return guess_next_dev_version(version_) 227 | current_tag = get_version(".", version_scheme=my_scheme) 228 | 229 | # create the package 230 | session.install_reqs(phase="setup.py#dist", phase_reqs=["setuptools_scm"]) 231 | rm_folder(Folders.dist) 232 | session.run2("python setup.py sdist bdist_wheel") 233 | 234 | if version[0].dirty or not version[0].exact: 235 | raise ValueError("You need to execute this action on a clean tag version with no local changes.") 236 | 237 | # Did we receive a token through positional arguments ? (nox -s release -- <token>) 238 | if len(session.posargs) == 1: 239 | # Run from within github actions - no need to publish on pypi 240 | gh_token = session.posargs[0] 241 | publish_on_pypi = False 242 | 243 | elif len(session.posargs) == 0: 244 | # Run from local commandline - assume we want to manually publish on PyPi 245 | publish_on_pypi = True 246 | 247 | # keyring set https://docs.github.com/en/rest token 248 | import keyring # (note that this import is not from the session env but the main nox env) 249 | gh_token = keyring.get_password("https://docs.github.com/en/rest", "token") 250 | assert len(gh_token) > 0 251 | 252 | else: 253 | raise ValueError("Only a single positional arg is allowed for now") 254 | 255 | # publish the package on PyPi 256 | if publish_on_pypi: 257 | # keyring set https://upload.pypi.org/legacy/ your-username 258 | # keyring set https://test.pypi.org/legacy/ your-username 259 | session.install_reqs(phase="PyPi", phase_reqs=["twine"]) 260 | session.run2("twine upload dist/* -u smarie") # -r testpypi 261 | 262 | # create the github release 263 | session.install_reqs(phase="release", phase_reqs=["click", "PyGithub"]) 264 | session.run2("python ci_tools/github_release.py -s {gh_token} " 265 | "--repo-slug {gh_org}/{gh_repo} -cf ./docs/changelog.md " 266 | "-d https://{gh_org}.github.io/{gh_repo}/changelog {tag}" 267 | "".format(gh_token=gh_token, gh_org=gh_org, gh_repo=gh_repo, tag=current_tag)) 268 | 269 | 270 | @nox.session(python=False) 271 | def gha_list(session): 272 | """(mandatory arg: <base_session_name>) Prints all sessions available for <base_session_name>, for GithubActions.""" 273 | 274 | # see https://stackoverflow.com/q/66747359/7262247 275 | 276 | # get the desired base session to generate the list for 277 | if len(session.posargs) != 1: 278 | raise ValueError("This session has a mandatory argument: <base_session_name>") 279 | session_func = globals()[session.posargs[0]] 280 | 281 | # list all sessions for this base session 282 | try: 283 | session_func.parametrize 284 | except AttributeError: 285 | sessions_list = ["%s-%s" % (session_func.__name__, py) for py in session_func.python] 286 | else: 287 | sessions_list = ["%s-%s(%s)" % (session_func.__name__, py, param) 288 | for py, param in product(session_func.python, session_func.parametrize)] 289 | 290 | # print the list so that it can be caught by GHA. 291 | # Note that json.dumps is optional since this is a list of string. 292 | # However it is to remind us that GHA expects a well-formatted json list of strings. 293 | print(dumps(sessions_list)) 294 | 295 | 296 | # if __name__ == '__main__': 297 | # # allow this file to be executable for easy debugging in any IDE 298 | # nox.run(globals()) 299 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=39.2", 4 | "setuptools_scm", 5 | "wheel" 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | # pip: no ! does not work in old python 2.7 and not recommended here 10 | # https://setuptools.readthedocs.io/en/latest/userguide/quickstart.html#basic-use 11 | 12 | [tool.conda] 13 | # Declare that the following packages should be installed with conda instead of pip 14 | # Note: this includes packages declared everywhere, here and in setup.cfg 15 | conda_packages = [ 16 | "setuptools", 17 | "wheel", 18 | "pip", 19 | "scikit-learn", 20 | "pandas", 21 | "numpy", 22 | "matplotlib", 23 | "seaborn" 24 | ] 25 | # pytest: not with conda ! does not work in old python 2.7 and 3.5 26 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # See https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files 2 | # And this great example : https://github.com/Kinto/kinto/blob/master/setup.cfg 3 | [metadata] 4 | name = m5py 5 | description = An implementation of M5 (Prime) and model trees for scikit-learn. 6 | description_file = README.md 7 | license = BSD 3-Clause 8 | long_description = file: docs/long_description.md 9 | long_description_content_type=text/markdown 10 | keywords = model tree regression M5 prime scikit learn machine learning 11 | author = Sylvain MARIE <sylvain.marie@se.com> 12 | maintainer = Sylvain MARIE <sylvain.marie@se.com> 13 | url = https://github.com/smarie/python-m5p 14 | # download_url = https://github.com/smarie/python-m5p/tarball/main >> do it in the setup.py to get the right version 15 | classifiers = 16 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 17 | Development Status :: 5 - Production/Stable 18 | Intended Audience :: Developers 19 | Intended Audience :: Science/Research 20 | License :: OSI Approved :: BSD License 21 | Topic :: Software Development :: Libraries :: Python Modules 22 | Topic :: Scientific/Engineering :: Mathematics 23 | Programming Language :: Python 24 | Programming Language :: Python :: 3 25 | # Programming Language :: Python :: 3.5 26 | # Programming Language :: Python :: 3.6 27 | Programming Language :: Python :: 3.7 28 | Programming Language :: Python :: 3.8 29 | Programming Language :: Python :: 3.9 30 | 31 | [options] 32 | # one day these will be able to come from requirement files, see https://github.com/pypa/setuptools/issues/1951. But will it be better ? 33 | setup_requires = 34 | setuptools_scm 35 | pytest-runner 36 | install_requires = 37 | # note: do not use double quotes in these, this triggers a weird bug in PyCharm in debug mode only 38 | scikit-learn 39 | # funcsigs;python_version<'3.3' 40 | # enum34;python_version<'3.4' 41 | tests_require = 42 | pytest 43 | 44 | # test_suite = tests --> no need apparently 45 | # 46 | zip_safe = False 47 | # explicitly setting zip_safe=False to avoid downloading `ply` see https://github.com/smarie/python-getversion/pull/5 48 | # and makes mypy happy see https://mypy.readthedocs.io/en/latest/installed_packages.html 49 | package_dir= 50 | =src 51 | packages = find: 52 | # see [options.packages.find] below 53 | # IMPORTANT: DO NOT set the `include_package_data` flag !! It triggers inclusion of all git-versioned files 54 | # see https://github.com/pypa/setuptools_scm/issues/190#issuecomment-351181286 55 | # include_package_data = True 56 | [options.packages.find] 57 | where=src 58 | exclude = 59 | contrib 60 | docs 61 | *tests* 62 | 63 | [options.package_data] 64 | * = py.typed, *.pyi 65 | 66 | 67 | # Optional dependencies that can be installed with e.g. $ pip install -e .[dev,test] 68 | # [options.extras_require] 69 | 70 | # -------------- Packaging ----------- 71 | # [options.entry_points] 72 | 73 | # [egg_info] >> already covered by setuptools_scm 74 | 75 | [bdist_wheel] 76 | # Code is written to work on both Python 2 and Python 3. 77 | universal=1 78 | 79 | # ------------- Others ------------- 80 | # In order to be able to execute 'python setup.py test' 81 | # from https://docs.pytest.org/en/latest/goodpractices.html#integrating-with-setuptools-python-setup-py-test-pytest-runner 82 | [aliases] 83 | test = pytest 84 | 85 | # pytest default configuration 86 | [tool:pytest] 87 | testpaths = tests/ 88 | addopts = 89 | --verbose 90 | --doctest-modules 91 | --ignore-glob='**/_*.py' 92 | 93 | # we need the 'always' for python 2 tests to work see https://github.com/pytest-dev/pytest/issues/2917 94 | filterwarnings = 95 | always 96 | ; ignore::UserWarning 97 | 98 | # Coverage config 99 | [coverage:run] 100 | branch = True 101 | omit = *tests* 102 | # this is done in nox.py (github actions) or ci_tools/run_tests.sh (travis) 103 | # source = m5py 104 | # command_line = -m pytest --junitxml="reports/pytest_reports/pytest.xml" --html="reports/pytest_reports/pytest.html" -v m5py/tests/ 105 | 106 | [coverage:report] 107 | fail_under = 40 108 | show_missing = True 109 | exclude_lines = 110 | # this line for all the python 2 not covered lines 111 | except ImportError: 112 | # we have to repeat this when exclude_lines is set 113 | pragma: no cover 114 | 115 | # Done in nox.py 116 | # [coverage:html] 117 | # directory = site/reports/coverage_reports 118 | # [coverage:xml] 119 | # output = site/reports/coverage_reports/coverage.xml 120 | 121 | [flake8] 122 | max-line-length = 120 123 | extend-ignore = D, E203 # D: Docstring errors, E203: see https://github.com/PyCQA/pycodestyle/issues/373 124 | copyright-check = True 125 | copyright-regexp = ^\#\s+Authors:\s+Sylvain MARIE <sylvain\.marie@se\.com>\n\#\s+\+\sAll\scontributors\sto\s<https://github\.com/smarie/python\-m5py>\n\#\n\#\s\sLicense:\s3\-clause\sBSD,\s<https://github\.com/smarie/python\-m5py/blob/main/LICENSE> 126 | exclude = 127 | .git 128 | .github 129 | .nox 130 | .pytest_cache 131 | ci_tools 132 | docs 133 | tests 134 | noxfile.py 135 | setup.py 136 | */_version.py 137 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | To understand this project's build structure 3 | 4 | - This project uses setuptools, so it is declared as the build system in the pyproject.toml file 5 | - We use as much as possible `setup.cfg` to store the information so that it can be read by other tools such as `tox` 6 | and `nox`. So `setup.py` contains **almost nothing** (see below) 7 | This philosophy was found after trying all other possible combinations in other projects :) 8 | A reference project that was inspiring to make this move : https://github.com/Kinto/kinto/blob/master/setup.cfg 9 | 10 | See also: 11 | https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files 12 | https://packaging.python.org/en/latest/distributing.html 13 | https://github.com/pypa/sampleproject 14 | """ 15 | from setuptools import setup 16 | 17 | 18 | # (1) check required versions (from https://medium.com/@daveshawley/safely-using-setup-cfg-for-metadata-1babbe54c108) 19 | import pkg_resources 20 | 21 | pkg_resources.require("setuptools>=39.2") 22 | pkg_resources.require("setuptools_scm") 23 | 24 | 25 | # (2) Generate download url using git version 26 | from setuptools_scm import get_version # noqa: E402 27 | 28 | URL = "https://github.com/smarie/python-m5p" 29 | DOWNLOAD_URL = URL + "/tarball/" + get_version() 30 | 31 | 32 | # (3) Call setup() with as little args as possible 33 | setup( 34 | download_url=DOWNLOAD_URL, 35 | use_scm_version={ 36 | "write_to": "src/m5py/_version.py" 37 | }, # we can't put `use_scm_version` in setup.cfg yet unfortunately 38 | ) 39 | -------------------------------------------------------------------------------- /src/m5py/__init__.py: -------------------------------------------------------------------------------- 1 | # Authors: Sylvain MARIE <sylvain.marie@se.com> 2 | # + All contributors to <https://github.com/smarie/python-m5p> 3 | # 4 | # License: 3-clause BSD, <https://github.com/smarie/python-m5p/blob/main/LICENSE> 5 | 6 | from m5py.main import M5Prime 7 | from m5py.export import export_text_m5 8 | 9 | try: 10 | # -- Distribution mode -- 11 | # import from _version.py generated by setuptools_scm during release 12 | from ._version import version as __version__ 13 | except ImportError: 14 | # -- Source mode -- 15 | # use setuptools_scm to get the current version from src using git 16 | from setuptools_scm import get_version as _gv 17 | from os import path as _path 18 | __version__ = _gv(_path.join(_path.dirname(__file__), _path.pardir)) 19 | 20 | __all__ = [ 21 | "__version__", 22 | # submodules 23 | "main", 24 | "export", 25 | # symbols 26 | "M5Prime", 27 | "export_text_m5" 28 | ] 29 | -------------------------------------------------------------------------------- /src/m5py/export.py: -------------------------------------------------------------------------------- 1 | from numbers import Integral 2 | 3 | import numpy as np 4 | from io import StringIO 5 | 6 | from sklearn.tree import _tree 7 | from sklearn.tree._criterion import FriedmanMSE 8 | 9 | from m5py.main import is_leaf, ConstantLeafModel, M5Base, check_is_fitted 10 | 11 | 12 | def export_text_m5(decision_tree, out_file=None, max_depth=None, 13 | feature_names=None, class_names=None, label='all', 14 | target_name=None, 15 | # filled=False, leaves_parallel=False, 16 | impurity=True, 17 | node_ids=False, proportion=False, 18 | # rounded=False, rotate=False, 19 | special_characters=False, precision=3, **kwargs): 20 | """Export a decision tree in TXT format. 21 | 22 | Note: this should be merged with ._export.export_text 23 | 24 | Inspired by WEKA and by 25 | >>> from sklearn.tree import export_graphviz 26 | 27 | This function generates a human-readable, text representation of the 28 | decision tree, which is then written into `out_file`. 29 | 30 | The sample counts that are shown are weighted with any sample_weights that 31 | might be present. 32 | 33 | Read more in the :ref:`User Guide <tree>`. 34 | 35 | Parameters 36 | ---------- 37 | decision_tree : decision tree classifier 38 | The decision tree to be exported to text. 39 | 40 | out_file : file object or string, optional (default='tree.dot') 41 | Handle or name of the output file. If ``None``, the result is 42 | returned as a string. 43 | 44 | max_depth : int, optional (default=None) 45 | The maximum depth of the representation. If None, the tree is fully 46 | generated. 47 | 48 | feature_names : list of strings, optional (default=None) 49 | Names of each of the features. 50 | 51 | class_names : list of strings, bool or None, optional (default=None) 52 | Names of each of the target classes in ascending numerical order. 53 | Only relevant for classification and not supported for multi-output. 54 | If ``True``, shows a symbolic representation of the class name. 55 | 56 | label : {'all', 'root', 'none'}, optional (default='all') 57 | Whether to show informative labels for impurity, etc. 58 | Options include 'all' to show at every node, 'root' to show only at 59 | the top root node, or 'none' to not show at any node. 60 | 61 | target_name : optional string with the target name. If not provided, the 62 | target will not be displayed in the equations 63 | 64 | impurity : bool, optional (default=True) 65 | When set to ``True``, show the impurity at each node. 66 | 67 | node_ids : bool, optional (default=False) 68 | When set to ``True``, show the ID number on each node. 69 | 70 | proportion : bool, optional (default=False) 71 | When set to ``True``, change the display of 'values' and/or 'samples' 72 | to be proportions and percentages respectively. 73 | 74 | special_characters : bool, optional (default=False) 75 | When set to ``False``, ignore special characters for PostScript 76 | compatibility. 77 | 78 | precision : int, optional (default=3) 79 | Number of digits of precision for floating point in the values of 80 | impurity, threshold and value attributes of each node. 81 | 82 | kwargs : other keyword arguments for the linear model printer 83 | 84 | Returns 85 | ------- 86 | dot_data : string 87 | String representation of the input tree in GraphViz dot format. 88 | Only returned if ``out_file`` is None. 89 | 90 | Examples 91 | -------- 92 | >>> from sklearn.datasets import load_iris 93 | >>> from sklearn import tree 94 | 95 | >>> clf = tree.DecisionTreeClassifier() 96 | >>> iris = load_iris() 97 | 98 | >>> clf = clf.fit(iris.data, iris.target) 99 | >>> tree_to_text(clf, out_file='tree.txt') # doctest: +SKIP 100 | 101 | """ 102 | 103 | models = [] 104 | 105 | def add_model(node_model): 106 | models.append(node_model) 107 | return len(models) 108 | 109 | def node_to_str(tree, node_id, criterion, node_models=None): 110 | """ Generates the node content string """ 111 | 112 | # Should labels be shown? 113 | labels = (label == 'root' and node_id == 0) or label == 'all' 114 | 115 | # PostScript compatibility for special characters 116 | if special_characters: 117 | characters = ['#', '<SUB>', '</SUB>', '≤', '<br/>', '>'] 118 | node_string = '<' 119 | else: 120 | characters = ['#', '[', ']', '<=', '\\n', ''] 121 | node_string = '' 122 | 123 | # -- If this node is not a leaf, Write the split decision criteria (x <= y) 124 | leaf = is_leaf(node_id, tree) 125 | if not leaf: 126 | if feature_names is not None: 127 | feature = feature_names[tree.feature[node_id]] 128 | else: 129 | feature = "X%s%s%s" % (characters[1], 130 | tree.feature[node_id], # feature id for the split 131 | characters[2]) 132 | node_string += '%s %s %s' % (feature, 133 | characters[3], # <= 134 | round(tree.threshold[node_id], # threshold for the split 135 | precision)) 136 | else: 137 | node_string += 'LEAF' 138 | 139 | # Node details - start bracket [ 140 | node_string += ' %s' % characters[1] 141 | 142 | # -- Write impurity 143 | if impurity: 144 | if isinstance(criterion, FriedmanMSE): 145 | criterion = "friedman_mse" 146 | elif not isinstance(criterion, str): 147 | criterion = "impurity" 148 | if labels: 149 | node_string += '%s=' % criterion 150 | node_string += str(round(tree.impurity[node_id], precision)) + ', ' 151 | 152 | # -- Write node sample count 153 | if labels: 154 | node_string += 'samples=' 155 | if proportion: 156 | percent = (100. * tree.n_node_samples[node_id] / 157 | float(tree.n_node_samples[0])) 158 | node_string += str(round(percent, 1)) + '%' 159 | else: 160 | node_string += str(tree.n_node_samples[node_id]) 161 | 162 | # Node details - end bracket ] 163 | node_string += '%s' % characters[2] 164 | 165 | # -- Write node class distribution / regression value 166 | if tree.n_outputs == 1: 167 | value = tree.value[node_id][0, :] 168 | else: 169 | value = tree.value[node_id] 170 | 171 | if proportion and tree.n_classes[0] != 1: 172 | # For classification this will show the proportion of samples 173 | value = value / tree.weighted_n_node_samples[node_id] 174 | if tree.n_classes[0] == 1: 175 | # Regression 176 | value_text = np.around(value, precision) 177 | elif proportion: 178 | # Classification 179 | value_text = np.around(value, precision) 180 | elif np.all(np.equal(np.mod(value, 1), 0)): 181 | # Classification without floating-point weights 182 | value_text = value.astype(int) 183 | else: 184 | # Classification with floating-point weights 185 | value_text = np.around(value, precision) 186 | 187 | # Strip whitespace 188 | value_text = str(value_text.astype('S32')).replace("b'", "'") 189 | value_text = value_text.replace("' '", ", ").replace("'", "") 190 | if tree.n_classes[0] == 1 and tree.n_outputs == 1: 191 | value_text = value_text.replace("[", "").replace("]", "") 192 | value_text = value_text.replace("\n ", characters[4]) 193 | 194 | if node_models is None: 195 | node_string += ' : ' 196 | if labels: 197 | node_string += 'value=' 198 | else: 199 | nodemodel = node_models[node_id] 200 | model_err_val = np.around(nodemodel.error, precision) 201 | if leaf: 202 | if isinstance(nodemodel, ConstantLeafModel): 203 | # the model does not contain the value. rely on the value_text computed from tree 204 | value_text = " : %s (err=%s, params=%s)" % (value_text, model_err_val, nodemodel.n_params) 205 | else: 206 | # put the model in the stack, we'll write it later 207 | model_id = add_model(nodemodel) 208 | value_text = " : LM%s (err=%s, params=%s)" % (model_id, model_err_val, nodemodel.n_params) 209 | else: 210 | # replace the value text with error at this node and number of parameters 211 | value_text = " (err=%s, params=%s)" % (model_err_val, nodemodel.n_params) 212 | 213 | node_string += value_text 214 | 215 | # Write node majority class 216 | if (class_names is not None and 217 | tree.n_classes[0] != 1 and 218 | tree.n_outputs == 1): 219 | # Only done for single-output classification trees 220 | node_string += ', ' 221 | if labels: 222 | node_string += 'class=' 223 | if class_names is not True: 224 | class_name = class_names[np.argmax(value)] 225 | else: 226 | class_name = "y%s%s%s" % (characters[1], 227 | np.argmax(value), 228 | characters[2]) 229 | node_string += class_name 230 | 231 | return node_string + characters[5] 232 | 233 | def recurse(tree, node_id, criterion, parent=None, depth=0, node_models=None): 234 | if node_id == _tree.TREE_LEAF: 235 | raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF) 236 | 237 | # Add node with description 238 | if max_depth is None or depth <= max_depth: 239 | indent_str = ("| " * depth) 240 | if node_ids: 241 | out_file.write('%d| %s%s\n' % (node_id, indent_str, node_to_str(tree, node_id, criterion, 242 | node_models=node_models))) 243 | else: 244 | out_file.write('%s%s\n' % (indent_str, node_to_str(tree, node_id, criterion, node_models=node_models))) 245 | 246 | # Recurse on Children if needed 247 | left_child = tree.children_left[node_id] 248 | right_child = tree.children_right[node_id] 249 | 250 | # if not is_leaf(node_id, tree) 251 | if left_child != _tree.TREE_LEAF: 252 | # that means that node_id is not a leaf (see is_leaf() below.): recurse on children 253 | recurse(tree, left_child, criterion=criterion, parent=node_id, depth=depth + 1, 254 | node_models=node_models) 255 | recurse(tree, right_child, criterion=criterion, parent=node_id, depth=depth + 1, 256 | node_models=node_models) 257 | 258 | else: 259 | ranks['leaves'].append(str(node_id)) 260 | out_file.write('%d| (...)\n') 261 | 262 | def write_models(models): 263 | for i, model in enumerate(models): 264 | out_file.write("LM%s: %s\n" % (i + 1, model.to_text(feature_names=feature_names, precision=precision, 265 | target_name=target_name, **kwargs))) 266 | 267 | # Main 268 | check_is_fitted(decision_tree, 'tree_') 269 | own_file = False 270 | return_string = False 271 | try: 272 | if isinstance(out_file, str): 273 | out_file = open(out_file, "w", encoding="utf-8") 274 | own_file = True 275 | 276 | if out_file is None: 277 | return_string = True 278 | out_file = StringIO() 279 | 280 | if isinstance(precision, Integral): 281 | if precision < 0: 282 | raise ValueError("'precision' should be greater or equal to 0." 283 | " Got {} instead.".format(precision)) 284 | else: 285 | raise ValueError("'precision' should be an integer. Got {}" 286 | " instead.".format(type(precision))) 287 | 288 | # Check length of feature_names before getting into the tree node 289 | # Raise error if length of feature_names does not match 290 | # n_features_ in the decision_tree 291 | if feature_names is not None: 292 | if len(feature_names) != decision_tree.n_features_in_: 293 | raise ValueError("Length of feature_names, %d " 294 | "does not match number of features, %d" 295 | % (len(feature_names), 296 | decision_tree.n_features_in_)) 297 | 298 | # The depth of each node for plotting with 'leaf' option TODO probably remove 299 | ranks = {'leaves': []} 300 | 301 | # Tree title 302 | if isinstance(decision_tree, M5Base): 303 | if hasattr(decision_tree, 'installed_smoothing_constant'): 304 | details = "pre-smoothed with constant %s" % decision_tree.installed_smoothing_constant 305 | else: 306 | if decision_tree.use_smoothing == 'installed': 307 | details = "under construction - not pre-smoothed yet" 308 | else: 309 | details = "unsmoothed - but this can be done at prediction time" 310 | 311 | # add more info or M5P 312 | out_file.write('%s (%s):\n' % (type(decision_tree).__name__, details)) 313 | else: 314 | # generic title 315 | out_file.write('%s :\n' % type(decision_tree).__name__) 316 | 317 | # some space for readability 318 | out_file.write('\n') 319 | 320 | # Now recurse the tree and add node & edge attributes 321 | if isinstance(decision_tree, _tree.Tree): 322 | recurse(decision_tree, 0, criterion="impurity") 323 | elif isinstance(decision_tree, M5Base) and hasattr(decision_tree, 'node_models'): 324 | recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion, node_models=decision_tree.node_models) 325 | 326 | # extra step: write all models 327 | out_file.write("\n") 328 | write_models(models) 329 | else: 330 | recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion) 331 | 332 | # Return the text if needed 333 | if return_string: 334 | return out_file.getvalue() 335 | 336 | finally: 337 | if own_file: 338 | out_file.close() 339 | -------------------------------------------------------------------------------- /src/m5py/linreg_utils.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABCMeta 2 | 3 | import numpy as np 4 | 5 | from sklearn.linear_model import LinearRegression 6 | from sklearn.linear_model._base import LinearModel 7 | from sklearn.preprocessing import StandardScaler 8 | from sklearn.utils.extmath import safe_sparse_dot 9 | 10 | 11 | def linreg_model_to_text(model, feature_names=None, target_name=None, 12 | precision=3, line_breaks=False): 13 | """ 14 | Converts a linear regression model to a text representation. 15 | 16 | :param model: 17 | :param feature_names: 18 | :param target_name: 19 | :param precision: 20 | :param line_breaks: if True, each term in the sum shows in a different line 21 | :return: 22 | """ 23 | bits = [] 24 | 25 | # Template for numbers: we want scientific notation with a given precision 26 | nb_tpl = "%%0.%se" % precision 27 | 28 | # Handle multi-dimensional y (its second dim must be size 1, though) 29 | if len(model.coef_.shape) > 1: 30 | assert model.coef_.shape[0] == 1 31 | assert len(model.coef_.shape) == 2 32 | coefs = np.ravel(model.coef_) # convert to 1D 33 | assert len(model.intercept_) == 1 34 | intercept = model.intercept_.item() # extract scalar 35 | else: 36 | coefs = model.coef_ # a 1D array 37 | intercept = model.intercept_ # a scalar already 38 | 39 | # First all coefs * drivers 40 | for i, c in enumerate(coefs): 41 | var_name = ("X[%s]" % i) if feature_names is None else feature_names[i] 42 | 43 | if i == 0: 44 | # first term 45 | if c < 1: 46 | # use scientific notation 47 | product_text = (nb_tpl + " * %s") % (c, var_name) 48 | else: 49 | # use standard notation with precision 50 | c = np.round(c, precision) 51 | product_text = "%s * %s" % (c, var_name) 52 | else: 53 | # all the other terms: the sign should appear 54 | lb = '\n' if line_breaks else "" 55 | coef_abs = np.abs(c) 56 | coef_sign = '+' if np.sign(c) > 0 else '-' 57 | if coef_abs < 1: 58 | # use scientific notation 59 | product_text = (("%s%s " + nb_tpl + " * %s") 60 | % (lb, coef_sign, coef_abs, var_name)) 61 | else: 62 | # use standard notation with precision 63 | coef_abs = np.round(coef_abs, precision) 64 | product_text = ("%s%s %s * %s" 65 | % (lb, coef_sign, coef_abs, var_name)) 66 | 67 | bits.append(product_text) 68 | 69 | # Finally the intercept 70 | if len(bits) == 0: 71 | # intercept is the only term in the sum 72 | if intercept < 1: 73 | # use scientific notation only for small numbers (otherwise 12 74 | # would read 1.2e1 ... not friendly) 75 | constant_text = nb_tpl % intercept 76 | else: 77 | # use standard notation with precision 78 | i = np.round(intercept, precision) 79 | constant_text = "%s" % i 80 | else: 81 | # there are other terms in the sum: the sign should appear 82 | lb = '\n' if line_breaks else "" 83 | coef_abs = np.abs(intercept) 84 | coef_sign = '+' if np.sign(intercept) > 0 else '-' 85 | if coef_abs < 1: 86 | # use scientific notation 87 | constant_text = ("%s%s " + nb_tpl) % (lb, coef_sign, coef_abs) 88 | else: 89 | # use standard notation with precision 90 | coef_abs = np.round(coef_abs, precision) 91 | constant_text = "%s%s %s" % (lb, coef_sign, coef_abs) 92 | 93 | bits.append(constant_text) 94 | 95 | txt = " ".join(bits) 96 | if target_name is not None: 97 | txt = target_name + " = " + txt 98 | 99 | return txt 100 | 101 | 102 | class DeNormalizableMixIn(metaclass=ABCMeta): 103 | """ 104 | An abstract class that models able to de-normalize should implement. 105 | """ 106 | __slots__ = () 107 | 108 | @abstractmethod 109 | def denormalize(self, 110 | x_scaler: StandardScaler = None, 111 | y_scaler: StandardScaler = None 112 | ): 113 | """ 114 | Denormalizes the model, knowing that it was fit with the given 115 | x_scaler and y_scaler 116 | """ 117 | 118 | 119 | class DeNormalizableLinearModelMixIn(DeNormalizableMixIn, LinearModel): 120 | """ 121 | A mix-in class to add 'denormalization' capability to a linear model 122 | """ 123 | def denormalize(self, 124 | x_scaler: StandardScaler = None, 125 | y_scaler: StandardScaler = None 126 | ): 127 | """ 128 | De-normalizes the linear regression model. 129 | Before this function is executed, 130 | (y-y_mean)/y_scale = self.coef_.T <dot> (x-x_mean)/x_scale + self.intercept_ 131 | so 132 | (y-y_mean)/y_scale = (self.coef_/x_scale).T <dot> x + (self.intercept_ - self.coef_.T <dot> x_mean/x_scale) 133 | that is 134 | (y-y_mean)/y_scale = new_coef.T <dot> x + new_intercept 135 | where 136 | * new_coef = (self.coef_/x_scale) 137 | * new_intercept = (self.intercept_ - (self.intercept_ - self.coef_.T <dot> x_mean/x_scale) 138 | 139 | Then going back to y 140 | y = (new_coef * y_scale).T <dot> x + (new_intercept * y_scale + y_mean) 141 | 142 | :param self: 143 | :param x_scaler: 144 | :param y_scaler: 145 | :return: 146 | """ 147 | # First save old coefficients 148 | self.normalized_coef_ = self.coef_ 149 | self.normalized_intercept_ = self.intercept_ 150 | 151 | # denormalize coefficients to take into account the x normalization 152 | if x_scaler is not None: 153 | new_coef = self.coef_ / x_scaler.scale_ 154 | new_intercept = ( 155 | self.intercept_ - 156 | safe_sparse_dot(x_scaler.mean_, new_coef.T, 157 | dense_output=True) 158 | ) 159 | 160 | self.coef_ = new_coef 161 | self.intercept_ = new_intercept 162 | 163 | # denormalize them further to take into account the y normalization 164 | if y_scaler is not None: 165 | new_coef = self.coef_ * y_scaler.scale_ 166 | new_intercept = y_scaler.inverse_transform( 167 | np.atleast_1d(self.intercept_) 168 | ) 169 | if np.isscalar(self.intercept_): 170 | new_intercept = new_intercept[0] 171 | self.coef_ = new_coef 172 | self.intercept_ = new_intercept 173 | 174 | 175 | class DeNormalizableLinearRegression(LinearRegression, 176 | DeNormalizableLinearModelMixIn): 177 | """ 178 | A Denormalizable linear regression. The old normalized coefficients are 179 | kept in a new field named `feature_importances_` 180 | """ 181 | @property 182 | def feature_importances_(self): 183 | if hasattr(self, '_feature_importances_'): 184 | return self._feature_importances_ 185 | else: 186 | return self.coef_ 187 | 188 | def denormalize(self, 189 | x_scaler: StandardScaler = None, 190 | y_scaler: StandardScaler = None): 191 | """ 192 | Denormalizes the model, and saves a copy of the old normalized 193 | coefficients in self._feature_importances. 194 | 195 | :param x_scaler: 196 | :param y_scaler: 197 | :return: 198 | """ 199 | self._feature_importances_ = self.coef_ 200 | super(DeNormalizableLinearRegression, self).denormalize(x_scaler, 201 | y_scaler) 202 | 203 | 204 | # For all other it should work too 205 | # class DeNormalizableLasso(Lasso, DeNormalizableLinearModelMixIn): 206 | # pass 207 | -------------------------------------------------------------------------------- /src/m5py/main.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from copy import copy 3 | from logging import getLogger 4 | from warnings import warn 5 | 6 | import numpy as np 7 | 8 | from scipy.sparse import issparse 9 | 10 | from sklearn import clone 11 | from sklearn.base import RegressorMixin, is_classifier 12 | from sklearn.linear_model import LinearRegression 13 | from sklearn.metrics import mean_squared_error 14 | from sklearn.preprocessing import StandardScaler 15 | from sklearn.tree import BaseDecisionTree, _tree 16 | from sklearn.tree._classes import DTYPE 17 | from sklearn.tree._tree import DOUBLE 18 | from sklearn.utils import check_array 19 | from sklearn.utils.validation import check_is_fitted 20 | from sklearn import __version__ as sklearn_version 21 | 22 | from m5py.linreg_utils import linreg_model_to_text, DeNormalizableMixIn, DeNormalizableLinearRegression 23 | 24 | from packaging.version import Version 25 | 26 | SKLEARN_VERSION = Version(sklearn_version) 27 | SKLEARN13_OR_GREATER = SKLEARN_VERSION >= Version("1.3.0") 28 | 29 | 30 | __all__ = ["M5Base", "M5Prime"] 31 | 32 | 33 | _SmoothingDetails = namedtuple("_SmoothingDetails", ("A", "B", "C")) 34 | # Internal structure to contain the recursive smoothed coefficients details in the smoothing algorithm 35 | 36 | 37 | logger = getLogger("m5p") 38 | 39 | 40 | class M5Base(BaseDecisionTree): 41 | """ 42 | M5Base. Implements base routines for generating M5 PredictionModel trees and rules. 43 | 44 | The original algorithm M5 was invented by Quinlan: 45 | 46 | - Quinlan J. R. (1992). Learning with continuous classes. Proceedings of the 47 | Australian Joint Conference on Artificial Intelligence. 343--348. World 48 | Scientific, Singapore. 49 | 50 | Yong Wang and Ian Witten made improvements and created M5': 51 | 52 | - Wang, Y and Witten, I. H. (1997). Induction of model trees for predicting 53 | continuous classes. Proceedings of the poster papers of the European 54 | Conference on Machine Learning. University of Economics, Faculty of 55 | Informatics and Statistics, Prague. 56 | 57 | Pruning and Smoothing can be activated and deactivated on top of the base 58 | model. TODO check if 'rules' should be supported too 59 | 60 | Inspired by Weka (https://github.com/bnjmn/weka) M5Base class, from Mark Hall 61 | 62 | Attributes 63 | ---------- 64 | criterion : str, default="friedman_mse" 65 | M5 suggests to use the standard deviation (RMSE impurity) instead of 66 | variance (MSE impurity) or absolute deviation (MAE impurity) as in CART. 67 | According to M5' paper, both 3 criterions are equivalent. 68 | splitter : default="best" 69 | M5 suggests to take the feature leading to the best gain in criterion. 70 | max_depth : default None 71 | This is not used in the original M5 article, hence default=None. 72 | min_samples_split : int, default=4 73 | M5 suggests a value of 4 so that each leaf will have at least 2 samples. 74 | min_samples_leaf : int, default=2 75 | M5' suggest to add this explicitly to avoid zero-variance in leaves, and 76 | n<=p for constant models. 77 | min_weight_fraction_leaf : float, default=0.0 78 | This is not used in the original M5 article, hence default=0.0. 79 | TODO this would actually maybe be better than min_sample_leaf ? 80 | max_features : default None 81 | This is not used in the original M5 article, hence default is None. 82 | max_leaf_nodes : int, default None 83 | This is not used in the original M5 article, hence default is None. 84 | min_impurity_decrease : float, default=0.0 85 | This is not used in the original M5 article, hence default is None. 86 | class_weight : default None 87 | This is not used (?) 88 | leaf_model : RegressorMixin 89 | The regression model used in the leaves. This instance will be cloned 90 | for each leaf. 91 | use_pruning : bool 92 | If False, pruning will be disabled. 93 | use_smoothing : bool or str {'installed', 'on_prediction'}, default None 94 | None and True means 'installed' by default except if smoothing_constant 95 | or smoothing_constant_ratio is 0.0 96 | smoothing_constant: float, default None 97 | The smoothing constant k defined in the original M5 article, used as the 98 | weight for each parent model in the recursive weighting process. 99 | If None, the default value from the paper (k=15) will be used. 100 | smoothing_constant_ratio: float, default None 101 | An alternate way to define the smoothing constant, as a ratio of the 102 | total number of training samples. The resulting smoothing constant will 103 | be smoothing_constant_ratio * n where n is the number of samples. 104 | Note that this may not be an integer. 105 | debug_prints: bool, default False 106 | A boolean to enable debug prints 107 | ccp_alpha: float, default 0.0 108 | TODO is this relevant ? does that conflict with "use_pruning" ? 109 | random_state : None, int, or RandomState, default=None 110 | See `RegressionTree.random_state`. 111 | """ 112 | 113 | def __init__( 114 | self, 115 | criterion="friedman_mse", 116 | splitter="best", 117 | max_depth=None, 118 | min_samples_split=4, 119 | min_samples_leaf=2, 120 | min_weight_fraction_leaf=0.0, 121 | max_features=None, 122 | max_leaf_nodes=None, 123 | min_impurity_decrease=0.0, 124 | class_weight=None, 125 | leaf_model=None, 126 | use_pruning=True, 127 | use_smoothing=None, 128 | smoothing_constant=None, 129 | smoothing_constant_ratio=None, 130 | debug_prints=False, 131 | ccp_alpha=0.0, 132 | random_state=None, 133 | ): 134 | 135 | # TODO the paper suggests to do this with 5% of total std 136 | # min_impurity_split = min_impurity_split_as_initial_ratio * initial_impurity 137 | 138 | super(M5Base, self).__init__( 139 | criterion=criterion, 140 | splitter=splitter, 141 | max_depth=max_depth, 142 | min_samples_split=min_samples_split, 143 | min_samples_leaf=min_samples_leaf, 144 | min_weight_fraction_leaf=min_weight_fraction_leaf, 145 | max_features=max_features, 146 | max_leaf_nodes=max_leaf_nodes, 147 | min_impurity_decrease=min_impurity_decrease, 148 | random_state=random_state, 149 | class_weight=class_weight, 150 | ccp_alpha=ccp_alpha, 151 | ) 152 | 153 | # warning : if the field names are different from constructor params, 154 | # then clone(self) will not work. 155 | if leaf_model is None: 156 | # to handle case when the model is learnt on normalized data and 157 | # we wish to be able to read the model equations. 158 | leaf_model = DeNormalizableLinearRegression() 159 | self.leaf_model = leaf_model 160 | 161 | # smoothing related 162 | if smoothing_constant_ratio is not None and smoothing_constant is not None: 163 | raise ValueError("Only one of `smoothing_constant` and `smoothing_constant_ratio` should be provided") 164 | elif (smoothing_constant_ratio == 0.0 or smoothing_constant == 0) and ( 165 | use_smoothing is True or use_smoothing == "installed" 166 | ): 167 | raise ValueError( 168 | "`use_smoothing` was explicitly enabled with " 169 | "pre-installed models, while smoothing " 170 | "constant/ratio are explicitly set to zero" 171 | ) 172 | 173 | self.use_pruning = use_pruning 174 | self.use_smoothing = use_smoothing 175 | self.smoothing_constant = smoothing_constant 176 | self.smoothing_constant_ratio = smoothing_constant_ratio 177 | 178 | self.debug_prints = debug_prints 179 | 180 | def as_pretty_text(self, **kwargs): 181 | """ 182 | Returns a multi-line representation of this decision tree, using 183 | `tree_to_text`. 184 | 185 | :return: a multi-line string representing this decision tree 186 | """ 187 | from m5py.export import export_text_m5 188 | return export_text_m5(self, out_file=None, **kwargs) 189 | 190 | def fit(self, X, y: np.ndarray, sample_weight=None, check_input=True, X_idx_sorted="deprecated"): 191 | """Fit a M5Prime model. 192 | 193 | Parameters 194 | ---------- 195 | X : numpy.ndarray 196 | y : numpy.ndarray 197 | sample_weight 198 | check_input 199 | X_idx_sorted 200 | 201 | Returns 202 | ------- 203 | 204 | """ 205 | # (0) smoothing default values behaviour 206 | if self.smoothing_constant_ratio == 0.0 or self.smoothing_constant == 0: 207 | self.use_smoothing = False 208 | elif self.use_smoothing is None or self.use_smoothing is True: 209 | # default behaviour 210 | if isinstance(self.leaf_model, LinearRegression): 211 | self.use_smoothing = "installed" 212 | else: 213 | self.use_smoothing = "on_prediction" 214 | 215 | # Finally make sure we are ok now 216 | if self.use_smoothing not in [False, np.bool_(False), "installed", "on_prediction"]: 217 | raise ValueError("use_smoothing: Unexpected value: %s, please report it as issue." % self.use_smoothing) 218 | 219 | 220 | # (1) Build the initial tree as usual 221 | 222 | # Get the correct fit method name based on the sklearn version used 223 | fit_method_name = "_fit" if SKLEARN13_OR_GREATER else "fit" 224 | 225 | fit_method = getattr(super(M5Base, self), fit_method_name) 226 | fit_method(X, y, sample_weight=sample_weight, check_input=check_input) 227 | 228 | 229 | if self.debug_prints: 230 | logger.debug("(debug_prints) Initial tree:") 231 | logger.debug(self.as_pretty_text(node_ids=True)) 232 | 233 | # (1b) prune initial tree to take into account min impurity in splits 234 | prune_on_min_impurity(self.tree_) 235 | 236 | if self.debug_prints: 237 | logger.debug("(debug_prints) Postprocessed tree:") 238 | logger.debug(self.as_pretty_text(node_ids=True)) 239 | 240 | # (2) Now prune tree and replace pruned branches with linear models 241 | # -- unfortunately we have to re-do this input validation 242 | # step, because it also converts the input to float32. 243 | if check_input: 244 | X = check_array(X, dtype=DTYPE, accept_sparse="csc") 245 | if issparse(X): 246 | X.sort_indices() 247 | 248 | if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: 249 | raise ValueError("No support for np.int64 index based " "sparse matrices") 250 | 251 | # -- initialise the structure to contain the leaves and node models 252 | self.node_models = np.empty((self.tree_.node_count,), dtype=object) 253 | 254 | # -- Pruning requires to know the global deviation of the target 255 | global_std_dev = np.nanstd(y) 256 | # global_abs_dev = np.nanmean(np.abs(y)) 257 | 258 | # -- Pruning requires to know the samples that reached each node. 259 | # From http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html 260 | # Retrieve the decision path of each sample. 261 | samples_to_nodes = self.decision_path(X) 262 | # * row i is to see the nodes (non-empty j) in which sample i appears. 263 | # sparse row-first (CSR) format is OK 264 | # * column j is to see the samples (non-empty i) that fall into that 265 | # node. To do that, we need to make it CSC 266 | nodes_to_samples = samples_to_nodes.tocsc() 267 | 268 | # -- execute the pruning 269 | # self.features_usage is a dict feature_idx -> nb times used, only 270 | # for used features. 271 | self.features_usage = build_models_and_get_pruning_info( 272 | self.tree_, 273 | X, 274 | y, 275 | nodes_to_samples, 276 | self.leaf_model, 277 | self.node_models, 278 | global_std_dev, 279 | use_pruning=self.use_pruning, 280 | ) 281 | 282 | # -- cleanup to compress inner structures: only keep non-pruned ones 283 | self._cleanup_tree() 284 | 285 | if self.debug_prints: 286 | logger.debug("(debug_prints) Pruned tree:") 287 | logger.debug(self.as_pretty_text(node_ids=True)) 288 | 289 | if self.use_smoothing == "installed": 290 | # Retrieve the NEW decision path of each sample. 291 | samples_to_nodes = self.decision_path(X) 292 | # * row i is to see the nodes (non-empty j) in which sample i 293 | # appears. sparse row-first (CSR) format is OK 294 | # * column j is to see the samples (non-empty i) that fall into 295 | # that node. To do that, we need to make it CSC 296 | nodes_to_samples = samples_to_nodes.tocsc() 297 | 298 | # default behaviour for smoothing constant and ratio 299 | smoothing_constant = self._get_smoothing_constant_to_use(X) 300 | 301 | self.install_smoothing(X, y, nodes_to_samples, smoothing_constant=smoothing_constant) 302 | 303 | if self.debug_prints: 304 | logger.debug("(debug_prints) Pruned and smoothed tree:") 305 | logger.debug(self.as_pretty_text(node_ids=True)) 306 | 307 | return self 308 | 309 | def _get_smoothing_constant_to_use(self, X): 310 | """ 311 | Returns the smoothing_constant to use for smoothing, based on current 312 | settings and X data 313 | """ 314 | if self.smoothing_constant_ratio is not None: 315 | nb_training_samples = X.shape[0] 316 | smoothing_cstt = self.smoothing_constant_ratio * nb_training_samples 317 | if smoothing_cstt < 15: 318 | warn( 319 | "smoothing constant ratio %s is leading to an extremely " 320 | "small smoothing constant %s because nb training samples" 321 | " is %s. Clipping to 15." % (self.smoothing_constant_ratio, smoothing_cstt, nb_training_samples) 322 | ) 323 | smoothing_cstt = 15 324 | else: 325 | smoothing_cstt = self.smoothing_constant 326 | if smoothing_cstt is None: 327 | smoothing_cstt = 15 # default from the original Quinlan paper 328 | 329 | return smoothing_cstt 330 | 331 | def _cleanup_tree(self): 332 | """ 333 | Reduces the size of this object by removing from internal structures 334 | all items that are not used any more (all leaves that have been pruned) 335 | """ 336 | old_tree = self.tree_ 337 | old_node_modls = self.node_models 338 | 339 | # Get all information to create a copy of the inner tree. 340 | # Note: _tree.copy() is gone so we use the pickle way 341 | # --- Base info: nb features, nb outputs, output classes 342 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L631 343 | # [1] = (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), self.n_outputs) 344 | # So these remain, we are just interested in changing the node-related arrays 345 | new_tree = _tree.Tree(*old_tree.__reduce__()[1]) 346 | 347 | # --- Node info 348 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637 349 | dct = old_tree.__getstate__().copy() 350 | 351 | # cleanup: only keep the nodes that are not undefined. 352 | # note: this is identical to 353 | # valid_nodes_indices = dct["nodes"]['left_child'] != TREE_UNDEFINED 354 | valid_nodes_indices = old_tree.children_left != _tree.TREE_UNDEFINED 355 | # valid_nodes_indices2 = old_tree.children_right != TREE_UNDEFINED 356 | # assert np.all(valid_nodes_indices == valid_nodes_indices2) 357 | new_node_count = sum(valid_nodes_indices) 358 | 359 | # create empty new structures 360 | n_shape = (new_node_count, *dct["nodes"].shape[1:]) 361 | new_nodes = _empty_contig_ar(n_shape, dtype=dct["nodes"].dtype) 362 | v_shape = (new_node_count, *dct["values"].shape[1:]) 363 | new_values = _empty_contig_ar(v_shape, dtype=dct["values"].dtype) 364 | m_shape = (new_node_count, *old_node_modls.shape[1:]) 365 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_modls.dtype) 366 | 367 | # Fill structures while reindexing the tree and remembering the depth 368 | global next_free_id 369 | next_free_id = 0 370 | 371 | def _compress(old_node_id): 372 | """ 373 | 374 | Parameters 375 | ---------- 376 | old_node_id 377 | 378 | Returns 379 | ------- 380 | the depth and new indices of left and right children 381 | 382 | """ 383 | global next_free_id 384 | new_node_id = next_free_id 385 | next_free_id += 1 386 | 387 | # use the old tree to walk 388 | old_node = dct["nodes"][old_node_id] 389 | left_id = old_node["left_child"] 390 | right_id = old_node["right_child"] 391 | 392 | # Create the new node with a copy of the old 393 | new_nodes[new_node_id] = old_node # this is an entire row so it is probably copied already by doing so. 394 | new_values[new_node_id] = dct["values"][old_node_id] 395 | new_node_models[new_node_id] = copy(old_node_modls[old_node_id]) 396 | 397 | if left_id == _tree.TREE_LEAF: 398 | # ... and right_id == _tree.TREE_LEAF 399 | # => node_id is a leaf. Nothing to do 400 | return 1, new_node_id 401 | else: 402 | 403 | left_depth, new_id_left = _compress(left_id) 404 | right_depth, new_id_right = _compress(right_id) 405 | 406 | # store the new indices 407 | new_nodes[new_node_id]["left_child"] = new_id_left 408 | new_nodes[new_node_id]["right_child"] = new_id_right 409 | 410 | return 1 + max(left_depth, right_depth), new_node_id 411 | 412 | # complete definition of the new tree 413 | dct["max_depth"] = _compress(0)[0] - 1 # root node has depth 0, not 1 414 | dct["node_count"] = new_node_count # new_nodes.shape[0] 415 | dct["nodes"] = new_nodes 416 | dct["values"] = new_values 417 | new_tree.__setstate__(dct) 418 | 419 | # Fix an issue on sklearn 0.17.1: setstate was not updating max_depth 420 | # See https://github.com/scikit-learn/scikit-learn/blob/0.17.1/sklearn/tree/_tree.pyx#L623 421 | new_tree.max_depth = dct["max_depth"] 422 | 423 | # update self fields 424 | self.tree_ = new_tree 425 | self.node_models = new_node_models 426 | 427 | def install_smoothing(self, X_train_all, y_train_all, nodes_to_samples, smoothing_constant): 428 | """ 429 | Executes the smoothing procedure described in the M5 and M5P paper, 430 | "once for all". This means that all models are modified so that after 431 | this method has completed, each model in the tree is already a smoothed 432 | model. 433 | 434 | This has pros (prediction speed) and cons (the model equations are 435 | harder to read - lots of redundancy) 436 | 437 | It can only be done if leaf models are instances of `LinearRegression` 438 | """ 439 | # check validity: leaf models have to support pre-computed smoothing 440 | if not isinstance(self.leaf_model, LinearRegression): 441 | raise TypeError( 442 | "`install_smoothing` is only available if leaf " 443 | "models are instances of `LinearRegression` or a " 444 | "subclass" 445 | ) 446 | 447 | # Select the Error metric to compute model errors (used to compare 448 | # node with subtree for pruning) 449 | # TODO in Weka they use RMSE, but in papers they use MAE. this should be a parameter. 450 | # TODO Shall we go further and store the residuals, or a bunch of metrics? not sure 451 | err_metric = root_mean_squared_error # mean_absolute_error 452 | 453 | # --- Node info 454 | # TODO we should probably not do this once in each method, but once or give access directly (no copy) 455 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637 456 | dct = self.tree_.__getstate__().copy() 457 | old_node_models = self.node_models 458 | 459 | m_shape = old_node_models.shape 460 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_models.dtype) 461 | 462 | def smooth__( 463 | coefs_, 464 | n_samples, 465 | features=None, 466 | parent=None, # type: _SmoothingDetails 467 | parent_features=None, 468 | k=smoothing_constant, # 469 | ): 470 | # type: (...) -> _SmoothingDetails 471 | """ 472 | Smoothes the model coefficients or intercept `coefs_` (an array or 473 | a scalar), using the smoothing results at parent node. 474 | 475 | At each node we keep in memory three results A, B, C. 476 | - the new coef at each node is A + B 477 | - the recursion equations are 478 | - B(n) = (k / (n_samples(n) + k)) * A(n-1) + B(n-1) 479 | - C(n) = (n_samples(n) / (n_samples(n) + k)) * C(n-1) 480 | - A(n) = coef(n)*C(n) 481 | """ 482 | if parent is None: 483 | # A0 = coef(0), C0 = 1, B0 = 0 484 | if features is None: 485 | # single scalar value 486 | return _SmoothingDetails(A=coefs_, B=0, C=1) 487 | else: 488 | # vector 489 | if len(coefs_) == 0: 490 | coefs_ = np.asarray(coefs_) 491 | if coefs_.shape[0] != len(features): 492 | raise ValueError("nb features does not match the nb " "of coefficients") 493 | return _SmoothingDetails( 494 | A=coefs_, 495 | B=np.zeros(coefs_.shape, dtype=coefs_.dtype), 496 | C=np.ones(coefs_.shape, dtype=coefs_.dtype), 497 | ) 498 | else: 499 | # B(n) = k/(n_samples(n)+k)*A(n-1) + B(n-1) 500 | # C(n) = (n_samples(n) / (n_samples(n) + k)) * C(n-1) 501 | Bn = (k / (n_samples + k)) * parent.A + parent.B 502 | Cn = (n_samples / (n_samples + k)) * parent.C 503 | 504 | # A(n) = coef(n) * C(n) 505 | if features is None: 506 | # single scalar value: easy 507 | An = coefs_ * Cn 508 | return _SmoothingDetails(A=An, B=Bn, C=Cn) 509 | else: 510 | # vector of coefs: we have to 'expand' the coefs array 511 | # because the coefs at this node apply for (features) 512 | # while coefs at parent node apply for (parent_features) 513 | An = np.zeros(Cn.shape, dtype=Cn.dtype) 514 | parent_features = np.array(parent_features) 515 | features = np.array(features) 516 | 517 | # Thanks https://stackoverflow.com/a/8251757/7262247 ! 518 | index = np.argsort(parent_features) 519 | sorted_parents = parent_features[index] 520 | sorted_index = np.searchsorted(sorted_parents, features) 521 | 522 | features_index = np.take(index, sorted_index, mode="clip") 523 | if np.any(parent_features[features_index] != features): 524 | # result = np.ma.array(features_index, mask=mask) 525 | raise ValueError( 526 | "Internal error - please report this." 527 | "One feature was found in the child " 528 | "node, that was not in the parent " 529 | "node." 530 | ) 531 | 532 | if len(features_index) > 0: 533 | An[features_index] = coefs_ * Cn[features_index] 534 | 535 | return _SmoothingDetails(A=An, B=Bn, C=Cn) 536 | 537 | def _smooth( 538 | node_id, 539 | parent_features=None, 540 | parent_coefs: _SmoothingDetails = None, 541 | parent_intercept: _SmoothingDetails = None, 542 | ): 543 | # Gather all info on this node 544 | # --base regression tree 545 | node_info = dct["nodes"][node_id] 546 | left_id = node_info["left_child"] 547 | right_id = node_info["right_child"] 548 | # --additional model 549 | node_model = old_node_models[node_id] 550 | # --samples 551 | samples_at_this_node = get_samples_at_node(node_id, nodes_to_samples) 552 | n_samples_at_this_node = samples_at_this_node.shape[0] 553 | 554 | # Note: should be equal to tree.n_node_samples[node_id] 555 | if n_samples_at_this_node != self.tree_.n_node_samples[node_id]: 556 | raise ValueError("n_samples_at_this_node: Unexpected value, please report it as issue.") 557 | 558 | y_true_this_node = y_train_all[samples_at_this_node] 559 | X_this_node = X_train_all[samples_at_this_node, :] 560 | 561 | # (1) smooth current node 562 | parent_features = parent_features if parent_features is not None else None 563 | is_constant_leaf = False 564 | if left_id == _tree.TREE_LEAF and isinstance(node_model, ConstantLeafModel): 565 | is_constant_leaf = True 566 | node_features = () 567 | smoothed_features = parent_features if parent_features is not None else node_features 568 | node_coefs = () 569 | node_intercept = dct["values"][node_id] 570 | 571 | # Extract the unique scalar value 572 | if len(node_intercept) != 1: 573 | raise ValueError("node_intercept: Unexpected value: , please report it as issue." % node_intercept) 574 | 575 | node_intercept = node_intercept.item() 576 | 577 | else: 578 | # A leaf LinRegLeafModel or a split SplitNodeModel 579 | node_features = node_model.features 580 | smoothed_features = parent_features if parent_features is not None else node_features 581 | node_coefs = node_model.model.coef_ 582 | node_intercept = node_model.model.intercept_ 583 | 584 | # Create a new linear regression model with smoothed coefficients 585 | smoothed_sklearn_model = clone(self.leaf_model) 586 | smoothed_coefs = smooth__( 587 | node_coefs, 588 | features=node_features, 589 | n_samples=n_samples_at_this_node, 590 | parent=parent_coefs, 591 | parent_features=parent_features, 592 | ) 593 | smoothed_intercept = smooth__(node_intercept, n_samples=n_samples_at_this_node, parent=parent_intercept) 594 | smoothed_sklearn_model.coef_ = smoothed_coefs.A + smoothed_coefs.B 595 | smoothed_sklearn_model.intercept_ = smoothed_intercept.A + smoothed_intercept.B 596 | 597 | # Finally update the node 598 | if is_constant_leaf: 599 | smoothed_node_model = LinRegLeafModel(smoothed_features, smoothed_sklearn_model, None) 600 | else: 601 | smoothed_node_model = copy(node_model) 602 | smoothed_node_model.features = smoothed_features 603 | smoothed_node_model.model = smoothed_sklearn_model 604 | 605 | # Remember the new smoothed model 606 | new_node_models[node_id] = smoothed_node_model 607 | 608 | if left_id == _tree.TREE_LEAF: 609 | # If this is a leaf, update the prediction error on X 610 | y_pred_this_node = smoothed_node_model.predict(X_this_node) 611 | smoothed_node_model.error = err_metric(y_true_this_node, y_pred_this_node) 612 | 613 | else: 614 | # If this is a split node - recurse on each subtree 615 | _smooth( 616 | left_id, 617 | parent_features=smoothed_features, 618 | parent_coefs=smoothed_coefs, 619 | parent_intercept=smoothed_intercept, 620 | ) 621 | _smooth( 622 | right_id, 623 | parent_features=smoothed_features, 624 | parent_coefs=smoothed_coefs, 625 | parent_intercept=smoothed_intercept, 626 | ) 627 | 628 | # Update the error using the same formula than the one we used 629 | # in build_models_and_get_pruning_info 630 | y_pred_children = predict_from_leaves_no_smoothing(self.tree_, new_node_models, X_this_node) 631 | err_children = err_metric(y_true_this_node, y_pred_children) 632 | 633 | # the total number of parameters is the sum of params in each 634 | # branch PLUS 1 for the split 635 | n_params_splitmodel = new_node_models[left_id].n_params + new_node_models[right_id].n_params + 1 636 | smoothed_node_model.n_params = n_params_splitmodel 637 | # do not adjust the error now, simply store the raw one 638 | # smoothed_node_model.error = compute_adjusted_error( 639 | # err_children, n_samples_at_this_node, n_params_splitmodel) 640 | smoothed_node_model.error = err_children 641 | 642 | return 643 | 644 | # smooth the whole tree 645 | _smooth(0) 646 | 647 | # use the new node models now 648 | self.node_models = new_node_models 649 | 650 | # remember the smoothing constant installed 651 | self.installed_smoothing_constant = smoothing_constant 652 | return 653 | 654 | def denormalize(self, x_scaler, y_scaler): 655 | """ 656 | De-normalizes this model according to the provided x and y normalization scalers. 657 | Currently only StandardScaler issupported. 658 | 659 | :param x_scaler: a StandardScaler or None 660 | :param y_scaler: a StandardScaler or None 661 | :return: 662 | """ 663 | # perform denormalization 664 | self._denormalize_tree(x_scaler, y_scaler) 665 | 666 | def _denormalize_tree(self, x_scaler, y_scaler): 667 | """ 668 | De-normalizes all models in the tree 669 | :return: 670 | """ 671 | old_tree = self.tree_ 672 | old_node_models = self.node_models 673 | 674 | # Get all information to create a copy of the inner tree. 675 | # Note: _tree.copy() is gone so we use the pickle way 676 | # --- Base info: nb features, nb outputs, output classes 677 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L631 678 | # [1] = (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), self.n_outputs) 679 | # So these remain, we are just interested in changing the node-related arrays 680 | new_tree = _tree.Tree(*old_tree.__reduce__()[1]) 681 | 682 | # --- Node info 683 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637 684 | dct = old_tree.__getstate__().copy() 685 | 686 | # create empty new structures for nodes and models 687 | n_shape = dct["nodes"].shape 688 | new_nodes = _empty_contig_ar(n_shape, dtype=dct["nodes"].dtype) 689 | v_shape = dct["values"].shape 690 | new_values = _empty_contig_ar(v_shape, dtype=dct["values"].dtype) 691 | m_shape = old_node_models.shape 692 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_models.dtype) 693 | 694 | def _denormalize(node_id, x_scaler, y_scaler): 695 | """ 696 | denormalizes the subtree below node `node_id`. 697 | 698 | - `new_nodes[node_id]` is filled with a copy of the old node 699 | (see old_node.dtype to see the various fields). If the node is 700 | a split node, the split threshold 701 | `new_nodes[node_id]['threshold']` is denormalized. 702 | 703 | - `new_values[node_id]` is filled with a denormalized copy of the 704 | old constant prediction at this node. Reminder: this constant 705 | prediction is actually used only when on the leaves now, but 706 | we keep it for ref. 707 | 708 | - `new_node_models[node_id]` is filled with a copy of the old 709 | model `old_node_models[node_id]`, that is denormalized if it 710 | is not a constant leaf 711 | 712 | :param node_id: 713 | :return: (nothing) 714 | """ 715 | # use the old tree to walk 716 | old_node = dct["nodes"][node_id] 717 | left_id = old_node['left_child'] 718 | right_id = old_node['right_child'] 719 | 720 | # Create the new node with a copy of the old 721 | new_nodes[node_id] = old_node # this is an entire row so it is probably copied already by doing so. 722 | new_model = copy(old_node_models[node_id]) 723 | new_node_models[node_id] = new_model 724 | 725 | # Create the new value by de-scaling y 726 | if y_scaler is not None: 727 | # Note: if this is a split node with a linear regression model 728 | # the value will never be used. However to preserve consistency 729 | # of the whole values structure and for debugging purposes, we 730 | # choose this safe path of denormalizing ALL. 731 | # TODO we could also do it at once outside of the recursive 732 | # calls, but we should check for side effects 733 | new_values[node_id] = y_scaler.inverse_transform( 734 | dct["values"][node_id] 735 | ) 736 | else: 737 | # no denormalization: simply copy 738 | new_values[node_id] = dct["values"][node_id] 739 | 740 | if left_id == _tree.TREE_LEAF: 741 | # ... and right_id == _tree.TREE_LEAF 742 | # => node_id is a leaf 743 | if isinstance(new_model, ConstantLeafModel): 744 | # nothing to do: we already re-scaled the value 745 | return 746 | elif isinstance(new_model, LinRegLeafModel): 747 | # denormalize model 748 | new_model.denormalize(x_scaler, y_scaler) 749 | return 750 | else: 751 | raise TypeError("Internal error - Leafs can only be" 752 | "constant or linear regression") 753 | else: 754 | # this is a split node, denormalize each subtree 755 | _denormalize(left_id, x_scaler, y_scaler) 756 | _denormalize(right_id, x_scaler, y_scaler) 757 | 758 | # denormalize the split value if needed 759 | if x_scaler is not None: 760 | split_feature = old_node['feature'] 761 | # The denormalizer requires a vector with all the features, 762 | # even if we only want to denormalize one. 763 | # -- put split value in a vector where it has pos 'feature' 764 | old_threshold_and_zeros = np.zeros((self.n_features_, ), dtype=dct["nodes"]['threshold'].dtype) 765 | old_threshold_and_zeros[split_feature] = old_node['threshold'] 766 | # -- denormalize the vector and retrieve value 'feature' 767 | new_nodes[node_id]['threshold'] = x_scaler.inverse_transform(old_threshold_and_zeros)[split_feature] 768 | else: 769 | # no denormalization: simply copy 770 | new_nodes[node_id]['threshold'] = old_node['threshold'] 771 | 772 | if isinstance(new_model, SplitNodeModel): 773 | # denormalize model at split node too, even if it is not 774 | # always used (depending on smoothing mode) 775 | new_model.denormalize(x_scaler, y_scaler) 776 | else: 777 | raise TypeError("Internal error: all intermediate nodes" 778 | "should be SplitNodeModel") 779 | 780 | return 781 | 782 | # denormalize the whole tree and put the result in (new_nodes, 783 | # new_values, new_node_models) recursively 784 | _denormalize(0, x_scaler, y_scaler) 785 | 786 | # complete definition of the new tree 787 | dct["nodes"] = new_nodes 788 | dct["values"] = new_values 789 | new_tree.__setstate__(dct) 790 | 791 | # update the self fields 792 | # self.features_usage 793 | self.tree_ = new_tree 794 | self.node_models = new_node_models 795 | 796 | def compress_features(self): 797 | """ 798 | Compresses the model and returns a vector of required feature indices. 799 | This model input will then be X[:, features] instead of X. 800 | """ 801 | used_features = sorted(self.features_usage.keys()) 802 | new_features_lookup_dct = {old_feature_idx: i for i, old_feature_idx in enumerate(used_features)} 803 | 804 | if used_features == list(range(self.n_features_in_)): 805 | # NOTHING TO DO: we need all features 806 | return used_features 807 | 808 | # Otherwise, We can compress. For this we have to create a copy of the 809 | # tree because we will change its internals 810 | old_tree = self.tree_ 811 | old_node_modls = self.node_models 812 | 813 | # Get all information to create a copy of the inner tree. 814 | # Note: _tree.copy() is gone so we use the pickle way 815 | # --- Base info: nb features, nb outputs, output classes 816 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L631 817 | # [1] = (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), self.n_outputs) 818 | # So these remain, we are just interested in changing the node-related arrays 819 | new_tree = _tree.Tree(*old_tree.__reduce__()[1]) 820 | 821 | # --- Node info 822 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637 823 | dct = old_tree.__getstate__().copy() 824 | 825 | # create empty new structures for nodes and models 826 | n_shape = dct["nodes"].shape 827 | new_nodes = _empty_contig_ar(n_shape, dtype=dct["nodes"].dtype) 828 | m_shape = old_node_modls.shape 829 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_modls.dtype) 830 | 831 | def _compress_features(node_id): 832 | """ 833 | 834 | Parameters 835 | ---------- 836 | node_id 837 | 838 | Returns 839 | ------- 840 | the depth and new indices of left and right children 841 | 842 | """ 843 | # use the old tree to walk 844 | old_node = dct["nodes"][node_id] 845 | left_id = old_node["left_child"] 846 | right_id = old_node["right_child"] 847 | 848 | # Create the new node with a copy of the old 849 | new_nodes[node_id] = old_node # this is an entire row so it is probably copied already by doing so. 850 | new_model = copy(old_node_modls[node_id]) 851 | new_node_models[node_id] = new_model 852 | 853 | if left_id == _tree.TREE_LEAF: 854 | # ... and right_id == _tree.TREE_LEAF 855 | # => node_id is a leaf 856 | if isinstance(new_model, ConstantLeafModel): 857 | # no features used 858 | return 859 | elif isinstance(new_model, LinRegLeafModel): 860 | # compress that model 861 | new_model.reindex_features(new_features_lookup_dct) 862 | return 863 | else: 864 | raise TypeError("Internal error - Leafs can only be" "constant or linear regression") 865 | else: 866 | 867 | _compress_features(left_id) 868 | _compress_features(right_id) 869 | 870 | # store the new split feature index in the node 871 | new_nodes[node_id]["feature"] = new_features_lookup_dct[old_node["feature"]] 872 | 873 | if not isinstance(new_model, SplitNodeModel): 874 | raise TypeError("Internal error: all intermediate nodes" "should be SplitNodeModel") 875 | 876 | # TODO now that split node models have a linear regression 877 | # model too, we should update them here 878 | 879 | return 880 | 881 | _compress_features(0) 882 | 883 | # complete definition of the new tree 884 | dct["nodes"] = new_nodes 885 | new_tree.__setstate__(dct) 886 | 887 | # update the self fields 888 | self.features_usage = {new_features_lookup_dct[k]: v for k, v in self.features_usage.items()} 889 | self.tree_ = new_tree 890 | self.node_models = new_node_models 891 | self.n_features_in_ = len(self.features_usage) 892 | 893 | # return the vector of used features 894 | return used_features 895 | 896 | @property 897 | def feature_importances_(self): 898 | """Return the feature importances (the higher, the more important). 899 | 900 | Returns 901 | ------- 902 | feature_importances_ : array, shape = [n_features] 903 | """ 904 | check_is_fitted(self, "tree_") 905 | 906 | # TODO adapt ? 907 | # features = np.array([self.features_usage[k] for k in sorted(self.features_usage.keys())], dtype=int) 908 | features = self.tree_.compute_feature_importances() 909 | 910 | return features 911 | 912 | def predict(self, X, check_input=True, smooth_predictions=None, smoothing_constant=None): 913 | """Predict class or regression value for X. 914 | 915 | For a classification model, the predicted class for each sample in X is 916 | returned. For a regression model, the predicted value based on X is 917 | returned. 918 | 919 | Parameters 920 | ---------- 921 | X : numpy.ndarray, shape (n_samples, n_features) 922 | The input samples. Internally, it will be converted to 923 | ``dtype=np.float32`` and if a sparse matrix is provided 924 | to a sparse ``csr_matrix``. 925 | check_input : boolean, (default=True) 926 | Allow to bypass several input checking. 927 | Don't use this parameter unless you know what you do. 928 | smooth_predictions : boolean, (default=None) 929 | None means "use self config" 930 | smoothing_constant: int, (default=self.smoothing_constant) 931 | Smoothing constant used when smooth_predictions is True. During 932 | smoothing, child node models are recursively mixed with their 933 | parent node models, and the mix is done using a weighted sum. The 934 | weight given to the child model is the number of training samples 935 | that reached its node, while the weight given to the parent model 936 | is the smoothing constant. Therefore it can be seen as an 937 | equivalent number of samples that parent models represent when 938 | injected into the recursive weighted sum. 939 | 940 | Returns 941 | ------- 942 | y : array of shape = [n_samples] or [n_samples, n_outputs] 943 | The predicted classes, or the predict values. 944 | """ 945 | check_is_fitted(self, "tree_") 946 | 947 | # If this is just a constant node only check the input's shape. 948 | if self.n_features_in_ == 0: 949 | # perform the input checking manually to set ensure_min_features=0 950 | X = check_array(X, dtype=DTYPE, accept_sparse="csr", ensure_min_features=0) 951 | if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc): 952 | raise ValueError("No support for np.int64 index based " "sparse matrices") 953 | # skip it then 954 | check_input = False 955 | 956 | # validate and convert dtype 957 | X = self._validate_X_predict(X, check_input) 958 | 959 | # -------- This is the only change wrt parent class. TODO maybe rather replace self.tree_ with a proxy --------- 960 | if smooth_predictions is None: 961 | # Default: smooth prediction at prediction time if configured as 962 | # such in the model. Note: if self.use_smoothing == 'installed', 963 | # models are already smoothed models, so no need to smooth again 964 | smooth_predictions = self.use_smoothing == "on_prediction" 965 | else: 966 | # user provided an explicit value for smooth_predictions 967 | if not smooth_predictions and self.use_smoothing == "installed": 968 | raise ValueError( 969 | "Smoothing has been pre-installed on this " 970 | "tree, it is not anymore possible to make " 971 | "predictions without smoothing" 972 | ) 973 | 974 | if smooth_predictions and smoothing_constant is None: 975 | # default parameter for smoothing is the one defined in the model 976 | # (with possible ratio) 977 | smoothing_constant = self._get_smoothing_constant_to_use(X) 978 | 979 | # Do not use the embedded tree (like in super): it has been pruned but 980 | # still has constant nodes 981 | # proba = self.tree_.predict(X) 982 | if not smooth_predictions: 983 | proba = predict_from_leaves_no_smoothing(self.tree_, self.node_models, X) 984 | else: 985 | proba = predict_from_leaves(self, X, smoothing=True, smoothing_constant=smoothing_constant) 986 | if len(proba.shape) < 2: 987 | proba = proba.reshape(-1, 1) 988 | # ------------------------------------------ 989 | 990 | n_samples = X.shape[0] 991 | 992 | # Classification 993 | if is_classifier(self): 994 | if self.n_outputs_ == 1: 995 | return self.classes_.take(np.argmax(proba, axis=1), axis=0) 996 | else: 997 | predictions = np.zeros((n_samples, self.n_outputs_)) 998 | for k in range(self.n_outputs_): 999 | predictions[:, k] = self.classes_[k].take(np.argmax(proba[:, k], axis=1), axis=0) 1000 | return predictions 1001 | 1002 | # Regression 1003 | else: 1004 | if self.n_outputs_ == 1: 1005 | return proba[:, 0] 1006 | else: 1007 | return proba[:, :, 0] 1008 | 1009 | 1010 | class ConstantLeafModel: 1011 | """ 1012 | Represents the additional information about a leaf node that is not pruned. 1013 | It contains the error associated with the training samples at this node. 1014 | 1015 | Note: the constant value is not stored here, as it is already available in 1016 | the sklearn tree struct. So to know the prediction at this node, use 1017 | `tree.value[node_id]` 1018 | """ 1019 | 1020 | __slots__ = ("error",) 1021 | 1022 | def __init__(self, error): 1023 | self.error = error 1024 | 1025 | @property 1026 | def n_params(self) -> int: 1027 | """ 1028 | Returns the number of parameters used by this model, including the 1029 | constant driver 1030 | """ 1031 | return 1 1032 | 1033 | @staticmethod 1034 | def predict_cstt(tree, node_id, n_samples): 1035 | """ 1036 | This static method is a helper to get an array of constant predictions 1037 | associated with this node. 1038 | 1039 | Parameters 1040 | ---------- 1041 | tree 1042 | node_id 1043 | n_samples 1044 | 1045 | Returns 1046 | ------- 1047 | 1048 | """ 1049 | cstt_prediction = tree.value[node_id] 1050 | 1051 | # cstt_prediction can be multioutput so it is an array. replicate it 1052 | from numpy import matlib # note: np.matlib not available in 1.10.x 1053 | 1054 | return matlib.repmat(cstt_prediction, n_samples, 1) 1055 | 1056 | 1057 | class LinRegNodeModel(DeNormalizableMixIn): 1058 | """ 1059 | Represents the additional information about a tree node that contains a 1060 | linear regression model. 1061 | 1062 | It contains 1063 | - the features used by the model, 1064 | - the scikit learn model object itself, 1065 | - and the error for the training samples that reached this node 1066 | 1067 | """ 1068 | __slots__ = ("features", "model", "error", "n_params") 1069 | 1070 | def __init__(self, features, model, error) -> None: 1071 | self.features = features 1072 | self.model = model 1073 | self.error = error 1074 | self.n_params = len(self.features) + 1 1075 | 1076 | def to_text(self, feature_names=None, target_name=None, precision=3, 1077 | line_breaks=False): 1078 | """ Returns a text representation of the linear regression model """ 1079 | return linreg_model_to_text(self.model, feature_names=feature_names, 1080 | target_name=target_name, 1081 | precision=precision, 1082 | line_breaks=line_breaks) 1083 | 1084 | def reindex_features(self, new_features_lookup_dct): 1085 | """ 1086 | Reindexes the required features using the provided lookup dictionary. 1087 | 1088 | Parameters 1089 | ---------- 1090 | new_features_lookup_dct 1091 | 1092 | Returns 1093 | ------- 1094 | 1095 | """ 1096 | self.features = [new_features_lookup_dct[f] for f in self.features] 1097 | 1098 | def denormalize(self, 1099 | x_scaler: StandardScaler = None, 1100 | y_scaler: StandardScaler = None 1101 | ): 1102 | """ 1103 | De-normalizes the linear model. 1104 | 1105 | :param x_scaler: 1106 | :param y_scaler: 1107 | :return: 1108 | """ 1109 | # create a clone of the x scaler with only the used features 1110 | x_scaler = copy(x_scaler) 1111 | if len(self.features) > 0: 1112 | x_scaler.scale_ = x_scaler.scale_[self.features] 1113 | x_scaler.mean_ = x_scaler.mean_[self.features] 1114 | else: 1115 | # in that particular case, the above expression is not working 1116 | x_scaler.scale_ = x_scaler.scale_[0:0] 1117 | x_scaler.mean_ = x_scaler.mean_[0:0] 1118 | 1119 | # use the denormalize function on the internal model 1120 | # TODO what if internal model is not able ? 1121 | self.model.denormalize(x_scaler, y_scaler) 1122 | 1123 | def predict(self, X: np.ndarray): 1124 | """Performs a prediction for X, only using the required features.""" 1125 | if len(self.features) < 1 and isinstance(self.model, LinearRegression): 1126 | # unfortunately LinearRegression models do not like it when no 1127 | # features are needed: their input validation requires at least 1. 1128 | # so we do it ourselves. 1129 | return self.model.intercept_ * np.ones((X.shape[0], 1)) 1130 | else: 1131 | return self.model.predict(X[:, self.features]) 1132 | 1133 | 1134 | class LinRegLeafModel(LinRegNodeModel): 1135 | """ 1136 | Represents the additional information about a leaf node with a linear 1137 | regression model 1138 | """ 1139 | pass 1140 | 1141 | 1142 | class SplitNodeModel(LinRegNodeModel): 1143 | """ 1144 | Represents the additional information about a split node, with a linear 1145 | regression model. 1146 | """ 1147 | __slots__ = ("n_params",) 1148 | 1149 | def __init__(self, n_params, error, features, model): 1150 | self.n_params = n_params 1151 | super(SplitNodeModel, self).__init__(features=features, model=model, error=error) 1152 | 1153 | 1154 | PRUNING_MULTIPLIER = 2 1155 | # see https://github.com/bnjmn/weka/blob/master/weka/src/main/java/weka/classifiers/trees/m5/RuleNode.java#L124 1156 | # TODO check why they use 2 instead of 1 (from the article) ?? 1157 | 1158 | 1159 | def root_mean_squared_error(*args, **kwargs): 1160 | return np.sqrt(mean_squared_error(*args, **kwargs)) 1161 | 1162 | 1163 | def prune_on_min_impurity(tree): 1164 | """ 1165 | Edits the given tree so as to prune subbranches that do not respect the min 1166 | impurity criterion. 1167 | 1168 | The paper suggests to do this with 5% but min_impurity_split is not 1169 | available in 0.17 (and criterion is not std but mse so we have to square) 1170 | 1171 | Parameters 1172 | ---------- 1173 | tree 1174 | 1175 | Returns 1176 | ------- 1177 | 1178 | """ 1179 | left_children = tree.children_left 1180 | right_children = tree.children_right 1181 | impurities = tree.impurity 1182 | 1183 | # The paper suggests to do this with 5% but min_impurity_split is not 1184 | # available in 0.17 (and criterion is not std but mse so we have to square) 1185 | # TODO adapt the formula to criterion used and/or generalize with min_impurity_split_as_initial_ratio 1186 | root_impurity = impurities[0] 1187 | impurity_threshold = root_impurity * (0.05 ** 2) 1188 | 1189 | def stop_on_min_impurity(node_id): 1190 | # note: in the paper that is 0.05 but criterion is on std. Here 1191 | # impurity is mse so a squared equivalent of std. 1192 | left_id = left_children[node_id] 1193 | right_id = right_children[node_id] 1194 | if left_id != _tree.TREE_LEAF: # a split node 1195 | if impurities[node_id] < impurity_threshold: 1196 | # stop here, this will be a leaf 1197 | prune_children(node_id, tree) 1198 | else: 1199 | stop_on_min_impurity(left_id) 1200 | stop_on_min_impurity(right_id) 1201 | 1202 | stop_on_min_impurity(0) 1203 | 1204 | 1205 | def build_models_and_get_pruning_info( 1206 | tree, X_train_all, y_train_all, nodes_to_samples, leaf_model, node_models, global_std_dev, use_pruning, node_id=0 1207 | ): 1208 | """ 1209 | 1210 | Parameters 1211 | ---------- 1212 | tree : Tree 1213 | A tree that will be pruned on the way 1214 | X_train_all 1215 | y_train_all 1216 | nodes_to_samples 1217 | leaf_model 1218 | node_models 1219 | global_std_dev 1220 | use_pruning 1221 | node_id 1222 | 1223 | Returns 1224 | ------- 1225 | a dictionary where the key is the feature index and the value is 1226 | the number of samples where this feature is used 1227 | """ 1228 | 1229 | # Select the Error metric to compute model errors (used to compare node 1230 | # with subtree for pruning) 1231 | # TODO in Weka they use RMSE, but in papers they use MAE. could be a param 1232 | # TODO Shall we go further and store the residuals, or a bunch of metrics? 1233 | err_metric = root_mean_squared_error # mean_absolute_error 1234 | 1235 | # Get the samples associated with this node 1236 | samples_at_this_node = get_samples_at_node(node_id, nodes_to_samples) 1237 | n_samples_at_this_node = samples_at_this_node.shape[0] 1238 | y_true_this_node = y_train_all[samples_at_this_node] 1239 | 1240 | # Is this a leaf node or a split node ? 1241 | left_node = tree.children_left[node_id] # this way we do not have to query it again in the else. 1242 | if left_node == _tree.TREE_LEAF: 1243 | # --> Current node is a LEAF. See is_leaf(node_id, tree) if you have doubts <-- 1244 | 1245 | # -- create a linear model for this node 1246 | # leaves do not have the right to use any features since they have no 1247 | # subtree: keep the constant prediction 1248 | y_pred_this_node = ConstantLeafModel.predict_cstt(tree, node_id, y_true_this_node.shape[0]) 1249 | err_this_node = err_metric(y_true_this_node, y_pred_this_node) 1250 | 1251 | # TODO when use_pruning = False, should we allow all features to be 1252 | # used instead of having to stick to the M5 rule of "only use a 1253 | # feature if the subtree includes a split with this feature" ? 1254 | # OR alternate proposal: should we transform the boolean use_pruning 1255 | # into a use_pruning_max_nb integer to say for example "only 2 level 1256 | # of pruning" ? 1257 | 1258 | # -- store the model information 1259 | node_models[node_id] = ConstantLeafModel(err_this_node) 1260 | 1261 | # -- return an empty dict for features used 1262 | return dict() # np.array([], dtype=int) 1263 | 1264 | else: 1265 | # --> Current node is a SPLIT <-- 1266 | right_node = tree.children_right[node_id] 1267 | 1268 | # (1) prune left and right subtree and get some information 1269 | features_l = build_models_and_get_pruning_info( 1270 | tree, X_train_all, y_train_all, nodes_to_samples, leaf_model, 1271 | node_models, global_std_dev, use_pruning, node_id=left_node 1272 | ) 1273 | features_r = build_models_and_get_pruning_info( 1274 | tree, X_train_all, y_train_all, nodes_to_samples, leaf_model, 1275 | node_models, global_std_dev, use_pruning, node_id=right_node 1276 | ) 1277 | 1278 | # (2) select only the samples that reach this node 1279 | X_this_node = X_train_all[samples_at_this_node, :] 1280 | 1281 | # (3) Create a model for this node 1282 | # -- fit a linear regression model taking into account only the 1283 | # features used in the subtrees + the split one 1284 | # TODO should we normalize=True (variance scaling) or use a whole 1285 | # pipeline here ? Not sure.. 1286 | # estimators = [('scale', StandardScaler()), ('clf', model_type())]; 1287 | # pipe = Pipeline(estimators) 1288 | # skmodel_this_node = leaf_model_type(**leaf_model_params) 1289 | skmodel_this_node = clone(leaf_model) 1290 | 1291 | # -- old - we used to only store the array of features 1292 | # selected_features = np.union1d(features_l, features_r) 1293 | # selected_features = np.union1d(selected_features, tree.feature[node_id]) 1294 | # -- new - we also gather the nb samples where this feature is used 1295 | selected_features_dct = features_l 1296 | for feature_id, n_samples_feat_used in features_r.items(): 1297 | if feature_id in selected_features_dct.keys(): 1298 | selected_features_dct[feature_id] += n_samples_feat_used 1299 | else: 1300 | selected_features_dct[feature_id] = n_samples_feat_used 1301 | selected_features_dct[tree.feature[node_id]] = n_samples_at_this_node 1302 | # -- use only the selected features, in the natural integer order 1303 | selected_features = sorted(selected_features_dct.keys()) 1304 | 1305 | X_train_this_node = X_this_node[:, selected_features] 1306 | skmodel_this_node.fit(X_train_this_node, y_true_this_node) 1307 | # -- predict and compute error 1308 | y_pred_this_node = skmodel_this_node.predict(X_train_this_node) 1309 | err_this_node = err_metric(y_true_this_node, y_pred_this_node) 1310 | # -- create the object 1311 | # TODO the paper suggest to perform recursive feature elimination in 1312 | # this model until adjusted_err_model is minimal, is it same in Weka ? 1313 | model_this_node = LinRegLeafModel(selected_features, skmodel_this_node, err_this_node) 1314 | 1315 | # (4) compute adj error criterion = ERR * (n+v)/(n-v) for both models 1316 | adjusted_err_model = compute_adjusted_error(err_this_node, n_samples_at_this_node, model_this_node.n_params) 1317 | 1318 | # (5) predict and compute adj error for the combination of child models 1319 | # -- Note: this is recursive so the leaves may contain linear models 1320 | # already 1321 | y_pred_children = predict_from_leaves_no_smoothing(tree, node_models, X_this_node) 1322 | err_children = err_metric(y_true_this_node, y_pred_children) 1323 | 1324 | # TODO the Weka implem (below) differs from the paper that suggests a 1325 | # weigthed sum of adjusted errors. This is maybe an equivalent formulation, to check. 1326 | # the total number of parameters if we do not prune, is the sum of 1327 | # params in each branch PLUS 1 for the split 1328 | n_params_splitmodel = node_models[left_node].n_params + node_models[right_node].n_params + 1 1329 | adjusted_err_children = compute_adjusted_error(err_children, n_samples_at_this_node, n_params_splitmodel) 1330 | 1331 | # (6) compare and either prune at this node or keep the subtrees 1332 | std_dev_this_node = np.nanstd(y_true_this_node) 1333 | # note: the first criterion is now already checked before that 1334 | # function call, in `prune_on_min_impurity` 1335 | if use_pruning and ( 1336 | # TODO these parameters should be in the constructor 1337 | # see also 2 comments about min_impurity_split_as_initial_ratio 1338 | std_dev_this_node < (global_std_dev * 0.05) 1339 | or (adjusted_err_model <= adjusted_err_children) 1340 | or (adjusted_err_model < (global_std_dev * 0.00001)) # so this means very good R² already 1341 | ): 1342 | # Choose model for this node rather than subtree model 1343 | # -- prune from this node on 1344 | removed_nodes = prune_children(node_id, tree) 1345 | node_models[removed_nodes] = _tree.TREE_UNDEFINED 1346 | 1347 | # store the model information 1348 | node_models[node_id] = model_this_node 1349 | 1350 | # update and return the features used 1351 | selected_features_dct = {k: n_samples_at_this_node for k in selected_features_dct.keys()} 1352 | return selected_features_dct 1353 | 1354 | else: 1355 | # The subtrees are good or we do not want pruning: keep them. 1356 | # This node will remain a split, and will only contain the digest 1357 | # about the subtree 1358 | node_models[node_id] = SplitNodeModel( 1359 | n_params_splitmodel, err_children, selected_features, skmodel_this_node 1360 | ) 1361 | 1362 | # return the features used 1363 | return selected_features_dct 1364 | 1365 | 1366 | def compute_adjusted_error(err, n_samples, n_parameters, multiplier=PRUNING_MULTIPLIER): 1367 | """ 1368 | Return the penalized error obtained from `err`, the prediction error at a 1369 | given tree node, by multiplying it by ``(n+p)/(n-p)``. `n` is the number of 1370 | samples at this node and `p` is the number of parameters of the model. 1371 | 1372 | According to the original M5 paper, this is to penalize complex models 1373 | used for few samples. 1374 | 1375 | Note that if ``n_samples <= n_parameters`` the denominator is zero or negative. 1376 | In this case an arbitrary high penalization factor is used: ``10 * err``. 1377 | 1378 | Parameters 1379 | ---------- 1380 | err : float 1381 | The model error for a given tree node. Typically the MAE. 1382 | 1383 | n_samples : int 1384 | The number of samples at this node. 1385 | 1386 | n_parameters : int 1387 | The number of parameters of the model at this node. 1388 | 1389 | Returns 1390 | ------- 1391 | penalized_mae : float 1392 | The penalized error ``err * (n+p)/(n-p)``, or ``10 * err`` if ``n <= p``. 1393 | """ 1394 | # 1395 | if n_samples <= n_parameters: 1396 | # denominator is zero or negative: use a large factor so as to penalize 1397 | # a lot this overly complex model. 1398 | factor = 10.0 # Caution says Yong in his code 1399 | else: 1400 | factor = (n_samples + multiplier * n_parameters) / (n_samples - n_parameters) 1401 | 1402 | return err * factor 1403 | 1404 | 1405 | def predict_from_leaves_no_smoothing(tree, node_models, X): 1406 | """ 1407 | Returns the prediction for all samples in X, based on using the appropriate 1408 | model leaf. 1409 | 1410 | This function 1411 | - uses tree.apply(X) to know in which tree leaf each sample in X goes 1412 | - then for each of the leaves that are actually touched, uses 1413 | node_models[leaf_id] to predict, for the X reaching that leaf 1414 | 1415 | This function assumes that node_models contains non-empty LinRegLeafModel 1416 | entries for all leaf nodes that will be reached by the samples X. 1417 | 1418 | Parameters 1419 | ---------- 1420 | tree : Tree 1421 | The tree object. 1422 | node_models : array-like 1423 | The array containing node models 1424 | X : 2D array-like 1425 | The sample data. 1426 | 1427 | Returns 1428 | ------- 1429 | y_predicted : 1D array-like 1430 | The prediction vector. 1431 | """ 1432 | # **** This does the job, but we have one execution of model.predict() per 1433 | # sample: probably not efficient 1434 | # sample_ids_to_leaf_node_ids = tree.apply(X) 1435 | # model_and_x = np.concatenate((node_models[leaf_node_ids].reshape(-1, 1), X), axis=1) 1436 | # def pred(m_and_x): 1437 | # return m_and_x[0].model.predict(m_and_x[1:].reshape(1,-1))[0] 1438 | # y_predicted = np.array(list(map(pred, model_and_x))) 1439 | 1440 | # **** This should be faster because the number of calls to predict() is 1441 | # equal to the number of leaves touched 1442 | sample_ids_to_leaf_node_ids = tree.apply(X) 1443 | y_predicted = -np.ones(sample_ids_to_leaf_node_ids.shape, dtype=DOUBLE) 1444 | 1445 | # -- find the unique list of leaves touched 1446 | leaf_node_ids, inverse_idx = np.unique(sample_ids_to_leaf_node_ids, return_inverse=True) 1447 | 1448 | # -- for each leaf perform the prediction for samples reaching that leaf 1449 | for leaf_node_id in leaf_node_ids: 1450 | # get the indices of the samples reaching that leaf 1451 | sample_indices = np.nonzero(sample_ids_to_leaf_node_ids == leaf_node_id)[0] 1452 | 1453 | # predict 1454 | node_model = node_models[leaf_node_id] 1455 | if isinstance(node_model, LinRegLeafModel): 1456 | y_predicted[sample_indices] = np.ravel(node_model.predict(X[sample_indices, :])) 1457 | else: 1458 | # isinstance(node_model, ConstantLeafModel) 1459 | y_predicted[sample_indices] = tree.value[leaf_node_id] 1460 | 1461 | # **** Recursive strategy: not used anymore 1462 | # left_node = tree.children_left[node_id] # this way we do not have to query it again in the else. 1463 | # if left_node == _tree.TREE_LEAF: 1464 | # # --> Current node is a LEAF. See is_leaf(node_id, tree) <-- 1465 | # y_predicted = node_models[node_id].model.predict() 1466 | # else: 1467 | # # --> Current node is a SPLIT <-- 1468 | # right_node = tree.children_right[node_id] 1469 | # 1470 | # 1471 | # samples_at_this_node = get_samples_at_node(node_id, nodes_to_samples) 1472 | # y_true_this_node = y_train_all[samples_at_this_node] 1473 | # # As surprising as it may seem, in numpy [samples_at_this_node, selected_features] does something else. 1474 | # X_train_this_node = X_train_all[samples_at_this_node, :][:, selected_features] 1475 | # 1476 | # X_left 1477 | 1478 | return y_predicted 1479 | 1480 | 1481 | def predict_from_leaves(m5p, X, smoothing=True, smoothing_constant=15): 1482 | """ 1483 | Predicts using the M5P tree, without using the compiled sklearn 1484 | `tree.apply` subroutine. 1485 | 1486 | The main purpose of this function is to apply smoothing to a M5P model tree 1487 | where smoothing has not been pre-installed on the models. For examples to 1488 | enable a model to be used both without and with smoothing for comparisons 1489 | purposes, or for models whose leaves are not Linear Models and therefore 1490 | for which no pre-installation method exist. 1491 | 1492 | Note: this method is slower than `predict_from_leaves_no_smoothing` when 1493 | `smoothing=False`. 1494 | 1495 | Parameters 1496 | ---------- 1497 | m5p : M5Prime 1498 | The model to use for prediction 1499 | X : array-like 1500 | The input data 1501 | smoothing : bool 1502 | Whether to apply smoothing 1503 | smoothing_constant : int 1504 | The smoothing constant `k` used as the prediction weight of parent node model. 1505 | (k=15 in the articles). 1506 | 1507 | 1508 | Returns 1509 | ------- 1510 | 1511 | """ 1512 | # validate and converts dtype just in case this was directly called 1513 | # e.g. in unit tests 1514 | X = m5p._validate_X_predict(X, check_input=True) 1515 | 1516 | tree = m5p.tree_ 1517 | node_models = m5p.node_models 1518 | nb_samples = X.shape[0] 1519 | y_predicted = -np.ones((nb_samples, 1), dtype=DOUBLE) 1520 | 1521 | # sample_ids_to_leaf_node_ids = tree.apply(X) 1522 | def smooth_predictions(ancestor_nodes, X_at_node, y_pred_at_node): 1523 | # note: y_predicted_at_node can be a constant 1524 | current_node_model_id = ancestor_nodes[-1] 1525 | for _i, parent_model_id in enumerate(reversed(ancestor_nodes[:-1])): 1526 | # warning: this is the nb of TRAINING samples at this node 1527 | node_nb_train_samples = tree.n_node_samples[current_node_model_id] 1528 | parent_model = node_models[parent_model_id] 1529 | parent_predictions = parent_model.predict(X_at_node) 1530 | y_pred_at_node = (node_nb_train_samples * y_pred_at_node + smoothing_constant * parent_predictions) / ( 1531 | node_nb_train_samples + smoothing_constant 1532 | ) 1533 | current_node_model_id = parent_model_id 1534 | 1535 | return y_pred_at_node 1536 | 1537 | def apply_prediction(node_id, ids=None, ancestor_nodes=None): 1538 | first_call = False 1539 | if ids is None: 1540 | ids = slice(None) 1541 | first_call = True 1542 | if smoothing: 1543 | if ancestor_nodes is None: 1544 | ancestor_nodes = [node_id] 1545 | else: 1546 | ancestor_nodes.append(node_id) 1547 | 1548 | left_id = tree.children_left[node_id] 1549 | if left_id == _tree.TREE_LEAF: 1550 | # ... and tree.children_right[node_id] == _tree.TREE_LEAF 1551 | # LEAF node: predict 1552 | # predict 1553 | node_model = node_models[node_id] 1554 | # assert (ids == (sample_ids_to_leaf_node_ids == node_id)).all() 1555 | if isinstance(node_model, LinRegLeafModel): 1556 | X_at_node = X[ids, :] 1557 | predictions = node_model.predict(X_at_node) 1558 | else: 1559 | # isinstance(node_model, ConstantLeafModel) 1560 | predictions = tree.value[node_id] 1561 | if smoothing: 1562 | X_at_node = X[ids, :] 1563 | 1564 | # finally apply smoothing 1565 | if smoothing: 1566 | y_predicted[ids] = smooth_predictions(ancestor_nodes, X_at_node, predictions) 1567 | else: 1568 | y_predicted[ids] = predictions 1569 | else: 1570 | right_id = tree.children_right[node_id] 1571 | # non-leaf node: split samples and recurse 1572 | left_group = np.zeros(nb_samples, dtype=bool) 1573 | left_group[ids] = X[ids, tree.feature[node_id]] <= tree.threshold[node_id] 1574 | right_group = (~left_group) if first_call else (ids & (~left_group)) 1575 | 1576 | # important: copy ancestor_nodes BEFORE calling anything, otherwise 1577 | # it will be modified 1578 | apply_prediction( 1579 | left_id, ids=left_group, ancestor_nodes=(ancestor_nodes.copy() if ancestor_nodes is not None else None) 1580 | ) 1581 | apply_prediction(right_id, ids=right_group, ancestor_nodes=ancestor_nodes) 1582 | 1583 | # recurse to fill all predictions 1584 | apply_prediction(0) 1585 | 1586 | return y_predicted 1587 | 1588 | 1589 | def prune_children(node_id, tree): 1590 | """ 1591 | Prunes the children of node_id in the given `tree`. 1592 | 1593 | Inspired by https://github.com/shenwanxiang/sklearn-post-prune-tree/blob/master/tree_prune.py#L122 1594 | 1595 | IMPORTANT this relies on the fact that `children_left` and `children_right` 1596 | are modificable (and are the only things we need to modify to fix the 1597 | tree). It seems to be the case as of now. 1598 | 1599 | Parameters 1600 | ---------- 1601 | node_id : int 1602 | tree : Tree 1603 | 1604 | Returns 1605 | ------- 1606 | removed_nodes : List[int] 1607 | A list of removed nodes 1608 | """ 1609 | 1610 | def _prune_below(_node_id): 1611 | left_child = tree.children_left[_node_id] 1612 | right_child = tree.children_right[_node_id] 1613 | if left_child == _tree.TREE_LEAF: 1614 | # _node_id is a leaf: left_ & right_child say "leaf".Nothing to do 1615 | return [] 1616 | else: 1617 | # Make sure everything is pruned below: children should be leaves 1618 | removed_l = _prune_below(left_child) 1619 | removed_r = _prune_below(right_child) 1620 | 1621 | # -- First declare that they are not leaves anymore but they do not exist at all 1622 | for child in [left_child, right_child]: 1623 | 1624 | if tree.children_left[child] != _tree.TREE_LEAF: 1625 | raise ValueError("Unexpected children_left: %s, please report it as issue." % child) 1626 | 1627 | if tree.children_right[child] != _tree.TREE_LEAF: 1628 | raise ValueError("Unexpected children_right: %s, please report it as issue." % child) 1629 | 1630 | tree.children_left[child] = _tree.TREE_UNDEFINED 1631 | tree.children_right[child] = _tree.TREE_UNDEFINED 1632 | 1633 | # -- Then declare that current node is a leaf 1634 | tree.children_left[_node_id] = _tree.TREE_LEAF 1635 | tree.children_right[_node_id] = _tree.TREE_LEAF 1636 | 1637 | # Return the list of removed nodes 1638 | return removed_l + removed_r + [left_child, right_child] 1639 | 1640 | # Note: we do not change the node count here, as we'll clean all later. 1641 | # true_node_count = tree.node_count - sum(tree.children_left == _tree.TREE_UNDEFINED) 1642 | # tree.node_count -= 2*len(nodes_to_remove) 1643 | 1644 | return _prune_below(node_id) 1645 | 1646 | 1647 | def is_leaf(node_id, tree): 1648 | """ 1649 | Returns true if node with id `node_id` is a leaf in tree `tree`. 1650 | Is is not actually used in this file because we always need the left child 1651 | node id in our code. 1652 | 1653 | But it is kept here as an easy way to remember how it works. 1654 | 1655 | Parameters 1656 | ---------- 1657 | node_id : int 1658 | tree : Tree 1659 | 1660 | Returns 1661 | ------- 1662 | _is_leaf : bool 1663 | A boolean flag, True if this node is a leaf. 1664 | """ 1665 | if node_id == _tree.TREE_LEAF or node_id == _tree.TREE_UNDEFINED: 1666 | raise ValueError("Invalid node_id %s" % node_id) 1667 | 1668 | return tree.children_left[node_id] == _tree.TREE_LEAF 1669 | 1670 | 1671 | def get_samples_at_node(node_id, nodes_to_samples): 1672 | """ 1673 | Return an array containing the ids of the samples for node `node_id`. 1674 | 1675 | This method requires the user to 1676 | - first compute the decision path for the sample matrix X 1677 | - then convert it to a csc 1678 | 1679 | >>> samples_to_nodes = estimator.decision_path(X) # returns a Scipy compressed sparse row matrix (CSR) 1680 | >>> nodes_to_samples = samples_to_nodes.tocsc() # we need the column equivalent (CSC) 1681 | >>> samples = get_samples_at_node(node_id, nodes_to_samples) 1682 | 1683 | Parameters 1684 | ---------- 1685 | node_id : int 1686 | The node for which the list of samples is queried. 1687 | nodes_to_samples : csc_matrix 1688 | A boolean matrix in Compressed Sparse Column (CSC) format where rows are the 1689 | tree nodes and columns are samples. The matrix contains a 1 when the samples 1690 | go through this node when processed by the decisions. 1691 | 1692 | Returns 1693 | ------- 1694 | 1695 | """ 1696 | return nodes_to_samples.indices[nodes_to_samples.indptr[node_id] : nodes_to_samples.indptr[node_id + 1]] 1697 | 1698 | 1699 | class M5Prime(M5Base, RegressorMixin): 1700 | """An M5' (M five prime) model tree regressor. 1701 | 1702 | The original M5 algorithm was invented by R. Quinlan [1]_. Y. Wang [2]_ made 1703 | improvements and named the resulting algorithm M5 Prime. 1704 | This implementation was inspired by Weka (https://github.com/bnjmn/weka) 1705 | M5Prime class, from Mark Hall. 1706 | 1707 | See also 1708 | -------- 1709 | M5Base 1710 | 1711 | References 1712 | ---------- 1713 | .. [1] Ross J. Quinlan, "Learning with Continuous Classes", 5th Australian 1714 | Joint Conference on Artificial Intelligence, pp343-348, 1992. 1715 | .. [2] Y. Wang and I. H. Witten, "Induction of model trees for predicting 1716 | continuous classes", Poster papers of the 9th European Conference 1717 | on Machine Learning, 1997. 1718 | """ 1719 | 1720 | 1721 | def _empty_contig_ar(shape, dtype): 1722 | """Return an empty contiguous array with given shape and dtype.""" 1723 | return np.ascontiguousarray(np.empty(shape, dtype=dtype)) 1724 | -------------------------------------------------------------------------------- /src/m5py/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smarie/python-m5p/8016d3cca3f263f56088fa8878d0f99c6fd81640/src/m5py/py.typed -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smarie/python-m5p/8016d3cca3f263f56088fa8878d0f99c6fd81640/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.metrics import mean_squared_error 4 | 5 | from m5py import M5Prime 6 | from m5py.main import predict_from_leaves 7 | 8 | 9 | @pytest.mark.parametrize("smoothing_constant", [75, 100], ids="smoothing_constant={}".format) 10 | def test_m5p_smoothing(smoothing_constant): 11 | """Tests that the M5P smoothing feature works correctly.""" 12 | 13 | # use a given random seed so as to enforce deterministic output 14 | # (and 100 gives a particularly interesting result with constant leaves, thats why we use it) 15 | np.random.seed(100) 16 | debug_prints = False 17 | 18 | x = np.random.uniform(-20, -2, (500, 1)) 19 | piece_1_mask = x < -12 20 | piece_2_mask = ~piece_1_mask 21 | piece_2_train_indices = np.argsort(x, axis=0)[-5:] # last 5 samples 22 | # training_mask = np.ma.mask_or(piece_1_mask, piece_2_train_mask) 23 | training_mask = piece_1_mask.copy() 24 | training_mask[piece_2_train_indices] = True 25 | test_mask = ~training_mask 26 | X_train = x[training_mask].reshape(-1, 1) 27 | X_test = x[test_mask].reshape(-1, 1) 28 | 29 | # create a y that is piecewise linear 30 | def piece1(x): 31 | return 12.3 * x + 52.1 32 | 33 | y = piece1(x) + 5 * np.random.randn(*x.shape) 34 | 35 | z = 0 36 | 37 | def piece2(x): 38 | return 5.3 * x + z 39 | 40 | z = piece1(-12) - piece2(-12) 41 | 42 | y[piece_2_mask] -= piece1(x[piece_2_mask]) 43 | y[piece_2_mask] += piece2(x[piece_2_mask]) 44 | y_train = y[training_mask].reshape(-1, 1) 45 | y_test = y[test_mask].reshape(-1, 1) 46 | # plt.plot(X_train, y_train, 'xb') 47 | # plt.plot(X_test, y_test, 'xg') 48 | # plt.draw() 49 | # plt.pause(0.001) 50 | 51 | # Fit M5P without smoothing 52 | # print("Without smoothing: ") 53 | model_no_smoothing = M5Prime(use_smoothing=False, debug_prints=False) 54 | model_no_smoothing.fit(X_train, y_train) 55 | # print(export_text_m5(model_no_smoothing, out_file=None, node_ids=True)) 56 | 57 | # Fit M5P with smoothing 58 | # print("With smoothing: ") 59 | model_smoothing = M5Prime(use_smoothing=True, debug_prints=debug_prints, smoothing_constant=smoothing_constant) 60 | model_smoothing.fit(X_train, y_train) 61 | # print(export_text_m5(model_smoothing, out_file=None, node_ids=True)) 62 | 63 | # Predict 64 | # --no smoothing: compare the fast and slow methods 65 | y_pred_no_smoothing = model_no_smoothing.predict(x).reshape(-1, 1) 66 | y_pred_no_smoothing_slow = predict_from_leaves(model_no_smoothing, x, smoothing=False).reshape(-1, 1) 67 | 68 | # --smoothing: compare pre-computed and late methods 69 | y_pred_installed_smoothing = model_smoothing.predict(x).reshape(-1, 1) 70 | y_pred_late_smoothing = model_no_smoothing.predict( 71 | x, smooth_predictions=True, smoothing_constant=smoothing_constant 72 | ).reshape(-1, 1) 73 | 74 | # plt.plot(x, y_pred_no_smoothing, '.r', label="no smoothing") 75 | # # plt.plot(x, y_pred_no_smoothing_slow, '.r', label="no smoothing slow") 76 | # plt.plot(x, y_pred_late_smoothing, '.m', label="smoothing (on prediction)") 77 | # plt.plot(x, y_pred_installed_smoothing, '.g', label="smoothing (installed models)") 78 | # plt.legend() 79 | # plt.draw() 80 | # plt.pause(0.001) 81 | 82 | # make sure both ways to smooth give the same output 83 | np.testing.assert_array_equal(y_pred_no_smoothing, y_pred_no_smoothing_slow) 84 | np.testing.assert_array_almost_equal(y_pred_late_smoothing, y_pred_installed_smoothing, decimal=4) 85 | 86 | # compare performances without/with smoothing 87 | mse_no_smoothing = mean_squared_error(y_test, y_pred_no_smoothing[test_mask]) 88 | mse_smoothing = mean_squared_error(y_test, y_pred_installed_smoothing[test_mask]) 89 | print("M5P MSE: %.4f (no smoothing) %.4f (smoothing)" % (mse_no_smoothing, mse_smoothing)) 90 | 91 | # simple assert: smoothing improves performance (in most cases but not always - thats why we fixed our random seed) 92 | assert mse_smoothing < mse_no_smoothing 93 | 94 | # plt.close('all') 95 | 96 | 97 | # def test_boston_housing(): 98 | # """ 99 | # A copy of the Scikit Learn Gradient Boosting regression example, with the M5P in addition 100 | # 101 | # http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html 102 | # :return: 103 | # """ 104 | # import matplotlib.pyplot as plt 105 | # plt.ion() # comment to debug 106 | # 107 | # # ############################################################################# 108 | # # Load data 109 | # boston = datasets.load_boston() 110 | # X, y = shuffle(boston.data, boston.target, random_state=13) 111 | # X = X.astype(np.float32) 112 | # offset = int(X.shape[0] * 0.9) 113 | # X_train, y_train = X[:offset], y[:offset] 114 | # X_test, y_test = X[offset:], y[offset:] 115 | # 116 | # # ############################################################################# 117 | # # Fit regression model 118 | # params = {'n_estimators': 500, 'max_depth': 4, 'min_samples_split': 2, 119 | # 'learning_rate': 0.01, 'loss': 'ls'} 120 | # clf = ensemble.GradientBoostingRegressor(**params) 121 | # clf.fit(X_train, y_train) 122 | # y_predicted = clf.predict(X_test) 123 | # mse = mean_squared_error(y_test, y_predicted) 124 | # print("XGBoost MSE: %.4f" % mse) 125 | # 126 | # # ############################################################################# 127 | # # Plot predictions 128 | # plt.figure(figsize=(18, 12)) 129 | # plt.subplot(2, 3, 1) 130 | # plt.title('Predictions on test set (RMSE = {:2f})'.format(np.sqrt(mse))) 131 | # plt.plot(y_test, y_predicted, '.') 132 | # plt.xlabel('true y') 133 | # plt.ylabel('predicted_y') 134 | # 135 | # # ############################################################################# 136 | # # Plot training deviance 137 | # 138 | # # compute test set deviance 139 | # test_score = np.zeros((params['n_estimators'],), dtype=np.float64) 140 | # 141 | # for i, y_pred in enumerate(clf.staged_predict(X_test)): 142 | # test_score[i] = clf.loss_(y_test, y_pred) 143 | # 144 | # plt.subplot(2, 3, 2) 145 | # plt.title('Deviance') 146 | # plt.plot(np.arange(params['n_estimators']) + 1, clf.train_score_, 'b-', 147 | # label='Training Set Deviance') 148 | # plt.plot(np.arange(params['n_estimators']) + 1, test_score, 'r-', 149 | # label='Test Set Deviance') 150 | # plt.legend(loc='upper right') 151 | # plt.xlabel('Boosting Iterations') 152 | # plt.ylabel('Deviance') 153 | # 154 | # # ############################################################################# 155 | # # Plot feature importance 156 | # feature_importance = clf.feature_importances_ 157 | # # make importances relative to max importance 158 | # feature_importance = 100.0 * (feature_importance / feature_importance.max()) 159 | # sorted_idx = np.argsort(feature_importance) 160 | # pos = np.arange(sorted_idx.shape[0]) + .5 161 | # plt.subplot(2, 3, 3) 162 | # plt.barh(pos, feature_importance[sorted_idx], align='center') 163 | # plt.yticks(pos, boston.feature_names[sorted_idx]) 164 | # plt.xlabel('Relative Importance') 165 | # plt.title('Variable Importance') 166 | # 167 | # # ------- M5P 168 | # # ############################################################################# 169 | # # Fit regression model 170 | # params = {} 171 | # clf = M5Prime(**params) 172 | # clf.fit(X_train, y_train) 173 | # 174 | # # Print the tree 175 | # print(export_text_m5(clf, out_file=None)) 176 | # print(export_text_m5(clf, out_file=None, feature_names=boston.feature_names)) 177 | # 178 | # # Predict 179 | # y_predicted = clf.predict(X_test) 180 | # mse = mean_squared_error(y_test, y_predicted) 181 | # print("M5P MSE: %.4f" % mse) 182 | # 183 | # # ############################################################################# 184 | # # Plot predictions 185 | # plt.subplot(2, 3, 4) 186 | # plt.title('Predictions on test set (RMSE = {:2f})'.format(np.sqrt(mse))) 187 | # plt.plot(y_test, y_predicted, '.') 188 | # plt.xlabel('true y') 189 | # plt.ylabel('predicted_y') 190 | # 191 | # # Compress the tree (features-wise) 192 | # idx = clf.compress_features() 193 | # # Print the tree 194 | # print(export_text_m5(clf, out_file=None)) 195 | # print(export_text_m5(clf, out_file=None, feature_names=boston.feature_names)) 196 | # 197 | # # Predict 198 | # y_predicted2 = clf.predict(X_test[:, idx]) 199 | # mse2 = mean_squared_error(y_test, y_predicted2) 200 | # print("M5P2 MSE: %.4f" % mse) 201 | # 202 | # # ############################################################################# 203 | # # Plot predictions 204 | # plt.subplot(2, 3, 5) 205 | # plt.title('Predictions on test set (RMSE = {:2f})'.format(np.sqrt(mse2))) 206 | # plt.plot(y_test, y_predicted2, '.') 207 | # plt.xlabel('true y') 208 | # plt.ylabel('predicted_y') 209 | # 210 | # # ############################################################################# 211 | # # Plot feature importance 212 | # feature_importance = clf.feature_importances_ 213 | # # make importances relative to max importance 214 | # feature_importance = 100.0 * (feature_importance / feature_importance.max()) 215 | # sorted_idx = np.argsort(feature_importance) 216 | # pos = np.arange(sorted_idx.shape[0]) + .5 217 | # plt.subplot(2, 3, 6) 218 | # plt.barh(pos, feature_importance[sorted_idx], align='center') 219 | # # do not forget that we now work on reindexed features 220 | # plt.yticks(pos, boston.feature_names[idx][sorted_idx]) 221 | # plt.xlabel('Relative Importance') 222 | # plt.title('Variable Importance') 223 | # 224 | # plt.close('all') 225 | # 226 | # 227 | # @pytest.mark.parametrize("use_smoothing", [None, # None (default) = True = 'installed' 228 | # False, 'on_prediction']) 229 | # def test_default_smoothing_modes(use_smoothing): 230 | # """ Tests with default constant/ratio, depending on smoothing mode """ 231 | # 232 | # half_x = np.random.random(25) 233 | # X = np.r_[half_x, -half_x] 234 | # Y = np.r_[2.8 * half_x + 5, -0.2 * half_x + 5] 235 | # 236 | # X = X.reshape(-1, 1) 237 | # Y = Y.reshape(-1, 1) 238 | # 239 | # regr = M5Prime(use_smoothing=use_smoothing) 240 | # regr.fit(X, Y) 241 | # print(export_text_m5(regr, out_file=None, node_ids=True)) 242 | # preds = regr.predict(X) 243 | # 244 | # import matplotlib.pyplot as plt 245 | # plt.ion() # comment to debug 246 | # plt.plot(X, Y, 'x', label='true') 247 | # plt.plot(X, preds, 'x', label='prediction') 248 | # plt.close('all') 249 | --------------------------------------------------------------------------------