├── .github ├── dependabot.yaml └── workflows │ ├── check-typos.yaml │ ├── deploy-docs.yaml │ ├── pre-commit-update.yaml │ └── pull-request.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .streamlit └── config.toml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── Makefile ├── README.md ├── _typos.toml ├── docs ├── contribute.md ├── css │ ├── custom.css │ └── termynal.css ├── img │ ├── cli.png │ ├── sksmith-logo.png │ ├── sksmith-logo.svg │ ├── tui.png │ └── webui.png ├── index.md ├── installation.md ├── js │ ├── custom.js │ └── termynal.js ├── user-guide.md └── why.md ├── mkdocs.yml ├── noxfile.py ├── pyproject.toml ├── requirements.txt ├── requirements └── test.txt ├── sksmithy ├── __init__.py ├── __main__.py ├── _arguments.py ├── _callbacks.py ├── _logger.py ├── _models.py ├── _parsers.py ├── _prompts.py ├── _static │ ├── description.md │ ├── template.py.jinja │ └── tui.tcss ├── _utils.py ├── app.py ├── cli.py ├── py.typed └── tui │ ├── __init__.py │ ├── _components.py │ ├── _tui.py │ └── _validators.py └── tests ├── __init__.py ├── conftest.py ├── test_app.py ├── test_cli.py ├── test_parsers.py ├── test_render.py └── test_tui.py /.github/dependabot.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "monthly" 8 | -------------------------------------------------------------------------------- /.github/workflows/check-typos.yaml: -------------------------------------------------------------------------------- 1 | name: Check spelling typos 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | 11 | run-typos: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout source code 15 | uses: actions/checkout@v4 16 | 17 | - name: Check spelling 18 | uses: crate-ci/typos@master 19 | with: 20 | files: . 21 | -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy Documentation 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | permissions: 7 | contents: write 8 | 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout source code 14 | uses: actions/checkout@v4 15 | - name: Configure Git Credentials 16 | run: | 17 | git config user.name github-actions[bot] 18 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.10" 23 | - name: Install uv 24 | run: curl -LsSf https://astral.sh/uv/install.sh | sh 25 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 26 | - uses: actions/cache@v4 27 | with: 28 | key: mkdocs-material-${{ env.cache_id }} 29 | path: .cache 30 | restore-keys: | 31 | mkdocs-material- 32 | 33 | - name: Install dependencies and deploy 34 | run: | 35 | uv install mkdocs-material textual --system 36 | mkdocs gh-deploy --force 37 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit-update.yaml: -------------------------------------------------------------------------------- 1 | name: Pre-commit auto-update 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 0 1 * *" # Every 1st of the month at 00:00 UTC 7 | 8 | permissions: write-all 9 | 10 | jobs: 11 | auto-update: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout source code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.10" 21 | 22 | - name: pre-commit install autoupdate 23 | run: | 24 | pip install pre-commit 25 | pre-commit autoupdate 26 | 27 | - name: Commit and push changes 28 | uses: peter-evans/create-pull-request@v7 29 | with: 30 | branch: update-pre-commit-hooks 31 | title: 'Update pre-commit hooks' 32 | commit-message: 'Update pre-commit hooks' 33 | body: | 34 | Update versions of pre-commit hooks to latest versions. 35 | -------------------------------------------------------------------------------- /.github/workflows/pull-request.yaml: -------------------------------------------------------------------------------- 1 | name: PR Checks 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout source code 14 | uses: actions/checkout@v4 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.10" 19 | - name: Install uv 20 | run: curl -LsSf https://astral.sh/uv/install.sh | sh 21 | - name: Install & run linter 22 | run: | 23 | uv pip install ruff --system 24 | make lint 25 | test: 26 | strategy: 27 | matrix: 28 | os: [ubuntu-latest, macos-latest, windows-latest] 29 | python-version: ["3.10", "3.11", "3.12"] 30 | runs-on: ${{ matrix.os }} 31 | steps: 32 | - name: Checkout source code 33 | uses: actions/checkout@v4 34 | - name: Set up Python ${{ matrix.python-version }} 35 | uses: actions/setup-python@v5 36 | with: 37 | python-version: ${{ matrix.python-version }} 38 | - name: Install uv 39 | run: curl -LsSf https://astral.sh/uv/install.sh | sh 40 | - name: Install dependencies and run tests 41 | run: | 42 | uv pip install -e ".[all]" --system 43 | uv pip install -r requirements/test.txt --system 44 | make test-cov 45 | - name: Install and run mypy 46 | run: | 47 | uv pip install mypy --system 48 | mypy sksmithy tests 49 | 50 | 51 | doc-build: 52 | runs-on: ubuntu-latest 53 | steps: 54 | - name: Checkout source code 55 | uses: actions/checkout@v4 56 | - name: Set up Python 57 | uses: actions/setup-python@v5 58 | with: 59 | python-version: "3.10" 60 | - name: Install uv 61 | run: curl -LsSf https://astral.sh/uv/install.sh | sh 62 | - name: Install dependencies and check docs can build 63 | run: | 64 | uv pip install mkdocs-material textual --system 65 | mkdocs build -v -s 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Cache of screenshots used in the docs 165 | .screenshot_cache -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: requirements-txt-fixer 8 | - id: check-json 9 | - id: check-yaml 10 | - id: check-ast 11 | - id: check-added-large-files 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: v0.4.7 14 | hooks: 15 | - id: ruff-format 16 | args: [sksmithy, tests] 17 | - id: ruff 18 | args: [--fix, sksmithy, tests] 19 | - repo: https://github.com/pre-commit/mirrors-mypy 20 | rev: v1.10.0 21 | hooks: 22 | - id: mypy 23 | args: [sksmithy, tests] 24 | - repo: https://github.com/Lucas-C/pre-commit-hooks-bandit 25 | rev: v1.0.6 26 | hooks: 27 | - id: python-bandit-vulnerability-check 28 | args: [--skip, "B101",--severity-level, medium, --recursive, sksmithy] 29 | - repo: https://github.com/pre-commit/pygrep-hooks 30 | rev: v1.10.0 31 | hooks: 32 | - id: python-no-eval 33 | - repo: https://github.com/crate-ci/typos 34 | rev: v1.21.0 35 | hooks: 36 | - id: typos 37 | -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [client] 2 | showSidebarNavigation = false 3 | 4 | [logger] 5 | level = "warning" 6 | 7 | [server] 8 | runOnSave = true 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | https://www.linkedin.com/in/francesco-bruzzesi/. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Francesco Bruzzesi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | sources = sksmithy tests 2 | 3 | clean-folders: 4 | rm -rf __pycache__ */__pycache__ */**/__pycache__ \ 5 | .pytest_cache */.pytest_cache */**/.pytest_cache \ 6 | .ruff_cache */.ruff_cache */**/.ruff_cache \ 7 | .mypy_cache */.mypy_cache */**/.mypy_cache \ 8 | .screenshot_cache \ 9 | site build dist htmlcov .coverage .tox 10 | 11 | 12 | lint: 13 | ruff version 14 | ruff format $(sources) 15 | ruff check $(sources) --fix 16 | ruff clean 17 | 18 | # Requires pytest-xdist (pip install pytest-xdist) 19 | test: 20 | pytest tests -n auto 21 | 22 | # Requires pytest-cov (pip install pytest-cov) 23 | test-cov: 24 | pytest tests --cov=sksmithy -n auto 25 | 26 | # Requires coverage (pip install coverage) 27 | coverage: 28 | rm -rf .coverage 29 | (rm docs/img/coverage.svg) || (echo "No coverage.svg file found") 30 | coverage run -m pytest 31 | coverage report -m 32 | coverage-badge -o docs/img/coverage.svg 33 | 34 | types: 35 | mypy $(sources) 36 | 37 | check: lint test-cov types clean-folders 38 | 39 | docs-serve: 40 | mkdocs serve 41 | 42 | docs-deploy: 43 | mkdocs gh-deploy 44 | 45 | pypi-push: 46 | rm -rf dist 47 | hatch build 48 | hatch publish 49 | 50 | get-version : 51 | @echo $(shell grep -m 1 version pyproject.toml | tr -s ' ' | tr -d '"' | tr -d "'" | cut -d' ' -f3) 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Scikit-learn Smithy 4 | 5 | Scikit-learn smithy is a tool that helps you to forge scikit-learn compatible estimator with ease. 6 | 7 | --- 8 | 9 | [WebUI](https://sklearn-smithy.streamlit.app/) | [Documentation](https://fbruzzesi.github.io/sklearn-smithy) | [Repository](https://github.com/fbruzzesi/sklearn-smithy) | [Issue Tracker](https://github.com/fbruzzesi/sklearn-smithy/issues) 10 | 11 | --- 12 | 13 | How can you use it? 14 | 15 |
✅ Directly from the browser via a Web UI. 16 | 17 |
18 | 19 | - Available at [sklearn-smithy.streamlit.app](https://sklearn-smithy.streamlit.app/) 20 | - It requires no installation. 21 | - Powered by [streamlit](https://streamlit.io/) 22 | 23 | 24 | 25 |
26 | 27 |
✅ As a CLI (command line interface) in the terminal. 28 | 29 |
30 | 31 | - Available via the `smith forge` command. 32 | - It requires [installation](#installation): `python -m pip install sklearn-smithy` 33 | - Powered by [typer](https://typer.tiangolo.com/). 34 | 35 | 36 | 37 |
38 | 39 |
✅ As a TUI (terminal user interface) in the terminal. 40 | 41 |
42 | 43 | - Available via the `smith forge-tui` command. 44 | - It requires installing [extra dependencies](#extra-dependencies): `python -m pip install "sklearn-smithy[textual]"` 45 | - Powered by [textual](https://textual.textualize.io/). 46 | 47 | 48 | 49 |
50 | 51 | All these tools will prompt a series of questions regarding the estimator you want to create, and then it will generate the boilerplate code for you. 52 | 53 | ## Why ❓ 54 | 55 | Writing scikit-learn compatible estimators might be harder than expected. 56 | 57 | While everyone knows about the `fit` and `predict`, there are other behaviours, methods and attributes that 58 | scikit-learn might be expecting from your estimator depending on: 59 | 60 | - The type of estimator you're writing. 61 | - The signature of the estimator. 62 | - The signature of the `.fit(...)` method. 63 | 64 | Scikit-learn Smithy to the rescue: this tool aims to help you crafting your own estimator by asking a few 65 | questions about it, and then generating the boilerplate code. 66 | 67 | In this way you will be able to fully focus on the core implementation logic, and not on nitty-gritty details 68 | of the scikit-learn API. 69 | 70 | ### Sanity check 71 | 72 | Once the core logic is implemented, the estimator should be ready to test against the _somewhat official_ 73 | [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks) 74 | pytest compatible decorator: 75 | 76 | ```py 77 | from sklearn.utils.estimator_checks import parametrize_with_checks 78 | 79 | @parametrize_with_checks([ 80 | YourAwesomeRegressor, 81 | MoreAwesomeClassifier, 82 | EvenMoreAwesomeTransformer, 83 | ]) 84 | def test_sklearn_compatible_estimator(estimator, check): 85 | check(estimator) 86 | ``` 87 | 88 | and it should be compatible with scikit-learn Pipeline, GridSearchCV, etc. 89 | 90 | ### Official guide 91 | 92 | Scikit-learn documentation on how to 93 | [develop estimators](https://scikit-learn.org/dev/developers/develop.html#developing-scikit-learn-estimators). 94 | 95 | ## Supported estimators 96 | 97 | The following types of scikit-learn estimator are supported: 98 | 99 | - ✅ Classifier 100 | - ✅ Regressor 101 | - ✅ Outlier Detector 102 | - ✅ Clusterer 103 | - ✅ Transformer 104 | - ✅ Feature Selector 105 | - 🚧 Meta Estimator 106 | 107 | ## Installation 108 | 109 | sklearn-smithy is available on [pypi](https://pypi.org/project/sklearn-smithy), so you can install it directly from there: 110 | 111 | ```bash 112 | python -m pip install sklearn-smithy 113 | ``` 114 | 115 | **Remark:** The minimum Python version required is 3.10. 116 | 117 | This will make the `smith` command available in your terminal, and you should be able to run the following: 118 | 119 | ```bash 120 | smith version 121 | ``` 122 | 123 | > sklearn-smithy=... 124 | 125 | ### Extra dependencies 126 | 127 | To run the TUI, you need to install the `textual` dependency as well: 128 | 129 | ```bash 130 | python -m pip install "sklearn-smithy[textual]" 131 | ``` 132 | 133 | ## User guide 📚 134 | 135 | Please refer to the dedicated [user guide](https://fbruzzesi.github.io/sklearn-smithy/user-guide/) documentation section. 136 | 137 | ## Origin story 138 | 139 | The idea for this tool originated from [scikit-lego #660](https://github.com/koaning/scikit-lego/pull/660), which I cannot better explain than quoting the PR description itself: 140 | 141 | > So the story goes as the following: 142 | > 143 | > - The CI/CD fails for scikit-learn==1.5rc1 because of a change in the `check_estimator` internals 144 | > - In the [scikit-learn issue](https://github.com/scikit-learn/scikit-learn/issues/28966) I got a better picture of how to run test for compatible components 145 | > - In particular, [rolling your own estimator](https://scikit-learn.org/dev/developers/develop.html#rolling-your-own-estimator) suggests to use [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks), and of course I thought "that is a great idea to avoid dealing manually with each test" 146 | > - Say no more, I enter a rabbit hole to refactor all our tests - which would be fine 147 | > - Except that these tests failures helped me figure out a few missing parts in the codebase 148 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | [default.extend-words] 2 | arange = "arange" # numpy function -------------------------------------------------------------------------------- /docs/contribute.md: -------------------------------------------------------------------------------- 1 | # Contributing 👏 2 | 3 | ## Guidelines 💡 4 | 5 | We welcome contributions to the library! If you have a bug fix or new feature that you would like to contribute, please follow the steps below: 6 | 7 | 1. Check the [existing issues](https://github.com/FBruzzesi/sklearn-smithy/issues){:target="_blank"} and/or [open a new one](https://github.com/FBruzzesi/sklearn-smithy/issues/new){:target="_blank"} to discuss the problem and potential solutions. 8 | 2. [Fork the repository](https://github.com/FBruzzesi/sklearn-smithy/fork){:target="_blank"} on GitHub. 9 | 3. [Clone the repository](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository){:target="_blank"} to your local machine. 10 | 4. Create a new branch for your bug fix or feature. 11 | 5. Make your changes and test them thoroughly, making sure that it passes all current tests. 12 | 6. Commit your changes and push the branch to your fork. 13 | 7. [Open a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request){:target="_blank"} on the main repository. 14 | 15 | ## Submitting Pull Requests 🎯 16 | 17 | When submitting a pull request, please make sure that you've followed the steps above and that your code has been thoroughly tested. Also, be sure to include a brief summary of the changes you've made and a reference to any issues that your pull request resolves. 18 | 19 | ## Code formatting 🚀 20 | 21 | **sklearn-smithy** uses [ruff](https://docs.astral.sh/ruff/){:target="_blank"} for both formatting and linting. Specific settings are declared in the pyproject.toml file. 22 | 23 | To format the code, you can run the following commands: 24 | 25 | === "with Make" 26 | 27 | ```bash 28 | make lint 29 | ``` 30 | 31 | === "without Make" 32 | 33 | ```bash 34 | ruff version 35 | ruff format smithy tests 36 | ruff check smithy tests --fix 37 | ruff clean 38 | ``` 39 | 40 | As part of the checks on pull requests, it is checked whether the code follows those standards. To ensure that the standard is met, it is recommended to install [pre-commit hooks](https://pre-commit.com/){:target="_blank"}: 41 | 42 | ```bash 43 | python -m pip install pre-commit 44 | pre-commit install 45 | ``` 46 | 47 | ## Developing 🐍 48 | 49 | Let's suppose that you already did steps 1-4 from the above list, now you should install the library and its developing dependencies in editable way. 50 | 51 | ```bash 52 | cd sklearn-smithy 53 | pip install -e ".[all]" --no-cache-dir 54 | pre-commit install 55 | ``` 56 | 57 | Now you are ready to proceed with all the changes you want to! 58 | 59 | ## Testing 🧪 60 | 61 | Once you are done with changes, you should: 62 | 63 | - add tests for the new features in the `/tests` folder 64 | - make sure that new features do not break existing codebase by running tests: 65 | 66 | === "with Make" 67 | 68 | ```bash 69 | make test 70 | ``` 71 | 72 | === "without Make" 73 | 74 | ```bash 75 | pytest tests -n auto 76 | ``` 77 | 78 | ## Docs 📑 79 | 80 | The documentation is generated using [mkdocs-material](https://squidfunk.github.io/mkdocs-material/){:target="_blank"}, the API part uses [mkdocstrings](https://mkdocstrings.github.io/){:target="_blank"}. 81 | 82 | If a new feature or a breaking change is developed, then we suggest to update documentation in the `/docs` folder as well, in order to describe how this can be used from a user perspective. 83 | -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | /** 2 | * custom.js 3 | * From https://github.com/tiangolo/typer/blob/master/docs/css/custom.css 4 | * 5 | * @author Sebastián Ramírez 6 | * @license MIT 7 | */ 8 | .termynal-comment { 9 | color: #4a968f; 10 | font-style: italic; 11 | display: block; 12 | } 13 | 14 | .termy [data-termynal] { 15 | white-space: pre-wrap; 16 | } 17 | 18 | a.external-link::after { 19 | /* \00A0 is a non-breaking space 20 | to make the mark be on the same line as the link 21 | */ 22 | content: "\00A0[↪]"; 23 | } 24 | 25 | a.internal-link::after { 26 | /* \00A0 is a non-breaking space 27 | to make the mark be on the same line as the link 28 | */ 29 | content: "\00A0↪"; 30 | } 31 | -------------------------------------------------------------------------------- /docs/css/termynal.css: -------------------------------------------------------------------------------- 1 | /** 2 | * termynal.js 3 | * 4 | * @author Lines Montani 5 | * @version 0.0.1 6 | * @license MIT 7 | */ 8 | 9 | :root { 10 | --color-bg: #252a33; 11 | --color-text: #eee; 12 | --color-text-subtle: #a2a2a2; 13 | } 14 | 15 | [data-termynal] { 16 | width: 750px; 17 | max-width: 100%; 18 | background: var(--color-bg); 19 | color: var(--color-text); 20 | /* font-size: 18px; */ 21 | font-size: 15px; 22 | /* font-family: 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; */ 23 | font-family: 'Roboto Mono', 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; 24 | border-radius: 4px; 25 | padding: 75px 45px 35px; 26 | position: relative; 27 | -webkit-box-sizing: border-box; 28 | box-sizing: border-box; 29 | line-height: 1.2; 30 | } 31 | 32 | [data-termynal]:before { 33 | content: ''; 34 | position: absolute; 35 | top: 15px; 36 | left: 15px; 37 | display: inline-block; 38 | width: 15px; 39 | height: 15px; 40 | border-radius: 50%; 41 | /* A little hack to display the window buttons in one pseudo element. */ 42 | background: #d9515d; 43 | -webkit-box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 44 | box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 45 | } 46 | 47 | [data-termynal]:after { 48 | content: 'bash'; 49 | position: absolute; 50 | color: var(--color-text-subtle); 51 | top: 5px; 52 | left: 0; 53 | width: 100%; 54 | text-align: center; 55 | } 56 | 57 | a[data-terminal-control] { 58 | text-align: right; 59 | display: block; 60 | color: #aebbff; 61 | } 62 | 63 | [data-ty] { 64 | display: block; 65 | line-height: 2; 66 | } 67 | 68 | [data-ty]:before { 69 | /* Set up defaults and ensure empty lines are displayed. */ 70 | content: ''; 71 | display: inline-block; 72 | vertical-align: middle; 73 | } 74 | 75 | [data-ty="input"]:before, 76 | [data-ty-prompt]:before { 77 | margin-right: 0.75em; 78 | color: var(--color-text-subtle); 79 | } 80 | 81 | [data-ty="input"]:before { 82 | content: '$'; 83 | } 84 | 85 | [data-ty][data-ty-prompt]:before { 86 | content: attr(data-ty-prompt); 87 | } 88 | 89 | [data-ty-cursor]:after { 90 | content: attr(data-ty-cursor); 91 | font-family: monospace; 92 | margin-left: 0.5em; 93 | -webkit-animation: blink 1s infinite; 94 | animation: blink 1s infinite; 95 | } 96 | 97 | 98 | /* Cursor animation */ 99 | 100 | @-webkit-keyframes blink { 101 | 50% { 102 | opacity: 0; 103 | } 104 | } 105 | 106 | @keyframes blink { 107 | 50% { 108 | opacity: 0; 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /docs/img/cli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/docs/img/cli.png -------------------------------------------------------------------------------- /docs/img/sksmith-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/docs/img/sksmith-logo.png -------------------------------------------------------------------------------- /docs/img/sksmith-logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /docs/img/tui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/docs/img/tui.png -------------------------------------------------------------------------------- /docs/img/webui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/docs/img/webui.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Scikit-learn Smithy 4 | 5 | Scikit-learn smithy is a tool that helps you to forge scikit-learn compatible estimator with ease. 6 | 7 | How can you use it? 8 | 9 | - [x] Directly from the browser via our [web UI](https://sklearn-smithy.streamlit.app/){:target="_blank"} ([more info](user-guide.md/#web-ui)) 10 | - [x] As a CLI (command line interface) in your terminal via the `smith forge` command ([more info](user-guide.md/#cli)) 11 | - [x] As a TUI (terminal user interface) in your terminal via the `smith forge-tui` command ([more info](user-guide.md/#tui)) 12 | 13 | !!! info 14 | 15 | All these tools will prompt a series of questions regarding the estimator you want to create, and then it will generate the boilerplate code for you. 16 | 17 | ## Supported estimators 18 | 19 | The following types of scikit-learn estimator are supported: 20 | 21 | - [x] Classifier 22 | - [x] Regressor 23 | - [x] Outlier Detector 24 | - [x] Clusterer 25 | - [x] Transformer 26 | - [x] Feature Selector 27 | - [ ] Meta Estimator 28 | 29 | ## Origin story 30 | 31 | The idea for this tool originated from [scikit-lego #660](https://github.com/koaning/scikit-lego/pull/660){:target="_blank"}, which I cannot better explain than quoting the PR description itself: 32 | 33 | > So the story goes as the following: 34 | > 35 | > - The CI/CD fails for scikit-learn==1.5rc1 because of a change in the `check_estimator` internals 36 | > - In the [scikit-learn issue](https://github.com/scikit-learn/scikit-learn/issues/28966){:target="_blank"} I got a better picture of how to run test for compatible components 37 | > - In particular, [rolling your own estimator](https://scikit-learn.org/dev/developers/develop.html#rolling-your-own-estimator){:target="_blank"} suggests to use [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks){:target="_blank"}, and of course I thought "that is a great idea to avoid dealing manually with each test" 38 | > - Say no more, I enter a rabbit hole to refactor all our tests - which would be fine 39 | > - Except that these tests failures helped me figure out a few missing parts in the codebase 40 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation ✨ 2 | 3 | sklearn-smithy is available on [pypi](https://pypi.org/project/sklearn-smithy){:target="_blank"}, so you can install it directly from there: 4 | 5 | ```bash 6 | python -m pip install sklearn-smithy 7 | ``` 8 | 9 | !!! warning 10 | The minimum Python version required is 3.10. 11 | 12 | This will make the `smith` command available in your terminal, and you should be able to run the following: 13 | 14 | ```bash 15 | smith version 16 | ``` 17 | 18 | > sklearn-smithy=... 19 | 20 | ## Extra dependencies 21 | 22 | To run the TUI (`smith forge-tui`), you need to install the `textual` dependency as well: 23 | 24 | ```bash 25 | python -m pip install "sklearn-smithy[textual]" 26 | ``` 27 | 28 | To run the WebUI locally (`smith forge-webui`), you need to install the `streamlit` dependency as well: 29 | 30 | ```bash 31 | python -m pip install "sklearn-smithy[streamlit]" 32 | ``` 33 | 34 | ## Other installation methods 35 | 36 | === "pip + source/git" 37 | 38 | ```bash 39 | python -m pip install git+https://github.com/FBruzzesi/sklearn-smithy.git 40 | ``` 41 | 42 | === "local clone" 43 | 44 | ```bash 45 | git clone https://github.com/FBruzzesi/sklearn-smithy.git 46 | cd sklearn-smithy 47 | python -m pip install . 48 | ``` 49 | -------------------------------------------------------------------------------- /docs/js/custom.js: -------------------------------------------------------------------------------- 1 | /** 2 | * custom.js 3 | * From https://github.com/tiangolo/typer/blob/master/docs/js/custom.js 4 | * 5 | * @author Sebastián Ramírez 6 | * @license MIT 7 | */ 8 | document.querySelectorAll(".use-termynal").forEach(node => { 9 | node.style.display = "block"; 10 | new Termynal(node, { 11 | lineDelay: 500 12 | }); 13 | }); 14 | const progressLiteralStart = "---> 100%"; 15 | const promptLiteralStart = "$ "; 16 | const customPromptLiteralStart = "# "; 17 | const termynalActivateClass = "termy"; 18 | let termynals = []; 19 | 20 | function createTermynals() { 21 | document 22 | .querySelectorAll(`.${termynalActivateClass} .highlight`) 23 | .forEach(node => { 24 | const text = node.textContent; 25 | const lines = text.split("\n"); 26 | const useLines = []; 27 | let buffer = []; 28 | function saveBuffer() { 29 | if (buffer.length) { 30 | let isBlankSpace = true; 31 | buffer.forEach(line => { 32 | if (line) { 33 | isBlankSpace = false; 34 | } 35 | }); 36 | dataValue = {}; 37 | if (isBlankSpace) { 38 | dataValue["delay"] = 0; 39 | } 40 | if (buffer[buffer.length - 1] === "") { 41 | // A last single
won't have effect 42 | // so put an additional one 43 | buffer.push(""); 44 | } 45 | const bufferValue = buffer.join("
"); 46 | dataValue["value"] = bufferValue; 47 | useLines.push(dataValue); 48 | buffer = []; 49 | } 50 | } 51 | for (let line of lines) { 52 | if (line === progressLiteralStart) { 53 | saveBuffer(); 54 | useLines.push({ 55 | type: "progress" 56 | }); 57 | } else if (line.startsWith(promptLiteralStart)) { 58 | saveBuffer(); 59 | const value = line.replace(promptLiteralStart, "").trimEnd(); 60 | useLines.push({ 61 | type: "input", 62 | value: value 63 | }); 64 | } else if (line.startsWith("// ")) { 65 | saveBuffer(); 66 | const value = "💬 " + line.replace("// ", "").trimEnd(); 67 | useLines.push({ 68 | value: value, 69 | class: "termynal-comment", 70 | delay: 0 71 | }); 72 | } else if (line.startsWith(customPromptLiteralStart)) { 73 | saveBuffer(); 74 | const promptStart = line.indexOf(promptLiteralStart); 75 | if (promptStart === -1) { 76 | console.error("Custom prompt found but no end delimiter", line) 77 | } 78 | const prompt = line.slice(0, promptStart).replace(customPromptLiteralStart, "") 79 | let value = line.slice(promptStart + promptLiteralStart.length); 80 | useLines.push({ 81 | type: "input", 82 | value: value, 83 | prompt: prompt 84 | }); 85 | } else { 86 | buffer.push(line); 87 | } 88 | } 89 | saveBuffer(); 90 | const div = document.createElement("div"); 91 | node.replaceWith(div); 92 | const termynal = new Termynal(div, { 93 | lineData: useLines, 94 | noInit: true, 95 | lineDelay: 500 96 | }); 97 | termynals.push(termynal); 98 | }); 99 | } 100 | 101 | function loadVisibleTermynals() { 102 | termynals = termynals.filter(termynal => { 103 | if (termynal.container.getBoundingClientRect().top - innerHeight <= 0) { 104 | termynal.init(); 105 | return false; 106 | } 107 | return true; 108 | }); 109 | } 110 | window.addEventListener("scroll", loadVisibleTermynals); 111 | createTermynals(); 112 | loadVisibleTermynals(); 113 | -------------------------------------------------------------------------------- /docs/js/termynal.js: -------------------------------------------------------------------------------- 1 | /** 2 | * termynal.js 3 | * A lightweight, modern and extensible animated terminal window, using 4 | * async/await. 5 | * 6 | * @author Lines Montani 7 | * @version 0.0.1 8 | * @license MIT 9 | */ 10 | 11 | 'use strict'; 12 | 13 | /** Generate a terminal widget. */ 14 | class Termynal { 15 | /** 16 | * Construct the widget's settings. 17 | * @param {(string|Node)=} container - Query selector or container element. 18 | * @param {Object=} options - Custom settings. 19 | * @param {string} options.prefix - Prefix to use for data attributes. 20 | * @param {number} options.startDelay - Delay before animation, in ms. 21 | * @param {number} options.typeDelay - Delay between each typed character, in ms. 22 | * @param {number} options.lineDelay - Delay between each line, in ms. 23 | * @param {number} options.progressLength - Number of characters displayed as progress bar. 24 | * @param {string} options.progressChar – Character to use for progress bar, defaults to █. 25 | * @param {number} options.progressPercent - Max percent of progress. 26 | * @param {string} options.cursor – Character to use for cursor, defaults to ▋. 27 | * @param {Object[]} lineData - Dynamically loaded line data objects. 28 | * @param {boolean} options.noInit - Don't initialise the animation. 29 | */ 30 | constructor(container = '#termynal', options = {}) { 31 | this.container = (typeof container === 'string') ? document.querySelector(container) : container; 32 | this.pfx = `data-${options.prefix || 'ty'}`; 33 | this.originalStartDelay = this.startDelay = options.startDelay 34 | || parseFloat(this.container.getAttribute(`${this.pfx}-startDelay`)) || 600; 35 | this.originalTypeDelay = this.typeDelay = options.typeDelay 36 | || parseFloat(this.container.getAttribute(`${this.pfx}-typeDelay`)) || 90; 37 | this.originalLineDelay = this.lineDelay = options.lineDelay 38 | || parseFloat(this.container.getAttribute(`${this.pfx}-lineDelay`)) || 1500; 39 | this.progressLength = options.progressLength 40 | || parseFloat(this.container.getAttribute(`${this.pfx}-progressLength`)) || 40; 41 | this.progressChar = options.progressChar 42 | || this.container.getAttribute(`${this.pfx}-progressChar`) || '█'; 43 | this.progressPercent = options.progressPercent 44 | || parseFloat(this.container.getAttribute(`${this.pfx}-progressPercent`)) || 100; 45 | this.cursor = options.cursor 46 | || this.container.getAttribute(`${this.pfx}-cursor`) || '▋'; 47 | this.lineData = this.lineDataToElements(options.lineData || []); 48 | this.loadLines() 49 | if (!options.noInit) this.init() 50 | } 51 | 52 | loadLines() { 53 | // Load all the lines and create the container so that the size is fixed 54 | // Otherwise it would be changing and the user viewport would be constantly 55 | // moving as she/he scrolls 56 | const finish = this.generateFinish() 57 | finish.style.visibility = 'hidden' 58 | this.container.appendChild(finish) 59 | // Appends dynamically loaded lines to existing line elements. 60 | this.lines = [...this.container.querySelectorAll(`[${this.pfx}]`)].concat(this.lineData); 61 | for (let line of this.lines) { 62 | line.style.visibility = 'hidden' 63 | this.container.appendChild(line) 64 | } 65 | const restart = this.generateRestart() 66 | restart.style.visibility = 'hidden' 67 | this.container.appendChild(restart) 68 | this.container.setAttribute('data-termynal', ''); 69 | } 70 | 71 | /** 72 | * Initialise the widget, get lines, clear container and start animation. 73 | */ 74 | init() { 75 | /** 76 | * Calculates width and height of Termynal container. 77 | * If container is empty and lines are dynamically loaded, defaults to browser `auto` or CSS. 78 | */ 79 | const containerStyle = getComputedStyle(this.container); 80 | this.container.style.width = containerStyle.width !== '0px' ? 81 | containerStyle.width : undefined; 82 | this.container.style.minHeight = containerStyle.height !== '0px' ? 83 | containerStyle.height : undefined; 84 | 85 | this.container.setAttribute('data-termynal', ''); 86 | this.container.innerHTML = ''; 87 | for (let line of this.lines) { 88 | line.style.visibility = 'visible' 89 | } 90 | this.start(); 91 | } 92 | 93 | /** 94 | * Start the animation and rener the lines depending on their data attributes. 95 | */ 96 | async start() { 97 | this.addFinish() 98 | await this._wait(this.startDelay); 99 | 100 | for (let line of this.lines) { 101 | const type = line.getAttribute(this.pfx); 102 | const delay = line.getAttribute(`${this.pfx}-delay`) || this.lineDelay; 103 | 104 | if (type == 'input') { 105 | line.setAttribute(`${this.pfx}-cursor`, this.cursor); 106 | await this.type(line); 107 | await this._wait(delay); 108 | } 109 | 110 | else if (type == 'progress') { 111 | await this.progress(line); 112 | await this._wait(delay); 113 | } 114 | 115 | else { 116 | this.container.appendChild(line); 117 | await this._wait(delay); 118 | } 119 | 120 | line.removeAttribute(`${this.pfx}-cursor`); 121 | } 122 | this.addRestart() 123 | this.finishElement.style.visibility = 'hidden' 124 | this.lineDelay = this.originalLineDelay 125 | this.typeDelay = this.originalTypeDelay 126 | this.startDelay = this.originalStartDelay 127 | } 128 | 129 | generateRestart() { 130 | const restart = document.createElement('a') 131 | restart.onclick = (e) => { 132 | e.preventDefault() 133 | this.container.innerHTML = '' 134 | this.init() 135 | } 136 | restart.href = '#' 137 | restart.setAttribute('data-terminal-control', '') 138 | restart.innerHTML = "restart ↻" 139 | return restart 140 | } 141 | 142 | generateFinish() { 143 | const finish = document.createElement('a') 144 | finish.onclick = (e) => { 145 | e.preventDefault() 146 | this.lineDelay = 0 147 | this.typeDelay = 0 148 | this.startDelay = 0 149 | } 150 | finish.href = '#' 151 | finish.setAttribute('data-terminal-control', '') 152 | finish.innerHTML = "fast →" 153 | this.finishElement = finish 154 | return finish 155 | } 156 | 157 | addRestart() { 158 | const restart = this.generateRestart() 159 | this.container.appendChild(restart) 160 | } 161 | 162 | addFinish() { 163 | const finish = this.generateFinish() 164 | this.container.appendChild(finish) 165 | } 166 | 167 | /** 168 | * Animate a typed line. 169 | * @param {Node} line - The line element to render. 170 | */ 171 | async type(line) { 172 | const chars = [...line.textContent]; 173 | line.textContent = ''; 174 | this.container.appendChild(line); 175 | 176 | for (let char of chars) { 177 | const delay = line.getAttribute(`${this.pfx}-typeDelay`) || this.typeDelay; 178 | await this._wait(delay); 179 | line.textContent += char; 180 | } 181 | } 182 | 183 | /** 184 | * Animate a progress bar. 185 | * @param {Node} line - The line element to render. 186 | */ 187 | async progress(line) { 188 | const progressLength = line.getAttribute(`${this.pfx}-progressLength`) 189 | || this.progressLength; 190 | const progressChar = line.getAttribute(`${this.pfx}-progressChar`) 191 | || this.progressChar; 192 | const chars = progressChar.repeat(progressLength); 193 | const progressPercent = line.getAttribute(`${this.pfx}-progressPercent`) 194 | || this.progressPercent; 195 | line.textContent = ''; 196 | this.container.appendChild(line); 197 | 198 | for (let i = 1; i < chars.length + 1; i++) { 199 | await this._wait(this.typeDelay); 200 | const percent = Math.round(i / chars.length * 100); 201 | line.textContent = `${chars.slice(0, i)} ${percent}%`; 202 | if (percent>progressPercent) { 203 | break; 204 | } 205 | } 206 | } 207 | 208 | /** 209 | * Helper function for animation delays, called with `await`. 210 | * @param {number} time - Timeout, in ms. 211 | */ 212 | _wait(time) { 213 | return new Promise(resolve => setTimeout(resolve, time)); 214 | } 215 | 216 | /** 217 | * Converts line data objects into line elements. 218 | * 219 | * @param {Object[]} lineData - Dynamically loaded lines. 220 | * @param {Object} line - Line data object. 221 | * @returns {Element[]} - Array of line elements. 222 | */ 223 | lineDataToElements(lineData) { 224 | return lineData.map(line => { 225 | let div = document.createElement('div'); 226 | div.innerHTML = `${line.value || ''}`; 227 | 228 | return div.firstElementChild; 229 | }); 230 | } 231 | 232 | /** 233 | * Helper function for generating attributes string. 234 | * 235 | * @param {Object} line - Line data object. 236 | * @returns {string} - String of attributes. 237 | */ 238 | _attributes(line) { 239 | let attrs = ''; 240 | for (let prop in line) { 241 | // Custom add class 242 | if (prop === 'class') { 243 | attrs += ` class=${line[prop]} ` 244 | continue 245 | } 246 | if (prop === 'type') { 247 | attrs += `${this.pfx}="${line[prop]}" ` 248 | } else if (prop !== 'value') { 249 | attrs += `${this.pfx}-${prop}="${line[prop]}" ` 250 | } 251 | } 252 | 253 | return attrs; 254 | } 255 | } 256 | 257 | /** 258 | * HTML API: If current script has container(s) specified, initialise Termynal. 259 | */ 260 | if (document.currentScript.hasAttribute('data-termynal-container')) { 261 | const containers = document.currentScript.getAttribute('data-termynal-container'); 262 | containers.split('|') 263 | .forEach(container => new Termynal(container)) 264 | } 265 | -------------------------------------------------------------------------------- /docs/user-guide.md: -------------------------------------------------------------------------------- 1 | # User Guide 📚 2 | 3 | As introduced in the [home page](index.md), **sklearn-smithy** is a tool that helps you to forge scikit-learn compatible estimator with ease, and it comes in three flavours. 4 | 5 | Let's see how to use each one of them. 6 | 7 | ## Web UI 🌐 8 | 9 | TL;DR: 10 | 11 | - [x] Available at [sklearn-smithy.streamlit.app](https://sklearn-smithy.streamlit.app/){:target="_blank"} 12 | - [x] It requires no installation. 13 | - [x] Powered by [streamlit](https://streamlit.io/){:target="_blank"} 14 | 15 | The web UI is the most user-friendly, low barrier way, to interact with the tool by accessing it directly from your browser, without any installation required. 16 | 17 | Once the estimator is forged, you can download the script with the code as a `.py` file, or you can copy the code directly from the browser. 18 | 19 | ??? example "Screenshot" 20 | ![Web UI](img/webui.png) 21 | 22 | ## CLI ⌨️ 23 | 24 | TL;DR: 25 | 26 | - [x] Available via the `smith forge` command. 27 | - [x] It requires [installation](installation.md): `python -m pip install sklearn-smithy` 28 | - [x] Powered by [typer](https://typer.tiangolo.com/){:target="_blank"}. 29 | 30 | Once the library is installed, the `smith` CLI (Command Line Interface) will be available and that is the primary way to interact with the `smithy` package. 31 | 32 | The CLI provides a main command called `forge`, which will prompt a series of question in the terminal, based on which it will generate the code for the estimator. 33 | 34 | ### `smith forge` example 35 | 36 | Let's see an example of how to use `smith forge` command: 37 | 38 |
39 | 40 | ```console 41 | $ smith forge 42 | # 🐍 How would you like to name the estimator?:$ MightyClassifier 43 | # 🎯 Which kind of estimator is it? (classifier, outlier, regressor, transformer, cluster, feature-selector):$ classifier 44 | # 📜 Please list the required parameters (comma-separated) []:$ alpha,beta 45 | # 📑 Please list the optional parameters (comma-separated) []:$ mu,sigma 46 | # 📶 Does the `.fit()` method support `sample_weight`? [y/N]:$ y 47 | # 📏 Is the estimator linear? [y/N]:$ N 48 | # 🎲 Should the estimator implement a `predict_proba` method? [y/N]:$ N 49 | # ❓ Should the estimator implement a `decision_function` method? [y/N]:$ y 50 | # 🧪 We are almost there... Is there any tag you want to add? (comma-separated) []:$ binary_only,non_deterministic 51 | # 📂 Where would you like to save the class? [mightyclassifier.py]:$ path/to/file.py 52 | Template forged at path/to/file.py 53 | ``` 54 | 55 |
56 | 57 | Now the estimator template to be filled will be available at the specified path `path/to/file.py`. 58 | 59 |
60 | 61 | ```console 62 | $ cat path/to/file.py | head -n 5 63 | import numpy as np 64 | 65 | from sklearn.base import BaseEstimator, ClassifierMixin 66 | from sklearn.utils import check_X_y 67 | from sklearn.utils.validation import check_is_fitted, check_array 68 | ``` 69 | 70 |
71 | 72 | ### Non-interactive mode 73 | 74 | As for any CLI, in principle it would be possible to run it in a non-interactive way, however this is not *fully* supported (yet) and it comes with some risks and limitations. 75 | 76 | The reason for this is that the **validation** and the parameters **interaction** happen while prompting the questions *one after the other*, meaning that the input to one prompt will determine what follows next. 77 | 78 | It is still possible to run the CLI in a non-interactive way, but it is not recommended, as it may lead to unexpected results. 79 | 80 | Let's see an example of how to run the `smith forge` command in a non-interactive way: 81 | 82 | !!! example "Non-interactive mode" 83 | 84 | ```terminal 85 | smith forge \ 86 | --name MyEstimator \ 87 | --estimator-type classifier \ 88 | --required-params "a,b" \ 89 | --optional-params "" \ 90 | --no-sample-weight \ 91 | --no-predict-proba \ 92 | --linear \ 93 | --no-decision-function \ 94 | --tags "binary_only" \ 95 | --output-file path/to/file.py 96 | ``` 97 | 98 | Notice how all arguments must be specified, otherwise they will prompt anyway, which means that the command would be interactive. 99 | 100 | Secondly, there is nothing preventing us to run the command with contradictory arguments at the same time. Operating in such a way can lead to two scenarios: 101 | 102 | 1. The result will be correct, however unexpected from a user point of view. 103 | For instance, calling `--estimator-type classifier` with `--linear` and `--decision-function` flags, will not create a `decision_function` method, as `LinearClassifierMixin` already takes care of it. 104 | 2. The result will be incorrect, as the arguments are contradictory. 105 | 106 | The first case is not a problematic from a functional point of view, while the second will lead to a broken estimator. 107 | 108 | Our suggestion is to use the CLI always in an interactive way, as it will take care of the proprer arguments interaction. 109 | 110 | ## TUI 💻 111 | 112 | TL;DR: 113 | 114 | - [x] Available via the `smith forge-tui` command. 115 | - [x] It requires installing [extra dependencies](installation.md#extra-dependencies): `python -m pip install "sklearn-smithy[textual]"` 116 | - [x] Powered by [textual](https://textual.textualize.io/){:target="_blank"}. 117 | 118 | If you like the CLI, but prefer a more interactive and graphical way from the comfort of your terminal, you can use the TUI (Terminal User Interface) provided by the `smith forge-tui` command. 119 | 120 | ```console 121 | $ smith forge-tui 122 | ``` 123 | 124 | ```{.textual path="sksmithy/tui/_tui.py" columns="200" lines="35"} 125 | ``` 126 | -------------------------------------------------------------------------------- /docs/why.md: -------------------------------------------------------------------------------- 1 | # Why❓ 2 | 3 | Writing scikit-learn compatible estimators might be harder than expected. 4 | 5 | While everyone knows about the `fit` and `predict`, there are other behaviours, methods and attributes that 6 | scikit-learn might be expecting from your estimator depending on: 7 | 8 | - The type of estimator you're writing. 9 | - The signature of the estimator. 10 | - The signature of the `.fit(...)` method. 11 | 12 | Scikit-learn Smithy to the rescue: this tool aims to help you crafting your own estimator by asking a few 13 | questions about it, and then generating the boilerplate code. 14 | 15 | In this way you will be able to fully focus on the core implementation logic, and not on nitty-gritty details 16 | of the scikit-learn API. 17 | 18 | ## Sanity check 19 | 20 | Once the core logic is implemented, the estimator should be ready to test against the _somewhat official_ 21 | [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks){:target="_blank"} 22 | pytest compatible decorator: 23 | 24 | ```py 25 | from sklearn.utils.estimator_checks import parametrize_with_checks 26 | 27 | @parametrize_with_checks([ 28 | YourAwesomeRegressor, 29 | MoreAwesomeClassifier, 30 | EvenMoreAwesomeTransformer, 31 | ]) 32 | def test_sklearn_compatible_estimator(estimator, check): 33 | check(estimator) 34 | ``` 35 | 36 | and it should be compatible with scikit-learn Pipeline, GridSearchCV, etc. 37 | 38 | ## Official guide 39 | 40 | Scikit-learn documentation on how to 41 | [develop estimators](https://scikit-learn.org/dev/developers/develop.html#developing-scikit-learn-estimators){:target="_blank"}. 42 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # Project information 2 | site_name: Sklearn Smithy 3 | site_url: https://fbruzzesi.github.io/sklearn-smithy/ 4 | site_author: Francesco Bruzzesi 5 | site_description: Toolkit to forge scikit-learn compatible estimators 6 | 7 | # Repository information 8 | repo_name: FBruzzesi/sklearn-smithy 9 | repo_url: https://github.com/fbruzzesi/sklearn-smithy 10 | edit_uri: edit/main/docs/ 11 | 12 | # Configuration 13 | use_directory_urls: true 14 | theme: 15 | name: material 16 | font: false 17 | palette: 18 | - media: '(prefers-color-scheme: light)' 19 | scheme: default 20 | primary: teal 21 | accent: deep-orange 22 | toggle: 23 | icon: material/brightness-7 24 | name: Switch to light mode 25 | - media: '(prefers-color-scheme: dark)' 26 | scheme: slate 27 | primary: teal 28 | accent: deep-orange 29 | toggle: 30 | icon: material/brightness-4 31 | name: Switch to dark mode 32 | features: 33 | - search.suggest 34 | - search.highlight 35 | - search.share 36 | - toc.follow 37 | - content.tabs.link 38 | - content.code.annotate 39 | - content.code.copy 40 | 41 | logo: img/sksmith-logo.png 42 | favicon: img/sksmith-logo.png 43 | 44 | # Plugins 45 | plugins: 46 | - search: 47 | separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' 48 | 49 | # Customization 50 | extra: 51 | social: 52 | - icon: fontawesome/brands/github 53 | link: https://github.com/fbruzzesi 54 | - icon: fontawesome/brands/linkedin 55 | link: https://www.linkedin.com/in/francesco-bruzzesi/ 56 | - icon: fontawesome/brands/python 57 | link: https://pypi.org/project/sklearn-smithy/ 58 | 59 | # Extensions 60 | markdown_extensions: 61 | - abbr 62 | - admonition 63 | - attr_list 64 | - codehilite 65 | - def_list 66 | - footnotes 67 | - md_in_html 68 | - toc: 69 | permalink: true 70 | - pymdownx.inlinehilite 71 | - pymdownx.snippets 72 | - pymdownx.superfences: 73 | custom_fences: 74 | - name: textual 75 | class: textual 76 | format: !!python/name:textual._doc.format_svg 77 | - pymdownx.details 78 | - pymdownx.tasklist: 79 | custom_checkbox: true 80 | - pymdownx.tabbed: 81 | alternate_style: true 82 | - pymdownx.highlight: 83 | anchor_linenums: true 84 | line_spans: __span 85 | pygments_lang_class: true 86 | 87 | nav: 88 | - Home 🏠: index.md 89 | - Installation ✨: installation.md 90 | - Why ❓: why.md 91 | - User Guide 📚: user-guide.md 92 | - Contributing 👏: contribute.md 93 | 94 | extra_css: 95 | - css/termynal.css 96 | - css/custom.css 97 | extra_javascript: 98 | - js/termynal.js 99 | - js/custom.js 100 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | from nox.sessions import Session 3 | 4 | nox.options.default_venv_backend = "uv" 5 | nox.options.reuse_venv = True 6 | 7 | PYTHON_VERSIONS = ["3.10", "3.11", "3.12"] 8 | 9 | 10 | @nox.session(python=PYTHON_VERSIONS) # type: ignore[misc] 11 | @nox.parametrize("pre", [False, True]) 12 | def pytest_coverage(session: Session, pre: bool) -> None: 13 | """Run pytest coverage across different python versions.""" 14 | pkg_install = [".[all]", "-r", "requirements/test.txt"] 15 | 16 | if pre: 17 | pkg_install.append("--pre") 18 | 19 | session.install(*pkg_install) 20 | 21 | session.run("pytest", "tests", "--cov=sksmithy", "--cov=tests", "--cov-fail-under=90", "--numprocesses=auto") 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "sklearn-smithy" 7 | version = "0.2.0" 8 | description = "Toolkit to forge scikit-learn compatible estimators." 9 | requires-python = ">=3.10" 10 | 11 | license = {file = "LICENSE"} 12 | readme = "README.md" 13 | 14 | authors = [ 15 | {name = "Francesco Bruzzesi"} 16 | ] 17 | 18 | keywords = [ 19 | "python", 20 | "cli", 21 | "webui", 22 | "tui", 23 | "data-science", 24 | "machine-learning", 25 | "scikit-learn" 26 | ] 27 | 28 | dependencies = [ 29 | "typer>=0.12.0", 30 | "rich>=13.0.0", 31 | "jinja2>=3.0.0", 32 | "result>=0.16.0", 33 | "ruff>=0.4.0", 34 | ] 35 | 36 | classifiers = [ 37 | "Development Status :: 4 - Beta", 38 | "License :: OSI Approved :: MIT License", 39 | "Topic :: Software Development :: Libraries :: Python Modules", 40 | "Typing :: Typed", 41 | "Programming Language :: Python :: 3", 42 | "Programming Language :: Python :: 3.10", 43 | "Programming Language :: Python :: 3.11", 44 | "Programming Language :: Python :: 3.12", 45 | ] 46 | 47 | [project.urls] 48 | Repository = "https://github.com/FBruzzesi/sklearn-smithy" 49 | Issues = "https://github.com/FBruzzesi/sklearn-smithy/issues" 50 | Documentation = "https://fbruzzesi.github.io/sklearn-smithy" 51 | Website = "https://sklearn-smithy.streamlit.app/" 52 | 53 | 54 | [project.optional-dependencies] 55 | streamlit = ["streamlit>=1.34.0"] 56 | textual = ["textual[syntax]>=0.65.0"] 57 | 58 | all = [ 59 | "streamlit>=1.34.0", 60 | "textual>=0.65.0", 61 | ] 62 | 63 | [project.scripts] 64 | smith = "sksmithy.__main__:cli" 65 | 66 | [tool.hatch.build.targets.sdist] 67 | only-include = ["sksmithy"] 68 | 69 | [tool.hatch.build.targets.wheel] 70 | packages = ["sksmithy"] 71 | 72 | [tool.ruff] 73 | line-length = 120 74 | target-version = "py310" 75 | 76 | [tool.ruff.lint] 77 | select = ["ALL"] 78 | ignore = [ 79 | "COM812", 80 | "ISC001", 81 | "PLR0913", 82 | "FBT001", 83 | "FBT002", 84 | "S603", 85 | "S607", 86 | "D100", 87 | "D104", 88 | "D400", 89 | ] 90 | 91 | [tool.ruff.lint.per-file-ignores] 92 | "tests/*" = ["D103","S101"] 93 | 94 | [tool.ruff.lint.pydocstyle] 95 | convention = "numpy" 96 | 97 | [tool.ruff.lint.pyupgrade] 98 | keep-runtime-typing = true 99 | 100 | [tool.ruff.format] 101 | docstring-code-format = true 102 | 103 | [tool.mypy] 104 | ignore_missing_imports = true 105 | python_version = "3.10" 106 | 107 | [tool.coverage.run] 108 | source = ["sksmithy/"] 109 | omit = [ 110 | "sksmithy/__main__.py", 111 | "sksmithy/_arguments.py", 112 | "sksmithy/_logger.py", 113 | "sksmithy/_prompts.py", 114 | "sksmithy/tui/__init__.py", 115 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Used by streamlit deployment 2 | -e ."[streamlit]" -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | anyio 2 | pytest 3 | pytest-asyncio 4 | pytest-cov 5 | pytest-tornasync 6 | pytest-trio 7 | pytest-xdist -------------------------------------------------------------------------------- /sksmithy/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | 3 | __title__ = "sksmithy" 4 | __version__ = metadata.version("sklearn-smithy") 5 | -------------------------------------------------------------------------------- /sksmithy/__main__.py: -------------------------------------------------------------------------------- 1 | from sksmithy.cli import cli 2 | 3 | if __name__ == "__main__": 4 | cli() 5 | -------------------------------------------------------------------------------- /sksmithy/_arguments.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from typer import Option 4 | 5 | from sksmithy._callbacks import estimator_callback, linear_callback, name_callback, params_callback, tags_callback 6 | from sksmithy._models import EstimatorType 7 | from sksmithy._prompts import ( 8 | PROMPT_DECISION_FUNCTION, 9 | PROMPT_ESTIMATOR, 10 | PROMPT_LINEAR, 11 | PROMPT_NAME, 12 | PROMPT_OPTIONAL, 13 | PROMPT_OUTPUT, 14 | PROMPT_PREDICT_PROBA, 15 | PROMPT_REQUIRED, 16 | PROMPT_SAMPLE_WEIGHT, 17 | PROMPT_TAGS, 18 | ) 19 | 20 | name_arg = Annotated[ 21 | str, 22 | Option( 23 | prompt=PROMPT_NAME, 24 | help="[bold green]Name[/bold green] of the estimator", 25 | callback=name_callback, 26 | ), 27 | ] 28 | 29 | estimator_type_arg = Annotated[ 30 | EstimatorType, 31 | Option( 32 | prompt=PROMPT_ESTIMATOR, 33 | help="[bold green]Estimator type[/bold green]", 34 | callback=estimator_callback, 35 | ), 36 | ] 37 | 38 | required_params_arg = Annotated[ 39 | str, 40 | Option( 41 | prompt=PROMPT_REQUIRED, 42 | help="List of [italic yellow](comma-separated)[/italic yellow] [bold green]required[/bold green] parameters", 43 | callback=params_callback, 44 | ), 45 | ] 46 | 47 | optional_params_arg = Annotated[ 48 | str, 49 | Option( 50 | prompt=PROMPT_OPTIONAL, 51 | help="List of [italic yellow](comma-separated)[/italic yellow] [bold green]optional[/bold green] parameters", 52 | callback=params_callback, 53 | ), 54 | ] 55 | 56 | sample_weight_arg = Annotated[ 57 | bool, 58 | Option( 59 | is_flag=True, 60 | prompt=PROMPT_SAMPLE_WEIGHT, 61 | help="Whether or not `.fit()` supports [bold green]`sample_weight`[/bold green]", 62 | ), 63 | ] 64 | 65 | linear_arg = Annotated[ 66 | bool, 67 | Option( 68 | is_flag=True, 69 | prompt=PROMPT_LINEAR, 70 | help="Whether or not the estimator is [bold green]linear[/bold green]", 71 | callback=linear_callback, 72 | ), 73 | ] 74 | 75 | predict_proba_arg = Annotated[ 76 | bool, 77 | Option( 78 | is_flag=True, 79 | prompt=PROMPT_PREDICT_PROBA, 80 | help="Whether or not the estimator implements [bold green]`predict_proba`[/bold green] method", 81 | ), 82 | ] 83 | 84 | decision_function_arg = Annotated[ 85 | bool, 86 | Option( 87 | is_flag=True, 88 | prompt=PROMPT_DECISION_FUNCTION, 89 | help="Whether or not the estimator implements [bold green]`decision_function`[/bold green] method", 90 | ), 91 | ] 92 | 93 | tags_arg = Annotated[ 94 | str, 95 | Option( 96 | prompt=PROMPT_TAGS, 97 | help="List of optional extra scikit-learn [bold green]tags[/bold green]", 98 | callback=tags_callback, 99 | ), 100 | ] 101 | 102 | output_file_arg = Annotated[ 103 | str, 104 | Option( 105 | prompt=PROMPT_OUTPUT, 106 | help="[bold green]Destination file[/bold green] where to save the boilerplate code", 107 | ), 108 | ] 109 | -------------------------------------------------------------------------------- /sksmithy/_callbacks.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Concatenate, ParamSpec, TypeVar 3 | 4 | from result import Err, Ok, Result 5 | from typer import BadParameter, CallbackParam, Context 6 | 7 | from sksmithy._models import EstimatorType 8 | from sksmithy._parsers import check_duplicates, name_parser, params_parser, tags_parser 9 | 10 | T = TypeVar("T") 11 | R = TypeVar("R") 12 | PS = ParamSpec("PS") 13 | 14 | 15 | def _parse_wrapper( 16 | ctx: Context, 17 | param: CallbackParam, 18 | value: T, 19 | parser: Callable[Concatenate[T, PS], Result[R, str]], 20 | *args: PS.args, 21 | **kwargs: PS.kwargs, 22 | ) -> tuple[Context, CallbackParam, R]: 23 | """Wrap a parser to handle 'caching' logic. 24 | 25 | `parser` should return a Result[R, str] 26 | 27 | Parameters 28 | ---------- 29 | ctx 30 | Typer context. 31 | param 32 | Callback parameter information. 33 | value 34 | Input for the parser callable. 35 | parser 36 | Parser function, it should return Result[R, str] 37 | *args 38 | Extra args for `parser`. 39 | **kwargs 40 | Extra kwargs for `parser`. 41 | 42 | Returns 43 | ------- 44 | ctx : Context 45 | Typer context updated with extra information. 46 | param : CallbackParam 47 | Unchanged callback parameters. 48 | result_value : R 49 | Parsed value. 50 | 51 | Raises 52 | ------ 53 | BadParameter 54 | If parser returns Err(msg) 55 | """ 56 | if not ctx.obj: 57 | ctx.obj = {} 58 | 59 | if param.name in ctx.obj: 60 | return ctx, param, ctx.obj[param.name] 61 | 62 | result = parser(value, *args, **kwargs) 63 | match result: 64 | case Ok(result_value): 65 | ctx.obj[param.name] = result_value 66 | return ctx, param, result_value 67 | case Err(msg): 68 | raise BadParameter(msg) 69 | 70 | 71 | def name_callback(ctx: Context, param: CallbackParam, value: str) -> str: 72 | """`name` argument callback. 73 | 74 | After parsing `name`, changes the default value of `output_file` argument to `{name.lower()}.py`. 75 | """ 76 | *_, name = _parse_wrapper(ctx, param, value, name_parser) 77 | 78 | # Change default value of output_file argument 79 | all_options = ctx.command.params 80 | output_file_option = next(opt for opt in all_options if opt.name == "output_file") 81 | output_file_option.default = f"{name.lower()}.py" 82 | 83 | return name 84 | 85 | 86 | def params_callback(ctx: Context, param: CallbackParam, value: str) -> list[str]: 87 | """`required_params` and `optional_params` arguments callback.""" 88 | ctx, param, parsed_params = _parse_wrapper(ctx, param, value, params_parser) 89 | 90 | if param.name == "optional_params" and ( 91 | msg := check_duplicates( 92 | required=ctx.params["required_params"], 93 | optional=parsed_params, 94 | ) 95 | ): 96 | del ctx.obj[param.name] 97 | raise BadParameter(msg) 98 | 99 | return parsed_params 100 | 101 | 102 | def tags_callback(ctx: Context, param: CallbackParam, value: str) -> list[str]: 103 | """`tags` argument callback.""" 104 | *_, parsed_value = _parse_wrapper(ctx, param, value, tags_parser) 105 | return parsed_value 106 | 107 | 108 | def estimator_callback(ctx: Context, param: CallbackParam, estimator: EstimatorType) -> str: 109 | """`estimator_type` argument callback. 110 | 111 | It dynamically modifies the behaviour of the rest of the prompts based on its value: 112 | 113 | - If not classifier or regressor, turns off linear prompt. 114 | - If not classifier or outlier, turns off predict_proba prompt. 115 | - If not classifier, turns off decision_function prompt. 116 | """ 117 | if not ctx.obj: # pragma: no cover 118 | ctx.obj = {} 119 | 120 | if param.name in ctx.obj: 121 | return ctx.obj[param.name] 122 | 123 | # !Warning: This unpacking relies on the order of the arguments in the forge command to be in the same order. 124 | # Is there a better/more robust way of dealing with it? 125 | linear, predict_proba, decision_function = ( 126 | opt for opt in ctx.command.params if opt.name in {"linear", "predict_proba", "decision_function"} 127 | ) 128 | 129 | match estimator: 130 | case EstimatorType.ClassifierMixin | EstimatorType.RegressorMixin: 131 | pass 132 | case _: 133 | linear.prompt = False # type: ignore[attr-defined] 134 | linear.prompt_required = False # type: ignore[attr-defined] 135 | 136 | match estimator: 137 | case EstimatorType.ClassifierMixin | EstimatorType.OutlierMixin: 138 | pass 139 | case _: 140 | predict_proba.prompt = False # type: ignore[attr-defined] 141 | predict_proba.prompt_required = False # type: ignore[attr-defined] 142 | 143 | match estimator: 144 | case EstimatorType.ClassifierMixin: 145 | pass 146 | case _: 147 | decision_function.prompt = False # type: ignore[attr-defined] 148 | decision_function.prompt_required = False # type: ignore[attr-defined] 149 | 150 | ctx.obj[param.name] = estimator.value 151 | 152 | return estimator.value 153 | 154 | 155 | def linear_callback(ctx: Context, param: CallbackParam, linear: bool) -> bool: 156 | """`linear` argument callback. 157 | 158 | It dynamically modifies the behaviour of the rest of the prompts based on its value: if the estimator is linear, 159 | then `decision_function` method is already implemented for a classifier. 160 | """ 161 | if not ctx.obj: # pragma: no cover 162 | ctx.obj = {} 163 | 164 | if param.name in ctx.obj: # pragma: no cover 165 | return ctx.obj[param.name] 166 | 167 | decision_function = next(opt for opt in ctx.command.params if opt.name == "decision_function") 168 | 169 | match linear: 170 | case True: 171 | decision_function.prompt = False # type: ignore[attr-defined] 172 | decision_function.prompt_required = False # type: ignore[attr-defined] 173 | case False: 174 | pass 175 | 176 | ctx.obj[param.name] = linear 177 | 178 | return linear 179 | -------------------------------------------------------------------------------- /sksmithy/_logger.py: -------------------------------------------------------------------------------- 1 | from rich.console import Console 2 | from rich.theme import Theme 3 | 4 | custom_theme = Theme( 5 | { 6 | "good": "bold green", 7 | "warning": "bold yellow", 8 | "bad": "bold red", 9 | } 10 | ) 11 | console = Console(theme=custom_theme) 12 | -------------------------------------------------------------------------------- /sksmithy/_models.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class EstimatorType(str, Enum): 5 | """List of possible estimator types. 6 | 7 | The reason of naming the enum with the mixin class is to simplify and have a convenient way of using the enum to 8 | render the jinja template with the class to import. 9 | """ 10 | 11 | ClassifierMixin = "classifier" 12 | RegressorMixin = "regressor" 13 | OutlierMixin = "outlier" 14 | ClusterMixin = "cluster" 15 | TransformerMixin = "transformer" 16 | SelectorMixin = "feature-selector" 17 | 18 | 19 | class TagType(str, Enum): 20 | """List of extra tags. 21 | 22 | Description of each tag is available in the dedicated section of the scikit-learn documentation: 23 | [estimator tags](https://scikit-learn.org/dev/developers/develop.html#estimator-tags). 24 | """ 25 | 26 | allow_nan = "allow_nan" 27 | array_api_support = "array_api_support" 28 | binary_only = "binary_only" 29 | multilabel = "multilabel" 30 | multioutput = "multioutput" 31 | multioutput_only = "multioutput_only" 32 | no_validation = "no_validation" 33 | non_deterministic = "non_deterministic" 34 | pairwise = "pairwise" 35 | preserves_dtype = "preserves_dtype" 36 | poor_score = "poor_score" 37 | requires_fit = "requires_fit" 38 | requires_positive_X = "requires_positive_X" # noqa: N815 39 | requires_y = "requires_y" 40 | requires_positive_y = "requires_positive_y" 41 | _skip_test = "_skip_test" 42 | _xfail_checks = "_xfail_checks" 43 | stateless = "stateless" 44 | X_types = "X_types" 45 | -------------------------------------------------------------------------------- /sksmithy/_parsers.py: -------------------------------------------------------------------------------- 1 | from keyword import iskeyword 2 | 3 | from result import Err, Ok, Result 4 | 5 | from sksmithy._models import TagType 6 | 7 | 8 | def name_parser(name: str | None) -> Result[str, str]: 9 | """Validate that `name` is a valid python class name. 10 | 11 | The parser returns `Err(...)` if: 12 | 13 | - `name` is not a valid python identifier 14 | - `name` is a python reserved keyword 15 | - `name` is empty 16 | 17 | Otherwise it returns `Ok(name)`. 18 | """ 19 | if name: 20 | if not name.isidentifier(): 21 | msg = f"`{name}` is not a valid python class name!" 22 | return Err(msg) 23 | if iskeyword(name): 24 | msg = f"`{name}` is a python reserved keyword!" 25 | return Err(msg) 26 | return Ok(name) 27 | msg = "Name cannot be empty!" 28 | return Err(msg) 29 | 30 | 31 | def params_parser(params: str | None) -> Result[list[str], str]: 32 | """Parse and validate that `params` contains valid python names. 33 | 34 | The parser first splits params on commas to get a list of strings. Then it returns `Err(...)` if: 35 | 36 | - any element in the list is not a valid python identifier 37 | - any element is repeated more than once 38 | 39 | Otherwise it returns `Ok(params.split(","))`. 40 | """ 41 | param_list: list[str] = params.split(",") if params else [] 42 | invalid = tuple(p for p in param_list if not p.isidentifier()) 43 | 44 | if len(invalid) > 0: 45 | msg = f"The following parameters are invalid python identifiers: {invalid}" 46 | return Err(msg) 47 | 48 | if len(set(param_list)) < len(param_list): 49 | msg = "Found repeated parameters!" 50 | return Err(msg) 51 | 52 | return Ok(param_list) 53 | 54 | 55 | def check_duplicates(required: list[str], optional: list[str]) -> str | None: 56 | """Check that there are not duplicates between required and optional params.""" 57 | duplicated_params = set(required).intersection(set(optional)) 58 | return ( 59 | f"The following parameters are duplicated between required and optional: {duplicated_params}" 60 | if duplicated_params 61 | else None 62 | ) 63 | 64 | 65 | def tags_parser(tags: str) -> Result[list[str], str]: 66 | """Parse and validate `tags` by comparing with sklearn list. 67 | 68 | The parser first splits tags on commas to get a list of strings. Then it returns `Err(...)` if any of the tag is not 69 | in the scikit-learn supported list. 70 | 71 | Otherwise it returns `Ok(tags.split(","))` 72 | """ 73 | list_tag: list[str] = tags.split(",") if tags else [] 74 | 75 | unavailable_tags = tuple(t for t in list_tag if t not in TagType.__members__) 76 | if len(unavailable_tags): 77 | msg = ( 78 | f"The following tags are not available: {unavailable_tags}." 79 | "\nPlease check the official documentation at " 80 | "https://scikit-learn.org/dev/developers/develop.html#estimator-tags" 81 | " to know which values are available." 82 | ) 83 | 84 | return Err(msg) 85 | 86 | return Ok(list_tag) 87 | -------------------------------------------------------------------------------- /sksmithy/_prompts.py: -------------------------------------------------------------------------------- 1 | from typing import Final 2 | 3 | PROMPT_NAME: Final[str] = "🐍 How would you like to name the estimator?" 4 | PROMPT_ESTIMATOR: Final[str] = "🎯 Which kind of estimator is it?" 5 | PROMPT_REQUIRED: Final[str] = "📜 Please list the required parameters (comma-separated)" 6 | PROMPT_OPTIONAL: Final[str] = "📑 Please list the optional parameters (comma-separated)" 7 | PROMPT_SAMPLE_WEIGHT: Final[str] = "📶 Does the `.fit()` method support `sample_weight`?" 8 | PROMPT_LINEAR: Final[str] = "📏 Is the estimator linear?" 9 | PROMPT_PREDICT_PROBA: Final[str] = "🎲 Should the estimator implement a `predict_proba` method?" 10 | PROMPT_DECISION_FUNCTION: Final[str] = "❓ Should the estimator implement a `decision_function` method?" 11 | PROMPT_TAGS: Final[str] = ( 12 | "🧪 We are almost there... Is there any tag you want to add? (comma-separated)\n" 13 | "To know more about tags, check the documentation at:\n" 14 | "https://scikit-learn.org/dev/developers/develop.html#estimator-tags" 15 | ) 16 | PROMPT_OUTPUT: Final[str] = "📂 Where would you like to save the class?" 17 | -------------------------------------------------------------------------------- /sksmithy/_static/description.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Writing scikit-learn compatible estimators might be harder than expected. 4 | 5 | While everyone knows about the `fit` and `predict`, there are other behaviours, methods and attributes that 6 | scikit-learn might be expecting from your estimator depending on: 7 | 8 | - The type of estimator you're writing. 9 | - The signature of the estimator. 10 | - The signature of the `.fit(...)` method. 11 | 12 | Scikit-learn Smithy to the rescue: this tool aims to help you crafting your own estimator by asking a few 13 | questions about it, and then generating the boilerplate code. 14 | 15 | In this way you will be able to fully focus on the core implementation logic, and not on nitty-gritty details 16 | of the scikit-learn API. 17 | 18 | ## Sanity check 19 | 20 | Once the core logic is implemented, the estimator should be ready to test against the _somewhat official_ 21 | [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks) 22 | pytest compatible decorator. 23 | 24 | ## Official guide 25 | 26 | Scikit-learn documentation on how to 27 | [develop estimators](https://scikit-learn.org/dev/developers/develop.html#developing-scikit-learn-estimators). 28 | -------------------------------------------------------------------------------- /sksmithy/_static/template.py.jinja: -------------------------------------------------------------------------------- 1 | {%- if estimator_type in ('classifier', 'feature-selector') %} 2 | import numpy as np 3 | {% endif -%} 4 | {%- if estimator_type == 'classifier' and linear %} 5 | from sklearn.base import BaseEstimator 6 | from sklearn.linear_model._base import LinearClassifierMixin 7 | {% elif estimator_type == 'regressor' and linear%} 8 | from sklearn.base import {{ mixin }} 9 | from sklearn.linear_model._base import LinearModel 10 | {% elif estimator_type == 'feature-selector'%} 11 | from sklearn.base import BaseEstimator 12 | from sklearn.feature_selection import SelectorMixin 13 | {% else %} 14 | from sklearn.base import BaseEstimator, {{ mixin }} 15 | {% endif -%} 16 | from sklearn.utils import check_X_y 17 | from sklearn.utils.validation import check_is_fitted, check_array 18 | 19 | {% if sample_weight %}from sklearn.utils.validation import _check_sample_weight{% endif %} 20 | 21 | 22 | class {{ name }}( 23 | {% if estimator_type == 'classifier' and linear %} 24 | LinearClassifierMixin, BaseEstimator 25 | {% elif estimator_type == 'regressor' and linear%} 26 | RegressorMixin, LinearModel 27 | {%else %} 28 | {{ mixin }}, BaseEstimator 29 | {% endif %}): 30 | """{{ name }} estimator. 31 | 32 | ... 33 | {% if parameters %} 34 | Parameters 35 | ---------- 36 | {% for param in parameters %} 37 | {{- param }} : ... 38 | {% endfor -%} 39 | {% endif -%} 40 | """ 41 | {% if required %}_required_parameters = {{ required }}{% endif -%} 42 | 43 | {% if parameters %} 44 | def __init__( 45 | self, 46 | {% for param in required %} 47 | {{- param }}, 48 | {% endfor -%} 49 | {%- if optional -%} 50 | *, 51 | {% endif -%} 52 | {% for param in optional %} 53 | {{- param }}=..., 54 | {% endfor -%} 55 | ): 56 | 57 | {%for param in parameters -%} 58 | self.{{param}} = {{param}} 59 | {% endfor -%} 60 | {% endif %} 61 | 62 | def fit(self, X, y{% if estimator_type in ('transformer', 'feature-selector') %}=None{% endif %}{% if sample_weight %}, sample_weight=None{% endif %}): 63 | """ 64 | Fit {{name}} estimator. 65 | 66 | Parameters 67 | ---------- 68 | X : {array-like, sparse matrix} of shape (n_samples, n_features) 69 | Training data. 70 | 71 | {%- if transformer-%} 72 | y : None 73 | Ignored. 74 | {% else %} 75 | y : array-like of shape (n_samples,) or (n_samples, n_targets) 76 | Target values. 77 | {% endif %} 78 | 79 | {%- if sample_weight -%} 80 | sample_weight : array-like of shape (n_samples,), default=None 81 | Individual weights for each sample. 82 | {% endif %} 83 | Returns 84 | ------- 85 | self : {{name}} 86 | Fitted {{name}} estimator. 87 | """ 88 | {%- if estimator_type in ('transformer', 'feature-selector') %} 89 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments 90 | {% else %} 91 | X, y = check_X_y(X, y, ...) #TODO: Fill in `check_X_y` arguments 92 | {% endif %} 93 | self.n_features_in_ = X.shape[1] 94 | {%- if estimator_type=='classifier'%} 95 | self.classes_ = np.unique(y) 96 | {% endif %} 97 | {%- if sample_weight %} 98 | sample_weight = _check_sample_weight(sample_weight) 99 | {% endif %} 100 | 101 | ... # TODO: Implement fit logic 102 | 103 | {%if linear -%} 104 | # For linear models, coef_ and intercept_ is all you need. `predict` is taken care of by the mixin 105 | self.coef_ = ... 106 | self.intercept_ = ... 107 | {%- endif %} 108 | {% if 'max_iter' in parameters -%}self.n_iter_ = ...{%- endif %} 109 | {% if estimator_type=='outlier' -%}self.offset_ = ...{%- endif %} 110 | {% if estimator_type=='cluster' -%}self.labels_ = ...{%- endif %} 111 | {% if estimator_type=='feature-selector'%} 112 | self.selected_features_ = ... # TODO: Indexes of selected features 113 | self.support_ = np.isin( 114 | np.arange(0, self.n_features_in_), # all_features 115 | self.selected_features_ 116 | ) 117 | {%- endif %} 118 | 119 | return self 120 | 121 | {% if estimator_type == 'classifier' and decision_function == True and linear == False %} 122 | def decision_function(self, X): 123 | """Confidence scores of X. 124 | 125 | Parameters 126 | ---------- 127 | X : array-like of shape (n_samples, n_features) 128 | The data to predict. 129 | 130 | Returns 131 | ------- 132 | Prediction array. 133 | """ 134 | 135 | check_is_fitted(self) 136 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments 137 | 138 | if X.shape[1] != self.n_features_in_: 139 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features." 140 | raise ValueError(msg) 141 | 142 | y_scores = ... # TODO: Implement decision_function logic 143 | 144 | return y_scores 145 | 146 | def predict(self, X): 147 | """Predict X. 148 | 149 | Parameters 150 | ---------- 151 | X : array-like of shape (n_samples, n_features) 152 | The data to predict. 153 | 154 | Returns 155 | ------- 156 | Prediction array. 157 | """ 158 | 159 | check_is_fitted(self) 160 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments 161 | 162 | decision = self.decision_function(X) 163 | y_pred = (decision.ravel() > 0).astype(int) if self.n_classes == 2 else np.argmax(decision, axis=1) 164 | return y_pred 165 | {% endif %} 166 | 167 | {% if estimator_type in ('classifier', 'outlier') and predict_proba == True %} 168 | def predict_proba(self, X): 169 | """Probability estimates of X. 170 | 171 | Parameters 172 | ---------- 173 | X : array-like of shape (n_samples, n_features) 174 | The data to predict. 175 | 176 | Returns 177 | ------- 178 | Prediction array. 179 | """ 180 | 181 | check_is_fitted(self) 182 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments 183 | 184 | if X.shape[1] != self.n_features_in_: 185 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features." 186 | raise ValueError(msg) 187 | 188 | y_proba = ... # TODO: Implement predict_proba logic 189 | 190 | return y_proba 191 | {% endif %} 192 | 193 | {% if estimator_type=='outlier' %} 194 | def score_samples(self, X): 195 | 196 | check_is_fitted(self) 197 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments 198 | 199 | if X.shape[1] != self.n_features_in_: 200 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features." 201 | raise ValueError(msg) 202 | 203 | ... # TODO: Implement scoring function, `decision_function` and `predict` will follow 204 | 205 | return ... 206 | 207 | def decision_function(self, X): 208 | return self.score_samples(X) - self.offset_ 209 | 210 | def predict(self, X): 211 | preds = (self.decision_function(X) >= 0).astype(int) 212 | preds[preds == 0] = -1 213 | return preds 214 | {%- endif %} 215 | 216 | {% if decision_function == False and linear == False and (estimator_type in ('classifier', 'regressor', 'cluster')) %} 217 | def predict(self, X): 218 | """Predict X. 219 | 220 | Parameters 221 | ---------- 222 | X : array-like of shape (n_samples, n_features) 223 | The data to predict. 224 | 225 | Returns 226 | ------- 227 | Prediction array. 228 | """ 229 | 230 | check_is_fitted(self) 231 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments 232 | 233 | if X.shape[1] != self.n_features_in_: 234 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features." 235 | raise ValueError(msg) 236 | 237 | y_pred = ... # TODO: Implement predict logic 238 | 239 | return y_pred 240 | {% endif %} 241 | 242 | {% if estimator_type=='transformer' -%} 243 | def transform(self, X): 244 | """Transform X. 245 | 246 | Parameters 247 | ---------- 248 | X : array-like of shape (n_samples, n_features) 249 | The data to transform. 250 | 251 | Returns 252 | ------- 253 | Transformed array. 254 | """ 255 | 256 | check_is_fitted(self) 257 | X = check_array(X, ...) # TODO: Fill in `check_array` arguments 258 | 259 | if X.shape[1] != self.n_features_in_: 260 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features." 261 | raise ValueError(msg) 262 | 263 | X_ts = ... # TODO: Implement transform logic 264 | 265 | return X_ts 266 | {%- endif %} 267 | 268 | {% if estimator_type=='feature-selector' -%} 269 | def _get_support_mask(self, X): 270 | """Get the boolean mask indicating which features are selected. 271 | 272 | Returns 273 | ------- 274 | support : boolean array of shape [# input features] 275 | An element is True iff its corresponding feature is selected for retention. 276 | """ 277 | 278 | check_is_fitted(self) 279 | return self.support_ 280 | {%- endif %} 281 | 282 | {% if tags %} 283 | def _more_tags(self): 284 | return { 285 | {%for tag in tags -%} 286 | "{{tag}}": ..., 287 | {% endfor -%} 288 | } 289 | {%- endif %} 290 | 291 | {% if estimator_type == 'classifier' %} 292 | @property 293 | def n_classes_(self): 294 | """Number of classes.""" 295 | return len(self.classes_) 296 | {% endif %} -------------------------------------------------------------------------------- /sksmithy/_static/tui.tcss: -------------------------------------------------------------------------------- 1 | .container { 2 | height: auto; 3 | width: auto; 4 | min-height: 10vh; 5 | } 6 | 7 | .label { 8 | height: 3; 9 | content-align: right middle; 10 | width: auto; 11 | } 12 | 13 | Screen { 14 | align: center middle; 15 | min-width: 100vw; 16 | } 17 | 18 | Header { 19 | color: $secondary; 20 | text-style: bold; 21 | } 22 | 23 | Horizontal { 24 | min-height: 10vh; 25 | height: auto; 26 | } 27 | 28 | Name, Estimator, Required, Optional { 29 | width: 50%; 30 | padding: 0 2 0 1; 31 | height: auto; 32 | } 33 | 34 | SampleWeight, Linear { 35 | width: 50%; 36 | padding: 1 0 0 1; 37 | height: auto; 38 | } 39 | 40 | PredictProba, DecisionFunction { 41 | width: 50%; 42 | padding: 0 0 0 1; 43 | height: auto; 44 | } 45 | 46 | Prompt { 47 | padding: 0 0 0 2; 48 | height: auto; 49 | } 50 | 51 | Switch { 52 | height: auto; 53 | width: auto; 54 | } 55 | 56 | Switch:disabled { 57 | background: darkslategrey; 58 | } 59 | 60 | Input.-valid { 61 | border: tall $success 60%; 62 | } 63 | Input.-valid:focus { 64 | border: tall $success; 65 | } 66 | 67 | ForgeRow { 68 | grid-size: 4 1; 69 | grid-gutter: 1; 70 | grid-columns: 45% 10% 10% 25%; 71 | min-height: 15vh; 72 | max-height: 15vh; 73 | } 74 | 75 | TextArea { 76 | min-height: 15vh; 77 | max-height: 100vh; 78 | } 79 | 80 | DestinationFile { 81 | column-span: 2; 82 | height: 100%; 83 | } 84 | 85 | Sidebar { 86 | width: 80; 87 | height: auto; 88 | background: $panel; 89 | transition: offset 200ms in_out_cubic; 90 | layer: overlay; 91 | 92 | } 93 | 94 | Sidebar:focus-within { 95 | offset: 0 0 !important; 96 | } 97 | 98 | Sidebar.-hidden { 99 | offset-x: -100%; 100 | } 101 | 102 | Sidebar Title { 103 | background: $boost; 104 | color: $secondary; 105 | padding: 2 0 1 0; 106 | border-right: vkey $background; 107 | dock: top; 108 | text-align: center; 109 | text-style: bold; 110 | } 111 | 112 | OptionGroup { 113 | background: $boost; 114 | color: $text; 115 | height: 1fr; 116 | border-right: vkey $background; 117 | } 118 | -------------------------------------------------------------------------------- /sksmithy/_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from importlib import resources 3 | from pathlib import Path 4 | from typing import Final 5 | 6 | from jinja2 import Template 7 | 8 | from sksmithy._models import EstimatorType 9 | 10 | TEMPLATE_PATH: Final[Path] = Path(str(resources.files("sksmithy") / "_static" / "template.py.jinja")) 11 | 12 | 13 | def render_template( 14 | name: str, 15 | estimator_type: EstimatorType, 16 | required: list[str], 17 | optional: list[str], 18 | linear: bool = False, 19 | sample_weight: bool = False, 20 | predict_proba: bool = False, 21 | decision_function: bool = False, 22 | tags: list[str] | None = None, 23 | ) -> str: 24 | """ 25 | Render a template using the provided parameters. 26 | 27 | This is achieved in a two steps process: 28 | 29 | - Render the jinja template using the input values. 30 | - Format the string using ruff formatter. 31 | 32 | !!! warning 33 | 34 | This function **does not** validate that arguments are necessarely compatible with each other. 35 | For instance, it could be possible to pass `estimator_type = EstimatorType.RegressorMixin` and 36 | `predict_proba = True` which makes no sense as combination, but it would not raise an error. 37 | 38 | Parameters 39 | ---------- 40 | name 41 | The name of the template. 42 | estimator_type 43 | The type of the estimator. 44 | required 45 | The list of required parameters. 46 | optional 47 | The list of optional parameters. 48 | linear 49 | Whether or not the estimator is linear. 50 | sample_weight 51 | Whether or not the estimator supports sample weights in `.fit()`. 52 | predict_proba 53 | Whether or not the estimator should implement `.predict_proba()` method. 54 | decision_function 55 | Whether or not the estimator should implement `.decision_function()` method. 56 | tags 57 | The list of scikit-learn extra tags. 58 | 59 | Returns 60 | ------- 61 | str : The rendered and formatted template as a string. 62 | """ 63 | values = { 64 | "name": name, 65 | "estimator_type": estimator_type.value, 66 | "mixin": estimator_type.name, 67 | "required": required, 68 | "optional": optional, 69 | "parameters": [*required, *optional], 70 | "linear": linear, 71 | "sample_weight": sample_weight, 72 | "predict_proba": predict_proba, 73 | "decision_function": decision_function, 74 | "tags": tags, 75 | } 76 | 77 | with TEMPLATE_PATH.open(mode="r") as stream: 78 | template = Template(stream.read()).render(values) 79 | 80 | return subprocess.check_output(["ruff", "format", "-"], input=template, encoding="utf-8") 81 | -------------------------------------------------------------------------------- /sksmithy/app.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | from importlib import resources 4 | from importlib.metadata import version 5 | 6 | from result import Err, Ok 7 | 8 | from sksmithy._models import EstimatorType, TagType 9 | from sksmithy._parsers import check_duplicates, name_parser, params_parser 10 | from sksmithy._prompts import ( 11 | PROMPT_DECISION_FUNCTION, 12 | PROMPT_ESTIMATOR, 13 | PROMPT_LINEAR, 14 | PROMPT_NAME, 15 | PROMPT_OPTIONAL, 16 | PROMPT_PREDICT_PROBA, 17 | PROMPT_REQUIRED, 18 | PROMPT_SAMPLE_WEIGHT, 19 | ) 20 | from sksmithy._utils import render_template 21 | 22 | if (st_version := version("streamlit")) and tuple(int(re.sub(r"\D", "", str(v))) for v in st_version.split(".")) < ( 23 | 1, 24 | 34, 25 | 0, 26 | ): # pragma: no cover 27 | st_import_err_msg = ( 28 | f"streamlit>=1.34.0 is required for this module. Found version {st_version}.\nInstall it with " 29 | '`python -m pip install "streamlit>=1.34.0"` or `python -m pip install "sklearn-smithy[streamlit]"`' 30 | ) 31 | raise ImportError(st_import_err_msg) 32 | 33 | else: # pragma: no cover 34 | import streamlit as st 35 | 36 | SIDEBAR_MSG: str = (resources.files("sksmithy") / "_static" / "description.md").read_text() 37 | 38 | 39 | def app() -> None: # noqa: C901,PLR0912,PLR0915 40 | """Streamlit App.""" 41 | st.set_page_config( 42 | page_title="Smithy", 43 | page_icon="⚒️", 44 | layout="wide", 45 | menu_items={ 46 | "Get Help": "https://github.com/FBruzzesi/sklearn-smithy", 47 | "Report a bug": "https://github.com/FBruzzesi/sklearn-smithy/issues/new", 48 | "About": """ 49 | Forge your own scikit-learn estimator! 50 | 51 | For more information, please visit the [sklearn-smithy](https://github.com/FBruzzesi/sklearn-smithy) 52 | repository. 53 | """, 54 | }, 55 | ) 56 | 57 | st.title("Scikit-learn Smithy ⚒️") 58 | st.markdown("## Forge your own scikit-learn compatible estimator") 59 | 60 | with st.sidebar: 61 | st.markdown(SIDEBAR_MSG) 62 | 63 | linear = False 64 | predict_proba = False 65 | decision_function = False 66 | estimator_type: EstimatorType | None = None 67 | 68 | required_is_valid = False 69 | optional_is_valid = False 70 | msg_duplicated_params: str | None = None 71 | 72 | if "forged_template" not in st.session_state: 73 | st.session_state["forged_template"] = "" 74 | 75 | if "forge_counter" not in st.session_state: 76 | st.session_state["forge_counter"] = 0 77 | 78 | with st.container(): # name and type 79 | c11, c12 = st.columns(2) 80 | 81 | with c11: # name 82 | name_input = st.text_input( 83 | label=PROMPT_NAME, 84 | value="MightyEstimator", 85 | placeholder="MightyEstimator", 86 | help=( 87 | "It should be a valid " 88 | "[python identifier](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)" 89 | ), 90 | key="name", 91 | ) 92 | 93 | match name_parser(name_input): 94 | case Ok(name): 95 | pass 96 | case Err(name_error_msg): 97 | name = "" 98 | st.error(name_error_msg) 99 | 100 | with c12: # type 101 | estimator = st.selectbox( 102 | label=PROMPT_ESTIMATOR, 103 | options=tuple(e.value for e in EstimatorType), 104 | format_func=lambda v: " ".join(x.capitalize() for x in v.split("-")), 105 | index=None, 106 | key="estimator", 107 | ) 108 | 109 | if estimator: 110 | estimator_type = EstimatorType(estimator) 111 | 112 | with st.container(): # params 113 | c21, c22 = st.columns(2) 114 | 115 | with c21: # required 116 | required_params = st.text_input( 117 | label=PROMPT_REQUIRED, 118 | placeholder="alpha,beta", 119 | help=( 120 | "It should be a sequence of comma-separated " 121 | "[python identifiers](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)" 122 | ), 123 | key="required", 124 | ) 125 | 126 | match params_parser(required_params): 127 | case Ok(required): 128 | required_is_valid = True 129 | case Err(required_err_msg): 130 | required_is_valid = False 131 | st.error(required_err_msg) 132 | 133 | with c22: # optional 134 | optional_params = st.text_input( 135 | label=PROMPT_OPTIONAL, 136 | placeholder="mu,sigma", 137 | help=( 138 | "It should be a sequence of comma-separated " 139 | "[python identifiers](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)" 140 | ), 141 | key="optional", 142 | ) 143 | 144 | match params_parser(optional_params): 145 | case Ok(optional): 146 | optional_is_valid = True 147 | case Err(optional_err_msg): 148 | optional_is_valid = False 149 | st.error(optional_err_msg) 150 | 151 | if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)): 152 | st.error(msg_duplicated_params) 153 | 154 | with st.container(): # sample_weight and linear 155 | c31, c32 = st.columns(2) 156 | 157 | with c31: # sample_weight 158 | sample_weight = st.toggle( 159 | PROMPT_SAMPLE_WEIGHT, 160 | help="[sample_weight](https://scikit-learn.org/dev/glossary.html#term-sample_weight)", 161 | key="sample_weight", 162 | ) 163 | with c32: # linear 164 | linear = st.toggle( 165 | label=PROMPT_LINEAR, 166 | disabled=(estimator_type not in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin}), 167 | help="Available only if estimator is `Classifier` or `Regressor`", 168 | key="linear", 169 | ) 170 | 171 | with st.container(): # predict_proba and decision_function 172 | c41, c42 = st.columns(2) 173 | 174 | with c41: # predict_proba 175 | predict_proba = st.toggle( 176 | label=PROMPT_PREDICT_PROBA, 177 | disabled=(estimator_type not in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin}), 178 | help=( 179 | "[predict_proba](https://scikit-learn.org/dev/glossary.html#term-predict_proba): " 180 | "Available only if estimator is `Classifier` or `Outlier`. " 181 | ), 182 | key="predict_proba", 183 | ) 184 | 185 | with c42: # decision_function 186 | decision_function = st.toggle( 187 | label=PROMPT_DECISION_FUNCTION, 188 | disabled=(estimator_type != EstimatorType.ClassifierMixin) or linear, 189 | help=( 190 | "[decision_function](https://scikit-learn.org/dev/glossary.html#term-decision_function): " 191 | "Available only if estimator is `Classifier`" 192 | ), 193 | key="decision_function", 194 | ) 195 | 196 | st.write("#") # empty space hack 197 | 198 | with st.container(): # forge button 199 | c51, c52, _, c54 = st.columns([2, 1, 1, 1]) 200 | 201 | with ( 202 | c51, 203 | st.popover( 204 | label="Additional tags", 205 | help=( 206 | "To know more about tags, check the " 207 | "[scikit-learn documentation](https://scikit-learn.org/dev/developers/develop.html#estimator-tags)" 208 | ), 209 | ), 210 | ): 211 | tags = st.multiselect( 212 | label="Select tags", 213 | options=tuple(e.value for e in TagType), 214 | help="These tags are not validated against the selected estimator type!", 215 | key="tags", 216 | ) 217 | 218 | with c52: 219 | forge_btn = st.button( 220 | label="Time to forge 🛠️", 221 | type="primary", 222 | disabled=any( 223 | [ 224 | not name, 225 | not estimator_type, 226 | not required_is_valid, 227 | not optional_is_valid, 228 | msg_duplicated_params, 229 | ] 230 | ), 231 | key="forge_btn", 232 | ) 233 | if forge_btn: 234 | st.session_state["forge_counter"] += 1 235 | st.session_state["forged_template"] = render_template( 236 | name=name, 237 | estimator_type=estimator_type, # type: ignore[arg-type] # At this point estimator_type is never None. 238 | required=required, 239 | optional=optional, 240 | linear=linear, 241 | sample_weight=sample_weight, 242 | predict_proba=predict_proba, 243 | decision_function=decision_function, 244 | tags=tags, 245 | ) 246 | 247 | with c54, st.popover(label="Download", disabled=not st.session_state["forge_counter"]): 248 | if name: 249 | file_name = st.text_input(label="Select filename", value=f"{name.lower()}.py", key="file_name") 250 | 251 | data = st.session_state["forged_template"] 252 | st.download_button( 253 | label="Confirm", 254 | type="primary", 255 | data=data, 256 | file_name=file_name, 257 | key="download_btn", 258 | ) 259 | 260 | st.write("#") # empty space hack 261 | 262 | with st.container(): # code output 263 | if forge_btn: 264 | st.toast("Request submitted!") 265 | progress_text = "Forging in progress ..." 266 | progress_bar = st.progress(0, text=progress_text) 267 | # Consider using status component instead 268 | # https://docs.streamlit.io/develop/api-reference/status/st.status 269 | 270 | for percent_complete in range(100): 271 | time.sleep(0.002) 272 | progress_bar.progress(percent_complete + 1, text=progress_text) 273 | 274 | time.sleep(0.2) 275 | progress_bar.empty() 276 | 277 | if st.session_state["forge_counter"]: 278 | st.code(st.session_state["forged_template"], language="python", line_numbers=True) 279 | 280 | 281 | if __name__ == "__main__": 282 | app() 283 | -------------------------------------------------------------------------------- /sksmithy/cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import typer 4 | 5 | from sksmithy._arguments import ( 6 | decision_function_arg, 7 | estimator_type_arg, 8 | linear_arg, 9 | name_arg, 10 | optional_params_arg, 11 | output_file_arg, 12 | predict_proba_arg, 13 | required_params_arg, 14 | sample_weight_arg, 15 | tags_arg, 16 | ) 17 | from sksmithy._logger import console 18 | from sksmithy._utils import render_template 19 | 20 | cli = typer.Typer( 21 | name="smith", 22 | help="CLI to generate scikit-learn estimator boilerplate code.", 23 | rich_markup_mode="rich", 24 | rich_help_panel="Customization and Utils", 25 | ) 26 | 27 | 28 | @cli.command() 29 | def version() -> None: 30 | """Display library version.""" 31 | from importlib import metadata 32 | 33 | __version__ = metadata.version("sklearn-smithy") 34 | console.print(f"sklearn-smithy={__version__}", style="good") 35 | 36 | 37 | @cli.command() 38 | def forge( 39 | name: name_arg, 40 | estimator_type: estimator_type_arg, 41 | required_params: required_params_arg = "", 42 | optional_params: optional_params_arg = "", 43 | sample_weight: sample_weight_arg = False, 44 | linear: linear_arg = False, 45 | predict_proba: predict_proba_arg = False, 46 | decision_function: decision_function_arg = False, 47 | tags: tags_arg = "", 48 | output_file: output_file_arg = "", 49 | ) -> None: 50 | """Generate a new shiny scikit-learn compatible estimator ✨ 51 | 52 | Depending on the estimator type the following additional information could be required: 53 | 54 | * if the estimator is linear (classifier or regression) 55 | * if the estimator implements `.predict_proba()` method (classifier or outlier detector) 56 | * if the estimator implements `.decision_function()` method (classifier only) 57 | 58 | Finally, the following two questions will be prompt: 59 | 60 | * if the estimator should have tags (To know more about tags, check the dedicated scikit-learn documentation 61 | at https://scikit-learn.org/dev/developers/develop.html#estimator-tags) 62 | * in which file the class should be saved (default is `f'{name.lower()}.py'`) 63 | """ 64 | forged_template = render_template( 65 | name=name, 66 | estimator_type=estimator_type, 67 | required=required_params, # type: ignore[arg-type] # Callback transforms it into `list[str]` 68 | optional=optional_params, # type: ignore[arg-type] # Callback transforms it into `list[str]` 69 | linear=linear, 70 | sample_weight=sample_weight, 71 | predict_proba=predict_proba, 72 | decision_function=decision_function, 73 | tags=tags, # type: ignore[arg-type] # Callback transforms it into `list[str]` 74 | ) 75 | 76 | destination_file = Path(output_file) 77 | destination_file.parent.mkdir(parents=True, exist_ok=True) 78 | 79 | with destination_file.open(mode="w") as destination: 80 | destination.write(forged_template) 81 | 82 | console.print(f"Template forged at {destination_file}", style="good") 83 | 84 | 85 | @cli.command(name="forge-tui") 86 | def forge_tui() -> None: 87 | """Run Terminal User Interface via Textual.""" 88 | from sksmithy.tui import ForgeTUI 89 | 90 | tui = ForgeTUI() 91 | tui.run() 92 | 93 | 94 | @cli.command(name="forge-webui") 95 | def forge_webui() -> None: 96 | """Run Web User Interface via Streamlit.""" 97 | import subprocess 98 | 99 | subprocess.run(["streamlit", "run", "sksmithy/app.py"], check=True) 100 | -------------------------------------------------------------------------------- /sksmithy/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/sksmithy/py.typed -------------------------------------------------------------------------------- /sksmithy/tui/__init__.py: -------------------------------------------------------------------------------- 1 | from sksmithy.tui._tui import ForgeTUI 2 | 3 | __all__ = ("ForgeTUI",) 4 | -------------------------------------------------------------------------------- /sksmithy/tui/_components.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import webbrowser 3 | from importlib import resources 4 | from pathlib import Path 5 | 6 | from result import Err, Ok 7 | from textual import on 8 | from textual.app import ComposeResult 9 | from textual.containers import Container, Grid, Horizontal, ScrollableContainer 10 | from textual.widgets import Button, Collapsible, Input, Markdown, Select, Static, Switch, TextArea 11 | 12 | from sksmithy._models import EstimatorType 13 | from sksmithy._parsers import check_duplicates, name_parser, params_parser 14 | from sksmithy._prompts import ( 15 | PROMPT_DECISION_FUNCTION, 16 | PROMPT_ESTIMATOR, 17 | PROMPT_LINEAR, 18 | PROMPT_NAME, 19 | PROMPT_OPTIONAL, 20 | PROMPT_OUTPUT, 21 | PROMPT_PREDICT_PROBA, 22 | PROMPT_REQUIRED, 23 | PROMPT_SAMPLE_WEIGHT, 24 | ) 25 | from sksmithy._utils import render_template 26 | from sksmithy.tui._validators import NameValidator, ParamsValidator 27 | 28 | if sys.version_info >= (3, 11): # pragma: no cover 29 | from typing import Self 30 | else: # pragma: no cover 31 | from typing_extensions import Self 32 | 33 | 34 | SIDEBAR_MSG: str = (resources.files("sksmithy") / "_static" / "description.md").read_text() 35 | 36 | 37 | class Prompt(Static): 38 | pass 39 | 40 | 41 | class Name(Container): 42 | """Name input component.""" 43 | 44 | def compose(self: Self) -> ComposeResult: 45 | yield Prompt(PROMPT_NAME, classes="label") 46 | yield Input(placeholder="MightyEstimator", id="name", validators=[NameValidator()]) 47 | 48 | @on(Input.Changed, "#name") 49 | def on_input_change(self: Self, event: Input.Changed) -> None: 50 | if not event.validation_result.is_valid: # type: ignore[union-attr] 51 | self.notify( 52 | message=event.validation_result.failure_descriptions[0], # type: ignore[union-attr] 53 | title="Invalid Name", 54 | severity="error", 55 | timeout=5, 56 | ) 57 | else: 58 | output_file = self.app.query_one("#output-file", Input) 59 | output_file.value = f"{event.value.lower()}.py" 60 | 61 | 62 | class Estimator(Container): 63 | """Estimator select component.""" 64 | 65 | def compose(self: Self) -> ComposeResult: 66 | yield Prompt(PROMPT_ESTIMATOR, classes="label") 67 | yield Select( 68 | options=((" ".join(x.capitalize() for x in e.value.split("-")), e.value) for e in EstimatorType), 69 | id="estimator", 70 | ) 71 | 72 | @on(Select.Changed, "#estimator") 73 | def on_select_change(self: Self, event: Select.Changed) -> None: 74 | linear = self.app.query_one("#linear", Switch) 75 | predict_proba = self.app.query_one("#predict_proba", Switch) 76 | decision_function = self.app.query_one("#decision_function", Switch) 77 | 78 | linear.disabled = event.value not in {"classifier", "regressor"} 79 | predict_proba.disabled = event.value not in {"classifier", "outlier"} 80 | decision_function.disabled = event.value not in {"classifier"} 81 | 82 | linear.value = linear.value and (not linear.disabled) 83 | predict_proba.value = predict_proba.value and (not predict_proba.disabled) 84 | decision_function.value = decision_function.value and (not decision_function.disabled) 85 | 86 | 87 | class Required(Container): 88 | """Required params input component.""" 89 | 90 | def compose(self: Self) -> ComposeResult: 91 | yield Prompt(PROMPT_REQUIRED, classes="label") 92 | yield Input(placeholder="alpha,beta", id="required", validators=[ParamsValidator()]) 93 | 94 | @on(Input.Submitted, "#required") 95 | def on_input_change(self: Self, event: Input.Submitted) -> None: 96 | if not event.validation_result.is_valid: # type: ignore[union-attr] 97 | self.notify( 98 | message="\n".join(event.validation_result.failure_descriptions), # type: ignore[union-attr] 99 | title="Invalid Parameter", 100 | severity="error", 101 | timeout=5, 102 | ) 103 | 104 | optional = self.app.query_one("#optional", Input).value or "" 105 | if ( 106 | optional 107 | and event.value 108 | and ( 109 | duplicates_result := check_duplicates( 110 | event.value.split(","), 111 | optional.split(","), 112 | ) 113 | ) 114 | ): 115 | self.notify( 116 | message=duplicates_result, 117 | title="Duplicate Parameter", 118 | severity="error", 119 | timeout=5, 120 | ) 121 | 122 | 123 | class Optional(Container): 124 | """Optional params input component.""" 125 | 126 | def compose(self: Self) -> ComposeResult: 127 | yield Prompt(PROMPT_OPTIONAL, classes="label") 128 | yield Input(placeholder="mu,sigma", id="optional", validators=[ParamsValidator()]) 129 | 130 | @on(Input.Submitted, "#optional") 131 | def on_optional_change(self: Self, event: Input.Submitted) -> None: 132 | if not event.validation_result.is_valid: # type: ignore[union-attr] 133 | self.notify( 134 | message="\n".join(event.validation_result.failure_descriptions), # type: ignore[union-attr] 135 | title="Invalid Parameter", 136 | severity="error", 137 | timeout=5, 138 | ) 139 | 140 | required = self.app.query_one("#required", Input).value or "" 141 | if ( 142 | required 143 | and event.value 144 | and ( 145 | duplicates_result := check_duplicates( 146 | required.split(","), 147 | event.value.split(","), 148 | ) 149 | ) 150 | ): 151 | self.notify( 152 | message=duplicates_result, 153 | title="Duplicate Parameter", 154 | severity="error", 155 | timeout=5, 156 | ) 157 | 158 | 159 | class SampleWeight(Container): 160 | """sample_weight switch component.""" 161 | 162 | def compose(self: Self) -> ComposeResult: 163 | yield Horizontal( 164 | Switch(id="sample_weight"), 165 | Prompt(PROMPT_SAMPLE_WEIGHT, classes="label"), 166 | classes="container", 167 | ) 168 | 169 | 170 | class Linear(Container): 171 | """linear switch component.""" 172 | 173 | def compose(self: Self) -> ComposeResult: 174 | yield Horizontal( 175 | Switch(id="linear"), 176 | Prompt(PROMPT_LINEAR, classes="label"), 177 | classes="container", 178 | ) 179 | 180 | @on(Switch.Changed, "#linear") 181 | def on_switch_changed(self: Self, event: Switch.Changed) -> None: 182 | decision_function = self.app.query_one("#decision_function", Switch) 183 | decision_function.disabled = event.value 184 | decision_function.value = decision_function.value and (not decision_function.disabled) 185 | 186 | 187 | class PredictProba(Container): 188 | """predict_proba switch component.""" 189 | 190 | def compose(self: Self) -> ComposeResult: 191 | yield Horizontal( 192 | Switch(id="predict_proba"), 193 | Prompt(PROMPT_PREDICT_PROBA, classes="label"), 194 | classes="container", 195 | ) 196 | 197 | 198 | class DecisionFunction(Container): 199 | """decision_function switch component.""" 200 | 201 | def compose(self: Self) -> ComposeResult: 202 | yield Horizontal( 203 | Switch(id="decision_function"), 204 | Prompt(PROMPT_DECISION_FUNCTION, classes="label"), 205 | classes="container", 206 | ) 207 | 208 | 209 | class ForgeButton(Container): 210 | """forge button component.""" 211 | 212 | def compose(self: Self) -> ComposeResult: 213 | yield Button(label="Forge ⚒️", id="forge-btn", variant="success") 214 | 215 | @on(Button.Pressed, "#forge-btn") 216 | def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901 217 | errors = [] 218 | 219 | name_input = self.app.query_one("#name", Input).value 220 | estimator = self.app.query_one("#estimator", Select).value 221 | required_params = self.app.query_one("#required", Input).value 222 | optional_params = self.app.query_one("#optional", Input).value 223 | 224 | sample_weight = self.app.query_one("#linear", Switch).value 225 | linear = self.app.query_one("#linear", Switch).value 226 | predict_proba = self.app.query_one("#predict_proba", Switch).value 227 | decision_function = self.app.query_one("#decision_function", Switch).value 228 | 229 | code_area = self.app.query_one("#code-area", TextArea) 230 | code_editor = self.app.query_one("#code-editor", Collapsible) 231 | 232 | match name_parser(name_input): 233 | case Ok(name): 234 | pass 235 | case Err(name_error_msg): 236 | errors.append(name_error_msg) 237 | 238 | match estimator: 239 | case str(v): 240 | estimator_type = EstimatorType(v) 241 | case Select.BLANK: 242 | errors.append("Estimator cannot be empty!") 243 | 244 | match params_parser(required_params): 245 | case Ok(required): 246 | required_is_valid = True 247 | case Err(required_err_msg): 248 | required_is_valid = False 249 | errors.append(required_err_msg) 250 | 251 | match params_parser(optional_params): 252 | case Ok(optional): 253 | optional_is_valid = True 254 | 255 | case Err(optional_err_msg): 256 | optional_is_valid = False 257 | errors.append(optional_err_msg) 258 | 259 | if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)): 260 | errors.append(msg_duplicated_params) 261 | 262 | if errors: 263 | self.notify( 264 | message="\n".join([f"- {e}" for e in errors]), 265 | title="Invalid inputs!", 266 | severity="error", 267 | timeout=5, 268 | ) 269 | 270 | else: 271 | forged_template = render_template( 272 | name=name, 273 | estimator_type=estimator_type, 274 | required=required, 275 | optional=optional, 276 | linear=linear, 277 | sample_weight=sample_weight, 278 | predict_proba=predict_proba, 279 | decision_function=decision_function, 280 | tags=None, 281 | ) 282 | 283 | code_area.text = forged_template 284 | code_editor.collapsed = False 285 | 286 | self.notify( 287 | message="Template forged!", 288 | title="Success!", 289 | severity="information", 290 | timeout=5, 291 | ) 292 | 293 | 294 | class SaveButton(Container): 295 | """forge button component.""" 296 | 297 | def compose(self: Self) -> ComposeResult: 298 | yield Button(label="Save 📂", id="save-btn", variant="primary") 299 | 300 | @on(Button.Pressed, "#save-btn") 301 | def on_save(self: Self, _: Button.Pressed) -> None: 302 | output_file = self.app.query_one("#output-file", Input).value 303 | 304 | if not output_file: 305 | self.notify( 306 | message="Outfile filename cannot be empty!", 307 | title="Invalid filename!", 308 | severity="error", 309 | timeout=5, 310 | ) 311 | else: 312 | destination_file = Path(output_file) 313 | destination_file.parent.mkdir(parents=True, exist_ok=True) 314 | 315 | code = self.app.query_one("#code-area", TextArea).text 316 | 317 | with destination_file.open(mode="w") as destination: 318 | destination.write(code) 319 | 320 | self.notify( 321 | message=f"Saved at {destination_file}", 322 | title="Success!", 323 | severity="information", 324 | timeout=5, 325 | ) 326 | 327 | 328 | class DestinationFile(Container): 329 | """Destination file input component.""" 330 | 331 | def compose(self: Self) -> ComposeResult: 332 | yield Input(placeholder=PROMPT_OUTPUT, id="output-file") 333 | 334 | 335 | class ForgeRow(Grid): 336 | """Row grid for forge.""" 337 | 338 | 339 | class OptionGroup(ScrollableContainer): 340 | pass 341 | 342 | 343 | class Sidebar(Container): 344 | def compose(self: Self) -> ComposeResult: 345 | yield OptionGroup(Markdown(SIDEBAR_MSG)) 346 | 347 | def on_markdown_link_clicked(self: Self, event: Markdown.LinkClicked) -> None: 348 | # Relevant discussion: https://github.com/Textualize/textual/discussions/3668 349 | webbrowser.open_new_tab(event.href) 350 | -------------------------------------------------------------------------------- /sksmithy/tui/_tui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from importlib import metadata, resources 3 | from typing import ClassVar 4 | 5 | from textual.app import App, ComposeResult 6 | from textual.containers import Container, Horizontal, ScrollableContainer 7 | from textual.reactive import reactive 8 | from textual.widgets import Button, Collapsible, Footer, Header, Rule, Static, TextArea 9 | 10 | from sksmithy.tui._components import ( 11 | DecisionFunction, 12 | DestinationFile, 13 | Estimator, 14 | ForgeButton, 15 | ForgeRow, 16 | Linear, 17 | Name, 18 | Optional, 19 | PredictProba, 20 | Required, 21 | SampleWeight, 22 | SaveButton, 23 | Sidebar, 24 | ) 25 | 26 | if sys.version_info >= (3, 11): # pragma: no cover 27 | from typing import Self 28 | else: # pragma: no cover 29 | from typing_extensions import Self 30 | 31 | 32 | class ForgeTUI(App): 33 | """Textual app to forge scikit-learn compatible estimators.""" 34 | 35 | CSS_PATH: ClassVar[str] = str(resources.files("sksmithy") / "_static" / "tui.tcss") 36 | TITLE: ClassVar[str] = "Scikit-learn Smithy ⚒️" # type: ignore[misc] 37 | 38 | BINDINGS: ClassVar = [ 39 | ("ctrl+d", "toggle_sidebar", "Description"), 40 | ("L", "toggle_dark", "Light/Dark mode"), 41 | ("F", "forge", "Forge"), 42 | ("ctrl+s", "save", "Save"), 43 | ("E", "app.quit", "Exit"), 44 | ] 45 | 46 | show_sidebar = reactive(False) # noqa: FBT003 47 | 48 | def on_mount(self: Self) -> None: 49 | """Compose on mount. 50 | 51 | Q: is this needed??? 52 | """ 53 | self.compose() 54 | 55 | def compose(self: Self) -> ComposeResult: 56 | """Create child widgets for the app.""" 57 | yield Container( 58 | Header(icon=f"v{metadata.version('sklearn-smithy')}"), 59 | ScrollableContainer( 60 | Horizontal(Name(), Estimator()), 61 | Horizontal(Required(), Optional()), 62 | Horizontal(SampleWeight(), Linear()), 63 | Horizontal(PredictProba(), DecisionFunction()), 64 | Rule(), 65 | ForgeRow( 66 | Static(), 67 | ForgeButton(), 68 | SaveButton(), 69 | DestinationFile(), 70 | ), 71 | Rule(), 72 | Collapsible( 73 | TextArea( 74 | text="", 75 | language="python", 76 | theme="vscode_dark", 77 | show_line_numbers=True, 78 | tab_behavior="indent", 79 | id="code-area", 80 | ), 81 | title="Code Editor", 82 | collapsed=True, 83 | id="code-editor", 84 | ), 85 | ), 86 | Sidebar(classes="-hidden"), 87 | Footer(), 88 | ) 89 | 90 | def action_toggle_dark(self: Self) -> None: # pragma: no cover 91 | """Toggle dark mode.""" 92 | self.dark = not self.dark 93 | 94 | def action_toggle_sidebar(self: Self) -> None: # pragma: no cover 95 | """Toggle sidebar component.""" 96 | sidebar = self.query_one(Sidebar) 97 | self.set_focus(None) 98 | 99 | if sidebar.has_class("-hidden"): 100 | sidebar.remove_class("-hidden") 101 | else: 102 | if sidebar.query("*:focus"): 103 | self.screen.set_focus(None) 104 | sidebar.add_class("-hidden") 105 | 106 | def action_forge(self: Self) -> None: 107 | """Press forge button.""" 108 | forge_btn = self.query_one("#forge-btn", Button) 109 | forge_btn.press() 110 | 111 | def action_save(self: Self) -> None: 112 | """Press save button.""" 113 | save_btn = self.query_one("#save-btn", Button) 114 | save_btn.press() 115 | 116 | 117 | if __name__ == "__main__": # pragma: no cover 118 | tui = ForgeTUI() 119 | tui.run() 120 | -------------------------------------------------------------------------------- /sksmithy/tui/_validators.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import TypeVar 3 | 4 | from result import Err, Ok, Result 5 | from textual.validation import ValidationResult, Validator 6 | 7 | from sksmithy._parsers import name_parser, params_parser 8 | 9 | if sys.version_info >= (3, 11): # pragma: no cover 10 | from typing import Self 11 | else: # pragma: no cover 12 | from typing_extensions import Self 13 | 14 | T = TypeVar("T") 15 | R = TypeVar("R") 16 | 17 | 18 | class _BaseValidator(Validator): 19 | @staticmethod 20 | def parser(value: str) -> Result[str | list[str], str]: # pragma: no cover 21 | raise NotImplementedError 22 | 23 | def validate(self: Self, value: str) -> ValidationResult: 24 | match self.parser(value): 25 | case Ok(_): 26 | return self.success() 27 | case Err(msg): 28 | return self.failure(msg) 29 | 30 | 31 | class NameValidator(_BaseValidator): 32 | @staticmethod 33 | def parser(value: str) -> Result[str, str]: 34 | return name_parser(value) 35 | 36 | 37 | class ParamsValidator(_BaseValidator): 38 | @staticmethod 39 | def parser(value: str) -> Result[list[str], str]: 40 | return params_parser(value) 41 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from streamlit.testing.v1 import AppTest 3 | 4 | from sksmithy._models import EstimatorType 5 | 6 | 7 | @pytest.fixture(params=["MightyEstimator"]) 8 | def name(request: pytest.FixtureRequest) -> str: 9 | return request.param 10 | 11 | 12 | @pytest.fixture(params=list(EstimatorType)) 13 | def estimator(request: pytest.FixtureRequest) -> EstimatorType: 14 | return request.param 15 | 16 | 17 | @pytest.fixture(params=[["alpha", "beta"], ["max_iter"], []]) 18 | def required(request: pytest.FixtureRequest) -> list[str]: 19 | return request.param 20 | 21 | 22 | @pytest.fixture( 23 | params=[ 24 | ("a,a", "Found repeated parameters!"), 25 | ("a-a", "The following parameters are invalid python identifiers: ('a-a',)"), 26 | ] 27 | ) 28 | def invalid_required(request: pytest.FixtureRequest) -> tuple[str, str]: 29 | return request.param 30 | 31 | 32 | @pytest.fixture( 33 | params=[ 34 | ("b,b", "Found repeated parameters!"), 35 | ("b b", "The following parameters are invalid python identifiers: ('b b',)"), 36 | ] 37 | ) 38 | def invalid_optional(request: pytest.FixtureRequest) -> tuple[str, str]: 39 | return request.param 40 | 41 | 42 | @pytest.fixture(params=[["mu", "sigma"], []]) 43 | def optional(request: pytest.FixtureRequest) -> list[str]: 44 | return request.param 45 | 46 | 47 | @pytest.fixture(params=[True, False]) 48 | def sample_weight(request: pytest.FixtureRequest) -> bool: 49 | return request.param 50 | 51 | 52 | @pytest.fixture(params=[True, False]) 53 | def linear(request: pytest.FixtureRequest) -> bool: 54 | return request.param 55 | 56 | 57 | @pytest.fixture(params=[True, False]) 58 | def predict_proba(request: pytest.FixtureRequest) -> bool: 59 | return request.param 60 | 61 | 62 | @pytest.fixture(params=[True, False]) 63 | def decision_function(request: pytest.FixtureRequest) -> bool: 64 | return request.param 65 | 66 | 67 | @pytest.fixture(params=[["allow_nan", "binary_only"], [], None]) 68 | def tags(request: pytest.FixtureRequest) -> list[str] | None: 69 | return request.param 70 | 71 | 72 | @pytest.fixture() 73 | def app() -> AppTest: 74 | return AppTest.from_file("sksmithy/app.py", default_timeout=10) 75 | -------------------------------------------------------------------------------- /tests/test_app.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from streamlit.testing.v1 import AppTest 3 | 4 | from sksmithy._models import EstimatorType 5 | 6 | 7 | def test_smoke(app: AppTest) -> None: 8 | """Basic smoke test.""" 9 | app.run() 10 | assert not app.exception 11 | 12 | 13 | @pytest.mark.parametrize( 14 | ("name_", "err_msg"), 15 | [ 16 | ("MightyEstimator", ""), 17 | ("not-valid-name", "`not-valid-name` is not a valid python class name!"), 18 | ("class", "`class` is a python reserved keyword!"), 19 | ], 20 | ) 21 | def test_name(app: AppTest, name_: str, err_msg: str) -> None: 22 | """Test `name` text_input component.""" 23 | app.run() 24 | app.text_input(key="name").input(name_).run() 25 | 26 | if err_msg: 27 | assert app.error[0].value == err_msg 28 | else: 29 | assert not app.error 30 | 31 | 32 | def test_estimator_interaction(app: AppTest, estimator: EstimatorType) -> None: 33 | """Test that all toggle components interact correctly with the selected estimator.""" 34 | app.run() 35 | app.selectbox(key="estimator").select(estimator.value).run() 36 | 37 | assert (not app.toggle(key="linear").disabled) == ( 38 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin} 39 | ) 40 | assert (not app.toggle(key="predict_proba").disabled) == ( 41 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin} 42 | ) 43 | assert (not app.toggle(key="decision_function").disabled) == (estimator == EstimatorType.ClassifierMixin) 44 | 45 | if estimator == EstimatorType.ClassifierMixin: 46 | app.toggle(key="linear").set_value(True).run() 47 | 48 | assert app.toggle(key="decision_function").disabled 49 | 50 | 51 | @pytest.mark.parametrize( 52 | ("required_", "optional_", "err_msg"), 53 | [ 54 | ("a,b", "c,d", ""), 55 | ("a,a", "", "Found repeated parameters!"), 56 | ("", "b,b", "Found repeated parameters!"), 57 | ("a-a", "", "The following parameters are invalid python identifiers: ('a-a',)"), 58 | ("", "b b", "The following parameters are invalid python identifiers: ('b b',)"), 59 | ("a,b", "a", "The following parameters are duplicated between required and optional: {'a'}"), 60 | ], 61 | ) 62 | def test_params( 63 | app: AppTest, name: str, estimator: EstimatorType, required_: str, optional_: str, err_msg: str 64 | ) -> None: 65 | """Test required and optional params interaction.""" 66 | app.run() 67 | app.text_input(key="name").input(name).run() 68 | app.selectbox(key="estimator").select(estimator.value).run() 69 | 70 | app.text_input(key="required").input(required_).run() 71 | app.text_input(key="optional").input(optional_).run() 72 | 73 | if err_msg: 74 | assert app.error[0].value == err_msg 75 | # Forge button gets disabled if any error happen 76 | assert app.button(key="forge_btn").disabled 77 | else: 78 | assert not app.error 79 | assert not app.button(key="forge_btn").disabled 80 | 81 | 82 | def test_forge(app: AppTest, name: str, estimator: EstimatorType) -> None: 83 | """Test forge button and all of its interactions. 84 | 85 | Remark that there is no way of testing `popover` or `download_button` components (yet). 86 | """ 87 | app.run() 88 | assert app.button(key="forge_btn").disabled 89 | assert app.session_state["forge_counter"] == 0 90 | 91 | app.text_input(key="name").input(name).run() 92 | app.selectbox(key="estimator").select(estimator.value).run() 93 | assert not app.button(key="forge_btn").disabled 94 | assert not app.code 95 | 96 | app.button(key="forge_btn").click().run() 97 | assert app.session_state["forge_counter"] == 1 98 | assert app.code is not None 99 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from typer.testing import CliRunner 5 | 6 | from sksmithy import __version__ 7 | from sksmithy._models import EstimatorType 8 | from sksmithy._prompts import ( 9 | PROMPT_DECISION_FUNCTION, 10 | PROMPT_ESTIMATOR, 11 | PROMPT_LINEAR, 12 | PROMPT_NAME, 13 | PROMPT_OPTIONAL, 14 | PROMPT_OUTPUT, 15 | PROMPT_PREDICT_PROBA, 16 | PROMPT_REQUIRED, 17 | PROMPT_SAMPLE_WEIGHT, 18 | PROMPT_TAGS, 19 | ) 20 | from sksmithy.cli import cli 21 | 22 | runner = CliRunner() 23 | 24 | 25 | def test_version() -> None: 26 | result = runner.invoke(cli, ["version"]) 27 | assert result.exit_code == 0 28 | assert f"sklearn-smithy={__version__}" in result.stdout 29 | 30 | 31 | @pytest.mark.parametrize("linear", ["y", "N"]) 32 | def test_forge_estimator(tmp_path: Path, name: str, estimator: EstimatorType, linear: str) -> None: 33 | """Tests that prompts are correct for classifier estimator.""" 34 | output_file = tmp_path / (f"{name.lower()}.py") 35 | assert not output_file.exists() 36 | 37 | _input = "".join( 38 | [ 39 | f"{name}\n", # name 40 | f"{estimator.value}\n", # estimator_type 41 | "\n", # required params 42 | "\n", # optional params 43 | "\n", # sample weight 44 | f"{linear}\n" if estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin} else "", 45 | "\n" if estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin} else "", # predict_proba 46 | "\n" 47 | if (linear == "N" and estimator == EstimatorType.ClassifierMixin) 48 | else "", # decision_function: prompted only if not linear 49 | "\n", # tags 50 | f"{output_file!s}\n", # output file 51 | ] 52 | ) 53 | 54 | result = runner.invoke( 55 | app=cli, 56 | args=["forge"], 57 | input=_input, 58 | ) 59 | 60 | assert result.exit_code == 0 61 | assert output_file.exists() 62 | 63 | # General prompts 64 | assert all( 65 | _prompt in result.stdout 66 | for _prompt in ( 67 | PROMPT_NAME, 68 | PROMPT_ESTIMATOR, 69 | PROMPT_REQUIRED, 70 | PROMPT_OPTIONAL, 71 | PROMPT_SAMPLE_WEIGHT, 72 | PROMPT_TAGS, 73 | f"{PROMPT_OUTPUT} [{name.lower()}.py]", 74 | ) 75 | ) 76 | 77 | # Estimator type specific prompts 78 | assert (PROMPT_LINEAR in result.stdout) == ( 79 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin} 80 | ) 81 | assert (PROMPT_PREDICT_PROBA in result.stdout) == ( 82 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin} 83 | ) 84 | assert (PROMPT_DECISION_FUNCTION in result.stdout) == (linear == "N" and estimator == EstimatorType.ClassifierMixin) 85 | 86 | 87 | @pytest.mark.parametrize( 88 | ("invalid_name", "name_err_msg"), 89 | [ 90 | ("class", "Error: `class` is a python reserved keyword!"), 91 | ("abc-xyz", "Error: `abc-xyz` is not a valid python class name!"), 92 | ], 93 | ) 94 | @pytest.mark.parametrize( 95 | ("invalid_required", "required_err_msg"), 96 | [ 97 | ("a-b", "Error: The following parameters are invalid python identifiers: ('a-b',)"), 98 | ("a,a", "Error: Found repeated parameters!"), 99 | ], 100 | ) 101 | @pytest.mark.parametrize( 102 | ("invalid_optional", "duplicated_err_msg"), 103 | [("a", "Error: The following parameters are duplicated between required and optional: {'a'}")], 104 | ) 105 | @pytest.mark.parametrize( 106 | ("invalid_tags", "tags_err_msg"), 107 | [("not-a-tag,also-not-a-tag", "Error: The following tags are not available: ('not-a-tag', 'also-not-a-tag').")], 108 | ) 109 | def test_forge_invalid_args( 110 | tmp_path: Path, 111 | name: str, 112 | invalid_name: str, 113 | name_err_msg: str, 114 | invalid_required: str, 115 | required_err_msg: str, 116 | invalid_optional: str, 117 | duplicated_err_msg: str, 118 | invalid_tags: str, 119 | tags_err_msg: str, 120 | ) -> None: 121 | """Tests that error messages are raised with invalid names.""" 122 | output_file = tmp_path / (f"{name.lower()}.py") 123 | assert not output_file.exists() 124 | 125 | _input = "".join( 126 | [ 127 | f"{invalid_name}\n", # name, invalid attempt 128 | f"{name}\n", # name, valid attempt 129 | "transformer\n" # type 130 | f"{invalid_required}\n", # required params, invalid attempt 131 | "a,b\n", # required params, valid attempt 132 | f"{invalid_optional}\n", # optional params, invalid attempt 133 | "c,d\n", # optional params, valid attempt 134 | "\n", # sample_weight 135 | f"{invalid_tags}\n", # tags, invalid attempt 136 | "binary_only\n", # valid attempt 137 | f"{output_file!s}\n", 138 | ] 139 | ) 140 | 141 | result = runner.invoke( 142 | app=cli, 143 | args=["forge"], 144 | input=_input, 145 | ) 146 | 147 | result = runner.invoke(cli, ["forge"], input=_input) 148 | 149 | assert result.exit_code == 0 150 | assert output_file.exists() 151 | 152 | assert all( 153 | err_msg in result.stdout for err_msg in (name_err_msg, required_err_msg, duplicated_err_msg, tags_err_msg) 154 | ) 155 | -------------------------------------------------------------------------------- /tests/test_parsers.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import pytest 4 | from result import Err, Ok, is_err, is_ok 5 | 6 | from sksmithy._parsers import check_duplicates, name_parser, params_parser, tags_parser 7 | 8 | 9 | @pytest.mark.parametrize( 10 | ("name", "checker"), 11 | [ 12 | ("valid_name", is_ok), 13 | ("ValidName", is_ok), 14 | ("123Invalid", is_err), 15 | ("class", is_err), 16 | ("", is_err), 17 | ], 18 | ) 19 | def test_name_parser(name: str, checker: Callable) -> None: 20 | result = name_parser(name) 21 | assert checker(result) 22 | 23 | 24 | @pytest.mark.parametrize( 25 | ("params", "checker", "expected"), 26 | [ 27 | (None, is_ok, []), 28 | ("a,b,c", is_ok, ["a", "b", "c"]), 29 | ("123a,b c,x", is_err, "The following parameters are invalid python identifiers: ('123a', 'b c')"), 30 | ("a,a,b", is_err, "Found repeated parameters!"), 31 | ], 32 | ) 33 | def test_params_parser(params: str, checker: Callable, expected: str) -> None: 34 | result = params_parser(params) 35 | assert checker(result) 36 | 37 | match result: 38 | case Ok(value): 39 | assert value == expected 40 | case Err(msg): 41 | assert msg == expected 42 | 43 | 44 | @pytest.mark.parametrize( 45 | ("required", "optional", "expected"), 46 | [ 47 | (["a", "b"], ["c", "d"], None), 48 | ([], ["c", "d"], None), 49 | ([], [], None), 50 | (["a", "b"], ["b", "c"], "The following parameters are duplicated between required and optional: {'b'}"), 51 | ], 52 | ) 53 | def test_check_duplicates(required: list[str], optional: list[str], expected: str) -> None: 54 | result = check_duplicates(required, optional) 55 | assert result == expected 56 | 57 | 58 | @pytest.mark.parametrize( 59 | ("tags", "checker", "expected"), 60 | [ 61 | ("allow_nan,binary_only", is_ok, ["allow_nan", "binary_only"]), 62 | ("", is_ok, []), 63 | ("some_madeup_tag", is_err, "The following tags are not available: ('some_madeup_tag',)"), 64 | ], 65 | ) 66 | def test_tags_parser(tags: str, checker: Callable, expected: str) -> None: 67 | result = tags_parser(tags) 68 | assert checker(result) 69 | match result: 70 | case Ok(value): 71 | assert value == expected 72 | case Err(msg): 73 | assert msg.startswith(expected) 74 | -------------------------------------------------------------------------------- /tests/test_render.py: -------------------------------------------------------------------------------- 1 | from sksmithy._models import EstimatorType 2 | from sksmithy._utils import render_template 3 | 4 | 5 | def test_params(name: str, required: list[str], optional: list[str]) -> None: 6 | """Tests params (both required and optional) render as expected.""" 7 | result = render_template( 8 | name=name, 9 | estimator_type=EstimatorType.ClassifierMixin, 10 | required=required, 11 | optional=optional, 12 | sample_weight=False, 13 | linear=False, 14 | predict_proba=False, 15 | decision_function=False, 16 | tags=None, 17 | ) 18 | 19 | assert all(f"self.{p} = {p}" in result for p in [*required, *optional]) 20 | assert ("self.n_iter_" in result) == ("max_iter" in [*required, *optional]) 21 | 22 | assert ("_required_parameters = " in result) == bool(required) 23 | # Not able to make a better assert work because of how f-strings render outer and inner strings 24 | # Here is what I tested assert (f'_required_parameters = {[f"{r}" for r in required]}' in result) == bool(required) 25 | # but still renders as "_required_parameters = ['a', 'b']" which is not how it is in the file 26 | 27 | 28 | def test_tags(name: str, tags: list[str] | None) -> None: 29 | """Tests tags render as expected.""" 30 | result = render_template( 31 | name=name, 32 | estimator_type=EstimatorType.ClassifierMixin, 33 | required=[], 34 | optional=[], 35 | sample_weight=False, 36 | linear=False, 37 | predict_proba=False, 38 | decision_function=False, 39 | tags=tags, 40 | ) 41 | 42 | assert ("def _more_tags(self)" in result) == bool(tags) 43 | 44 | if tags: 45 | for tag in tags: 46 | assert f'"{tag}": ...,' in result 47 | 48 | 49 | def test_common_estimator(name: str, estimator: EstimatorType, sample_weight: bool) -> None: 50 | """Tests common features are present for all estimators. Includes testing for sample_weight""" 51 | result = render_template( 52 | name=name, 53 | estimator_type=estimator, 54 | required=[], 55 | optional=[], 56 | sample_weight=sample_weight, 57 | linear=False, 58 | predict_proba=False, 59 | decision_function=False, 60 | tags=None, 61 | ) 62 | 63 | assert f"class {name}" in result 64 | assert "self.n_features_in_ = X.shape[1]" in result 65 | assert ("sample_weight = _check_sample_weight(sample_weight)" in result) == sample_weight 66 | 67 | match estimator: 68 | case EstimatorType.TransformerMixin | EstimatorType.SelectorMixin: 69 | assert "X = check_array(X, ...)" in result 70 | assert ("def fit(self, X, y=None, sample_weight=None)" in result) == (sample_weight) 71 | assert ("def fit(self, X, y=None)" in result) == (not sample_weight) 72 | case _: 73 | assert "X, y = check_X_y(X, y, ...)" in result 74 | assert ("def fit(self, X, y, sample_weight=None)" in result) == (sample_weight) 75 | assert ("def fit(self, X, y)" in result) == (not sample_weight) 76 | 77 | 78 | def test_classifier(name: str, linear: bool, predict_proba: bool, decision_function: bool) -> None: 79 | """Tests classifier specific rendering.""" 80 | estimator_type = EstimatorType.ClassifierMixin 81 | 82 | result = render_template( 83 | name=name, 84 | estimator_type=estimator_type, 85 | required=[], 86 | optional=[], 87 | sample_weight=False, 88 | linear=linear, 89 | predict_proba=predict_proba, 90 | decision_function=decision_function, 91 | tags=None, 92 | ) 93 | 94 | # Classifier specific 95 | assert "self.classes_ = " in result 96 | assert "def n_classes_(self)" in result 97 | assert "def transform(self, X)" not in result 98 | 99 | assert "def transform(self, X)" not in result 100 | 101 | # Linear 102 | assert ("class MightyEstimator(LinearClassifierMixin, BaseEstimator)" in result) == linear 103 | assert ("self.coef_ = ..." in result) == linear 104 | assert ("self.intercept_ = ..." in result) == linear 105 | 106 | assert ("class MightyEstimator(ClassifierMixin, BaseEstimator)" in result) == (not linear) 107 | assert ("def predict(self, X)" in result) == (not linear) 108 | 109 | # Predict proba 110 | assert ("def predict_proba(self, X)" in result) == predict_proba 111 | 112 | # Decision function 113 | assert ("def decision_function(self, X)" in result) == (decision_function and not linear) 114 | 115 | 116 | def test_regressor(name: str, linear: bool) -> None: 117 | """Tests regressor specific rendering.""" 118 | estimator_type = EstimatorType.RegressorMixin 119 | 120 | result = render_template( 121 | name=name, 122 | estimator_type=estimator_type, 123 | required=[], 124 | optional=[], 125 | sample_weight=False, 126 | linear=linear, 127 | predict_proba=False, 128 | decision_function=False, 129 | tags=None, 130 | ) 131 | 132 | # Regressor specific 133 | assert "def transform(self, X)" not in result 134 | 135 | # Linear 136 | assert ("class MightyEstimator(RegressorMixin, LinearModel)" in result) == linear 137 | assert ("self.coef_ = ..." in result) == linear 138 | assert ("self.intercept_ = ..." in result) == linear 139 | 140 | assert ("class MightyEstimator(RegressorMixin, BaseEstimator)" in result) == (not linear) 141 | assert ("def predict(self, X)" in result) == (not linear) 142 | 143 | 144 | def test_outlier(name: str, predict_proba: bool) -> None: 145 | """Tests outlier specific rendering.""" 146 | estimator_type = EstimatorType.OutlierMixin 147 | 148 | result = render_template( 149 | name=name, 150 | estimator_type=estimator_type, 151 | required=[], 152 | optional=[], 153 | sample_weight=False, 154 | linear=False, 155 | predict_proba=predict_proba, 156 | decision_function=False, 157 | tags=None, 158 | ) 159 | 160 | # Outlier specific 161 | assert "class MightyEstimator(OutlierMixin, BaseEstimator)" in result 162 | assert "self.offset_" in result 163 | assert "def score_samples(self, X)" in result 164 | assert "def decision_function(self, X)" in result 165 | assert "def predict(self, X)" in result 166 | 167 | assert "def transform(self, X)" not in result 168 | 169 | # Predict proba 170 | assert ("def predict_proba(self, X)" in result) == predict_proba 171 | 172 | 173 | def test_transformer(name: str) -> None: 174 | """Tests transformer specific rendering.""" 175 | estimator_type = EstimatorType.TransformerMixin 176 | 177 | result = render_template( 178 | name=name, 179 | estimator_type=estimator_type, 180 | required=[], 181 | optional=[], 182 | sample_weight=False, 183 | linear=False, 184 | predict_proba=False, 185 | decision_function=False, 186 | tags=None, 187 | ) 188 | # Transformer specific 189 | assert "class MightyEstimator(TransformerMixin, BaseEstimator)" in result 190 | assert "def transform(self, X)" in result 191 | assert "def predict(self, X)" not in result 192 | 193 | 194 | def test_feature_selector(name: str) -> None: 195 | """Tests transformer specific rendering.""" 196 | estimator_type = EstimatorType.SelectorMixin 197 | 198 | result = render_template( 199 | name=name, 200 | estimator_type=estimator_type, 201 | required=[], 202 | optional=[], 203 | sample_weight=False, 204 | linear=False, 205 | predict_proba=False, 206 | decision_function=False, 207 | tags=None, 208 | ) 209 | # Transformer specific 210 | assert "class MightyEstimator(SelectorMixin, BaseEstimator)" in result 211 | assert "def _get_support_mask(self, X)" in result 212 | assert "self.support_" in result 213 | assert "def predict(self, X)" not in result 214 | 215 | 216 | def test_cluster(name: str) -> None: 217 | """Tests cluster specific rendering.""" 218 | estimator_type = EstimatorType.ClusterMixin 219 | 220 | result = render_template( 221 | name=name, 222 | estimator_type=estimator_type, 223 | required=[], 224 | optional=[], 225 | sample_weight=False, 226 | linear=False, 227 | predict_proba=False, 228 | decision_function=False, 229 | tags=None, 230 | ) 231 | 232 | # Cluster specific 233 | assert "class MightyEstimator(ClusterMixin, BaseEstimator)" in result 234 | assert "self.labels_ = ..." in result 235 | assert "def predict(self, X)" in result 236 | -------------------------------------------------------------------------------- /tests/test_tui.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from textual.widgets import Button, Input, Select, Switch 5 | 6 | from sksmithy._models import EstimatorType 7 | from sksmithy.tui import ForgeTUI 8 | 9 | 10 | async def test_smoke() -> None: 11 | """Basic smoke test.""" 12 | app = ForgeTUI() 13 | async with app.run_test(size=None) as pilot: 14 | await pilot.pause() 15 | assert pilot is not None 16 | 17 | await pilot.pause() 18 | await pilot.exit(0) 19 | 20 | 21 | @pytest.mark.parametrize( 22 | ("name_", "err_msg"), 23 | [ 24 | ("MightyEstimator", ""), 25 | ("not-valid-name", "`not-valid-name` is not a valid python class name!"), 26 | ("class", "`class` is a python reserved keyword!"), 27 | ], 28 | ) 29 | async def test_name(name_: str, err_msg: str) -> None: 30 | """Test `name` text_input component.""" 31 | app = ForgeTUI() 32 | async with app.run_test(size=None) as pilot: 33 | name_comp = pilot.app.query_one("#name", Input) 34 | name_comp.value = name_ 35 | await pilot.pause() 36 | 37 | assert (not name_comp.is_valid) == bool(err_msg) 38 | 39 | notifications = list(pilot.app._notifications) # noqa: SLF001 40 | assert len(notifications) == int(bool(err_msg)) 41 | 42 | if notifications: 43 | assert notifications[0].message == err_msg 44 | 45 | 46 | async def test_estimator_interaction(estimator: EstimatorType) -> None: 47 | """Test that all toggle components interact correctly with the selected estimator.""" 48 | app = ForgeTUI() 49 | async with app.run_test(size=None) as pilot: 50 | pilot.app.query_one("#estimator", Select).value = estimator.value 51 | await pilot.pause() 52 | 53 | assert (not pilot.app.query_one("#linear", Switch).disabled) == ( 54 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin} 55 | ) 56 | assert (not pilot.app.query_one("#predict_proba", Switch).disabled) == ( 57 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin} 58 | ) 59 | 60 | assert (not pilot.app.query_one("#decision_function", Switch).disabled) == ( 61 | estimator == EstimatorType.ClassifierMixin 62 | ) 63 | 64 | if estimator == EstimatorType.ClassifierMixin: 65 | linear = pilot.app.query_one("#linear", Switch) 66 | linear.value = True 67 | 68 | await pilot.pause() 69 | assert pilot.app.query_one("#decision_function", Switch).disabled 70 | 71 | 72 | async def test_valid_params() -> None: 73 | """Test required and optional params interaction.""" 74 | app = ForgeTUI() 75 | required_ = "a,b" 76 | optional_ = "c,d" 77 | async with app.run_test(size=None) as pilot: 78 | required_comp = pilot.app.query_one("#required", Input) 79 | optional_comp = pilot.app.query_one("#optional", Input) 80 | 81 | required_comp.value = required_ 82 | optional_comp.value = optional_ 83 | 84 | await required_comp.action_submit() 85 | await optional_comp.action_submit() 86 | await pilot.pause(0.01) 87 | 88 | notifications = list(pilot.app._notifications) # noqa: SLF001 89 | assert not notifications 90 | 91 | 92 | @pytest.mark.parametrize(("required_", "optional_"), [("a,b", "a"), ("a", "a,b")]) 93 | async def test_duplicated_params(required_: str, optional_: str) -> None: 94 | app = ForgeTUI() 95 | msg = "The following parameters are duplicated between required and optional: {'a'}" 96 | 97 | async with app.run_test(size=None) as pilot: 98 | required_comp = pilot.app.query_one("#required", Input) 99 | optional_comp = pilot.app.query_one("#optional", Input) 100 | 101 | required_comp.value = required_ 102 | optional_comp.value = optional_ 103 | 104 | await required_comp.action_submit() 105 | await optional_comp.action_submit() 106 | await pilot.pause() 107 | 108 | forge_btn = pilot.app.query_one("#forge-btn", Button) 109 | forge_btn.action_press() 110 | await pilot.pause() 111 | 112 | assert all(msg in n.message for n in pilot.app._notifications) # noqa: SLF001 113 | 114 | 115 | async def test_forge_raise() -> None: 116 | """Test forge button and all of its interactions.""" 117 | app = ForgeTUI() 118 | async with app.run_test(size=None) as pilot: 119 | required_comp = pilot.app.query_one("#required", Input) 120 | optional_comp = pilot.app.query_one("#optional", Input) 121 | 122 | required_comp.value = "a,a" 123 | optional_comp.value = "b b" 124 | 125 | await required_comp.action_submit() 126 | await optional_comp.action_submit() 127 | await pilot.pause() 128 | 129 | forge_btn = pilot.app.query_one("#forge-btn", Button) 130 | forge_btn.action_press() 131 | await pilot.pause() 132 | 133 | m1, m2, m3 = (n.message for n in pilot.app._notifications) # noqa: SLF001 134 | 135 | assert "Found repeated parameters!" in m1 136 | assert "The following parameters are invalid python identifiers: ('b b',)" in m2 137 | 138 | assert "Name cannot be empty!" in m3 139 | assert "Estimator cannot be empty!" in m3 140 | assert "Found repeated parameters!" in m3 141 | assert "The following parameters are invalid python identifiers: ('b b',)" in m3 142 | 143 | 144 | @pytest.mark.parametrize("use_binding", [True, False]) 145 | async def test_forge_and_save(tmp_path: Path, name: str, estimator: EstimatorType, use_binding: bool) -> None: 146 | """Test forge button and all of its interactions.""" 147 | app = ForgeTUI() 148 | async with app.run_test(size=None) as pilot: 149 | name_comp = pilot.app.query_one("#name", Input) 150 | estimator_comp = pilot.app.query_one("#estimator", Select) 151 | await pilot.pause() 152 | 153 | output_file_comp = pilot.app.query_one("#output-file", Input) 154 | 155 | name_comp.value = name 156 | estimator_comp.value = estimator.value 157 | 158 | await pilot.pause() 159 | 160 | output_file = tmp_path / (f"{name.lower()}.py") 161 | output_file_comp.value = str(output_file) 162 | await output_file_comp.action_submit() 163 | await pilot.pause() 164 | 165 | if use_binding: 166 | await pilot.press("F") 167 | else: 168 | forge_btn = pilot.app.query_one("#forge-btn", Button) 169 | forge_btn.action_press() 170 | await pilot.pause() 171 | 172 | if use_binding: 173 | await pilot.press("ctrl+s") 174 | else: 175 | save_btn = pilot.app.query_one("#save-btn", Button) 176 | save_btn.action_press() 177 | await pilot.pause() 178 | 179 | m1, m2 = (n.message for n in pilot.app._notifications) # noqa: SLF001 180 | 181 | assert "Template forged!" in m1 182 | assert "Saved at" in m2 183 | 184 | assert output_file.exists() 185 | --------------------------------------------------------------------------------