├── .gitignore ├── streamlit_images ├── floating.webp ├── flying.webp └── sinking.webp ├── notebooks ├── __pycache__ │ └── common.cpython-310.pyc ├── 2_snowml.ipynb ├── common.py ├── 1_load_data.ipynb ├── 4_Pipeline.ipynb └── 3_Deployment.ipynb ├── environment.yml ├── run.py ├── .pre-commit-config.yaml ├── .github └── workflows │ └── notebooks.yml ├── README.md └── streamlit └── app.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | titanic.csv 3 | -------------------------------------------------------------------------------- /streamlit_images/floating.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cromano8/Snowflake_ML_Intro/HEAD/streamlit_images/floating.webp -------------------------------------------------------------------------------- /streamlit_images/flying.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cromano8/Snowflake_ML_Intro/HEAD/streamlit_images/flying.webp -------------------------------------------------------------------------------- /streamlit_images/sinking.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cromano8/Snowflake_ML_Intro/HEAD/streamlit_images/sinking.webp -------------------------------------------------------------------------------- /notebooks/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cromano8/Snowflake_ML_Intro/HEAD/notebooks/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: Intro_SnowML 2 | channels: 3 | - https://repo.anaconda.com/pkgs/snowflake 4 | dependencies: 5 | - python=3.10 6 | - ipykernel 7 | - nbconvert 8 | - nbformat 9 | - python-dotenv 10 | - seaborn 11 | - snowflake-snowpark-python[pandas] 12 | - scikit-learn 13 | - xgboost 14 | - absl-py 15 | - anyio 16 | - cloudpickle 17 | - numpy 18 | - packaging 19 | - pandas 20 | - pyyaml 21 | - typing-extensions 22 | - snowflake-ml-python 23 | - lightgbm 24 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import nbformat 2 | from nbconvert.preprocessors import ExecutePreprocessor 3 | from pathlib import Path 4 | from dotenv import load_dotenv 5 | 6 | load_dotenv(override=True) 7 | 8 | notebooks = sorted( 9 | Path("notebooks").glob("*.ipynb"), key=lambda f: int(f.name.split("_")[0]) 10 | ) 11 | 12 | for notebook in notebooks: 13 | with open(notebook) as f: 14 | nb = nbformat.read(f, as_version=4) 15 | ep = ExecutePreprocessor().preprocess(nb, {"metadata": {"path": "notebooks/"}}) 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - id: requirements-txt-fixer 10 | - repo: https://github.com/codespell-project/codespell 11 | rev: v2.2.5 12 | hooks: 13 | - id: codespell 14 | - repo: https://github.com/nbQA-dev/nbQA 15 | rev: 1.7.0 16 | hooks: 17 | - id: nbqa-black 18 | - id: nbqa-pyupgrade 19 | args: ["--py38-plus"] 20 | - id: nbqa-isort 21 | args: [--profile=black] 22 | -------------------------------------------------------------------------------- /.github/workflows/notebooks.yml: -------------------------------------------------------------------------------- 1 | name: Test Dependencies 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | paths: 9 | - '**/*.ipynb' 10 | - 'environment.yml' 11 | - 'run.py' 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | 20 | - name: Set up Python 3.10 21 | uses: actions/setup-python@v3 22 | with: 23 | python-version: "3.10" 24 | 25 | - name: Download and install Micromamba 26 | run: | 27 | "${SHELL}" <(curl -L micro.mamba.pm/install.sh) 28 | micromamba shell init --shell bash --root-prefix=~/micromamba 29 | export PATH="$HOME/micromamba/bin:$PATH" 30 | 31 | - name: Create Micromamba environment and run Jupyter notebooks 32 | run: | 33 | micromamba create --file environment.yml --yes 34 | micromamba run -n Intro_SnowML python run.py 35 | env: 36 | SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} 37 | SNOWFLAKE_USER: ${{ secrets.SNOWFLAKE_USER }} 38 | SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }} 39 | SNOWFLAKE_DATABASE: ${{ secrets.SNOWFLAKE_DATABASE }} 40 | SNOWFLAKE_SCHEMA: ${{ secrets.SNOWFLAKE_SCHEMA }} 41 | SNOWFLAKE_ROLE: ${{ secrets.SNOWFLAKE_ROLE }} 42 | SNOWFLAKE_WAREHOUSE: ${{ secrets.SNOWFLAKE_WAREHOUSE }} 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Snowflake for Data Science 2 | 3 | ### Getting Started 4 | 5 | - 🎥 **Intro Video Walkthrough:** [Snowflake for ML Intro](https://github.com/cromano8/Snowflake_ML_Intro/assets/59093254/c249ce4c-2494-49ba-8142-6aec21cc4b08) 6 | - 🎥 **End-to-End ML Ops in Snowflake:** [Live: End-to-End ML Ops in Snowflake](https://www.youtube.com/watch?v=prA014tFRwY) 7 | - 🔗 **Regular 30-Day Trial:** [Sign Up](https://signup.snowflake.com/) 8 | - 🔗 **Student/Educator 120-Day Trial:** [Sign Up (Student)](https://signup.snowflake.com/?trial=student) 9 | 10 | Although we recorded videos, we are constantly making upgrades and additions to this repo, so the videos may differ slightly from what is in the repo. Overall they are the same but we will continue to upload more videos on any additions to the repo. 11 | 12 | ## Configuration Setup 13 | 14 | 1. Create a `.env` file and populate it with your account details: 15 | 16 | ```plaintext 17 | SNOWFLAKE_ACCOUNT = abc123.us-east-1 18 | SNOWFLAKE_USER = username 19 | SNOWFLAKE_PASSWORD = yourpassword 20 | SNOWFLAKE_ROLE = sysadmin 21 | SNOWFLAKE_WAREHOUSE = compute_wh 22 | SNOWFLAKE_DATABASE = snowpark 23 | SNOWFLAKE_SCHEMA = titanic 24 | ``` 25 | 26 | 2. Utilize the `environment.yml` file to set up your Python environment for the demo: 27 | - Examples in the terminal: 28 | - `conda env create -f environment.yml` 29 | - `micromamba create -f environment.yml -y` 30 | 31 | Why we partner with Anaconda 32 | 33 | Image 34 | 35 | Review of distributed Hyperparameter tuning benefits 36 | 37 | Local run time 8 min 27 seconds
38 | 39 | ![Screenshot 2024-02-05 at 10 13 50 AM](https://github.com/cromano8/Snowflake_ML_Intro/assets/59093254/7721cbc8-3fff-4fb4-9767-c3aadf1ac239) 40 | 41 | SnowflakeML run time 1 min 17 seconds (6.5x improvement in speed leveraging a Large WH) 42 | 43 | ![Screenshot 2024-02-05 at 10 16 43 AM](https://github.com/cromano8/Snowflake_ML_Intro/assets/59093254/6b30e9b0-a47e-4558-aa98-05d71cf01802) 44 | 45 | ## Data Processing & ML Operations 46 | 47 | ### Load & Transform Data 48 | 49 | Execute the `load_data` notebook to accomplish the following: 50 | 51 | - Load the Titanic dataset from Seaborn, convert to uppercase, and save as CSV 52 | - Upload the CSV file to a Snowflake Internal Stage 53 | - Create a Snowpark DataFrame from the staged CSV 54 | - Write the Snowpark DataFrame to Snowflake as a table 55 | 56 | ### Machine Learning Operations (snowml) 57 | 58 | In the `snowml` notebook: 59 | 60 | - Generate a Snowpark DataFrame from the Titanic table 61 | - Validate and handle null values 62 | - Remove columns with high null counts and correlations 63 | - Adjust Fare datatype and impute categorical nulls 64 | - One-Hot Encode Categorical Values 65 | - Segregate data into Test & Train sets 66 | - Train an XGBOOST Classifier Model with hyperparameter tuning 67 | - Conduct predictions on the test set 68 | - Display Accuracy, Precision, and Recall metrics 69 | 70 | ### Advanced MLOps with Live/Batch Inference & Streamlit 71 | 72 | Following the `load_data` steps, utilize the deployment notebook to: 73 | 74 | - Create a Snowpark DataFrame from the Titanic table 75 | - Assess and eliminate columns with high null counts and correlated columns 76 | - Adjust Fare datatype and handle categorical nulls 77 | - One-Hot Encode Categorical Values 78 | - Split the data into Test & Train sets 79 | - Train an XGBOOST Classifier Model, optimizing with grid search 80 | - Display model accuracy and best parameters 81 | - Register the model in the model registry 82 | - Deploy the model as a vectorized UDF (User Defined Function) 83 | - Execute batch predictions on a table 84 | - Perform real-time predictions using Streamlit for interactive inference 85 | -------------------------------------------------------------------------------- /streamlit/app.py: -------------------------------------------------------------------------------- 1 | # Copy and paste this app into an SiS app 2 | # Import python packages and add snowflake-ml-python to the package list 3 | 4 | import base64 5 | import pandas as pd 6 | import streamlit as st 7 | import sys 8 | from snowflake.ml.registry import Registry 9 | 10 | st.set_page_config(layout="wide") 11 | 12 | # Get the current credentials, locally or in SiS 13 | if sys._xoptions["snowflake_import_directory"]: 14 | from snowflake.snowpark.context import get_active_session 15 | 16 | session = get_active_session() 17 | else: 18 | from snowflake.ml.utils.connection_params import SnowflakeLoginOptions 19 | from snowflake.snowpark import Session 20 | 21 | session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate() 22 | 23 | st.title("Will you survive the titanic?") 24 | 25 | 26 | @st.cache_resource 27 | def load_images(): 28 | sink_bytes_object = session.file.get_stream( 29 | "@ML_DATA/sinking.webp.gz", decompress=True 30 | ).read() 31 | sink_image64 = base64.b64encode(sink_bytes_object).decode() 32 | 33 | float_bytes_object = session.file.get_stream( 34 | "@ML_DATA/floating.webp.gz", decompress=True 35 | ).read() 36 | float_image64 = base64.b64encode(float_bytes_object).decode() 37 | 38 | flying_bytes_object = session.file.get_stream( 39 | "@ML_DATA/flying.webp.gz", decompress=True 40 | ).read() 41 | flying_image64 = base64.b64encode(flying_bytes_object).decode() 42 | 43 | return dict(sink=sink_image64, float=float_image64, fly=flying_image64) 44 | 45 | @st.cache_resource 46 | def get_model_version(): 47 | reg = Registry(session=session) 48 | m = reg.get_model("titanic") 49 | return m.default 50 | 51 | 52 | @st.cache_data 53 | def load_data(df): 54 | return session.table(df).to_pandas() 55 | 56 | 57 | test_2 = load_data("test_data") 58 | 59 | 60 | col1, col2, col3 = st.columns(3) 61 | 62 | with col1: 63 | pclass = st.selectbox("What class is your ticket?", ["FIRST", "SECOND", "THIRD"]) 64 | town = st.selectbox( 65 | "What town did you embark from?", ["SOUTHAMPTON", "CHERBOURG", "QUEENSTOWN"] 66 | ) 67 | 68 | with col2: 69 | fare = st.number_input( 70 | "What was the cost of your ticket?", 71 | min_value=0.00, 72 | max_value=512.00, 73 | value=50.00, 74 | ) 75 | who = st.selectbox("Are you a Man, Woman, or Child?", ("MAN", "WOMAN", "CHILD")) 76 | with col3: 77 | sibsp = st.number_input( 78 | "How many siblings/spouses are traveling with you?", 79 | min_value=0, 80 | max_value=8, 81 | value=0, 82 | step=1, 83 | ) 84 | parch = st.number_input( 85 | "How many parents/children are traveling with you?", 86 | min_value=0, 87 | max_value=6, 88 | value=0, 89 | step=1, 90 | ) 91 | 92 | 93 | data = [[sibsp, parch, pclass, who, town, fare]] 94 | 95 | columns = ["SIBSP", "PARCH", "CLASS", "WHO", "EMBARK_TOWN", "FARE"] 96 | 97 | titanic_df = pd.DataFrame(data, columns=columns) 98 | 99 | cat_cols = ["CLASS", "WHO", "EMBARK_TOWN"] 100 | num_cols = ["SIBSP", "PARCH", "FARE"] 101 | 102 | 103 | titanic_df = pd.get_dummies(data=titanic_df, columns=cat_cols) 104 | 105 | 106 | # Get columns in df1 that are not in df2 107 | new_columns = list(set(test_2.columns) - set(titanic_df.columns)) 108 | 109 | # Add missing columns to df2 and fill them with zeros 110 | for col in new_columns: 111 | titanic_df[col] = 0 112 | 113 | CLASS_SECOND = titanic_df["CLASS_SECOND"][0] 114 | CLASS_THIRD = titanic_df["CLASS_THIRD"][0] 115 | WHO_MAN = titanic_df["WHO_MAN"][0] 116 | WHO_WOMAN = titanic_df["WHO_WOMAN"][0] 117 | EMBARK_TOWN_QUEENSTOWN = titanic_df["EMBARK_TOWN_QUEENSTOWN"][0] 118 | EMBARK_TOWN_SOUTHAMPTON = titanic_df["EMBARK_TOWN_SOUTHAMPTON"][0] 119 | SIBSP = titanic_df["SIBSP"][0] 120 | PARCH = titanic_df["PARCH"][0] 121 | FARE = titanic_df["FARE"][0] 122 | 123 | mv = get_model_version() 124 | 125 | predictions = mv.run(function_name="predict_proba", X=titanic_df) 126 | surv_pred = predictions['output_feature_1'][0] 127 | 128 | surv_pred = round(float(surv_pred) * 100, 2) 129 | 130 | st.metric(value=f"{surv_pred}%", label="Chance of surviving the Titanic") 131 | 132 | images = load_images() 133 | 134 | if surv_pred > 75: 135 | st.image(f'data:image/gif;base64,{images.get("fly")}') 136 | st.write("Have fun!") 137 | elif surv_pred >= 40: 138 | st.image(f'data:image/gif;base64,{images.get("float")}') 139 | st.write("You may survive but you are going to be cold ") 140 | else: 141 | st.image(f'data:image/gif;base64,{images.get("sink")}') 142 | st.write("I would NOT get on this boat") 143 | -------------------------------------------------------------------------------- /notebooks/2_snowml.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import warnings\n", 10 | "\n", 11 | "from snowflake.ml.modeling.impute import SimpleImputer\n", 12 | "from snowflake.ml.modeling.metrics import accuracy_score, precision_score, recall_score\n", 13 | "from snowflake.ml.modeling.preprocessing import OneHotEncoder\n", 14 | "from snowflake.ml.modeling.xgboost import XGBClassifier\n", 15 | "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", 16 | "from snowflake.snowpark import Session\n", 17 | "from snowflake.snowpark import types as T\n", 18 | "from snowflake.snowpark.functions import col\n", 19 | "\n", 20 | "warnings.simplefilter(action=\"ignore\", category=UserWarning)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "titanic_df = session.table(\"titanic\")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "titanic_df.show()" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "# Columns with null values and their respective counts\n", 57 | "{\n", 58 | " k: v\n", 59 | " for k, v in {\n", 60 | " col_name: titanic_df.where(col(col_name).is_null()).count()\n", 61 | " for col_name in titanic_df.columns\n", 62 | " }.items()\n", 63 | " if v > 0\n", 64 | "}" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "titanic_df = titanic_df.drop([\"AGE\", \"DECK\", \"ALIVE\", \"ADULT_MALE\", \"EMBARKED\"])" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "titanic_df = titanic_df.withColumn(\"FARE\", titanic_df[\"FARE\"].astype(T.FloatType()))\n", 83 | "\n", 84 | "titanic_df.show()" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "cat_cols = [\"SEX\", \"CLASS\", \"WHO\", \"EMBARK_TOWN\"]\n", 94 | "num_cols = [\"PCLASS\", \"SIBSP\", \"PARCH\", \"FARE\"]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "impute_cat = SimpleImputer(\n", 104 | " input_cols=cat_cols,\n", 105 | " output_cols=cat_cols,\n", 106 | " strategy=\"most_frequent\",\n", 107 | " drop_input_cols=True,\n", 108 | ")\n", 109 | "\n", 110 | "titanic_df = impute_cat.fit(titanic_df).transform(titanic_df)\n", 111 | "titanic_df.show()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "OHE = OneHotEncoder(\n", 121 | " input_cols=cat_cols,\n", 122 | " output_cols=cat_cols,\n", 123 | " drop_input_cols=True,\n", 124 | " drop=\"first\",\n", 125 | " handle_unknown=\"ignore\",\n", 126 | ")\n", 127 | "\n", 128 | "titanic_df = OHE.fit(titanic_df).transform(titanic_df)\n", 129 | "titanic_df.show()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "train_df, test_df = titanic_df.random_split(weights=[0.8, 0.2], seed=8)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "xgb = XGBClassifier(\n", 148 | " input_cols=train_df.drop(\"SURVIVED\").columns,\n", 149 | " label_cols=\"SURVIVED\",\n", 150 | " output_cols=\"PRED_SURVIVED\",\n", 151 | ")\n", 152 | "\n", 153 | "# Train\n", 154 | "xgb.fit(train_df)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "result = xgb.predict(test_df)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "accuracy = accuracy_score(\n", 173 | " df=result, y_true_col_names=\"SURVIVED\", y_pred_col_names=\"PRED_SURVIVED\"\n", 174 | ")\n", 175 | "\n", 176 | "precision = precision_score(\n", 177 | " df=result, y_true_col_names=\"SURVIVED\", y_pred_col_names=\"PRED_SURVIVED\"\n", 178 | ")\n", 179 | "\n", 180 | "recall = recall_score(\n", 181 | " df=result, y_true_col_names=\"SURVIVED\", y_pred_col_names=\"PRED_SURVIVED\"\n", 182 | ")\n", 183 | "\n", 184 | "print(f\"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}\")" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "language_info": { 190 | "name": "python" 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /notebooks/common.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import logging 4 | import re 5 | from functools import wraps 6 | 7 | import pandas as pd 8 | from snowflake.ml.registry import Registry 9 | from snowflake.snowpark import Session 10 | from snowflake.snowpark import functions as F 11 | from snowflake.snowpark import types as T 12 | 13 | 14 | def logger(func): 15 | @wraps(func) 16 | def wrapper(*args, **kwargs): 17 | logging.info(f"Running {func.__name__} with args: {args}, kwargs: {kwargs}") 18 | try: 19 | result = func(*args, **kwargs) 20 | logging.info(f"Finished {func.__name__} with result: {result}") 21 | except Exception as e: 22 | logging.error(f"Error occurred in {func.__name__}: {e}") 23 | raise 24 | else: 25 | return result 26 | 27 | return wrapper 28 | 29 | 30 | def convert_to_all_caps(c): 31 | """ 32 | Converts a given string to all capital letters and separates words with underscores. 33 | 34 | Args: 35 | c (str): The input string to be converted. 36 | 37 | Returns: 38 | str: The converted string with all capital letters and underscores separating words. 39 | """ 40 | return re.sub( 41 | r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])", 42 | "_", 43 | c, 44 | ).upper() 45 | 46 | 47 | def rename_columns_all_caps(df): 48 | """ 49 | Renames all columns in the DataFrame to uppercase. 50 | 51 | Args: 52 | df (snowpark.DataFrame): The input DataFrame. 53 | 54 | Returns: 55 | snowpark.DataFrame: The DataFrame with all column names converted to uppercase. 56 | """ 57 | return df.to_df([convert_to_all_caps(c) for c in df.columns]) 58 | 59 | 60 | def read_url_csv(session: Session, url: str): 61 | df_pandas = pd.read_csv(url) 62 | df = rename_columns_all_caps(session.create_dataframe(df_pandas)) 63 | return df 64 | 65 | 66 | def get_col_types(df, type): 67 | """ 68 | Returns a list of column names in a DataFrame that match the specified data type. 69 | 70 | Args: 71 | df: The DataFrame to search for column types. 72 | type (str): The data type to filter columns by. Valid values are "string" and "numeric". 73 | 74 | Returns: 75 | list: A list of column names that match the specified data type. 76 | 77 | Raises: 78 | ValueError: If the specified type is not "string" or "numeric". 79 | """ 80 | if type == "string": 81 | return [c.name for c in df.schema if isinstance(c.datatype, (T.StringType))] 82 | elif type == "numeric": 83 | return [ 84 | c.name 85 | for c in df.schema 86 | if isinstance( 87 | c.datatype, (T.DoubleType, T.IntegerType, T.LongType, T.FloatType) 88 | ) 89 | ] 90 | else: 91 | raise ValueError(f"Invalid type: {type}") 92 | 93 | 94 | def get_next_version(reg, model_name) -> str: 95 | """ 96 | Returns the next version of a model based on the existing versions in the registry. 97 | 98 | Args: 99 | reg: The registry object that provides access to the models. 100 | model_name: The name of the model. 101 | 102 | Returns: 103 | str: The next version of the model in the format "V_". 104 | 105 | Raises: 106 | ValueError: If the version list for the model is empty or if the version format is invalid. 107 | """ 108 | models = reg.show_models() 109 | if models.empty: 110 | return "V_1" 111 | elif model_name not in models["name"].to_list(): 112 | return "V_1" 113 | max_version_number = max( 114 | [ 115 | int(version.split("_")[-1]) 116 | for version in ast.literal_eval( 117 | models.loc[models["name"] == model_name, "versions"].values[0] 118 | ) 119 | ] 120 | ) 121 | return f"V_{max_version_number + 1}" 122 | 123 | 124 | def count_all_nulls(df) -> dict: 125 | """ 126 | Counts the number of null values in each column of a DataFrame and returns a dictionary 127 | with column names as keys and the corresponding count of null values as values. 128 | 129 | Args: 130 | df: The DataFrame to count null values in. 131 | 132 | Returns: 133 | dict: A dictionary with column names as keys and the count of null values as values. 134 | """ 135 | return { 136 | k: v 137 | for k, v in { 138 | c: df.where(F.col(c).is_null()).count() for c in df.columns 139 | }.items() 140 | if v > 0 141 | } 142 | 143 | 144 | def get_version_with_highest_accuracy(reg: Registry, model_name: str): 145 | """ 146 | Returns the version name of the model with the highest accuracy. 147 | 148 | Parameters: 149 | reg (Registry): The Registry object. 150 | model_name (str): The name of the model. 151 | 152 | Returns: 153 | str: The version name of the model with the highest accuracy. 154 | """ 155 | model_versions = reg.get_model(model_name).show_versions() 156 | model_versions["accuracy"] = model_versions["metadata"].apply( 157 | lambda x: json.loads(x).get("metrics", {}).get("Accuracy", None) 158 | ) 159 | return ( 160 | model_versions.sort_values(by="accuracy", ascending=False) 161 | .head(1)["name"] 162 | .values[0] 163 | ) 164 | 165 | 166 | def get_infer_schema(reg: Registry) -> T.StructType: 167 | m = reg.get_model("TITANIC") 168 | mv = m.default 169 | 170 | input_schema = T.StructType( 171 | [ 172 | T.StructField(input.name, input.as_snowpark_type()) 173 | for input in [ 174 | fx.get("signature").inputs 175 | for fx in mv.show_functions() 176 | if fx.get("name") == "PREDICT" 177 | ][0] 178 | ] 179 | ) 180 | return input_schema -------------------------------------------------------------------------------- /notebooks/1_load_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import snowflake.snowpark.functions as F\n", 11 | "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", 12 | "from snowflake.snowpark import Session" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stderr", 22 | "output_type": "stream", 23 | "text": [ 24 | "SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. \n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "[Row(status='TITANIC already exists, statement succeeded.')]" 41 | ] 42 | }, 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "session.sql(\"CREATE DATABASE IF NOT EXISTS SNOWPARK;\").collect()\n", 50 | "session.sql(\"CREATE SCHEMA IF NOT EXISTS SNOWPARK.TITANIC;\").collect()" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "data": { 60 | "text/plain": [ 61 | "[Row(status='ML_DATA already exists, statement succeeded.')]" 62 | ] 63 | }, 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "output_type": "execute_result" 67 | } 68 | ], 69 | "source": [ 70 | "session.sql(\"CREATE STAGE if not exists ML_DATA\").collect()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 5, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "titanic = pd.read_csv(\n", 80 | " \"https://raw.githubusercontent.com/mwaskom/seaborn-data/master/titanic.csv\"\n", 81 | ")\n", 82 | "titanic.columns = [c.upper() for c in titanic.columns]\n", 83 | "titanic.to_csv(\"titanic.csv\", index=False)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "[PutResult(source='titanic.csv', target='titanic.csv.gz', source_size=57018, target_size=6528, source_compression='NONE', target_compression='GZIP', status='UPLOADED', message='')]" 95 | ] 96 | }, 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "session.file.put(\"titanic.csv\", \"@ml_data\", overwrite=True)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 7, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "name": "stdout", 113 | "output_type": "stream", 114 | "text": [ 115 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 116 | "|\"SURVIVED\" |\"PCLASS\" |\"SEX\" |\"AGE\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"EMBARKED\" |\"CLASS\" |\"WHO\" |\"ADULT_MALE\" |\"DECK\" |\"EMBARK_TOWN\" |\"ALIVE\" |\"ALONE\" |\n", 117 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 118 | "|0 |3 |male |22.00 |1 |0 |7.2500 |S |Third |man |True |NULL |Southampton |False |False |\n", 119 | "|1 |1 |female |38.00 |1 |0 |71.2833 |C |First |woman |False |C |Cherbourg |True |False |\n", 120 | "|1 |3 |female |26.00 |0 |0 |7.9250 |S |Third |woman |False |NULL |Southampton |True |True |\n", 121 | "|1 |1 |female |35.00 |1 |0 |53.1000 |S |First |woman |False |C |Southampton |True |False |\n", 122 | "|0 |3 |male |35.00 |0 |0 |8.0500 |S |Third |man |True |NULL |Southampton |False |True |\n", 123 | "|0 |3 |male |NULL |0 |0 |8.4583 |Q |Third |man |True |NULL |Queenstown |False |True |\n", 124 | "|0 |1 |male |54.00 |0 |0 |51.8625 |S |First |man |True |E |Southampton |False |True |\n", 125 | "|0 |3 |male |2.00 |3 |1 |21.0750 |S |Third |child |False |NULL |Southampton |False |False |\n", 126 | "|1 |3 |female |27.00 |0 |2 |11.1333 |S |Third |woman |False |NULL |Southampton |True |False |\n", 127 | "|1 |2 |female |14.00 |1 |0 |30.0708 |C |Second |child |False |NULL |Cherbourg |True |False |\n", 128 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 129 | "\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "# Create a Snowpark DataFrame that is configured to load data from the CSV file\n", 135 | "titanic_df = (\n", 136 | " session.read.option(\"infer_schema\", True)\n", 137 | " .option(\"PARSE_HEADER\", True)\n", 138 | " .csv(\"@ml_data/titanic.csv\")\n", 139 | ")\n", 140 | "titanic_df.show()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 153 | "|\"SURVIVED\" |\"PCLASS\" |\"AGE\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"ADULT_MALE\" |\"DECK\" |\"ALIVE\" |\"ALONE\" |\"SEX\" |\"EMBARKED\" |\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\n", 154 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 155 | "|0 |3 |22.00 |1 |0 |7.2500 |True |NULL |False |False |MALE |S |THIRD |MAN |SOUTHAMPTON |\n", 156 | "|1 |1 |38.00 |1 |0 |71.2833 |False |C |True |False |FEMALE |C |FIRST |WOMAN |CHERBOURG |\n", 157 | "|1 |3 |26.00 |0 |0 |7.9250 |False |NULL |True |True |FEMALE |S |THIRD |WOMAN |SOUTHAMPTON |\n", 158 | "|1 |1 |35.00 |1 |0 |53.1000 |False |C |True |False |FEMALE |S |FIRST |WOMAN |SOUTHAMPTON |\n", 159 | "|0 |3 |35.00 |0 |0 |8.0500 |True |NULL |False |True |MALE |S |THIRD |MAN |SOUTHAMPTON |\n", 160 | "|0 |3 |NULL |0 |0 |8.4583 |True |NULL |False |True |MALE |Q |THIRD |MAN |QUEENSTOWN |\n", 161 | "|0 |1 |54.00 |0 |0 |51.8625 |True |E |False |True |MALE |S |FIRST |MAN |SOUTHAMPTON |\n", 162 | "|0 |3 |2.00 |3 |1 |21.0750 |False |NULL |False |False |MALE |S |THIRD |CHILD |SOUTHAMPTON |\n", 163 | "|1 |3 |27.00 |0 |2 |11.1333 |False |NULL |True |False |FEMALE |S |THIRD |WOMAN |SOUTHAMPTON |\n", 164 | "|1 |2 |14.00 |1 |0 |30.0708 |False |NULL |True |False |FEMALE |C |SECOND |CHILD |CHERBOURG |\n", 165 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 166 | "\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "def fix_values(column):\n", 172 | " return F.upper(F.regexp_replace(F.col(column), \"[^a-zA-Z0-9]+\", \"_\"))\n", 173 | "\n", 174 | "\n", 175 | "for col in [\"SEX\", \"EMBARKED\", \"CLASS\", \"WHO\", \"EMBARK_TOWN\"]:\n", 176 | " titanic_df = titanic_df.with_column(col, fix_values(col))\n", 177 | "\n", 178 | "titanic_df.show()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 9, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "titanic_df.write.mode(\"overwrite\").save_as_table(\"titanic\")" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.10.11" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 2 212 | } 213 | -------------------------------------------------------------------------------- /notebooks/4_Pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ast\n", 10 | "import json\n", 11 | "import warnings\n", 12 | "\n", 13 | "import pandas as pd\n", 14 | "from snowflake.ml.modeling.impute import SimpleImputer\n", 15 | "from snowflake.ml.modeling.metrics import (\n", 16 | " accuracy_score,\n", 17 | " confusion_matrix,\n", 18 | " f1_score,\n", 19 | " precision_score,\n", 20 | " recall_score,\n", 21 | ")\n", 22 | "from snowflake.ml.modeling.metrics import accuracy_score\n", 23 | "from snowflake.ml.modeling.model_selection import GridSearchCV\n", 24 | "from snowflake.ml.modeling.preprocessing import OneHotEncoder\n", 25 | "from snowflake.ml.modeling.xgboost import XGBClassifier\n", 26 | "from snowflake.ml.registry import Registry\n", 27 | "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", 28 | "from snowflake.snowpark import Session\n", 29 | "from snowflake.snowpark import functions as F\n", 30 | "from snowflake.snowpark import types as T\n", 31 | "from snowflake.snowpark.functions import col\n", 32 | "from snowflake.ml.modeling.pipeline import Pipeline\n", 33 | "\n", 34 | "from common import (\n", 35 | " get_col_types,\n", 36 | " get_next_version,\n", 37 | " get_version_with_highest_accuracy,\n", 38 | ")" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stderr", 48 | "output_type": "stream", 49 | "text": [ 50 | "SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. \n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/ml/_internal/utils/snowpark_dataframe_utils.py:88: UserWarning: Warning: The Decimal(7, 4) data type is being automatically converted to DoubleType in the Snowpark DataFrame. This automatic conversion may lead to potential precision loss and rounding errors. If you wish to prevent this conversion, you should manually perform the necessary data type conversion.\n", 68 | " warnings.warn(\n", 69 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/snowpark/session.py:2182: UserWarning: Pandas Dataframe has non-standard index of type which will not be written. Consider changing the index to pd.RangeIndex(start=0,...,step=1) or call reset_index() to keep index as column(s)\n", 70 | " success, nchunks, nrows, ci_output = write_pandas(\n", 71 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/snowpark/session.py:2182: UserWarning: Pandas Dataframe has non-standard index of type which will not be written. Consider changing the index to pd.RangeIndex(start=0,...,step=1) or call reset_index() to keep index as column(s)\n", 72 | " success, nchunks, nrows, ci_output = write_pandas(\n", 73 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/snowpark/session.py:2182: UserWarning: Pandas Dataframe has non-standard index of type which will not be written. Consider changing the index to pd.RangeIndex(start=0,...,step=1) or call reset_index() to keep index as column(s)\n", 74 | " success, nchunks, nrows, ci_output = write_pandas(\n", 75 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/ml/_internal/utils/snowpark_dataframe_utils.py:88: UserWarning: Warning: The Decimal(7, 4) data type is being automatically converted to DoubleType in the Snowpark DataFrame. This automatic conversion may lead to potential precision loss and rounding errors. If you wish to prevent this conversion, you should manually perform the necessary data type conversion.\n", 76 | " warnings.warn(\n" 77 | ] 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 84 | "|\"SEX_MALE\" |\"CLASS_SECOND\" |\"CLASS_THIRD\" |\"WHO_MAN\" |\"WHO_WOMAN\" |\"EMBARK_TOWN_QUEENSTOWN\" |\"EMBARK_TOWN_SOUTHAMPTON\" |\"SURVIVED\" |\"PCLASS\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"ALONE\" |\"PREDICT_PROBA_0\" |\"PREDICT_PROBA_1\" |\"OUTPUT_SURVIVED\" |\n", 85 | "-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 86 | "|1.0 |0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |3 |1 |0 |7.25 |False |0.9325933456420898 |0.06740663200616837 |0 |\n", 87 | "|0.0 |0.0 |0.0 |0.0 |1.0 |0.0 |0.0 |1 |1 |1 |0 |71.2833 |False |0.009830057621002196 |0.9901699423789978 |1 |\n", 88 | "|0.0 |0.0 |0.0 |0.0 |1.0 |0.0 |1.0 |1 |1 |1 |0 |53.1 |False |0.015713274478912354 |0.9842867255210876 |1 |\n", 89 | "|1.0 |0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |3 |0 |0 |8.05 |True |0.8683130145072937 |0.1316869705915451 |0 |\n", 90 | "|1.0 |0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |0 |3 |3 |1 |21.075 |False |0.8506706953048706 |0.1493292897939682 |0 |\n", 91 | "|0.0 |0.0 |1.0 |0.0 |1.0 |0.0 |1.0 |1 |3 |0 |2 |11.1333 |False |0.2507440447807312 |0.7492559552192688 |1 |\n", 92 | "|0.0 |1.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1 |2 |1 |0 |30.0708 |False |0.011385798454284668 |0.9886142015457152 |1 |\n", 93 | "|0.0 |0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |1 |3 |1 |1 |16.7 |False |0.11237221956253052 |0.8876277804374695 |1 |\n", 94 | "|0.0 |0.0 |0.0 |0.0 |1.0 |0.0 |1.0 |1 |1 |0 |0 |26.55 |True |0.0649450421333313 |0.9350549578666688 |1 |\n", 95 | "|1.0 |0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |3 |1 |5 |31.275 |False |0.8820328712463379 |0.1179671511054039 |0 |\n", 96 | "-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 97 | "\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "df = session.table(\"TITANIC\")\n", 103 | "df = df.drop([\"PASSENGER_ID\", \"AGE\", \"DECK\", \"ALIVE\", \"ADULT_MALE\", \"EMBARKED\"])\n", 104 | "train_df, test_df = df.random_split([0.8, 0.2], seed=42)\n", 105 | "train_df.write.save_as_table(\"TRAIN\", mode=\"overwrite\")\n", 106 | "test_df.write.save_as_table(\"TEST\", mode=\"overwrite\")\n", 107 | "train_df, test_df = session.table(\"train\"), session.table(\"test\")\n", 108 | "parameters = {\n", 109 | " \"n_estimators\": [100, 200, 300, 400, 500],\n", 110 | " \"learning_rate\": [0.1, 0.2, 0.3, 0.4, 0.5],\n", 111 | " \"max_depth\": list(range(3, 5, 1)),\n", 112 | " \"min_child_weight\": list(range(3, 5, 1)),\n", 113 | "}\n", 114 | "cat_cols = get_col_types(df, \"string\")\n", 115 | "pipeline = Pipeline(\n", 116 | " steps=[\n", 117 | " (\n", 118 | " \"SimpleImputer\",\n", 119 | " SimpleImputer(\n", 120 | " input_cols=cat_cols,\n", 121 | " output_cols=cat_cols,\n", 122 | " strategy=\"most_frequent\",\n", 123 | " drop_input_cols=True,\n", 124 | " ),\n", 125 | " ),\n", 126 | " (\n", 127 | " \"OneHotEncoder\",\n", 128 | " OneHotEncoder(\n", 129 | " input_cols=cat_cols,\n", 130 | " output_cols=cat_cols,\n", 131 | " drop_input_cols=True,\n", 132 | " drop=\"first\",\n", 133 | " handle_unknown=\"ignore\",\n", 134 | " ),\n", 135 | " ),\n", 136 | " (\n", 137 | " \"GridSearchCV\",\n", 138 | " GridSearchCV(\n", 139 | " estimator=XGBClassifier(random_state=42),\n", 140 | " param_grid=parameters,\n", 141 | " n_jobs=-1,\n", 142 | " scoring=\"accuracy\",\n", 143 | " label_cols=\"SURVIVED\",\n", 144 | " ),\n", 145 | " ),\n", 146 | " ],\n", 147 | ")\n", 148 | "pipeline.fit(train_df)\n", 149 | "result_df = pipeline.predict_proba(train_df)\n", 150 | "result_df = result_df.with_column(\n", 151 | " \"OUTPUT_SURVIVED\", F.col(\"predict_proba_1\").cast(T.LongType())\n", 152 | ")\n", 153 | "result_df.show()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 8, 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stderr", 163 | "output_type": "stream", 164 | "text": [ 165 | "DataFrame.flatten() is deprecated since 0.7.0. Use `DataFrame.join_table_function()` instead.\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "metrics = {\n", 171 | " \"Accuracy\": accuracy_score(\n", 172 | " df=result_df,\n", 173 | " y_true_col_names=\"SURVIVED\",\n", 174 | " y_pred_col_names=\"OUTPUT_SURVIVED\",\n", 175 | " ),\n", 176 | " \"Precision\": precision_score(\n", 177 | " df=result_df,\n", 178 | " y_true_col_names=\"SURVIVED\",\n", 179 | " y_pred_col_names=\"OUTPUT_SURVIVED\",\n", 180 | " ),\n", 181 | " \"Recall\": recall_score(\n", 182 | " df=result_df,\n", 183 | " y_true_col_names=\"SURVIVED\",\n", 184 | " y_pred_col_names=\"OUTPUT_SURVIVED\",\n", 185 | " ),\n", 186 | " \"F1 Score\": f1_score(\n", 187 | " df=result_df,\n", 188 | " y_true_col_names=\"SURVIVED\",\n", 189 | " y_pred_col_names=\"OUTPUT_SURVIVED\",\n", 190 | " ),\n", 191 | " \"Confusion Matrix\": confusion_matrix(\n", 192 | " df=result_df, y_true_col_name=\"SURVIVED\", y_pred_col_name=\"OUTPUT_SURVIVED\"\n", 193 | " ).tolist(),\n", 194 | " }" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 13, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "{'Accuracy': 0.887931,\n", 206 | " 'Precision': 0.8961038961038961,\n", 207 | " 'Recall': 0.7931034482758621,\n", 208 | " 'F1 Score': 0.8414634146341463,\n", 209 | " 'Confusion Matrix': [[411.0, 24.0], [54.0, 207.0]]}" 210 | ] 211 | }, 212 | "execution_count": 13, 213 | "metadata": {}, 214 | "output_type": "execute_result" 215 | } 216 | ], 217 | "source": [ 218 | "metrics" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 14, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stderr", 228 | "output_type": "stream", 229 | "text": [ 230 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/contextlib.py:135: UserWarning: `relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, reproducibility, etc., set `options={'relax_version': False}` when logging the model.\n", 231 | " return next(self.gen)\n" 232 | ] 233 | } 234 | ], 235 | "source": [ 236 | "X = train_df.drop(\"SURVIVED\").limit(100)\n", 237 | "\n", 238 | "reg = Registry(session=session)\n", 239 | "\n", 240 | "titanic_model = reg.log_model(\n", 241 | " model_name=\"TITANIC_PIPE\",\n", 242 | " version_name=get_next_version(reg, \"TITANIC_PIPE\"),\n", 243 | " model=pipeline,\n", 244 | " metrics=metrics,\n", 245 | ")" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 15, 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [ 254 | "m = reg.get_model(\"TITANIC_PIPE\")\n", 255 | "m.default_version = get_version_with_highest_accuracy(reg, \"TITANIC_PIPE\")" 256 | ] 257 | }, 258 | { 259 | "cell_type": "markdown", 260 | "metadata": {}, 261 | "source": [ 262 | "## Call pipeline from SQL" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Show the data is not cleaned before performming inference" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 17, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "-------------------------------------------------------------------------------------------------------------\n", 282 | "|\"SURVIVED\" |\"PCLASS\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"ALONE\" |\"SEX\" |\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\n", 283 | "-------------------------------------------------------------------------------------------------------------\n", 284 | "|1 |3 |0 |0 |7.9250 |True |FEMALE |THIRD |WOMAN |SOUTHAMPTON |\n", 285 | "|0 |3 |0 |0 |8.4583 |True |MALE |THIRD |MAN |QUEENSTOWN |\n", 286 | "|0 |1 |0 |0 |51.8625 |True |MALE |FIRST |MAN |SOUTHAMPTON |\n", 287 | "|0 |3 |0 |0 |8.0500 |True |MALE |THIRD |MAN |SOUTHAMPTON |\n", 288 | "|1 |3 |0 |0 |8.0292 |True |FEMALE |THIRD |CHILD |QUEENSTOWN |\n", 289 | "|0 |3 |0 |0 |7.8958 |True |MALE |THIRD |MAN |SOUTHAMPTON |\n", 290 | "|1 |1 |1 |0 |146.5208 |False |FEMALE |FIRST |WOMAN |CHERBOURG |\n", 291 | "|1 |2 |1 |2 |41.5792 |False |FEMALE |SECOND |CHILD |CHERBOURG |\n", 292 | "|0 |3 |1 |0 |17.8000 |False |FEMALE |THIRD |WOMAN |SOUTHAMPTON |\n", 293 | "|0 |3 |0 |0 |7.8000 |True |MALE |THIRD |MAN |SOUTHAMPTON |\n", 294 | "-------------------------------------------------------------------------------------------------------------\n", 295 | "\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "test_df = session.table('TEST')\n", 301 | "test_df.show()" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "Run the pipeline" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": 20, 314 | "metadata": {}, 315 | "outputs": [ 316 | { 317 | "name": "stdout", 318 | "output_type": "stream", 319 | "text": [ 320 | "-----------------------------------------------------------------------------------------------------------------------\n", 321 | "|\"PCLASS\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"ALONE\" |\"SEX\" |\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\"SURV_PRED\" |\n", 322 | "-----------------------------------------------------------------------------------------------------------------------\n", 323 | "|3 |0 |0 |7.9250 |True |FEMALE |THIRD |WOMAN |SOUTHAMPTON |0.5756063461303711 |\n", 324 | "|3 |0 |0 |8.4583 |True |MALE |THIRD |MAN |QUEENSTOWN |0.06476970762014389 |\n", 325 | "|1 |0 |0 |51.8625 |True |MALE |FIRST |MAN |SOUTHAMPTON |0.07613715529441833 |\n", 326 | "|3 |0 |0 |8.0500 |True |MALE |THIRD |MAN |SOUTHAMPTON |0.1316869705915451 |\n", 327 | "|3 |0 |0 |8.0292 |True |FEMALE |THIRD |CHILD |QUEENSTOWN |0.5572702884674072 |\n", 328 | "|3 |0 |0 |7.8958 |True |MALE |THIRD |MAN |SOUTHAMPTON |0.030010107904672623 |\n", 329 | "|1 |1 |0 |146.5208 |False |FEMALE |FIRST |WOMAN |CHERBOURG |0.9791738986968994 |\n", 330 | "|2 |1 |2 |41.5792 |False |FEMALE |SECOND |CHILD |CHERBOURG |0.9464718103408813 |\n", 331 | "|3 |1 |0 |17.8000 |False |FEMALE |THIRD |WOMAN |SOUTHAMPTON |0.19497600197792053 |\n", 332 | "|3 |0 |0 |7.8000 |True |MALE |THIRD |MAN |SOUTHAMPTON |0.08418451994657516 |\n", 333 | "-----------------------------------------------------------------------------------------------------------------------\n", 334 | "\n" 335 | ] 336 | } 337 | ], 338 | "source": [ 339 | "#Copy this code in a snowflake worksheet or run via session.sql\n", 340 | "inference_df = session.sql('''\n", 341 | "select *, TITANIC_PIPE!predict_proba(*):PREDICT_PROBA_1\n", 342 | "as surv_pred\n", 343 | "from (\n", 344 | "select * exclude survived\n", 345 | "from test)\n", 346 | " ''')\n", 347 | "inference_df.show()" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 22, 353 | "metadata": {}, 354 | "outputs": [ 355 | { 356 | "name": "stdout", 357 | "output_type": "stream", 358 | "text": [ 359 | "--------------------------------------------------------------------------------------------------------------\n", 360 | "|\"PCLASS\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"ALONE\" |\"SEX\" |\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\"SURV_PRED\" |\n", 361 | "--------------------------------------------------------------------------------------------------------------\n", 362 | "|3 |0 |0 |7.9250 |True |FEMALE |THIRD |WOMAN |SOUTHAMPTON |1 |\n", 363 | "|3 |0 |0 |8.4583 |True |MALE |THIRD |MAN |QUEENSTOWN |0 |\n", 364 | "|1 |0 |0 |51.8625 |True |MALE |FIRST |MAN |SOUTHAMPTON |0 |\n", 365 | "|3 |0 |0 |8.0500 |True |MALE |THIRD |MAN |SOUTHAMPTON |0 |\n", 366 | "|3 |0 |0 |8.0292 |True |FEMALE |THIRD |CHILD |QUEENSTOWN |1 |\n", 367 | "|3 |0 |0 |7.8958 |True |MALE |THIRD |MAN |SOUTHAMPTON |0 |\n", 368 | "|1 |1 |0 |146.5208 |False |FEMALE |FIRST |WOMAN |CHERBOURG |1 |\n", 369 | "|2 |1 |2 |41.5792 |False |FEMALE |SECOND |CHILD |CHERBOURG |1 |\n", 370 | "|3 |1 |0 |17.8000 |False |FEMALE |THIRD |WOMAN |SOUTHAMPTON |0 |\n", 371 | "|3 |0 |0 |7.8000 |True |MALE |THIRD |MAN |SOUTHAMPTON |0 |\n", 372 | "--------------------------------------------------------------------------------------------------------------\n", 373 | "\n" 374 | ] 375 | } 376 | ], 377 | "source": [ 378 | "#Copy this code in a snowflake worksheet or run via session.sql\n", 379 | "inference_df = session.sql('''\n", 380 | "select *, TITANIC_PIPE!predict(*):OUTPUT_SURVIVED\n", 381 | "as surv_pred\n", 382 | "from (\n", 383 | "select * exclude survived\n", 384 | "from test)\n", 385 | " ''')\n", 386 | "inference_df.show()" 387 | ] 388 | } 389 | ], 390 | "metadata": { 391 | "kernelspec": { 392 | "display_name": "Intro_SnowML", 393 | "language": "python", 394 | "name": "python3" 395 | }, 396 | "language_info": { 397 | "codemirror_mode": { 398 | "name": "ipython", 399 | "version": 3 400 | }, 401 | "file_extension": ".py", 402 | "mimetype": "text/x-python", 403 | "name": "python", 404 | "nbconvert_exporter": "python", 405 | "pygments_lexer": "ipython3", 406 | "version": "3.10.14" 407 | } 408 | }, 409 | "nbformat": 4, 410 | "nbformat_minor": 2 411 | } 412 | -------------------------------------------------------------------------------- /notebooks/3_Deployment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ast\n", 10 | "import json\n", 11 | "import warnings\n", 12 | "\n", 13 | "import pandas as pd\n", 14 | "from snowflake.ml.modeling.impute import SimpleImputer\n", 15 | "from snowflake.ml.modeling.metrics import accuracy_score\n", 16 | "from snowflake.ml.modeling.model_selection import GridSearchCV\n", 17 | "from snowflake.ml.modeling.preprocessing import OneHotEncoder\n", 18 | "from snowflake.ml.modeling.xgboost import XGBClassifier\n", 19 | "from snowflake.ml.modeling.lightgbm import LGBMClassifier \n", 20 | "from snowflake.ml.registry import Registry\n", 21 | "from snowflake.ml.utils.connection_params import SnowflakeLoginOptions\n", 22 | "from snowflake.snowpark import Session\n", 23 | "from snowflake.snowpark import types as T\n", 24 | "from snowflake.snowpark.functions import col\n", 25 | "\n", 26 | "warnings.simplefilter(action=\"ignore\", category=UserWarning)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "SnowflakeLoginOptions() is in private preview since 0.2.0. Do not use it in production. \n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "session = Session.builder.configs(SnowflakeLoginOptions()).getOrCreate()" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "titanic_df = session.table(\"titanic\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 65 | "|\"SURVIVED\" |\"PCLASS\" |\"AGE\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"ADULT_MALE\" |\"DECK\" |\"ALIVE\" |\"ALONE\" |\"SEX\" |\"EMBARKED\" |\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\n", 66 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 67 | "|0 |3 |22.00 |1 |0 |7.2500 |True |NULL |False |False |MALE |S |THIRD |MAN |SOUTHAMPTON |\n", 68 | "|1 |1 |38.00 |1 |0 |71.2833 |False |C |True |False |FEMALE |C |FIRST |WOMAN |CHERBOURG |\n", 69 | "|1 |3 |26.00 |0 |0 |7.9250 |False |NULL |True |True |FEMALE |S |THIRD |WOMAN |SOUTHAMPTON |\n", 70 | "|1 |1 |35.00 |1 |0 |53.1000 |False |C |True |False |FEMALE |S |FIRST |WOMAN |SOUTHAMPTON |\n", 71 | "|0 |3 |35.00 |0 |0 |8.0500 |True |NULL |False |True |MALE |S |THIRD |MAN |SOUTHAMPTON |\n", 72 | "|0 |3 |NULL |0 |0 |8.4583 |True |NULL |False |True |MALE |Q |THIRD |MAN |QUEENSTOWN |\n", 73 | "|0 |1 |54.00 |0 |0 |51.8625 |True |E |False |True |MALE |S |FIRST |MAN |SOUTHAMPTON |\n", 74 | "|0 |3 |2.00 |3 |1 |21.0750 |False |NULL |False |False |MALE |S |THIRD |CHILD |SOUTHAMPTON |\n", 75 | "|1 |3 |27.00 |0 |2 |11.1333 |False |NULL |True |False |FEMALE |S |THIRD |WOMAN |SOUTHAMPTON |\n", 76 | "|1 |2 |14.00 |1 |0 |30.0708 |False |NULL |True |False |FEMALE |C |SECOND |CHILD |CHERBOURG |\n", 77 | "-------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 78 | "\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "titanic_df.show()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "data": { 93 | "text/plain": [ 94 | "{'AGE': 177, 'DECK': 688, 'EMBARKED': 2, 'EMBARK_TOWN': 2}" 95 | ] 96 | }, 97 | "execution_count": 5, 98 | "metadata": {}, 99 | "output_type": "execute_result" 100 | } 101 | ], 102 | "source": [ 103 | "# Columns with null values and their respective counts\n", 104 | "{\n", 105 | " k: v\n", 106 | " for k, v in {\n", 107 | " col_name: titanic_df.where(col(col_name).is_null()).count()\n", 108 | " for col_name in titanic_df.columns\n", 109 | " }.items()\n", 110 | " if v > 0\n", 111 | "}" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "titanic_df = titanic_df.drop(\n", 121 | " [\"AGE\", \"DECK\", \"ALIVE\", \"ADULT_MALE\", \"EMBARKED\", \"SEX\", \"PCLASS\", \"ALONE\"]\n", 122 | ")" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 7, 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "------------------------------------------------------------------------------\n", 135 | "|\"SURVIVED\" |\"SIBSP\" |\"PARCH\" |\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\"FARE\" |\n", 136 | "------------------------------------------------------------------------------\n", 137 | "|0 |1 |0 |THIRD |MAN |SOUTHAMPTON |7.25 |\n", 138 | "|1 |1 |0 |FIRST |WOMAN |CHERBOURG |71.2833 |\n", 139 | "|1 |0 |0 |THIRD |WOMAN |SOUTHAMPTON |7.925 |\n", 140 | "|1 |1 |0 |FIRST |WOMAN |SOUTHAMPTON |53.1 |\n", 141 | "|0 |0 |0 |THIRD |MAN |SOUTHAMPTON |8.05 |\n", 142 | "|0 |0 |0 |THIRD |MAN |QUEENSTOWN |8.4583 |\n", 143 | "|0 |0 |0 |FIRST |MAN |SOUTHAMPTON |51.8625 |\n", 144 | "|0 |3 |1 |THIRD |CHILD |SOUTHAMPTON |21.075 |\n", 145 | "|1 |0 |2 |THIRD |WOMAN |SOUTHAMPTON |11.1333 |\n", 146 | "|1 |1 |0 |SECOND |CHILD |CHERBOURG |30.0708 |\n", 147 | "------------------------------------------------------------------------------\n", 148 | "\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "titanic_df = titanic_df.withColumn(\"FARE\", titanic_df[\"FARE\"].astype(T.FloatType()))\n", 154 | "\n", 155 | "titanic_df.show()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 8, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "cat_cols = [\"CLASS\", \"WHO\", \"EMBARK_TOWN\"]\n", 165 | "num_cols = [\"SIBSP\", \"PARCH\", \"FARE\"]" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 9, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "------------------------------------------------------------------------------\n", 178 | "|\"CLASS\" |\"WHO\" |\"EMBARK_TOWN\" |\"SURVIVED\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\n", 179 | "------------------------------------------------------------------------------\n", 180 | "|THIRD |MAN |SOUTHAMPTON |0 |1 |0 |7.25 |\n", 181 | "|FIRST |WOMAN |CHERBOURG |1 |1 |0 |71.2833 |\n", 182 | "|THIRD |WOMAN |SOUTHAMPTON |1 |0 |0 |7.925 |\n", 183 | "|FIRST |WOMAN |SOUTHAMPTON |1 |1 |0 |53.1 |\n", 184 | "|THIRD |MAN |SOUTHAMPTON |0 |0 |0 |8.05 |\n", 185 | "|THIRD |MAN |QUEENSTOWN |0 |0 |0 |8.4583 |\n", 186 | "|FIRST |MAN |SOUTHAMPTON |0 |0 |0 |51.8625 |\n", 187 | "|THIRD |CHILD |SOUTHAMPTON |0 |3 |1 |21.075 |\n", 188 | "|THIRD |WOMAN |SOUTHAMPTON |1 |0 |2 |11.1333 |\n", 189 | "|SECOND |CHILD |CHERBOURG |1 |1 |0 |30.0708 |\n", 190 | "------------------------------------------------------------------------------\n", 191 | "\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "impute_cat = SimpleImputer(\n", 197 | " input_cols=cat_cols,\n", 198 | " output_cols=cat_cols,\n", 199 | " strategy=\"most_frequent\",\n", 200 | " drop_input_cols=True,\n", 201 | ")\n", 202 | "\n", 203 | "titanic_df = impute_cat.fit(titanic_df).transform(titanic_df)\n", 204 | "titanic_df.show()" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 10, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "name": "stderr", 214 | "output_type": "stream", 215 | "text": [ 216 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/ml/modeling/preprocessing/one_hot_encoder.py:853: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 217 | " state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan)\n", 218 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/ml/modeling/preprocessing/one_hot_encoder.py:854: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 219 | " state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(convert_to_string_excluding_nan)\n", 220 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/ml/modeling/preprocessing/one_hot_encoder.py:853: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 221 | " state_pandas[[_CATEGORY]] = state_pandas[[_CATEGORY]].applymap(convert_to_string_excluding_nan)\n", 222 | "/opt/anaconda3/envs/Intro_SnowML/lib/python3.10/site-packages/snowflake/ml/modeling/preprocessing/one_hot_encoder.py:854: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", 223 | " state_pandas[[_FITTED_CATEGORY]] = state_pandas[[_FITTED_CATEGORY]].applymap(convert_to_string_excluding_nan)\n" 224 | ] 225 | }, 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 231 | "|\"CLASS_SECOND\" |\"CLASS_THIRD\" |\"WHO_MAN\" |\"WHO_WOMAN\" |\"EMBARK_TOWN_QUEENSTOWN\" |\"EMBARK_TOWN_SOUTHAMPTON\" |\"SURVIVED\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\n", 232 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 233 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |1 |0 |7.25 |\n", 234 | "|0.0 |0.0 |0.0 |1.0 |0.0 |0.0 |1 |1 |0 |71.2833 |\n", 235 | "|0.0 |1.0 |0.0 |1.0 |0.0 |1.0 |1 |0 |0 |7.925 |\n", 236 | "|0.0 |0.0 |0.0 |1.0 |0.0 |1.0 |1 |1 |0 |53.1 |\n", 237 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |0 |8.05 |\n", 238 | "|0.0 |1.0 |1.0 |0.0 |1.0 |0.0 |0 |0 |0 |8.4583 |\n", 239 | "|0.0 |0.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |0 |51.8625 |\n", 240 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |0 |3 |1 |21.075 |\n", 241 | "|0.0 |1.0 |0.0 |1.0 |0.0 |1.0 |1 |0 |2 |11.1333 |\n", 242 | "|1.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1 |1 |0 |30.0708 |\n", 243 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 244 | "\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "OHE = OneHotEncoder(\n", 250 | " input_cols=cat_cols,\n", 251 | " output_cols=cat_cols,\n", 252 | " drop_input_cols=True,\n", 253 | " drop=\"first\",\n", 254 | " handle_unknown=\"ignore\",\n", 255 | ")\n", 256 | "\n", 257 | "titanic_df = OHE.fit(titanic_df).transform(titanic_df)\n", 258 | "titanic_df.show()" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 11, 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "train_df, test_df = titanic_df.random_split(weights=[0.8, 0.2], seed=8)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "parameters = {\n", 277 | " \"n_estimators\": [100, 200, 300, 400, 500],\n", 278 | " \"learning_rate\": [0.1, 0.2, 0.3, 0.4, 0.5],\n", 279 | " \"max_depth\": list(range(3, 6, 1)),\n", 280 | " \"min_child_weight\": list(range(1, 6, 1)),\n", 281 | "}" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 13, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "data": { 291 | "text/plain": [ 292 | "{'n_estimators': [100, 200, 300, 400, 500],\n", 293 | " 'learning_rate': [0.1, 0.2, 0.3, 0.4, 0.5],\n", 294 | " 'max_depth': [3, 4, 5],\n", 295 | " 'min_child_weight': [1, 2, 3, 4, 5]}" 296 | ] 297 | }, 298 | "execution_count": 13, 299 | "metadata": {}, 300 | "output_type": "execute_result" 301 | } 302 | ], 303 | "source": [ 304 | "parameters" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 14, 310 | "metadata": {}, 311 | "outputs": [ 312 | { 313 | "data": { 314 | "text/plain": [ 315 | "[Row(status='Statement executed successfully.')]" 316 | ] 317 | }, 318 | "execution_count": 14, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "session.sql(\n", 325 | " f\"ALTER WAREHOUSE {session.get_current_warehouse()[1:-1]} SET WAREHOUSE_SIZE=LARGE;\"\n", 326 | ").collect()" 327 | ] 328 | }, 329 | { 330 | "cell_type": "markdown", 331 | "metadata": {}, 332 | "source": [ 333 | "Data scientists may not have the ability to change the warehouse size. They will usually have access to a larger warehouse and can easily switch as well using session.use_warehouse('bigger_warehouse')" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 16, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "name": "stderr", 343 | "output_type": "stream", 344 | "text": [ 345 | "Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.\n" 346 | ] 347 | }, 348 | { 349 | "data": { 350 | "text/plain": [ 351 | "" 352 | ] 353 | }, 354 | "execution_count": 16, 355 | "metadata": {}, 356 | "output_type": "execute_result" 357 | } 358 | ], 359 | "source": [ 360 | "xgb_grid_search = GridSearchCV(\n", 361 | " estimator=XGBClassifier(),\n", 362 | " param_grid=parameters,\n", 363 | " n_jobs=-1,\n", 364 | " scoring=\"accuracy\",\n", 365 | " input_cols=train_df.drop(\"SURVIVED\").columns,\n", 366 | " label_cols=\"SURVIVED\",\n", 367 | " output_cols=\"PRED_SURVIVED\",\n", 368 | ")\n", 369 | "\n", 370 | "# Train\n", 371 | "xgb_grid_search.fit(train_df)" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 17, 377 | "metadata": {}, 378 | "outputs": [ 379 | { 380 | "name": "stderr", 381 | "output_type": "stream", 382 | "text": [ 383 | "Package 'fastparquet' is not installed in the local environment. Your UDF might not work when the package is installed on the server but not on your local environment.\n" 384 | ] 385 | }, 386 | { 387 | "data": { 388 | "text/plain": [ 389 | "" 390 | ] 391 | }, 392 | "execution_count": 17, 393 | "metadata": {}, 394 | "output_type": "execute_result" 395 | } 396 | ], 397 | "source": [ 398 | "lgbm_grid_search = GridSearchCV(\n", 399 | " estimator=LGBMClassifier(),\n", 400 | " param_grid=parameters,\n", 401 | " n_jobs=-1,\n", 402 | " scoring=\"accuracy\",\n", 403 | " input_cols=train_df.drop(\"SURVIVED\").columns,\n", 404 | " label_cols=\"SURVIVED\",\n", 405 | " output_cols=\"PRED_SURVIVED\",\n", 406 | ")\n", 407 | "\n", 408 | "# Train\n", 409 | "lgbm_grid_search.fit(train_df)" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 18, 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "data": { 419 | "text/plain": [ 420 | "[Row(status='Statement executed successfully.')]" 421 | ] 422 | }, 423 | "execution_count": 18, 424 | "metadata": {}, 425 | "output_type": "execute_result" 426 | } 427 | ], 428 | "source": [ 429 | "session.sql(\n", 430 | " f\"ALTER WAREHOUSE {session.get_current_warehouse()[1:-1]} SET WAREHOUSE_SIZE=XSMALL;\"\n", 431 | ").collect()" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 19, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "result = xgb_grid_search.predict(test_df)" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": 20, 446 | "metadata": {}, 447 | "outputs": [ 448 | { 449 | "name": "stdout", 450 | "output_type": "stream", 451 | "text": [ 452 | "Accuracy: 0.819149\n" 453 | ] 454 | } 455 | ], 456 | "source": [ 457 | "accuracy = accuracy_score(\n", 458 | " df=result, y_true_col_names=\"SURVIVED\", y_pred_col_names=\"PRED_SURVIVED\"\n", 459 | ")\n", 460 | "\n", 461 | "print(f\"Accuracy: {accuracy}\")" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 21, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "lgbm_result = lgbm_grid_search.predict(test_df)" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 22, 476 | "metadata": {}, 477 | "outputs": [ 478 | { 479 | "name": "stdout", 480 | "output_type": "stream", 481 | "text": [ 482 | "Accuracy: 0.808511\n" 483 | ] 484 | } 485 | ], 486 | "source": [ 487 | "accuracy = accuracy_score(\n", 488 | " df=lgbm_result, y_true_col_names=\"SURVIVED\", y_pred_col_names=\"PRED_SURVIVED\"\n", 489 | ")\n", 490 | "\n", 491 | "print(f\"Accuracy: {accuracy}\")" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 23, 497 | "metadata": {}, 498 | "outputs": [ 499 | { 500 | "data": { 501 | "text/html": [ 502 | "
\n", 503 | "\n", 516 | "\n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | "
accuracylearning_ratemax_depthmin_child_weightn_estimators
940.8221480.234500
3110.8207600.533200
2160.8207600.354200
1140.8207500.243500
2990.8207400.455500
\n", 570 | "
" 571 | ], 572 | "text/plain": [ 573 | " accuracy learning_rate max_depth min_child_weight n_estimators\n", 574 | "94 0.822148 0.2 3 4 500\n", 575 | "311 0.820760 0.5 3 3 200\n", 576 | "216 0.820760 0.3 5 4 200\n", 577 | "114 0.820750 0.2 4 3 500\n", 578 | "299 0.820740 0.4 5 5 500" 579 | ] 580 | }, 581 | "execution_count": 23, 582 | "metadata": {}, 583 | "output_type": "execute_result" 584 | } 585 | ], 586 | "source": [ 587 | "# Print each combination of hyperparameters with their accuracy\n", 588 | "results = xgb_grid_search.to_sklearn().cv_results_\n", 589 | "data = {\"accuracy\": results[\"mean_test_score\"]}\n", 590 | "for i, param in enumerate(results[\"params\"]):\n", 591 | " for key, value in param.items():\n", 592 | " if key not in data:\n", 593 | " data[key] = [None] * len(results[\"params\"])\n", 594 | " data[key][i] = value\n", 595 | "\n", 596 | "# Create DataFrame\n", 597 | "hp_df = pd.DataFrame(data).sort_values(by=\"accuracy\", ascending=False)\n", 598 | "hp_df.head()" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "# Model Registry\n" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": 24, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "optimal_model = xgb_grid_search.to_sklearn().best_estimator_" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": 25, 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [ 623 | "# create function to add one to our model versions if it already exists\n", 624 | "\n", 625 | "\n", 626 | "def check_and_update(df, model_name):\n", 627 | " if df.empty:\n", 628 | " return \"V_1\"\n", 629 | " elif df[df[\"name\"] == model_name].empty:\n", 630 | " return \"V_1\"\n", 631 | " else:\n", 632 | " # Increment model_version if df is not a pandas Series\n", 633 | " lst = sorted(ast.literal_eval(df[\"versions\"][0]))\n", 634 | " last_value = lst[-1]\n", 635 | " prefix, num = last_value.rsplit(\"_\", 1)\n", 636 | " new_last_value = f\"{prefix}_{int(num)+1}\"\n", 637 | " lst[-1] = new_last_value\n", 638 | " return new_last_value" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 26, 644 | "metadata": {}, 645 | "outputs": [], 646 | "source": [ 647 | "# Get sample input data to pass into the registry logging function\n", 648 | "X = train_df.drop(\"SURVIVED\").limit(100)\n", 649 | "\n", 650 | "# Create a registry and log the model\n", 651 | "# You can specify a different DB and Schema if you'd like\n", 652 | "# otherwise it uses the session context\n", 653 | "reg = Registry(session=session)\n", 654 | "\n", 655 | "reg_df = reg.show_models()\n", 656 | "\n", 657 | "# Define model name and version (use uppercase for name)\n", 658 | "model_name = \"TITANIC\"\n", 659 | "\n", 660 | "model_version = check_and_update(reg_df, model_name)\n", 661 | "\n", 662 | "titanic_model = reg.log_model(\n", 663 | " model_name=model_name,\n", 664 | " version_name=model_version,\n", 665 | " model=optimal_model,\n", 666 | " sample_input_data=X,\n", 667 | ")\n", 668 | "\n", 669 | "# Add evaluation metric\n", 670 | "titanic_model.set_metric(\n", 671 | " metric_name=\"accuracy\",\n", 672 | " value=hp_df[\"accuracy\"][0],\n", 673 | ")" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": 27, 679 | "metadata": {}, 680 | "outputs": [ 681 | { 682 | "data": { 683 | "text/html": [ 684 | "
\n", 685 | "\n", 698 | "\n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | "
created_onnamedatabase_nameschema_namecommentownerdefault_version_nameversions
02024-07-23 12:49:47.511000-07:00TITANICCROMANODEMONoneSYSADMINV_1[\"V_1\"]
12024-06-18 09:23:46.548000-07:00TITANIC_PIPECROMANODEMONoneSYSADMINV_1[\"V_1\",\"V_2\",\"V_3\",\"V_4\",\"V_5\"]
\n", 737 | "
" 738 | ], 739 | "text/plain": [ 740 | " created_on name database_name schema_name \\\n", 741 | "0 2024-07-23 12:49:47.511000-07:00 TITANIC CROMANO DEMO \n", 742 | "1 2024-06-18 09:23:46.548000-07:00 TITANIC_PIPE CROMANO DEMO \n", 743 | "\n", 744 | " comment owner default_version_name versions \n", 745 | "0 None SYSADMIN V_1 [\"V_1\"] \n", 746 | "1 None SYSADMIN V_1 [\"V_1\",\"V_2\",\"V_3\",\"V_4\",\"V_5\"] " 747 | ] 748 | }, 749 | "execution_count": 27, 750 | "metadata": {}, 751 | "output_type": "execute_result" 752 | } 753 | ], 754 | "source": [ 755 | "reg.show_models()" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": 29, 761 | "metadata": {}, 762 | "outputs": [], 763 | "source": [ 764 | "hyperparameters = {\n", 765 | " k: v for k, v in optimal_model.get_params().items() if v and k != \"missing\"\n", 766 | "}\n", 767 | "titanic_model.set_metric(metric_name=\"hyperparameters\", value=hyperparameters)" 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "execution_count": 30, 773 | "metadata": {}, 774 | "outputs": [ 775 | { 776 | "data": { 777 | "text/html": [ 778 | "
\n", 779 | "\n", 792 | "\n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | "
created_onnamealiasescommentdatabase_nameschema_namemodule_nameis_default_versionfunctionsmetadatauser_data
02024-07-23 12:49:47.554000-07:00V_1[\"DEFAULT\",\"FIRST\",\"LAST\"]NoneCROMANODEMOTITANICtrue[\"PREDICT_PROBA\",\"PREDICT\"]{\"metrics\": {\"accuracy\": 0.8093617021276595, \"hyperparameters\": {\"objective\": \"binary:logistic\", \"learning_rate\": 0.2, \"max_depth\": 3, \"min_child_weight\": 4, \"n_estimators\": 500, \"n_jobs\": 3}}, \"snowpark_ml_schema_version\": \"2024-01-01\"}{}
\n", 826 | "
" 827 | ], 828 | "text/plain": [ 829 | " created_on name aliases comment \\\n", 830 | "0 2024-07-23 12:49:47.554000-07:00 V_1 [\"DEFAULT\",\"FIRST\",\"LAST\"] None \n", 831 | "\n", 832 | " database_name schema_name module_name is_default_version \\\n", 833 | "0 CROMANO DEMO TITANIC true \n", 834 | "\n", 835 | " functions \\\n", 836 | "0 [\"PREDICT_PROBA\",\"PREDICT\"] \n", 837 | "\n", 838 | " metadata \\\n", 839 | "0 {\"metrics\": {\"accuracy\": 0.8093617021276595, \"hyperparameters\": {\"objective\": \"binary:logistic\", \"learning_rate\": 0.2, \"max_depth\": 3, \"min_child_weight\": 4, \"n_estimators\": 500, \"n_jobs\": 3}}, \"snowpark_ml_schema_version\": \"2024-01-01\"} \n", 840 | "\n", 841 | " user_data \n", 842 | "0 {} " 843 | ] 844 | }, 845 | "execution_count": 30, 846 | "metadata": {}, 847 | "output_type": "execute_result" 848 | } 849 | ], 850 | "source": [ 851 | "pd.options.display.max_colwidth = 500\n", 852 | "reg.get_model(model_name).show_versions()" 853 | ] 854 | }, 855 | { 856 | "cell_type": "markdown", 857 | "metadata": {}, 858 | "source": [ 859 | "If you have multiple versions of the model, we want the UDF to be deployed as the version with the highest accuracy\n" 860 | ] 861 | }, 862 | { 863 | "cell_type": "code", 864 | "execution_count": 31, 865 | "metadata": {}, 866 | "outputs": [], 867 | "source": [ 868 | "reg_df = reg.get_model(model_name).show_versions()\n", 869 | "reg_df[\"accuracy\"] = reg_df[\"metadata\"].apply(\n", 870 | " lambda x: json.loads(x)[\"metrics\"][\"accuracy\"]\n", 871 | ")\n", 872 | "best_model = reg_df.sort_values(by=\"accuracy\", ascending=False)" 873 | ] 874 | }, 875 | { 876 | "cell_type": "code", 877 | "execution_count": 32, 878 | "metadata": {}, 879 | "outputs": [ 880 | { 881 | "data": { 882 | "text/plain": [ 883 | "'V_1'" 884 | ] 885 | }, 886 | "execution_count": 32, 887 | "metadata": {}, 888 | "output_type": "execute_result" 889 | } 890 | ], 891 | "source": [ 892 | "deployed_version = best_model[\"name\"].iloc[0]\n", 893 | "deployed_version" 894 | ] 895 | }, 896 | { 897 | "cell_type": "markdown", 898 | "metadata": {}, 899 | "source": [ 900 | "Set the default version to the deployed version (best model)" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": 33, 906 | "metadata": {}, 907 | "outputs": [ 908 | { 909 | "data": { 910 | "text/plain": [ 911 | "'V_1'" 912 | ] 913 | }, 914 | "execution_count": 33, 915 | "metadata": {}, 916 | "output_type": "execute_result" 917 | } 918 | ], 919 | "source": [ 920 | "m = reg.get_model(model_name)\n", 921 | "m.default = deployed_version\n", 922 | "mv = m.default\n", 923 | "mv.version_name" 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": 34, 929 | "metadata": {}, 930 | "outputs": [ 931 | { 932 | "name": "stdout", 933 | "output_type": "stream", 934 | "text": [ 935 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 936 | "|\"CLASS_SECOND\" |\"CLASS_THIRD\" |\"WHO_MAN\" |\"WHO_WOMAN\" |\"EMBARK_TOWN_QUEENSTOWN\" |\"EMBARK_TOWN_SOUTHAMPTON\" |\"SURVIVED\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"output_feature_0\" |\"output_feature_1\" |\n", 937 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 938 | "|0.0 |1.0 |1.0 |0.0 |1.0 |0.0 |0 |0 |0 |8.4583 |0.9589072465896606 |0.041092731058597565 |\n", 939 | "|0.0 |0.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |0 |51.8625 |0.8610950708389282 |0.13890492916107178 |\n", 940 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |0 |3 |1 |21.075 |0.8630478382110596 |0.13695217669010162 |\n", 941 | "|1.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1 |1 |0 |30.0708 |0.005812287330627441 |0.9941877126693726 |\n", 942 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |1 |5 |31.275 |0.8883638978004456 |0.11163610219955444 |\n", 943 | "|1.0 |0.0 |0.0 |1.0 |0.0 |1.0 |1 |0 |0 |16.0 |0.029041707515716553 |0.9709582924842834 |\n", 944 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |0 |3 |1 |21.075 |0.8630478382110596 |0.13695217669010162 |\n", 945 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |0 |7.8958 |0.9779325723648071 |0.022067449986934662 |\n", 946 | "|0.0 |0.0 |0.0 |1.0 |0.0 |0.0 |1 |1 |0 |146.5208 |0.0013046860694885254 |0.9986953139305115 |\n", 947 | "|0.0 |0.0 |1.0 |0.0 |0.0 |0.0 |0 |1 |0 |82.1708 |0.3023839592933655 |0.6976160407066345 |\n", 948 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 949 | "\n" 950 | ] 951 | } 952 | ], 953 | "source": [ 954 | "remote_prediction = mv.run(test_df, function_name=\"predict_proba\")\n", 955 | "remote_prediction.show()" 956 | ] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "execution_count": 35, 961 | "metadata": {}, 962 | "outputs": [], 963 | "source": [ 964 | "# To test in SQL write test data back to a table\n", 965 | "\n", 966 | "test_df.write.mode(\"overwrite\").save_as_table(\"TEST_DATA\")" 967 | ] 968 | }, 969 | { 970 | "cell_type": "markdown", 971 | "metadata": {}, 972 | "source": [ 973 | "## Add images to stage for Streamlit App\n" 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": 36, 979 | "metadata": {}, 980 | "outputs": [ 981 | { 982 | "data": { 983 | "text/plain": [ 984 | "[PutResult(source='floating.webp', target='floating.webp.gz', source_size=205540, target_size=204784, source_compression='NONE', target_compression='GZIP', status='UPLOADED', message=''),\n", 985 | " PutResult(source='flying.webp', target='flying.webp.gz', source_size=77798, target_size=76928, source_compression='NONE', target_compression='GZIP', status='UPLOADED', message=''),\n", 986 | " PutResult(source='sinking.webp', target='sinking.webp.gz', source_size=148802, target_size=148160, source_compression='NONE', target_compression='GZIP', status='UPLOADED', message='')]" 987 | ] 988 | }, 989 | "execution_count": 36, 990 | "metadata": {}, 991 | "output_type": "execute_result" 992 | } 993 | ], 994 | "source": [ 995 | "session.file.put(\"../streamlit_images/*\", \"@ML_DATA\")" 996 | ] 997 | }, 998 | { 999 | "cell_type": "markdown", 1000 | "metadata": {}, 1001 | "source": [ 1002 | "## Calling Model from SQL " 1003 | ] 1004 | }, 1005 | { 1006 | "cell_type": "code", 1007 | "execution_count": 37, 1008 | "metadata": {}, 1009 | "outputs": [ 1010 | { 1011 | "name": "stdout", 1012 | "output_type": "stream", 1013 | "text": [ 1014 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 1015 | "|\"CLASS_SECOND\" |\"CLASS_THIRD\" |\"WHO_MAN\" |\"WHO_WOMAN\" |\"EMBARK_TOWN_QUEENSTOWN\" |\"EMBARK_TOWN_SOUTHAMPTON\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"SURV_PRED\" |\n", 1016 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 1017 | "|0.0 |1.0 |1.0 |0.0 |1.0 |0.0 |0 |0 |8.4583 |9.589072465896606e-01 |\n", 1018 | "|0.0 |0.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |51.8625 |8.610950708389282e-01 |\n", 1019 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |3 |1 |21.075 |8.630478382110596e-01 |\n", 1020 | "|1.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1 |0 |30.0708 |5.812287330627441e-03 |\n", 1021 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |1 |5 |31.275 |8.883638978004456e-01 |\n", 1022 | "|1.0 |0.0 |0.0 |1.0 |0.0 |1.0 |0 |0 |16.0 |2.904170751571655e-02 |\n", 1023 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |3 |1 |21.075 |8.630478382110596e-01 |\n", 1024 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |7.8958 |9.779325723648071e-01 |\n", 1025 | "|0.0 |0.0 |0.0 |1.0 |0.0 |0.0 |1 |0 |146.5208 |1.304686069488525e-03 |\n", 1026 | "|0.0 |0.0 |1.0 |0.0 |0.0 |0.0 |1 |0 |82.1708 |3.023839592933655e-01 |\n", 1027 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 1028 | "\n" 1029 | ] 1030 | } 1031 | ], 1032 | "source": [ 1033 | "#Copy this code in a snowflake worksheet or run via session.sql\n", 1034 | "inference_df = session.sql('''\n", 1035 | "select *, TITANIC!predict_proba(*):output_feature_0\n", 1036 | "as surv_pred\n", 1037 | "from (\n", 1038 | "select * exclude survived\n", 1039 | "from test_data)\n", 1040 | " ''')\n", 1041 | "inference_df.show()" 1042 | ] 1043 | }, 1044 | { 1045 | "cell_type": "markdown", 1046 | "metadata": {}, 1047 | "source": [ 1048 | "# Calling model from a new notebook" 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "execution_count": 38, 1054 | "metadata": {}, 1055 | "outputs": [ 1056 | { 1057 | "name": "stdout", 1058 | "output_type": "stream", 1059 | "text": [ 1060 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 1061 | "|\"CLASS_SECOND\" |\"CLASS_THIRD\" |\"WHO_MAN\" |\"WHO_WOMAN\" |\"EMBARK_TOWN_QUEENSTOWN\" |\"EMBARK_TOWN_SOUTHAMPTON\" |\"SURVIVED\" |\"SIBSP\" |\"PARCH\" |\"FARE\" |\"PRED_SURVIVED\" |\n", 1062 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 1063 | "|0.0 |1.0 |1.0 |0.0 |1.0 |0.0 |0 |0 |0 |8.4583 |0.041092731058597565 |\n", 1064 | "|0.0 |0.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |0 |51.8625 |0.13890492916107178 |\n", 1065 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |0 |3 |1 |21.075 |0.13695217669010162 |\n", 1066 | "|1.0 |0.0 |0.0 |0.0 |0.0 |0.0 |1 |1 |0 |30.0708 |0.9941877126693726 |\n", 1067 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |1 |5 |31.275 |0.11163610219955444 |\n", 1068 | "|1.0 |0.0 |0.0 |1.0 |0.0 |1.0 |1 |0 |0 |16.0 |0.9709582924842834 |\n", 1069 | "|0.0 |1.0 |0.0 |0.0 |0.0 |1.0 |0 |3 |1 |21.075 |0.13695217669010162 |\n", 1070 | "|0.0 |1.0 |1.0 |0.0 |0.0 |1.0 |0 |0 |0 |7.8958 |0.022067449986934662 |\n", 1071 | "|0.0 |0.0 |0.0 |1.0 |0.0 |0.0 |1 |1 |0 |146.5208 |0.9986953139305115 |\n", 1072 | "|0.0 |0.0 |1.0 |0.0 |0.0 |0.0 |0 |1 |0 |82.1708 |0.6976160407066345 |\n", 1073 | "--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", 1074 | "\n" 1075 | ] 1076 | } 1077 | ], 1078 | "source": [ 1079 | "# Point to the registry\n", 1080 | "\n", 1081 | "reg = Registry(session=session)\n", 1082 | "\n", 1083 | "# Get the default version of your model (Model with best accuracy in our case)\n", 1084 | "\n", 1085 | "mv = reg.get_model(\"titanic\").default\n", 1086 | "\n", 1087 | "remote_prediction = mv.run(test_df, function_name=\"predict_proba\")\n", 1088 | "remote_prediction.drop('\"output_feature_0\"').with_column_renamed(\n", 1089 | " '\"output_feature_1\"', \"pred_survived\"\n", 1090 | ").show()" 1091 | ] 1092 | }, 1093 | { 1094 | "cell_type": "markdown", 1095 | "metadata": {}, 1096 | "source": [ 1097 | "## To delete your model and all of it's versions" 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": 33, 1103 | "metadata": {}, 1104 | "outputs": [], 1105 | "source": [ 1106 | "#reg.delete_model(\"TITANIC\")" 1107 | ] 1108 | } 1109 | ], 1110 | "metadata": { 1111 | "kernelspec": { 1112 | "display_name": "Intro_SnowML", 1113 | "language": "python", 1114 | "name": "python3" 1115 | }, 1116 | "language_info": { 1117 | "codemirror_mode": { 1118 | "name": "ipython", 1119 | "version": 3 1120 | }, 1121 | "file_extension": ".py", 1122 | "mimetype": "text/x-python", 1123 | "name": "python", 1124 | "nbconvert_exporter": "python", 1125 | "pygments_lexer": "ipython3", 1126 | "version": "3.10.14" 1127 | } 1128 | }, 1129 | "nbformat": 4, 1130 | "nbformat_minor": 2 1131 | } 1132 | --------------------------------------------------------------------------------