├── utils ├── __init__.py ├── charts.py └── dataframes.py ├── requirements_dev.txt ├── requirements.txt ├── .pre-commit-config.yaml ├── pyproject.toml ├── Makefile ├── LICENSE ├── README.md ├── .gitignore └── notebooks └── classification └── no_skill.ipynb /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | ipykernel==6.29.4 2 | mypy==1.9.0 3 | pre-commit==3.7.0 4 | ruff==0.4.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair==5.3.0 2 | crepes==0.6.2 3 | numpy==1.26.4 4 | pandas==2.2.2 5 | scikit-learn==1.4.2 6 | scipy==1.13.0 7 | vegafusion[embed]==1.6.7 -------------------------------------------------------------------------------- /utils/charts.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import altair as alt 4 | from IPython.display import Image 5 | 6 | 7 | def display_static_altair_images(chart: alt.Chart) -> Image: 8 | with BytesIO() as buffer: 9 | chart.save(buffer, format="png") 10 | return Image(buffer.getvalue()) 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: check-yaml 8 | - id: check-json 9 | - id: check-docstring-first 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.2.1 12 | hooks: 13 | - id: ruff 14 | - id: ruff-format 15 | - repo: https://github.com/pre-commit/mirrors-mypy 16 | rev: v1.9.0 17 | hooks: 18 | - id: mypy 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.build.targets.wheel] 6 | packages = ["utils"] 7 | 8 | [project] 9 | name = "utils" 10 | version = "0.0.1" 11 | authors = [ 12 | { name="Breytner Nascimento", email="breytner.nascimento@gmail.com" }, 13 | ] 14 | description = "A small example package" 15 | readme = "README.md" 16 | requires-python = ">=3.12" 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ] -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTHON_VERSION = 3.12 2 | 3 | .PHONY: all help 4 | help: 5 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 6 | 7 | python: ## Install and config python 8 | @sudo apt-get update 9 | @sudo apt update && sudo apt upgrade -y 10 | @sudo apt install software-properties-common -y 11 | @sudo add-apt-repository ppa:deadsnakes/ppa 12 | @sudo apt-get install python-is-python3 13 | @sudo apt install python$(PYTHON_VERSION) 14 | @sudo apt install python$(PYTHON_VERSION)-venv 15 | @python$(PYTHON_VERSION) -m ensurepip 16 | @sudo update-alternatives --install /usr/bin/python python /usr/bin/python$(PYTHON_VERSION) $(subst .,,$(PYTHON_VERSION)) 17 | 18 | venv: ## Create virtual env 19 | @rm -rf .venv/ && python -m venv .venv && . .venv/bin/activate; \ 20 | pip install --upgrade pip; \ 21 | pip install uv; \ 22 | uv pip install \ 23 | -r requirements.txt \ 24 | -r requirements_dev.txt; \ 25 | uv pip install -e .; \ 26 | pre-commit install; 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Breytner Nascimento 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Science done right 2 | 3 | Data science tips and tricks to enhance data analysis and predictive modeling. 4 | 5 | Most of the tips shown here aren't directly useful to the day-to-day job, but they'll demistify many concepts around machine learning and data science. This knowledge will hopefully build up to the point it'll help you achieve greater results as the blackbox around many `pip install magic-buttons` used daily are opened. 6 | 7 | ## Table of Contents 8 | * [Tutorials](#tutorials) 9 | * [Using this repo](#using-this-repo) 10 | * [License](#license) 11 | 12 | ### About me 13 | #### Links 14 | * [E-mail](mailto:lunde@adobe.com) 15 | * [GitHub](https://github.com/BreytMN) 16 | * [LinkedIn](https://www.linkedin.com/in/breytner-nascimento/) 17 | * [Portfolio](https://portfolio.breytmn.com) 18 | 19 | ## Tutorials: 20 | 21 | ### Classification 22 | * [Upskilling a no-skill classifier with Conformal Prediction](notebooks/classification/no_skill.ipynb) - updated 2024-05-09 23 | * ["You can't classify with linear regression"](notebooks/classification/linear_regression_classifier.ipynb) - added 2024-05-09 24 | 25 | ### Regression 26 | * [Linear Extrapolation with non-linear algorithms](notebooks/regression/extrapolation.ipynb) - updated 2024-05-09 27 | 28 | ## Using this repo 29 | Clone the repository: 30 | ```bash 31 | git clone git@github.com:BreytMN/datascience-done-right.git 32 | ``` 33 | 34 | I suggest using VS Code as text editor and IDE for navigating this repository. Open the folder inside VS Code after changing directory with: 35 | ```bash 36 | cd datascience-done-right 37 | code . 38 | ``` 39 | 40 | This repository is written using WSL with an Ubuntu distro running Python 3.12. If you do not have Python 3.12 installed, you can run this on terminal (Ubuntu): 41 | ```bash 42 | make python 43 | ``` 44 | This will install Python 3.12 and all dependencies need to run it, in the end it will prompt the user to set the default version for Python. 45 | 46 | After that you can run: 47 | ```bash 48 | make venv 49 | source .venv/bin/activate 50 | ``` 51 | The first command will create a virtual environment and install all libraries needed to run any code inside this repository. The second command will activate the environment. 52 | 53 | ## License 54 | [MIT License](LICENSE) -------------------------------------------------------------------------------- /utils/dataframes.py: -------------------------------------------------------------------------------- 1 | from typing import Hashable, Optional, Sequence 2 | 3 | import pandas as pd 4 | from sklearn.datasets import make_blobs 5 | 6 | 7 | def make_classification_df( 8 | n_samples: Sequence[int], 9 | features: list[str], 10 | centers: Sequence[Sequence[int]], 11 | cluster_std: Sequence[float], 12 | target: str = "y", 13 | custom_targets: Optional[Sequence[Hashable]] = None, 14 | random_state: int = 0, 15 | ) -> pd.DataFrame: 16 | """Make simple classification dataframe for testing purposes 17 | 18 | Args: 19 | n_samples (Sequence[int]): A Sequence containing the number of points 20 | to be generated for each target. 21 | 22 | features (list[str]): List of names to be used as features. The lenght 23 | of this list will determinate the number of features to be generated. 24 | 25 | centers (Sequence[Sequence[int]]): Sequences containing the centers of 26 | each cluster. The lenght of the centers Sequence must match the lenght 27 | of n_samples while the lenght of its Sequences must match the lenght of 28 | `features` 29 | 30 | cluster_std (Sequence[float]): Standard deviation for cluster generation. 31 | The lenght must match the lenght of `n_samples`. 32 | 33 | target (str, optional): Name of the target column. Defaults to "y". 34 | 35 | custom_targets (Sequence[Hashable], optional): List of custom labels. 36 | Defaults to None. 37 | 38 | random_state (int, optional): Integer to be used in the generator for 39 | reproducible experiments. Defaults to 0. 40 | 41 | Returns: 42 | pd.DataFrame: The generated dataframe 43 | """ 44 | 45 | if len(n_samples) != len(centers) or len(n_samples) != len(cluster_std): 46 | condition1 = "Length of features must match length of centers" 47 | condition2 = "Lenght of n_samples must match lenght of cluster_std" 48 | msg = f"{condition1}. {condition2}." 49 | raise ValueError(msg) 50 | 51 | X_, y_ = make_blobs( 52 | n_samples=n_samples, 53 | n_features=len(features), 54 | centers=centers, 55 | cluster_std=cluster_std, 56 | random_state=random_state, 57 | ) 58 | 59 | df = pd.concat( 60 | ( 61 | pd.DataFrame(X_, columns=features), 62 | pd.Series(y_, name=target), 63 | ), 64 | axis=1, 65 | ) 66 | 67 | if custom_targets is not None: 68 | df = df.assign( 69 | **{target: df[target].replace(range(len(custom_targets)), custom_targets)} 70 | ) 71 | 72 | return df 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /notebooks/classification/no_skill.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Upskilling a no-skill classifier with Conformal Prediction\n", 8 | "\n", 9 | "Whenever a scientist needs to build a model, they need to evaluate the results against certain metrics given some context.\n", 10 | "\n", 11 | "This led to a culture that focus too much on optimizing metrics instead of measuring how the model would impact the business.\n", 12 | "\n", 13 | "In this notebook I'll show a way to evaluate models through different lens by using conformal prediction that can helps you given a clear picture of what your model is predicting." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import altair as alt\n", 23 | "import numpy as np\n", 24 | "import pandas as pd\n", 25 | "from crepes import WrapClassifier\n", 26 | "from sklearn.base import ClassifierMixin\n", 27 | "from sklearn.ensemble import RandomForestClassifier\n", 28 | "from sklearn.linear_model import LogisticRegression\n", 29 | "from sklearn.metrics import classification_report\n", 30 | "from sklearn.model_selection import train_test_split\n", 31 | "\n", 32 | "from utils.charts import display_static_altair_images\n", 33 | "from utils.dataframes import make_classification_df\n", 34 | "\n", 35 | "\n", 36 | "def calculate_coverage(\n", 37 | " X: pd.DataFrame,\n", 38 | " y: pd.Series,\n", 39 | " label: int,\n", 40 | " calibrated_conformal_classifier: WrapClassifier,\n", 41 | " alphas: list[float],\n", 42 | ") -> pd.DataFrame:\n", 43 | " X = X.reset_index(drop=True)\n", 44 | " y = y.reset_index(drop=True)\n", 45 | "\n", 46 | " sets = {\n", 47 | " alpha: calibrated_conformal_classifier.predict_set(X, confidence=1 - alpha)\n", 48 | " for alpha in alphas\n", 49 | " }\n", 50 | "\n", 51 | " count_true_label = sum(True for value in y if value == label)\n", 52 | " size = len(y)\n", 53 | "\n", 54 | " random_guessing = count_true_label / size\n", 55 | "\n", 56 | " results = []\n", 57 | "\n", 58 | " for alpha in alphas:\n", 59 | " count_coverage = 0\n", 60 | " count_sets = 0\n", 61 | "\n", 62 | " sets_aux = sets[alpha]\n", 63 | " for i, value in enumerate(y):\n", 64 | " if sets_aux[i, label]:\n", 65 | " count_sets += 1\n", 66 | " if value == label:\n", 67 | " count_coverage += 1\n", 68 | "\n", 69 | " denominator_count_sets = count_sets if count_sets > 0 else 1\n", 70 | "\n", 71 | " res = {\n", 72 | " \"alpha\": alpha,\n", 73 | " \"coverage\": count_coverage,\n", 74 | " \"% coverage (recall)\": round(count_coverage * 100 / count_true_label, 2),\n", 75 | " \"# sets containing target\": count_sets,\n", 76 | " \"% sets containing_target\": round(count_sets * 100 / size, 2),\n", 77 | " \"% sets correctly covering target (precision)\": round(\n", 78 | " count_coverage * 100 / denominator_count_sets, 2\n", 79 | " ),\n", 80 | " \"pp gain over random guessing\": round(\n", 81 | " ((count_coverage * 100) / denominator_count_sets)\n", 82 | " - (random_guessing * 100),\n", 83 | " 2,\n", 84 | " ),\n", 85 | " }\n", 86 | "\n", 87 | " results.append(res)\n", 88 | "\n", 89 | " return pd.DataFrame(results).set_index(\"alpha\")\n", 90 | "\n", 91 | "\n", 92 | "def train_and_calibrate(\n", 93 | " classifier: ClassifierMixin,\n", 94 | " X_train: pd.DataFrame,\n", 95 | " X_calib: pd.DataFrame,\n", 96 | " X_test: pd.DataFrame,\n", 97 | " y_train: pd.Series,\n", 98 | " y_calib: pd.Series,\n", 99 | " y_test: pd.Series,\n", 100 | ") -> WrapClassifier:\n", 101 | " classifier.fit(X_train, y_train)\n", 102 | " y_preds = classifier.predict(X_test)\n", 103 | "\n", 104 | " print(classification_report(y_test, y_preds, zero_division=0))\n", 105 | "\n", 106 | " conformal_classifier = WrapClassifier(classifier)\n", 107 | " conformal_classifier.calibrate(\n", 108 | " X_calib.reset_index(drop=True), y_calib.reset_index(drop=True), class_cond=True\n", 109 | " )\n", 110 | "\n", 111 | " return conformal_classifier\n", 112 | "\n", 113 | "\n", 114 | "start_alpha = 0.05\n", 115 | "end_alpha = 0.95\n", 116 | "num_alpha = int((end_alpha + start_alpha) / 0.05) - 1\n", 117 | "\n", 118 | "alphas = np.linspace(start_alpha, end_alpha, num_alpha)\n", 119 | "alphas = [round(i, 2) for i in alphas]" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "features = [\"a\", \"b\"]\n", 129 | "target = \"y\"\n", 130 | "\n", 131 | "df_params = {\n", 132 | " \"n_samples\": [20000, 1000, 1000],\n", 133 | " \"features\": features,\n", 134 | " \"centers\": [(0, 0), (3, 0), (1, 2)],\n", 135 | " \"cluster_std\": [2, 0.5, 0.8],\n", 136 | " \"random_state\": 0,\n", 137 | "}\n", 138 | "\n", 139 | "df = make_classification_df(**df_params)\n", 140 | "\n", 141 | "chart = (\n", 142 | " alt.Chart(df)\n", 143 | " .mark_point(size=10, opacity=0.5, filled=True)\n", 144 | " .encode(\n", 145 | " alt.X(\"a:Q\").scale(domain=[-8, 8], clamp=True),\n", 146 | " alt.Y(\"b:Q\").scale(domain=[-8, 8], clamp=True),\n", 147 | " alt.Color(\"y:N\"),\n", 148 | " )\n", 149 | " .properties(\n", 150 | " width=600,\n", 151 | " height=600,\n", 152 | " )\n", 153 | ")\n", 154 | "\n", 155 | "display_static_altair_images(chart)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "From the image above one can already imagine that it will be very hard to correctly classify the label 1 while being almost impossible to do the same with label 2.\n", 163 | "\n", 164 | "Some would say \"upsampling\" while others would say \"downsampling\". Upsampling is objectively bad, I'm not wasting my time on why it is bad to create artificial data. Downsampling on the other hand is ok, but it adds an extra layer of complexity that can't be avoided if you want correct results.\n", 165 | "\n", 166 | "For this case I prefer Conformal Prediction. It's a method that helps you say something like \"in this region of the feature space we expect this probability for each label\" with mathematical grounding.\n", 167 | "\n", 168 | "What does this means in practice?\n", 169 | " * for a targeted campaign this could mean \"anyone from this region is sure to be at least interested in this ad\" (remember that are cases in which we don't even have the capability of attending the whole demand for something, so if we can reduce our spending in marketing campaigns while still guaranteeing we sell the whole stock it is a big win);\n", 170 | " * on the other hand, if we are talking about fraud detection, we can use a cheaper model to detect anyone slightly suspicious and then send the results to a more sofisticated and expensive model that maybe is a paid API we contracted for our business, which means that we don't waste too many resources trying to detect frauds on most transactions." 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "X, y = df.filter(features), df[target]\n", 180 | "\n", 181 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)\n", 182 | "X_calib, X_test, y_calib, y_test = train_test_split(\n", 183 | " X_test, y_test, test_size=0.5, random_state=0\n", 184 | ")\n", 185 | "\n", 186 | "train_data = (X_train, X_calib, X_test, y_train, y_calib, y_test)\n", 187 | "\n", 188 | "new_df = make_classification_df(**{**df_params, \"random_state\": 1})\n", 189 | "new_X, new_y = new_df.filter(features), new_df[target]" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "## Logistic Regression Classifier\n", 197 | "\n", 198 | "Training a Logistic Regression Classifier yields a model that can't classify no one beyonds label 0 correctly. However, it isn't a useless model like the metrics make it seems so.\n", 199 | "\n", 200 | "Looking at the table calculated on the results of each set whe can see how the model is able to perform at each alpha. This give us the ability to analyze where our model is performing correctly with virtually 100% certainty that the label is correct while showing us at which threshold the models starts to \"fail\".\n", 201 | "\n", 202 | "That means even a classically bad classifier can be useful if we don't have anything better." 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "conformal_logistic_regression_classifier = train_and_calibrate(\n", 212 | " LogisticRegression(random_state=0), *train_data\n", 213 | ")\n", 214 | "for label in (0, 1, 2):\n", 215 | " print(f\"***** Results for label {label} *****\")\n", 216 | " display(\n", 217 | " calculate_coverage(\n", 218 | " X_test,\n", 219 | " y_test,\n", 220 | " label,\n", 221 | " conformal_logistic_regression_classifier,\n", 222 | " alphas,\n", 223 | " )\n", 224 | " )" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": {}, 230 | "source": [ 231 | "### Predicting on new data\n", 232 | "\n", 233 | "It also keeps roughly the same results on data that follows the same distribution." 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "for label in (0, 1, 2):\n", 243 | " print(f\"***** Results for label {label} *****\")\n", 244 | " display(\n", 245 | " calculate_coverage(\n", 246 | " new_X,\n", 247 | " new_y,\n", 248 | " label,\n", 249 | " conformal_logistic_regression_classifier,\n", 250 | " alphas,\n", 251 | " )\n", 252 | " )" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "## Random Forest Classifier\n", 260 | "\n", 261 | "Of course the Random Forest Classifier would perform way better than a Logistic Regression Classifier. However it still performs very bad, but the Conformal Prediction on top of it still gives way better control over the results.\n", 262 | "\n", 263 | "But hey, you do lose something in the label 0: there is no region with virtual 100% correct labels! There are some cases in which the no skill Logistic Regression Classifier can be more usefull than the smarter Random Forest Classifier, considering the results of the Conformal Prediction on top of it." 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "conformal_random_forest_classifier = train_and_calibrate(\n", 273 | " RandomForestClassifier(random_state=0), *train_data\n", 274 | ")\n", 275 | "for label in (0, 1, 2):\n", 276 | " print(f\"***** Results for label {label} *****\")\n", 277 | " display(\n", 278 | " calculate_coverage(\n", 279 | " X_test,\n", 280 | " y_test,\n", 281 | " label,\n", 282 | " conformal_random_forest_classifier,\n", 283 | " alphas,\n", 284 | " )\n", 285 | " )" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "# Predicting on new data\n", 293 | "\n", 294 | "Same as the logistic regression classifier, just so we can see it the conformal prediction helpings in many scenarios while confirming that we lost something for label 0 even in new data." 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "for label in (0, 1, 2):\n", 304 | " print(f\"***** Results for label {label} *****\")\n", 305 | " display(\n", 306 | " calculate_coverage(\n", 307 | " new_X,\n", 308 | " new_y,\n", 309 | " label,\n", 310 | " conformal_random_forest_classifier,\n", 311 | " alphas,\n", 312 | " )\n", 313 | " )" 314 | ] 315 | } 316 | ], 317 | "metadata": { 318 | "kernelspec": { 319 | "display_name": ".venv", 320 | "language": "python", 321 | "name": "python3" 322 | }, 323 | "language_info": { 324 | "codemirror_mode": { 325 | "name": "ipython", 326 | "version": 3 327 | }, 328 | "file_extension": ".py", 329 | "mimetype": "text/x-python", 330 | "name": "python", 331 | "nbconvert_exporter": "python", 332 | "pygments_lexer": "ipython3", 333 | "version": "3.12.2" 334 | } 335 | }, 336 | "nbformat": 4, 337 | "nbformat_minor": 2 338 | } 339 | --------------------------------------------------------------------------------