├── .coveragerc ├── .github ├── dependabot.yml └── workflows │ ├── deploy-gh-pages.yml │ ├── lint.yml │ └── python-app.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── doc ├── Makefile ├── _static │ ├── css │ │ └── project-template.css │ ├── img │ │ ├── index_api.svg │ │ ├── index_examples.svg │ │ ├── index_getting_started.svg │ │ ├── index_user_guide.svg │ │ └── logo.png │ └── js │ │ └── copybutton.js ├── _templates │ ├── class.rst │ ├── function.rst │ ├── numpydoc_docstring.py │ └── sidebar-search-bs.html ├── api.rst ├── conf.py ├── index.rst ├── make.bat ├── quick_start.rst └── user_guide.rst ├── examples ├── README.txt ├── plot_classifier.py ├── plot_template.py └── plot_transformer.py ├── pixi.lock ├── pyproject.toml └── skltemplate ├── __init__.py ├── _template.py ├── tests ├── __init__.py ├── test_common.py └── test_template.py └── utils ├── __init__.py ├── discovery.py └── tests ├── __init__.py └── test_discovery.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # Configuration for coverage.py 2 | 3 | [run] 4 | branch = True 5 | source = skltemplate 6 | include = */skltemplate/* 7 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions as recommended in SPEC8: 4 | # https://github.com/scientific-python/specs/pull/325 5 | # At the time of writing, release critical workflows such as 6 | # pypa/gh-action-pypi-publish should use hash-based versioning for security 7 | # reasons. This strategy may be generalized to all other github actions 8 | # in the future. 9 | - package-ecosystem: "github-actions" 10 | directory: "/" 11 | schedule: 12 | interval: "weekly" 13 | groups: 14 | actions: 15 | patterns: 16 | - "*" 17 | reviewers: 18 | - "glemaitre" 19 | -------------------------------------------------------------------------------- /.github/workflows/deploy-gh-pages.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | deploy-gh-pages: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | - uses: prefix-dev/setup-pixi@v0.8.1 20 | with: 21 | pixi-version: v0.23.0 22 | environments: doc 23 | frozen: true 24 | 25 | - name: Build documentation 26 | run: pixi run -e doc build-doc 27 | 28 | - name: Update the main gh-page website 29 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 30 | uses: peaceiris/actions-gh-pages@v4.0.0 31 | with: 32 | github_token: ${{ secrets.GITHUB_TOKEN }} 33 | publish_dir: ./doc/_build/html 34 | commit_message: "[ci skip] ${{ github.event.head_commit.message }}" 35 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: prefix-dev/setup-pixi@v0.8.1 17 | with: 18 | pixi-version: v0.23.0 19 | environments: lint 20 | frozen: true 21 | 22 | - name: Run linter 23 | run: pixi run -e lint lint 24 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | strategy: 14 | matrix: 15 | os: [windows-latest, ubuntu-latest, macos-latest, macos-12] 16 | environment: [test] 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: prefix-dev/setup-pixi@v0.8.1 21 | with: 22 | pixi-version: v0.23.0 23 | environments: ${{ matrix.environment }} 24 | frozen: true 25 | 26 | - name: Run tests 27 | run: pixi run -e ${{ matrix.environment }} test 28 | 29 | - name: Upload coverage reports to Codecov 30 | uses: codecov/codecov-action@v4.6.0 31 | with: 32 | token: ${{ secrets.CODECOV_TOKEN }} 33 | slug: scikit-learn-contrib/project-template 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # scikit-learn specific 10 | doc/_build/ 11 | doc/auto_examples/ 12 | doc/modules/generated/ 13 | doc/datasets/generated/ 14 | 15 | # Distribution / packaging 16 | 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *,cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | 62 | # Sphinx documentation 63 | doc/_build/ 64 | doc/generated/ 65 | doc/sg_execution_times.rst 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | .pixi 71 | 72 | # General 73 | .DS_Store 74 | .AppleDouble 75 | .LSOverride 76 | 77 | # auto-generated files 78 | skltemplate/_version.py -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 23.3.0 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: v0.0.272 14 | hooks: 15 | - id: ruff 16 | args: ["--fix", "--show-source"] 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Vighnesh Birodkar and scikit-learn-contrib contributors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of project-template nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | project-template - A template for scikit-learn contributions 2 | ============================================================ 3 | 4 | ![tests](https://github.com/scikit-learn-contrib/project-template/actions/workflows/python-app.yml/badge.svg) 5 | [![codecov](https://codecov.io/gh/scikit-learn-contrib/project-template/graph/badge.svg?token=L0XPWwoPLw)](https://codecov.io/gh/scikit-learn-contrib/project-template) 6 | ![doc](https://github.com/scikit-learn-contrib/project-template/actions/workflows/deploy-gh-pages.yml/badge.svg) 7 | 8 | **project-template** is a template project for [scikit-learn](https://scikit-learn.org) 9 | compatible extensions. 10 | 11 | It aids development of estimators that can be used in scikit-learn pipelines and 12 | (hyper)parameter search, while facilitating testing (including some API compliance), 13 | documentation, open source development, packaging, and continuous integration. 14 | 15 | Refer to the documentation to modify the template for your own scikit-learn 16 | contribution: https://contrib.scikit-learn.org/project-template 17 | 18 | *Thank you for cleanly contributing to the scikit-learn ecosystem!* 19 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | -rm -rf $(BUILDDIR)/* 51 | -rm -rf auto_examples/ 52 | -rm -rf generated/* 53 | -rm -rf modules/generated/* 54 | 55 | html: 56 | # These two lines make the build a bit more lengthy, and the 57 | # the embedding of images more robust 58 | rm -rf $(BUILDDIR)/html/_images 59 | #rm -rf _build/doctrees/ 60 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 63 | 64 | dirhtml: 65 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 66 | @echo 67 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 68 | 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | pickle: 75 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 76 | @echo 77 | @echo "Build finished; now you can process the pickle files." 78 | 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | htmlhelp: 85 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 86 | @echo 87 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 88 | ".hhp project file in $(BUILDDIR)/htmlhelp." 89 | 90 | qthelp: 91 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 92 | @echo 93 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 94 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 95 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/project-template.qhcp" 96 | @echo "To view the help file:" 97 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/project-template.qhc" 98 | 99 | devhelp: 100 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 101 | @echo 102 | @echo "Build finished." 103 | @echo "To view the help file:" 104 | @echo "# mkdir -p $$HOME/.local/share/devhelp/project-template" 105 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/project-template" 106 | @echo "# devhelp" 107 | 108 | epub: 109 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 110 | @echo 111 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 112 | 113 | latex: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo 116 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 117 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 118 | "(use \`make latexpdf' here to do that automatically)." 119 | 120 | latexpdf: 121 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 122 | @echo "Running LaTeX files through pdflatex..." 123 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 124 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 125 | 126 | latexpdfja: 127 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 128 | @echo "Running LaTeX files through platex and dvipdfmx..." 129 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 130 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 131 | 132 | text: 133 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 134 | @echo 135 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 136 | 137 | man: 138 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 139 | @echo 140 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 141 | 142 | texinfo: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo 145 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 146 | @echo "Run \`make' in that directory to run these through makeinfo" \ 147 | "(use \`make info' here to do that automatically)." 148 | 149 | info: 150 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 151 | @echo "Running Texinfo files through makeinfo..." 152 | make -C $(BUILDDIR)/texinfo info 153 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 154 | 155 | gettext: 156 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 157 | @echo 158 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 159 | 160 | changes: 161 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 162 | @echo 163 | @echo "The overview file is in $(BUILDDIR)/changes." 164 | 165 | linkcheck: 166 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 167 | @echo 168 | @echo "Link check complete; look for any errors in the above output " \ 169 | "or in $(BUILDDIR)/linkcheck/output.txt." 170 | 171 | doctest: 172 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 173 | @echo "Testing of doctests in the sources finished, look at the " \ 174 | "results in $(BUILDDIR)/doctest/output.txt." 175 | 176 | xml: 177 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 178 | @echo 179 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 180 | 181 | pseudoxml: 182 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 183 | @echo 184 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 185 | -------------------------------------------------------------------------------- /doc/_static/css/project-template.css: -------------------------------------------------------------------------------- 1 | /* Override some aspects of the pydata-sphinx-theme */ 2 | 3 | :root { 4 | /* Use softer blue from bootstrap's default info color */ 5 | --pst-color-info: 23, 162, 184; 6 | } 7 | 8 | table { 9 | width: auto; 10 | /* Override fit-content which breaks Styler user guide ipynb */ 11 | } 12 | 13 | .transparent-image { 14 | background: none !important; 15 | margin-top: 1em; 16 | margin-bottom: 1em; 17 | } 18 | 19 | /* Main index page overview cards */ 20 | 21 | .intro-card { 22 | padding: 30px 10px 20px 10px; 23 | } 24 | 25 | .intro-card .sd-card-img-top { 26 | margin: 10px; 27 | height: 52px; 28 | background: none !important; 29 | } 30 | 31 | .intro-card .sd-card-title { 32 | color: var(--pst-color-primary); 33 | font-size: var(--pst-font-size-h5); 34 | padding: 1rem 0rem 0.5rem 0rem; 35 | } 36 | 37 | .intro-card .sd-card-footer { 38 | border: none !important; 39 | } 40 | 41 | .intro-card .sd-card-footer p.sd-card-text { 42 | max-width: 220px; 43 | margin-left: auto; 44 | margin-right: auto; 45 | } 46 | 47 | .intro-card .sd-btn-secondary { 48 | background-color: #6c757d !important; 49 | border-color: #6c757d !important; 50 | } 51 | 52 | .intro-card .sd-btn-secondary:hover { 53 | background-color: #5a6268 !important; 54 | border-color: #545b62 !important; 55 | } 56 | 57 | .card, .card img { 58 | background-color: var(--pst-color-background); 59 | } 60 | -------------------------------------------------------------------------------- /doc/_static/img/index_api.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 21 | 43 | 45 | 46 | 48 | image/svg+xml 49 | 51 | 52 | 53 | 54 | 55 | 60 | 63 | 68 | 73 | 76 | 82 | 88 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /doc/_static/img/index_examples.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 21 | 43 | 45 | 46 | 48 | image/svg+xml 49 | 51 | 52 | 53 | 54 | 55 | 60 | 63 | 69 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /doc/_static/img/index_getting_started.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 21 | 43 | 45 | 46 | 48 | image/svg+xml 49 | 51 | 52 | 53 | 54 | 55 | 60 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /doc/_static/img/index_user_guide.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 21 | 43 | 45 | 46 | 48 | image/svg+xml 49 | 51 | 52 | 53 | 54 | 55 | 60 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /doc/_static/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scikit-learn-contrib/project-template/fc9f82eb6b998bde744ef736d4e101a9610ad5fb/doc/_static/img/logo.png -------------------------------------------------------------------------------- /doc/_static/js/copybutton.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | /* Add a [>>>] button on the top-right corner of code samples to hide 3 | * the >>> and ... prompts and the output and thus make the code 4 | * copyable. */ 5 | var div = $('.highlight-python .highlight,' + 6 | '.highlight-python3 .highlight,' + 7 | '.highlight-pycon .highlight,' + 8 | '.highlight-default .highlight') 9 | var pre = div.find('pre'); 10 | 11 | // get the styles from the current theme 12 | pre.parent().parent().css('position', 'relative'); 13 | var hide_text = 'Hide the prompts and output'; 14 | var show_text = 'Show the prompts and output'; 15 | var border_width = pre.css('border-top-width'); 16 | var border_style = pre.css('border-top-style'); 17 | var border_color = pre.css('border-top-color'); 18 | var button_styles = { 19 | 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 20 | 'border-color': border_color, 'border-style': border_style, 21 | 'border-width': border_width, 'color': border_color, 'text-size': '75%', 22 | 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em', 23 | 'border-radius': '0 3px 0 0' 24 | } 25 | 26 | // create and add the button to all the code blocks that contain >>> 27 | div.each(function(index) { 28 | var jthis = $(this); 29 | if (jthis.find('.gp').length > 0) { 30 | var button = $('>>>'); 31 | button.css(button_styles) 32 | button.attr('title', hide_text); 33 | button.data('hidden', 'false'); 34 | jthis.prepend(button); 35 | } 36 | // tracebacks (.gt) contain bare text elements that need to be 37 | // wrapped in a span to work with .nextUntil() (see later) 38 | jthis.find('pre:has(.gt)').contents().filter(function() { 39 | return ((this.nodeType == 3) && (this.data.trim().length > 0)); 40 | }).wrap(''); 41 | }); 42 | 43 | // define the behavior of the button when it's clicked 44 | $('.copybutton').click(function(e){ 45 | e.preventDefault(); 46 | var button = $(this); 47 | if (button.data('hidden') === 'false') { 48 | // hide the code output 49 | button.parent().find('.go, .gp, .gt').hide(); 50 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); 51 | button.css('text-decoration', 'line-through'); 52 | button.attr('title', show_text); 53 | button.data('hidden', 'true'); 54 | } else { 55 | // show the code output 56 | button.parent().find('.go, .gp, .gt').show(); 57 | button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); 58 | button.css('text-decoration', 'none'); 59 | button.attr('title', hide_text); 60 | button.data('hidden', 'false'); 61 | } 62 | }); 63 | }); 64 | -------------------------------------------------------------------------------- /doc/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | 10 | {% if methods %} 11 | .. rubric:: Methods 12 | 13 | .. autosummary:: 14 | {% for item in methods %} 15 | {% if '__init__' not in item %} 16 | ~{{ name }}.{{ item }} 17 | {% endif %} 18 | {%- endfor %} 19 | {% endif %} 20 | {% endblock %} 21 | 22 | .. include:: {{module}}.{{objname}}.examples 23 | 24 | .. raw:: html 25 | 26 |
27 | -------------------------------------------------------------------------------- /doc/_templates/function.rst: -------------------------------------------------------------------------------- 1 | {{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | 8 | .. include:: {{module}}.{{objname}}.examples 9 | 10 | .. raw:: html 11 | 12 |
13 | -------------------------------------------------------------------------------- /doc/_templates/numpydoc_docstring.py: -------------------------------------------------------------------------------- 1 | {{index}} 2 | {{summary}} 3 | {{extended_summary}} 4 | {{parameters}} 5 | {{returns}} 6 | {{yields}} 7 | {{other_parameters}} 8 | {{attributes}} 9 | {{raises}} 10 | {{warns}} 11 | {{warnings}} 12 | {{see_also}} 13 | {{notes}} 14 | {{references}} 15 | {{examples}} 16 | {{methods}} -------------------------------------------------------------------------------- /doc/_templates/sidebar-search-bs.html: -------------------------------------------------------------------------------- 1 | 15 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | ############# 4 | API Reference 5 | ############# 6 | 7 | This is an example on how to document the API of your own project. 8 | 9 | .. currentmodule:: skltemplate 10 | 11 | Estimator 12 | ========= 13 | 14 | .. autosummary:: 15 | :toctree: generated/ 16 | :template: class.rst 17 | 18 | TemplateEstimator 19 | 20 | Transformer 21 | =========== 22 | 23 | .. autosummary:: 24 | :toctree: generated/ 25 | :template: class.rst 26 | 27 | TemplateTransformer 28 | 29 | Predictor 30 | ========= 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :template: class.rst 35 | 36 | TemplateClassifier 37 | 38 | 39 | Utilities 40 | ========= 41 | 42 | .. autosummary:: 43 | :toctree: generated/ 44 | :template: functions.rst 45 | 46 | utils.discovery.all_estimators 47 | utils.discovery.all_displays 48 | utils.discovery.all_functions 49 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import os 7 | import sys 8 | 9 | # -- Project information ----------------------------------------------------- 10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 11 | from importlib.metadata import version as get_version 12 | 13 | project = "Scikit-learn Project Template" 14 | copyright = "2016, V. Birodkar" 15 | author = "V. Birodkar" 16 | release = get_version('skltemplate') 17 | version = ".".join(release.split(".")[:3]) 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | # Add any Sphinx extension module names here, as strings. They can be 23 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 24 | # ones. 25 | extensions = [ 26 | "sphinx.ext.autodoc", 27 | "sphinx.ext.autosummary", 28 | "sphinx.ext.intersphinx", 29 | "sphinx_design", 30 | "sphinx-prompt", 31 | "sphinx_gallery.gen_gallery", 32 | "numpydoc", 33 | ] 34 | 35 | templates_path = ["_templates"] 36 | exclude_patterns = ["_build", "_templates", "Thumbs.db", ".DS_Store"] 37 | 38 | # The reST default role (used for this markup: `text`) to use for all 39 | # documents. 40 | default_role = "literal" 41 | 42 | # -- Options for HTML output ------------------------------------------------- 43 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 44 | 45 | html_theme = "pydata_sphinx_theme" 46 | html_static_path = ["_static"] 47 | html_style = "css/project-template.css" 48 | html_logo = "_static/img/logo.png" 49 | # html_favicon = "_static/img/favicon.ico" 50 | html_css_files = [ 51 | "css/project-template.css", 52 | ] 53 | html_sidebars = { 54 | "quick_start": [], 55 | "user_guide": [], 56 | "auto_examples/index": [], 57 | } 58 | 59 | html_theme_options = { 60 | "external_links": [], 61 | "github_url": "https://github.com/scikit-learn-contrib/project-template", 62 | # "twitter_url": "https://twitter.com/pandas_dev", 63 | "use_edit_page_button": True, 64 | "show_toc_level": 1, 65 | # "navbar_align": "right", # For testing that the navbar items align properly 66 | } 67 | 68 | html_context = { 69 | "github_user": "scikit-learn-contrib", 70 | "github_repo": "project-template", 71 | "github_version": "master", 72 | "doc_path": "doc", 73 | } 74 | 75 | # -- Options for autodoc ------------------------------------------------------ 76 | 77 | autodoc_default_options = { 78 | "members": True, 79 | "inherited-members": True, 80 | } 81 | 82 | # generate autosummary even if no references 83 | autosummary_generate = True 84 | 85 | # -- Options for numpydoc ----------------------------------------------------- 86 | 87 | # this is needed for some reason... 88 | # see https://github.com/numpy/numpydoc/issues/69 89 | numpydoc_show_class_members = False 90 | 91 | # -- Options for intersphinx -------------------------------------------------- 92 | 93 | intersphinx_mapping = { 94 | "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), 95 | "numpy": ("https://numpy.org/doc/stable", None), 96 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 97 | "scikit-learn": ("https://scikit-learn.org/stable", None), 98 | "matplotlib": ("https://matplotlib.org/", None), 99 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 100 | "joblib": ("https://joblib.readthedocs.io/en/latest/", None), 101 | } 102 | 103 | # -- Options for sphinx-gallery ----------------------------------------------- 104 | 105 | # Generate the plot for the gallery 106 | plot_gallery = True 107 | 108 | sphinx_gallery_conf = { 109 | "doc_module": "skltemplate", 110 | "backreferences_dir": os.path.join("generated"), 111 | "examples_dirs": "../examples", 112 | "gallery_dirs": "auto_examples", 113 | "reference_url": {"skltemplate": None}, 114 | } 115 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. project-template documentation master file, created by 2 | sphinx-quickstart on Mon Jan 18 14:44:12 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :notoc: 7 | 8 | ############################################# 9 | Project template for `scikit-learn` extension 10 | ############################################# 11 | 12 | **Date**: |today| **Version**: |version| 13 | 14 | **Useful links**: 15 | `Source Repository `__ | 16 | `Issues & Ideas `__ | 17 | 18 | This is the documentation for the `project-template` to help at extending 19 | `scikit-learn`. It provides some information on how to build your own custom 20 | `scikit-learn` compatible estimators as well as a template to package them. 21 | 22 | 23 | .. grid:: 1 2 2 2 24 | :gutter: 4 25 | :padding: 2 2 0 0 26 | :class-container: sd-text-center 27 | 28 | .. grid-item-card:: Getting started 29 | :img-top: _static/img/index_getting_started.svg 30 | :class-card: intro-card 31 | :shadow: md 32 | 33 | Information regarding this template and how to modify it for your own project. 34 | 35 | +++ 36 | 37 | .. button-ref:: quick_start 38 | :ref-type: ref 39 | :click-parent: 40 | :color: secondary 41 | :expand: 42 | 43 | To the getting started guideline 44 | 45 | .. grid-item-card:: User guide 46 | :img-top: _static/img/index_user_guide.svg 47 | :class-card: intro-card 48 | :shadow: md 49 | 50 | An example of narrative documentation. Here, we will explain how to create your 51 | own `scikit-learn` estimator. 52 | 53 | +++ 54 | 55 | .. button-ref:: user_guide 56 | :ref-type: ref 57 | :click-parent: 58 | :color: secondary 59 | :expand: 60 | 61 | To the user guide 62 | 63 | .. grid-item-card:: API reference 64 | :img-top: _static/img/index_api.svg 65 | :class-card: intro-card 66 | :shadow: md 67 | 68 | An example of API documentation. This is an example how to use `sphinx` to 69 | automatically generate reference API page. 70 | 71 | +++ 72 | 73 | .. button-ref:: api 74 | :ref-type: ref 75 | :click-parent: 76 | :color: secondary 77 | :expand: 78 | 79 | To the reference guide 80 | 81 | .. grid-item-card:: Examples 82 | :img-top: _static/img/index_examples.svg 83 | :class-card: intro-card 84 | :shadow: md 85 | 86 | A set of examples. It complements the User Guide and it is the right place to 87 | show how to use your compatible estimator. 88 | 89 | +++ 90 | 91 | .. button-ref:: general_examples 92 | :ref-type: ref 93 | :click-parent: 94 | :color: secondary 95 | :expand: 96 | 97 | To the gallery of examples 98 | 99 | 100 | .. toctree:: 101 | :maxdepth: 3 102 | :hidden: 103 | :titlesonly: 104 | 105 | quick_start 106 | user_guide 107 | api 108 | auto_examples/index 109 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\project-template.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\project-template.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /doc/quick_start.rst: -------------------------------------------------------------------------------- 1 | .. _quick_start: 2 | 3 | ############### 4 | Getting started 5 | ############### 6 | 7 | This package serves as a skeleton package aiding at developing compatible 8 | scikit-learn contribution. 9 | 10 | Creating your own scikit-learn contribution package 11 | =================================================== 12 | 13 | Download and setup your repository 14 | ---------------------------------- 15 | 16 | To create your package, you need to clone the ``project-template`` repository: 17 | 18 | .. prompt:: bash $ 19 | 20 | git clone https://github.com/scikit-learn-contrib/project-template.git 21 | 22 | Before to reinitialize your git repository, you need to make the following 23 | changes. Replace all occurrences of ``skltemplate``, ``sklearn-template``, or 24 | ``project-template`` with the name of you own project. You can find all the 25 | occurrences using the following command: 26 | 27 | .. prompt:: bash $ 28 | 29 | git grep skltemplate 30 | git grep sklearn-template 31 | git grep project-template 32 | 33 | To remove the history of the template package, you need to remove the `.git` 34 | directory: 35 | 36 | .. prompt:: bash $ 37 | 38 | rm -rf .git 39 | 40 | Then, you need to initialize your new git repository: 41 | 42 | .. prompt:: bash $ 43 | 44 | git init 45 | git add . 46 | git commit -m 'Initial commit' 47 | 48 | Finally, you create an online repository on GitHub and push your code online: 49 | 50 | .. prompt:: bash $ 51 | 52 | git remote add origin https://github.com/your_remote/your_contribution.git 53 | git push origin main 54 | 55 | Develop your own scikit-learn estimators 56 | ---------------------------------------- 57 | 58 | .. _check_estimator: http://scikit-learn.org/stable/modules/generated/sklearn.utils.estimator_checks.check_estimator.html#sklearn.utils.estimator_checks.check_estimator 59 | .. _`Contributor's Guide`: http://scikit-learn.org/stable/developers/ 60 | .. _PEP8: https://www.python.org/dev/peps/pep-0008/ 61 | .. _PEP257: https://www.python.org/dev/peps/pep-0257/ 62 | .. _NumPyDoc: https://github.com/numpy/numpydoc 63 | .. _doctests: https://docs.python.org/3/library/doctest.html 64 | 65 | You can modify the source files as you want. However, your custom estimators 66 | need to pass the check_estimator_ test to be scikit-learn compatible. We provide a 67 | file called `test_common.py` where we run the checks on our custom estimators. 68 | 69 | You can refer to the :ref:`User Guide ` to help you create a compatible 70 | scikit-learn estimator. 71 | 72 | In any case, developers should endeavor to adhere to scikit-learn's 73 | `Contributor's Guide`_ which promotes the use of: 74 | 75 | * algorithm-specific unit tests, in addition to ``check_estimator``'s common 76 | tests; 77 | * PEP8_-compliant code; 78 | * a clearly documented API using NumpyDoc_ and PEP257_-compliant docstrings; 79 | * references to relevant scientific literature in standard citation formats; 80 | * doctests_ to provide succinct usage examples; 81 | * standalone examples to illustrate the usage, model visualisation, and 82 | benefits/benchmarks of particular algorithms; 83 | * efficient code when the need for optimization is supported by benchmarks. 84 | 85 | Managing your local and continuous integration environment 86 | ---------------------------------------------------------- 87 | 88 | Here, we set up for you an repository that uses `pixi`. The `pixi.toml` file defines 89 | the packages and tasks to be run that we will present below. You can refer to the 90 | following documentation link to install `pixi`: https://pixi.sh/latest/#installation 91 | 92 | Once done, you can refer to the documentation to get started but we provide the 93 | command below to interact with the main task requested to develop your package. 94 | 95 | Edit the documentation 96 | ---------------------- 97 | 98 | .. _Sphinx: http://www.sphinx-doc.org/en/stable/ 99 | 100 | The documentation is created using Sphinx_. In addition, the examples are 101 | created using ``sphinx-gallery``. Therefore, to generate locally the 102 | documentation, you can leverage the following `pixi` task: 103 | 104 | .. prompt:: bash $ 105 | 106 | pixi run build-doc 107 | 108 | The documentation is made of: 109 | 110 | * a home page, ``doc/index.rst``; 111 | * an API documentation, ``doc/api.rst`` in which you should add all public 112 | objects for which the docstring should be exposed publicly. 113 | * a User Guide documentation, ``doc/user_guide.rst``, containing the narrative 114 | documentation of your package, to give as much intuition as possible to your 115 | users. 116 | * examples which are created in the `examples/` folder. Each example 117 | illustrates some usage of the package. the example file name should start by 118 | `plot_*.py`. 119 | 120 | Local testing 121 | ------------- 122 | 123 | To run the tests locally, you can use the following command: 124 | 125 | .. prompt:: bash $ 126 | 127 | pixi run test 128 | 129 | It will use `pytest` under the hood to run the package tests. 130 | 131 | In addition, you have a linter task to check the code consistency in terms of style: 132 | 133 | .. prompt:: bash $ 134 | 135 | pixi run lint 136 | 137 | Activating the development environment 138 | -------------------------------------- 139 | 140 | In the case that you don't want to use the `pixi run` commands and directly interact 141 | with the usual python tools, you can activate the development environment: 142 | 143 | .. prompt:: bash $ 144 | 145 | pixi shell -e dev 146 | 147 | This will activate an environment containing the dependencies needed to run the linters, 148 | tests, and build the documentation. So for instance, you can run the tests with: 149 | 150 | .. prompt:: bash $ 151 | 152 | pytest -vsl skltemplate 153 | 154 | In this case, you can even use pre-commit before using git. You will need to initialize 155 | it with: 156 | 157 | .. prompt:: bash $ 158 | 159 | pre-commit install 160 | 161 | Setup the continuous integration 162 | -------------------------------- 163 | 164 | The project template already contains configuration files of the continuous 165 | integration system. It leverage the above pixi commands and run on GitHub Actions. 166 | In short, it will: 167 | 168 | * run the tests on the different platforms (Linux, MacOS, Windows) and upload the 169 | coverage report to codecov.io; 170 | * check the code style (linter); 171 | * build the documentation and deploy it automatically on GitHub Pages. 172 | 173 | Publish your package 174 | ==================== 175 | 176 | .. _PyPi: https://packaging.python.org/tutorials/packaging-projects/ 177 | .. _conda-forge: https://conda-forge.org/ 178 | 179 | You can make your package available through PyPi_ and conda-forge_. Refer to 180 | the associated documentation to be able to upload your packages such that 181 | it will be installable with ``pip`` and ``conda``. 182 | -------------------------------------------------------------------------------- /doc/user_guide.rst: -------------------------------------------------------------------------------- 1 | .. title:: User guide : contents 2 | 3 | .. _user_guide: 4 | 5 | ========== 6 | User Guide 7 | ========== 8 | 9 | Estimator 10 | --------- 11 | 12 | The central piece of transformer, regressor, and classifier is 13 | :class:`sklearn.base.BaseEstimator`. All estimators in scikit-learn are derived 14 | from this class. In more details, this base class enables to set and get 15 | parameters of the estimator. It can be imported as:: 16 | 17 | >>> from sklearn.base import BaseEstimator 18 | 19 | Once imported, you can create a class which inherate from this base class:: 20 | 21 | >>> class MyOwnEstimator(BaseEstimator): 22 | ... pass 23 | 24 | Transformer 25 | ----------- 26 | 27 | Transformers are scikit-learn estimators which implement a ``transform`` method. 28 | The use case is the following: 29 | 30 | * at ``fit``, some parameters can be learned from ``X`` and ``y``; 31 | * at ``transform``, `X` will be transformed, using the parameters learned 32 | during ``fit``. 33 | 34 | .. _mixin: https://en.wikipedia.org/wiki/Mixin 35 | 36 | In addition, scikit-learn provides a 37 | mixin_, i.e. :class:`sklearn.base.TransformerMixin`, which 38 | implement the combination of ``fit`` and ``transform`` called ``fit_transform``. 39 | 40 | One can import the mixin class as:: 41 | 42 | >>> from sklearn.base import TransformerMixin 43 | 44 | Therefore, when creating a transformer, you need to create a class which 45 | inherits from both :class:`sklearn.base.BaseEstimator` and 46 | :class:`sklearn.base.TransformerMixin`. The scikit-learn API imposed ``fit`` to 47 | **return ``self``**. The reason is that it allows to pipeline ``fit`` and 48 | ``transform`` imposed by the :class:`sklearn.base.TransformerMixin`. The 49 | ``fit`` method is expected to have ``X`` and ``y`` as inputs. Note that 50 | ``transform`` takes only ``X`` as input and is expected to return the 51 | transformed version of ``X``:: 52 | 53 | >>> class MyOwnTransformer(TransformerMixin, BaseEstimator): 54 | ... def fit(self, X, y=None): 55 | ... return self 56 | ... def transform(self, X): 57 | ... return X 58 | 59 | We build a basic example to show that our :class:`MyOwnTransformer` is working 60 | within a scikit-learn ``pipeline``:: 61 | 62 | >>> from sklearn.datasets import load_iris 63 | >>> from sklearn.pipeline import make_pipeline 64 | >>> from sklearn.linear_model import LogisticRegression 65 | >>> X, y = load_iris(return_X_y=True) 66 | >>> pipe = make_pipeline(MyOwnTransformer(), 67 | ... LogisticRegression(random_state=10, 68 | ... solver='lbfgs')) 69 | >>> pipe.fit(X, y) # doctest: +ELLIPSIS 70 | Pipeline(...) 71 | >>> pipe.predict(X) # doctest: +ELLIPSIS 72 | array([...]) 73 | 74 | Predictor 75 | --------- 76 | 77 | Regressor 78 | ~~~~~~~~~ 79 | 80 | Similarly, regressors are scikit-learn estimators which implement a ``predict`` 81 | method. The use case is the following: 82 | 83 | * at ``fit``, some parameters can be learned from ``X`` and ``y``; 84 | * at ``predict``, predictions will be computed using ``X`` using the parameters 85 | learned during ``fit``. 86 | 87 | In addition, scikit-learn provides a mixin_, i.e. 88 | :class:`sklearn.base.RegressorMixin`, which implements the ``score`` method 89 | which computes the :math:`R^2` score of the predictions. 90 | 91 | One can import the mixin as:: 92 | 93 | >>> from sklearn.base import RegressorMixin 94 | 95 | Therefore, we create a regressor, :class:`MyOwnRegressor` which inherits from 96 | both :class:`sklearn.base.BaseEstimator` and 97 | :class:`sklearn.base.RegressorMixin`. The method ``fit`` gets ``X`` and ``y`` 98 | as input and should return ``self``. It should implement the ``predict`` 99 | function which should output the predictions of your regressor:: 100 | 101 | >>> import numpy as np 102 | >>> class MyOwnRegressor(RegressorMixin, BaseEstimator): 103 | ... def fit(self, X, y): 104 | ... return self 105 | ... def predict(self, X): 106 | ... return np.mean(X, axis=1) 107 | 108 | We illustrate that this regressor is working within a scikit-learn pipeline:: 109 | 110 | >>> from sklearn.datasets import load_diabetes 111 | >>> X, y = load_diabetes(return_X_y=True) 112 | >>> pipe = make_pipeline(MyOwnTransformer(), MyOwnRegressor()) 113 | >>> pipe.fit(X, y) # doctest: +ELLIPSIS 114 | Pipeline(...) 115 | >>> pipe.predict(X) # doctest: +ELLIPSIS 116 | array([...]) 117 | 118 | Since we inherit from the :class:`sklearn.base.RegressorMixin`, we can call 119 | the ``score`` method which will return the :math:`R^2` score:: 120 | 121 | >>> pipe.score(X, y) # doctest: +ELLIPSIS 122 | -3.9... 123 | 124 | Classifier 125 | ~~~~~~~~~~ 126 | 127 | Similarly to regressors, classifiers implement ``predict``. In addition, they 128 | output the probabilities of the prediction using the ``predict_proba`` method: 129 | 130 | * at ``fit``, some parameters can be learned from ``X`` and ``y``; 131 | * at ``predict``, predictions will be computed using ``X`` using the parameters 132 | learned during ``fit``. The output corresponds to the predicted class for each sample; 133 | * ``predict_proba`` will give a 2D matrix where each column corresponds to the 134 | class and each entry will be the probability of the associated class. 135 | 136 | In addition, scikit-learn provides a mixin, i.e. 137 | :class:`sklearn.base.ClassifierMixin`, which implements the ``score`` method 138 | which computes the accuracy score of the predictions. 139 | 140 | One can import this mixin as:: 141 | 142 | >>> from sklearn.base import ClassifierMixin 143 | 144 | Therefore, we create a classifier, :class:`MyOwnClassifier` which inherits 145 | from both :class:`slearn.base.BaseEstimator` and 146 | :class:`sklearn.base.ClassifierMixin`. The method ``fit`` gets ``X`` and ``y`` 147 | as input and should return ``self``. It should implement the ``predict`` 148 | function which should output the class inferred by the classifier. 149 | ``predict_proba`` will output some probabilities instead:: 150 | 151 | >>> class MyOwnClassifier(ClassifierMixin, BaseEstimator): 152 | ... def fit(self, X, y): 153 | ... self.classes_ = np.unique(y) 154 | ... return self 155 | ... def predict(self, X): 156 | ... return np.random.randint(0, self.classes_.size, 157 | ... size=X.shape[0]) 158 | ... def predict_proba(self, X): 159 | ... pred = np.random.rand(X.shape[0], self.classes_.size) 160 | ... return pred / np.sum(pred, axis=1)[:, np.newaxis] 161 | 162 | We illustrate that this regressor is working within a scikit-learn pipeline:: 163 | 164 | >>> X, y = load_iris(return_X_y=True) 165 | >>> pipe = make_pipeline(MyOwnTransformer(), MyOwnClassifier()) 166 | >>> pipe.fit(X, y) # doctest: +ELLIPSIS 167 | Pipeline(...) 168 | 169 | Then, you can call ``predict`` and ``predict_proba``:: 170 | 171 | >>> pipe.predict(X) # doctest: +ELLIPSIS 172 | array([...]) 173 | >>> pipe.predict_proba(X) # doctest: +ELLIPSIS 174 | array([...]) 175 | 176 | Since our classifier inherits from :class:`sklearn.base.ClassifierMixin`, we 177 | can compute the accuracy by calling the ``score`` method:: 178 | 179 | >>> pipe.score(X, y) # doctest: +ELLIPSIS 180 | 0... 181 | -------------------------------------------------------------------------------- /examples/README.txt: -------------------------------------------------------------------------------- 1 | .. _general_examples: 2 | 3 | Example Gallery 4 | =============== 5 | 6 | Introductory examples. 7 | -------------------------------------------------------------------------------- /examples/plot_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================ 3 | Plotting Template Classifier 4 | ============================ 5 | 6 | An example plot of :class:`skltemplate.template.TemplateClassifier` 7 | """ 8 | 9 | # %% 10 | # Train our classifier on very simple dataset 11 | from skltemplate import TemplateClassifier 12 | 13 | X = [[0, 0], [1, 1]] 14 | y = [0, 1] 15 | clf = TemplateClassifier().fit(X, y) 16 | 17 | # %% 18 | # Create a test dataset 19 | import numpy as np 20 | 21 | rng = np.random.RandomState(13) 22 | X_test = rng.rand(500, 2) 23 | 24 | # %% 25 | # Use scikit-learn to display the decision boundary 26 | from sklearn.inspection import DecisionBoundaryDisplay 27 | 28 | disp = DecisionBoundaryDisplay.from_estimator(clf, X_test) 29 | disp.ax_.scatter( 30 | X_test[:, 0], 31 | X_test[:, 1], 32 | c=clf.predict(X_test), 33 | s=20, 34 | edgecolors="k", 35 | linewidths=0.5, 36 | ) 37 | disp.ax_.set( 38 | xlabel="Feature 1", 39 | ylabel="Feature 2", 40 | title="Template Classifier Decision Boundary", 41 | ) 42 | -------------------------------------------------------------------------------- /examples/plot_template.py: -------------------------------------------------------------------------------- 1 | """ 2 | =========================== 3 | Plotting Template Estimator 4 | =========================== 5 | 6 | An example plot of :class:`skltemplate.template.TemplateEstimator` 7 | """ 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | 11 | from skltemplate import TemplateEstimator 12 | 13 | X = np.arange(100).reshape(100, 1) 14 | y = np.zeros((100,)) 15 | estimator = TemplateEstimator() 16 | estimator.fit(X, y) 17 | plt.plot(estimator.predict(X)) 18 | plt.show() 19 | -------------------------------------------------------------------------------- /examples/plot_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================= 3 | Plotting Template Transformer 4 | ============================= 5 | 6 | An example plot of :class:`skltemplate.template.TemplateTransformer` 7 | """ 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | 11 | from skltemplate import TemplateTransformer 12 | 13 | X = np.arange(50, dtype=np.float64).reshape(-1, 1) 14 | X /= 50 15 | estimator = TemplateTransformer() 16 | X_transformed = estimator.fit_transform(X) 17 | 18 | plt.plot(X.flatten(), label="Original Data") 19 | plt.plot(X_transformed.flatten(), label="Transformed Data") 20 | plt.title("Plots of original and transformed data") 21 | 22 | plt.legend(loc="best") 23 | plt.grid(True) 24 | plt.xlabel("Index") 25 | plt.ylabel("Value of Data") 26 | 27 | plt.show() 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "setuptools_scm[toml]>=8"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "skltemplate" 7 | dynamic = ["version"] 8 | authors = [ 9 | { name="Vighnesh Birodkar", email="vighneshbirodkar@nyu.edu" }, 10 | { name="Guillaume Lemaitre", email="g.lemaitre58@gmail.com" }, 11 | ] 12 | description = "A template for scikit-learn compatible packages." 13 | readme = "README.md" 14 | requires-python = ">=3.9" 15 | dependencies = [ 16 | "scikit-learn>=1.4.2", 17 | ] 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "License :: OSI Approved :: BSD License", 25 | "Operating System :: POSIX", 26 | "Operating System :: Unix", 27 | "Operating System :: MacOS", 28 | "Operating System :: Microsoft :: Windows", 29 | ] 30 | 31 | [project.urls] 32 | Homepage = "https://github.com/scikit-learn-contrib/project-template" 33 | Issues = "https://github.com/scikit-learn-contrib/project-template/issues" 34 | 35 | [tool.setuptools_scm] 36 | version_file = "skltemplate/_version.py" 37 | 38 | [tool.pixi.project] 39 | channels = ["conda-forge"] 40 | platforms = ["win-64", "linux-64", "osx-64", "osx-arm64"] 41 | 42 | [tool.pixi.dependencies] 43 | python = ">=3.9" 44 | scikit-learn = ">=1.4.2" 45 | 46 | [tool.pixi.pypi-dependencies] 47 | skltemplate = { path=".", editable=true } 48 | 49 | [tool.pixi.feature.lint.dependencies] 50 | # The version below should be aligned with the one of `.pre-commit-config.yaml` 51 | black = "23.3.0" 52 | pre-commit = "3.7.1" 53 | ruff = "0.4.2" 54 | 55 | [tool.pixi.feature.lint.tasks] 56 | black = { cmd = "black --check --diff skltemplate && black --check --diff examples" } 57 | ruff = { cmd = "ruff check --output-format=full skltemplate && ruff check --output-format=full examples" } 58 | lint = { depends_on = ["black", "ruff"]} 59 | 60 | [tool.pixi.feature.test.dependencies] 61 | pytest = "*" 62 | pytest-cov = "*" 63 | 64 | [tool.pixi.feature.test.tasks] 65 | test = { cmd = "pytest -vsl --cov=skltemplate --cov-report=xml skltemplate" } 66 | 67 | [tool.pixi.feature.doc.dependencies] 68 | matplotlib = "*" 69 | numpydoc = "*" 70 | pydata-sphinx-theme = "*" 71 | setuptools-scm = ">=8" # needed for the versioning 72 | sphinx = "*" 73 | sphinx-design = "*" 74 | sphinx-gallery = "*" 75 | sphinx-prompt = "*" 76 | 77 | [tool.pixi.feature.doc.tasks] 78 | build-doc = { cmd = "make html", cwd = "doc" } 79 | clean-doc = { cmd = "rm -rf _build", cwd = "doc" } 80 | 81 | [tool.pixi.environments] 82 | doc = ["doc"] 83 | lint = ["lint"] 84 | test = ["test"] 85 | dev = ["doc", "lint", "test"] 86 | 87 | [tool.black] 88 | line-length = 88 89 | target_version = ['py38', 'py39', 'py310'] 90 | preview = true 91 | exclude = ''' 92 | /( 93 | \.eggs # exclude a few common directories in the 94 | | \.git # root of the project 95 | | \.vscode 96 | )/ 97 | ''' 98 | force-exclude = "skltemplate/_version.py" 99 | 100 | [tool.ruff] 101 | # max line length for black 102 | line-length = 88 103 | target-version = "py38" 104 | exclude=[ 105 | ".git", 106 | "__pycache__", 107 | "dist", 108 | "doc/_build", 109 | "doc/auto_examples", 110 | "build", 111 | "skltemplate/_version.py", 112 | ] 113 | 114 | [tool.ruff.lint] 115 | # all rules can be found here: https://beta.ruff.rs/docs/rules/ 116 | select = ["E", "F", "W", "I"] 117 | ignore=[ 118 | # space before : (needed for how black formats slicing) 119 | "E203", 120 | # do not assign a lambda expression, use a def 121 | "E731", 122 | # do not use variables named 'l', 'O', or 'I' 123 | "E741", 124 | ] 125 | 126 | [tool.ruff.lint.per-file-ignores] 127 | # It's fine not to put the import at the top of the file in the examples 128 | # folder. 129 | "examples/*"=["E402"] 130 | "doc/conf.py"=["E402"] 131 | "doc/_templates/numpydoc_docstring.py"=["F821", "W292"] 132 | 133 | [tool.pytest.ini_options] 134 | addopts = "--doctest-modules --color=yes" 135 | doctest_optionflags = "NORMALIZE_WHITESPACE" 136 | -------------------------------------------------------------------------------- /skltemplate/__init__.py: -------------------------------------------------------------------------------- 1 | # Authors: scikit-learn-contrib developers 2 | # License: BSD 3 clause 3 | 4 | from ._template import TemplateClassifier, TemplateEstimator, TemplateTransformer 5 | from ._version import __version__ 6 | 7 | __all__ = [ 8 | "TemplateEstimator", 9 | "TemplateClassifier", 10 | "TemplateTransformer", 11 | "__version__", 12 | ] 13 | -------------------------------------------------------------------------------- /skltemplate/_template.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a module to be used as a reference for building other modules 3 | """ 4 | 5 | # Authors: scikit-learn-contrib developers 6 | # License: BSD 3 clause 7 | 8 | import numpy as np 9 | from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, _fit_context 10 | from sklearn.metrics import euclidean_distances 11 | from sklearn.utils.multiclass import check_classification_targets 12 | from sklearn.utils.validation import check_is_fitted 13 | 14 | 15 | class TemplateEstimator(BaseEstimator): 16 | """A template estimator to be used as a reference implementation. 17 | 18 | For more information regarding how to build your own estimator, read more 19 | in the :ref:`User Guide `. 20 | 21 | Parameters 22 | ---------- 23 | demo_param : str, default='demo_param' 24 | A parameter used for demonstration of how to pass and store parameters. 25 | 26 | Attributes 27 | ---------- 28 | is_fitted_ : bool 29 | A boolean indicating whether the estimator has been fitted. 30 | 31 | n_features_in_ : int 32 | Number of features seen during :term:`fit`. 33 | 34 | feature_names_in_ : ndarray of shape (`n_features_in_`,) 35 | Names of features seen during :term:`fit`. Defined only when `X` 36 | has feature names that are all strings. 37 | 38 | Examples 39 | -------- 40 | >>> from skltemplate import TemplateEstimator 41 | >>> import numpy as np 42 | >>> X = np.arange(100).reshape(100, 1) 43 | >>> y = np.zeros((100, )) 44 | >>> estimator = TemplateEstimator() 45 | >>> estimator.fit(X, y) 46 | TemplateEstimator() 47 | """ 48 | 49 | # This is a dictionary allowing to define the type of parameters. 50 | # It used to validate parameter within the `_fit_context` decorator. 51 | _parameter_constraints = { 52 | "demo_param": [str], 53 | } 54 | 55 | def __init__(self, demo_param="demo_param"): 56 | self.demo_param = demo_param 57 | 58 | @_fit_context(prefer_skip_nested_validation=True) 59 | def fit(self, X, y): 60 | """A reference implementation of a fitting function. 61 | 62 | Parameters 63 | ---------- 64 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 65 | The training input samples. 66 | 67 | y : array-like, shape (n_samples,) or (n_samples, n_outputs) 68 | The target values (class labels in classification, real numbers in 69 | regression). 70 | 71 | Returns 72 | ------- 73 | self : object 74 | Returns self. 75 | """ 76 | # `_validate_data` is defined in the `BaseEstimator` class. 77 | # It allows to: 78 | # - run different checks on the input data; 79 | # - define some attributes associated to the input data: `n_features_in_` and 80 | # `feature_names_in_`. 81 | X, y = self._validate_data(X, y, accept_sparse=True) 82 | self.is_fitted_ = True 83 | # `fit` should always return `self` 84 | return self 85 | 86 | def predict(self, X): 87 | """A reference implementation of a predicting function. 88 | 89 | Parameters 90 | ---------- 91 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 92 | The training input samples. 93 | 94 | Returns 95 | ------- 96 | y : ndarray, shape (n_samples,) 97 | Returns an array of ones. 98 | """ 99 | # Check if fit had been called 100 | check_is_fitted(self) 101 | # We need to set reset=False because we don't want to overwrite `n_features_in_` 102 | # `feature_names_in_` but only check that the shape is consistent. 103 | X = self._validate_data(X, accept_sparse=True, reset=False) 104 | return np.ones(X.shape[0], dtype=np.int64) 105 | 106 | 107 | # Note that the mixin class should always be on the left of `BaseEstimator` to ensure 108 | # the MRO works as expected. 109 | class TemplateClassifier(ClassifierMixin, BaseEstimator): 110 | """An example classifier which implements a 1-NN algorithm. 111 | 112 | For more information regarding how to build your own classifier, read more 113 | in the :ref:`User Guide `. 114 | 115 | Parameters 116 | ---------- 117 | demo_param : str, default='demo' 118 | A parameter used for demonstation of how to pass and store paramters. 119 | 120 | Attributes 121 | ---------- 122 | X_ : ndarray, shape (n_samples, n_features) 123 | The input passed during :meth:`fit`. 124 | 125 | y_ : ndarray, shape (n_samples,) 126 | The labels passed during :meth:`fit`. 127 | 128 | classes_ : ndarray, shape (n_classes,) 129 | The classes seen at :meth:`fit`. 130 | 131 | n_features_in_ : int 132 | Number of features seen during :term:`fit`. 133 | 134 | feature_names_in_ : ndarray of shape (`n_features_in_`,) 135 | Names of features seen during :term:`fit`. Defined only when `X` 136 | has feature names that are all strings. 137 | 138 | Examples 139 | -------- 140 | >>> from sklearn.datasets import load_iris 141 | >>> from skltemplate import TemplateClassifier 142 | >>> X, y = load_iris(return_X_y=True) 143 | >>> clf = TemplateClassifier().fit(X, y) 144 | >>> clf.predict(X) 145 | array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 146 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 147 | 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 148 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 149 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 150 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 151 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) 152 | """ 153 | 154 | # This is a dictionary allowing to define the type of parameters. 155 | # It used to validate parameter within the `_fit_context` decorator. 156 | _parameter_constraints = { 157 | "demo_param": [str], 158 | } 159 | 160 | def __init__(self, demo_param="demo"): 161 | self.demo_param = demo_param 162 | 163 | @_fit_context(prefer_skip_nested_validation=True) 164 | def fit(self, X, y): 165 | """A reference implementation of a fitting function for a classifier. 166 | 167 | Parameters 168 | ---------- 169 | X : array-like, shape (n_samples, n_features) 170 | The training input samples. 171 | 172 | y : array-like, shape (n_samples,) 173 | The target values. An array of int. 174 | 175 | Returns 176 | ------- 177 | self : object 178 | Returns self. 179 | """ 180 | # `_validate_data` is defined in the `BaseEstimator` class. 181 | # It allows to: 182 | # - run different checks on the input data; 183 | # - define some attributes associated to the input data: `n_features_in_` and 184 | # `feature_names_in_`. 185 | X, y = self._validate_data(X, y) 186 | # We need to make sure that we have a classification task 187 | check_classification_targets(y) 188 | 189 | # classifier should always store the classes seen during `fit` 190 | self.classes_ = np.unique(y) 191 | 192 | # Store the training data to predict later 193 | self.X_ = X 194 | self.y_ = y 195 | 196 | # Return the classifier 197 | return self 198 | 199 | def predict(self, X): 200 | """A reference implementation of a prediction for a classifier. 201 | 202 | Parameters 203 | ---------- 204 | X : array-like, shape (n_samples, n_features) 205 | The input samples. 206 | 207 | Returns 208 | ------- 209 | y : ndarray, shape (n_samples,) 210 | The label for each sample is the label of the closest sample 211 | seen during fit. 212 | """ 213 | # Check if fit had been called 214 | check_is_fitted(self) 215 | 216 | # Input validation 217 | # We need to set reset=False because we don't want to overwrite `n_features_in_` 218 | # `feature_names_in_` but only check that the shape is consistent. 219 | X = self._validate_data(X, reset=False) 220 | 221 | closest = np.argmin(euclidean_distances(X, self.X_), axis=1) 222 | return self.y_[closest] 223 | 224 | 225 | # Note that the mixin class should always be on the left of `BaseEstimator` to ensure 226 | # the MRO works as expected. 227 | class TemplateTransformer(TransformerMixin, BaseEstimator): 228 | """An example transformer that returns the element-wise square root. 229 | 230 | For more information regarding how to build your own transformer, read more 231 | in the :ref:`User Guide `. 232 | 233 | Parameters 234 | ---------- 235 | demo_param : str, default='demo' 236 | A parameter used for demonstation of how to pass and store paramters. 237 | 238 | Attributes 239 | ---------- 240 | n_features_in_ : int 241 | Number of features seen during :term:`fit`. 242 | 243 | feature_names_in_ : ndarray of shape (`n_features_in_`,) 244 | Names of features seen during :term:`fit`. Defined only when `X` 245 | has feature names that are all strings. 246 | """ 247 | 248 | # This is a dictionary allowing to define the type of parameters. 249 | # It used to validate parameter within the `_fit_context` decorator. 250 | _parameter_constraints = { 251 | "demo_param": [str], 252 | } 253 | 254 | def __init__(self, demo_param="demo"): 255 | self.demo_param = demo_param 256 | 257 | @_fit_context(prefer_skip_nested_validation=True) 258 | def fit(self, X, y=None): 259 | """A reference implementation of a fitting function for a transformer. 260 | 261 | Parameters 262 | ---------- 263 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 264 | The training input samples. 265 | 266 | y : None 267 | There is no need of a target in a transformer, yet the pipeline API 268 | requires this parameter. 269 | 270 | Returns 271 | ------- 272 | self : object 273 | Returns self. 274 | """ 275 | X = self._validate_data(X, accept_sparse=True) 276 | 277 | # Return the transformer 278 | return self 279 | 280 | def transform(self, X): 281 | """A reference implementation of a transform function. 282 | 283 | Parameters 284 | ---------- 285 | X : {array-like, sparse-matrix}, shape (n_samples, n_features) 286 | The input samples. 287 | 288 | Returns 289 | ------- 290 | X_transformed : array, shape (n_samples, n_features) 291 | The array containing the element-wise square roots of the values 292 | in ``X``. 293 | """ 294 | # Since this is a stateless transformer, we should not call `check_is_fitted`. 295 | # Common test will check for this particularly. 296 | 297 | # Input validation 298 | # We need to set reset=False because we don't want to overwrite `n_features_in_` 299 | # `feature_names_in_` but only check that the shape is consistent. 300 | X = self._validate_data(X, accept_sparse=True, reset=False) 301 | return np.sqrt(X) 302 | 303 | def _more_tags(self): 304 | # This is a quick example to show the tags API:\ 305 | # https://scikit-learn.org/dev/developers/develop.html#estimator-tags 306 | # Here, our transformer does not do any operation in `fit` and only validate 307 | # the parameters. Thus, it is stateless. 308 | return {"stateless": True} 309 | -------------------------------------------------------------------------------- /skltemplate/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Authors: scikit-learn-contrib developers 2 | # License: BSD 3 clause 3 | -------------------------------------------------------------------------------- /skltemplate/tests/test_common.py: -------------------------------------------------------------------------------- 1 | """This file shows how to write test based on the scikit-learn common tests.""" 2 | 3 | # Authors: scikit-learn-contrib developers 4 | # License: BSD 3 clause 5 | 6 | from sklearn.utils.estimator_checks import parametrize_with_checks 7 | 8 | from skltemplate.utils.discovery import all_estimators 9 | 10 | 11 | # parametrize_with_checks allows to get a generator of check that is more fine-grained 12 | # than check_estimator 13 | @parametrize_with_checks([est() for _, est in all_estimators()]) 14 | def test_estimators(estimator, check, request): 15 | """Check the compatibility with scikit-learn API""" 16 | check(estimator) 17 | -------------------------------------------------------------------------------- /skltemplate/tests/test_template.py: -------------------------------------------------------------------------------- 1 | """This file will just show how to write tests for the template classes.""" 2 | import numpy as np 3 | import pytest 4 | from sklearn.datasets import load_iris 5 | from sklearn.utils._testing import assert_allclose, assert_array_equal 6 | 7 | from skltemplate import TemplateClassifier, TemplateEstimator, TemplateTransformer 8 | 9 | # Authors: scikit-learn-contrib developers 10 | # License: BSD 3 clause 11 | 12 | 13 | @pytest.fixture 14 | def data(): 15 | return load_iris(return_X_y=True) 16 | 17 | 18 | def test_template_estimator(data): 19 | """Check the internals and behaviour of `TemplateEstimator`.""" 20 | est = TemplateEstimator() 21 | assert est.demo_param == "demo_param" 22 | 23 | est.fit(*data) 24 | assert hasattr(est, "is_fitted_") 25 | 26 | X = data[0] 27 | y_pred = est.predict(X) 28 | assert_array_equal(y_pred, np.ones(X.shape[0], dtype=np.int64)) 29 | 30 | 31 | def test_template_transformer(data): 32 | """Check the internals and behaviour of `TemplateTransformer`.""" 33 | X, y = data 34 | trans = TemplateTransformer() 35 | assert trans.demo_param == "demo" 36 | 37 | trans.fit(X) 38 | assert trans.n_features_in_ == X.shape[1] 39 | 40 | X_trans = trans.transform(X) 41 | assert_allclose(X_trans, np.sqrt(X)) 42 | 43 | X_trans = trans.fit_transform(X) 44 | assert_allclose(X_trans, np.sqrt(X)) 45 | 46 | 47 | def test_template_classifier(data): 48 | """Check the internals and behaviour of `TemplateClassifier`.""" 49 | X, y = data 50 | clf = TemplateClassifier() 51 | assert clf.demo_param == "demo" 52 | 53 | clf.fit(X, y) 54 | assert hasattr(clf, "classes_") 55 | assert hasattr(clf, "X_") 56 | assert hasattr(clf, "y_") 57 | 58 | y_pred = clf.predict(X) 59 | assert y_pred.shape == (X.shape[0],) 60 | -------------------------------------------------------------------------------- /skltemplate/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Authors: scikit-learn-contrib developers 2 | # License: BSD 3 clause 3 | -------------------------------------------------------------------------------- /skltemplate/utils/discovery.py: -------------------------------------------------------------------------------- 1 | """ 2 | The :mod:`skltemplate.utils.discovery` module includes utilities to discover 3 | objects (i.e. estimators, displays, functions) from the `skltemplate` package. 4 | """ 5 | 6 | # Adapted from scikit-learn 7 | # Authors: scikit-learn-contrib developers 8 | # License: BSD 3 clause 9 | 10 | import inspect 11 | import pkgutil 12 | from importlib import import_module 13 | from operator import itemgetter 14 | from pathlib import Path 15 | 16 | from sklearn.base import ( 17 | BaseEstimator, 18 | ClassifierMixin, 19 | ClusterMixin, 20 | RegressorMixin, 21 | TransformerMixin, 22 | ) 23 | from sklearn.utils._testing import ignore_warnings 24 | 25 | _MODULE_TO_IGNORE = {"tests"} 26 | 27 | 28 | def all_estimators(type_filter=None): 29 | """Get a list of all estimators from `skltemplate`. 30 | 31 | This function crawls the module and gets all classes that inherit 32 | from `BaseEstimator`. Classes that are defined in test-modules are not 33 | included. 34 | 35 | Parameters 36 | ---------- 37 | type_filter : {"classifier", "regressor", "cluster", "transformer"} \ 38 | or list of such str, default=None 39 | Which kind of estimators should be returned. If None, no filter is 40 | applied and all estimators are returned. Possible values are 41 | 'classifier', 'regressor', 'cluster' and 'transformer' to get 42 | estimators only of these specific types, or a list of these to 43 | get the estimators that fit at least one of the types. 44 | 45 | Returns 46 | ------- 47 | estimators : list of tuples 48 | List of (name, class), where ``name`` is the class name as string 49 | and ``class`` is the actual type of the class. 50 | 51 | Examples 52 | -------- 53 | >>> from skltemplate.utils.discovery import all_estimators 54 | >>> estimators = all_estimators() 55 | >>> type(estimators) 56 | 57 | """ 58 | 59 | def is_abstract(c): 60 | if not (hasattr(c, "__abstractmethods__")): 61 | return False 62 | if not len(c.__abstractmethods__): 63 | return False 64 | return True 65 | 66 | all_classes = [] 67 | root = str(Path(__file__).parent.parent) # skltemplate package 68 | # Ignore deprecation warnings triggered at import time and from walking 69 | # packages 70 | with ignore_warnings(category=FutureWarning): 71 | for _, module_name, _ in pkgutil.walk_packages( 72 | path=[root], prefix="skltemplate." 73 | ): 74 | module_parts = module_name.split(".") 75 | if any(part in _MODULE_TO_IGNORE for part in module_parts): 76 | continue 77 | module = import_module(module_name) 78 | classes = inspect.getmembers(module, inspect.isclass) 79 | classes = [ 80 | (name, est_cls) for name, est_cls in classes if not name.startswith("_") 81 | ] 82 | 83 | all_classes.extend(classes) 84 | 85 | all_classes = set(all_classes) 86 | 87 | estimators = [ 88 | c 89 | for c in all_classes 90 | if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator") 91 | ] 92 | # get rid of abstract base classes 93 | estimators = [c for c in estimators if not is_abstract(c[1])] 94 | 95 | if type_filter is not None: 96 | if not isinstance(type_filter, list): 97 | type_filter = [type_filter] 98 | else: 99 | type_filter = list(type_filter) # copy 100 | filtered_estimators = [] 101 | filters = { 102 | "classifier": ClassifierMixin, 103 | "regressor": RegressorMixin, 104 | "transformer": TransformerMixin, 105 | "cluster": ClusterMixin, 106 | } 107 | for name, mixin in filters.items(): 108 | if name in type_filter: 109 | type_filter.remove(name) 110 | filtered_estimators.extend( 111 | [est for est in estimators if issubclass(est[1], mixin)] 112 | ) 113 | estimators = filtered_estimators 114 | if type_filter: 115 | raise ValueError( 116 | "Parameter type_filter must be 'classifier', " 117 | "'regressor', 'transformer', 'cluster' or " 118 | "None, got" 119 | f" {repr(type_filter)}." 120 | ) 121 | 122 | # drop duplicates, sort for reproducibility 123 | # itemgetter is used to ensure the sort does not extend to the 2nd item of 124 | # the tuple 125 | return sorted(set(estimators), key=itemgetter(0)) 126 | 127 | 128 | def all_displays(): 129 | """Get a list of all displays from `skltemplate`. 130 | 131 | Returns 132 | ------- 133 | displays : list of tuples 134 | List of (name, class), where ``name`` is the display class name as 135 | string and ``class`` is the actual type of the class. 136 | 137 | Examples 138 | -------- 139 | >>> from skltemplate.utils.discovery import all_displays 140 | >>> displays = all_displays() 141 | """ 142 | all_classes = [] 143 | root = str(Path(__file__).parent.parent) # skltemplate package 144 | # Ignore deprecation warnings triggered at import time and from walking 145 | # packages 146 | with ignore_warnings(category=FutureWarning): 147 | for _, module_name, _ in pkgutil.walk_packages( 148 | path=[root], prefix="skltemplate." 149 | ): 150 | module_parts = module_name.split(".") 151 | if any(part in _MODULE_TO_IGNORE for part in module_parts): 152 | continue 153 | module = import_module(module_name) 154 | classes = inspect.getmembers(module, inspect.isclass) 155 | classes = [ 156 | (name, display_class) 157 | for name, display_class in classes 158 | if not name.startswith("_") and name.endswith("Display") 159 | ] 160 | all_classes.extend(classes) 161 | 162 | return sorted(set(all_classes), key=itemgetter(0)) 163 | 164 | 165 | def _is_checked_function(item): 166 | if not inspect.isfunction(item): 167 | return False 168 | 169 | if item.__name__.startswith("_"): 170 | return False 171 | 172 | mod = item.__module__ 173 | if not mod.startswith("skltemplate.") or mod.endswith("estimator_checks"): 174 | return False 175 | 176 | return True 177 | 178 | 179 | def all_functions(): 180 | """Get a list of all functions from `skltemplate`. 181 | 182 | Returns 183 | ------- 184 | functions : list of tuples 185 | List of (name, function), where ``name`` is the function name as 186 | string and ``function`` is the actual function. 187 | 188 | Examples 189 | -------- 190 | >>> from skltemplate.utils.discovery import all_functions 191 | >>> functions = all_functions() 192 | """ 193 | all_functions = [] 194 | root = str(Path(__file__).parent.parent) # skltemplate package 195 | # Ignore deprecation warnings triggered at import time and from walking 196 | # packages 197 | with ignore_warnings(category=FutureWarning): 198 | for _, module_name, _ in pkgutil.walk_packages( 199 | path=[root], prefix="skltemplate." 200 | ): 201 | module_parts = module_name.split(".") 202 | if any(part in _MODULE_TO_IGNORE for part in module_parts): 203 | continue 204 | 205 | module = import_module(module_name) 206 | functions = inspect.getmembers(module, _is_checked_function) 207 | functions = [ 208 | (func.__name__, func) 209 | for name, func in functions 210 | if not name.startswith("_") 211 | ] 212 | all_functions.extend(functions) 213 | 214 | # drop duplicates, sort for reproducibility 215 | # itemgetter is used to ensure the sort does not extend to the 2nd item of 216 | # the tuple 217 | return sorted(set(all_functions), key=itemgetter(0)) 218 | -------------------------------------------------------------------------------- /skltemplate/utils/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Authors: scikit-learn-contrib developers 2 | # License: BSD 3 clause 3 | -------------------------------------------------------------------------------- /skltemplate/utils/tests/test_discovery.py: -------------------------------------------------------------------------------- 1 | # Authors: scikit-learn-contrib developers 2 | # License: BSD 3 clause 3 | 4 | import pytest 5 | 6 | from skltemplate.utils.discovery import all_displays, all_estimators, all_functions 7 | 8 | 9 | def test_all_estimators(): 10 | estimators = all_estimators() 11 | assert len(estimators) == 3 12 | 13 | estimators = all_estimators(type_filter="classifier") 14 | assert len(estimators) == 1 15 | 16 | estimators = all_estimators(type_filter=["classifier", "transformer"]) 17 | assert len(estimators) == 2 18 | 19 | err_msg = "Parameter type_filter must be" 20 | with pytest.raises(ValueError, match=err_msg): 21 | all_estimators(type_filter="xxxx") 22 | 23 | 24 | def test_all_displays(): 25 | displays = all_displays() 26 | assert len(displays) == 0 27 | 28 | 29 | def test_all_functions(): 30 | functions = all_functions() 31 | assert len(functions) == 3 32 | --------------------------------------------------------------------------------