7 | * @version 0.0.1
8 | * @license MIT
9 | */
10 |
11 | 'use strict';
12 |
13 | /** Generate a terminal widget. */
14 | class Termynal {
15 | /**
16 | * Construct the widget's settings.
17 | * @param {(string|Node)=} container - Query selector or container element.
18 | * @param {Object=} options - Custom settings.
19 | * @param {string} options.prefix - Prefix to use for data attributes.
20 | * @param {number} options.startDelay - Delay before animation, in ms.
21 | * @param {number} options.typeDelay - Delay between each typed character, in ms.
22 | * @param {number} options.lineDelay - Delay between each line, in ms.
23 | * @param {number} options.progressLength - Number of characters displayed as progress bar.
24 | * @param {string} options.progressChar – Character to use for progress bar, defaults to █.
25 | * @param {number} options.progressPercent - Max percent of progress.
26 | * @param {string} options.cursor – Character to use for cursor, defaults to ▋.
27 | * @param {Object[]} lineData - Dynamically loaded line data objects.
28 | * @param {boolean} options.noInit - Don't initialise the animation.
29 | */
30 | constructor(container = '#termynal', options = {}) {
31 | this.container = (typeof container === 'string') ? document.querySelector(container) : container;
32 | this.pfx = `data-${options.prefix || 'ty'}`;
33 | this.originalStartDelay = this.startDelay = options.startDelay
34 | || parseFloat(this.container.getAttribute(`${this.pfx}-startDelay`)) || 600;
35 | this.originalTypeDelay = this.typeDelay = options.typeDelay
36 | || parseFloat(this.container.getAttribute(`${this.pfx}-typeDelay`)) || 90;
37 | this.originalLineDelay = this.lineDelay = options.lineDelay
38 | || parseFloat(this.container.getAttribute(`${this.pfx}-lineDelay`)) || 1500;
39 | this.progressLength = options.progressLength
40 | || parseFloat(this.container.getAttribute(`${this.pfx}-progressLength`)) || 40;
41 | this.progressChar = options.progressChar
42 | || this.container.getAttribute(`${this.pfx}-progressChar`) || '█';
43 | this.progressPercent = options.progressPercent
44 | || parseFloat(this.container.getAttribute(`${this.pfx}-progressPercent`)) || 100;
45 | this.cursor = options.cursor
46 | || this.container.getAttribute(`${this.pfx}-cursor`) || '▋';
47 | this.lineData = this.lineDataToElements(options.lineData || []);
48 | this.loadLines()
49 | if (!options.noInit) this.init()
50 | }
51 |
52 | loadLines() {
53 | // Load all the lines and create the container so that the size is fixed
54 | // Otherwise it would be changing and the user viewport would be constantly
55 | // moving as she/he scrolls
56 | const finish = this.generateFinish()
57 | finish.style.visibility = 'hidden'
58 | this.container.appendChild(finish)
59 | // Appends dynamically loaded lines to existing line elements.
60 | this.lines = [...this.container.querySelectorAll(`[${this.pfx}]`)].concat(this.lineData);
61 | for (let line of this.lines) {
62 | line.style.visibility = 'hidden'
63 | this.container.appendChild(line)
64 | }
65 | const restart = this.generateRestart()
66 | restart.style.visibility = 'hidden'
67 | this.container.appendChild(restart)
68 | this.container.setAttribute('data-termynal', '');
69 | }
70 |
71 | /**
72 | * Initialise the widget, get lines, clear container and start animation.
73 | */
74 | init() {
75 | /**
76 | * Calculates width and height of Termynal container.
77 | * If container is empty and lines are dynamically loaded, defaults to browser `auto` or CSS.
78 | */
79 | const containerStyle = getComputedStyle(this.container);
80 | this.container.style.width = containerStyle.width !== '0px' ?
81 | containerStyle.width : undefined;
82 | this.container.style.minHeight = containerStyle.height !== '0px' ?
83 | containerStyle.height : undefined;
84 |
85 | this.container.setAttribute('data-termynal', '');
86 | this.container.innerHTML = '';
87 | for (let line of this.lines) {
88 | line.style.visibility = 'visible'
89 | }
90 | this.start();
91 | }
92 |
93 | /**
94 | * Start the animation and rener the lines depending on their data attributes.
95 | */
96 | async start() {
97 | this.addFinish()
98 | await this._wait(this.startDelay);
99 |
100 | for (let line of this.lines) {
101 | const type = line.getAttribute(this.pfx);
102 | const delay = line.getAttribute(`${this.pfx}-delay`) || this.lineDelay;
103 |
104 | if (type == 'input') {
105 | line.setAttribute(`${this.pfx}-cursor`, this.cursor);
106 | await this.type(line);
107 | await this._wait(delay);
108 | }
109 |
110 | else if (type == 'progress') {
111 | await this.progress(line);
112 | await this._wait(delay);
113 | }
114 |
115 | else {
116 | this.container.appendChild(line);
117 | await this._wait(delay);
118 | }
119 |
120 | line.removeAttribute(`${this.pfx}-cursor`);
121 | }
122 | this.addRestart()
123 | this.finishElement.style.visibility = 'hidden'
124 | this.lineDelay = this.originalLineDelay
125 | this.typeDelay = this.originalTypeDelay
126 | this.startDelay = this.originalStartDelay
127 | }
128 |
129 | generateRestart() {
130 | const restart = document.createElement('a')
131 | restart.onclick = (e) => {
132 | e.preventDefault()
133 | this.container.innerHTML = ''
134 | this.init()
135 | }
136 | restart.href = '#'
137 | restart.setAttribute('data-terminal-control', '')
138 | restart.innerHTML = "restart ↻"
139 | return restart
140 | }
141 |
142 | generateFinish() {
143 | const finish = document.createElement('a')
144 | finish.onclick = (e) => {
145 | e.preventDefault()
146 | this.lineDelay = 0
147 | this.typeDelay = 0
148 | this.startDelay = 0
149 | }
150 | finish.href = '#'
151 | finish.setAttribute('data-terminal-control', '')
152 | finish.innerHTML = "fast →"
153 | this.finishElement = finish
154 | return finish
155 | }
156 |
157 | addRestart() {
158 | const restart = this.generateRestart()
159 | this.container.appendChild(restart)
160 | }
161 |
162 | addFinish() {
163 | const finish = this.generateFinish()
164 | this.container.appendChild(finish)
165 | }
166 |
167 | /**
168 | * Animate a typed line.
169 | * @param {Node} line - The line element to render.
170 | */
171 | async type(line) {
172 | const chars = [...line.textContent];
173 | line.textContent = '';
174 | this.container.appendChild(line);
175 |
176 | for (let char of chars) {
177 | const delay = line.getAttribute(`${this.pfx}-typeDelay`) || this.typeDelay;
178 | await this._wait(delay);
179 | line.textContent += char;
180 | }
181 | }
182 |
183 | /**
184 | * Animate a progress bar.
185 | * @param {Node} line - The line element to render.
186 | */
187 | async progress(line) {
188 | const progressLength = line.getAttribute(`${this.pfx}-progressLength`)
189 | || this.progressLength;
190 | const progressChar = line.getAttribute(`${this.pfx}-progressChar`)
191 | || this.progressChar;
192 | const chars = progressChar.repeat(progressLength);
193 | const progressPercent = line.getAttribute(`${this.pfx}-progressPercent`)
194 | || this.progressPercent;
195 | line.textContent = '';
196 | this.container.appendChild(line);
197 |
198 | for (let i = 1; i < chars.length + 1; i++) {
199 | await this._wait(this.typeDelay);
200 | const percent = Math.round(i / chars.length * 100);
201 | line.textContent = `${chars.slice(0, i)} ${percent}%`;
202 | if (percent>progressPercent) {
203 | break;
204 | }
205 | }
206 | }
207 |
208 | /**
209 | * Helper function for animation delays, called with `await`.
210 | * @param {number} time - Timeout, in ms.
211 | */
212 | _wait(time) {
213 | return new Promise(resolve => setTimeout(resolve, time));
214 | }
215 |
216 | /**
217 | * Converts line data objects into line elements.
218 | *
219 | * @param {Object[]} lineData - Dynamically loaded lines.
220 | * @param {Object} line - Line data object.
221 | * @returns {Element[]} - Array of line elements.
222 | */
223 | lineDataToElements(lineData) {
224 | return lineData.map(line => {
225 | let div = document.createElement('div');
226 | div.innerHTML = `${line.value || ''}`;
227 |
228 | return div.firstElementChild;
229 | });
230 | }
231 |
232 | /**
233 | * Helper function for generating attributes string.
234 | *
235 | * @param {Object} line - Line data object.
236 | * @returns {string} - String of attributes.
237 | */
238 | _attributes(line) {
239 | let attrs = '';
240 | for (let prop in line) {
241 | // Custom add class
242 | if (prop === 'class') {
243 | attrs += ` class=${line[prop]} `
244 | continue
245 | }
246 | if (prop === 'type') {
247 | attrs += `${this.pfx}="${line[prop]}" `
248 | } else if (prop !== 'value') {
249 | attrs += `${this.pfx}-${prop}="${line[prop]}" `
250 | }
251 | }
252 |
253 | return attrs;
254 | }
255 | }
256 |
257 | /**
258 | * HTML API: If current script has container(s) specified, initialise Termynal.
259 | */
260 | if (document.currentScript.hasAttribute('data-termynal-container')) {
261 | const containers = document.currentScript.getAttribute('data-termynal-container');
262 | containers.split('|')
263 | .forEach(container => new Termynal(container))
264 | }
265 |
--------------------------------------------------------------------------------
/docs/user-guide.md:
--------------------------------------------------------------------------------
1 | # User Guide 📚
2 |
3 | As introduced in the [home page](index.md), **sklearn-smithy** is a tool that helps you to forge scikit-learn compatible estimator with ease, and it comes in three flavours.
4 |
5 | Let's see how to use each one of them.
6 |
7 | ## Web UI 🌐
8 |
9 | TL;DR:
10 |
11 | - [x] Available at [sklearn-smithy.streamlit.app](https://sklearn-smithy.streamlit.app/){:target="_blank"}
12 | - [x] It requires no installation.
13 | - [x] Powered by [streamlit](https://streamlit.io/){:target="_blank"}
14 |
15 | The web UI is the most user-friendly, low barrier way, to interact with the tool by accessing it directly from your browser, without any installation required.
16 |
17 | Once the estimator is forged, you can download the script with the code as a `.py` file, or you can copy the code directly from the browser.
18 |
19 | ??? example "Screenshot"
20 | 
21 |
22 | ## CLI ⌨️
23 |
24 | TL;DR:
25 |
26 | - [x] Available via the `smith forge` command.
27 | - [x] It requires [installation](installation.md): `python -m pip install sklearn-smithy`
28 | - [x] Powered by [typer](https://typer.tiangolo.com/){:target="_blank"}.
29 |
30 | Once the library is installed, the `smith` CLI (Command Line Interface) will be available and that is the primary way to interact with the `smithy` package.
31 |
32 | The CLI provides a main command called `forge`, which will prompt a series of question in the terminal, based on which it will generate the code for the estimator.
33 |
34 | ### `smith forge` example
35 |
36 | Let's see an example of how to use `smith forge` command:
37 |
38 |
39 |
40 | ```console
41 | $ smith forge
42 | # 🐍 How would you like to name the estimator?:$ MightyClassifier
43 | # 🎯 Which kind of estimator is it? (classifier, outlier, regressor, transformer, cluster, feature-selector):$ classifier
44 | # 📜 Please list the required parameters (comma-separated) []:$ alpha,beta
45 | # 📑 Please list the optional parameters (comma-separated) []:$ mu,sigma
46 | # 📶 Does the `.fit()` method support `sample_weight`? [y/N]:$ y
47 | # 📏 Is the estimator linear? [y/N]:$ N
48 | # 🎲 Should the estimator implement a `predict_proba` method? [y/N]:$ N
49 | # ❓ Should the estimator implement a `decision_function` method? [y/N]:$ y
50 | # 🧪 We are almost there... Is there any tag you want to add? (comma-separated) []:$ binary_only,non_deterministic
51 | # 📂 Where would you like to save the class? [mightyclassifier.py]:$ path/to/file.py
52 | Template forged at path/to/file.py
53 | ```
54 |
55 |
56 |
57 | Now the estimator template to be filled will be available at the specified path `path/to/file.py`.
58 |
59 |
60 |
61 | ```console
62 | $ cat path/to/file.py | head -n 5
63 | import numpy as np
64 |
65 | from sklearn.base import BaseEstimator, ClassifierMixin
66 | from sklearn.utils import check_X_y
67 | from sklearn.utils.validation import check_is_fitted, check_array
68 | ```
69 |
70 |
71 |
72 | ### Non-interactive mode
73 |
74 | As for any CLI, in principle it would be possible to run it in a non-interactive way, however this is not *fully* supported (yet) and it comes with some risks and limitations.
75 |
76 | The reason for this is that the **validation** and the parameters **interaction** happen while prompting the questions *one after the other*, meaning that the input to one prompt will determine what follows next.
77 |
78 | It is still possible to run the CLI in a non-interactive way, but it is not recommended, as it may lead to unexpected results.
79 |
80 | Let's see an example of how to run the `smith forge` command in a non-interactive way:
81 |
82 | !!! example "Non-interactive mode"
83 |
84 | ```terminal
85 | smith forge \
86 | --name MyEstimator \
87 | --estimator-type classifier \
88 | --required-params "a,b" \
89 | --optional-params "" \
90 | --no-sample-weight \
91 | --no-predict-proba \
92 | --linear \
93 | --no-decision-function \
94 | --tags "binary_only" \
95 | --output-file path/to/file.py
96 | ```
97 |
98 | Notice how all arguments must be specified, otherwise they will prompt anyway, which means that the command would be interactive.
99 |
100 | Secondly, there is nothing preventing us to run the command with contradictory arguments at the same time. Operating in such a way can lead to two scenarios:
101 |
102 | 1. The result will be correct, however unexpected from a user point of view.
103 | For instance, calling `--estimator-type classifier` with `--linear` and `--decision-function` flags, will not create a `decision_function` method, as `LinearClassifierMixin` already takes care of it.
104 | 2. The result will be incorrect, as the arguments are contradictory.
105 |
106 | The first case is not a problematic from a functional point of view, while the second will lead to a broken estimator.
107 |
108 | Our suggestion is to use the CLI always in an interactive way, as it will take care of the proprer arguments interaction.
109 |
110 | ## TUI 💻
111 |
112 | TL;DR:
113 |
114 | - [x] Available via the `smith forge-tui` command.
115 | - [x] It requires installing [extra dependencies](installation.md#extra-dependencies): `python -m pip install "sklearn-smithy[textual]"`
116 | - [x] Powered by [textual](https://textual.textualize.io/){:target="_blank"}.
117 |
118 | If you like the CLI, but prefer a more interactive and graphical way from the comfort of your terminal, you can use the TUI (Terminal User Interface) provided by the `smith forge-tui` command.
119 |
120 | ```console
121 | $ smith forge-tui
122 | ```
123 |
124 | ```{.textual path="sksmithy/tui/_tui.py" columns="200" lines="35"}
125 | ```
126 |
--------------------------------------------------------------------------------
/docs/why.md:
--------------------------------------------------------------------------------
1 | # Why❓
2 |
3 | Writing scikit-learn compatible estimators might be harder than expected.
4 |
5 | While everyone knows about the `fit` and `predict`, there are other behaviours, methods and attributes that
6 | scikit-learn might be expecting from your estimator depending on:
7 |
8 | - The type of estimator you're writing.
9 | - The signature of the estimator.
10 | - The signature of the `.fit(...)` method.
11 |
12 | Scikit-learn Smithy to the rescue: this tool aims to help you crafting your own estimator by asking a few
13 | questions about it, and then generating the boilerplate code.
14 |
15 | In this way you will be able to fully focus on the core implementation logic, and not on nitty-gritty details
16 | of the scikit-learn API.
17 |
18 | ## Sanity check
19 |
20 | Once the core logic is implemented, the estimator should be ready to test against the _somewhat official_
21 | [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks){:target="_blank"}
22 | pytest compatible decorator:
23 |
24 | ```py
25 | from sklearn.utils.estimator_checks import parametrize_with_checks
26 |
27 | @parametrize_with_checks([
28 | YourAwesomeRegressor,
29 | MoreAwesomeClassifier,
30 | EvenMoreAwesomeTransformer,
31 | ])
32 | def test_sklearn_compatible_estimator(estimator, check):
33 | check(estimator)
34 | ```
35 |
36 | and it should be compatible with scikit-learn Pipeline, GridSearchCV, etc.
37 |
38 | ## Official guide
39 |
40 | Scikit-learn documentation on how to
41 | [develop estimators](https://scikit-learn.org/dev/developers/develop.html#developing-scikit-learn-estimators){:target="_blank"}.
42 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | # Project information
2 | site_name: Sklearn Smithy
3 | site_url: https://fbruzzesi.github.io/sklearn-smithy/
4 | site_author: Francesco Bruzzesi
5 | site_description: Toolkit to forge scikit-learn compatible estimators
6 |
7 | # Repository information
8 | repo_name: FBruzzesi/sklearn-smithy
9 | repo_url: https://github.com/fbruzzesi/sklearn-smithy
10 | edit_uri: edit/main/docs/
11 |
12 | # Configuration
13 | use_directory_urls: true
14 | theme:
15 | name: material
16 | font: false
17 | palette:
18 | - media: '(prefers-color-scheme: light)'
19 | scheme: default
20 | primary: teal
21 | accent: deep-orange
22 | toggle:
23 | icon: material/brightness-7
24 | name: Switch to light mode
25 | - media: '(prefers-color-scheme: dark)'
26 | scheme: slate
27 | primary: teal
28 | accent: deep-orange
29 | toggle:
30 | icon: material/brightness-4
31 | name: Switch to dark mode
32 | features:
33 | - search.suggest
34 | - search.highlight
35 | - search.share
36 | - toc.follow
37 | - content.tabs.link
38 | - content.code.annotate
39 | - content.code.copy
40 |
41 | logo: img/sksmith-logo.png
42 | favicon: img/sksmith-logo.png
43 |
44 | # Plugins
45 | plugins:
46 | - search:
47 | separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])'
48 |
49 | # Customization
50 | extra:
51 | social:
52 | - icon: fontawesome/brands/github
53 | link: https://github.com/fbruzzesi
54 | - icon: fontawesome/brands/linkedin
55 | link: https://www.linkedin.com/in/francesco-bruzzesi/
56 | - icon: fontawesome/brands/python
57 | link: https://pypi.org/project/sklearn-smithy/
58 |
59 | # Extensions
60 | markdown_extensions:
61 | - abbr
62 | - admonition
63 | - attr_list
64 | - codehilite
65 | - def_list
66 | - footnotes
67 | - md_in_html
68 | - toc:
69 | permalink: true
70 | - pymdownx.inlinehilite
71 | - pymdownx.snippets
72 | - pymdownx.superfences:
73 | custom_fences:
74 | - name: textual
75 | class: textual
76 | format: !!python/name:textual._doc.format_svg
77 | - pymdownx.details
78 | - pymdownx.tasklist:
79 | custom_checkbox: true
80 | - pymdownx.tabbed:
81 | alternate_style: true
82 | - pymdownx.highlight:
83 | anchor_linenums: true
84 | line_spans: __span
85 | pygments_lang_class: true
86 |
87 | nav:
88 | - Home 🏠: index.md
89 | - Installation ✨: installation.md
90 | - Why ❓: why.md
91 | - User Guide 📚: user-guide.md
92 | - Contributing 👏: contribute.md
93 |
94 | extra_css:
95 | - css/termynal.css
96 | - css/custom.css
97 | extra_javascript:
98 | - js/termynal.js
99 | - js/custom.js
100 |
--------------------------------------------------------------------------------
/noxfile.py:
--------------------------------------------------------------------------------
1 | import nox
2 | from nox.sessions import Session
3 |
4 | nox.options.default_venv_backend = "uv"
5 | nox.options.reuse_venv = True
6 |
7 | PYTHON_VERSIONS = ["3.10", "3.11", "3.12"]
8 |
9 |
10 | @nox.session(python=PYTHON_VERSIONS) # type: ignore[misc]
11 | @nox.parametrize("pre", [False, True])
12 | def pytest_coverage(session: Session, pre: bool) -> None:
13 | """Run pytest coverage across different python versions."""
14 | pkg_install = [".[all]", "-r", "requirements/test.txt"]
15 |
16 | if pre:
17 | pkg_install.append("--pre")
18 |
19 | session.install(*pkg_install)
20 |
21 | session.run("pytest", "tests", "--cov=sksmithy", "--cov=tests", "--cov-fail-under=90", "--numprocesses=auto")
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "sklearn-smithy"
7 | version = "0.2.0"
8 | description = "Toolkit to forge scikit-learn compatible estimators."
9 | requires-python = ">=3.10"
10 |
11 | license = {file = "LICENSE"}
12 | readme = "README.md"
13 |
14 | authors = [
15 | {name = "Francesco Bruzzesi"}
16 | ]
17 |
18 | keywords = [
19 | "python",
20 | "cli",
21 | "webui",
22 | "tui",
23 | "data-science",
24 | "machine-learning",
25 | "scikit-learn"
26 | ]
27 |
28 | dependencies = [
29 | "typer>=0.12.0",
30 | "rich>=13.0.0",
31 | "jinja2>=3.0.0",
32 | "result>=0.16.0",
33 | "ruff>=0.4.0",
34 | ]
35 |
36 | classifiers = [
37 | "Development Status :: 4 - Beta",
38 | "License :: OSI Approved :: MIT License",
39 | "Topic :: Software Development :: Libraries :: Python Modules",
40 | "Typing :: Typed",
41 | "Programming Language :: Python :: 3",
42 | "Programming Language :: Python :: 3.10",
43 | "Programming Language :: Python :: 3.11",
44 | "Programming Language :: Python :: 3.12",
45 | ]
46 |
47 | [project.urls]
48 | Repository = "https://github.com/FBruzzesi/sklearn-smithy"
49 | Issues = "https://github.com/FBruzzesi/sklearn-smithy/issues"
50 | Documentation = "https://fbruzzesi.github.io/sklearn-smithy"
51 | Website = "https://sklearn-smithy.streamlit.app/"
52 |
53 |
54 | [project.optional-dependencies]
55 | streamlit = ["streamlit>=1.34.0"]
56 | textual = ["textual[syntax]>=0.65.0"]
57 |
58 | all = [
59 | "streamlit>=1.34.0",
60 | "textual>=0.65.0",
61 | ]
62 |
63 | [project.scripts]
64 | smith = "sksmithy.__main__:cli"
65 |
66 | [tool.hatch.build.targets.sdist]
67 | only-include = ["sksmithy"]
68 |
69 | [tool.hatch.build.targets.wheel]
70 | packages = ["sksmithy"]
71 |
72 | [tool.ruff]
73 | line-length = 120
74 | target-version = "py310"
75 |
76 | [tool.ruff.lint]
77 | select = ["ALL"]
78 | ignore = [
79 | "COM812",
80 | "ISC001",
81 | "PLR0913",
82 | "FBT001",
83 | "FBT002",
84 | "S603",
85 | "S607",
86 | "D100",
87 | "D104",
88 | "D400",
89 | ]
90 |
91 | [tool.ruff.lint.per-file-ignores]
92 | "tests/*" = ["D103","S101"]
93 |
94 | [tool.ruff.lint.pydocstyle]
95 | convention = "numpy"
96 |
97 | [tool.ruff.lint.pyupgrade]
98 | keep-runtime-typing = true
99 |
100 | [tool.ruff.format]
101 | docstring-code-format = true
102 |
103 | [tool.mypy]
104 | ignore_missing_imports = true
105 | python_version = "3.10"
106 |
107 | [tool.coverage.run]
108 | source = ["sksmithy/"]
109 | omit = [
110 | "sksmithy/__main__.py",
111 | "sksmithy/_arguments.py",
112 | "sksmithy/_logger.py",
113 | "sksmithy/_prompts.py",
114 | "sksmithy/tui/__init__.py",
115 | ]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Used by streamlit deployment
2 | -e ."[streamlit]"
--------------------------------------------------------------------------------
/requirements/test.txt:
--------------------------------------------------------------------------------
1 | anyio
2 | pytest
3 | pytest-asyncio
4 | pytest-cov
5 | pytest-tornasync
6 | pytest-trio
7 | pytest-xdist
--------------------------------------------------------------------------------
/sksmithy/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import metadata
2 |
3 | __title__ = "sksmithy"
4 | __version__ = metadata.version("sklearn-smithy")
5 |
--------------------------------------------------------------------------------
/sksmithy/__main__.py:
--------------------------------------------------------------------------------
1 | from sksmithy.cli import cli
2 |
3 | if __name__ == "__main__":
4 | cli()
5 |
--------------------------------------------------------------------------------
/sksmithy/_arguments.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from typer import Option
4 |
5 | from sksmithy._callbacks import estimator_callback, linear_callback, name_callback, params_callback, tags_callback
6 | from sksmithy._models import EstimatorType
7 | from sksmithy._prompts import (
8 | PROMPT_DECISION_FUNCTION,
9 | PROMPT_ESTIMATOR,
10 | PROMPT_LINEAR,
11 | PROMPT_NAME,
12 | PROMPT_OPTIONAL,
13 | PROMPT_OUTPUT,
14 | PROMPT_PREDICT_PROBA,
15 | PROMPT_REQUIRED,
16 | PROMPT_SAMPLE_WEIGHT,
17 | PROMPT_TAGS,
18 | )
19 |
20 | name_arg = Annotated[
21 | str,
22 | Option(
23 | prompt=PROMPT_NAME,
24 | help="[bold green]Name[/bold green] of the estimator",
25 | callback=name_callback,
26 | ),
27 | ]
28 |
29 | estimator_type_arg = Annotated[
30 | EstimatorType,
31 | Option(
32 | prompt=PROMPT_ESTIMATOR,
33 | help="[bold green]Estimator type[/bold green]",
34 | callback=estimator_callback,
35 | ),
36 | ]
37 |
38 | required_params_arg = Annotated[
39 | str,
40 | Option(
41 | prompt=PROMPT_REQUIRED,
42 | help="List of [italic yellow](comma-separated)[/italic yellow] [bold green]required[/bold green] parameters",
43 | callback=params_callback,
44 | ),
45 | ]
46 |
47 | optional_params_arg = Annotated[
48 | str,
49 | Option(
50 | prompt=PROMPT_OPTIONAL,
51 | help="List of [italic yellow](comma-separated)[/italic yellow] [bold green]optional[/bold green] parameters",
52 | callback=params_callback,
53 | ),
54 | ]
55 |
56 | sample_weight_arg = Annotated[
57 | bool,
58 | Option(
59 | is_flag=True,
60 | prompt=PROMPT_SAMPLE_WEIGHT,
61 | help="Whether or not `.fit()` supports [bold green]`sample_weight`[/bold green]",
62 | ),
63 | ]
64 |
65 | linear_arg = Annotated[
66 | bool,
67 | Option(
68 | is_flag=True,
69 | prompt=PROMPT_LINEAR,
70 | help="Whether or not the estimator is [bold green]linear[/bold green]",
71 | callback=linear_callback,
72 | ),
73 | ]
74 |
75 | predict_proba_arg = Annotated[
76 | bool,
77 | Option(
78 | is_flag=True,
79 | prompt=PROMPT_PREDICT_PROBA,
80 | help="Whether or not the estimator implements [bold green]`predict_proba`[/bold green] method",
81 | ),
82 | ]
83 |
84 | decision_function_arg = Annotated[
85 | bool,
86 | Option(
87 | is_flag=True,
88 | prompt=PROMPT_DECISION_FUNCTION,
89 | help="Whether or not the estimator implements [bold green]`decision_function`[/bold green] method",
90 | ),
91 | ]
92 |
93 | tags_arg = Annotated[
94 | str,
95 | Option(
96 | prompt=PROMPT_TAGS,
97 | help="List of optional extra scikit-learn [bold green]tags[/bold green]",
98 | callback=tags_callback,
99 | ),
100 | ]
101 |
102 | output_file_arg = Annotated[
103 | str,
104 | Option(
105 | prompt=PROMPT_OUTPUT,
106 | help="[bold green]Destination file[/bold green] where to save the boilerplate code",
107 | ),
108 | ]
109 |
--------------------------------------------------------------------------------
/sksmithy/_callbacks.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 | from typing import Concatenate, ParamSpec, TypeVar
3 |
4 | from result import Err, Ok, Result
5 | from typer import BadParameter, CallbackParam, Context
6 |
7 | from sksmithy._models import EstimatorType
8 | from sksmithy._parsers import check_duplicates, name_parser, params_parser, tags_parser
9 |
10 | T = TypeVar("T")
11 | R = TypeVar("R")
12 | PS = ParamSpec("PS")
13 |
14 |
15 | def _parse_wrapper(
16 | ctx: Context,
17 | param: CallbackParam,
18 | value: T,
19 | parser: Callable[Concatenate[T, PS], Result[R, str]],
20 | *args: PS.args,
21 | **kwargs: PS.kwargs,
22 | ) -> tuple[Context, CallbackParam, R]:
23 | """Wrap a parser to handle 'caching' logic.
24 |
25 | `parser` should return a Result[R, str]
26 |
27 | Parameters
28 | ----------
29 | ctx
30 | Typer context.
31 | param
32 | Callback parameter information.
33 | value
34 | Input for the parser callable.
35 | parser
36 | Parser function, it should return Result[R, str]
37 | *args
38 | Extra args for `parser`.
39 | **kwargs
40 | Extra kwargs for `parser`.
41 |
42 | Returns
43 | -------
44 | ctx : Context
45 | Typer context updated with extra information.
46 | param : CallbackParam
47 | Unchanged callback parameters.
48 | result_value : R
49 | Parsed value.
50 |
51 | Raises
52 | ------
53 | BadParameter
54 | If parser returns Err(msg)
55 | """
56 | if not ctx.obj:
57 | ctx.obj = {}
58 |
59 | if param.name in ctx.obj:
60 | return ctx, param, ctx.obj[param.name]
61 |
62 | result = parser(value, *args, **kwargs)
63 | match result:
64 | case Ok(result_value):
65 | ctx.obj[param.name] = result_value
66 | return ctx, param, result_value
67 | case Err(msg):
68 | raise BadParameter(msg)
69 |
70 |
71 | def name_callback(ctx: Context, param: CallbackParam, value: str) -> str:
72 | """`name` argument callback.
73 |
74 | After parsing `name`, changes the default value of `output_file` argument to `{name.lower()}.py`.
75 | """
76 | *_, name = _parse_wrapper(ctx, param, value, name_parser)
77 |
78 | # Change default value of output_file argument
79 | all_options = ctx.command.params
80 | output_file_option = next(opt for opt in all_options if opt.name == "output_file")
81 | output_file_option.default = f"{name.lower()}.py"
82 |
83 | return name
84 |
85 |
86 | def params_callback(ctx: Context, param: CallbackParam, value: str) -> list[str]:
87 | """`required_params` and `optional_params` arguments callback."""
88 | ctx, param, parsed_params = _parse_wrapper(ctx, param, value, params_parser)
89 |
90 | if param.name == "optional_params" and (
91 | msg := check_duplicates(
92 | required=ctx.params["required_params"],
93 | optional=parsed_params,
94 | )
95 | ):
96 | del ctx.obj[param.name]
97 | raise BadParameter(msg)
98 |
99 | return parsed_params
100 |
101 |
102 | def tags_callback(ctx: Context, param: CallbackParam, value: str) -> list[str]:
103 | """`tags` argument callback."""
104 | *_, parsed_value = _parse_wrapper(ctx, param, value, tags_parser)
105 | return parsed_value
106 |
107 |
108 | def estimator_callback(ctx: Context, param: CallbackParam, estimator: EstimatorType) -> str:
109 | """`estimator_type` argument callback.
110 |
111 | It dynamically modifies the behaviour of the rest of the prompts based on its value:
112 |
113 | - If not classifier or regressor, turns off linear prompt.
114 | - If not classifier or outlier, turns off predict_proba prompt.
115 | - If not classifier, turns off decision_function prompt.
116 | """
117 | if not ctx.obj: # pragma: no cover
118 | ctx.obj = {}
119 |
120 | if param.name in ctx.obj:
121 | return ctx.obj[param.name]
122 |
123 | # !Warning: This unpacking relies on the order of the arguments in the forge command to be in the same order.
124 | # Is there a better/more robust way of dealing with it?
125 | linear, predict_proba, decision_function = (
126 | opt for opt in ctx.command.params if opt.name in {"linear", "predict_proba", "decision_function"}
127 | )
128 |
129 | match estimator:
130 | case EstimatorType.ClassifierMixin | EstimatorType.RegressorMixin:
131 | pass
132 | case _:
133 | linear.prompt = False # type: ignore[attr-defined]
134 | linear.prompt_required = False # type: ignore[attr-defined]
135 |
136 | match estimator:
137 | case EstimatorType.ClassifierMixin | EstimatorType.OutlierMixin:
138 | pass
139 | case _:
140 | predict_proba.prompt = False # type: ignore[attr-defined]
141 | predict_proba.prompt_required = False # type: ignore[attr-defined]
142 |
143 | match estimator:
144 | case EstimatorType.ClassifierMixin:
145 | pass
146 | case _:
147 | decision_function.prompt = False # type: ignore[attr-defined]
148 | decision_function.prompt_required = False # type: ignore[attr-defined]
149 |
150 | ctx.obj[param.name] = estimator.value
151 |
152 | return estimator.value
153 |
154 |
155 | def linear_callback(ctx: Context, param: CallbackParam, linear: bool) -> bool:
156 | """`linear` argument callback.
157 |
158 | It dynamically modifies the behaviour of the rest of the prompts based on its value: if the estimator is linear,
159 | then `decision_function` method is already implemented for a classifier.
160 | """
161 | if not ctx.obj: # pragma: no cover
162 | ctx.obj = {}
163 |
164 | if param.name in ctx.obj: # pragma: no cover
165 | return ctx.obj[param.name]
166 |
167 | decision_function = next(opt for opt in ctx.command.params if opt.name == "decision_function")
168 |
169 | match linear:
170 | case True:
171 | decision_function.prompt = False # type: ignore[attr-defined]
172 | decision_function.prompt_required = False # type: ignore[attr-defined]
173 | case False:
174 | pass
175 |
176 | ctx.obj[param.name] = linear
177 |
178 | return linear
179 |
--------------------------------------------------------------------------------
/sksmithy/_logger.py:
--------------------------------------------------------------------------------
1 | from rich.console import Console
2 | from rich.theme import Theme
3 |
4 | custom_theme = Theme(
5 | {
6 | "good": "bold green",
7 | "warning": "bold yellow",
8 | "bad": "bold red",
9 | }
10 | )
11 | console = Console(theme=custom_theme)
12 |
--------------------------------------------------------------------------------
/sksmithy/_models.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class EstimatorType(str, Enum):
5 | """List of possible estimator types.
6 |
7 | The reason of naming the enum with the mixin class is to simplify and have a convenient way of using the enum to
8 | render the jinja template with the class to import.
9 | """
10 |
11 | ClassifierMixin = "classifier"
12 | RegressorMixin = "regressor"
13 | OutlierMixin = "outlier"
14 | ClusterMixin = "cluster"
15 | TransformerMixin = "transformer"
16 | SelectorMixin = "feature-selector"
17 |
18 |
19 | class TagType(str, Enum):
20 | """List of extra tags.
21 |
22 | Description of each tag is available in the dedicated section of the scikit-learn documentation:
23 | [estimator tags](https://scikit-learn.org/dev/developers/develop.html#estimator-tags).
24 | """
25 |
26 | allow_nan = "allow_nan"
27 | array_api_support = "array_api_support"
28 | binary_only = "binary_only"
29 | multilabel = "multilabel"
30 | multioutput = "multioutput"
31 | multioutput_only = "multioutput_only"
32 | no_validation = "no_validation"
33 | non_deterministic = "non_deterministic"
34 | pairwise = "pairwise"
35 | preserves_dtype = "preserves_dtype"
36 | poor_score = "poor_score"
37 | requires_fit = "requires_fit"
38 | requires_positive_X = "requires_positive_X" # noqa: N815
39 | requires_y = "requires_y"
40 | requires_positive_y = "requires_positive_y"
41 | _skip_test = "_skip_test"
42 | _xfail_checks = "_xfail_checks"
43 | stateless = "stateless"
44 | X_types = "X_types"
45 |
--------------------------------------------------------------------------------
/sksmithy/_parsers.py:
--------------------------------------------------------------------------------
1 | from keyword import iskeyword
2 |
3 | from result import Err, Ok, Result
4 |
5 | from sksmithy._models import TagType
6 |
7 |
8 | def name_parser(name: str | None) -> Result[str, str]:
9 | """Validate that `name` is a valid python class name.
10 |
11 | The parser returns `Err(...)` if:
12 |
13 | - `name` is not a valid python identifier
14 | - `name` is a python reserved keyword
15 | - `name` is empty
16 |
17 | Otherwise it returns `Ok(name)`.
18 | """
19 | if name:
20 | if not name.isidentifier():
21 | msg = f"`{name}` is not a valid python class name!"
22 | return Err(msg)
23 | if iskeyword(name):
24 | msg = f"`{name}` is a python reserved keyword!"
25 | return Err(msg)
26 | return Ok(name)
27 | msg = "Name cannot be empty!"
28 | return Err(msg)
29 |
30 |
31 | def params_parser(params: str | None) -> Result[list[str], str]:
32 | """Parse and validate that `params` contains valid python names.
33 |
34 | The parser first splits params on commas to get a list of strings. Then it returns `Err(...)` if:
35 |
36 | - any element in the list is not a valid python identifier
37 | - any element is repeated more than once
38 |
39 | Otherwise it returns `Ok(params.split(","))`.
40 | """
41 | param_list: list[str] = params.split(",") if params else []
42 | invalid = tuple(p for p in param_list if not p.isidentifier())
43 |
44 | if len(invalid) > 0:
45 | msg = f"The following parameters are invalid python identifiers: {invalid}"
46 | return Err(msg)
47 |
48 | if len(set(param_list)) < len(param_list):
49 | msg = "Found repeated parameters!"
50 | return Err(msg)
51 |
52 | return Ok(param_list)
53 |
54 |
55 | def check_duplicates(required: list[str], optional: list[str]) -> str | None:
56 | """Check that there are not duplicates between required and optional params."""
57 | duplicated_params = set(required).intersection(set(optional))
58 | return (
59 | f"The following parameters are duplicated between required and optional: {duplicated_params}"
60 | if duplicated_params
61 | else None
62 | )
63 |
64 |
65 | def tags_parser(tags: str) -> Result[list[str], str]:
66 | """Parse and validate `tags` by comparing with sklearn list.
67 |
68 | The parser first splits tags on commas to get a list of strings. Then it returns `Err(...)` if any of the tag is not
69 | in the scikit-learn supported list.
70 |
71 | Otherwise it returns `Ok(tags.split(","))`
72 | """
73 | list_tag: list[str] = tags.split(",") if tags else []
74 |
75 | unavailable_tags = tuple(t for t in list_tag if t not in TagType.__members__)
76 | if len(unavailable_tags):
77 | msg = (
78 | f"The following tags are not available: {unavailable_tags}."
79 | "\nPlease check the official documentation at "
80 | "https://scikit-learn.org/dev/developers/develop.html#estimator-tags"
81 | " to know which values are available."
82 | )
83 |
84 | return Err(msg)
85 |
86 | return Ok(list_tag)
87 |
--------------------------------------------------------------------------------
/sksmithy/_prompts.py:
--------------------------------------------------------------------------------
1 | from typing import Final
2 |
3 | PROMPT_NAME: Final[str] = "🐍 How would you like to name the estimator?"
4 | PROMPT_ESTIMATOR: Final[str] = "🎯 Which kind of estimator is it?"
5 | PROMPT_REQUIRED: Final[str] = "📜 Please list the required parameters (comma-separated)"
6 | PROMPT_OPTIONAL: Final[str] = "📑 Please list the optional parameters (comma-separated)"
7 | PROMPT_SAMPLE_WEIGHT: Final[str] = "📶 Does the `.fit()` method support `sample_weight`?"
8 | PROMPT_LINEAR: Final[str] = "📏 Is the estimator linear?"
9 | PROMPT_PREDICT_PROBA: Final[str] = "🎲 Should the estimator implement a `predict_proba` method?"
10 | PROMPT_DECISION_FUNCTION: Final[str] = "❓ Should the estimator implement a `decision_function` method?"
11 | PROMPT_TAGS: Final[str] = (
12 | "🧪 We are almost there... Is there any tag you want to add? (comma-separated)\n"
13 | "To know more about tags, check the documentation at:\n"
14 | "https://scikit-learn.org/dev/developers/develop.html#estimator-tags"
15 | )
16 | PROMPT_OUTPUT: Final[str] = "📂 Where would you like to save the class?"
17 |
--------------------------------------------------------------------------------
/sksmithy/_static/description.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Writing scikit-learn compatible estimators might be harder than expected.
4 |
5 | While everyone knows about the `fit` and `predict`, there are other behaviours, methods and attributes that
6 | scikit-learn might be expecting from your estimator depending on:
7 |
8 | - The type of estimator you're writing.
9 | - The signature of the estimator.
10 | - The signature of the `.fit(...)` method.
11 |
12 | Scikit-learn Smithy to the rescue: this tool aims to help you crafting your own estimator by asking a few
13 | questions about it, and then generating the boilerplate code.
14 |
15 | In this way you will be able to fully focus on the core implementation logic, and not on nitty-gritty details
16 | of the scikit-learn API.
17 |
18 | ## Sanity check
19 |
20 | Once the core logic is implemented, the estimator should be ready to test against the _somewhat official_
21 | [`parametrize_with_checks`](https://scikit-learn.org/dev/modules/generated/sklearn.utils.estimator_checks.parametrize_with_checks.html#sklearn.utils.estimator_checks.parametrize_with_checks)
22 | pytest compatible decorator.
23 |
24 | ## Official guide
25 |
26 | Scikit-learn documentation on how to
27 | [develop estimators](https://scikit-learn.org/dev/developers/develop.html#developing-scikit-learn-estimators).
28 |
--------------------------------------------------------------------------------
/sksmithy/_static/template.py.jinja:
--------------------------------------------------------------------------------
1 | {%- if estimator_type in ('classifier', 'feature-selector') %}
2 | import numpy as np
3 | {% endif -%}
4 | {%- if estimator_type == 'classifier' and linear %}
5 | from sklearn.base import BaseEstimator
6 | from sklearn.linear_model._base import LinearClassifierMixin
7 | {% elif estimator_type == 'regressor' and linear%}
8 | from sklearn.base import {{ mixin }}
9 | from sklearn.linear_model._base import LinearModel
10 | {% elif estimator_type == 'feature-selector'%}
11 | from sklearn.base import BaseEstimator
12 | from sklearn.feature_selection import SelectorMixin
13 | {% else %}
14 | from sklearn.base import BaseEstimator, {{ mixin }}
15 | {% endif -%}
16 | from sklearn.utils import check_X_y
17 | from sklearn.utils.validation import check_is_fitted, check_array
18 |
19 | {% if sample_weight %}from sklearn.utils.validation import _check_sample_weight{% endif %}
20 |
21 |
22 | class {{ name }}(
23 | {% if estimator_type == 'classifier' and linear %}
24 | LinearClassifierMixin, BaseEstimator
25 | {% elif estimator_type == 'regressor' and linear%}
26 | RegressorMixin, LinearModel
27 | {%else %}
28 | {{ mixin }}, BaseEstimator
29 | {% endif %}):
30 | """{{ name }} estimator.
31 |
32 | ...
33 | {% if parameters %}
34 | Parameters
35 | ----------
36 | {% for param in parameters %}
37 | {{- param }} : ...
38 | {% endfor -%}
39 | {% endif -%}
40 | """
41 | {% if required %}_required_parameters = {{ required }}{% endif -%}
42 |
43 | {% if parameters %}
44 | def __init__(
45 | self,
46 | {% for param in required %}
47 | {{- param }},
48 | {% endfor -%}
49 | {%- if optional -%}
50 | *,
51 | {% endif -%}
52 | {% for param in optional %}
53 | {{- param }}=...,
54 | {% endfor -%}
55 | ):
56 |
57 | {%for param in parameters -%}
58 | self.{{param}} = {{param}}
59 | {% endfor -%}
60 | {% endif %}
61 |
62 | def fit(self, X, y{% if estimator_type in ('transformer', 'feature-selector') %}=None{% endif %}{% if sample_weight %}, sample_weight=None{% endif %}):
63 | """
64 | Fit {{name}} estimator.
65 |
66 | Parameters
67 | ----------
68 | X : {array-like, sparse matrix} of shape (n_samples, n_features)
69 | Training data.
70 |
71 | {%- if transformer-%}
72 | y : None
73 | Ignored.
74 | {% else %}
75 | y : array-like of shape (n_samples,) or (n_samples, n_targets)
76 | Target values.
77 | {% endif %}
78 |
79 | {%- if sample_weight -%}
80 | sample_weight : array-like of shape (n_samples,), default=None
81 | Individual weights for each sample.
82 | {% endif %}
83 | Returns
84 | -------
85 | self : {{name}}
86 | Fitted {{name}} estimator.
87 | """
88 | {%- if estimator_type in ('transformer', 'feature-selector') %}
89 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments
90 | {% else %}
91 | X, y = check_X_y(X, y, ...) #TODO: Fill in `check_X_y` arguments
92 | {% endif %}
93 | self.n_features_in_ = X.shape[1]
94 | {%- if estimator_type=='classifier'%}
95 | self.classes_ = np.unique(y)
96 | {% endif %}
97 | {%- if sample_weight %}
98 | sample_weight = _check_sample_weight(sample_weight)
99 | {% endif %}
100 |
101 | ... # TODO: Implement fit logic
102 |
103 | {%if linear -%}
104 | # For linear models, coef_ and intercept_ is all you need. `predict` is taken care of by the mixin
105 | self.coef_ = ...
106 | self.intercept_ = ...
107 | {%- endif %}
108 | {% if 'max_iter' in parameters -%}self.n_iter_ = ...{%- endif %}
109 | {% if estimator_type=='outlier' -%}self.offset_ = ...{%- endif %}
110 | {% if estimator_type=='cluster' -%}self.labels_ = ...{%- endif %}
111 | {% if estimator_type=='feature-selector'%}
112 | self.selected_features_ = ... # TODO: Indexes of selected features
113 | self.support_ = np.isin(
114 | np.arange(0, self.n_features_in_), # all_features
115 | self.selected_features_
116 | )
117 | {%- endif %}
118 |
119 | return self
120 |
121 | {% if estimator_type == 'classifier' and decision_function == True and linear == False %}
122 | def decision_function(self, X):
123 | """Confidence scores of X.
124 |
125 | Parameters
126 | ----------
127 | X : array-like of shape (n_samples, n_features)
128 | The data to predict.
129 |
130 | Returns
131 | -------
132 | Prediction array.
133 | """
134 |
135 | check_is_fitted(self)
136 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments
137 |
138 | if X.shape[1] != self.n_features_in_:
139 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features."
140 | raise ValueError(msg)
141 |
142 | y_scores = ... # TODO: Implement decision_function logic
143 |
144 | return y_scores
145 |
146 | def predict(self, X):
147 | """Predict X.
148 |
149 | Parameters
150 | ----------
151 | X : array-like of shape (n_samples, n_features)
152 | The data to predict.
153 |
154 | Returns
155 | -------
156 | Prediction array.
157 | """
158 |
159 | check_is_fitted(self)
160 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments
161 |
162 | decision = self.decision_function(X)
163 | y_pred = (decision.ravel() > 0).astype(int) if self.n_classes == 2 else np.argmax(decision, axis=1)
164 | return y_pred
165 | {% endif %}
166 |
167 | {% if estimator_type in ('classifier', 'outlier') and predict_proba == True %}
168 | def predict_proba(self, X):
169 | """Probability estimates of X.
170 |
171 | Parameters
172 | ----------
173 | X : array-like of shape (n_samples, n_features)
174 | The data to predict.
175 |
176 | Returns
177 | -------
178 | Prediction array.
179 | """
180 |
181 | check_is_fitted(self)
182 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments
183 |
184 | if X.shape[1] != self.n_features_in_:
185 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features."
186 | raise ValueError(msg)
187 |
188 | y_proba = ... # TODO: Implement predict_proba logic
189 |
190 | return y_proba
191 | {% endif %}
192 |
193 | {% if estimator_type=='outlier' %}
194 | def score_samples(self, X):
195 |
196 | check_is_fitted(self)
197 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments
198 |
199 | if X.shape[1] != self.n_features_in_:
200 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features."
201 | raise ValueError(msg)
202 |
203 | ... # TODO: Implement scoring function, `decision_function` and `predict` will follow
204 |
205 | return ...
206 |
207 | def decision_function(self, X):
208 | return self.score_samples(X) - self.offset_
209 |
210 | def predict(self, X):
211 | preds = (self.decision_function(X) >= 0).astype(int)
212 | preds[preds == 0] = -1
213 | return preds
214 | {%- endif %}
215 |
216 | {% if decision_function == False and linear == False and (estimator_type in ('classifier', 'regressor', 'cluster')) %}
217 | def predict(self, X):
218 | """Predict X.
219 |
220 | Parameters
221 | ----------
222 | X : array-like of shape (n_samples, n_features)
223 | The data to predict.
224 |
225 | Returns
226 | -------
227 | Prediction array.
228 | """
229 |
230 | check_is_fitted(self)
231 | X = check_array(X, ...) #TODO: Fill in `check_array` arguments
232 |
233 | if X.shape[1] != self.n_features_in_:
234 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features."
235 | raise ValueError(msg)
236 |
237 | y_pred = ... # TODO: Implement predict logic
238 |
239 | return y_pred
240 | {% endif %}
241 |
242 | {% if estimator_type=='transformer' -%}
243 | def transform(self, X):
244 | """Transform X.
245 |
246 | Parameters
247 | ----------
248 | X : array-like of shape (n_samples, n_features)
249 | The data to transform.
250 |
251 | Returns
252 | -------
253 | Transformed array.
254 | """
255 |
256 | check_is_fitted(self)
257 | X = check_array(X, ...) # TODO: Fill in `check_array` arguments
258 |
259 | if X.shape[1] != self.n_features_in_:
260 | msg = f"X has {X.shape[1]} features but the estimator was fitted on {self.n_features_in_} features."
261 | raise ValueError(msg)
262 |
263 | X_ts = ... # TODO: Implement transform logic
264 |
265 | return X_ts
266 | {%- endif %}
267 |
268 | {% if estimator_type=='feature-selector' -%}
269 | def _get_support_mask(self, X):
270 | """Get the boolean mask indicating which features are selected.
271 |
272 | Returns
273 | -------
274 | support : boolean array of shape [# input features]
275 | An element is True iff its corresponding feature is selected for retention.
276 | """
277 |
278 | check_is_fitted(self)
279 | return self.support_
280 | {%- endif %}
281 |
282 | {% if tags %}
283 | def _more_tags(self):
284 | return {
285 | {%for tag in tags -%}
286 | "{{tag}}": ...,
287 | {% endfor -%}
288 | }
289 | {%- endif %}
290 |
291 | {% if estimator_type == 'classifier' %}
292 | @property
293 | def n_classes_(self):
294 | """Number of classes."""
295 | return len(self.classes_)
296 | {% endif %}
--------------------------------------------------------------------------------
/sksmithy/_static/tui.tcss:
--------------------------------------------------------------------------------
1 | .container {
2 | height: auto;
3 | width: auto;
4 | min-height: 10vh;
5 | }
6 |
7 | .label {
8 | height: 3;
9 | content-align: right middle;
10 | width: auto;
11 | }
12 |
13 | Screen {
14 | align: center middle;
15 | min-width: 100vw;
16 | }
17 |
18 | Header {
19 | color: $secondary;
20 | text-style: bold;
21 | }
22 |
23 | Horizontal {
24 | min-height: 10vh;
25 | height: auto;
26 | }
27 |
28 | Name, Estimator, Required, Optional {
29 | width: 50%;
30 | padding: 0 2 0 1;
31 | height: auto;
32 | }
33 |
34 | SampleWeight, Linear {
35 | width: 50%;
36 | padding: 1 0 0 1;
37 | height: auto;
38 | }
39 |
40 | PredictProba, DecisionFunction {
41 | width: 50%;
42 | padding: 0 0 0 1;
43 | height: auto;
44 | }
45 |
46 | Prompt {
47 | padding: 0 0 0 2;
48 | height: auto;
49 | }
50 |
51 | Switch {
52 | height: auto;
53 | width: auto;
54 | }
55 |
56 | Switch:disabled {
57 | background: darkslategrey;
58 | }
59 |
60 | Input.-valid {
61 | border: tall $success 60%;
62 | }
63 | Input.-valid:focus {
64 | border: tall $success;
65 | }
66 |
67 | ForgeRow {
68 | grid-size: 4 1;
69 | grid-gutter: 1;
70 | grid-columns: 45% 10% 10% 25%;
71 | min-height: 15vh;
72 | max-height: 15vh;
73 | }
74 |
75 | TextArea {
76 | min-height: 15vh;
77 | max-height: 100vh;
78 | }
79 |
80 | DestinationFile {
81 | column-span: 2;
82 | height: 100%;
83 | }
84 |
85 | Sidebar {
86 | width: 80;
87 | height: auto;
88 | background: $panel;
89 | transition: offset 200ms in_out_cubic;
90 | layer: overlay;
91 |
92 | }
93 |
94 | Sidebar:focus-within {
95 | offset: 0 0 !important;
96 | }
97 |
98 | Sidebar.-hidden {
99 | offset-x: -100%;
100 | }
101 |
102 | Sidebar Title {
103 | background: $boost;
104 | color: $secondary;
105 | padding: 2 0 1 0;
106 | border-right: vkey $background;
107 | dock: top;
108 | text-align: center;
109 | text-style: bold;
110 | }
111 |
112 | OptionGroup {
113 | background: $boost;
114 | color: $text;
115 | height: 1fr;
116 | border-right: vkey $background;
117 | }
118 |
--------------------------------------------------------------------------------
/sksmithy/_utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from importlib import resources
3 | from pathlib import Path
4 | from typing import Final
5 |
6 | from jinja2 import Template
7 |
8 | from sksmithy._models import EstimatorType
9 |
10 | TEMPLATE_PATH: Final[Path] = Path(str(resources.files("sksmithy") / "_static" / "template.py.jinja"))
11 |
12 |
13 | def render_template(
14 | name: str,
15 | estimator_type: EstimatorType,
16 | required: list[str],
17 | optional: list[str],
18 | linear: bool = False,
19 | sample_weight: bool = False,
20 | predict_proba: bool = False,
21 | decision_function: bool = False,
22 | tags: list[str] | None = None,
23 | ) -> str:
24 | """
25 | Render a template using the provided parameters.
26 |
27 | This is achieved in a two steps process:
28 |
29 | - Render the jinja template using the input values.
30 | - Format the string using ruff formatter.
31 |
32 | !!! warning
33 |
34 | This function **does not** validate that arguments are necessarely compatible with each other.
35 | For instance, it could be possible to pass `estimator_type = EstimatorType.RegressorMixin` and
36 | `predict_proba = True` which makes no sense as combination, but it would not raise an error.
37 |
38 | Parameters
39 | ----------
40 | name
41 | The name of the template.
42 | estimator_type
43 | The type of the estimator.
44 | required
45 | The list of required parameters.
46 | optional
47 | The list of optional parameters.
48 | linear
49 | Whether or not the estimator is linear.
50 | sample_weight
51 | Whether or not the estimator supports sample weights in `.fit()`.
52 | predict_proba
53 | Whether or not the estimator should implement `.predict_proba()` method.
54 | decision_function
55 | Whether or not the estimator should implement `.decision_function()` method.
56 | tags
57 | The list of scikit-learn extra tags.
58 |
59 | Returns
60 | -------
61 | str : The rendered and formatted template as a string.
62 | """
63 | values = {
64 | "name": name,
65 | "estimator_type": estimator_type.value,
66 | "mixin": estimator_type.name,
67 | "required": required,
68 | "optional": optional,
69 | "parameters": [*required, *optional],
70 | "linear": linear,
71 | "sample_weight": sample_weight,
72 | "predict_proba": predict_proba,
73 | "decision_function": decision_function,
74 | "tags": tags,
75 | }
76 |
77 | with TEMPLATE_PATH.open(mode="r") as stream:
78 | template = Template(stream.read()).render(values)
79 |
80 | return subprocess.check_output(["ruff", "format", "-"], input=template, encoding="utf-8")
81 |
--------------------------------------------------------------------------------
/sksmithy/app.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | from importlib import resources
4 | from importlib.metadata import version
5 |
6 | from result import Err, Ok
7 |
8 | from sksmithy._models import EstimatorType, TagType
9 | from sksmithy._parsers import check_duplicates, name_parser, params_parser
10 | from sksmithy._prompts import (
11 | PROMPT_DECISION_FUNCTION,
12 | PROMPT_ESTIMATOR,
13 | PROMPT_LINEAR,
14 | PROMPT_NAME,
15 | PROMPT_OPTIONAL,
16 | PROMPT_PREDICT_PROBA,
17 | PROMPT_REQUIRED,
18 | PROMPT_SAMPLE_WEIGHT,
19 | )
20 | from sksmithy._utils import render_template
21 |
22 | if (st_version := version("streamlit")) and tuple(int(re.sub(r"\D", "", str(v))) for v in st_version.split(".")) < (
23 | 1,
24 | 34,
25 | 0,
26 | ): # pragma: no cover
27 | st_import_err_msg = (
28 | f"streamlit>=1.34.0 is required for this module. Found version {st_version}.\nInstall it with "
29 | '`python -m pip install "streamlit>=1.34.0"` or `python -m pip install "sklearn-smithy[streamlit]"`'
30 | )
31 | raise ImportError(st_import_err_msg)
32 |
33 | else: # pragma: no cover
34 | import streamlit as st
35 |
36 | SIDEBAR_MSG: str = (resources.files("sksmithy") / "_static" / "description.md").read_text()
37 |
38 |
39 | def app() -> None: # noqa: C901,PLR0912,PLR0915
40 | """Streamlit App."""
41 | st.set_page_config(
42 | page_title="Smithy",
43 | page_icon="⚒️",
44 | layout="wide",
45 | menu_items={
46 | "Get Help": "https://github.com/FBruzzesi/sklearn-smithy",
47 | "Report a bug": "https://github.com/FBruzzesi/sklearn-smithy/issues/new",
48 | "About": """
49 | Forge your own scikit-learn estimator!
50 |
51 | For more information, please visit the [sklearn-smithy](https://github.com/FBruzzesi/sklearn-smithy)
52 | repository.
53 | """,
54 | },
55 | )
56 |
57 | st.title("Scikit-learn Smithy ⚒️")
58 | st.markdown("## Forge your own scikit-learn compatible estimator")
59 |
60 | with st.sidebar:
61 | st.markdown(SIDEBAR_MSG)
62 |
63 | linear = False
64 | predict_proba = False
65 | decision_function = False
66 | estimator_type: EstimatorType | None = None
67 |
68 | required_is_valid = False
69 | optional_is_valid = False
70 | msg_duplicated_params: str | None = None
71 |
72 | if "forged_template" not in st.session_state:
73 | st.session_state["forged_template"] = ""
74 |
75 | if "forge_counter" not in st.session_state:
76 | st.session_state["forge_counter"] = 0
77 |
78 | with st.container(): # name and type
79 | c11, c12 = st.columns(2)
80 |
81 | with c11: # name
82 | name_input = st.text_input(
83 | label=PROMPT_NAME,
84 | value="MightyEstimator",
85 | placeholder="MightyEstimator",
86 | help=(
87 | "It should be a valid "
88 | "[python identifier](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)"
89 | ),
90 | key="name",
91 | )
92 |
93 | match name_parser(name_input):
94 | case Ok(name):
95 | pass
96 | case Err(name_error_msg):
97 | name = ""
98 | st.error(name_error_msg)
99 |
100 | with c12: # type
101 | estimator = st.selectbox(
102 | label=PROMPT_ESTIMATOR,
103 | options=tuple(e.value for e in EstimatorType),
104 | format_func=lambda v: " ".join(x.capitalize() for x in v.split("-")),
105 | index=None,
106 | key="estimator",
107 | )
108 |
109 | if estimator:
110 | estimator_type = EstimatorType(estimator)
111 |
112 | with st.container(): # params
113 | c21, c22 = st.columns(2)
114 |
115 | with c21: # required
116 | required_params = st.text_input(
117 | label=PROMPT_REQUIRED,
118 | placeholder="alpha,beta",
119 | help=(
120 | "It should be a sequence of comma-separated "
121 | "[python identifiers](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)"
122 | ),
123 | key="required",
124 | )
125 |
126 | match params_parser(required_params):
127 | case Ok(required):
128 | required_is_valid = True
129 | case Err(required_err_msg):
130 | required_is_valid = False
131 | st.error(required_err_msg)
132 |
133 | with c22: # optional
134 | optional_params = st.text_input(
135 | label=PROMPT_OPTIONAL,
136 | placeholder="mu,sigma",
137 | help=(
138 | "It should be a sequence of comma-separated "
139 | "[python identifiers](https://docs.python.org/3/reference/lexical_analysis.html#identifiers)"
140 | ),
141 | key="optional",
142 | )
143 |
144 | match params_parser(optional_params):
145 | case Ok(optional):
146 | optional_is_valid = True
147 | case Err(optional_err_msg):
148 | optional_is_valid = False
149 | st.error(optional_err_msg)
150 |
151 | if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)):
152 | st.error(msg_duplicated_params)
153 |
154 | with st.container(): # sample_weight and linear
155 | c31, c32 = st.columns(2)
156 |
157 | with c31: # sample_weight
158 | sample_weight = st.toggle(
159 | PROMPT_SAMPLE_WEIGHT,
160 | help="[sample_weight](https://scikit-learn.org/dev/glossary.html#term-sample_weight)",
161 | key="sample_weight",
162 | )
163 | with c32: # linear
164 | linear = st.toggle(
165 | label=PROMPT_LINEAR,
166 | disabled=(estimator_type not in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin}),
167 | help="Available only if estimator is `Classifier` or `Regressor`",
168 | key="linear",
169 | )
170 |
171 | with st.container(): # predict_proba and decision_function
172 | c41, c42 = st.columns(2)
173 |
174 | with c41: # predict_proba
175 | predict_proba = st.toggle(
176 | label=PROMPT_PREDICT_PROBA,
177 | disabled=(estimator_type not in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin}),
178 | help=(
179 | "[predict_proba](https://scikit-learn.org/dev/glossary.html#term-predict_proba): "
180 | "Available only if estimator is `Classifier` or `Outlier`. "
181 | ),
182 | key="predict_proba",
183 | )
184 |
185 | with c42: # decision_function
186 | decision_function = st.toggle(
187 | label=PROMPT_DECISION_FUNCTION,
188 | disabled=(estimator_type != EstimatorType.ClassifierMixin) or linear,
189 | help=(
190 | "[decision_function](https://scikit-learn.org/dev/glossary.html#term-decision_function): "
191 | "Available only if estimator is `Classifier`"
192 | ),
193 | key="decision_function",
194 | )
195 |
196 | st.write("#") # empty space hack
197 |
198 | with st.container(): # forge button
199 | c51, c52, _, c54 = st.columns([2, 1, 1, 1])
200 |
201 | with (
202 | c51,
203 | st.popover(
204 | label="Additional tags",
205 | help=(
206 | "To know more about tags, check the "
207 | "[scikit-learn documentation](https://scikit-learn.org/dev/developers/develop.html#estimator-tags)"
208 | ),
209 | ),
210 | ):
211 | tags = st.multiselect(
212 | label="Select tags",
213 | options=tuple(e.value for e in TagType),
214 | help="These tags are not validated against the selected estimator type!",
215 | key="tags",
216 | )
217 |
218 | with c52:
219 | forge_btn = st.button(
220 | label="Time to forge 🛠️",
221 | type="primary",
222 | disabled=any(
223 | [
224 | not name,
225 | not estimator_type,
226 | not required_is_valid,
227 | not optional_is_valid,
228 | msg_duplicated_params,
229 | ]
230 | ),
231 | key="forge_btn",
232 | )
233 | if forge_btn:
234 | st.session_state["forge_counter"] += 1
235 | st.session_state["forged_template"] = render_template(
236 | name=name,
237 | estimator_type=estimator_type, # type: ignore[arg-type] # At this point estimator_type is never None.
238 | required=required,
239 | optional=optional,
240 | linear=linear,
241 | sample_weight=sample_weight,
242 | predict_proba=predict_proba,
243 | decision_function=decision_function,
244 | tags=tags,
245 | )
246 |
247 | with c54, st.popover(label="Download", disabled=not st.session_state["forge_counter"]):
248 | if name:
249 | file_name = st.text_input(label="Select filename", value=f"{name.lower()}.py", key="file_name")
250 |
251 | data = st.session_state["forged_template"]
252 | st.download_button(
253 | label="Confirm",
254 | type="primary",
255 | data=data,
256 | file_name=file_name,
257 | key="download_btn",
258 | )
259 |
260 | st.write("#") # empty space hack
261 |
262 | with st.container(): # code output
263 | if forge_btn:
264 | st.toast("Request submitted!")
265 | progress_text = "Forging in progress ..."
266 | progress_bar = st.progress(0, text=progress_text)
267 | # Consider using status component instead
268 | # https://docs.streamlit.io/develop/api-reference/status/st.status
269 |
270 | for percent_complete in range(100):
271 | time.sleep(0.002)
272 | progress_bar.progress(percent_complete + 1, text=progress_text)
273 |
274 | time.sleep(0.2)
275 | progress_bar.empty()
276 |
277 | if st.session_state["forge_counter"]:
278 | st.code(st.session_state["forged_template"], language="python", line_numbers=True)
279 |
280 |
281 | if __name__ == "__main__":
282 | app()
283 |
--------------------------------------------------------------------------------
/sksmithy/cli.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import typer
4 |
5 | from sksmithy._arguments import (
6 | decision_function_arg,
7 | estimator_type_arg,
8 | linear_arg,
9 | name_arg,
10 | optional_params_arg,
11 | output_file_arg,
12 | predict_proba_arg,
13 | required_params_arg,
14 | sample_weight_arg,
15 | tags_arg,
16 | )
17 | from sksmithy._logger import console
18 | from sksmithy._utils import render_template
19 |
20 | cli = typer.Typer(
21 | name="smith",
22 | help="CLI to generate scikit-learn estimator boilerplate code.",
23 | rich_markup_mode="rich",
24 | rich_help_panel="Customization and Utils",
25 | )
26 |
27 |
28 | @cli.command()
29 | def version() -> None:
30 | """Display library version."""
31 | from importlib import metadata
32 |
33 | __version__ = metadata.version("sklearn-smithy")
34 | console.print(f"sklearn-smithy={__version__}", style="good")
35 |
36 |
37 | @cli.command()
38 | def forge(
39 | name: name_arg,
40 | estimator_type: estimator_type_arg,
41 | required_params: required_params_arg = "",
42 | optional_params: optional_params_arg = "",
43 | sample_weight: sample_weight_arg = False,
44 | linear: linear_arg = False,
45 | predict_proba: predict_proba_arg = False,
46 | decision_function: decision_function_arg = False,
47 | tags: tags_arg = "",
48 | output_file: output_file_arg = "",
49 | ) -> None:
50 | """Generate a new shiny scikit-learn compatible estimator ✨
51 |
52 | Depending on the estimator type the following additional information could be required:
53 |
54 | * if the estimator is linear (classifier or regression)
55 | * if the estimator implements `.predict_proba()` method (classifier or outlier detector)
56 | * if the estimator implements `.decision_function()` method (classifier only)
57 |
58 | Finally, the following two questions will be prompt:
59 |
60 | * if the estimator should have tags (To know more about tags, check the dedicated scikit-learn documentation
61 | at https://scikit-learn.org/dev/developers/develop.html#estimator-tags)
62 | * in which file the class should be saved (default is `f'{name.lower()}.py'`)
63 | """
64 | forged_template = render_template(
65 | name=name,
66 | estimator_type=estimator_type,
67 | required=required_params, # type: ignore[arg-type] # Callback transforms it into `list[str]`
68 | optional=optional_params, # type: ignore[arg-type] # Callback transforms it into `list[str]`
69 | linear=linear,
70 | sample_weight=sample_weight,
71 | predict_proba=predict_proba,
72 | decision_function=decision_function,
73 | tags=tags, # type: ignore[arg-type] # Callback transforms it into `list[str]`
74 | )
75 |
76 | destination_file = Path(output_file)
77 | destination_file.parent.mkdir(parents=True, exist_ok=True)
78 |
79 | with destination_file.open(mode="w") as destination:
80 | destination.write(forged_template)
81 |
82 | console.print(f"Template forged at {destination_file}", style="good")
83 |
84 |
85 | @cli.command(name="forge-tui")
86 | def forge_tui() -> None:
87 | """Run Terminal User Interface via Textual."""
88 | from sksmithy.tui import ForgeTUI
89 |
90 | tui = ForgeTUI()
91 | tui.run()
92 |
93 |
94 | @cli.command(name="forge-webui")
95 | def forge_webui() -> None:
96 | """Run Web User Interface via Streamlit."""
97 | import subprocess
98 |
99 | subprocess.run(["streamlit", "run", "sksmithy/app.py"], check=True)
100 |
--------------------------------------------------------------------------------
/sksmithy/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/sksmithy/py.typed
--------------------------------------------------------------------------------
/sksmithy/tui/__init__.py:
--------------------------------------------------------------------------------
1 | from sksmithy.tui._tui import ForgeTUI
2 |
3 | __all__ = ("ForgeTUI",)
4 |
--------------------------------------------------------------------------------
/sksmithy/tui/_components.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import webbrowser
3 | from importlib import resources
4 | from pathlib import Path
5 |
6 | from result import Err, Ok
7 | from textual import on
8 | from textual.app import ComposeResult
9 | from textual.containers import Container, Grid, Horizontal, ScrollableContainer
10 | from textual.widgets import Button, Collapsible, Input, Markdown, Select, Static, Switch, TextArea
11 |
12 | from sksmithy._models import EstimatorType
13 | from sksmithy._parsers import check_duplicates, name_parser, params_parser
14 | from sksmithy._prompts import (
15 | PROMPT_DECISION_FUNCTION,
16 | PROMPT_ESTIMATOR,
17 | PROMPT_LINEAR,
18 | PROMPT_NAME,
19 | PROMPT_OPTIONAL,
20 | PROMPT_OUTPUT,
21 | PROMPT_PREDICT_PROBA,
22 | PROMPT_REQUIRED,
23 | PROMPT_SAMPLE_WEIGHT,
24 | )
25 | from sksmithy._utils import render_template
26 | from sksmithy.tui._validators import NameValidator, ParamsValidator
27 |
28 | if sys.version_info >= (3, 11): # pragma: no cover
29 | from typing import Self
30 | else: # pragma: no cover
31 | from typing_extensions import Self
32 |
33 |
34 | SIDEBAR_MSG: str = (resources.files("sksmithy") / "_static" / "description.md").read_text()
35 |
36 |
37 | class Prompt(Static):
38 | pass
39 |
40 |
41 | class Name(Container):
42 | """Name input component."""
43 |
44 | def compose(self: Self) -> ComposeResult:
45 | yield Prompt(PROMPT_NAME, classes="label")
46 | yield Input(placeholder="MightyEstimator", id="name", validators=[NameValidator()])
47 |
48 | @on(Input.Changed, "#name")
49 | def on_input_change(self: Self, event: Input.Changed) -> None:
50 | if not event.validation_result.is_valid: # type: ignore[union-attr]
51 | self.notify(
52 | message=event.validation_result.failure_descriptions[0], # type: ignore[union-attr]
53 | title="Invalid Name",
54 | severity="error",
55 | timeout=5,
56 | )
57 | else:
58 | output_file = self.app.query_one("#output-file", Input)
59 | output_file.value = f"{event.value.lower()}.py"
60 |
61 |
62 | class Estimator(Container):
63 | """Estimator select component."""
64 |
65 | def compose(self: Self) -> ComposeResult:
66 | yield Prompt(PROMPT_ESTIMATOR, classes="label")
67 | yield Select(
68 | options=((" ".join(x.capitalize() for x in e.value.split("-")), e.value) for e in EstimatorType),
69 | id="estimator",
70 | )
71 |
72 | @on(Select.Changed, "#estimator")
73 | def on_select_change(self: Self, event: Select.Changed) -> None:
74 | linear = self.app.query_one("#linear", Switch)
75 | predict_proba = self.app.query_one("#predict_proba", Switch)
76 | decision_function = self.app.query_one("#decision_function", Switch)
77 |
78 | linear.disabled = event.value not in {"classifier", "regressor"}
79 | predict_proba.disabled = event.value not in {"classifier", "outlier"}
80 | decision_function.disabled = event.value not in {"classifier"}
81 |
82 | linear.value = linear.value and (not linear.disabled)
83 | predict_proba.value = predict_proba.value and (not predict_proba.disabled)
84 | decision_function.value = decision_function.value and (not decision_function.disabled)
85 |
86 |
87 | class Required(Container):
88 | """Required params input component."""
89 |
90 | def compose(self: Self) -> ComposeResult:
91 | yield Prompt(PROMPT_REQUIRED, classes="label")
92 | yield Input(placeholder="alpha,beta", id="required", validators=[ParamsValidator()])
93 |
94 | @on(Input.Submitted, "#required")
95 | def on_input_change(self: Self, event: Input.Submitted) -> None:
96 | if not event.validation_result.is_valid: # type: ignore[union-attr]
97 | self.notify(
98 | message="\n".join(event.validation_result.failure_descriptions), # type: ignore[union-attr]
99 | title="Invalid Parameter",
100 | severity="error",
101 | timeout=5,
102 | )
103 |
104 | optional = self.app.query_one("#optional", Input).value or ""
105 | if (
106 | optional
107 | and event.value
108 | and (
109 | duplicates_result := check_duplicates(
110 | event.value.split(","),
111 | optional.split(","),
112 | )
113 | )
114 | ):
115 | self.notify(
116 | message=duplicates_result,
117 | title="Duplicate Parameter",
118 | severity="error",
119 | timeout=5,
120 | )
121 |
122 |
123 | class Optional(Container):
124 | """Optional params input component."""
125 |
126 | def compose(self: Self) -> ComposeResult:
127 | yield Prompt(PROMPT_OPTIONAL, classes="label")
128 | yield Input(placeholder="mu,sigma", id="optional", validators=[ParamsValidator()])
129 |
130 | @on(Input.Submitted, "#optional")
131 | def on_optional_change(self: Self, event: Input.Submitted) -> None:
132 | if not event.validation_result.is_valid: # type: ignore[union-attr]
133 | self.notify(
134 | message="\n".join(event.validation_result.failure_descriptions), # type: ignore[union-attr]
135 | title="Invalid Parameter",
136 | severity="error",
137 | timeout=5,
138 | )
139 |
140 | required = self.app.query_one("#required", Input).value or ""
141 | if (
142 | required
143 | and event.value
144 | and (
145 | duplicates_result := check_duplicates(
146 | required.split(","),
147 | event.value.split(","),
148 | )
149 | )
150 | ):
151 | self.notify(
152 | message=duplicates_result,
153 | title="Duplicate Parameter",
154 | severity="error",
155 | timeout=5,
156 | )
157 |
158 |
159 | class SampleWeight(Container):
160 | """sample_weight switch component."""
161 |
162 | def compose(self: Self) -> ComposeResult:
163 | yield Horizontal(
164 | Switch(id="sample_weight"),
165 | Prompt(PROMPT_SAMPLE_WEIGHT, classes="label"),
166 | classes="container",
167 | )
168 |
169 |
170 | class Linear(Container):
171 | """linear switch component."""
172 |
173 | def compose(self: Self) -> ComposeResult:
174 | yield Horizontal(
175 | Switch(id="linear"),
176 | Prompt(PROMPT_LINEAR, classes="label"),
177 | classes="container",
178 | )
179 |
180 | @on(Switch.Changed, "#linear")
181 | def on_switch_changed(self: Self, event: Switch.Changed) -> None:
182 | decision_function = self.app.query_one("#decision_function", Switch)
183 | decision_function.disabled = event.value
184 | decision_function.value = decision_function.value and (not decision_function.disabled)
185 |
186 |
187 | class PredictProba(Container):
188 | """predict_proba switch component."""
189 |
190 | def compose(self: Self) -> ComposeResult:
191 | yield Horizontal(
192 | Switch(id="predict_proba"),
193 | Prompt(PROMPT_PREDICT_PROBA, classes="label"),
194 | classes="container",
195 | )
196 |
197 |
198 | class DecisionFunction(Container):
199 | """decision_function switch component."""
200 |
201 | def compose(self: Self) -> ComposeResult:
202 | yield Horizontal(
203 | Switch(id="decision_function"),
204 | Prompt(PROMPT_DECISION_FUNCTION, classes="label"),
205 | classes="container",
206 | )
207 |
208 |
209 | class ForgeButton(Container):
210 | """forge button component."""
211 |
212 | def compose(self: Self) -> ComposeResult:
213 | yield Button(label="Forge ⚒️", id="forge-btn", variant="success")
214 |
215 | @on(Button.Pressed, "#forge-btn")
216 | def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901
217 | errors = []
218 |
219 | name_input = self.app.query_one("#name", Input).value
220 | estimator = self.app.query_one("#estimator", Select).value
221 | required_params = self.app.query_one("#required", Input).value
222 | optional_params = self.app.query_one("#optional", Input).value
223 |
224 | sample_weight = self.app.query_one("#linear", Switch).value
225 | linear = self.app.query_one("#linear", Switch).value
226 | predict_proba = self.app.query_one("#predict_proba", Switch).value
227 | decision_function = self.app.query_one("#decision_function", Switch).value
228 |
229 | code_area = self.app.query_one("#code-area", TextArea)
230 | code_editor = self.app.query_one("#code-editor", Collapsible)
231 |
232 | match name_parser(name_input):
233 | case Ok(name):
234 | pass
235 | case Err(name_error_msg):
236 | errors.append(name_error_msg)
237 |
238 | match estimator:
239 | case str(v):
240 | estimator_type = EstimatorType(v)
241 | case Select.BLANK:
242 | errors.append("Estimator cannot be empty!")
243 |
244 | match params_parser(required_params):
245 | case Ok(required):
246 | required_is_valid = True
247 | case Err(required_err_msg):
248 | required_is_valid = False
249 | errors.append(required_err_msg)
250 |
251 | match params_parser(optional_params):
252 | case Ok(optional):
253 | optional_is_valid = True
254 |
255 | case Err(optional_err_msg):
256 | optional_is_valid = False
257 | errors.append(optional_err_msg)
258 |
259 | if required_is_valid and optional_is_valid and (msg_duplicated_params := check_duplicates(required, optional)):
260 | errors.append(msg_duplicated_params)
261 |
262 | if errors:
263 | self.notify(
264 | message="\n".join([f"- {e}" for e in errors]),
265 | title="Invalid inputs!",
266 | severity="error",
267 | timeout=5,
268 | )
269 |
270 | else:
271 | forged_template = render_template(
272 | name=name,
273 | estimator_type=estimator_type,
274 | required=required,
275 | optional=optional,
276 | linear=linear,
277 | sample_weight=sample_weight,
278 | predict_proba=predict_proba,
279 | decision_function=decision_function,
280 | tags=None,
281 | )
282 |
283 | code_area.text = forged_template
284 | code_editor.collapsed = False
285 |
286 | self.notify(
287 | message="Template forged!",
288 | title="Success!",
289 | severity="information",
290 | timeout=5,
291 | )
292 |
293 |
294 | class SaveButton(Container):
295 | """forge button component."""
296 |
297 | def compose(self: Self) -> ComposeResult:
298 | yield Button(label="Save 📂", id="save-btn", variant="primary")
299 |
300 | @on(Button.Pressed, "#save-btn")
301 | def on_save(self: Self, _: Button.Pressed) -> None:
302 | output_file = self.app.query_one("#output-file", Input).value
303 |
304 | if not output_file:
305 | self.notify(
306 | message="Outfile filename cannot be empty!",
307 | title="Invalid filename!",
308 | severity="error",
309 | timeout=5,
310 | )
311 | else:
312 | destination_file = Path(output_file)
313 | destination_file.parent.mkdir(parents=True, exist_ok=True)
314 |
315 | code = self.app.query_one("#code-area", TextArea).text
316 |
317 | with destination_file.open(mode="w") as destination:
318 | destination.write(code)
319 |
320 | self.notify(
321 | message=f"Saved at {destination_file}",
322 | title="Success!",
323 | severity="information",
324 | timeout=5,
325 | )
326 |
327 |
328 | class DestinationFile(Container):
329 | """Destination file input component."""
330 |
331 | def compose(self: Self) -> ComposeResult:
332 | yield Input(placeholder=PROMPT_OUTPUT, id="output-file")
333 |
334 |
335 | class ForgeRow(Grid):
336 | """Row grid for forge."""
337 |
338 |
339 | class OptionGroup(ScrollableContainer):
340 | pass
341 |
342 |
343 | class Sidebar(Container):
344 | def compose(self: Self) -> ComposeResult:
345 | yield OptionGroup(Markdown(SIDEBAR_MSG))
346 |
347 | def on_markdown_link_clicked(self: Self, event: Markdown.LinkClicked) -> None:
348 | # Relevant discussion: https://github.com/Textualize/textual/discussions/3668
349 | webbrowser.open_new_tab(event.href)
350 |
--------------------------------------------------------------------------------
/sksmithy/tui/_tui.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from importlib import metadata, resources
3 | from typing import ClassVar
4 |
5 | from textual.app import App, ComposeResult
6 | from textual.containers import Container, Horizontal, ScrollableContainer
7 | from textual.reactive import reactive
8 | from textual.widgets import Button, Collapsible, Footer, Header, Rule, Static, TextArea
9 |
10 | from sksmithy.tui._components import (
11 | DecisionFunction,
12 | DestinationFile,
13 | Estimator,
14 | ForgeButton,
15 | ForgeRow,
16 | Linear,
17 | Name,
18 | Optional,
19 | PredictProba,
20 | Required,
21 | SampleWeight,
22 | SaveButton,
23 | Sidebar,
24 | )
25 |
26 | if sys.version_info >= (3, 11): # pragma: no cover
27 | from typing import Self
28 | else: # pragma: no cover
29 | from typing_extensions import Self
30 |
31 |
32 | class ForgeTUI(App):
33 | """Textual app to forge scikit-learn compatible estimators."""
34 |
35 | CSS_PATH: ClassVar[str] = str(resources.files("sksmithy") / "_static" / "tui.tcss")
36 | TITLE: ClassVar[str] = "Scikit-learn Smithy ⚒️" # type: ignore[misc]
37 |
38 | BINDINGS: ClassVar = [
39 | ("ctrl+d", "toggle_sidebar", "Description"),
40 | ("L", "toggle_dark", "Light/Dark mode"),
41 | ("F", "forge", "Forge"),
42 | ("ctrl+s", "save", "Save"),
43 | ("E", "app.quit", "Exit"),
44 | ]
45 |
46 | show_sidebar = reactive(False) # noqa: FBT003
47 |
48 | def on_mount(self: Self) -> None:
49 | """Compose on mount.
50 |
51 | Q: is this needed???
52 | """
53 | self.compose()
54 |
55 | def compose(self: Self) -> ComposeResult:
56 | """Create child widgets for the app."""
57 | yield Container(
58 | Header(icon=f"v{metadata.version('sklearn-smithy')}"),
59 | ScrollableContainer(
60 | Horizontal(Name(), Estimator()),
61 | Horizontal(Required(), Optional()),
62 | Horizontal(SampleWeight(), Linear()),
63 | Horizontal(PredictProba(), DecisionFunction()),
64 | Rule(),
65 | ForgeRow(
66 | Static(),
67 | ForgeButton(),
68 | SaveButton(),
69 | DestinationFile(),
70 | ),
71 | Rule(),
72 | Collapsible(
73 | TextArea(
74 | text="",
75 | language="python",
76 | theme="vscode_dark",
77 | show_line_numbers=True,
78 | tab_behavior="indent",
79 | id="code-area",
80 | ),
81 | title="Code Editor",
82 | collapsed=True,
83 | id="code-editor",
84 | ),
85 | ),
86 | Sidebar(classes="-hidden"),
87 | Footer(),
88 | )
89 |
90 | def action_toggle_dark(self: Self) -> None: # pragma: no cover
91 | """Toggle dark mode."""
92 | self.dark = not self.dark
93 |
94 | def action_toggle_sidebar(self: Self) -> None: # pragma: no cover
95 | """Toggle sidebar component."""
96 | sidebar = self.query_one(Sidebar)
97 | self.set_focus(None)
98 |
99 | if sidebar.has_class("-hidden"):
100 | sidebar.remove_class("-hidden")
101 | else:
102 | if sidebar.query("*:focus"):
103 | self.screen.set_focus(None)
104 | sidebar.add_class("-hidden")
105 |
106 | def action_forge(self: Self) -> None:
107 | """Press forge button."""
108 | forge_btn = self.query_one("#forge-btn", Button)
109 | forge_btn.press()
110 |
111 | def action_save(self: Self) -> None:
112 | """Press save button."""
113 | save_btn = self.query_one("#save-btn", Button)
114 | save_btn.press()
115 |
116 |
117 | if __name__ == "__main__": # pragma: no cover
118 | tui = ForgeTUI()
119 | tui.run()
120 |
--------------------------------------------------------------------------------
/sksmithy/tui/_validators.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from typing import TypeVar
3 |
4 | from result import Err, Ok, Result
5 | from textual.validation import ValidationResult, Validator
6 |
7 | from sksmithy._parsers import name_parser, params_parser
8 |
9 | if sys.version_info >= (3, 11): # pragma: no cover
10 | from typing import Self
11 | else: # pragma: no cover
12 | from typing_extensions import Self
13 |
14 | T = TypeVar("T")
15 | R = TypeVar("R")
16 |
17 |
18 | class _BaseValidator(Validator):
19 | @staticmethod
20 | def parser(value: str) -> Result[str | list[str], str]: # pragma: no cover
21 | raise NotImplementedError
22 |
23 | def validate(self: Self, value: str) -> ValidationResult:
24 | match self.parser(value):
25 | case Ok(_):
26 | return self.success()
27 | case Err(msg):
28 | return self.failure(msg)
29 |
30 |
31 | class NameValidator(_BaseValidator):
32 | @staticmethod
33 | def parser(value: str) -> Result[str, str]:
34 | return name_parser(value)
35 |
36 |
37 | class ParamsValidator(_BaseValidator):
38 | @staticmethod
39 | def parser(value: str) -> Result[list[str], str]:
40 | return params_parser(value)
41 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FBruzzesi/sklearn-smithy/213aefcf64950a72cd51bd3b02b4ccb23484dada/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from streamlit.testing.v1 import AppTest
3 |
4 | from sksmithy._models import EstimatorType
5 |
6 |
7 | @pytest.fixture(params=["MightyEstimator"])
8 | def name(request: pytest.FixtureRequest) -> str:
9 | return request.param
10 |
11 |
12 | @pytest.fixture(params=list(EstimatorType))
13 | def estimator(request: pytest.FixtureRequest) -> EstimatorType:
14 | return request.param
15 |
16 |
17 | @pytest.fixture(params=[["alpha", "beta"], ["max_iter"], []])
18 | def required(request: pytest.FixtureRequest) -> list[str]:
19 | return request.param
20 |
21 |
22 | @pytest.fixture(
23 | params=[
24 | ("a,a", "Found repeated parameters!"),
25 | ("a-a", "The following parameters are invalid python identifiers: ('a-a',)"),
26 | ]
27 | )
28 | def invalid_required(request: pytest.FixtureRequest) -> tuple[str, str]:
29 | return request.param
30 |
31 |
32 | @pytest.fixture(
33 | params=[
34 | ("b,b", "Found repeated parameters!"),
35 | ("b b", "The following parameters are invalid python identifiers: ('b b',)"),
36 | ]
37 | )
38 | def invalid_optional(request: pytest.FixtureRequest) -> tuple[str, str]:
39 | return request.param
40 |
41 |
42 | @pytest.fixture(params=[["mu", "sigma"], []])
43 | def optional(request: pytest.FixtureRequest) -> list[str]:
44 | return request.param
45 |
46 |
47 | @pytest.fixture(params=[True, False])
48 | def sample_weight(request: pytest.FixtureRequest) -> bool:
49 | return request.param
50 |
51 |
52 | @pytest.fixture(params=[True, False])
53 | def linear(request: pytest.FixtureRequest) -> bool:
54 | return request.param
55 |
56 |
57 | @pytest.fixture(params=[True, False])
58 | def predict_proba(request: pytest.FixtureRequest) -> bool:
59 | return request.param
60 |
61 |
62 | @pytest.fixture(params=[True, False])
63 | def decision_function(request: pytest.FixtureRequest) -> bool:
64 | return request.param
65 |
66 |
67 | @pytest.fixture(params=[["allow_nan", "binary_only"], [], None])
68 | def tags(request: pytest.FixtureRequest) -> list[str] | None:
69 | return request.param
70 |
71 |
72 | @pytest.fixture()
73 | def app() -> AppTest:
74 | return AppTest.from_file("sksmithy/app.py", default_timeout=10)
75 |
--------------------------------------------------------------------------------
/tests/test_app.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from streamlit.testing.v1 import AppTest
3 |
4 | from sksmithy._models import EstimatorType
5 |
6 |
7 | def test_smoke(app: AppTest) -> None:
8 | """Basic smoke test."""
9 | app.run()
10 | assert not app.exception
11 |
12 |
13 | @pytest.mark.parametrize(
14 | ("name_", "err_msg"),
15 | [
16 | ("MightyEstimator", ""),
17 | ("not-valid-name", "`not-valid-name` is not a valid python class name!"),
18 | ("class", "`class` is a python reserved keyword!"),
19 | ],
20 | )
21 | def test_name(app: AppTest, name_: str, err_msg: str) -> None:
22 | """Test `name` text_input component."""
23 | app.run()
24 | app.text_input(key="name").input(name_).run()
25 |
26 | if err_msg:
27 | assert app.error[0].value == err_msg
28 | else:
29 | assert not app.error
30 |
31 |
32 | def test_estimator_interaction(app: AppTest, estimator: EstimatorType) -> None:
33 | """Test that all toggle components interact correctly with the selected estimator."""
34 | app.run()
35 | app.selectbox(key="estimator").select(estimator.value).run()
36 |
37 | assert (not app.toggle(key="linear").disabled) == (
38 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin}
39 | )
40 | assert (not app.toggle(key="predict_proba").disabled) == (
41 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin}
42 | )
43 | assert (not app.toggle(key="decision_function").disabled) == (estimator == EstimatorType.ClassifierMixin)
44 |
45 | if estimator == EstimatorType.ClassifierMixin:
46 | app.toggle(key="linear").set_value(True).run()
47 |
48 | assert app.toggle(key="decision_function").disabled
49 |
50 |
51 | @pytest.mark.parametrize(
52 | ("required_", "optional_", "err_msg"),
53 | [
54 | ("a,b", "c,d", ""),
55 | ("a,a", "", "Found repeated parameters!"),
56 | ("", "b,b", "Found repeated parameters!"),
57 | ("a-a", "", "The following parameters are invalid python identifiers: ('a-a',)"),
58 | ("", "b b", "The following parameters are invalid python identifiers: ('b b',)"),
59 | ("a,b", "a", "The following parameters are duplicated between required and optional: {'a'}"),
60 | ],
61 | )
62 | def test_params(
63 | app: AppTest, name: str, estimator: EstimatorType, required_: str, optional_: str, err_msg: str
64 | ) -> None:
65 | """Test required and optional params interaction."""
66 | app.run()
67 | app.text_input(key="name").input(name).run()
68 | app.selectbox(key="estimator").select(estimator.value).run()
69 |
70 | app.text_input(key="required").input(required_).run()
71 | app.text_input(key="optional").input(optional_).run()
72 |
73 | if err_msg:
74 | assert app.error[0].value == err_msg
75 | # Forge button gets disabled if any error happen
76 | assert app.button(key="forge_btn").disabled
77 | else:
78 | assert not app.error
79 | assert not app.button(key="forge_btn").disabled
80 |
81 |
82 | def test_forge(app: AppTest, name: str, estimator: EstimatorType) -> None:
83 | """Test forge button and all of its interactions.
84 |
85 | Remark that there is no way of testing `popover` or `download_button` components (yet).
86 | """
87 | app.run()
88 | assert app.button(key="forge_btn").disabled
89 | assert app.session_state["forge_counter"] == 0
90 |
91 | app.text_input(key="name").input(name).run()
92 | app.selectbox(key="estimator").select(estimator.value).run()
93 | assert not app.button(key="forge_btn").disabled
94 | assert not app.code
95 |
96 | app.button(key="forge_btn").click().run()
97 | assert app.session_state["forge_counter"] == 1
98 | assert app.code is not None
99 |
--------------------------------------------------------------------------------
/tests/test_cli.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | from typer.testing import CliRunner
5 |
6 | from sksmithy import __version__
7 | from sksmithy._models import EstimatorType
8 | from sksmithy._prompts import (
9 | PROMPT_DECISION_FUNCTION,
10 | PROMPT_ESTIMATOR,
11 | PROMPT_LINEAR,
12 | PROMPT_NAME,
13 | PROMPT_OPTIONAL,
14 | PROMPT_OUTPUT,
15 | PROMPT_PREDICT_PROBA,
16 | PROMPT_REQUIRED,
17 | PROMPT_SAMPLE_WEIGHT,
18 | PROMPT_TAGS,
19 | )
20 | from sksmithy.cli import cli
21 |
22 | runner = CliRunner()
23 |
24 |
25 | def test_version() -> None:
26 | result = runner.invoke(cli, ["version"])
27 | assert result.exit_code == 0
28 | assert f"sklearn-smithy={__version__}" in result.stdout
29 |
30 |
31 | @pytest.mark.parametrize("linear", ["y", "N"])
32 | def test_forge_estimator(tmp_path: Path, name: str, estimator: EstimatorType, linear: str) -> None:
33 | """Tests that prompts are correct for classifier estimator."""
34 | output_file = tmp_path / (f"{name.lower()}.py")
35 | assert not output_file.exists()
36 |
37 | _input = "".join(
38 | [
39 | f"{name}\n", # name
40 | f"{estimator.value}\n", # estimator_type
41 | "\n", # required params
42 | "\n", # optional params
43 | "\n", # sample weight
44 | f"{linear}\n" if estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin} else "",
45 | "\n" if estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin} else "", # predict_proba
46 | "\n"
47 | if (linear == "N" and estimator == EstimatorType.ClassifierMixin)
48 | else "", # decision_function: prompted only if not linear
49 | "\n", # tags
50 | f"{output_file!s}\n", # output file
51 | ]
52 | )
53 |
54 | result = runner.invoke(
55 | app=cli,
56 | args=["forge"],
57 | input=_input,
58 | )
59 |
60 | assert result.exit_code == 0
61 | assert output_file.exists()
62 |
63 | # General prompts
64 | assert all(
65 | _prompt in result.stdout
66 | for _prompt in (
67 | PROMPT_NAME,
68 | PROMPT_ESTIMATOR,
69 | PROMPT_REQUIRED,
70 | PROMPT_OPTIONAL,
71 | PROMPT_SAMPLE_WEIGHT,
72 | PROMPT_TAGS,
73 | f"{PROMPT_OUTPUT} [{name.lower()}.py]",
74 | )
75 | )
76 |
77 | # Estimator type specific prompts
78 | assert (PROMPT_LINEAR in result.stdout) == (
79 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin}
80 | )
81 | assert (PROMPT_PREDICT_PROBA in result.stdout) == (
82 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin}
83 | )
84 | assert (PROMPT_DECISION_FUNCTION in result.stdout) == (linear == "N" and estimator == EstimatorType.ClassifierMixin)
85 |
86 |
87 | @pytest.mark.parametrize(
88 | ("invalid_name", "name_err_msg"),
89 | [
90 | ("class", "Error: `class` is a python reserved keyword!"),
91 | ("abc-xyz", "Error: `abc-xyz` is not a valid python class name!"),
92 | ],
93 | )
94 | @pytest.mark.parametrize(
95 | ("invalid_required", "required_err_msg"),
96 | [
97 | ("a-b", "Error: The following parameters are invalid python identifiers: ('a-b',)"),
98 | ("a,a", "Error: Found repeated parameters!"),
99 | ],
100 | )
101 | @pytest.mark.parametrize(
102 | ("invalid_optional", "duplicated_err_msg"),
103 | [("a", "Error: The following parameters are duplicated between required and optional: {'a'}")],
104 | )
105 | @pytest.mark.parametrize(
106 | ("invalid_tags", "tags_err_msg"),
107 | [("not-a-tag,also-not-a-tag", "Error: The following tags are not available: ('not-a-tag', 'also-not-a-tag').")],
108 | )
109 | def test_forge_invalid_args(
110 | tmp_path: Path,
111 | name: str,
112 | invalid_name: str,
113 | name_err_msg: str,
114 | invalid_required: str,
115 | required_err_msg: str,
116 | invalid_optional: str,
117 | duplicated_err_msg: str,
118 | invalid_tags: str,
119 | tags_err_msg: str,
120 | ) -> None:
121 | """Tests that error messages are raised with invalid names."""
122 | output_file = tmp_path / (f"{name.lower()}.py")
123 | assert not output_file.exists()
124 |
125 | _input = "".join(
126 | [
127 | f"{invalid_name}\n", # name, invalid attempt
128 | f"{name}\n", # name, valid attempt
129 | "transformer\n" # type
130 | f"{invalid_required}\n", # required params, invalid attempt
131 | "a,b\n", # required params, valid attempt
132 | f"{invalid_optional}\n", # optional params, invalid attempt
133 | "c,d\n", # optional params, valid attempt
134 | "\n", # sample_weight
135 | f"{invalid_tags}\n", # tags, invalid attempt
136 | "binary_only\n", # valid attempt
137 | f"{output_file!s}\n",
138 | ]
139 | )
140 |
141 | result = runner.invoke(
142 | app=cli,
143 | args=["forge"],
144 | input=_input,
145 | )
146 |
147 | result = runner.invoke(cli, ["forge"], input=_input)
148 |
149 | assert result.exit_code == 0
150 | assert output_file.exists()
151 |
152 | assert all(
153 | err_msg in result.stdout for err_msg in (name_err_msg, required_err_msg, duplicated_err_msg, tags_err_msg)
154 | )
155 |
--------------------------------------------------------------------------------
/tests/test_parsers.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | import pytest
4 | from result import Err, Ok, is_err, is_ok
5 |
6 | from sksmithy._parsers import check_duplicates, name_parser, params_parser, tags_parser
7 |
8 |
9 | @pytest.mark.parametrize(
10 | ("name", "checker"),
11 | [
12 | ("valid_name", is_ok),
13 | ("ValidName", is_ok),
14 | ("123Invalid", is_err),
15 | ("class", is_err),
16 | ("", is_err),
17 | ],
18 | )
19 | def test_name_parser(name: str, checker: Callable) -> None:
20 | result = name_parser(name)
21 | assert checker(result)
22 |
23 |
24 | @pytest.mark.parametrize(
25 | ("params", "checker", "expected"),
26 | [
27 | (None, is_ok, []),
28 | ("a,b,c", is_ok, ["a", "b", "c"]),
29 | ("123a,b c,x", is_err, "The following parameters are invalid python identifiers: ('123a', 'b c')"),
30 | ("a,a,b", is_err, "Found repeated parameters!"),
31 | ],
32 | )
33 | def test_params_parser(params: str, checker: Callable, expected: str) -> None:
34 | result = params_parser(params)
35 | assert checker(result)
36 |
37 | match result:
38 | case Ok(value):
39 | assert value == expected
40 | case Err(msg):
41 | assert msg == expected
42 |
43 |
44 | @pytest.mark.parametrize(
45 | ("required", "optional", "expected"),
46 | [
47 | (["a", "b"], ["c", "d"], None),
48 | ([], ["c", "d"], None),
49 | ([], [], None),
50 | (["a", "b"], ["b", "c"], "The following parameters are duplicated between required and optional: {'b'}"),
51 | ],
52 | )
53 | def test_check_duplicates(required: list[str], optional: list[str], expected: str) -> None:
54 | result = check_duplicates(required, optional)
55 | assert result == expected
56 |
57 |
58 | @pytest.mark.parametrize(
59 | ("tags", "checker", "expected"),
60 | [
61 | ("allow_nan,binary_only", is_ok, ["allow_nan", "binary_only"]),
62 | ("", is_ok, []),
63 | ("some_madeup_tag", is_err, "The following tags are not available: ('some_madeup_tag',)"),
64 | ],
65 | )
66 | def test_tags_parser(tags: str, checker: Callable, expected: str) -> None:
67 | result = tags_parser(tags)
68 | assert checker(result)
69 | match result:
70 | case Ok(value):
71 | assert value == expected
72 | case Err(msg):
73 | assert msg.startswith(expected)
74 |
--------------------------------------------------------------------------------
/tests/test_render.py:
--------------------------------------------------------------------------------
1 | from sksmithy._models import EstimatorType
2 | from sksmithy._utils import render_template
3 |
4 |
5 | def test_params(name: str, required: list[str], optional: list[str]) -> None:
6 | """Tests params (both required and optional) render as expected."""
7 | result = render_template(
8 | name=name,
9 | estimator_type=EstimatorType.ClassifierMixin,
10 | required=required,
11 | optional=optional,
12 | sample_weight=False,
13 | linear=False,
14 | predict_proba=False,
15 | decision_function=False,
16 | tags=None,
17 | )
18 |
19 | assert all(f"self.{p} = {p}" in result for p in [*required, *optional])
20 | assert ("self.n_iter_" in result) == ("max_iter" in [*required, *optional])
21 |
22 | assert ("_required_parameters = " in result) == bool(required)
23 | # Not able to make a better assert work because of how f-strings render outer and inner strings
24 | # Here is what I tested assert (f'_required_parameters = {[f"{r}" for r in required]}' in result) == bool(required)
25 | # but still renders as "_required_parameters = ['a', 'b']" which is not how it is in the file
26 |
27 |
28 | def test_tags(name: str, tags: list[str] | None) -> None:
29 | """Tests tags render as expected."""
30 | result = render_template(
31 | name=name,
32 | estimator_type=EstimatorType.ClassifierMixin,
33 | required=[],
34 | optional=[],
35 | sample_weight=False,
36 | linear=False,
37 | predict_proba=False,
38 | decision_function=False,
39 | tags=tags,
40 | )
41 |
42 | assert ("def _more_tags(self)" in result) == bool(tags)
43 |
44 | if tags:
45 | for tag in tags:
46 | assert f'"{tag}": ...,' in result
47 |
48 |
49 | def test_common_estimator(name: str, estimator: EstimatorType, sample_weight: bool) -> None:
50 | """Tests common features are present for all estimators. Includes testing for sample_weight"""
51 | result = render_template(
52 | name=name,
53 | estimator_type=estimator,
54 | required=[],
55 | optional=[],
56 | sample_weight=sample_weight,
57 | linear=False,
58 | predict_proba=False,
59 | decision_function=False,
60 | tags=None,
61 | )
62 |
63 | assert f"class {name}" in result
64 | assert "self.n_features_in_ = X.shape[1]" in result
65 | assert ("sample_weight = _check_sample_weight(sample_weight)" in result) == sample_weight
66 |
67 | match estimator:
68 | case EstimatorType.TransformerMixin | EstimatorType.SelectorMixin:
69 | assert "X = check_array(X, ...)" in result
70 | assert ("def fit(self, X, y=None, sample_weight=None)" in result) == (sample_weight)
71 | assert ("def fit(self, X, y=None)" in result) == (not sample_weight)
72 | case _:
73 | assert "X, y = check_X_y(X, y, ...)" in result
74 | assert ("def fit(self, X, y, sample_weight=None)" in result) == (sample_weight)
75 | assert ("def fit(self, X, y)" in result) == (not sample_weight)
76 |
77 |
78 | def test_classifier(name: str, linear: bool, predict_proba: bool, decision_function: bool) -> None:
79 | """Tests classifier specific rendering."""
80 | estimator_type = EstimatorType.ClassifierMixin
81 |
82 | result = render_template(
83 | name=name,
84 | estimator_type=estimator_type,
85 | required=[],
86 | optional=[],
87 | sample_weight=False,
88 | linear=linear,
89 | predict_proba=predict_proba,
90 | decision_function=decision_function,
91 | tags=None,
92 | )
93 |
94 | # Classifier specific
95 | assert "self.classes_ = " in result
96 | assert "def n_classes_(self)" in result
97 | assert "def transform(self, X)" not in result
98 |
99 | assert "def transform(self, X)" not in result
100 |
101 | # Linear
102 | assert ("class MightyEstimator(LinearClassifierMixin, BaseEstimator)" in result) == linear
103 | assert ("self.coef_ = ..." in result) == linear
104 | assert ("self.intercept_ = ..." in result) == linear
105 |
106 | assert ("class MightyEstimator(ClassifierMixin, BaseEstimator)" in result) == (not linear)
107 | assert ("def predict(self, X)" in result) == (not linear)
108 |
109 | # Predict proba
110 | assert ("def predict_proba(self, X)" in result) == predict_proba
111 |
112 | # Decision function
113 | assert ("def decision_function(self, X)" in result) == (decision_function and not linear)
114 |
115 |
116 | def test_regressor(name: str, linear: bool) -> None:
117 | """Tests regressor specific rendering."""
118 | estimator_type = EstimatorType.RegressorMixin
119 |
120 | result = render_template(
121 | name=name,
122 | estimator_type=estimator_type,
123 | required=[],
124 | optional=[],
125 | sample_weight=False,
126 | linear=linear,
127 | predict_proba=False,
128 | decision_function=False,
129 | tags=None,
130 | )
131 |
132 | # Regressor specific
133 | assert "def transform(self, X)" not in result
134 |
135 | # Linear
136 | assert ("class MightyEstimator(RegressorMixin, LinearModel)" in result) == linear
137 | assert ("self.coef_ = ..." in result) == linear
138 | assert ("self.intercept_ = ..." in result) == linear
139 |
140 | assert ("class MightyEstimator(RegressorMixin, BaseEstimator)" in result) == (not linear)
141 | assert ("def predict(self, X)" in result) == (not linear)
142 |
143 |
144 | def test_outlier(name: str, predict_proba: bool) -> None:
145 | """Tests outlier specific rendering."""
146 | estimator_type = EstimatorType.OutlierMixin
147 |
148 | result = render_template(
149 | name=name,
150 | estimator_type=estimator_type,
151 | required=[],
152 | optional=[],
153 | sample_weight=False,
154 | linear=False,
155 | predict_proba=predict_proba,
156 | decision_function=False,
157 | tags=None,
158 | )
159 |
160 | # Outlier specific
161 | assert "class MightyEstimator(OutlierMixin, BaseEstimator)" in result
162 | assert "self.offset_" in result
163 | assert "def score_samples(self, X)" in result
164 | assert "def decision_function(self, X)" in result
165 | assert "def predict(self, X)" in result
166 |
167 | assert "def transform(self, X)" not in result
168 |
169 | # Predict proba
170 | assert ("def predict_proba(self, X)" in result) == predict_proba
171 |
172 |
173 | def test_transformer(name: str) -> None:
174 | """Tests transformer specific rendering."""
175 | estimator_type = EstimatorType.TransformerMixin
176 |
177 | result = render_template(
178 | name=name,
179 | estimator_type=estimator_type,
180 | required=[],
181 | optional=[],
182 | sample_weight=False,
183 | linear=False,
184 | predict_proba=False,
185 | decision_function=False,
186 | tags=None,
187 | )
188 | # Transformer specific
189 | assert "class MightyEstimator(TransformerMixin, BaseEstimator)" in result
190 | assert "def transform(self, X)" in result
191 | assert "def predict(self, X)" not in result
192 |
193 |
194 | def test_feature_selector(name: str) -> None:
195 | """Tests transformer specific rendering."""
196 | estimator_type = EstimatorType.SelectorMixin
197 |
198 | result = render_template(
199 | name=name,
200 | estimator_type=estimator_type,
201 | required=[],
202 | optional=[],
203 | sample_weight=False,
204 | linear=False,
205 | predict_proba=False,
206 | decision_function=False,
207 | tags=None,
208 | )
209 | # Transformer specific
210 | assert "class MightyEstimator(SelectorMixin, BaseEstimator)" in result
211 | assert "def _get_support_mask(self, X)" in result
212 | assert "self.support_" in result
213 | assert "def predict(self, X)" not in result
214 |
215 |
216 | def test_cluster(name: str) -> None:
217 | """Tests cluster specific rendering."""
218 | estimator_type = EstimatorType.ClusterMixin
219 |
220 | result = render_template(
221 | name=name,
222 | estimator_type=estimator_type,
223 | required=[],
224 | optional=[],
225 | sample_weight=False,
226 | linear=False,
227 | predict_proba=False,
228 | decision_function=False,
229 | tags=None,
230 | )
231 |
232 | # Cluster specific
233 | assert "class MightyEstimator(ClusterMixin, BaseEstimator)" in result
234 | assert "self.labels_ = ..." in result
235 | assert "def predict(self, X)" in result
236 |
--------------------------------------------------------------------------------
/tests/test_tui.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | from textual.widgets import Button, Input, Select, Switch
5 |
6 | from sksmithy._models import EstimatorType
7 | from sksmithy.tui import ForgeTUI
8 |
9 |
10 | async def test_smoke() -> None:
11 | """Basic smoke test."""
12 | app = ForgeTUI()
13 | async with app.run_test(size=None) as pilot:
14 | await pilot.pause()
15 | assert pilot is not None
16 |
17 | await pilot.pause()
18 | await pilot.exit(0)
19 |
20 |
21 | @pytest.mark.parametrize(
22 | ("name_", "err_msg"),
23 | [
24 | ("MightyEstimator", ""),
25 | ("not-valid-name", "`not-valid-name` is not a valid python class name!"),
26 | ("class", "`class` is a python reserved keyword!"),
27 | ],
28 | )
29 | async def test_name(name_: str, err_msg: str) -> None:
30 | """Test `name` text_input component."""
31 | app = ForgeTUI()
32 | async with app.run_test(size=None) as pilot:
33 | name_comp = pilot.app.query_one("#name", Input)
34 | name_comp.value = name_
35 | await pilot.pause()
36 |
37 | assert (not name_comp.is_valid) == bool(err_msg)
38 |
39 | notifications = list(pilot.app._notifications) # noqa: SLF001
40 | assert len(notifications) == int(bool(err_msg))
41 |
42 | if notifications:
43 | assert notifications[0].message == err_msg
44 |
45 |
46 | async def test_estimator_interaction(estimator: EstimatorType) -> None:
47 | """Test that all toggle components interact correctly with the selected estimator."""
48 | app = ForgeTUI()
49 | async with app.run_test(size=None) as pilot:
50 | pilot.app.query_one("#estimator", Select).value = estimator.value
51 | await pilot.pause()
52 |
53 | assert (not pilot.app.query_one("#linear", Switch).disabled) == (
54 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin}
55 | )
56 | assert (not pilot.app.query_one("#predict_proba", Switch).disabled) == (
57 | estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin}
58 | )
59 |
60 | assert (not pilot.app.query_one("#decision_function", Switch).disabled) == (
61 | estimator == EstimatorType.ClassifierMixin
62 | )
63 |
64 | if estimator == EstimatorType.ClassifierMixin:
65 | linear = pilot.app.query_one("#linear", Switch)
66 | linear.value = True
67 |
68 | await pilot.pause()
69 | assert pilot.app.query_one("#decision_function", Switch).disabled
70 |
71 |
72 | async def test_valid_params() -> None:
73 | """Test required and optional params interaction."""
74 | app = ForgeTUI()
75 | required_ = "a,b"
76 | optional_ = "c,d"
77 | async with app.run_test(size=None) as pilot:
78 | required_comp = pilot.app.query_one("#required", Input)
79 | optional_comp = pilot.app.query_one("#optional", Input)
80 |
81 | required_comp.value = required_
82 | optional_comp.value = optional_
83 |
84 | await required_comp.action_submit()
85 | await optional_comp.action_submit()
86 | await pilot.pause(0.01)
87 |
88 | notifications = list(pilot.app._notifications) # noqa: SLF001
89 | assert not notifications
90 |
91 |
92 | @pytest.mark.parametrize(("required_", "optional_"), [("a,b", "a"), ("a", "a,b")])
93 | async def test_duplicated_params(required_: str, optional_: str) -> None:
94 | app = ForgeTUI()
95 | msg = "The following parameters are duplicated between required and optional: {'a'}"
96 |
97 | async with app.run_test(size=None) as pilot:
98 | required_comp = pilot.app.query_one("#required", Input)
99 | optional_comp = pilot.app.query_one("#optional", Input)
100 |
101 | required_comp.value = required_
102 | optional_comp.value = optional_
103 |
104 | await required_comp.action_submit()
105 | await optional_comp.action_submit()
106 | await pilot.pause()
107 |
108 | forge_btn = pilot.app.query_one("#forge-btn", Button)
109 | forge_btn.action_press()
110 | await pilot.pause()
111 |
112 | assert all(msg in n.message for n in pilot.app._notifications) # noqa: SLF001
113 |
114 |
115 | async def test_forge_raise() -> None:
116 | """Test forge button and all of its interactions."""
117 | app = ForgeTUI()
118 | async with app.run_test(size=None) as pilot:
119 | required_comp = pilot.app.query_one("#required", Input)
120 | optional_comp = pilot.app.query_one("#optional", Input)
121 |
122 | required_comp.value = "a,a"
123 | optional_comp.value = "b b"
124 |
125 | await required_comp.action_submit()
126 | await optional_comp.action_submit()
127 | await pilot.pause()
128 |
129 | forge_btn = pilot.app.query_one("#forge-btn", Button)
130 | forge_btn.action_press()
131 | await pilot.pause()
132 |
133 | m1, m2, m3 = (n.message for n in pilot.app._notifications) # noqa: SLF001
134 |
135 | assert "Found repeated parameters!" in m1
136 | assert "The following parameters are invalid python identifiers: ('b b',)" in m2
137 |
138 | assert "Name cannot be empty!" in m3
139 | assert "Estimator cannot be empty!" in m3
140 | assert "Found repeated parameters!" in m3
141 | assert "The following parameters are invalid python identifiers: ('b b',)" in m3
142 |
143 |
144 | @pytest.mark.parametrize("use_binding", [True, False])
145 | async def test_forge_and_save(tmp_path: Path, name: str, estimator: EstimatorType, use_binding: bool) -> None:
146 | """Test forge button and all of its interactions."""
147 | app = ForgeTUI()
148 | async with app.run_test(size=None) as pilot:
149 | name_comp = pilot.app.query_one("#name", Input)
150 | estimator_comp = pilot.app.query_one("#estimator", Select)
151 | await pilot.pause()
152 |
153 | output_file_comp = pilot.app.query_one("#output-file", Input)
154 |
155 | name_comp.value = name
156 | estimator_comp.value = estimator.value
157 |
158 | await pilot.pause()
159 |
160 | output_file = tmp_path / (f"{name.lower()}.py")
161 | output_file_comp.value = str(output_file)
162 | await output_file_comp.action_submit()
163 | await pilot.pause()
164 |
165 | if use_binding:
166 | await pilot.press("F")
167 | else:
168 | forge_btn = pilot.app.query_one("#forge-btn", Button)
169 | forge_btn.action_press()
170 | await pilot.pause()
171 |
172 | if use_binding:
173 | await pilot.press("ctrl+s")
174 | else:
175 | save_btn = pilot.app.query_one("#save-btn", Button)
176 | save_btn.action_press()
177 | await pilot.pause()
178 |
179 | m1, m2 = (n.message for n in pilot.app._notifications) # noqa: SLF001
180 |
181 | assert "Template forged!" in m1
182 | assert "Saved at" in m2
183 |
184 | assert output_file.exists()
185 |
--------------------------------------------------------------------------------