├── .github
└── workflows
│ └── base.yml
├── .gitignore
├── .zenodo.json
├── CITATION.cff
├── LICENSE
├── README.md
├── ci_tools
├── .pylintrc
├── check_python_version.py
├── flake8-requirements.txt
├── github_release.py
└── nox_utils.py
├── docs
├── changelog.md
├── examples
│ ├── 1_simple_1D_demo.py
│ ├── 2_pw_linear_demo.py
│ └── Readme.md
├── index.md
└── long_description.md
├── mkdocs.yml
├── noxfile-requirements.txt
├── noxfile.py
├── pyproject.toml
├── setup.cfg
├── setup.py
├── src
└── m5py
│ ├── __init__.py
│ ├── export.py
│ ├── linreg_utils.py
│ ├── main.py
│ └── py.typed
└── tests
├── __init__.py
└── test_main.py
/.github/workflows/base.yml:
--------------------------------------------------------------------------------
1 | # .github/workflows/base.yml
2 | name: Build
3 | on:
4 | # this one is to trigger the workflow manually from the interface
5 | workflow_dispatch:
6 |
7 | push:
8 | tags:
9 | - '*'
10 | branches:
11 | - main
12 | pull_request:
13 | branches:
14 | - main
15 | jobs:
16 | # pre-job to read nox tests matrix - see https://stackoverflow.com/q/66747359/7262247
17 | list_nox_test_sessions:
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v2
21 | - uses: actions/setup-python@v1
22 | with:
23 | python-version: 3.7
24 | architecture: x64
25 |
26 | - name: Install noxfile requirements
27 | shell: bash -l {0}
28 | run: pip install -r noxfile-requirements.txt
29 |
30 | - name: List 'tests' nox sessions
31 | id: set-matrix
32 | run: echo "::set-output name=matrix::$(nox -s gha_list -- tests)"
33 | outputs:
34 | matrix: ${{ steps.set-matrix.outputs.matrix }} # save nox sessions list to outputs
35 |
36 | run_all_tests:
37 | needs: list_nox_test_sessions
38 | strategy:
39 | fail-fast: false
40 | matrix:
41 | os: [ ubuntu-latest ] # , macos-latest, windows-latest]
42 | # all nox sessions: manually > dynamically from previous job
43 | # nox_session: ["tests-2.7", "tests-3.7"]
44 | nox_session: ${{ fromJson(needs.list_nox_test_sessions.outputs.matrix) }}
45 |
46 | name: ${{ matrix.os }} ${{ matrix.nox_session }} # ${{ matrix.name_suffix }}
47 | runs-on: ${{ matrix.os }}
48 | steps:
49 | - uses: actions/checkout@v2
50 |
51 | # Conda install
52 | - name: Install conda v3.7
53 | uses: conda-incubator/setup-miniconda@v2
54 | with:
55 | # auto-update-conda: true
56 | python-version: 3.7
57 | activate-environment: noxenv
58 | - run: conda info
59 | shell: bash -l {0} # so that conda works
60 | - run: conda list
61 | shell: bash -l {0} # so that conda works
62 |
63 | # Nox install + run
64 | - name: Install noxfile requirements
65 | shell: bash -l {0} # so that conda works
66 | run: pip install -r noxfile-requirements.txt
67 | - run: conda list
68 | shell: bash -l {0} # so that conda works
69 | - run: nox -s "${{ matrix.nox_session }}"
70 | shell: bash -l {0} # so that conda works
71 |
72 | # Share ./docs/reports so that they can be deployed with doc in next job
73 | - name: Share reports with other jobs
74 | # if: matrix.nox_session == '...': not needed, if empty wont be shared
75 | uses: actions/upload-artifact@master
76 | with:
77 | name: reports_dir
78 | path: ./docs/reports
79 |
80 | publish_release:
81 | needs: run_all_tests
82 | runs-on: ubuntu-latest
83 | if: github.event_name == 'push'
84 | steps:
85 | - name: GitHub context to debug conditional steps
86 | env:
87 | GITHUB_CONTEXT: ${{ toJSON(github) }}
88 | run: echo "$GITHUB_CONTEXT"
89 |
90 | - uses: actions/checkout@v2
91 | with:
92 | fetch-depth: 0 # so that gh-deploy works
93 |
94 | # 1) retrieve the reports generated previously
95 | - name: Retrieve reports
96 | uses: actions/download-artifact@master
97 | with:
98 | name: reports_dir
99 | path: ./docs/reports
100 |
101 | # Conda install
102 | - name: Install conda v3.7
103 | uses: conda-incubator/setup-miniconda@v2
104 | with:
105 | # auto-update-conda: true
106 | python-version: 3.7
107 | activate-environment: noxenv
108 | - run: conda info
109 | shell: bash -l {0} # so that conda works
110 | - run: conda list
111 | shell: bash -l {0} # so that conda works
112 |
113 | # Nox install
114 | - name: Install noxfile requirements
115 | shell: bash -l {0} # so that conda works
116 | run: pip install -r noxfile-requirements.txt
117 | - run: conda list
118 | shell: bash -l {0} # so that conda works
119 |
120 | # 5) Run the flake8 report and badge
121 | - name: Run flake8 analysis and generate corresponding badge
122 | shell: bash -l {0} # so that conda works
123 | run: nox -s flake8
124 |
125 | # -------------- only on Ubuntu + MAIN PUSH (no pull request, no tag) -----------
126 |
127 | # 5) Publish the doc and test reports
128 | - name: \[not on TAG\] Publish documentation, tests and coverage reports
129 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads') # startsWith(matrix.os,'ubuntu')
130 | shell: bash -l {0} # so that conda works
131 | run: nox -s publish
132 |
133 | # 6) Publish coverage report
134 | - name: \[not on TAG\] Create codecov.yaml with correct paths
135 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads')
136 | shell: bash
137 | run: |
138 | cat << EOF > codecov.yml
139 | # codecov.yml
140 | fixes:
141 | - "/home/runner/work/smarie/python-m5p/::" # Correct paths
142 | EOF
143 | - name: \[not on TAG\] Publish coverage report
144 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/heads')
145 | uses: codecov/codecov-action@v1
146 | with:
147 | files: ./docs/reports/coverage/coverage.xml
148 |
149 | # -------------- only on Ubuntu + TAG PUSH (no pull request) -----------
150 |
151 | # 7) Create github release and build the wheel
152 | - name: \[TAG only\] Build wheel and create github release
153 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
154 | shell: bash -l {0} # so that conda works
155 | run: nox -s release -- ${{ secrets.GITHUB_TOKEN }}
156 |
157 | # 8) Publish the wheel on PyPi
158 | - name: \[TAG only\] Deploy on PyPi
159 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
160 | uses: pypa/gh-action-pypi-publish@release/v1
161 | with:
162 | user: __token__
163 | password: ${{ secrets.PYPI_API_TOKEN }}
164 |
165 | delete-artifacts:
166 | needs: publish_release
167 | runs-on: ubuntu-latest
168 | if: github.event_name == 'push'
169 | steps:
170 | - uses: kolpav/purge-artifacts-action@v1
171 | with:
172 | token: ${{ secrets.GITHUB_TOKEN }}
173 | expire-in: 0 # Setting this to 0 will delete all artifacts
174 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | src/*/_version.py
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv*/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # PyCharm development
133 | /.idea
134 |
135 | # OSX
136 | .DS_Store
137 |
138 | # JUnit and coverage reports
139 | docs/reports
140 |
141 | # ODSClient cache
142 | .odsclient
143 |
--------------------------------------------------------------------------------
/.zenodo.json:
--------------------------------------------------------------------------------
1 | {
2 | "title": "m5py",
3 | "description": "
An implementation of M5 (Prime) and model trees for scikit-learn.
",
4 | "language": "eng",
5 | "license": {
6 | "id": "bsd-license"
7 | },
8 | "keywords": [
9 | "python",
10 | "model",
11 | "tree",
12 | "regression",
13 | "M5",
14 | "prime",
15 | "scikit-learn",
16 | "machine learning"
17 | ],
18 | "creators": [
19 | {
20 | "orcid": "0000-0002-5929-1047",
21 | "affiliation": "Schneider Electric",
22 | "name": "Sylvain Mari\u00e9"
23 | },
24 | {
25 | "name": "Various github contributors"
26 | }
27 | ]
28 | }
29 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # This CITATION.cff file was generated with cffinit.
2 | # Visit https://bit.ly/cffinit to generate yours today!
3 |
4 | cff-version: 1.2.0
5 | title: m5py
6 | message: 'If you use this software, please cite it as below.'
7 | type: software
8 | authors:
9 | - family-names: Marié
10 | given-names: Sylvain
11 | orcid: 'https://orcid.org/0000-0002-5929-1047'
12 | identifiers:
13 | - type: doi
14 | value: 10.5281/zenodo.10552219
15 | description: Software (Zenodo)
16 | repository-code: 'https://github.com/smarie/python-m5p'
17 | url: 'https://smarie.github.io/python-m5p/'
18 | repository-artifact: 'https://pypi.org/project/m5py/'
19 | abstract: >-
20 | An implementation of M5 (Prime) and model trees for
21 | scikit-learn.
22 | keywords:
23 | - python
24 | - model
25 | - tree
26 | - regression
27 | - M5
28 | - prime
29 | - scikit-learn
30 | - machine learning
31 | license: BSD-3-Clause
32 | doi: 10.5281/zenodo.10552219
33 | preferred-citation:
34 | type: conference
35 | url: 'https://hal.science/hal-03762155/'
36 | authors:
37 | - family-names: Marié
38 | given-names: Sylvain
39 | orcid: 'https://orcid.org/0000-0002-5929-1047'
40 | title: >-
41 | `python-m5p` - M5 Prime regression trees in python, compliant with
42 | scikit-learn
43 | conference:
44 | name: "PyCon.DE & PyData"
45 | city: "Berlin"
46 | country: "Germany"
47 | date-end: "2022-04-13"
48 | date-start: "2022-04-11"
49 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2016-2022, Sylvain Marié, Schneider Electric Industries
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # `m5py`
2 |
3 | *`scikit-learn`-compliant M5 / M5' model trees for python*
4 |
5 | [](https://pypi.python.org/pypi/m5py/) [](https://github.com/smarie/python-m5p/actions/workflows/base.yml) [](./reports/junit/report.html) [](./reports/coverage/index.html) [](https://codecov.io/gh/smarie/python-m5p) [](./reports/flake8/index.html)
6 |
7 | [](https://smarie.github.io/python-m5p/) [](https://pypi.python.org/pypi/m5py/) [](https://pepy.tech/project/m5py) [](https://pepy.tech/project/m5py) [](https://github.com/smarie/python-m5p/stargazers)
8 |
9 | **This is the readme for developers.** The documentation for users is available here: [https://smarie.github.io/python-m5p/](https://smarie.github.io/python-m5p/)
10 |
11 | ## Want to contribute ?
12 |
13 | Contributions are welcome ! Simply fork this project on github, commit your contributions, and create pull requests.
14 |
15 | Here is a non-exhaustive list of interesting open topics: [https://github.com/smarie/python-m5p/issues](https://github.com/smarie/python-m5p/issues)
16 |
17 | ## `nox` setup
18 |
19 | This project uses `nox` to define all lifecycle tasks. In order to be able to run those tasks, you should create python 3.7 environment and install the requirements:
20 |
21 | ```bash
22 | >>> conda create -n noxenv python="3.7"
23 | >>> activate noxenv
24 | (noxenv) >>> pip install -r noxfile-requirements.txt
25 | ```
26 |
27 | You should then be able to list all available tasks using:
28 |
29 | ```
30 | >>> nox --list
31 | Sessions defined in \noxfile.py:
32 |
33 | * tests-2.7 -> Run the test suite, including test reports generation and coverage reports.
34 | * tests-3.5 -> Run the test suite, including test reports generation and coverage reports.
35 | * tests-3.6 -> Run the test suite, including test reports generation and coverage reports.
36 | * tests-3.8 -> Run the test suite, including test reports generation and coverage reports.
37 | * tests-3.7 -> Run the test suite, including test reports generation and coverage reports.
38 | - docs-3.7 -> Generates the doc and serves it on a local http server. Pass '-- build' to build statically instead.
39 | - publish-3.7 -> Deploy the docs+reports on github pages. Note: this rebuilds the docs
40 | - release-3.7 -> Create a release on github corresponding to the latest tag
41 | ```
42 |
43 | ## Running the tests and generating the reports
44 |
45 | This project uses `pytest` so running `pytest` at the root folder will execute all tests on current environment. However it is a bit cumbersome to manage all requirements by hand ; it is easier to use `nox` to run `pytest` on all supported python environments with the correct package requirements:
46 |
47 | ```bash
48 | nox
49 | ```
50 |
51 | Tests and coverage reports are automatically generated under `./docs/reports` for one of the sessions (`tests-3.7`).
52 |
53 | If you wish to execute tests on a specific environment, use explicit session names, e.g. `nox -s tests-3.6`.
54 |
55 |
56 | ## Editing the documentation
57 |
58 | This project uses `mkdocs` to generate its documentation page. Therefore building a local copy of the doc page may be done using `mkdocs build -f docs/mkdocs.yml`. However once again things are easier with `nox`. You can easily build and serve locally a version of the documentation site using:
59 |
60 | ```bash
61 | >>> nox -s docs
62 | nox > Running session docs-3.7
63 | nox > Creating conda env in .nox\docs-3-7 with python=3.7
64 | nox > [docs] Installing requirements with pip: ['mkdocs-material', 'mkdocs', 'pymdown-extensions', 'pygments']
65 | nox > python -m pip install mkdocs-material mkdocs pymdown-extensions pygments
66 | nox > mkdocs serve -f ./docs/mkdocs.yml
67 | INFO - Building documentation...
68 | INFO - Cleaning site directory
69 | INFO - The following pages exist in the docs directory, but are not included in the "nav" configuration:
70 | - long_description.md
71 | INFO - Documentation built in 1.07 seconds
72 | INFO - Serving on http://127.0.0.1:8000
73 | INFO - Start watching changes
74 | ...
75 | ```
76 |
77 | While this is running, you can edit the files under `./docs/` and browse the automatically refreshed documentation at the local [http://127.0.0.1:8000](http://127.0.0.1:8000) page.
78 |
79 | Once you are done, simply hit `` to stop the session.
80 |
81 | Publishing the documentation (including tests and coverage reports) is done automatically by [the continuous integration engine](https://github.com/smarie/python-m5p/actions), using the `nox -s publish` session, this is not needed for local development.
82 |
83 | ## Packaging
84 |
85 | This project uses `setuptools_scm` to synchronise the version number. Therefore the following command should be used for development snapshots as well as official releases: `python setup.py sdist bdist_wheel`. However this is not generally needed since [the continuous integration engine](https://github.com/smarie/python-m5p/actions) does it automatically for us on git tags. For reference, this is done in the `nox -s release` session.
86 |
87 | ### Merging pull requests with edits - memo
88 |
89 | Ax explained in github ('get commandline instructions'):
90 |
91 | ```bash
92 | git checkout -b - main
93 | git pull https://github.com//python-m5p.git --no-commit --ff-only
94 | ```
95 |
96 | if the second step does not work, do a normal auto-merge (do not use **rebase**!):
97 |
98 | ```bash
99 | git pull https://github.com//python-m5p.git --no-commit
100 | ```
101 |
102 | Finally review the changes, possibly perform some modifications, and commit.
103 |
--------------------------------------------------------------------------------
/ci_tools/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # Specify a configuration file.
4 | #rcfile=
5 |
6 | # Python code to execute, usually for sys.path manipulation such as
7 | # pygtk.require().
8 | # init-hook="import m5py"
9 |
10 | # Add files or directories to the blacklist. They should be base names, not
11 | # paths.
12 | ignore=
13 |
14 | # Pickle collected data for later comparisons.
15 | persistent=no
16 |
17 | # List of plugins (as comma separated values of python modules names) to load,
18 | # usually to register additional checkers.
19 | load-plugins=
20 |
21 | # Use multiple processes to speed up Pylint.
22 | # DO NOT CHANGE THIS VALUES >1 HIDE RESULTS!!!!!
23 | jobs=1
24 |
25 | # Allow loading of arbitrary C extensions. Extensions are imported into the
26 | # active Python interpreter and may run arbitrary code.
27 | unsafe-load-any-extension=no
28 |
29 | # A comma-separated list of package or module names from where C extensions may
30 | # be loaded. Extensions are loading into the active Python interpreter and may
31 | # run arbitrary code
32 | extension-pkg-whitelist=
33 |
34 | # Allow optimization of some AST trees. This will activate a peephole AST
35 | # optimizer, which will apply various small optimizations. For instance, it can
36 | # be used to obtain the result of joining multiple strings with the addition
37 | # operator. Joining a lot of strings can lead to a maximum recursion error in
38 | # Pylint and this flag can prevent that. It has one side effect, the resulting
39 | # AST will be different than the one from reality.
40 | optimize-ast=no
41 |
42 |
43 | [MESSAGES CONTROL]
44 |
45 | # Only show warnings with the listed confidence levels. Leave empty to show
46 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
47 | confidence=
48 |
49 | # Enable the message, report, category or checker with the given id(s). You can
50 | # either give multiple identifier separated by comma (,) or put this option
51 | # multiple time. See also the "--disable" option for examples.
52 | disable=all
53 |
54 | enable=import-error,
55 | import-self,
56 | reimported,
57 | wildcard-import,
58 | misplaced-future,
59 | relative-import,
60 | deprecated-module,
61 | unpacking-non-sequence,
62 | invalid-all-object,
63 | undefined-all-variable,
64 | used-before-assignment,
65 | cell-var-from-loop,
66 | global-variable-undefined,
67 | redefined-builtin,
68 | redefine-in-handler,
69 | unused-import,
70 | unused-wildcard-import,
71 | global-variable-not-assigned,
72 | undefined-loop-variable,
73 | global-statement,
74 | global-at-module-level,
75 | bad-open-mode,
76 | redundant-unittest-assert,
77 | boolean-datetime,
78 | # Has common issues with our style due to
79 | # https://github.com/PyCQA/pylint/issues/210
80 | unused-variable
81 |
82 | # Things we'd like to enable someday:
83 | # redefined-outer-name (requires a bunch of work to clean up our code first)
84 | # undefined-variable (re-enable when pylint fixes https://github.com/PyCQA/pylint/issues/760)
85 | # no-name-in-module (giving us spurious warnings https://github.com/PyCQA/pylint/issues/73)
86 | # unused-argument (need to clean up or code a lot, e.g. prefix unused_?)
87 |
88 | # Things we'd like to try.
89 | # Procedure:
90 | # 1. Enable a bunch.
91 | # 2. See if there's spurious ones; if so disable.
92 | # 3. Record above.
93 | # 4. Remove from this list.
94 | # deprecated-method,
95 | # anomalous-unicode-escape-in-string,
96 | # anomalous-backslash-in-string,
97 | # not-in-loop,
98 | # function-redefined,
99 | # continue-in-finally,
100 | # abstract-class-instantiated,
101 | # star-needs-assignment-target,
102 | # duplicate-argument-name,
103 | # return-in-init,
104 | # too-many-star-expressions,
105 | # nonlocal-and-global,
106 | # return-outside-function,
107 | # return-arg-in-generator,
108 | # invalid-star-assignment-target,
109 | # bad-reversed-sequence,
110 | # nonexistent-operator,
111 | # yield-outside-function,
112 | # init-is-generator,
113 | # nonlocal-without-binding,
114 | # lost-exception,
115 | # assert-on-tuple,
116 | # dangerous-default-value,
117 | # duplicate-key,
118 | # useless-else-on-loop,
119 | # expression-not-assigned,
120 | # confusing-with-statement,
121 | # unnecessary-lambda,
122 | # pointless-statement,
123 | # pointless-string-statement,
124 | # unnecessary-pass,
125 | # unreachable,
126 | # eval-used,
127 | # exec-used,
128 | # bad-builtin,
129 | # using-constant-test,
130 | # deprecated-lambda,
131 | # bad-super-call,
132 | # missing-super-argument,
133 | # slots-on-old-class,
134 | # super-on-old-class,
135 | # property-on-old-class,
136 | # not-an-iterable,
137 | # not-a-mapping,
138 | # format-needs-mapping,
139 | # truncated-format-string,
140 | # missing-format-string-key,
141 | # mixed-format-string,
142 | # too-few-format-args,
143 | # bad-str-strip-call,
144 | # too-many-format-args,
145 | # bad-format-character,
146 | # format-combined-specification,
147 | # bad-format-string-key,
148 | # bad-format-string,
149 | # missing-format-attribute,
150 | # missing-format-argument-key,
151 | # unused-format-string-argument,
152 | # unused-format-string-key,
153 | # invalid-format-index,
154 | # bad-indentation,
155 | # mixed-indentation,
156 | # unnecessary-semicolon,
157 | # lowercase-l-suffix,
158 | # fixme,
159 | # invalid-encoded-data,
160 | # unpacking-in-except,
161 | # import-star-module-level,
162 | # parameter-unpacking,
163 | # long-suffix,
164 | # old-octal-literal,
165 | # old-ne-operator,
166 | # backtick,
167 | # old-raise-syntax,
168 | # print-statement,
169 | # metaclass-assignment,
170 | # next-method-called,
171 | # dict-iter-method,
172 | # dict-view-method,
173 | # indexing-exception,
174 | # raising-string,
175 | # standarderror-builtin,
176 | # using-cmp-argument,
177 | # cmp-method,
178 | # coerce-method,
179 | # delslice-method,
180 | # getslice-method,
181 | # hex-method,
182 | # nonzero-method,
183 | # oct-method,
184 | # setslice-method,
185 | # apply-builtin,
186 | # basestring-builtin,
187 | # buffer-builtin,
188 | # cmp-builtin,
189 | # coerce-builtin,
190 | # old-division,
191 | # execfile-builtin,
192 | # file-builtin,
193 | # filter-builtin-not-iterating,
194 | # no-absolute-import,
195 | # input-builtin,
196 | # intern-builtin,
197 | # long-builtin,
198 | # map-builtin-not-iterating,
199 | # range-builtin-not-iterating,
200 | # raw_input-builtin,
201 | # reduce-builtin,
202 | # reload-builtin,
203 | # round-builtin,
204 | # unichr-builtin,
205 | # unicode-builtin,
206 | # xrange-builtin,
207 | # zip-builtin-not-iterating,
208 | # logging-format-truncated,
209 | # logging-too-few-args,
210 | # logging-too-many-args,
211 | # logging-unsupported-format,
212 | # logging-not-lazy,
213 | # logging-format-interpolation,
214 | # invalid-unary-operand-type,
215 | # unsupported-binary-operation,
216 | # no-member,
217 | # not-callable,
218 | # redundant-keyword-arg,
219 | # assignment-from-no-return,
220 | # assignment-from-none,
221 | # not-context-manager,
222 | # repeated-keyword,
223 | # missing-kwoa,
224 | # no-value-for-parameter,
225 | # invalid-sequence-index,
226 | # invalid-slice-index,
227 | # too-many-function-args,
228 | # unexpected-keyword-arg,
229 | # unsupported-membership-test,
230 | # unsubscriptable-object,
231 | # access-member-before-definition,
232 | # method-hidden,
233 | # assigning-non-slot,
234 | # duplicate-bases,
235 | # inconsistent-mro,
236 | # inherit-non-class,
237 | # invalid-slots,
238 | # invalid-slots-object,
239 | # no-method-argument,
240 | # no-self-argument,
241 | # unexpected-special-method-signature,
242 | # non-iterator-returned,
243 | # protected-access,
244 | # arguments-differ,
245 | # attribute-defined-outside-init,
246 | # no-init,
247 | # abstract-method,
248 | # signature-differs,
249 | # bad-staticmethod-argument,
250 | # non-parent-init-called,
251 | # super-init-not-called,
252 | # bad-except-order,
253 | # catching-non-exception,
254 | # bad-exception-context,
255 | # notimplemented-raised,
256 | # raising-bad-type,
257 | # raising-non-exception,
258 | # misplaced-bare-raise,
259 | # duplicate-except,
260 | # broad-except,
261 | # nonstandard-exception,
262 | # binary-op-exception,
263 | # bare-except,
264 | # not-async-context-manager,
265 | # yield-inside-async-function,
266 |
267 | # ...
268 | [REPORTS]
269 |
270 | # Set the output format. Available formats are text, parseable, colorized, msvs
271 | # (visual studio) and html. You can also give a reporter class, eg
272 | # mypackage.mymodule.MyReporterClass.
273 | output-format=parseable
274 |
275 | # Put messages in a separate file for each module / package specified on the
276 | # command line instead of printing them on stdout. Reports (if any) will be
277 | # written in a file name "pylint_global.[txt|html]".
278 | files-output=no
279 |
280 | # Tells whether to display a full report or only the messages
281 | reports=no
282 |
283 | # Python expression which should return a note less than 10 (10 is the highest
284 | # note). You have access to the variables errors warning, statement which
285 | # respectively contain the number of errors / warnings messages and the total
286 | # number of statements analyzed. This is used by the global evaluation report
287 | # (RP0004).
288 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
289 |
290 | # Template used to display messages. This is a python new-style format string
291 | # used to format the message information. See doc for all details
292 | #msg-template=
293 |
294 |
295 | [LOGGING]
296 |
297 | # Logging modules to check that the string format arguments are in logging
298 | # function parameter format
299 | logging-modules=logging
300 |
301 |
302 | [FORMAT]
303 |
304 | # Maximum number of characters on a single line.
305 | max-line-length=100
306 |
307 | # Regexp for a line that is allowed to be longer than the limit.
308 | ignore-long-lines=^\s*(# )??$
309 |
310 | # Allow the body of an if to be on the same line as the test if there is no
311 | # else.
312 | single-line-if-stmt=no
313 |
314 | # List of optional constructs for which whitespace checking is disabled. `dict-
315 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
316 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
317 | # `empty-line` allows space-only lines.
318 | no-space-check=trailing-comma,dict-separator
319 |
320 | # Maximum number of lines in a module
321 | max-module-lines=1000
322 |
323 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
324 | # tab).
325 | indent-string=' '
326 |
327 | # Number of spaces of indent required inside a hanging or continued line.
328 | indent-after-paren=4
329 |
330 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
331 | expected-line-ending-format=
332 |
333 |
334 | [TYPECHECK]
335 |
336 | # Tells whether missing members accessed in mixin class should be ignored. A
337 | # mixin class is detected if its name ends with "mixin" (case insensitive).
338 | ignore-mixin-members=yes
339 |
340 | # List of module names for which member attributes should not be checked
341 | # (useful for modules/projects where namespaces are manipulated during runtime
342 | # and thus existing member attributes cannot be deduced by static analysis. It
343 | # supports qualified module names, as well as Unix pattern matching.
344 | ignored-modules=
345 |
346 | # List of classes names for which member attributes should not be checked
347 | # (useful for classes with attributes dynamically set). This supports can work
348 | # with qualified names.
349 | ignored-classes=
350 |
351 | # List of members which are set dynamically and missed by pylint inference
352 | # system, and so shouldn't trigger E1101 when accessed. Python regular
353 | # expressions are accepted.
354 | generated-members=
355 |
356 |
357 | [VARIABLES]
358 |
359 | # Tells whether we should check for unused import in __init__ files.
360 | init-import=no
361 |
362 | # A regular expression matching the name of dummy variables (i.e. expectedly
363 | # not used).
364 | dummy-variables-rgx=^_|^dummy
365 |
366 | # List of additional names supposed to be defined in builtins. Remember that
367 | # you should avoid to define new builtins when possible.
368 | additional-builtins=
369 |
370 | # List of strings which can identify a callback function by name. A callback
371 | # name must start or end with one of those strings.
372 | callbacks=cb_,_cb
373 |
374 |
375 | [SIMILARITIES]
376 |
377 | # Minimum lines number of a similarity.
378 | min-similarity-lines=4
379 |
380 | # Ignore comments when computing similarities.
381 | ignore-comments=yes
382 |
383 | # Ignore docstrings when computing similarities.
384 | ignore-docstrings=yes
385 |
386 | # Ignore imports when computing similarities.
387 | ignore-imports=no
388 |
389 |
390 | [SPELLING]
391 |
392 | # Spelling dictionary name. Available dictionaries: none. To make it working
393 | # install python-enchant package.
394 | spelling-dict=
395 |
396 | # List of comma separated words that should not be checked.
397 | spelling-ignore-words=
398 |
399 | # A path to a file that contains private dictionary; one word per line.
400 | spelling-private-dict-file=
401 |
402 | # Tells whether to store unknown words to indicated private dictionary in
403 | # --spelling-private-dict-file option instead of raising a message.
404 | spelling-store-unknown-words=no
405 |
406 |
407 | [MISCELLANEOUS]
408 |
409 | # List of note tags to take in consideration, separated by a comma.
410 | notes=FIXME,XXX,TODO
411 |
412 |
413 | [BASIC]
414 |
415 | # List of builtins function names that should not be used, separated by a comma
416 | bad-functions=map,filter,input
417 |
418 | # Good variable names which should always be accepted, separated by a comma
419 | good-names=i,j,k,ex,Run,_
420 |
421 | # Bad variable names which should always be refused, separated by a comma
422 | bad-names=foo,bar,baz,toto,tutu,tata
423 |
424 | # Colon-delimited sets of names that determine each other's naming style when
425 | # the name regexes allow several styles.
426 | name-group=
427 |
428 | # Include a hint for the correct naming format with invalid-name
429 | include-naming-hint=no
430 |
431 | # Regular expression matching correct function names
432 | function-rgx=[a-z_][a-z0-9_]{2,30}$
433 |
434 | # Naming hint for function names
435 | function-name-hint=[a-z_][a-z0-9_]{2,30}$
436 |
437 | # Regular expression matching correct variable names
438 | variable-rgx=[a-z_][a-z0-9_]{2,30}$
439 |
440 | # Naming hint for variable names
441 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$
442 |
443 | # Regular expression matching correct constant names
444 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
445 |
446 | # Naming hint for constant names
447 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
448 |
449 | # Regular expression matching correct attribute names
450 | attr-rgx=[a-z_][a-z0-9_]{2,30}$
451 |
452 | # Naming hint for attribute names
453 | attr-name-hint=[a-z_][a-z0-9_]{2,30}$
454 |
455 | # Regular expression matching correct argument names
456 | argument-rgx=[a-z_][a-z0-9_]{2,30}$
457 |
458 | # Naming hint for argument names
459 | argument-name-hint=[a-z_][a-z0-9_]{2,30}$
460 |
461 | # Regular expression matching correct class attribute names
462 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
463 |
464 | # Naming hint for class attribute names
465 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
466 |
467 | # Regular expression matching correct inline iteration names
468 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
469 |
470 | # Naming hint for inline iteration names
471 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
472 |
473 | # Regular expression matching correct class names
474 | class-rgx=[A-Z_][a-zA-Z0-9]+$
475 |
476 | # Naming hint for class names
477 | class-name-hint=[A-Z_][a-zA-Z0-9]+$
478 |
479 | # Regular expression matching correct module names
480 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
481 |
482 | # Naming hint for module names
483 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
484 |
485 | # Regular expression matching correct method names
486 | method-rgx=[a-z_][a-z0-9_]{2,30}$
487 |
488 | # Naming hint for method names
489 | method-name-hint=[a-z_][a-z0-9_]{2,30}$
490 |
491 | # Regular expression which should only match function or class names that do
492 | # not require a docstring.
493 | no-docstring-rgx=^_
494 |
495 | # Minimum line length for functions/classes that require docstrings, shorter
496 | # ones are exempt.
497 | docstring-min-length=-1
498 |
499 |
500 | [ELIF]
501 |
502 | # Maximum number of nested blocks for function / method body
503 | max-nested-blocks=5
504 |
505 |
506 | [IMPORTS]
507 |
508 | # Deprecated modules which should not be used, separated by a comma
509 | deprecated-modules=regsub,TERMIOS,Bastion,rexec
510 |
511 | # Create a graph of every (i.e. internal and external) dependencies in the
512 | # given file (report RP0402 must not be disabled)
513 | import-graph=
514 |
515 | # Create a graph of external dependencies in the given file (report RP0402 must
516 | # not be disabled)
517 | ext-import-graph=
518 |
519 | # Create a graph of internal dependencies in the given file (report RP0402 must
520 | # not be disabled)
521 | int-import-graph=
522 |
523 |
524 | [DESIGN]
525 |
526 | # Maximum number of arguments for function / method
527 | max-args=5
528 |
529 | # Argument names that match this expression will be ignored. Default to name
530 | # with leading underscore
531 | ignored-argument-names=_.*
532 |
533 | # Maximum number of locals for function / method body
534 | max-locals=15
535 |
536 | # Maximum number of return / yield for function / method body
537 | max-returns=6
538 |
539 | # Maximum number of branch for function / method body
540 | max-branches=12
541 |
542 | # Maximum number of statements in function / method body
543 | max-statements=50
544 |
545 | # Maximum number of parents for a class (see R0901).
546 | max-parents=7
547 |
548 | # Maximum number of attributes for a class (see R0902).
549 | max-attributes=7
550 |
551 | # Minimum number of public methods for a class (see R0903).
552 | min-public-methods=2
553 |
554 | # Maximum number of public methods for a class (see R0904).
555 | max-public-methods=20
556 |
557 | # Maximum number of boolean expressions in a if statement
558 | max-bool-expr=5
559 |
560 |
561 | [CLASSES]
562 |
563 | # List of method names used to declare (i.e. assign) instance attributes.
564 | defining-attr-methods=__init__,__new__,setUp
565 |
566 | # List of valid names for the first argument in a class method.
567 | valid-classmethod-first-arg=cls
568 |
569 | # List of valid names for the first argument in a metaclass class method.
570 | valid-metaclass-classmethod-first-arg=mcs
571 |
572 | # List of member names, which should be excluded from the protected access
573 | # warning.
574 | exclude-protected=_asdict,_fields,_replace,_source,_make
575 |
576 |
577 | [EXCEPTIONS]
578 |
579 | # Exceptions that will emit a warning when being caught. Defaults to
580 | # "Exception"
581 | overgeneral-exceptions=Exception
582 |
--------------------------------------------------------------------------------
/ci_tools/check_python_version.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | if __name__ == "__main__":
4 | # Execute only if run as a script.
5 | # Check the arguments
6 | nbargs = len(sys.argv[1:])
7 | if nbargs != 1:
8 | raise ValueError("a mandatory argument is required: ")
9 |
10 | expected_version_str = sys.argv[1]
11 | try:
12 | expected_version = tuple(int(i) for i in expected_version_str.split("."))
13 | except Exception as e:
14 | raise ValueError("Error while parsing expected version %r: %r" % (expected_version, e))
15 |
16 | if len(expected_version) < 1:
17 | raise ValueError("At least a major is expected")
18 |
19 | if sys.version_info[0] != expected_version[0]:
20 | raise AssertionError("Major version does not match. Expected %r - Actual %r" % (expected_version_str, sys.version))
21 |
22 | if len(expected_version) >= 2 and sys.version_info[1] != expected_version[1]:
23 | raise AssertionError("Minor version does not match. Expected %r - Actual %r" % (expected_version_str, sys.version))
24 |
25 | if len(expected_version) >= 3 and sys.version_info[2] != expected_version[2]:
26 | raise AssertionError("Patch version does not match. Expected %r - Actual %r" % (expected_version_str, sys.version))
27 |
28 | print("SUCCESS - Actual python version %r matches expected one %r" % (sys.version, expected_version_str))
29 |
--------------------------------------------------------------------------------
/ci_tools/flake8-requirements.txt:
--------------------------------------------------------------------------------
1 | setuptools_scm>=3,<4
2 | chardet
3 | flake8>=3.6,<4
4 | flake8-html>=0.4,<1
5 | flake8-bandit>=2.1.1,<3
6 | bandit<1.7.3
7 | flake8-bugbear>=20.1.0,<21.0.0
8 | flake8-docstrings>=1.5,<2
9 | flake8-print>=3.1.1,<4
10 | flake8-tidy-imports>=4.2.1,<5
11 | flake8-copyright==0.2.2 # Internal forked repo to fix an issue, keep specific version
12 | pydocstyle>=5.1.1,<6
13 | pycodestyle>=2.6.0,<3
14 | mccabe>=0.6.1,<1
15 | naming>=0.5.1,<1
16 | pyflakes>=2.2,<3
17 | genbadge[flake8]
18 | jinja2>=3.0.0,<3.1.0
19 |
--------------------------------------------------------------------------------
/ci_tools/github_release.py:
--------------------------------------------------------------------------------
1 | # a clone of the ruby example https://gist.github.com/valeriomazzeo/5491aee76f758f7352e2e6611ce87ec1
2 | import os
3 | from os import path
4 |
5 | import re
6 |
7 | import click
8 | from click import Path
9 | from github import Github, UnknownObjectException
10 | # from valid8 import validate not compliant with python 2.7
11 |
12 |
13 | @click.command()
14 | @click.option('-u', '--user', help='GitHub username')
15 | @click.option('-p', '--pwd', help='GitHub password')
16 | @click.option('-s', '--secret', help='GitHub access token')
17 | @click.option('-r', '--repo-slug', help='Repo slug. i.e.: apple/swift')
18 | @click.option('-cf', '--changelog-file', help='Changelog file path')
19 | @click.option('-d', '--doc-url', help='Documentation url')
20 | @click.option('-df', '--data-file', help='Data file to upload', type=Path(exists=True, file_okay=True, dir_okay=False,
21 | resolve_path=True))
22 | @click.argument('tag')
23 | def create_or_update_release(user, pwd, secret, repo_slug, changelog_file, doc_url, data_file, tag):
24 | """
25 | Creates or updates (TODO)
26 | a github release corresponding to git tag .
27 | """
28 | # 1- AUTHENTICATION
29 | if user is not None and secret is None:
30 | # using username and password
31 | # validate('user', user, instance_of=str)
32 | assert isinstance(user, str)
33 | # validate('pwd', pwd, instance_of=str)
34 | assert isinstance(pwd, str)
35 | g = Github(user, pwd)
36 | elif user is None and secret is not None:
37 | # or using an access token
38 | # validate('secret', secret, instance_of=str)
39 | assert isinstance(secret, str)
40 | g = Github(secret)
41 | else:
42 | raise ValueError("You should either provide username/password OR an access token")
43 | click.echo("Logged in as {user_name}".format(user_name=g.get_user()))
44 |
45 | # 2- CHANGELOG VALIDATION
46 | regex_pattern = "[\s\S]*[\n][#]+[\s]*(?P[\S ]*%s[\S ]*)[\n]+?(?P[\s\S]*?)[\n]*?(\n#|$)" % re.escape(tag)
47 | changelog_section = re.compile(regex_pattern)
48 | if changelog_file is not None:
49 | # validate('changelog_file', changelog_file, custom=os.path.exists,
50 | # help_msg="changelog file should be a valid file path")
51 | assert os.path.exists(changelog_file), "changelog file should be a valid file path"
52 | with open(changelog_file) as f:
53 | contents = f.read()
54 |
55 | match = changelog_section.match(contents).groupdict()
56 | if match is None or len(match) != 2:
57 | raise ValueError("Unable to find changelog section matching regexp pattern in changelog file.")
58 | else:
59 | title = match['title']
60 | message = match['body']
61 | else:
62 | title = tag
63 | message = ''
64 |
65 | # append footer if doc url is provided
66 | message += "\n\nSee [documentation page](%s) for details." % doc_url
67 |
68 | # 3- REPOSITORY EXPLORATION
69 | # validate('repo_slug', repo_slug, instance_of=str, min_len=1, help_msg="repo_slug should be a non-empty string")
70 | assert isinstance(repo_slug, str) and len(repo_slug) > 0, "repo_slug should be a non-empty string"
71 | repo = g.get_repo(repo_slug)
72 |
73 | # -- Is there a tag with that name ?
74 | try:
75 | tag_ref = repo.get_git_ref("tags/" + tag)
76 | except UnknownObjectException:
77 | raise ValueError("No tag with name %s exists in repository %s" % (tag, repo.name))
78 |
79 | # -- Is there already a release with that tag name ?
80 | click.echo("Checking if release %s already exists in repository %s" % (tag, repo.name))
81 | try:
82 | release = repo.get_release(tag)
83 | if release is not None:
84 | raise ValueError("Release %s already exists in repository %s. Please set overwrite to True if you wish to "
85 | "update the release (Not yet supported)" % (tag, repo.name))
86 | except UnknownObjectException:
87 | # Release does not exist: we can safely create it.
88 | click.echo("Creating release %s on repo: %s" % (tag, repo.name))
89 | click.echo("Release title: '%s'" % title)
90 | click.echo("Release message:\n--\n%s\n--\n" % message)
91 | repo.create_git_release(tag=tag, name=title,
92 | message=message,
93 | draft=False, prerelease=False)
94 |
95 | # add the asset file if needed
96 | if data_file is not None:
97 | release = None
98 | while release is None:
99 | release = repo.get_release(tag)
100 | release.upload_asset(path=data_file, label=path.split(data_file)[1], content_type="application/gzip")
101 |
102 | # --- Memo ---
103 | # release.target_commitish # 'master'
104 | # release.tag_name # '0.5.0'
105 | # release.title # 'First public release'
106 | # release.body # markdown body
107 | # release.draft # False
108 | # release.prerelease # False
109 | # #
110 | # release.author
111 | # release.created_at # datetime.datetime(2018, 11, 9, 17, 49, 56)
112 | # release.published_at # datetime.datetime(2018, 11, 9, 20, 11, 10)
113 | # release.last_modified # None
114 | # #
115 | # release.id # 13928525
116 | # release.etag # 'W/"dfab7a13086d1b44fe290d5d04125124"'
117 | # release.url # 'https://api.github.com/repos/smarie/python-m5p/releases/13928525'
118 | # release.html_url # 'https://github.com/smarie/python-m5p/releases/tag/0.5.0'
119 | # release.tarball_url # 'https://api.github.com/repos/smarie/python-m5p/tarball/0.5.0'
120 | # release.zipball_url # 'https://api.github.com/repos/smarie/python-m5p/zipball/0.5.0'
121 | # release.upload_url # 'https://uploads.github.com/repos/smarie/python-m5p/releases/13928525/assets{?name,label}'
122 |
123 |
124 | if __name__ == '__main__':
125 | create_or_update_release()
126 |
--------------------------------------------------------------------------------
/ci_tools/nox_utils.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | import asyncio
4 | from collections import namedtuple
5 | from inspect import signature, isfunction
6 | import logging
7 | from pathlib import Path
8 | import shutil
9 | import subprocess
10 | import sys
11 | import os
12 |
13 | from typing import Sequence, Dict, Union, Iterable, Mapping, Any, IO, Tuple, Optional, List
14 |
15 | from makefun import wraps, remove_signature_parameters, add_signature_parameters
16 |
17 | import nox
18 | from nox.sessions import Session
19 |
20 |
21 | nox_logger = logging.getLogger("nox")
22 |
23 |
24 | PY27, PY35, PY36, PY37, PY38, PY39, PY310 = "2.7", "3.5", "3.6", "3.7", "3.8", "3.9", "3.10"
25 | DONT_INSTALL = "dont_install"
26 |
27 |
28 | def power_session(
29 | func=None,
30 | envs=None,
31 | grid_param_name="env",
32 | python=None,
33 | py=None,
34 | reuse_venv=None,
35 | name=None,
36 | venv_backend=None,
37 | venv_params=None,
38 | logsdir=None,
39 | **kwargs
40 | ):
41 | """A nox.session on steroids
42 |
43 | :param func:
44 | :param envs: a dictionary {key: dict_of_params} where key is either the python version of a tuple (python version,
45 | grid id) and all keys in the dict_of_params must be the same in all entries. The decorated function should
46 | have one parameter for each of these keys, they will be injected with the value.
47 | :param grid_param_name: when the key in `envs` is a tuple, this name will be the name of the generated parameter to
48 | iterate through the various combinations for each python version.
49 | :param python:
50 | :param py:
51 | :param reuse_venv:
52 | :param name:
53 | :param venv_backend:
54 | :param venv_params:
55 | :param logsdir:
56 | :param kwargs:
57 | :return:
58 | """
59 | if func is not None:
60 | return power_session()(func)
61 | else:
62 | def combined_decorator(f):
63 | # replace Session with PowerSession
64 | f = with_power_session(f)
65 |
66 | # open a log file for the session, use it to stream the commands stdout and stderrs,
67 | # and possibly inject the log file in the session function
68 | if logsdir is not None:
69 | f = with_logfile(logs_dir=logsdir)(f)
70 |
71 | # decorate with @nox.session and possibly @nox.parametrize to create the grid
72 | return nox_session_with_grid(python=python, py=py, envs=envs, reuse_venv=reuse_venv, name=name,
73 | grid_param_name=grid_param_name, venv_backend=venv_backend,
74 | venv_params=venv_params, **kwargs)(f)
75 |
76 | return combined_decorator
77 |
78 |
79 | def with_power_session(f=None):
80 | """ A decorator to patch the session objects in order to add all methods from Session2"""
81 |
82 | if f is not None:
83 | return with_power_session()(f)
84 |
85 | def _decorator(f):
86 | @wraps(f)
87 | def _f_wrapper(**kwargs):
88 | # patch the session arg
89 | PowerSession.patch(kwargs['session'])
90 |
91 | # finally execute the session
92 | return f(**kwargs)
93 |
94 | return _f_wrapper
95 |
96 | return _decorator
97 |
98 |
99 | class PowerSession(Session):
100 | """
101 | Our nox session improvements
102 | """
103 |
104 | # ------------ commandline runners -----------
105 |
106 | def run2(self,
107 | command: Union[Iterable[str], str],
108 | logfile: Union[bool, str, Path] = True,
109 | **kwargs):
110 | """
111 | An improvement of session.run that is able to
112 |
113 | - automatically split the provided command if it is a string
114 | - use a log file
115 |
116 | :param command:
117 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path.
118 | :param kwargs:
119 | :return:
120 | """
121 | if isinstance(command, str):
122 | command = command.split(' ')
123 |
124 | self.run(*command, logfile=logfile, **kwargs)
125 |
126 | def run_multi(self,
127 | cmds: str,
128 | logfile: Union[bool, str, Path] = True,
129 | **kwargs):
130 | """
131 | An improvement of session.run that is able to
132 |
133 | - support multiline strings
134 | - use a log file
135 |
136 | :param cmds:
137 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path.
138 | :param kwargs:
139 | :return:
140 | """
141 | for cmdline in (line for line in cmds.splitlines() if line):
142 | self.run2(cmdline, logfile=logfile, **kwargs)
143 |
144 | # ------------ requirements installers -----------
145 |
146 | def install_reqs(
147 | self,
148 | # pre wired phases
149 | setup=False,
150 | install=False,
151 | tests=False,
152 | extras=(),
153 | # custom phase
154 | phase=None,
155 | phase_reqs=None,
156 | versions_dct=None
157 | ):
158 | """
159 | A high-level helper to install requirements from the various project files
160 |
161 | - pyproject.toml "[build-system] requires" (if setup=True)
162 | - setup.cfg "[options] setup_requires" (if setup=True)
163 | - setup.cfg "[options] install_requires" (if install=True)
164 | - setup.cfg "[options] test_requires" (if tests=True)
165 | - setup.cfg "[options.extras_require] <...>" (if extras=(a tuple of extras))
166 |
167 | Two additional mechanisms are provided in order to customize how packages are installed.
168 |
169 | Conda packages
170 | --------------
171 | If the session runs on a conda environment, you can add a [tool.conda] section to your pyproject.toml. This
172 | section should contain a `conda_packages` entry containing the list of package names that should be installed
173 | using conda instead of pip.
174 |
175 | ```
176 | [tool.conda]
177 | # Declare that the following packages should be installed with conda instead of pip
178 | # Note: this includes packages declared everywhere, here and in setup.cfg
179 | conda_packages = [
180 | "setuptools",
181 | "wheel",
182 | "pip"
183 | ]
184 | ```
185 |
186 | Version constraints
187 | -------------------
188 | In addition to the version constraints in the pyproject.toml and setup.cfg, you can specify additional temporary
189 | constraints with the `versions_dct` argument , for example if you know that this executes on a specific python
190 | version that requires special care.
191 | For this, simply pass a dictionary of {'pkg_name': 'pkg_constraint'} for example {"pip": ">10"}.
192 |
193 | """
194 |
195 | # Read requirements from pyproject.toml
196 | toml_setup_reqs, toml_use_conda_for = read_pyproject_toml()
197 | if setup:
198 | self.install_any("pyproject.toml#build-system", toml_setup_reqs,
199 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct)
200 |
201 | # Read test requirements from setup.cfg
202 | setup_cfg = read_setuptools_cfg()
203 | if setup:
204 | self.install_any("setup.cfg#setup_requires", setup_cfg.setup_requires,
205 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct)
206 | if install:
207 | self.install_any("setup.cfg#install_requires", setup_cfg.install_requires,
208 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct)
209 | if tests:
210 | self.install_any("setup.cfg#tests_requires", setup_cfg.tests_requires,
211 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct)
212 |
213 | for extra in extras:
214 | self.install_any("setup.cfg#extras_require#%s" % extra, setup_cfg.extras_require[extra],
215 | use_conda_for=toml_use_conda_for, versions_dct=versions_dct)
216 |
217 | if phase is not None:
218 | self.install_any(phase, phase_reqs, use_conda_for=toml_use_conda_for, versions_dct=versions_dct)
219 |
220 | def install_any(self,
221 | phase_name: str,
222 | pkgs: Sequence[str],
223 | use_conda_for: Sequence[str] = (),
224 | versions_dct: Dict[str, str] = None,
225 | logfile: Union[bool, str, Path] = True,
226 | ):
227 | """Install the `pkgs` provided with `session.install(*pkgs)`, except for those present in `use_conda_for`"""
228 |
229 | nox_logger.debug("\nAbout to install *%s* requirements: %s.\n "
230 | "Conda pkgs are %s" % (phase_name, pkgs, use_conda_for))
231 |
232 | # use the provided versions dictionary to update the versions
233 | if versions_dct is None:
234 | versions_dct = dict()
235 | pkgs = [pkg + versions_dct.get(pkg, "") for pkg in pkgs if versions_dct.get(pkg, "") != DONT_INSTALL]
236 |
237 | # install on conda... if the session uses conda backend
238 | if not isinstance(self.virtualenv, nox.virtualenv.CondaEnv):
239 | conda_pkgs = []
240 | else:
241 | conda_pkgs = [pkg_req for pkg_req in pkgs if any(get_req_pkg_name(pkg_req) == c for c in use_conda_for)]
242 | if len(conda_pkgs) > 0:
243 | nox_logger.info("[%s] Installing requirements with conda: %s" % (phase_name, conda_pkgs))
244 | self.conda_install2(*conda_pkgs, logfile=logfile)
245 |
246 | pip_pkgs = [pkg_req for pkg_req in pkgs if pkg_req not in conda_pkgs]
247 | # safety: make sure that nothing went modified or forgotten
248 | assert set(conda_pkgs).union(set(pip_pkgs)) == set(pkgs)
249 | if len(pip_pkgs) > 0:
250 | nox_logger.info("[%s] Installing requirements with pip: %s" % (phase_name, pip_pkgs))
251 | self.install2(*pip_pkgs, logfile=logfile)
252 |
253 | def conda_install2(self,
254 | *conda_pkgs,
255 | logfile: Union[bool, str, Path] = True,
256 | **kwargs
257 | ):
258 | """
259 | Same as session.conda_install() but with support for `logfile`.
260 |
261 | :param conda_pkgs:
262 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path.
263 | :return:
264 | """
265 | return self.conda_install(*conda_pkgs, logfile=logfile, **kwargs)
266 |
267 | def install2(self,
268 | *pip_pkgs,
269 | logfile: Union[bool, str, Path] = True,
270 | **kwargs
271 | ):
272 | """
273 | Same as session.install() but with support for `logfile`.
274 |
275 | :param pip_pkgs:
276 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path.
277 | :return:
278 | """
279 | return self.install(*pip_pkgs, logfile=logfile, **kwargs)
280 |
281 | def get_session_id(self):
282 | """Return the session id"""
283 | return Path(self.bin).name
284 |
285 | @classmethod
286 | def is_power_session(cls, session: Session):
287 | return PowerSession.install2.__name__ in session.__dict__
288 |
289 | @classmethod
290 | def patch(cls, session: Session):
291 | """
292 | Add all methods from this class to the provided object.
293 | Note that we could instead have created a proper proxy... but complex for not a lot of benefit.
294 | :param session:
295 | :return:
296 | """
297 | if not cls.is_power_session(session):
298 | for m_name, m in cls.__dict__.items():
299 | if not isfunction(m):
300 | continue
301 | if m is cls.patch:
302 | continue
303 | if not hasattr(session, m_name):
304 | setattr(session.__class__, m_name, m)
305 |
306 | return True
307 |
308 |
309 | # ------------- requirements related
310 |
311 |
312 | def read_pyproject_toml():
313 | """
314 | Reads the `pyproject.toml` and returns
315 |
316 | - a list of setup requirements from [build-system] requires
317 | - sub-list of these requirements that should be installed with conda, from [tool.my_conda] conda_packages
318 | """
319 | if os.path.exists("pyproject.toml"):
320 | import toml
321 | nox_logger.debug("\nA `pyproject.toml` file exists. Loading it.")
322 | pyproject = toml.load("pyproject.toml")
323 | requires = pyproject['build-system']['requires']
324 | conda_pkgs = pyproject['tool']['conda']['conda_packages']
325 | return requires, conda_pkgs
326 | else:
327 | raise FileNotFoundError("No `pyproject.toml` file exists. No dependency will be installed ...")
328 |
329 |
330 | SetupCfg = namedtuple('SetupCfg', ('setup_requires', 'install_requires', 'tests_requires', 'extras_require'))
331 |
332 |
333 | def read_setuptools_cfg():
334 | """
335 | Reads the `setup.cfg` file and extracts the various requirements lists
336 | """
337 | # see https://stackoverflow.com/a/30679041/7262247
338 | from setuptools import Distribution
339 | dist = Distribution()
340 | dist.parse_config_files()
341 | return SetupCfg(setup_requires=dist.setup_requires,
342 | install_requires=dist.install_requires,
343 | tests_requires=dist.tests_require,
344 | extras_require=dist.extras_require)
345 |
346 |
347 | def get_req_pkg_name(r):
348 | """Return the package name part of a python package requirement.
349 |
350 | For example
351 | "funcsigs;python<'3.5'" will return "funcsigs"
352 | "pytest>=3" will return "pytest"
353 | """
354 | return r.replace('<', '=').replace('>', '=').replace(';', '=').split("=")[0]
355 |
356 |
357 | # ------------- log related
358 |
359 |
360 | def with_logfile(logs_dir: Path,
361 | logfile_arg: str = "logfile",
362 | logfile_handler_arg: str = "logfilehandler"
363 | ):
364 | """ A decorator to inject a logfile"""
365 |
366 | def _decorator(f):
367 | # check the signature of f
368 | foo_sig = signature(f)
369 | needs_logfile_injection = logfile_arg in foo_sig.parameters
370 | needs_logfilehandler_injection = logfile_handler_arg in foo_sig.parameters
371 |
372 | # modify the exposed signature if needed
373 | new_sig = None
374 | if needs_logfile_injection:
375 | new_sig = remove_signature_parameters(foo_sig, logfile_arg)
376 | if needs_logfilehandler_injection:
377 | new_sig = remove_signature_parameters(foo_sig, logfile_handler_arg)
378 |
379 | @wraps(f, new_sig=new_sig)
380 | def _f_wrapper(**kwargs):
381 | # find the session arg
382 | session = kwargs['session'] # type: Session
383 |
384 | # add file handler to logger
385 | logfile = logs_dir / ("%s.log" % PowerSession.get_session_id(session))
386 | error_logfile = logfile.with_name("ERROR_%s" % logfile.name)
387 | success_logfile = logfile.with_name("SUCCESS_%s" % logfile.name)
388 | # delete old files if present
389 | for _f in (logfile, error_logfile, success_logfile):
390 | if _f.exists():
391 | _f.unlink()
392 |
393 | # add a FileHandler to the logger
394 | logfile_handler = log_to_file(logfile)
395 |
396 | # inject the log file / log file handler in the args:
397 | if needs_logfile_injection:
398 | kwargs[logfile_arg] = logfile
399 | if needs_logfilehandler_injection:
400 | kwargs[logfile_handler_arg] = logfile_handler
401 |
402 | # finally execute the session
403 | try:
404 | res = f(**kwargs)
405 | except Exception as e:
406 | # close and detach the file logger and rename as ERROR_....log
407 | remove_file_logger()
408 | logfile.rename(error_logfile)
409 | raise e
410 | else:
411 | # close and detach the file logger and rename as SUCCESS_....log
412 | remove_file_logger()
413 | logfile.rename(success_logfile)
414 | return res
415 |
416 | return _f_wrapper
417 |
418 | return _decorator
419 |
420 |
421 | def log_to_file(file_path: Union[str, Path]
422 | ):
423 | """
424 | Closes and removes all file handlers from the nox logger,
425 | and add a new one to the provided file path
426 |
427 | :param file_path:
428 | :return:
429 | """
430 | for h in list(nox_logger.handlers):
431 | if isinstance(h, logging.FileHandler):
432 | h.close()
433 | nox_logger.removeHandler(h)
434 | fh = logging.FileHandler(str(file_path), mode='w')
435 | nox_logger.addHandler(fh)
436 | return fh
437 |
438 |
439 | def get_current_logfile_handler():
440 | """
441 | Returns the current unique log file handler (see `log_to_file`)
442 | """
443 | for h in list(nox_logger.handlers):
444 | if isinstance(h, logging.FileHandler):
445 | return h
446 | return None
447 |
448 |
449 | def get_log_file_stream():
450 | """
451 | Returns the output stream for the current log file handler if any (see `log_to_file`)
452 | """
453 | h = get_current_logfile_handler()
454 | if h is not None:
455 | return h.stream
456 | return None
457 |
458 |
459 | def remove_file_logger():
460 | """
461 | Closes and detaches the current logfile handler
462 | :return:
463 | """
464 | h = get_current_logfile_handler()
465 | if h is not None:
466 | h.close()
467 | nox_logger.removeHandler(h)
468 |
469 |
470 | # ------------ environment grid / parametrization related
471 |
472 | def nox_session_with_grid(python = None,
473 | py = None,
474 | envs: Mapping[str, Mapping[str, Any]] = None,
475 | reuse_venv: Optional[bool] = None,
476 | name: Optional[str] = None,
477 | venv_backend: Any = None,
478 | venv_params: Any = None,
479 | grid_param_name: str = None,
480 | **kwargs
481 | ):
482 | """
483 | Since nox is not yet capable to define a build matrix with python and parameters mixed in the same parametrize
484 | this implements it with a dirty hack.
485 | To remove when https://github.com/theacodes/nox/pull/404 is complete
486 |
487 | :param envs:
488 | :param env_python_key:
489 | :return:
490 | """
491 | if envs is None:
492 | # Fast track default to @nox.session
493 | return nox.session(python=python, py=py, reuse_venv=reuse_venv, name=name, venv_backend=venv_backend,
494 | venv_params=venv_params, **kwargs)
495 | else:
496 | # Current limitation : session param names can be 'python' or 'py' only
497 | if py is not None or python is not None:
498 | raise ValueError("`python` session argument can not be provided both directly and through the "
499 | "`env` with `session_param_names`")
500 |
501 | # First examine the env and collect the parameter values for python
502 | all_python = []
503 | all_params = []
504 |
505 | env_contents_names = None
506 | has_parameter = None
507 | for env_id, env_params in envs.items():
508 | # consistency checks for the env_id
509 | if has_parameter is None:
510 | has_parameter = isinstance(env_id, tuple)
511 | else:
512 | if has_parameter != isinstance(env_id, tuple):
513 | raise ValueError("All keys in env should be tuples, or not be tuples. Error for %r" % env_id)
514 |
515 | # retrieve python version and parameter
516 | if not has_parameter:
517 | if env_id not in all_python:
518 | all_python.append(env_id)
519 | else:
520 | if len(env_id) != 2:
521 | raise ValueError("Only a size-2 tuple can be used as env id")
522 | py_id, param_id = env_id
523 | if py_id not in all_python:
524 | all_python.append(py_id)
525 | if param_id not in all_params:
526 | all_params.append(param_id)
527 |
528 | # consistency checks for the dict contents.
529 | if env_contents_names is None:
530 | env_contents_names = set(env_params.keys())
531 | else:
532 | if env_contents_names != set(env_params.keys()):
533 | raise ValueError("Environment %r parameters %r does not match parameters in the first environment: %r"
534 | % (env_id, env_contents_names, set(env_params.keys())))
535 |
536 | if has_parameter and not grid_param_name:
537 | raise ValueError("You must provide a grid parameter name when the env keys are tuples.")
538 |
539 | def _decorator(f):
540 | s_name = name if name is not None else f.__name__
541 | for pyv, _param in product(all_python, all_params):
542 | if (pyv, _param) not in envs:
543 | # create a dummy folder to avoid creating a useless venv ?
544 | env_dir = Path(".nox") / ("%s-%s-%s-%s" % (s_name, pyv.replace('.', '-'), grid_param_name, _param))
545 | env_dir.mkdir(parents=True, exist_ok=True)
546 |
547 | # check the signature of f
548 | foo_sig = signature(f)
549 | missing = env_contents_names - set(foo_sig.parameters)
550 | if len(missing) > 0:
551 | raise ValueError("Session function %r does not contain environment parameter(s) %r" % (f.__name__, missing))
552 |
553 | # modify the exposed signature if needed
554 | new_sig = None
555 | if len(env_contents_names) > 0:
556 | new_sig = remove_signature_parameters(foo_sig, *env_contents_names)
557 |
558 | if has_parameter:
559 | if grid_param_name in foo_sig.parameters:
560 | raise ValueError("Internal error, this parameter has a reserved name: %r" % grid_param_name)
561 | else:
562 | new_sig = add_signature_parameters(new_sig, last=(grid_param_name,))
563 |
564 | @wraps(f, new_sig=new_sig)
565 | def _f_wrapper(**kwargs):
566 | # find the session arg
567 | session = kwargs['session'] # type: Session
568 |
569 | # get the versions to use for this environment
570 | try:
571 | if has_parameter:
572 | grid_param = kwargs.pop(grid_param_name)
573 | params_dct = envs[(session.python, grid_param)]
574 | else:
575 | params_dct = envs[session.python]
576 | except KeyError:
577 | # Skip this session, it is a dummy one
578 | nox_logger.warning(
579 | "Skipping configuration, this is not supported in python version %r" % session.python)
580 | return
581 |
582 | # inject the parameters in the args:
583 | kwargs.update(params_dct)
584 |
585 | # finally execute the session
586 | return f(**kwargs)
587 |
588 | if has_parameter:
589 | _f_wrapper = nox.parametrize(grid_param_name, all_params)(_f_wrapper)
590 |
591 | _f_wrapper = nox.session(python=all_python, reuse_venv=reuse_venv, name=name,
592 | venv_backend=venv_backend, venv_params=venv_params)(_f_wrapper)
593 | return _f_wrapper
594 |
595 | return _decorator
596 |
597 |
598 | # ----------- other goodies
599 |
600 |
601 | def rm_file(folder: Union[str, Path]
602 | ):
603 | """Since on windows Path.unlink throws permission error sometimes, os.remove is preferred."""
604 | if isinstance(folder, str):
605 | folder = Path(folder)
606 |
607 | if folder.exists():
608 | os.remove(str(folder))
609 | # Folders.site.unlink() --> possible PermissionError
610 |
611 |
612 | def rm_folder(folder: Union[str, Path]
613 | ):
614 | """Since on windows Path.unlink throws permission error sometimes, shutil is preferred."""
615 | if isinstance(folder, str):
616 | folder = Path(folder)
617 |
618 | if folder.exists():
619 | shutil.rmtree(str(folder))
620 | # Folders.site.unlink() --> possible PermissionError
621 |
622 |
623 | # --- the patch of popen able to tee to logfile --
624 |
625 |
626 | import nox.popen as nox_popen_module
627 | orig_nox_popen = nox_popen_module.popen
628 |
629 |
630 | class LogFileStreamCtx:
631 | def __init__(self, logfile_stream):
632 | self.logfile_stream = logfile_stream
633 |
634 | def __enter__(self):
635 | return self.logfile_stream
636 |
637 | def __exit__(self, exc_type, exc_val, exc_tb):
638 | pass
639 |
640 |
641 | def patched_popen(
642 | args: Sequence[str],
643 | env: Mapping[str, str] = None,
644 | silent: bool = False,
645 | stdout: Union[int, IO] = None,
646 | stderr: Union[int, IO] = subprocess.STDOUT,
647 | logfile: Union[bool, str, Path] = None,
648 | **kwargs
649 | ) -> Tuple[int, str]:
650 | """
651 | Our patch of nox.popen.popen().
652 |
653 | Current behaviour in `nox` is
654 |
655 | - when `silent=True` (default), process err is redirected to STDOUT and process out is captured in a PIPE and sent
656 | to the logger (that does not displaying it :) )
657 |
658 | - when `silent=False` (explicitly set, or when nox is run with verbose flag), process out and process err are both
659 | redirected to STDOUT.
660 |
661 | Our implementation allows us to be a little more flexible:
662 |
663 | - if logfile is True or a string/Path, both process err and process out are both TEE-ed to logfile
664 | - at the same time, the above behaviour remains.
665 |
666 | :param args:
667 | :param env:
668 | :param silent:
669 | :param stdout:
670 | :param stderr:
671 | :param logfile: None/False (normal nox behaviour), or True (using nox file handler), or a file path.
672 | :return:
673 | """
674 | logfile_stream = get_log_file_stream()
675 |
676 | if logfile in (None, False) or (logfile is True and logfile_stream is None):
677 | # execute popen as usual
678 | return orig_nox_popen(args=args, env=env, silent=silent, stdout=stdout, stderr=stderr, **kwargs)
679 |
680 | else:
681 | # we'll need to tee the popen
682 | if logfile is True:
683 | ctx = LogFileStreamCtx
684 | else:
685 | ctx = lambda _: open(logfile, "a")
686 |
687 | with ctx(logfile_stream) as log_file_stream:
688 | if silent and stdout is not None:
689 | raise ValueError(
690 | "Can not specify silent and stdout; passing a custom stdout always silences the commands output in "
691 | "Nox's log."
692 | )
693 |
694 | shell = kwargs.get("shell", False)
695 | if shell:
696 | raise ValueError("Using shell=True is not yet supported with async streaming to log files")
697 |
698 | if stdout is not None or stderr is not subprocess.STDOUT:
699 | raise ValueError("Using custom streams is not yet supported with async popen")
700 |
701 | # old way
702 | # proc = subprocess.Popen(args, env=env, stdout=stdout, stderr=stderr)
703 |
704 | # New way: use asyncio to stream correctly
705 | # Note: if keyboard interrupts do not work we should check
706 | # https://mail.python.org/pipermail/async-sig/2017-August/000374.html maybe or the following threads.
707 |
708 | # define the async coroutines
709 | async def async_popen():
710 | process = await asyncio.create_subprocess_exec(*args, env=env, stdout=asyncio.subprocess.PIPE,
711 | stderr=asyncio.subprocess.PIPE, **kwargs)
712 |
713 | # bind the out and err streams - see https://stackoverflow.com/a/59041913/7262247
714 | # to mimic nox behaviour we only use a single capturing list
715 | outlines = []
716 | await asyncio.wait([
717 | # process out is only redirected to STDOUT if not silent
718 | _read_stream(process.stdout, lambda l: tee(l, sinklist=outlines, sinkstream=log_file_stream,
719 | quiet=silent, verbosepipe=sys.stdout)),
720 | # process err is always redirected to STDOUT (quiet=False) with a specific label
721 | _read_stream(process.stderr, lambda l: tee(l, sinklist=outlines, sinkstream=log_file_stream,
722 | quiet=False, verbosepipe=sys.stdout, label="ERR:"))
723 | ])
724 | return_code = await process.wait() # make sur the process has ended and retrieve its return code
725 | return return_code, outlines
726 |
727 | # run the coroutine in the event loop
728 | loop = asyncio.get_event_loop()
729 | return_code, outlines = loop.run_until_complete(async_popen())
730 |
731 | # just in case, flush everything
732 | log_file_stream.flush()
733 | sys.stdout.flush()
734 | sys.stderr.flush()
735 |
736 | if silent:
737 | # same behaviour as in nox: this will be passed to the logger, and it will act depending on verbose flag
738 | out = "\n".join(outlines) if len(outlines) > 0 else ""
739 | else:
740 | # already written to stdout, no need to capture
741 | out = ""
742 |
743 | return return_code, out
744 |
745 |
746 | async def _read_stream(stream, callback):
747 | """Helper async coroutine to read from a stream line by line and write them in callback"""
748 | while True:
749 | line = await stream.readline()
750 | if line:
751 | callback(line)
752 | else:
753 | break
754 |
755 |
756 | def tee(linebytes, sinklist, sinkstream, verbosepipe, quiet, label=""):
757 | """
758 | Helper routine to read a line, decode it, and append it to several sinks:
759 |
760 | - an optional `sinklist` list that will receive the decoded string in its "append" method
761 | - an optional `sinkstream` stream that will receive the decoded string in its "writelines" method
762 | - an optional `verbosepipe` stream that will receive only when quiet=False, the decoded string through a print
763 |
764 | append it to the sink, and if quiet=False, write it to pipe too.
765 | """
766 | line = linebytes.decode('utf-8').rstrip()
767 |
768 | if sinklist is not None:
769 | sinklist.append(line)
770 |
771 | if sinkstream is not None:
772 | sinkstream.write(line + "\n")
773 | sinkstream.flush()
774 |
775 | if not quiet and verbosepipe is not None:
776 | print(label, line, file=verbosepipe)
777 | verbosepipe.flush()
778 |
779 |
780 | def patch_popen():
781 | nox_popen_module.popen = patched_popen
782 |
783 | from nox.command import popen
784 | if popen is not patched_popen:
785 | nox.command.popen = patched_popen
786 |
787 | # change event loop on windows
788 | # see https://stackoverflow.com/a/44639711/7262247
789 | # and https://docs.python.org/3/library/asyncio-platforms.html#subprocess-support-on-windows
790 | if 'win32' in sys.platform:
791 | # Windows specific event-loop policy & cmd
792 | asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
793 | # cmds = [['C:/Windows/system32/HOSTNAME.EXE']]
794 |
795 | # loop = asyncio.ProactorEventLoop()
796 | # asyncio.set_event_loop(loop)
797 |
798 |
799 | patch_popen()
800 |
--------------------------------------------------------------------------------
/docs/changelog.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ### 0.3.3 - Added zenodo
4 |
5 | * Added zenodo entry
6 |
7 | ### 0.3.2 - Fixed compliance with sklearn 1.3.0
8 |
9 | * Fixed `AttributeError: 'super' object has no attribute 'fit' `.
10 | PR [#16](https://github.com/smarie/python-m5p/pull/16) by [lccatala](https://github.com/lccatala)
11 |
12 | ### 0.3.1 - Fixed compliance with sklearn 1.1.0
13 |
14 | * Fixed `TypeError: fit() got an unexpected keyword argument 'X_idx_sorted'`.
15 | PR [#11](https://github.com/smarie/python-m5p/pull/11) by [preinaj](https://github.com/preinaj)
16 |
17 | ### 0.3.0 - First public version
18 |
19 | * Initial fork from private repository + [scikit-learn/scikit-learn#13732](https://github.com/scikit-learn/scikit-learn/pull/13732). Fixes [#2](https://github.com/smarie/python-m5p/issues/1)
20 | * Two gallery examples (1D sine and pw-linear)
21 |
22 | ### 0.0.1 - Initial repo setup
23 |
24 | First CI roundtrip. Fixes [#1](https://github.com/smarie/python-m5p/issues/1)
25 |
--------------------------------------------------------------------------------
/docs/examples/1_simple_1D_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple 1D example
3 | =================
4 |
5 | A 1D regression with M5P decision tree.
6 |
7 | The tree is used to fit a sine curve with addition noisy observation. As a
8 | result, it learns local linear regressions approximating the sine curve.
9 |
10 | We can see the role of pruning (Tree 2) and pruning + smoothing (Tree 3).
11 | """
12 | # %%
13 | # Import the necessary modules and libraries
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 |
17 | from m5py import M5Prime, export_text_m5
18 |
19 | # %%
20 | # Create a random dataset
21 | rng = np.random.RandomState(1)
22 | X = np.sort(5 * rng.rand(80, 1), axis=0)
23 | y = np.sin(X).ravel()
24 | y[::5] += 0.5 * (0.5 - rng.rand(16))
25 |
26 | # %%
27 | # Fit regression model
28 | regr_1 = M5Prime(use_smoothing=False, use_pruning=False)
29 | regr_1_label = "Tree 1"
30 | regr_1.fit(X, y)
31 | regr_2 = M5Prime(use_smoothing=False)
32 | regr_2_label = "Tree 2"
33 | regr_2.fit(X, y)
34 | regr_3 = M5Prime(smoothing_constant=5)
35 | regr_3_label = "Tree 3"
36 | regr_3.fit(X, y)
37 |
38 | # %%
39 | # Predict
40 | X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
41 | y_1 = regr_1.predict(X_test)
42 | y_2 = regr_2.predict(X_test)
43 | y_3 = regr_3.predict(X_test)
44 |
45 | # %%
46 | # Print the trees
47 | print("\n----- %s" % regr_1_label)
48 | print(regr_1.as_pretty_text())
49 |
50 | # %%
51 | print("\n----- %s" % regr_2_label)
52 | print(regr_2.as_pretty_text())
53 |
54 | # %%
55 | print("\n----- %s" % regr_3_label)
56 | print(export_text_m5(regr_3, out_file=None)) # equivalent to as_pretty_text
57 |
58 | # %%
59 | # Plot the results
60 | fig = plt.figure()
61 | plt.scatter(X, y, s=20, edgecolor="black",
62 | c="darkorange", label="data")
63 | plt.plot(X_test, y_1, color="cornflowerblue", label=regr_1_label, linewidth=2)
64 | plt.plot(X_test, y_2, color="yellowgreen", label=regr_2_label, linewidth=2)
65 | plt.plot(X_test, y_3, color="green", label=regr_3_label, linewidth=2)
66 | plt.xlabel("data")
67 | plt.ylabel("target")
68 | plt.title("Decision Tree Regression")
69 | plt.legend()
70 | fig
71 |
--------------------------------------------------------------------------------
/docs/examples/2_pw_linear_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | Piecewise-Linear artificial example
3 | ===================================
4 |
5 | This demo is reproducing the results from the "Artificial Dataset" in the M5
6 | paper, named `pw-linear` in the M5' paper.
7 | """
8 | # %%
9 | # Import the necessary modules and libraries
10 | import numpy as np
11 | import pandas as pd
12 | import seaborn as sns
13 |
14 | from sklearn.model_selection import cross_val_score
15 | from sklearn.tree import DecisionTreeRegressor, export_text
16 | from m5py import M5Prime
17 |
18 | # %%
19 | # Create a random dataset
20 | rng = np.random.RandomState(1)
21 | nb_samples = 200
22 |
23 | X1 = rng.randint(0, 2, nb_samples) * 2 - 1
24 | X2 = rng.randint(-1, 2, nb_samples)
25 | X3 = rng.randint(-1, 2, nb_samples)
26 | X4 = rng.randint(-1, 2, nb_samples)
27 | X5 = rng.randint(-1, 2, nb_samples)
28 | X6 = rng.randint(-1, 2, nb_samples)
29 | X7 = rng.randint(-1, 2, nb_samples)
30 | X8 = rng.randint(-1, 2, nb_samples)
31 | X9 = rng.randint(-1, 2, nb_samples)
32 | X10 = rng.randint(-1, 2, nb_samples)
33 |
34 | feature_names = ["X%i" % i for i in range(1, 11)]
35 | X = np.c_[X1, X2, X3, X4, X5, X6, X7, X8, X9, X10]
36 |
37 | y = np.where(
38 | X1 > 0,
39 | 3 + 3 * X2 + 2 * X3 + X4,
40 | -3 + 3 * X5 + 2 * X6 + X7
41 | ) + rng.normal(loc=0., scale=2 ** 0.5, size=nb_samples)
42 |
43 | # %%
44 | # Define regression models and evaluate them on 10-fold CV
45 | regr_0 = DecisionTreeRegressor()
46 | regr_0_label = "Tree 0"
47 | regr_0_scores = cross_val_score(regr_0, X, y, cv=10)
48 |
49 | regr_1 = M5Prime(use_smoothing=False, use_pruning=False)
50 | regr_1_label = "Tree 1"
51 | regr_1_scores = cross_val_score(regr_1, X, y, cv=10)
52 |
53 | regr_2 = M5Prime(use_smoothing=False)
54 | regr_2_label = "Tree 2"
55 | regr_2_scores = cross_val_score(regr_2, X, y, cv=10)
56 |
57 | regr_3 = M5Prime(use_smoothing=True)
58 | regr_3_label = "Tree 3"
59 | regr_3_scores = cross_val_score(regr_3, X, y, cv=10)
60 |
61 | scores = np.c_[regr_0_scores, regr_1_scores, regr_2_scores, regr_3_scores]
62 | avgs = scores.mean(axis=0)
63 | stds = scores.std(axis=0)
64 | labels = [regr_0_label, regr_1_label, regr_2_label, regr_3_label]
65 |
66 | scores_df = pd.DataFrame(data=scores, columns=labels)
67 | sns.violinplot(data=scores_df)
68 |
69 | # %%
70 | # Fit the final models and print the trees:
71 | #
72 | regr_0.fit(X, y)
73 | print("\n----- %s" % regr_0_label)
74 | print(export_text(regr_0, feature_names=feature_names))
75 |
76 | # %%
77 | regr_1.fit(X, y)
78 | print("\n----- %s" % regr_1_label)
79 | print(regr_1.as_pretty_text(feature_names=feature_names))
80 |
81 | # %%
82 | regr_2.fit(X, y)
83 | print("\n----- %s" % regr_2_label)
84 | print(regr_2.as_pretty_text(feature_names=feature_names))
85 |
86 | # %%
87 | regr_3.fit(X, y)
88 | print("\n----- %s" % regr_3_label)
89 | print(regr_3.as_pretty_text(feature_names=feature_names))
90 |
--------------------------------------------------------------------------------
/docs/examples/Readme.md:
--------------------------------------------------------------------------------
1 | # Usage examples
2 |
3 | These examples demonstrate how to use the library.
4 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # `m5py`
2 |
3 | *`scikit-learn`-compliant M5 / M5' model trees for python*
4 |
5 | [](https://pypi.python.org/pypi/m5py/) [](https://github.com/smarie/python-m5p/actions/workflows/base.yml) [](./reports/junit/report.html) [](./reports/coverage/index.html) [](https://codecov.io/gh/smarie/python-m5p) [](./reports/flake8/index.html)
6 |
7 | [](https://smarie.github.io/python-m5p/) [](https://pypi.python.org/pypi/m5py/) [](https://pepy.tech/project/m5py) [](https://pepy.tech/project/m5py) [](https://github.com/smarie/python-m5p/stargazers)[](https://zenodo.org/doi/10.5281/zenodo.10552218)
8 |
9 | In 1996 R. Quinlan introduced the M5 algorithm, a regression tree algorithm similar to CART (Breiman), with additional pruning so that leaves may contain linear models instead of constant values. The idea was to get smoother and simpler models.
10 |
11 | The algorithm was later enhanced by Wang & Witten under the name M5 Prime (aka M5', or M5P), with an implementation in the Weka toolbox.
12 |
13 | `m5py` is a python implementation leveraging `scikit-learn`'s regression tree engine.
14 |
15 |
16 | ## Citing
17 |
18 | If `m5py` helps you with your research work, don't hesitate to spread the word ! For this
19 | please cite this [conference presentation](https://hal.science/hal-03762155/).
20 | Optionally in addition you can also cite this Zenodo entry
21 | [](https://zenodo.org/doi/10.5281/zenodo.10552218). Thanks!
22 |
23 | ## Installing
24 |
25 | ```bash
26 | > pip install m5py
27 | ```
28 |
29 | ## Usage
30 |
31 | See the [usage examples gallery](./generated/gallery).
32 |
33 | ## Main features / benefits
34 |
35 | * The classic `M5` algorithm in python, compliant with `scikit-learn`.
36 |
37 | ## See Also
38 |
39 | * Weka's [M5P model](https://weka.sourceforge.io/doc.dev/weka/classifiers/trees/M5P.html)
40 | * LightGBM's [linear_tree](https://lightgbm.readthedocs.io/en/latest/Parameters.html#linear_tree)
41 | * [linear-tree](https://github.com/cerlymarco/linear-tree) that has a similar procedure than lightgbm's (from
42 | [discussion](https://github.com/scikit-learn/scikit-learn/issues/13106#issuecomment-808730062))
43 |
44 | ### Others
45 |
46 | *Do you like this library ? You might also like [my other python libraries](https://github.com/smarie/OVERVIEW#python)*
47 |
48 | ## Want to contribute ?
49 |
50 | Details on the github page: [https://github.com/smarie/python-m5p](https://github.com/smarie/python-m5p)
51 |
--------------------------------------------------------------------------------
/docs/long_description.md:
--------------------------------------------------------------------------------
1 | # `m5py`
2 |
3 | [](https://pypi.python.org/pypi/m5py/) [](https://github.com/smarie/python-m5p/actions/workflows/base.yml) [](https://smarie.github.io/python-m5p/reports/junit/report.html) [](https://smarie.github.io/python-m5p/reports/coverage/index.html) [](https://codecov.io/gh/smarie/python-m5p) [](https://smarie.github.io/python-m5p/reports/flake8/index.html)
4 |
5 | [](https://smarie.github.io/python-m5p/) [](https://pypi.python.org/pypi/m5py/) [](https://pepy.tech/project/m5py) [](https://pepy.tech/project/m5py) [](https://github.com/smarie/python-m5p/stargazers)
6 |
7 | *`scikit-learn`-compliant M5 / M5' model trees for python*
8 |
9 | The documentation for users is available here: [https://smarie.github.io/python-m5p/](https://smarie.github.io/python-m5p/)
10 |
11 | A readme for developers is available here: [https://github.com/smarie/python-m5p](https://github.com/smarie/python-m5p)
12 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: m5py
2 | # site_description: 'A short description of my project'
3 | repo_url: https://github.com/smarie/python-m5p
4 | # docs_dir: .
5 | # site_dir: ../site
6 | # default branch is main instead of master now on github
7 | edit_uri : ./edit/main/docs
8 | nav:
9 | - Home: index.md
10 | - Usage examples: generated/gallery
11 | # - API reference: api_reference.md
12 | - Changelog: changelog.md
13 |
14 | theme: material # readthedocs mkdocs
15 |
16 | plugins:
17 | - gallery:
18 | examples_dirs: docs/examples # path to your example scripts
19 | gallery_dirs: docs/generated/gallery # where to save generated gallery
20 | filename_pattern: ".*_demo"
21 | within_subsection_order: FileNameSortKey # order according to file name
22 |
23 | - search # make sure the search plugin is still enabled
24 |
25 | markdown_extensions:
26 | - toc:
27 | permalink: true
28 |
--------------------------------------------------------------------------------
/noxfile-requirements.txt:
--------------------------------------------------------------------------------
1 | nox
2 | toml
3 | makefun
4 | setuptools_scm # used in 'release'
5 | keyring # used in 'release'
6 |
--------------------------------------------------------------------------------
/noxfile.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from json import dumps
3 | import logging
4 |
5 | import nox # noqa
6 | from pathlib import Path # noqa
7 | import sys
8 |
9 | # add parent folder to python path so that we can import noxfile_utils.py
10 | # note that you need to "pip install -r noxfile-requiterements.txt" for this file to work.
11 | sys.path.append(str(Path(__file__).parent / "ci_tools"))
12 | from nox_utils import PY37, PY36, PY35, PY38, PY39, PY310, power_session, rm_folder, rm_file, PowerSession # noqa
13 |
14 |
15 | pkg_name = "m5py"
16 | gh_org = "smarie"
17 | gh_repo = "python-m5p"
18 |
19 | ENVS = {
20 | # python 3.10 is not available on conda yet
21 | # PY310: {"coverage": False, "pkg_specs": {"pip": ">19"}},
22 | PY39: {"coverage": False, "pkg_specs": {"pip": ">19"}},
23 | # PY27: {"coverage": False, "pkg_specs": {"pip": ">10"}},
24 | # PY35: {"coverage": False, "pkg_specs": {"pip": ">10"}},
25 | # PY36: {"coverage": False, "pkg_specs": {"pip": ">19"}},
26 | # IMPORTANT: this should be last so that the folder docs/reports is not deleted afterwards
27 | PY37: {"coverage": False, "pkg_specs": {"pip": ">19"}}, # , "pytest-html": "1.9.0"
28 | PY38: {"coverage": True, "pkg_specs": {"pip": ">19"}},
29 | }
30 |
31 |
32 | # set the default activated sessions, minimal for CI
33 | nox.options.sessions = ["tests", "flake8", "docs"] # , "docs", "gh_pages"
34 | nox.options.reuse_existing_virtualenvs = True # this can be done using -r
35 | # if platform.system() == "Windows": >> always use this for better control
36 | nox.options.default_venv_backend = "conda"
37 | # os.environ["NO_COLOR"] = "True" # nox.options.nocolor = True does not work
38 | # nox.options.verbose = True
39 |
40 | nox_logger = logging.getLogger("nox")
41 | # nox_logger.setLevel(logging.INFO) NO !!!! this prevents the "verbose" nox flag to work !
42 |
43 |
44 | class Folders:
45 | root = Path(__file__).parent
46 | ci_tools = root / "ci_tools"
47 | runlogs = root / Path(nox.options.envdir or ".nox") / "_runlogs"
48 | runlogs.mkdir(parents=True, exist_ok=True)
49 | dist = root / "dist"
50 | site = root / "site"
51 | site_reports = site / "reports"
52 | reports_root = root / "docs" / "reports"
53 | test_reports = reports_root / "junit"
54 | test_xml = test_reports / "junit.xml"
55 | test_html = test_reports / "report.html"
56 | test_badge = test_reports / "junit-badge.svg"
57 | coverage_reports = reports_root / "coverage"
58 | coverage_xml = coverage_reports / "coverage.xml"
59 | coverage_intermediate_file = root / ".coverage"
60 | coverage_badge = coverage_reports / "coverage-badge.svg"
61 | flake8_reports = reports_root / "flake8"
62 | flake8_intermediate_file = root / "flake8stats.txt"
63 | flake8_badge = flake8_reports / "flake8-badge.svg"
64 |
65 |
66 | @power_session(envs=ENVS, logsdir=Folders.runlogs)
67 | def tests(session: PowerSession, coverage, pkg_specs):
68 | """Run the test suite, including test reports generation and coverage reports. """
69 |
70 | # As soon as this runs, we delete the target site and coverage files to avoid reporting wrong coverage/etc.
71 | rm_folder(Folders.site)
72 | rm_folder(Folders.reports_root)
73 | # delete the .coverage files if any (they are not supposed to be any, but just in case)
74 | rm_file(Folders.coverage_intermediate_file)
75 | rm_file(Folders.root / "coverage.xml")
76 |
77 | # CI-only dependencies
78 | # Did we receive a flag through positional arguments ? (nox -s tests -- )
79 | # install_ci_deps = False
80 | # if len(session.posargs) == 1:
81 | # assert session.posargs[0] == "keyrings.alt"
82 | # install_ci_deps = True
83 | # elif len(session.posargs) > 1:
84 | # raise ValueError("Only a single positional argument is accepted, received: %r" % session.posargs)
85 |
86 | # uncomment and edit if you wish to uninstall something without deleting the whole env
87 | # session.run2("pip uninstall pytest-asyncio --yes")
88 |
89 | # install all requirements
90 | # session.install_reqs(phase="pip", phase_reqs=("pip",), versions_dct=pkg_specs)
91 | session.install_reqs(setup=True, install=True, tests=True, versions_dct=pkg_specs)
92 |
93 | # install CI-only dependencies
94 | # if install_ci_deps:
95 | # session.install2("keyrings.alt")
96 |
97 | # list all (conda list alone does not work correctly on github actions)
98 | # session.run2("conda list")
99 | conda_prefix = Path(session.bin)
100 | if conda_prefix.name == "bin":
101 | conda_prefix = conda_prefix.parent
102 | session.run2("conda list", env={"CONDA_PREFIX": str(conda_prefix), "CONDA_DEFAULT_ENV": session.get_session_id()})
103 |
104 | # Fail if the assumed python version is not the actual one
105 | session.run2("python ci_tools/check_python_version.py %s" % session.python)
106 |
107 | # check that it can be imported even from a different folder
108 | # Important: do not surround the command into double quotes as in the shell !
109 | # session.run('python', '-c', 'import os; os.chdir(\'./docs/\'); import %s' % pkg_name)
110 |
111 | # finally run all tests
112 | if not coverage:
113 | # install self so that it is recognized by pytest
114 | session.run2("pip install . --no-deps")
115 | # session.install(".", "--no-deps")
116 |
117 | # simple: pytest only
118 | session.run2("python -m pytest --cache-clear -v tests/")
119 | else:
120 | # install self in "develop" mode so that coverage can be measured
121 | session.run2("pip install -e . --no-deps")
122 |
123 | # coverage + junit html reports + badge generation
124 | session.install_reqs(phase="coverage",
125 | phase_reqs=["coverage", "pytest-html", "genbadge[tests,coverage]"],
126 | versions_dct=pkg_specs)
127 |
128 | # --coverage + junit html reports
129 | session.run2("coverage run --source src/{pkg_name} "
130 | "-m pytest --cache-clear --junitxml={test_xml} --html={test_html} -v tests/"
131 | "".format(pkg_name=pkg_name, test_xml=Folders.test_xml, test_html=Folders.test_html))
132 | session.run2("coverage report")
133 | session.run2("coverage xml -o {covxml}".format(covxml=Folders.coverage_xml))
134 | session.run2("coverage html -d {dst}".format(dst=Folders.coverage_reports))
135 | # delete this intermediate file, it is not needed anymore
136 | rm_file(Folders.coverage_intermediate_file)
137 |
138 | # --generates the badge for the test results and fail build if less than x% tests pass
139 | nox_logger.info("Generating badge for tests coverage")
140 | # Use our own package to generate the badge
141 | session.run2("genbadge tests -i %s -o %s -t 100" % (Folders.test_xml, Folders.test_badge))
142 | session.run2("genbadge coverage -i %s -o %s" % (Folders.coverage_xml, Folders.coverage_badge))
143 |
144 |
145 | @power_session(python=PY38, logsdir=Folders.runlogs)
146 | def flake8(session: PowerSession):
147 | """Launch flake8 qualimetry."""
148 |
149 | session.install("-r", str(Folders.ci_tools / "flake8-requirements.txt"))
150 | session.install_reqs(phase="flake8", phase_reqs=["numpy", "pandas", "scikit-learn"])
151 | session.run2("pip install . --no-deps")
152 |
153 | rm_folder(Folders.flake8_reports)
154 | Folders.flake8_reports.mkdir(parents=True, exist_ok=True)
155 | rm_file(Folders.flake8_intermediate_file)
156 |
157 | session.cd("src")
158 |
159 | # Options are set in `setup.cfg` file
160 | session.run("flake8", pkg_name, "--exit-zero", "--format=html", "--htmldir", str(Folders.flake8_reports),
161 | "--statistics", "--tee", "--output-file", str(Folders.flake8_intermediate_file))
162 | # generate our badge
163 | session.run2("genbadge flake8 -i %s -o %s" % (Folders.flake8_intermediate_file, Folders.flake8_badge))
164 | rm_file(Folders.flake8_intermediate_file)
165 |
166 |
167 | @power_session(python=[PY38])
168 | def docs(session: PowerSession):
169 | """Generates the doc and serves it on a local http server. Pass '-- build' to build statically instead."""
170 |
171 | # we need to install self for the doc gallery examples to work
172 | session.install_reqs(phase="docs", phase_reqs=["numpy", "pandas", "scikit-learn", "matplotlib", "seaborn"])
173 | session.run2("pip install . --no-deps")
174 | session.install_reqs(phase="docs", phase_reqs=[
175 | "mkdocs-material", "mkdocs", "pymdown-extensions", "pygments", "mkdocs-gallery", "pillow", "matplotlib",
176 | ])
177 |
178 | if session.posargs:
179 | # use posargs instead of "serve"
180 | session.run2("mkdocs %s" % " ".join(session.posargs))
181 | else:
182 | session.run2("mkdocs serve")
183 |
184 |
185 | @power_session(python=[PY38])
186 | def publish(session: PowerSession):
187 | """Deploy the docs+reports on github pages. Note: this rebuilds the docs"""
188 |
189 | # we need to install self for the doc gallery examples to work
190 | session.install_reqs(phase="publish", phase_reqs=["numpy", "pandas", "scikit-learn", "matplotlib", "seaborn"])
191 | session.run2("pip install . --no-deps")
192 | session.install_reqs(phase="publish", phase_reqs=[
193 | "mkdocs-material", "mkdocs", "pymdown-extensions", "pygments", "mkdocs-gallery", "pillow"
194 | ])
195 |
196 | # possibly rebuild the docs in a static way (mkdocs serve does not build locally)
197 | session.run2("mkdocs build")
198 |
199 | # check that the doc has been generated with coverage
200 | if not Folders.site_reports.exists():
201 | raise ValueError("Test reports have not been built yet. Please run 'nox -s tests-3.7' first")
202 |
203 | # publish the docs
204 | session.run2("mkdocs gh-deploy")
205 |
206 | # publish the coverage - now in github actions only
207 | # session.install_reqs(phase="codecov", phase_reqs=["codecov", "keyring"])
208 | # # keyring set https://app.codecov.io/gh// token
209 | # import keyring # (note that this import is not from the session env but the main nox env)
210 | # codecov_token = keyring.get_password("https://app.codecov.io/gh//>", "token")
211 | # # note: do not use --root nor -f ! otherwise "There was an error processing coverage reports"
212 | # session.run2('codecov -t %s -f %s' % (codecov_token, Folders.coverage_xml))
213 |
214 |
215 | @power_session(python=[PY38])
216 | def release(session: PowerSession):
217 | """Create a release on github corresponding to the latest tag"""
218 |
219 | # Get current tag using setuptools_scm and make sure this is not a dirty/dev one
220 | from setuptools_scm import get_version # (note that this import is not from the session env but the main nox env)
221 | from setuptools_scm.version import guess_next_dev_version
222 | version = []
223 |
224 | def my_scheme(version_):
225 | version.append(version_)
226 | return guess_next_dev_version(version_)
227 | current_tag = get_version(".", version_scheme=my_scheme)
228 |
229 | # create the package
230 | session.install_reqs(phase="setup.py#dist", phase_reqs=["setuptools_scm"])
231 | rm_folder(Folders.dist)
232 | session.run2("python setup.py sdist bdist_wheel")
233 |
234 | if version[0].dirty or not version[0].exact:
235 | raise ValueError("You need to execute this action on a clean tag version with no local changes.")
236 |
237 | # Did we receive a token through positional arguments ? (nox -s release -- )
238 | if len(session.posargs) == 1:
239 | # Run from within github actions - no need to publish on pypi
240 | gh_token = session.posargs[0]
241 | publish_on_pypi = False
242 |
243 | elif len(session.posargs) == 0:
244 | # Run from local commandline - assume we want to manually publish on PyPi
245 | publish_on_pypi = True
246 |
247 | # keyring set https://docs.github.com/en/rest token
248 | import keyring # (note that this import is not from the session env but the main nox env)
249 | gh_token = keyring.get_password("https://docs.github.com/en/rest", "token")
250 | assert len(gh_token) > 0
251 |
252 | else:
253 | raise ValueError("Only a single positional arg is allowed for now")
254 |
255 | # publish the package on PyPi
256 | if publish_on_pypi:
257 | # keyring set https://upload.pypi.org/legacy/ your-username
258 | # keyring set https://test.pypi.org/legacy/ your-username
259 | session.install_reqs(phase="PyPi", phase_reqs=["twine"])
260 | session.run2("twine upload dist/* -u smarie") # -r testpypi
261 |
262 | # create the github release
263 | session.install_reqs(phase="release", phase_reqs=["click", "PyGithub"])
264 | session.run2("python ci_tools/github_release.py -s {gh_token} "
265 | "--repo-slug {gh_org}/{gh_repo} -cf ./docs/changelog.md "
266 | "-d https://{gh_org}.github.io/{gh_repo}/changelog {tag}"
267 | "".format(gh_token=gh_token, gh_org=gh_org, gh_repo=gh_repo, tag=current_tag))
268 |
269 |
270 | @nox.session(python=False)
271 | def gha_list(session):
272 | """(mandatory arg: ) Prints all sessions available for , for GithubActions."""
273 |
274 | # see https://stackoverflow.com/q/66747359/7262247
275 |
276 | # get the desired base session to generate the list for
277 | if len(session.posargs) != 1:
278 | raise ValueError("This session has a mandatory argument: ")
279 | session_func = globals()[session.posargs[0]]
280 |
281 | # list all sessions for this base session
282 | try:
283 | session_func.parametrize
284 | except AttributeError:
285 | sessions_list = ["%s-%s" % (session_func.__name__, py) for py in session_func.python]
286 | else:
287 | sessions_list = ["%s-%s(%s)" % (session_func.__name__, py, param)
288 | for py, param in product(session_func.python, session_func.parametrize)]
289 |
290 | # print the list so that it can be caught by GHA.
291 | # Note that json.dumps is optional since this is a list of string.
292 | # However it is to remind us that GHA expects a well-formatted json list of strings.
293 | print(dumps(sessions_list))
294 |
295 |
296 | # if __name__ == '__main__':
297 | # # allow this file to be executable for easy debugging in any IDE
298 | # nox.run(globals())
299 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=39.2",
4 | "setuptools_scm",
5 | "wheel"
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
9 | # pip: no ! does not work in old python 2.7 and not recommended here
10 | # https://setuptools.readthedocs.io/en/latest/userguide/quickstart.html#basic-use
11 |
12 | [tool.conda]
13 | # Declare that the following packages should be installed with conda instead of pip
14 | # Note: this includes packages declared everywhere, here and in setup.cfg
15 | conda_packages = [
16 | "setuptools",
17 | "wheel",
18 | "pip",
19 | "scikit-learn",
20 | "pandas",
21 | "numpy",
22 | "matplotlib",
23 | "seaborn"
24 | ]
25 | # pytest: not with conda ! does not work in old python 2.7 and 3.5
26 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | # See https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
2 | # And this great example : https://github.com/Kinto/kinto/blob/master/setup.cfg
3 | [metadata]
4 | name = m5py
5 | description = An implementation of M5 (Prime) and model trees for scikit-learn.
6 | description_file = README.md
7 | license = BSD 3-Clause
8 | long_description = file: docs/long_description.md
9 | long_description_content_type=text/markdown
10 | keywords = model tree regression M5 prime scikit learn machine learning
11 | author = Sylvain MARIE
12 | maintainer = Sylvain MARIE
13 | url = https://github.com/smarie/python-m5p
14 | # download_url = https://github.com/smarie/python-m5p/tarball/main >> do it in the setup.py to get the right version
15 | classifiers =
16 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
17 | Development Status :: 5 - Production/Stable
18 | Intended Audience :: Developers
19 | Intended Audience :: Science/Research
20 | License :: OSI Approved :: BSD License
21 | Topic :: Software Development :: Libraries :: Python Modules
22 | Topic :: Scientific/Engineering :: Mathematics
23 | Programming Language :: Python
24 | Programming Language :: Python :: 3
25 | # Programming Language :: Python :: 3.5
26 | # Programming Language :: Python :: 3.6
27 | Programming Language :: Python :: 3.7
28 | Programming Language :: Python :: 3.8
29 | Programming Language :: Python :: 3.9
30 |
31 | [options]
32 | # one day these will be able to come from requirement files, see https://github.com/pypa/setuptools/issues/1951. But will it be better ?
33 | setup_requires =
34 | setuptools_scm
35 | pytest-runner
36 | install_requires =
37 | # note: do not use double quotes in these, this triggers a weird bug in PyCharm in debug mode only
38 | scikit-learn
39 | # funcsigs;python_version<'3.3'
40 | # enum34;python_version<'3.4'
41 | tests_require =
42 | pytest
43 |
44 | # test_suite = tests --> no need apparently
45 | #
46 | zip_safe = False
47 | # explicitly setting zip_safe=False to avoid downloading `ply` see https://github.com/smarie/python-getversion/pull/5
48 | # and makes mypy happy see https://mypy.readthedocs.io/en/latest/installed_packages.html
49 | package_dir=
50 | =src
51 | packages = find:
52 | # see [options.packages.find] below
53 | # IMPORTANT: DO NOT set the `include_package_data` flag !! It triggers inclusion of all git-versioned files
54 | # see https://github.com/pypa/setuptools_scm/issues/190#issuecomment-351181286
55 | # include_package_data = True
56 | [options.packages.find]
57 | where=src
58 | exclude =
59 | contrib
60 | docs
61 | *tests*
62 |
63 | [options.package_data]
64 | * = py.typed, *.pyi
65 |
66 |
67 | # Optional dependencies that can be installed with e.g. $ pip install -e .[dev,test]
68 | # [options.extras_require]
69 |
70 | # -------------- Packaging -----------
71 | # [options.entry_points]
72 |
73 | # [egg_info] >> already covered by setuptools_scm
74 |
75 | [bdist_wheel]
76 | # Code is written to work on both Python 2 and Python 3.
77 | universal=1
78 |
79 | # ------------- Others -------------
80 | # In order to be able to execute 'python setup.py test'
81 | # from https://docs.pytest.org/en/latest/goodpractices.html#integrating-with-setuptools-python-setup-py-test-pytest-runner
82 | [aliases]
83 | test = pytest
84 |
85 | # pytest default configuration
86 | [tool:pytest]
87 | testpaths = tests/
88 | addopts =
89 | --verbose
90 | --doctest-modules
91 | --ignore-glob='**/_*.py'
92 |
93 | # we need the 'always' for python 2 tests to work see https://github.com/pytest-dev/pytest/issues/2917
94 | filterwarnings =
95 | always
96 | ; ignore::UserWarning
97 |
98 | # Coverage config
99 | [coverage:run]
100 | branch = True
101 | omit = *tests*
102 | # this is done in nox.py (github actions) or ci_tools/run_tests.sh (travis)
103 | # source = m5py
104 | # command_line = -m pytest --junitxml="reports/pytest_reports/pytest.xml" --html="reports/pytest_reports/pytest.html" -v m5py/tests/
105 |
106 | [coverage:report]
107 | fail_under = 40
108 | show_missing = True
109 | exclude_lines =
110 | # this line for all the python 2 not covered lines
111 | except ImportError:
112 | # we have to repeat this when exclude_lines is set
113 | pragma: no cover
114 |
115 | # Done in nox.py
116 | # [coverage:html]
117 | # directory = site/reports/coverage_reports
118 | # [coverage:xml]
119 | # output = site/reports/coverage_reports/coverage.xml
120 |
121 | [flake8]
122 | max-line-length = 120
123 | extend-ignore = D, E203 # D: Docstring errors, E203: see https://github.com/PyCQA/pycodestyle/issues/373
124 | copyright-check = True
125 | copyright-regexp = ^\#\s+Authors:\s+Sylvain MARIE \n\#\s+\+\sAll\scontributors\sto\s\n\#\n\#\s\sLicense:\s3\-clause\sBSD,\s
126 | exclude =
127 | .git
128 | .github
129 | .nox
130 | .pytest_cache
131 | ci_tools
132 | docs
133 | tests
134 | noxfile.py
135 | setup.py
136 | */_version.py
137 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | To understand this project's build structure
3 |
4 | - This project uses setuptools, so it is declared as the build system in the pyproject.toml file
5 | - We use as much as possible `setup.cfg` to store the information so that it can be read by other tools such as `tox`
6 | and `nox`. So `setup.py` contains **almost nothing** (see below)
7 | This philosophy was found after trying all other possible combinations in other projects :)
8 | A reference project that was inspiring to make this move : https://github.com/Kinto/kinto/blob/master/setup.cfg
9 |
10 | See also:
11 | https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
12 | https://packaging.python.org/en/latest/distributing.html
13 | https://github.com/pypa/sampleproject
14 | """
15 | from setuptools import setup
16 |
17 |
18 | # (1) check required versions (from https://medium.com/@daveshawley/safely-using-setup-cfg-for-metadata-1babbe54c108)
19 | import pkg_resources
20 |
21 | pkg_resources.require("setuptools>=39.2")
22 | pkg_resources.require("setuptools_scm")
23 |
24 |
25 | # (2) Generate download url using git version
26 | from setuptools_scm import get_version # noqa: E402
27 |
28 | URL = "https://github.com/smarie/python-m5p"
29 | DOWNLOAD_URL = URL + "/tarball/" + get_version()
30 |
31 |
32 | # (3) Call setup() with as little args as possible
33 | setup(
34 | download_url=DOWNLOAD_URL,
35 | use_scm_version={
36 | "write_to": "src/m5py/_version.py"
37 | }, # we can't put `use_scm_version` in setup.cfg yet unfortunately
38 | )
39 |
--------------------------------------------------------------------------------
/src/m5py/__init__.py:
--------------------------------------------------------------------------------
1 | # Authors: Sylvain MARIE
2 | # + All contributors to
3 | #
4 | # License: 3-clause BSD,
5 |
6 | from m5py.main import M5Prime
7 | from m5py.export import export_text_m5
8 |
9 | try:
10 | # -- Distribution mode --
11 | # import from _version.py generated by setuptools_scm during release
12 | from ._version import version as __version__
13 | except ImportError:
14 | # -- Source mode --
15 | # use setuptools_scm to get the current version from src using git
16 | from setuptools_scm import get_version as _gv
17 | from os import path as _path
18 | __version__ = _gv(_path.join(_path.dirname(__file__), _path.pardir))
19 |
20 | __all__ = [
21 | "__version__",
22 | # submodules
23 | "main",
24 | "export",
25 | # symbols
26 | "M5Prime",
27 | "export_text_m5"
28 | ]
29 |
--------------------------------------------------------------------------------
/src/m5py/export.py:
--------------------------------------------------------------------------------
1 | from numbers import Integral
2 |
3 | import numpy as np
4 | from io import StringIO
5 |
6 | from sklearn.tree import _tree
7 | from sklearn.tree._criterion import FriedmanMSE
8 |
9 | from m5py.main import is_leaf, ConstantLeafModel, M5Base, check_is_fitted
10 |
11 |
12 | def export_text_m5(decision_tree, out_file=None, max_depth=None,
13 | feature_names=None, class_names=None, label='all',
14 | target_name=None,
15 | # filled=False, leaves_parallel=False,
16 | impurity=True,
17 | node_ids=False, proportion=False,
18 | # rounded=False, rotate=False,
19 | special_characters=False, precision=3, **kwargs):
20 | """Export a decision tree in TXT format.
21 |
22 | Note: this should be merged with ._export.export_text
23 |
24 | Inspired by WEKA and by
25 | >>> from sklearn.tree import export_graphviz
26 |
27 | This function generates a human-readable, text representation of the
28 | decision tree, which is then written into `out_file`.
29 |
30 | The sample counts that are shown are weighted with any sample_weights that
31 | might be present.
32 |
33 | Read more in the :ref:`User Guide `.
34 |
35 | Parameters
36 | ----------
37 | decision_tree : decision tree classifier
38 | The decision tree to be exported to text.
39 |
40 | out_file : file object or string, optional (default='tree.dot')
41 | Handle or name of the output file. If ``None``, the result is
42 | returned as a string.
43 |
44 | max_depth : int, optional (default=None)
45 | The maximum depth of the representation. If None, the tree is fully
46 | generated.
47 |
48 | feature_names : list of strings, optional (default=None)
49 | Names of each of the features.
50 |
51 | class_names : list of strings, bool or None, optional (default=None)
52 | Names of each of the target classes in ascending numerical order.
53 | Only relevant for classification and not supported for multi-output.
54 | If ``True``, shows a symbolic representation of the class name.
55 |
56 | label : {'all', 'root', 'none'}, optional (default='all')
57 | Whether to show informative labels for impurity, etc.
58 | Options include 'all' to show at every node, 'root' to show only at
59 | the top root node, or 'none' to not show at any node.
60 |
61 | target_name : optional string with the target name. If not provided, the
62 | target will not be displayed in the equations
63 |
64 | impurity : bool, optional (default=True)
65 | When set to ``True``, show the impurity at each node.
66 |
67 | node_ids : bool, optional (default=False)
68 | When set to ``True``, show the ID number on each node.
69 |
70 | proportion : bool, optional (default=False)
71 | When set to ``True``, change the display of 'values' and/or 'samples'
72 | to be proportions and percentages respectively.
73 |
74 | special_characters : bool, optional (default=False)
75 | When set to ``False``, ignore special characters for PostScript
76 | compatibility.
77 |
78 | precision : int, optional (default=3)
79 | Number of digits of precision for floating point in the values of
80 | impurity, threshold and value attributes of each node.
81 |
82 | kwargs : other keyword arguments for the linear model printer
83 |
84 | Returns
85 | -------
86 | dot_data : string
87 | String representation of the input tree in GraphViz dot format.
88 | Only returned if ``out_file`` is None.
89 |
90 | Examples
91 | --------
92 | >>> from sklearn.datasets import load_iris
93 | >>> from sklearn import tree
94 |
95 | >>> clf = tree.DecisionTreeClassifier()
96 | >>> iris = load_iris()
97 |
98 | >>> clf = clf.fit(iris.data, iris.target)
99 | >>> tree_to_text(clf, out_file='tree.txt') # doctest: +SKIP
100 |
101 | """
102 |
103 | models = []
104 |
105 | def add_model(node_model):
106 | models.append(node_model)
107 | return len(models)
108 |
109 | def node_to_str(tree, node_id, criterion, node_models=None):
110 | """ Generates the node content string """
111 |
112 | # Should labels be shown?
113 | labels = (label == 'root' and node_id == 0) or label == 'all'
114 |
115 | # PostScript compatibility for special characters
116 | if special_characters:
117 | characters = ['#', '', '', '≤', '
', '>']
118 | node_string = '<'
119 | else:
120 | characters = ['#', '[', ']', '<=', '\\n', '']
121 | node_string = ''
122 |
123 | # -- If this node is not a leaf, Write the split decision criteria (x <= y)
124 | leaf = is_leaf(node_id, tree)
125 | if not leaf:
126 | if feature_names is not None:
127 | feature = feature_names[tree.feature[node_id]]
128 | else:
129 | feature = "X%s%s%s" % (characters[1],
130 | tree.feature[node_id], # feature id for the split
131 | characters[2])
132 | node_string += '%s %s %s' % (feature,
133 | characters[3], # <=
134 | round(tree.threshold[node_id], # threshold for the split
135 | precision))
136 | else:
137 | node_string += 'LEAF'
138 |
139 | # Node details - start bracket [
140 | node_string += ' %s' % characters[1]
141 |
142 | # -- Write impurity
143 | if impurity:
144 | if isinstance(criterion, FriedmanMSE):
145 | criterion = "friedman_mse"
146 | elif not isinstance(criterion, str):
147 | criterion = "impurity"
148 | if labels:
149 | node_string += '%s=' % criterion
150 | node_string += str(round(tree.impurity[node_id], precision)) + ', '
151 |
152 | # -- Write node sample count
153 | if labels:
154 | node_string += 'samples='
155 | if proportion:
156 | percent = (100. * tree.n_node_samples[node_id] /
157 | float(tree.n_node_samples[0]))
158 | node_string += str(round(percent, 1)) + '%'
159 | else:
160 | node_string += str(tree.n_node_samples[node_id])
161 |
162 | # Node details - end bracket ]
163 | node_string += '%s' % characters[2]
164 |
165 | # -- Write node class distribution / regression value
166 | if tree.n_outputs == 1:
167 | value = tree.value[node_id][0, :]
168 | else:
169 | value = tree.value[node_id]
170 |
171 | if proportion and tree.n_classes[0] != 1:
172 | # For classification this will show the proportion of samples
173 | value = value / tree.weighted_n_node_samples[node_id]
174 | if tree.n_classes[0] == 1:
175 | # Regression
176 | value_text = np.around(value, precision)
177 | elif proportion:
178 | # Classification
179 | value_text = np.around(value, precision)
180 | elif np.all(np.equal(np.mod(value, 1), 0)):
181 | # Classification without floating-point weights
182 | value_text = value.astype(int)
183 | else:
184 | # Classification with floating-point weights
185 | value_text = np.around(value, precision)
186 |
187 | # Strip whitespace
188 | value_text = str(value_text.astype('S32')).replace("b'", "'")
189 | value_text = value_text.replace("' '", ", ").replace("'", "")
190 | if tree.n_classes[0] == 1 and tree.n_outputs == 1:
191 | value_text = value_text.replace("[", "").replace("]", "")
192 | value_text = value_text.replace("\n ", characters[4])
193 |
194 | if node_models is None:
195 | node_string += ' : '
196 | if labels:
197 | node_string += 'value='
198 | else:
199 | nodemodel = node_models[node_id]
200 | model_err_val = np.around(nodemodel.error, precision)
201 | if leaf:
202 | if isinstance(nodemodel, ConstantLeafModel):
203 | # the model does not contain the value. rely on the value_text computed from tree
204 | value_text = " : %s (err=%s, params=%s)" % (value_text, model_err_val, nodemodel.n_params)
205 | else:
206 | # put the model in the stack, we'll write it later
207 | model_id = add_model(nodemodel)
208 | value_text = " : LM%s (err=%s, params=%s)" % (model_id, model_err_val, nodemodel.n_params)
209 | else:
210 | # replace the value text with error at this node and number of parameters
211 | value_text = " (err=%s, params=%s)" % (model_err_val, nodemodel.n_params)
212 |
213 | node_string += value_text
214 |
215 | # Write node majority class
216 | if (class_names is not None and
217 | tree.n_classes[0] != 1 and
218 | tree.n_outputs == 1):
219 | # Only done for single-output classification trees
220 | node_string += ', '
221 | if labels:
222 | node_string += 'class='
223 | if class_names is not True:
224 | class_name = class_names[np.argmax(value)]
225 | else:
226 | class_name = "y%s%s%s" % (characters[1],
227 | np.argmax(value),
228 | characters[2])
229 | node_string += class_name
230 |
231 | return node_string + characters[5]
232 |
233 | def recurse(tree, node_id, criterion, parent=None, depth=0, node_models=None):
234 | if node_id == _tree.TREE_LEAF:
235 | raise ValueError("Invalid node_id %s" % _tree.TREE_LEAF)
236 |
237 | # Add node with description
238 | if max_depth is None or depth <= max_depth:
239 | indent_str = ("| " * depth)
240 | if node_ids:
241 | out_file.write('%d| %s%s\n' % (node_id, indent_str, node_to_str(tree, node_id, criterion,
242 | node_models=node_models)))
243 | else:
244 | out_file.write('%s%s\n' % (indent_str, node_to_str(tree, node_id, criterion, node_models=node_models)))
245 |
246 | # Recurse on Children if needed
247 | left_child = tree.children_left[node_id]
248 | right_child = tree.children_right[node_id]
249 |
250 | # if not is_leaf(node_id, tree)
251 | if left_child != _tree.TREE_LEAF:
252 | # that means that node_id is not a leaf (see is_leaf() below.): recurse on children
253 | recurse(tree, left_child, criterion=criterion, parent=node_id, depth=depth + 1,
254 | node_models=node_models)
255 | recurse(tree, right_child, criterion=criterion, parent=node_id, depth=depth + 1,
256 | node_models=node_models)
257 |
258 | else:
259 | ranks['leaves'].append(str(node_id))
260 | out_file.write('%d| (...)\n')
261 |
262 | def write_models(models):
263 | for i, model in enumerate(models):
264 | out_file.write("LM%s: %s\n" % (i + 1, model.to_text(feature_names=feature_names, precision=precision,
265 | target_name=target_name, **kwargs)))
266 |
267 | # Main
268 | check_is_fitted(decision_tree, 'tree_')
269 | own_file = False
270 | return_string = False
271 | try:
272 | if isinstance(out_file, str):
273 | out_file = open(out_file, "w", encoding="utf-8")
274 | own_file = True
275 |
276 | if out_file is None:
277 | return_string = True
278 | out_file = StringIO()
279 |
280 | if isinstance(precision, Integral):
281 | if precision < 0:
282 | raise ValueError("'precision' should be greater or equal to 0."
283 | " Got {} instead.".format(precision))
284 | else:
285 | raise ValueError("'precision' should be an integer. Got {}"
286 | " instead.".format(type(precision)))
287 |
288 | # Check length of feature_names before getting into the tree node
289 | # Raise error if length of feature_names does not match
290 | # n_features_ in the decision_tree
291 | if feature_names is not None:
292 | if len(feature_names) != decision_tree.n_features_in_:
293 | raise ValueError("Length of feature_names, %d "
294 | "does not match number of features, %d"
295 | % (len(feature_names),
296 | decision_tree.n_features_in_))
297 |
298 | # The depth of each node for plotting with 'leaf' option TODO probably remove
299 | ranks = {'leaves': []}
300 |
301 | # Tree title
302 | if isinstance(decision_tree, M5Base):
303 | if hasattr(decision_tree, 'installed_smoothing_constant'):
304 | details = "pre-smoothed with constant %s" % decision_tree.installed_smoothing_constant
305 | else:
306 | if decision_tree.use_smoothing == 'installed':
307 | details = "under construction - not pre-smoothed yet"
308 | else:
309 | details = "unsmoothed - but this can be done at prediction time"
310 |
311 | # add more info or M5P
312 | out_file.write('%s (%s):\n' % (type(decision_tree).__name__, details))
313 | else:
314 | # generic title
315 | out_file.write('%s :\n' % type(decision_tree).__name__)
316 |
317 | # some space for readability
318 | out_file.write('\n')
319 |
320 | # Now recurse the tree and add node & edge attributes
321 | if isinstance(decision_tree, _tree.Tree):
322 | recurse(decision_tree, 0, criterion="impurity")
323 | elif isinstance(decision_tree, M5Base) and hasattr(decision_tree, 'node_models'):
324 | recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion, node_models=decision_tree.node_models)
325 |
326 | # extra step: write all models
327 | out_file.write("\n")
328 | write_models(models)
329 | else:
330 | recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
331 |
332 | # Return the text if needed
333 | if return_string:
334 | return out_file.getvalue()
335 |
336 | finally:
337 | if own_file:
338 | out_file.close()
339 |
--------------------------------------------------------------------------------
/src/m5py/linreg_utils.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod, ABCMeta
2 |
3 | import numpy as np
4 |
5 | from sklearn.linear_model import LinearRegression
6 | from sklearn.linear_model._base import LinearModel
7 | from sklearn.preprocessing import StandardScaler
8 | from sklearn.utils.extmath import safe_sparse_dot
9 |
10 |
11 | def linreg_model_to_text(model, feature_names=None, target_name=None,
12 | precision=3, line_breaks=False):
13 | """
14 | Converts a linear regression model to a text representation.
15 |
16 | :param model:
17 | :param feature_names:
18 | :param target_name:
19 | :param precision:
20 | :param line_breaks: if True, each term in the sum shows in a different line
21 | :return:
22 | """
23 | bits = []
24 |
25 | # Template for numbers: we want scientific notation with a given precision
26 | nb_tpl = "%%0.%se" % precision
27 |
28 | # Handle multi-dimensional y (its second dim must be size 1, though)
29 | if len(model.coef_.shape) > 1:
30 | assert model.coef_.shape[0] == 1
31 | assert len(model.coef_.shape) == 2
32 | coefs = np.ravel(model.coef_) # convert to 1D
33 | assert len(model.intercept_) == 1
34 | intercept = model.intercept_.item() # extract scalar
35 | else:
36 | coefs = model.coef_ # a 1D array
37 | intercept = model.intercept_ # a scalar already
38 |
39 | # First all coefs * drivers
40 | for i, c in enumerate(coefs):
41 | var_name = ("X[%s]" % i) if feature_names is None else feature_names[i]
42 |
43 | if i == 0:
44 | # first term
45 | if c < 1:
46 | # use scientific notation
47 | product_text = (nb_tpl + " * %s") % (c, var_name)
48 | else:
49 | # use standard notation with precision
50 | c = np.round(c, precision)
51 | product_text = "%s * %s" % (c, var_name)
52 | else:
53 | # all the other terms: the sign should appear
54 | lb = '\n' if line_breaks else ""
55 | coef_abs = np.abs(c)
56 | coef_sign = '+' if np.sign(c) > 0 else '-'
57 | if coef_abs < 1:
58 | # use scientific notation
59 | product_text = (("%s%s " + nb_tpl + " * %s")
60 | % (lb, coef_sign, coef_abs, var_name))
61 | else:
62 | # use standard notation with precision
63 | coef_abs = np.round(coef_abs, precision)
64 | product_text = ("%s%s %s * %s"
65 | % (lb, coef_sign, coef_abs, var_name))
66 |
67 | bits.append(product_text)
68 |
69 | # Finally the intercept
70 | if len(bits) == 0:
71 | # intercept is the only term in the sum
72 | if intercept < 1:
73 | # use scientific notation only for small numbers (otherwise 12
74 | # would read 1.2e1 ... not friendly)
75 | constant_text = nb_tpl % intercept
76 | else:
77 | # use standard notation with precision
78 | i = np.round(intercept, precision)
79 | constant_text = "%s" % i
80 | else:
81 | # there are other terms in the sum: the sign should appear
82 | lb = '\n' if line_breaks else ""
83 | coef_abs = np.abs(intercept)
84 | coef_sign = '+' if np.sign(intercept) > 0 else '-'
85 | if coef_abs < 1:
86 | # use scientific notation
87 | constant_text = ("%s%s " + nb_tpl) % (lb, coef_sign, coef_abs)
88 | else:
89 | # use standard notation with precision
90 | coef_abs = np.round(coef_abs, precision)
91 | constant_text = "%s%s %s" % (lb, coef_sign, coef_abs)
92 |
93 | bits.append(constant_text)
94 |
95 | txt = " ".join(bits)
96 | if target_name is not None:
97 | txt = target_name + " = " + txt
98 |
99 | return txt
100 |
101 |
102 | class DeNormalizableMixIn(metaclass=ABCMeta):
103 | """
104 | An abstract class that models able to de-normalize should implement.
105 | """
106 | __slots__ = ()
107 |
108 | @abstractmethod
109 | def denormalize(self,
110 | x_scaler: StandardScaler = None,
111 | y_scaler: StandardScaler = None
112 | ):
113 | """
114 | Denormalizes the model, knowing that it was fit with the given
115 | x_scaler and y_scaler
116 | """
117 |
118 |
119 | class DeNormalizableLinearModelMixIn(DeNormalizableMixIn, LinearModel):
120 | """
121 | A mix-in class to add 'denormalization' capability to a linear model
122 | """
123 | def denormalize(self,
124 | x_scaler: StandardScaler = None,
125 | y_scaler: StandardScaler = None
126 | ):
127 | """
128 | De-normalizes the linear regression model.
129 | Before this function is executed,
130 | (y-y_mean)/y_scale = self.coef_.T (x-x_mean)/x_scale + self.intercept_
131 | so
132 | (y-y_mean)/y_scale = (self.coef_/x_scale).T x + (self.intercept_ - self.coef_.T x_mean/x_scale)
133 | that is
134 | (y-y_mean)/y_scale = new_coef.T x + new_intercept
135 | where
136 | * new_coef = (self.coef_/x_scale)
137 | * new_intercept = (self.intercept_ - (self.intercept_ - self.coef_.T x_mean/x_scale)
138 |
139 | Then going back to y
140 | y = (new_coef * y_scale).T x + (new_intercept * y_scale + y_mean)
141 |
142 | :param self:
143 | :param x_scaler:
144 | :param y_scaler:
145 | :return:
146 | """
147 | # First save old coefficients
148 | self.normalized_coef_ = self.coef_
149 | self.normalized_intercept_ = self.intercept_
150 |
151 | # denormalize coefficients to take into account the x normalization
152 | if x_scaler is not None:
153 | new_coef = self.coef_ / x_scaler.scale_
154 | new_intercept = (
155 | self.intercept_ -
156 | safe_sparse_dot(x_scaler.mean_, new_coef.T,
157 | dense_output=True)
158 | )
159 |
160 | self.coef_ = new_coef
161 | self.intercept_ = new_intercept
162 |
163 | # denormalize them further to take into account the y normalization
164 | if y_scaler is not None:
165 | new_coef = self.coef_ * y_scaler.scale_
166 | new_intercept = y_scaler.inverse_transform(
167 | np.atleast_1d(self.intercept_)
168 | )
169 | if np.isscalar(self.intercept_):
170 | new_intercept = new_intercept[0]
171 | self.coef_ = new_coef
172 | self.intercept_ = new_intercept
173 |
174 |
175 | class DeNormalizableLinearRegression(LinearRegression,
176 | DeNormalizableLinearModelMixIn):
177 | """
178 | A Denormalizable linear regression. The old normalized coefficients are
179 | kept in a new field named `feature_importances_`
180 | """
181 | @property
182 | def feature_importances_(self):
183 | if hasattr(self, '_feature_importances_'):
184 | return self._feature_importances_
185 | else:
186 | return self.coef_
187 |
188 | def denormalize(self,
189 | x_scaler: StandardScaler = None,
190 | y_scaler: StandardScaler = None):
191 | """
192 | Denormalizes the model, and saves a copy of the old normalized
193 | coefficients in self._feature_importances.
194 |
195 | :param x_scaler:
196 | :param y_scaler:
197 | :return:
198 | """
199 | self._feature_importances_ = self.coef_
200 | super(DeNormalizableLinearRegression, self).denormalize(x_scaler,
201 | y_scaler)
202 |
203 |
204 | # For all other it should work too
205 | # class DeNormalizableLasso(Lasso, DeNormalizableLinearModelMixIn):
206 | # pass
207 |
--------------------------------------------------------------------------------
/src/m5py/main.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from copy import copy
3 | from logging import getLogger
4 | from warnings import warn
5 |
6 | import numpy as np
7 |
8 | from scipy.sparse import issparse
9 |
10 | from sklearn import clone
11 | from sklearn.base import RegressorMixin, is_classifier
12 | from sklearn.linear_model import LinearRegression
13 | from sklearn.metrics import mean_squared_error
14 | from sklearn.preprocessing import StandardScaler
15 | from sklearn.tree import BaseDecisionTree, _tree
16 | from sklearn.tree._classes import DTYPE
17 | from sklearn.tree._tree import DOUBLE
18 | from sklearn.utils import check_array
19 | from sklearn.utils.validation import check_is_fitted
20 | from sklearn import __version__ as sklearn_version
21 |
22 | from m5py.linreg_utils import linreg_model_to_text, DeNormalizableMixIn, DeNormalizableLinearRegression
23 |
24 | from packaging.version import Version
25 |
26 | SKLEARN_VERSION = Version(sklearn_version)
27 | SKLEARN13_OR_GREATER = SKLEARN_VERSION >= Version("1.3.0")
28 |
29 |
30 | __all__ = ["M5Base", "M5Prime"]
31 |
32 |
33 | _SmoothingDetails = namedtuple("_SmoothingDetails", ("A", "B", "C"))
34 | # Internal structure to contain the recursive smoothed coefficients details in the smoothing algorithm
35 |
36 |
37 | logger = getLogger("m5p")
38 |
39 |
40 | class M5Base(BaseDecisionTree):
41 | """
42 | M5Base. Implements base routines for generating M5 PredictionModel trees and rules.
43 |
44 | The original algorithm M5 was invented by Quinlan:
45 |
46 | - Quinlan J. R. (1992). Learning with continuous classes. Proceedings of the
47 | Australian Joint Conference on Artificial Intelligence. 343--348. World
48 | Scientific, Singapore.
49 |
50 | Yong Wang and Ian Witten made improvements and created M5':
51 |
52 | - Wang, Y and Witten, I. H. (1997). Induction of model trees for predicting
53 | continuous classes. Proceedings of the poster papers of the European
54 | Conference on Machine Learning. University of Economics, Faculty of
55 | Informatics and Statistics, Prague.
56 |
57 | Pruning and Smoothing can be activated and deactivated on top of the base
58 | model. TODO check if 'rules' should be supported too
59 |
60 | Inspired by Weka (https://github.com/bnjmn/weka) M5Base class, from Mark Hall
61 |
62 | Attributes
63 | ----------
64 | criterion : str, default="friedman_mse"
65 | M5 suggests to use the standard deviation (RMSE impurity) instead of
66 | variance (MSE impurity) or absolute deviation (MAE impurity) as in CART.
67 | According to M5' paper, both 3 criterions are equivalent.
68 | splitter : default="best"
69 | M5 suggests to take the feature leading to the best gain in criterion.
70 | max_depth : default None
71 | This is not used in the original M5 article, hence default=None.
72 | min_samples_split : int, default=4
73 | M5 suggests a value of 4 so that each leaf will have at least 2 samples.
74 | min_samples_leaf : int, default=2
75 | M5' suggest to add this explicitly to avoid zero-variance in leaves, and
76 | n<=p for constant models.
77 | min_weight_fraction_leaf : float, default=0.0
78 | This is not used in the original M5 article, hence default=0.0.
79 | TODO this would actually maybe be better than min_sample_leaf ?
80 | max_features : default None
81 | This is not used in the original M5 article, hence default is None.
82 | max_leaf_nodes : int, default None
83 | This is not used in the original M5 article, hence default is None.
84 | min_impurity_decrease : float, default=0.0
85 | This is not used in the original M5 article, hence default is None.
86 | class_weight : default None
87 | This is not used (?)
88 | leaf_model : RegressorMixin
89 | The regression model used in the leaves. This instance will be cloned
90 | for each leaf.
91 | use_pruning : bool
92 | If False, pruning will be disabled.
93 | use_smoothing : bool or str {'installed', 'on_prediction'}, default None
94 | None and True means 'installed' by default except if smoothing_constant
95 | or smoothing_constant_ratio is 0.0
96 | smoothing_constant: float, default None
97 | The smoothing constant k defined in the original M5 article, used as the
98 | weight for each parent model in the recursive weighting process.
99 | If None, the default value from the paper (k=15) will be used.
100 | smoothing_constant_ratio: float, default None
101 | An alternate way to define the smoothing constant, as a ratio of the
102 | total number of training samples. The resulting smoothing constant will
103 | be smoothing_constant_ratio * n where n is the number of samples.
104 | Note that this may not be an integer.
105 | debug_prints: bool, default False
106 | A boolean to enable debug prints
107 | ccp_alpha: float, default 0.0
108 | TODO is this relevant ? does that conflict with "use_pruning" ?
109 | random_state : None, int, or RandomState, default=None
110 | See `RegressionTree.random_state`.
111 | """
112 |
113 | def __init__(
114 | self,
115 | criterion="friedman_mse",
116 | splitter="best",
117 | max_depth=None,
118 | min_samples_split=4,
119 | min_samples_leaf=2,
120 | min_weight_fraction_leaf=0.0,
121 | max_features=None,
122 | max_leaf_nodes=None,
123 | min_impurity_decrease=0.0,
124 | class_weight=None,
125 | leaf_model=None,
126 | use_pruning=True,
127 | use_smoothing=None,
128 | smoothing_constant=None,
129 | smoothing_constant_ratio=None,
130 | debug_prints=False,
131 | ccp_alpha=0.0,
132 | random_state=None,
133 | ):
134 |
135 | # TODO the paper suggests to do this with 5% of total std
136 | # min_impurity_split = min_impurity_split_as_initial_ratio * initial_impurity
137 |
138 | super(M5Base, self).__init__(
139 | criterion=criterion,
140 | splitter=splitter,
141 | max_depth=max_depth,
142 | min_samples_split=min_samples_split,
143 | min_samples_leaf=min_samples_leaf,
144 | min_weight_fraction_leaf=min_weight_fraction_leaf,
145 | max_features=max_features,
146 | max_leaf_nodes=max_leaf_nodes,
147 | min_impurity_decrease=min_impurity_decrease,
148 | random_state=random_state,
149 | class_weight=class_weight,
150 | ccp_alpha=ccp_alpha,
151 | )
152 |
153 | # warning : if the field names are different from constructor params,
154 | # then clone(self) will not work.
155 | if leaf_model is None:
156 | # to handle case when the model is learnt on normalized data and
157 | # we wish to be able to read the model equations.
158 | leaf_model = DeNormalizableLinearRegression()
159 | self.leaf_model = leaf_model
160 |
161 | # smoothing related
162 | if smoothing_constant_ratio is not None and smoothing_constant is not None:
163 | raise ValueError("Only one of `smoothing_constant` and `smoothing_constant_ratio` should be provided")
164 | elif (smoothing_constant_ratio == 0.0 or smoothing_constant == 0) and (
165 | use_smoothing is True or use_smoothing == "installed"
166 | ):
167 | raise ValueError(
168 | "`use_smoothing` was explicitly enabled with "
169 | "pre-installed models, while smoothing "
170 | "constant/ratio are explicitly set to zero"
171 | )
172 |
173 | self.use_pruning = use_pruning
174 | self.use_smoothing = use_smoothing
175 | self.smoothing_constant = smoothing_constant
176 | self.smoothing_constant_ratio = smoothing_constant_ratio
177 |
178 | self.debug_prints = debug_prints
179 |
180 | def as_pretty_text(self, **kwargs):
181 | """
182 | Returns a multi-line representation of this decision tree, using
183 | `tree_to_text`.
184 |
185 | :return: a multi-line string representing this decision tree
186 | """
187 | from m5py.export import export_text_m5
188 | return export_text_m5(self, out_file=None, **kwargs)
189 |
190 | def fit(self, X, y: np.ndarray, sample_weight=None, check_input=True, X_idx_sorted="deprecated"):
191 | """Fit a M5Prime model.
192 |
193 | Parameters
194 | ----------
195 | X : numpy.ndarray
196 | y : numpy.ndarray
197 | sample_weight
198 | check_input
199 | X_idx_sorted
200 |
201 | Returns
202 | -------
203 |
204 | """
205 | # (0) smoothing default values behaviour
206 | if self.smoothing_constant_ratio == 0.0 or self.smoothing_constant == 0:
207 | self.use_smoothing = False
208 | elif self.use_smoothing is None or self.use_smoothing is True:
209 | # default behaviour
210 | if isinstance(self.leaf_model, LinearRegression):
211 | self.use_smoothing = "installed"
212 | else:
213 | self.use_smoothing = "on_prediction"
214 |
215 | # Finally make sure we are ok now
216 | if self.use_smoothing not in [False, np.bool_(False), "installed", "on_prediction"]:
217 | raise ValueError("use_smoothing: Unexpected value: %s, please report it as issue." % self.use_smoothing)
218 |
219 |
220 | # (1) Build the initial tree as usual
221 |
222 | # Get the correct fit method name based on the sklearn version used
223 | fit_method_name = "_fit" if SKLEARN13_OR_GREATER else "fit"
224 |
225 | fit_method = getattr(super(M5Base, self), fit_method_name)
226 | fit_method(X, y, sample_weight=sample_weight, check_input=check_input)
227 |
228 |
229 | if self.debug_prints:
230 | logger.debug("(debug_prints) Initial tree:")
231 | logger.debug(self.as_pretty_text(node_ids=True))
232 |
233 | # (1b) prune initial tree to take into account min impurity in splits
234 | prune_on_min_impurity(self.tree_)
235 |
236 | if self.debug_prints:
237 | logger.debug("(debug_prints) Postprocessed tree:")
238 | logger.debug(self.as_pretty_text(node_ids=True))
239 |
240 | # (2) Now prune tree and replace pruned branches with linear models
241 | # -- unfortunately we have to re-do this input validation
242 | # step, because it also converts the input to float32.
243 | if check_input:
244 | X = check_array(X, dtype=DTYPE, accept_sparse="csc")
245 | if issparse(X):
246 | X.sort_indices()
247 |
248 | if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
249 | raise ValueError("No support for np.int64 index based " "sparse matrices")
250 |
251 | # -- initialise the structure to contain the leaves and node models
252 | self.node_models = np.empty((self.tree_.node_count,), dtype=object)
253 |
254 | # -- Pruning requires to know the global deviation of the target
255 | global_std_dev = np.nanstd(y)
256 | # global_abs_dev = np.nanmean(np.abs(y))
257 |
258 | # -- Pruning requires to know the samples that reached each node.
259 | # From http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
260 | # Retrieve the decision path of each sample.
261 | samples_to_nodes = self.decision_path(X)
262 | # * row i is to see the nodes (non-empty j) in which sample i appears.
263 | # sparse row-first (CSR) format is OK
264 | # * column j is to see the samples (non-empty i) that fall into that
265 | # node. To do that, we need to make it CSC
266 | nodes_to_samples = samples_to_nodes.tocsc()
267 |
268 | # -- execute the pruning
269 | # self.features_usage is a dict feature_idx -> nb times used, only
270 | # for used features.
271 | self.features_usage = build_models_and_get_pruning_info(
272 | self.tree_,
273 | X,
274 | y,
275 | nodes_to_samples,
276 | self.leaf_model,
277 | self.node_models,
278 | global_std_dev,
279 | use_pruning=self.use_pruning,
280 | )
281 |
282 | # -- cleanup to compress inner structures: only keep non-pruned ones
283 | self._cleanup_tree()
284 |
285 | if self.debug_prints:
286 | logger.debug("(debug_prints) Pruned tree:")
287 | logger.debug(self.as_pretty_text(node_ids=True))
288 |
289 | if self.use_smoothing == "installed":
290 | # Retrieve the NEW decision path of each sample.
291 | samples_to_nodes = self.decision_path(X)
292 | # * row i is to see the nodes (non-empty j) in which sample i
293 | # appears. sparse row-first (CSR) format is OK
294 | # * column j is to see the samples (non-empty i) that fall into
295 | # that node. To do that, we need to make it CSC
296 | nodes_to_samples = samples_to_nodes.tocsc()
297 |
298 | # default behaviour for smoothing constant and ratio
299 | smoothing_constant = self._get_smoothing_constant_to_use(X)
300 |
301 | self.install_smoothing(X, y, nodes_to_samples, smoothing_constant=smoothing_constant)
302 |
303 | if self.debug_prints:
304 | logger.debug("(debug_prints) Pruned and smoothed tree:")
305 | logger.debug(self.as_pretty_text(node_ids=True))
306 |
307 | return self
308 |
309 | def _get_smoothing_constant_to_use(self, X):
310 | """
311 | Returns the smoothing_constant to use for smoothing, based on current
312 | settings and X data
313 | """
314 | if self.smoothing_constant_ratio is not None:
315 | nb_training_samples = X.shape[0]
316 | smoothing_cstt = self.smoothing_constant_ratio * nb_training_samples
317 | if smoothing_cstt < 15:
318 | warn(
319 | "smoothing constant ratio %s is leading to an extremely "
320 | "small smoothing constant %s because nb training samples"
321 | " is %s. Clipping to 15." % (self.smoothing_constant_ratio, smoothing_cstt, nb_training_samples)
322 | )
323 | smoothing_cstt = 15
324 | else:
325 | smoothing_cstt = self.smoothing_constant
326 | if smoothing_cstt is None:
327 | smoothing_cstt = 15 # default from the original Quinlan paper
328 |
329 | return smoothing_cstt
330 |
331 | def _cleanup_tree(self):
332 | """
333 | Reduces the size of this object by removing from internal structures
334 | all items that are not used any more (all leaves that have been pruned)
335 | """
336 | old_tree = self.tree_
337 | old_node_modls = self.node_models
338 |
339 | # Get all information to create a copy of the inner tree.
340 | # Note: _tree.copy() is gone so we use the pickle way
341 | # --- Base info: nb features, nb outputs, output classes
342 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L631
343 | # [1] = (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), self.n_outputs)
344 | # So these remain, we are just interested in changing the node-related arrays
345 | new_tree = _tree.Tree(*old_tree.__reduce__()[1])
346 |
347 | # --- Node info
348 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637
349 | dct = old_tree.__getstate__().copy()
350 |
351 | # cleanup: only keep the nodes that are not undefined.
352 | # note: this is identical to
353 | # valid_nodes_indices = dct["nodes"]['left_child'] != TREE_UNDEFINED
354 | valid_nodes_indices = old_tree.children_left != _tree.TREE_UNDEFINED
355 | # valid_nodes_indices2 = old_tree.children_right != TREE_UNDEFINED
356 | # assert np.all(valid_nodes_indices == valid_nodes_indices2)
357 | new_node_count = sum(valid_nodes_indices)
358 |
359 | # create empty new structures
360 | n_shape = (new_node_count, *dct["nodes"].shape[1:])
361 | new_nodes = _empty_contig_ar(n_shape, dtype=dct["nodes"].dtype)
362 | v_shape = (new_node_count, *dct["values"].shape[1:])
363 | new_values = _empty_contig_ar(v_shape, dtype=dct["values"].dtype)
364 | m_shape = (new_node_count, *old_node_modls.shape[1:])
365 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_modls.dtype)
366 |
367 | # Fill structures while reindexing the tree and remembering the depth
368 | global next_free_id
369 | next_free_id = 0
370 |
371 | def _compress(old_node_id):
372 | """
373 |
374 | Parameters
375 | ----------
376 | old_node_id
377 |
378 | Returns
379 | -------
380 | the depth and new indices of left and right children
381 |
382 | """
383 | global next_free_id
384 | new_node_id = next_free_id
385 | next_free_id += 1
386 |
387 | # use the old tree to walk
388 | old_node = dct["nodes"][old_node_id]
389 | left_id = old_node["left_child"]
390 | right_id = old_node["right_child"]
391 |
392 | # Create the new node with a copy of the old
393 | new_nodes[new_node_id] = old_node # this is an entire row so it is probably copied already by doing so.
394 | new_values[new_node_id] = dct["values"][old_node_id]
395 | new_node_models[new_node_id] = copy(old_node_modls[old_node_id])
396 |
397 | if left_id == _tree.TREE_LEAF:
398 | # ... and right_id == _tree.TREE_LEAF
399 | # => node_id is a leaf. Nothing to do
400 | return 1, new_node_id
401 | else:
402 |
403 | left_depth, new_id_left = _compress(left_id)
404 | right_depth, new_id_right = _compress(right_id)
405 |
406 | # store the new indices
407 | new_nodes[new_node_id]["left_child"] = new_id_left
408 | new_nodes[new_node_id]["right_child"] = new_id_right
409 |
410 | return 1 + max(left_depth, right_depth), new_node_id
411 |
412 | # complete definition of the new tree
413 | dct["max_depth"] = _compress(0)[0] - 1 # root node has depth 0, not 1
414 | dct["node_count"] = new_node_count # new_nodes.shape[0]
415 | dct["nodes"] = new_nodes
416 | dct["values"] = new_values
417 | new_tree.__setstate__(dct)
418 |
419 | # Fix an issue on sklearn 0.17.1: setstate was not updating max_depth
420 | # See https://github.com/scikit-learn/scikit-learn/blob/0.17.1/sklearn/tree/_tree.pyx#L623
421 | new_tree.max_depth = dct["max_depth"]
422 |
423 | # update self fields
424 | self.tree_ = new_tree
425 | self.node_models = new_node_models
426 |
427 | def install_smoothing(self, X_train_all, y_train_all, nodes_to_samples, smoothing_constant):
428 | """
429 | Executes the smoothing procedure described in the M5 and M5P paper,
430 | "once for all". This means that all models are modified so that after
431 | this method has completed, each model in the tree is already a smoothed
432 | model.
433 |
434 | This has pros (prediction speed) and cons (the model equations are
435 | harder to read - lots of redundancy)
436 |
437 | It can only be done if leaf models are instances of `LinearRegression`
438 | """
439 | # check validity: leaf models have to support pre-computed smoothing
440 | if not isinstance(self.leaf_model, LinearRegression):
441 | raise TypeError(
442 | "`install_smoothing` is only available if leaf "
443 | "models are instances of `LinearRegression` or a "
444 | "subclass"
445 | )
446 |
447 | # Select the Error metric to compute model errors (used to compare
448 | # node with subtree for pruning)
449 | # TODO in Weka they use RMSE, but in papers they use MAE. this should be a parameter.
450 | # TODO Shall we go further and store the residuals, or a bunch of metrics? not sure
451 | err_metric = root_mean_squared_error # mean_absolute_error
452 |
453 | # --- Node info
454 | # TODO we should probably not do this once in each method, but once or give access directly (no copy)
455 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637
456 | dct = self.tree_.__getstate__().copy()
457 | old_node_models = self.node_models
458 |
459 | m_shape = old_node_models.shape
460 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_models.dtype)
461 |
462 | def smooth__(
463 | coefs_,
464 | n_samples,
465 | features=None,
466 | parent=None, # type: _SmoothingDetails
467 | parent_features=None,
468 | k=smoothing_constant, #
469 | ):
470 | # type: (...) -> _SmoothingDetails
471 | """
472 | Smoothes the model coefficients or intercept `coefs_` (an array or
473 | a scalar), using the smoothing results at parent node.
474 |
475 | At each node we keep in memory three results A, B, C.
476 | - the new coef at each node is A + B
477 | - the recursion equations are
478 | - B(n) = (k / (n_samples(n) + k)) * A(n-1) + B(n-1)
479 | - C(n) = (n_samples(n) / (n_samples(n) + k)) * C(n-1)
480 | - A(n) = coef(n)*C(n)
481 | """
482 | if parent is None:
483 | # A0 = coef(0), C0 = 1, B0 = 0
484 | if features is None:
485 | # single scalar value
486 | return _SmoothingDetails(A=coefs_, B=0, C=1)
487 | else:
488 | # vector
489 | if len(coefs_) == 0:
490 | coefs_ = np.asarray(coefs_)
491 | if coefs_.shape[0] != len(features):
492 | raise ValueError("nb features does not match the nb " "of coefficients")
493 | return _SmoothingDetails(
494 | A=coefs_,
495 | B=np.zeros(coefs_.shape, dtype=coefs_.dtype),
496 | C=np.ones(coefs_.shape, dtype=coefs_.dtype),
497 | )
498 | else:
499 | # B(n) = k/(n_samples(n)+k)*A(n-1) + B(n-1)
500 | # C(n) = (n_samples(n) / (n_samples(n) + k)) * C(n-1)
501 | Bn = (k / (n_samples + k)) * parent.A + parent.B
502 | Cn = (n_samples / (n_samples + k)) * parent.C
503 |
504 | # A(n) = coef(n) * C(n)
505 | if features is None:
506 | # single scalar value: easy
507 | An = coefs_ * Cn
508 | return _SmoothingDetails(A=An, B=Bn, C=Cn)
509 | else:
510 | # vector of coefs: we have to 'expand' the coefs array
511 | # because the coefs at this node apply for (features)
512 | # while coefs at parent node apply for (parent_features)
513 | An = np.zeros(Cn.shape, dtype=Cn.dtype)
514 | parent_features = np.array(parent_features)
515 | features = np.array(features)
516 |
517 | # Thanks https://stackoverflow.com/a/8251757/7262247 !
518 | index = np.argsort(parent_features)
519 | sorted_parents = parent_features[index]
520 | sorted_index = np.searchsorted(sorted_parents, features)
521 |
522 | features_index = np.take(index, sorted_index, mode="clip")
523 | if np.any(parent_features[features_index] != features):
524 | # result = np.ma.array(features_index, mask=mask)
525 | raise ValueError(
526 | "Internal error - please report this."
527 | "One feature was found in the child "
528 | "node, that was not in the parent "
529 | "node."
530 | )
531 |
532 | if len(features_index) > 0:
533 | An[features_index] = coefs_ * Cn[features_index]
534 |
535 | return _SmoothingDetails(A=An, B=Bn, C=Cn)
536 |
537 | def _smooth(
538 | node_id,
539 | parent_features=None,
540 | parent_coefs: _SmoothingDetails = None,
541 | parent_intercept: _SmoothingDetails = None,
542 | ):
543 | # Gather all info on this node
544 | # --base regression tree
545 | node_info = dct["nodes"][node_id]
546 | left_id = node_info["left_child"]
547 | right_id = node_info["right_child"]
548 | # --additional model
549 | node_model = old_node_models[node_id]
550 | # --samples
551 | samples_at_this_node = get_samples_at_node(node_id, nodes_to_samples)
552 | n_samples_at_this_node = samples_at_this_node.shape[0]
553 |
554 | # Note: should be equal to tree.n_node_samples[node_id]
555 | if n_samples_at_this_node != self.tree_.n_node_samples[node_id]:
556 | raise ValueError("n_samples_at_this_node: Unexpected value, please report it as issue.")
557 |
558 | y_true_this_node = y_train_all[samples_at_this_node]
559 | X_this_node = X_train_all[samples_at_this_node, :]
560 |
561 | # (1) smooth current node
562 | parent_features = parent_features if parent_features is not None else None
563 | is_constant_leaf = False
564 | if left_id == _tree.TREE_LEAF and isinstance(node_model, ConstantLeafModel):
565 | is_constant_leaf = True
566 | node_features = ()
567 | smoothed_features = parent_features if parent_features is not None else node_features
568 | node_coefs = ()
569 | node_intercept = dct["values"][node_id]
570 |
571 | # Extract the unique scalar value
572 | if len(node_intercept) != 1:
573 | raise ValueError("node_intercept: Unexpected value: , please report it as issue." % node_intercept)
574 |
575 | node_intercept = node_intercept.item()
576 |
577 | else:
578 | # A leaf LinRegLeafModel or a split SplitNodeModel
579 | node_features = node_model.features
580 | smoothed_features = parent_features if parent_features is not None else node_features
581 | node_coefs = node_model.model.coef_
582 | node_intercept = node_model.model.intercept_
583 |
584 | # Create a new linear regression model with smoothed coefficients
585 | smoothed_sklearn_model = clone(self.leaf_model)
586 | smoothed_coefs = smooth__(
587 | node_coefs,
588 | features=node_features,
589 | n_samples=n_samples_at_this_node,
590 | parent=parent_coefs,
591 | parent_features=parent_features,
592 | )
593 | smoothed_intercept = smooth__(node_intercept, n_samples=n_samples_at_this_node, parent=parent_intercept)
594 | smoothed_sklearn_model.coef_ = smoothed_coefs.A + smoothed_coefs.B
595 | smoothed_sklearn_model.intercept_ = smoothed_intercept.A + smoothed_intercept.B
596 |
597 | # Finally update the node
598 | if is_constant_leaf:
599 | smoothed_node_model = LinRegLeafModel(smoothed_features, smoothed_sklearn_model, None)
600 | else:
601 | smoothed_node_model = copy(node_model)
602 | smoothed_node_model.features = smoothed_features
603 | smoothed_node_model.model = smoothed_sklearn_model
604 |
605 | # Remember the new smoothed model
606 | new_node_models[node_id] = smoothed_node_model
607 |
608 | if left_id == _tree.TREE_LEAF:
609 | # If this is a leaf, update the prediction error on X
610 | y_pred_this_node = smoothed_node_model.predict(X_this_node)
611 | smoothed_node_model.error = err_metric(y_true_this_node, y_pred_this_node)
612 |
613 | else:
614 | # If this is a split node - recurse on each subtree
615 | _smooth(
616 | left_id,
617 | parent_features=smoothed_features,
618 | parent_coefs=smoothed_coefs,
619 | parent_intercept=smoothed_intercept,
620 | )
621 | _smooth(
622 | right_id,
623 | parent_features=smoothed_features,
624 | parent_coefs=smoothed_coefs,
625 | parent_intercept=smoothed_intercept,
626 | )
627 |
628 | # Update the error using the same formula than the one we used
629 | # in build_models_and_get_pruning_info
630 | y_pred_children = predict_from_leaves_no_smoothing(self.tree_, new_node_models, X_this_node)
631 | err_children = err_metric(y_true_this_node, y_pred_children)
632 |
633 | # the total number of parameters is the sum of params in each
634 | # branch PLUS 1 for the split
635 | n_params_splitmodel = new_node_models[left_id].n_params + new_node_models[right_id].n_params + 1
636 | smoothed_node_model.n_params = n_params_splitmodel
637 | # do not adjust the error now, simply store the raw one
638 | # smoothed_node_model.error = compute_adjusted_error(
639 | # err_children, n_samples_at_this_node, n_params_splitmodel)
640 | smoothed_node_model.error = err_children
641 |
642 | return
643 |
644 | # smooth the whole tree
645 | _smooth(0)
646 |
647 | # use the new node models now
648 | self.node_models = new_node_models
649 |
650 | # remember the smoothing constant installed
651 | self.installed_smoothing_constant = smoothing_constant
652 | return
653 |
654 | def denormalize(self, x_scaler, y_scaler):
655 | """
656 | De-normalizes this model according to the provided x and y normalization scalers.
657 | Currently only StandardScaler issupported.
658 |
659 | :param x_scaler: a StandardScaler or None
660 | :param y_scaler: a StandardScaler or None
661 | :return:
662 | """
663 | # perform denormalization
664 | self._denormalize_tree(x_scaler, y_scaler)
665 |
666 | def _denormalize_tree(self, x_scaler, y_scaler):
667 | """
668 | De-normalizes all models in the tree
669 | :return:
670 | """
671 | old_tree = self.tree_
672 | old_node_models = self.node_models
673 |
674 | # Get all information to create a copy of the inner tree.
675 | # Note: _tree.copy() is gone so we use the pickle way
676 | # --- Base info: nb features, nb outputs, output classes
677 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L631
678 | # [1] = (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), self.n_outputs)
679 | # So these remain, we are just interested in changing the node-related arrays
680 | new_tree = _tree.Tree(*old_tree.__reduce__()[1])
681 |
682 | # --- Node info
683 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637
684 | dct = old_tree.__getstate__().copy()
685 |
686 | # create empty new structures for nodes and models
687 | n_shape = dct["nodes"].shape
688 | new_nodes = _empty_contig_ar(n_shape, dtype=dct["nodes"].dtype)
689 | v_shape = dct["values"].shape
690 | new_values = _empty_contig_ar(v_shape, dtype=dct["values"].dtype)
691 | m_shape = old_node_models.shape
692 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_models.dtype)
693 |
694 | def _denormalize(node_id, x_scaler, y_scaler):
695 | """
696 | denormalizes the subtree below node `node_id`.
697 |
698 | - `new_nodes[node_id]` is filled with a copy of the old node
699 | (see old_node.dtype to see the various fields). If the node is
700 | a split node, the split threshold
701 | `new_nodes[node_id]['threshold']` is denormalized.
702 |
703 | - `new_values[node_id]` is filled with a denormalized copy of the
704 | old constant prediction at this node. Reminder: this constant
705 | prediction is actually used only when on the leaves now, but
706 | we keep it for ref.
707 |
708 | - `new_node_models[node_id]` is filled with a copy of the old
709 | model `old_node_models[node_id]`, that is denormalized if it
710 | is not a constant leaf
711 |
712 | :param node_id:
713 | :return: (nothing)
714 | """
715 | # use the old tree to walk
716 | old_node = dct["nodes"][node_id]
717 | left_id = old_node['left_child']
718 | right_id = old_node['right_child']
719 |
720 | # Create the new node with a copy of the old
721 | new_nodes[node_id] = old_node # this is an entire row so it is probably copied already by doing so.
722 | new_model = copy(old_node_models[node_id])
723 | new_node_models[node_id] = new_model
724 |
725 | # Create the new value by de-scaling y
726 | if y_scaler is not None:
727 | # Note: if this is a split node with a linear regression model
728 | # the value will never be used. However to preserve consistency
729 | # of the whole values structure and for debugging purposes, we
730 | # choose this safe path of denormalizing ALL.
731 | # TODO we could also do it at once outside of the recursive
732 | # calls, but we should check for side effects
733 | new_values[node_id] = y_scaler.inverse_transform(
734 | dct["values"][node_id]
735 | )
736 | else:
737 | # no denormalization: simply copy
738 | new_values[node_id] = dct["values"][node_id]
739 |
740 | if left_id == _tree.TREE_LEAF:
741 | # ... and right_id == _tree.TREE_LEAF
742 | # => node_id is a leaf
743 | if isinstance(new_model, ConstantLeafModel):
744 | # nothing to do: we already re-scaled the value
745 | return
746 | elif isinstance(new_model, LinRegLeafModel):
747 | # denormalize model
748 | new_model.denormalize(x_scaler, y_scaler)
749 | return
750 | else:
751 | raise TypeError("Internal error - Leafs can only be"
752 | "constant or linear regression")
753 | else:
754 | # this is a split node, denormalize each subtree
755 | _denormalize(left_id, x_scaler, y_scaler)
756 | _denormalize(right_id, x_scaler, y_scaler)
757 |
758 | # denormalize the split value if needed
759 | if x_scaler is not None:
760 | split_feature = old_node['feature']
761 | # The denormalizer requires a vector with all the features,
762 | # even if we only want to denormalize one.
763 | # -- put split value in a vector where it has pos 'feature'
764 | old_threshold_and_zeros = np.zeros((self.n_features_, ), dtype=dct["nodes"]['threshold'].dtype)
765 | old_threshold_and_zeros[split_feature] = old_node['threshold']
766 | # -- denormalize the vector and retrieve value 'feature'
767 | new_nodes[node_id]['threshold'] = x_scaler.inverse_transform(old_threshold_and_zeros)[split_feature]
768 | else:
769 | # no denormalization: simply copy
770 | new_nodes[node_id]['threshold'] = old_node['threshold']
771 |
772 | if isinstance(new_model, SplitNodeModel):
773 | # denormalize model at split node too, even if it is not
774 | # always used (depending on smoothing mode)
775 | new_model.denormalize(x_scaler, y_scaler)
776 | else:
777 | raise TypeError("Internal error: all intermediate nodes"
778 | "should be SplitNodeModel")
779 |
780 | return
781 |
782 | # denormalize the whole tree and put the result in (new_nodes,
783 | # new_values, new_node_models) recursively
784 | _denormalize(0, x_scaler, y_scaler)
785 |
786 | # complete definition of the new tree
787 | dct["nodes"] = new_nodes
788 | dct["values"] = new_values
789 | new_tree.__setstate__(dct)
790 |
791 | # update the self fields
792 | # self.features_usage
793 | self.tree_ = new_tree
794 | self.node_models = new_node_models
795 |
796 | def compress_features(self):
797 | """
798 | Compresses the model and returns a vector of required feature indices.
799 | This model input will then be X[:, features] instead of X.
800 | """
801 | used_features = sorted(self.features_usage.keys())
802 | new_features_lookup_dct = {old_feature_idx: i for i, old_feature_idx in enumerate(used_features)}
803 |
804 | if used_features == list(range(self.n_features_in_)):
805 | # NOTHING TO DO: we need all features
806 | return used_features
807 |
808 | # Otherwise, We can compress. For this we have to create a copy of the
809 | # tree because we will change its internals
810 | old_tree = self.tree_
811 | old_node_modls = self.node_models
812 |
813 | # Get all information to create a copy of the inner tree.
814 | # Note: _tree.copy() is gone so we use the pickle way
815 | # --- Base info: nb features, nb outputs, output classes
816 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L631
817 | # [1] = (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), self.n_outputs)
818 | # So these remain, we are just interested in changing the node-related arrays
819 | new_tree = _tree.Tree(*old_tree.__reduce__()[1])
820 |
821 | # --- Node info
822 | # see https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L637
823 | dct = old_tree.__getstate__().copy()
824 |
825 | # create empty new structures for nodes and models
826 | n_shape = dct["nodes"].shape
827 | new_nodes = _empty_contig_ar(n_shape, dtype=dct["nodes"].dtype)
828 | m_shape = old_node_modls.shape
829 | new_node_models = _empty_contig_ar(m_shape, dtype=old_node_modls.dtype)
830 |
831 | def _compress_features(node_id):
832 | """
833 |
834 | Parameters
835 | ----------
836 | node_id
837 |
838 | Returns
839 | -------
840 | the depth and new indices of left and right children
841 |
842 | """
843 | # use the old tree to walk
844 | old_node = dct["nodes"][node_id]
845 | left_id = old_node["left_child"]
846 | right_id = old_node["right_child"]
847 |
848 | # Create the new node with a copy of the old
849 | new_nodes[node_id] = old_node # this is an entire row so it is probably copied already by doing so.
850 | new_model = copy(old_node_modls[node_id])
851 | new_node_models[node_id] = new_model
852 |
853 | if left_id == _tree.TREE_LEAF:
854 | # ... and right_id == _tree.TREE_LEAF
855 | # => node_id is a leaf
856 | if isinstance(new_model, ConstantLeafModel):
857 | # no features used
858 | return
859 | elif isinstance(new_model, LinRegLeafModel):
860 | # compress that model
861 | new_model.reindex_features(new_features_lookup_dct)
862 | return
863 | else:
864 | raise TypeError("Internal error - Leafs can only be" "constant or linear regression")
865 | else:
866 |
867 | _compress_features(left_id)
868 | _compress_features(right_id)
869 |
870 | # store the new split feature index in the node
871 | new_nodes[node_id]["feature"] = new_features_lookup_dct[old_node["feature"]]
872 |
873 | if not isinstance(new_model, SplitNodeModel):
874 | raise TypeError("Internal error: all intermediate nodes" "should be SplitNodeModel")
875 |
876 | # TODO now that split node models have a linear regression
877 | # model too, we should update them here
878 |
879 | return
880 |
881 | _compress_features(0)
882 |
883 | # complete definition of the new tree
884 | dct["nodes"] = new_nodes
885 | new_tree.__setstate__(dct)
886 |
887 | # update the self fields
888 | self.features_usage = {new_features_lookup_dct[k]: v for k, v in self.features_usage.items()}
889 | self.tree_ = new_tree
890 | self.node_models = new_node_models
891 | self.n_features_in_ = len(self.features_usage)
892 |
893 | # return the vector of used features
894 | return used_features
895 |
896 | @property
897 | def feature_importances_(self):
898 | """Return the feature importances (the higher, the more important).
899 |
900 | Returns
901 | -------
902 | feature_importances_ : array, shape = [n_features]
903 | """
904 | check_is_fitted(self, "tree_")
905 |
906 | # TODO adapt ?
907 | # features = np.array([self.features_usage[k] for k in sorted(self.features_usage.keys())], dtype=int)
908 | features = self.tree_.compute_feature_importances()
909 |
910 | return features
911 |
912 | def predict(self, X, check_input=True, smooth_predictions=None, smoothing_constant=None):
913 | """Predict class or regression value for X.
914 |
915 | For a classification model, the predicted class for each sample in X is
916 | returned. For a regression model, the predicted value based on X is
917 | returned.
918 |
919 | Parameters
920 | ----------
921 | X : numpy.ndarray, shape (n_samples, n_features)
922 | The input samples. Internally, it will be converted to
923 | ``dtype=np.float32`` and if a sparse matrix is provided
924 | to a sparse ``csr_matrix``.
925 | check_input : boolean, (default=True)
926 | Allow to bypass several input checking.
927 | Don't use this parameter unless you know what you do.
928 | smooth_predictions : boolean, (default=None)
929 | None means "use self config"
930 | smoothing_constant: int, (default=self.smoothing_constant)
931 | Smoothing constant used when smooth_predictions is True. During
932 | smoothing, child node models are recursively mixed with their
933 | parent node models, and the mix is done using a weighted sum. The
934 | weight given to the child model is the number of training samples
935 | that reached its node, while the weight given to the parent model
936 | is the smoothing constant. Therefore it can be seen as an
937 | equivalent number of samples that parent models represent when
938 | injected into the recursive weighted sum.
939 |
940 | Returns
941 | -------
942 | y : array of shape = [n_samples] or [n_samples, n_outputs]
943 | The predicted classes, or the predict values.
944 | """
945 | check_is_fitted(self, "tree_")
946 |
947 | # If this is just a constant node only check the input's shape.
948 | if self.n_features_in_ == 0:
949 | # perform the input checking manually to set ensure_min_features=0
950 | X = check_array(X, dtype=DTYPE, accept_sparse="csr", ensure_min_features=0)
951 | if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc):
952 | raise ValueError("No support for np.int64 index based " "sparse matrices")
953 | # skip it then
954 | check_input = False
955 |
956 | # validate and convert dtype
957 | X = self._validate_X_predict(X, check_input)
958 |
959 | # -------- This is the only change wrt parent class. TODO maybe rather replace self.tree_ with a proxy ---------
960 | if smooth_predictions is None:
961 | # Default: smooth prediction at prediction time if configured as
962 | # such in the model. Note: if self.use_smoothing == 'installed',
963 | # models are already smoothed models, so no need to smooth again
964 | smooth_predictions = self.use_smoothing == "on_prediction"
965 | else:
966 | # user provided an explicit value for smooth_predictions
967 | if not smooth_predictions and self.use_smoothing == "installed":
968 | raise ValueError(
969 | "Smoothing has been pre-installed on this "
970 | "tree, it is not anymore possible to make "
971 | "predictions without smoothing"
972 | )
973 |
974 | if smooth_predictions and smoothing_constant is None:
975 | # default parameter for smoothing is the one defined in the model
976 | # (with possible ratio)
977 | smoothing_constant = self._get_smoothing_constant_to_use(X)
978 |
979 | # Do not use the embedded tree (like in super): it has been pruned but
980 | # still has constant nodes
981 | # proba = self.tree_.predict(X)
982 | if not smooth_predictions:
983 | proba = predict_from_leaves_no_smoothing(self.tree_, self.node_models, X)
984 | else:
985 | proba = predict_from_leaves(self, X, smoothing=True, smoothing_constant=smoothing_constant)
986 | if len(proba.shape) < 2:
987 | proba = proba.reshape(-1, 1)
988 | # ------------------------------------------
989 |
990 | n_samples = X.shape[0]
991 |
992 | # Classification
993 | if is_classifier(self):
994 | if self.n_outputs_ == 1:
995 | return self.classes_.take(np.argmax(proba, axis=1), axis=0)
996 | else:
997 | predictions = np.zeros((n_samples, self.n_outputs_))
998 | for k in range(self.n_outputs_):
999 | predictions[:, k] = self.classes_[k].take(np.argmax(proba[:, k], axis=1), axis=0)
1000 | return predictions
1001 |
1002 | # Regression
1003 | else:
1004 | if self.n_outputs_ == 1:
1005 | return proba[:, 0]
1006 | else:
1007 | return proba[:, :, 0]
1008 |
1009 |
1010 | class ConstantLeafModel:
1011 | """
1012 | Represents the additional information about a leaf node that is not pruned.
1013 | It contains the error associated with the training samples at this node.
1014 |
1015 | Note: the constant value is not stored here, as it is already available in
1016 | the sklearn tree struct. So to know the prediction at this node, use
1017 | `tree.value[node_id]`
1018 | """
1019 |
1020 | __slots__ = ("error",)
1021 |
1022 | def __init__(self, error):
1023 | self.error = error
1024 |
1025 | @property
1026 | def n_params(self) -> int:
1027 | """
1028 | Returns the number of parameters used by this model, including the
1029 | constant driver
1030 | """
1031 | return 1
1032 |
1033 | @staticmethod
1034 | def predict_cstt(tree, node_id, n_samples):
1035 | """
1036 | This static method is a helper to get an array of constant predictions
1037 | associated with this node.
1038 |
1039 | Parameters
1040 | ----------
1041 | tree
1042 | node_id
1043 | n_samples
1044 |
1045 | Returns
1046 | -------
1047 |
1048 | """
1049 | cstt_prediction = tree.value[node_id]
1050 |
1051 | # cstt_prediction can be multioutput so it is an array. replicate it
1052 | from numpy import matlib # note: np.matlib not available in 1.10.x
1053 |
1054 | return matlib.repmat(cstt_prediction, n_samples, 1)
1055 |
1056 |
1057 | class LinRegNodeModel(DeNormalizableMixIn):
1058 | """
1059 | Represents the additional information about a tree node that contains a
1060 | linear regression model.
1061 |
1062 | It contains
1063 | - the features used by the model,
1064 | - the scikit learn model object itself,
1065 | - and the error for the training samples that reached this node
1066 |
1067 | """
1068 | __slots__ = ("features", "model", "error", "n_params")
1069 |
1070 | def __init__(self, features, model, error) -> None:
1071 | self.features = features
1072 | self.model = model
1073 | self.error = error
1074 | self.n_params = len(self.features) + 1
1075 |
1076 | def to_text(self, feature_names=None, target_name=None, precision=3,
1077 | line_breaks=False):
1078 | """ Returns a text representation of the linear regression model """
1079 | return linreg_model_to_text(self.model, feature_names=feature_names,
1080 | target_name=target_name,
1081 | precision=precision,
1082 | line_breaks=line_breaks)
1083 |
1084 | def reindex_features(self, new_features_lookup_dct):
1085 | """
1086 | Reindexes the required features using the provided lookup dictionary.
1087 |
1088 | Parameters
1089 | ----------
1090 | new_features_lookup_dct
1091 |
1092 | Returns
1093 | -------
1094 |
1095 | """
1096 | self.features = [new_features_lookup_dct[f] for f in self.features]
1097 |
1098 | def denormalize(self,
1099 | x_scaler: StandardScaler = None,
1100 | y_scaler: StandardScaler = None
1101 | ):
1102 | """
1103 | De-normalizes the linear model.
1104 |
1105 | :param x_scaler:
1106 | :param y_scaler:
1107 | :return:
1108 | """
1109 | # create a clone of the x scaler with only the used features
1110 | x_scaler = copy(x_scaler)
1111 | if len(self.features) > 0:
1112 | x_scaler.scale_ = x_scaler.scale_[self.features]
1113 | x_scaler.mean_ = x_scaler.mean_[self.features]
1114 | else:
1115 | # in that particular case, the above expression is not working
1116 | x_scaler.scale_ = x_scaler.scale_[0:0]
1117 | x_scaler.mean_ = x_scaler.mean_[0:0]
1118 |
1119 | # use the denormalize function on the internal model
1120 | # TODO what if internal model is not able ?
1121 | self.model.denormalize(x_scaler, y_scaler)
1122 |
1123 | def predict(self, X: np.ndarray):
1124 | """Performs a prediction for X, only using the required features."""
1125 | if len(self.features) < 1 and isinstance(self.model, LinearRegression):
1126 | # unfortunately LinearRegression models do not like it when no
1127 | # features are needed: their input validation requires at least 1.
1128 | # so we do it ourselves.
1129 | return self.model.intercept_ * np.ones((X.shape[0], 1))
1130 | else:
1131 | return self.model.predict(X[:, self.features])
1132 |
1133 |
1134 | class LinRegLeafModel(LinRegNodeModel):
1135 | """
1136 | Represents the additional information about a leaf node with a linear
1137 | regression model
1138 | """
1139 | pass
1140 |
1141 |
1142 | class SplitNodeModel(LinRegNodeModel):
1143 | """
1144 | Represents the additional information about a split node, with a linear
1145 | regression model.
1146 | """
1147 | __slots__ = ("n_params",)
1148 |
1149 | def __init__(self, n_params, error, features, model):
1150 | self.n_params = n_params
1151 | super(SplitNodeModel, self).__init__(features=features, model=model, error=error)
1152 |
1153 |
1154 | PRUNING_MULTIPLIER = 2
1155 | # see https://github.com/bnjmn/weka/blob/master/weka/src/main/java/weka/classifiers/trees/m5/RuleNode.java#L124
1156 | # TODO check why they use 2 instead of 1 (from the article) ??
1157 |
1158 |
1159 | def root_mean_squared_error(*args, **kwargs):
1160 | return np.sqrt(mean_squared_error(*args, **kwargs))
1161 |
1162 |
1163 | def prune_on_min_impurity(tree):
1164 | """
1165 | Edits the given tree so as to prune subbranches that do not respect the min
1166 | impurity criterion.
1167 |
1168 | The paper suggests to do this with 5% but min_impurity_split is not
1169 | available in 0.17 (and criterion is not std but mse so we have to square)
1170 |
1171 | Parameters
1172 | ----------
1173 | tree
1174 |
1175 | Returns
1176 | -------
1177 |
1178 | """
1179 | left_children = tree.children_left
1180 | right_children = tree.children_right
1181 | impurities = tree.impurity
1182 |
1183 | # The paper suggests to do this with 5% but min_impurity_split is not
1184 | # available in 0.17 (and criterion is not std but mse so we have to square)
1185 | # TODO adapt the formula to criterion used and/or generalize with min_impurity_split_as_initial_ratio
1186 | root_impurity = impurities[0]
1187 | impurity_threshold = root_impurity * (0.05 ** 2)
1188 |
1189 | def stop_on_min_impurity(node_id):
1190 | # note: in the paper that is 0.05 but criterion is on std. Here
1191 | # impurity is mse so a squared equivalent of std.
1192 | left_id = left_children[node_id]
1193 | right_id = right_children[node_id]
1194 | if left_id != _tree.TREE_LEAF: # a split node
1195 | if impurities[node_id] < impurity_threshold:
1196 | # stop here, this will be a leaf
1197 | prune_children(node_id, tree)
1198 | else:
1199 | stop_on_min_impurity(left_id)
1200 | stop_on_min_impurity(right_id)
1201 |
1202 | stop_on_min_impurity(0)
1203 |
1204 |
1205 | def build_models_and_get_pruning_info(
1206 | tree, X_train_all, y_train_all, nodes_to_samples, leaf_model, node_models, global_std_dev, use_pruning, node_id=0
1207 | ):
1208 | """
1209 |
1210 | Parameters
1211 | ----------
1212 | tree : Tree
1213 | A tree that will be pruned on the way
1214 | X_train_all
1215 | y_train_all
1216 | nodes_to_samples
1217 | leaf_model
1218 | node_models
1219 | global_std_dev
1220 | use_pruning
1221 | node_id
1222 |
1223 | Returns
1224 | -------
1225 | a dictionary where the key is the feature index and the value is
1226 | the number of samples where this feature is used
1227 | """
1228 |
1229 | # Select the Error metric to compute model errors (used to compare node
1230 | # with subtree for pruning)
1231 | # TODO in Weka they use RMSE, but in papers they use MAE. could be a param
1232 | # TODO Shall we go further and store the residuals, or a bunch of metrics?
1233 | err_metric = root_mean_squared_error # mean_absolute_error
1234 |
1235 | # Get the samples associated with this node
1236 | samples_at_this_node = get_samples_at_node(node_id, nodes_to_samples)
1237 | n_samples_at_this_node = samples_at_this_node.shape[0]
1238 | y_true_this_node = y_train_all[samples_at_this_node]
1239 |
1240 | # Is this a leaf node or a split node ?
1241 | left_node = tree.children_left[node_id] # this way we do not have to query it again in the else.
1242 | if left_node == _tree.TREE_LEAF:
1243 | # --> Current node is a LEAF. See is_leaf(node_id, tree) if you have doubts <--
1244 |
1245 | # -- create a linear model for this node
1246 | # leaves do not have the right to use any features since they have no
1247 | # subtree: keep the constant prediction
1248 | y_pred_this_node = ConstantLeafModel.predict_cstt(tree, node_id, y_true_this_node.shape[0])
1249 | err_this_node = err_metric(y_true_this_node, y_pred_this_node)
1250 |
1251 | # TODO when use_pruning = False, should we allow all features to be
1252 | # used instead of having to stick to the M5 rule of "only use a
1253 | # feature if the subtree includes a split with this feature" ?
1254 | # OR alternate proposal: should we transform the boolean use_pruning
1255 | # into a use_pruning_max_nb integer to say for example "only 2 level
1256 | # of pruning" ?
1257 |
1258 | # -- store the model information
1259 | node_models[node_id] = ConstantLeafModel(err_this_node)
1260 |
1261 | # -- return an empty dict for features used
1262 | return dict() # np.array([], dtype=int)
1263 |
1264 | else:
1265 | # --> Current node is a SPLIT <--
1266 | right_node = tree.children_right[node_id]
1267 |
1268 | # (1) prune left and right subtree and get some information
1269 | features_l = build_models_and_get_pruning_info(
1270 | tree, X_train_all, y_train_all, nodes_to_samples, leaf_model,
1271 | node_models, global_std_dev, use_pruning, node_id=left_node
1272 | )
1273 | features_r = build_models_and_get_pruning_info(
1274 | tree, X_train_all, y_train_all, nodes_to_samples, leaf_model,
1275 | node_models, global_std_dev, use_pruning, node_id=right_node
1276 | )
1277 |
1278 | # (2) select only the samples that reach this node
1279 | X_this_node = X_train_all[samples_at_this_node, :]
1280 |
1281 | # (3) Create a model for this node
1282 | # -- fit a linear regression model taking into account only the
1283 | # features used in the subtrees + the split one
1284 | # TODO should we normalize=True (variance scaling) or use a whole
1285 | # pipeline here ? Not sure..
1286 | # estimators = [('scale', StandardScaler()), ('clf', model_type())];
1287 | # pipe = Pipeline(estimators)
1288 | # skmodel_this_node = leaf_model_type(**leaf_model_params)
1289 | skmodel_this_node = clone(leaf_model)
1290 |
1291 | # -- old - we used to only store the array of features
1292 | # selected_features = np.union1d(features_l, features_r)
1293 | # selected_features = np.union1d(selected_features, tree.feature[node_id])
1294 | # -- new - we also gather the nb samples where this feature is used
1295 | selected_features_dct = features_l
1296 | for feature_id, n_samples_feat_used in features_r.items():
1297 | if feature_id in selected_features_dct.keys():
1298 | selected_features_dct[feature_id] += n_samples_feat_used
1299 | else:
1300 | selected_features_dct[feature_id] = n_samples_feat_used
1301 | selected_features_dct[tree.feature[node_id]] = n_samples_at_this_node
1302 | # -- use only the selected features, in the natural integer order
1303 | selected_features = sorted(selected_features_dct.keys())
1304 |
1305 | X_train_this_node = X_this_node[:, selected_features]
1306 | skmodel_this_node.fit(X_train_this_node, y_true_this_node)
1307 | # -- predict and compute error
1308 | y_pred_this_node = skmodel_this_node.predict(X_train_this_node)
1309 | err_this_node = err_metric(y_true_this_node, y_pred_this_node)
1310 | # -- create the object
1311 | # TODO the paper suggest to perform recursive feature elimination in
1312 | # this model until adjusted_err_model is minimal, is it same in Weka ?
1313 | model_this_node = LinRegLeafModel(selected_features, skmodel_this_node, err_this_node)
1314 |
1315 | # (4) compute adj error criterion = ERR * (n+v)/(n-v) for both models
1316 | adjusted_err_model = compute_adjusted_error(err_this_node, n_samples_at_this_node, model_this_node.n_params)
1317 |
1318 | # (5) predict and compute adj error for the combination of child models
1319 | # -- Note: this is recursive so the leaves may contain linear models
1320 | # already
1321 | y_pred_children = predict_from_leaves_no_smoothing(tree, node_models, X_this_node)
1322 | err_children = err_metric(y_true_this_node, y_pred_children)
1323 |
1324 | # TODO the Weka implem (below) differs from the paper that suggests a
1325 | # weigthed sum of adjusted errors. This is maybe an equivalent formulation, to check.
1326 | # the total number of parameters if we do not prune, is the sum of
1327 | # params in each branch PLUS 1 for the split
1328 | n_params_splitmodel = node_models[left_node].n_params + node_models[right_node].n_params + 1
1329 | adjusted_err_children = compute_adjusted_error(err_children, n_samples_at_this_node, n_params_splitmodel)
1330 |
1331 | # (6) compare and either prune at this node or keep the subtrees
1332 | std_dev_this_node = np.nanstd(y_true_this_node)
1333 | # note: the first criterion is now already checked before that
1334 | # function call, in `prune_on_min_impurity`
1335 | if use_pruning and (
1336 | # TODO these parameters should be in the constructor
1337 | # see also 2 comments about min_impurity_split_as_initial_ratio
1338 | std_dev_this_node < (global_std_dev * 0.05)
1339 | or (adjusted_err_model <= adjusted_err_children)
1340 | or (adjusted_err_model < (global_std_dev * 0.00001)) # so this means very good R² already
1341 | ):
1342 | # Choose model for this node rather than subtree model
1343 | # -- prune from this node on
1344 | removed_nodes = prune_children(node_id, tree)
1345 | node_models[removed_nodes] = _tree.TREE_UNDEFINED
1346 |
1347 | # store the model information
1348 | node_models[node_id] = model_this_node
1349 |
1350 | # update and return the features used
1351 | selected_features_dct = {k: n_samples_at_this_node for k in selected_features_dct.keys()}
1352 | return selected_features_dct
1353 |
1354 | else:
1355 | # The subtrees are good or we do not want pruning: keep them.
1356 | # This node will remain a split, and will only contain the digest
1357 | # about the subtree
1358 | node_models[node_id] = SplitNodeModel(
1359 | n_params_splitmodel, err_children, selected_features, skmodel_this_node
1360 | )
1361 |
1362 | # return the features used
1363 | return selected_features_dct
1364 |
1365 |
1366 | def compute_adjusted_error(err, n_samples, n_parameters, multiplier=PRUNING_MULTIPLIER):
1367 | """
1368 | Return the penalized error obtained from `err`, the prediction error at a
1369 | given tree node, by multiplying it by ``(n+p)/(n-p)``. `n` is the number of
1370 | samples at this node and `p` is the number of parameters of the model.
1371 |
1372 | According to the original M5 paper, this is to penalize complex models
1373 | used for few samples.
1374 |
1375 | Note that if ``n_samples <= n_parameters`` the denominator is zero or negative.
1376 | In this case an arbitrary high penalization factor is used: ``10 * err``.
1377 |
1378 | Parameters
1379 | ----------
1380 | err : float
1381 | The model error for a given tree node. Typically the MAE.
1382 |
1383 | n_samples : int
1384 | The number of samples at this node.
1385 |
1386 | n_parameters : int
1387 | The number of parameters of the model at this node.
1388 |
1389 | Returns
1390 | -------
1391 | penalized_mae : float
1392 | The penalized error ``err * (n+p)/(n-p)``, or ``10 * err`` if ``n <= p``.
1393 | """
1394 | #
1395 | if n_samples <= n_parameters:
1396 | # denominator is zero or negative: use a large factor so as to penalize
1397 | # a lot this overly complex model.
1398 | factor = 10.0 # Caution says Yong in his code
1399 | else:
1400 | factor = (n_samples + multiplier * n_parameters) / (n_samples - n_parameters)
1401 |
1402 | return err * factor
1403 |
1404 |
1405 | def predict_from_leaves_no_smoothing(tree, node_models, X):
1406 | """
1407 | Returns the prediction for all samples in X, based on using the appropriate
1408 | model leaf.
1409 |
1410 | This function
1411 | - uses tree.apply(X) to know in which tree leaf each sample in X goes
1412 | - then for each of the leaves that are actually touched, uses
1413 | node_models[leaf_id] to predict, for the X reaching that leaf
1414 |
1415 | This function assumes that node_models contains non-empty LinRegLeafModel
1416 | entries for all leaf nodes that will be reached by the samples X.
1417 |
1418 | Parameters
1419 | ----------
1420 | tree : Tree
1421 | The tree object.
1422 | node_models : array-like
1423 | The array containing node models
1424 | X : 2D array-like
1425 | The sample data.
1426 |
1427 | Returns
1428 | -------
1429 | y_predicted : 1D array-like
1430 | The prediction vector.
1431 | """
1432 | # **** This does the job, but we have one execution of model.predict() per
1433 | # sample: probably not efficient
1434 | # sample_ids_to_leaf_node_ids = tree.apply(X)
1435 | # model_and_x = np.concatenate((node_models[leaf_node_ids].reshape(-1, 1), X), axis=1)
1436 | # def pred(m_and_x):
1437 | # return m_and_x[0].model.predict(m_and_x[1:].reshape(1,-1))[0]
1438 | # y_predicted = np.array(list(map(pred, model_and_x)))
1439 |
1440 | # **** This should be faster because the number of calls to predict() is
1441 | # equal to the number of leaves touched
1442 | sample_ids_to_leaf_node_ids = tree.apply(X)
1443 | y_predicted = -np.ones(sample_ids_to_leaf_node_ids.shape, dtype=DOUBLE)
1444 |
1445 | # -- find the unique list of leaves touched
1446 | leaf_node_ids, inverse_idx = np.unique(sample_ids_to_leaf_node_ids, return_inverse=True)
1447 |
1448 | # -- for each leaf perform the prediction for samples reaching that leaf
1449 | for leaf_node_id in leaf_node_ids:
1450 | # get the indices of the samples reaching that leaf
1451 | sample_indices = np.nonzero(sample_ids_to_leaf_node_ids == leaf_node_id)[0]
1452 |
1453 | # predict
1454 | node_model = node_models[leaf_node_id]
1455 | if isinstance(node_model, LinRegLeafModel):
1456 | y_predicted[sample_indices] = np.ravel(node_model.predict(X[sample_indices, :]))
1457 | else:
1458 | # isinstance(node_model, ConstantLeafModel)
1459 | y_predicted[sample_indices] = tree.value[leaf_node_id]
1460 |
1461 | # **** Recursive strategy: not used anymore
1462 | # left_node = tree.children_left[node_id] # this way we do not have to query it again in the else.
1463 | # if left_node == _tree.TREE_LEAF:
1464 | # # --> Current node is a LEAF. See is_leaf(node_id, tree) <--
1465 | # y_predicted = node_models[node_id].model.predict()
1466 | # else:
1467 | # # --> Current node is a SPLIT <--
1468 | # right_node = tree.children_right[node_id]
1469 | #
1470 | #
1471 | # samples_at_this_node = get_samples_at_node(node_id, nodes_to_samples)
1472 | # y_true_this_node = y_train_all[samples_at_this_node]
1473 | # # As surprising as it may seem, in numpy [samples_at_this_node, selected_features] does something else.
1474 | # X_train_this_node = X_train_all[samples_at_this_node, :][:, selected_features]
1475 | #
1476 | # X_left
1477 |
1478 | return y_predicted
1479 |
1480 |
1481 | def predict_from_leaves(m5p, X, smoothing=True, smoothing_constant=15):
1482 | """
1483 | Predicts using the M5P tree, without using the compiled sklearn
1484 | `tree.apply` subroutine.
1485 |
1486 | The main purpose of this function is to apply smoothing to a M5P model tree
1487 | where smoothing has not been pre-installed on the models. For examples to
1488 | enable a model to be used both without and with smoothing for comparisons
1489 | purposes, or for models whose leaves are not Linear Models and therefore
1490 | for which no pre-installation method exist.
1491 |
1492 | Note: this method is slower than `predict_from_leaves_no_smoothing` when
1493 | `smoothing=False`.
1494 |
1495 | Parameters
1496 | ----------
1497 | m5p : M5Prime
1498 | The model to use for prediction
1499 | X : array-like
1500 | The input data
1501 | smoothing : bool
1502 | Whether to apply smoothing
1503 | smoothing_constant : int
1504 | The smoothing constant `k` used as the prediction weight of parent node model.
1505 | (k=15 in the articles).
1506 |
1507 |
1508 | Returns
1509 | -------
1510 |
1511 | """
1512 | # validate and converts dtype just in case this was directly called
1513 | # e.g. in unit tests
1514 | X = m5p._validate_X_predict(X, check_input=True)
1515 |
1516 | tree = m5p.tree_
1517 | node_models = m5p.node_models
1518 | nb_samples = X.shape[0]
1519 | y_predicted = -np.ones((nb_samples, 1), dtype=DOUBLE)
1520 |
1521 | # sample_ids_to_leaf_node_ids = tree.apply(X)
1522 | def smooth_predictions(ancestor_nodes, X_at_node, y_pred_at_node):
1523 | # note: y_predicted_at_node can be a constant
1524 | current_node_model_id = ancestor_nodes[-1]
1525 | for _i, parent_model_id in enumerate(reversed(ancestor_nodes[:-1])):
1526 | # warning: this is the nb of TRAINING samples at this node
1527 | node_nb_train_samples = tree.n_node_samples[current_node_model_id]
1528 | parent_model = node_models[parent_model_id]
1529 | parent_predictions = parent_model.predict(X_at_node)
1530 | y_pred_at_node = (node_nb_train_samples * y_pred_at_node + smoothing_constant * parent_predictions) / (
1531 | node_nb_train_samples + smoothing_constant
1532 | )
1533 | current_node_model_id = parent_model_id
1534 |
1535 | return y_pred_at_node
1536 |
1537 | def apply_prediction(node_id, ids=None, ancestor_nodes=None):
1538 | first_call = False
1539 | if ids is None:
1540 | ids = slice(None)
1541 | first_call = True
1542 | if smoothing:
1543 | if ancestor_nodes is None:
1544 | ancestor_nodes = [node_id]
1545 | else:
1546 | ancestor_nodes.append(node_id)
1547 |
1548 | left_id = tree.children_left[node_id]
1549 | if left_id == _tree.TREE_LEAF:
1550 | # ... and tree.children_right[node_id] == _tree.TREE_LEAF
1551 | # LEAF node: predict
1552 | # predict
1553 | node_model = node_models[node_id]
1554 | # assert (ids == (sample_ids_to_leaf_node_ids == node_id)).all()
1555 | if isinstance(node_model, LinRegLeafModel):
1556 | X_at_node = X[ids, :]
1557 | predictions = node_model.predict(X_at_node)
1558 | else:
1559 | # isinstance(node_model, ConstantLeafModel)
1560 | predictions = tree.value[node_id]
1561 | if smoothing:
1562 | X_at_node = X[ids, :]
1563 |
1564 | # finally apply smoothing
1565 | if smoothing:
1566 | y_predicted[ids] = smooth_predictions(ancestor_nodes, X_at_node, predictions)
1567 | else:
1568 | y_predicted[ids] = predictions
1569 | else:
1570 | right_id = tree.children_right[node_id]
1571 | # non-leaf node: split samples and recurse
1572 | left_group = np.zeros(nb_samples, dtype=bool)
1573 | left_group[ids] = X[ids, tree.feature[node_id]] <= tree.threshold[node_id]
1574 | right_group = (~left_group) if first_call else (ids & (~left_group))
1575 |
1576 | # important: copy ancestor_nodes BEFORE calling anything, otherwise
1577 | # it will be modified
1578 | apply_prediction(
1579 | left_id, ids=left_group, ancestor_nodes=(ancestor_nodes.copy() if ancestor_nodes is not None else None)
1580 | )
1581 | apply_prediction(right_id, ids=right_group, ancestor_nodes=ancestor_nodes)
1582 |
1583 | # recurse to fill all predictions
1584 | apply_prediction(0)
1585 |
1586 | return y_predicted
1587 |
1588 |
1589 | def prune_children(node_id, tree):
1590 | """
1591 | Prunes the children of node_id in the given `tree`.
1592 |
1593 | Inspired by https://github.com/shenwanxiang/sklearn-post-prune-tree/blob/master/tree_prune.py#L122
1594 |
1595 | IMPORTANT this relies on the fact that `children_left` and `children_right`
1596 | are modificable (and are the only things we need to modify to fix the
1597 | tree). It seems to be the case as of now.
1598 |
1599 | Parameters
1600 | ----------
1601 | node_id : int
1602 | tree : Tree
1603 |
1604 | Returns
1605 | -------
1606 | removed_nodes : List[int]
1607 | A list of removed nodes
1608 | """
1609 |
1610 | def _prune_below(_node_id):
1611 | left_child = tree.children_left[_node_id]
1612 | right_child = tree.children_right[_node_id]
1613 | if left_child == _tree.TREE_LEAF:
1614 | # _node_id is a leaf: left_ & right_child say "leaf".Nothing to do
1615 | return []
1616 | else:
1617 | # Make sure everything is pruned below: children should be leaves
1618 | removed_l = _prune_below(left_child)
1619 | removed_r = _prune_below(right_child)
1620 |
1621 | # -- First declare that they are not leaves anymore but they do not exist at all
1622 | for child in [left_child, right_child]:
1623 |
1624 | if tree.children_left[child] != _tree.TREE_LEAF:
1625 | raise ValueError("Unexpected children_left: %s, please report it as issue." % child)
1626 |
1627 | if tree.children_right[child] != _tree.TREE_LEAF:
1628 | raise ValueError("Unexpected children_right: %s, please report it as issue." % child)
1629 |
1630 | tree.children_left[child] = _tree.TREE_UNDEFINED
1631 | tree.children_right[child] = _tree.TREE_UNDEFINED
1632 |
1633 | # -- Then declare that current node is a leaf
1634 | tree.children_left[_node_id] = _tree.TREE_LEAF
1635 | tree.children_right[_node_id] = _tree.TREE_LEAF
1636 |
1637 | # Return the list of removed nodes
1638 | return removed_l + removed_r + [left_child, right_child]
1639 |
1640 | # Note: we do not change the node count here, as we'll clean all later.
1641 | # true_node_count = tree.node_count - sum(tree.children_left == _tree.TREE_UNDEFINED)
1642 | # tree.node_count -= 2*len(nodes_to_remove)
1643 |
1644 | return _prune_below(node_id)
1645 |
1646 |
1647 | def is_leaf(node_id, tree):
1648 | """
1649 | Returns true if node with id `node_id` is a leaf in tree `tree`.
1650 | Is is not actually used in this file because we always need the left child
1651 | node id in our code.
1652 |
1653 | But it is kept here as an easy way to remember how it works.
1654 |
1655 | Parameters
1656 | ----------
1657 | node_id : int
1658 | tree : Tree
1659 |
1660 | Returns
1661 | -------
1662 | _is_leaf : bool
1663 | A boolean flag, True if this node is a leaf.
1664 | """
1665 | if node_id == _tree.TREE_LEAF or node_id == _tree.TREE_UNDEFINED:
1666 | raise ValueError("Invalid node_id %s" % node_id)
1667 |
1668 | return tree.children_left[node_id] == _tree.TREE_LEAF
1669 |
1670 |
1671 | def get_samples_at_node(node_id, nodes_to_samples):
1672 | """
1673 | Return an array containing the ids of the samples for node `node_id`.
1674 |
1675 | This method requires the user to
1676 | - first compute the decision path for the sample matrix X
1677 | - then convert it to a csc
1678 |
1679 | >>> samples_to_nodes = estimator.decision_path(X) # returns a Scipy compressed sparse row matrix (CSR)
1680 | >>> nodes_to_samples = samples_to_nodes.tocsc() # we need the column equivalent (CSC)
1681 | >>> samples = get_samples_at_node(node_id, nodes_to_samples)
1682 |
1683 | Parameters
1684 | ----------
1685 | node_id : int
1686 | The node for which the list of samples is queried.
1687 | nodes_to_samples : csc_matrix
1688 | A boolean matrix in Compressed Sparse Column (CSC) format where rows are the
1689 | tree nodes and columns are samples. The matrix contains a 1 when the samples
1690 | go through this node when processed by the decisions.
1691 |
1692 | Returns
1693 | -------
1694 |
1695 | """
1696 | return nodes_to_samples.indices[nodes_to_samples.indptr[node_id] : nodes_to_samples.indptr[node_id + 1]]
1697 |
1698 |
1699 | class M5Prime(M5Base, RegressorMixin):
1700 | """An M5' (M five prime) model tree regressor.
1701 |
1702 | The original M5 algorithm was invented by R. Quinlan [1]_. Y. Wang [2]_ made
1703 | improvements and named the resulting algorithm M5 Prime.
1704 | This implementation was inspired by Weka (https://github.com/bnjmn/weka)
1705 | M5Prime class, from Mark Hall.
1706 |
1707 | See also
1708 | --------
1709 | M5Base
1710 |
1711 | References
1712 | ----------
1713 | .. [1] Ross J. Quinlan, "Learning with Continuous Classes", 5th Australian
1714 | Joint Conference on Artificial Intelligence, pp343-348, 1992.
1715 | .. [2] Y. Wang and I. H. Witten, "Induction of model trees for predicting
1716 | continuous classes", Poster papers of the 9th European Conference
1717 | on Machine Learning, 1997.
1718 | """
1719 |
1720 |
1721 | def _empty_contig_ar(shape, dtype):
1722 | """Return an empty contiguous array with given shape and dtype."""
1723 | return np.ascontiguousarray(np.empty(shape, dtype=dtype))
1724 |
--------------------------------------------------------------------------------
/src/m5py/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smarie/python-m5p/8016d3cca3f263f56088fa8878d0f99c6fd81640/src/m5py/py.typed
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/smarie/python-m5p/8016d3cca3f263f56088fa8878d0f99c6fd81640/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from sklearn.metrics import mean_squared_error
4 |
5 | from m5py import M5Prime
6 | from m5py.main import predict_from_leaves
7 |
8 |
9 | @pytest.mark.parametrize("smoothing_constant", [75, 100], ids="smoothing_constant={}".format)
10 | def test_m5p_smoothing(smoothing_constant):
11 | """Tests that the M5P smoothing feature works correctly."""
12 |
13 | # use a given random seed so as to enforce deterministic output
14 | # (and 100 gives a particularly interesting result with constant leaves, thats why we use it)
15 | np.random.seed(100)
16 | debug_prints = False
17 |
18 | x = np.random.uniform(-20, -2, (500, 1))
19 | piece_1_mask = x < -12
20 | piece_2_mask = ~piece_1_mask
21 | piece_2_train_indices = np.argsort(x, axis=0)[-5:] # last 5 samples
22 | # training_mask = np.ma.mask_or(piece_1_mask, piece_2_train_mask)
23 | training_mask = piece_1_mask.copy()
24 | training_mask[piece_2_train_indices] = True
25 | test_mask = ~training_mask
26 | X_train = x[training_mask].reshape(-1, 1)
27 | X_test = x[test_mask].reshape(-1, 1)
28 |
29 | # create a y that is piecewise linear
30 | def piece1(x):
31 | return 12.3 * x + 52.1
32 |
33 | y = piece1(x) + 5 * np.random.randn(*x.shape)
34 |
35 | z = 0
36 |
37 | def piece2(x):
38 | return 5.3 * x + z
39 |
40 | z = piece1(-12) - piece2(-12)
41 |
42 | y[piece_2_mask] -= piece1(x[piece_2_mask])
43 | y[piece_2_mask] += piece2(x[piece_2_mask])
44 | y_train = y[training_mask].reshape(-1, 1)
45 | y_test = y[test_mask].reshape(-1, 1)
46 | # plt.plot(X_train, y_train, 'xb')
47 | # plt.plot(X_test, y_test, 'xg')
48 | # plt.draw()
49 | # plt.pause(0.001)
50 |
51 | # Fit M5P without smoothing
52 | # print("Without smoothing: ")
53 | model_no_smoothing = M5Prime(use_smoothing=False, debug_prints=False)
54 | model_no_smoothing.fit(X_train, y_train)
55 | # print(export_text_m5(model_no_smoothing, out_file=None, node_ids=True))
56 |
57 | # Fit M5P with smoothing
58 | # print("With smoothing: ")
59 | model_smoothing = M5Prime(use_smoothing=True, debug_prints=debug_prints, smoothing_constant=smoothing_constant)
60 | model_smoothing.fit(X_train, y_train)
61 | # print(export_text_m5(model_smoothing, out_file=None, node_ids=True))
62 |
63 | # Predict
64 | # --no smoothing: compare the fast and slow methods
65 | y_pred_no_smoothing = model_no_smoothing.predict(x).reshape(-1, 1)
66 | y_pred_no_smoothing_slow = predict_from_leaves(model_no_smoothing, x, smoothing=False).reshape(-1, 1)
67 |
68 | # --smoothing: compare pre-computed and late methods
69 | y_pred_installed_smoothing = model_smoothing.predict(x).reshape(-1, 1)
70 | y_pred_late_smoothing = model_no_smoothing.predict(
71 | x, smooth_predictions=True, smoothing_constant=smoothing_constant
72 | ).reshape(-1, 1)
73 |
74 | # plt.plot(x, y_pred_no_smoothing, '.r', label="no smoothing")
75 | # # plt.plot(x, y_pred_no_smoothing_slow, '.r', label="no smoothing slow")
76 | # plt.plot(x, y_pred_late_smoothing, '.m', label="smoothing (on prediction)")
77 | # plt.plot(x, y_pred_installed_smoothing, '.g', label="smoothing (installed models)")
78 | # plt.legend()
79 | # plt.draw()
80 | # plt.pause(0.001)
81 |
82 | # make sure both ways to smooth give the same output
83 | np.testing.assert_array_equal(y_pred_no_smoothing, y_pred_no_smoothing_slow)
84 | np.testing.assert_array_almost_equal(y_pred_late_smoothing, y_pred_installed_smoothing, decimal=4)
85 |
86 | # compare performances without/with smoothing
87 | mse_no_smoothing = mean_squared_error(y_test, y_pred_no_smoothing[test_mask])
88 | mse_smoothing = mean_squared_error(y_test, y_pred_installed_smoothing[test_mask])
89 | print("M5P MSE: %.4f (no smoothing) %.4f (smoothing)" % (mse_no_smoothing, mse_smoothing))
90 |
91 | # simple assert: smoothing improves performance (in most cases but not always - thats why we fixed our random seed)
92 | assert mse_smoothing < mse_no_smoothing
93 |
94 | # plt.close('all')
95 |
96 |
97 | # def test_boston_housing():
98 | # """
99 | # A copy of the Scikit Learn Gradient Boosting regression example, with the M5P in addition
100 | #
101 | # http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_regression.html
102 | # :return:
103 | # """
104 | # import matplotlib.pyplot as plt
105 | # plt.ion() # comment to debug
106 | #
107 | # # #############################################################################
108 | # # Load data
109 | # boston = datasets.load_boston()
110 | # X, y = shuffle(boston.data, boston.target, random_state=13)
111 | # X = X.astype(np.float32)
112 | # offset = int(X.shape[0] * 0.9)
113 | # X_train, y_train = X[:offset], y[:offset]
114 | # X_test, y_test = X[offset:], y[offset:]
115 | #
116 | # # #############################################################################
117 | # # Fit regression model
118 | # params = {'n_estimators': 500, 'max_depth': 4, 'min_samples_split': 2,
119 | # 'learning_rate': 0.01, 'loss': 'ls'}
120 | # clf = ensemble.GradientBoostingRegressor(**params)
121 | # clf.fit(X_train, y_train)
122 | # y_predicted = clf.predict(X_test)
123 | # mse = mean_squared_error(y_test, y_predicted)
124 | # print("XGBoost MSE: %.4f" % mse)
125 | #
126 | # # #############################################################################
127 | # # Plot predictions
128 | # plt.figure(figsize=(18, 12))
129 | # plt.subplot(2, 3, 1)
130 | # plt.title('Predictions on test set (RMSE = {:2f})'.format(np.sqrt(mse)))
131 | # plt.plot(y_test, y_predicted, '.')
132 | # plt.xlabel('true y')
133 | # plt.ylabel('predicted_y')
134 | #
135 | # # #############################################################################
136 | # # Plot training deviance
137 | #
138 | # # compute test set deviance
139 | # test_score = np.zeros((params['n_estimators'],), dtype=np.float64)
140 | #
141 | # for i, y_pred in enumerate(clf.staged_predict(X_test)):
142 | # test_score[i] = clf.loss_(y_test, y_pred)
143 | #
144 | # plt.subplot(2, 3, 2)
145 | # plt.title('Deviance')
146 | # plt.plot(np.arange(params['n_estimators']) + 1, clf.train_score_, 'b-',
147 | # label='Training Set Deviance')
148 | # plt.plot(np.arange(params['n_estimators']) + 1, test_score, 'r-',
149 | # label='Test Set Deviance')
150 | # plt.legend(loc='upper right')
151 | # plt.xlabel('Boosting Iterations')
152 | # plt.ylabel('Deviance')
153 | #
154 | # # #############################################################################
155 | # # Plot feature importance
156 | # feature_importance = clf.feature_importances_
157 | # # make importances relative to max importance
158 | # feature_importance = 100.0 * (feature_importance / feature_importance.max())
159 | # sorted_idx = np.argsort(feature_importance)
160 | # pos = np.arange(sorted_idx.shape[0]) + .5
161 | # plt.subplot(2, 3, 3)
162 | # plt.barh(pos, feature_importance[sorted_idx], align='center')
163 | # plt.yticks(pos, boston.feature_names[sorted_idx])
164 | # plt.xlabel('Relative Importance')
165 | # plt.title('Variable Importance')
166 | #
167 | # # ------- M5P
168 | # # #############################################################################
169 | # # Fit regression model
170 | # params = {}
171 | # clf = M5Prime(**params)
172 | # clf.fit(X_train, y_train)
173 | #
174 | # # Print the tree
175 | # print(export_text_m5(clf, out_file=None))
176 | # print(export_text_m5(clf, out_file=None, feature_names=boston.feature_names))
177 | #
178 | # # Predict
179 | # y_predicted = clf.predict(X_test)
180 | # mse = mean_squared_error(y_test, y_predicted)
181 | # print("M5P MSE: %.4f" % mse)
182 | #
183 | # # #############################################################################
184 | # # Plot predictions
185 | # plt.subplot(2, 3, 4)
186 | # plt.title('Predictions on test set (RMSE = {:2f})'.format(np.sqrt(mse)))
187 | # plt.plot(y_test, y_predicted, '.')
188 | # plt.xlabel('true y')
189 | # plt.ylabel('predicted_y')
190 | #
191 | # # Compress the tree (features-wise)
192 | # idx = clf.compress_features()
193 | # # Print the tree
194 | # print(export_text_m5(clf, out_file=None))
195 | # print(export_text_m5(clf, out_file=None, feature_names=boston.feature_names))
196 | #
197 | # # Predict
198 | # y_predicted2 = clf.predict(X_test[:, idx])
199 | # mse2 = mean_squared_error(y_test, y_predicted2)
200 | # print("M5P2 MSE: %.4f" % mse)
201 | #
202 | # # #############################################################################
203 | # # Plot predictions
204 | # plt.subplot(2, 3, 5)
205 | # plt.title('Predictions on test set (RMSE = {:2f})'.format(np.sqrt(mse2)))
206 | # plt.plot(y_test, y_predicted2, '.')
207 | # plt.xlabel('true y')
208 | # plt.ylabel('predicted_y')
209 | #
210 | # # #############################################################################
211 | # # Plot feature importance
212 | # feature_importance = clf.feature_importances_
213 | # # make importances relative to max importance
214 | # feature_importance = 100.0 * (feature_importance / feature_importance.max())
215 | # sorted_idx = np.argsort(feature_importance)
216 | # pos = np.arange(sorted_idx.shape[0]) + .5
217 | # plt.subplot(2, 3, 6)
218 | # plt.barh(pos, feature_importance[sorted_idx], align='center')
219 | # # do not forget that we now work on reindexed features
220 | # plt.yticks(pos, boston.feature_names[idx][sorted_idx])
221 | # plt.xlabel('Relative Importance')
222 | # plt.title('Variable Importance')
223 | #
224 | # plt.close('all')
225 | #
226 | #
227 | # @pytest.mark.parametrize("use_smoothing", [None, # None (default) = True = 'installed'
228 | # False, 'on_prediction'])
229 | # def test_default_smoothing_modes(use_smoothing):
230 | # """ Tests with default constant/ratio, depending on smoothing mode """
231 | #
232 | # half_x = np.random.random(25)
233 | # X = np.r_[half_x, -half_x]
234 | # Y = np.r_[2.8 * half_x + 5, -0.2 * half_x + 5]
235 | #
236 | # X = X.reshape(-1, 1)
237 | # Y = Y.reshape(-1, 1)
238 | #
239 | # regr = M5Prime(use_smoothing=use_smoothing)
240 | # regr.fit(X, Y)
241 | # print(export_text_m5(regr, out_file=None, node_ids=True))
242 | # preds = regr.predict(X)
243 | #
244 | # import matplotlib.pyplot as plt
245 | # plt.ion() # comment to debug
246 | # plt.plot(X, Y, 'x', label='true')
247 | # plt.plot(X, preds, 'x', label='prediction')
248 | # plt.close('all')
249 |
--------------------------------------------------------------------------------