├── packages.txt ├── requirements.txt ├── .astro └── config.yaml ├── .dockerignore ├── include ├── .DS_Store ├── rexmex.zip ├── pytorch_tabnet.zip ├── images │ ├── streamlit.jpg │ ├── apache_airflow.jpg │ ├── weathersource_tile.png │ ├── sagemaker_studio_lab.jpg │ └── weathersource_getdata.png ├── state.json └── streamlit_app.py ├── setup ├── snowpark_container_build.zip ├── Dockerfile ├── environment.yml ├── update-domain-input.json ├── app-image-config-input.json ├── create-domain-input.json └── build.sh ├── docker-compose.override.yml.TEMP ├── dags ├── snowpark_connection.py ├── model_eval.py ├── mlops_monthly_pipeline.py ├── mlops_setup_pipeline.py ├── airflow_incremental_pipeline.py ├── feature_engineering.py ├── airflow_setup_pipeline.py ├── ingest.py ├── station_train_predict.py ├── elt.py ├── mlops_tasks.py ├── airflow_tasks.py ├── cdc.py └── mlops_pipeline.py ├── environment.yml ├── jupyter_env.yml ├── airflow_settings.yml ├── Dockerfile ├── tests └── dags │ └── test_dag_integrity.py ├── .gitignore ├── 02_Data_Marketplace.ipynb ├── README.md ├── 02_Data_Science-ARIMA-Baseline.ipynb ├── 06_Streamlit_App.ipynb ├── 00_Setup.ipynb ├── 01_Ingest.ipynb └── 05_Airflow_Pipeline.ipynb /packages.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.astro/config.yaml: -------------------------------------------------------------------------------- 1 | project: 2 | name: citibike_ml 3 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .astro 2 | .git 3 | .env 4 | airflow_settings.yaml 5 | pod-config.yml 6 | logs/ 7 | penv/ 8 | -------------------------------------------------------------------------------- /include/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/.DS_Store -------------------------------------------------------------------------------- /include/rexmex.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/rexmex.zip -------------------------------------------------------------------------------- /include/pytorch_tabnet.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/pytorch_tabnet.zip -------------------------------------------------------------------------------- /include/images/streamlit.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/images/streamlit.jpg -------------------------------------------------------------------------------- /include/images/apache_airflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/images/apache_airflow.jpg -------------------------------------------------------------------------------- /setup/snowpark_container_build.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/setup/snowpark_container_build.zip -------------------------------------------------------------------------------- /include/images/weathersource_tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/images/weathersource_tile.png -------------------------------------------------------------------------------- /include/images/sagemaker_studio_lab.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/images/sagemaker_studio_lab.jpg -------------------------------------------------------------------------------- /include/images/weathersource_getdata.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python/HEAD/include/images/weathersource_getdata.png -------------------------------------------------------------------------------- /setup/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM continuumio/miniconda3:4.10.3 2 | 3 | COPY environment.yml . 4 | COPY ../include/snowflake_snowpark_python-0.4.0-py3-none-any.whl /tmp 5 | RUN conda update conda && \ 6 | conda env create -f environment.yml #&& \ 7 | -------------------------------------------------------------------------------- /setup/environment.yml: -------------------------------------------------------------------------------- 1 | name: snowpark 2 | channels: 3 | - conda 4 | dependencies: 5 | - python=3.8 6 | - pip 7 | - pip: 8 | - '/tmp/snowflake_snowpark_python-0.4.0-py3-none-any.whl[pandas]' 9 | - pytorch-tabnet 10 | - matplotlib 11 | - seaborn 12 | - rexmex 13 | - awscli 14 | - boto3 15 | - ipykernel 16 | -------------------------------------------------------------------------------- /docker-compose.override.yml.TEMP: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | jupyter: 4 | image: fletchjeff/vhol-citibike:v0.4 5 | networks: 6 | - airflow 7 | volumes: 8 | - ${PWD}:/code/ 9 | ports: 10 | - 8888:8888 11 | command: bash -c "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/code --ip 0.0.0.0 --no-browser --allow-root --NotebookApp.token=''" 12 | -------------------------------------------------------------------------------- /setup/update-domain-input.json: -------------------------------------------------------------------------------- 1 | { 2 | "DomainId": "SAGEMAKER_DOMAIN", 3 | "DefaultUserSettings": { 4 | "KernelGatewayAppSettings": { 5 | "CustomImages": [ 6 | { 7 | "ImageName": "snowpark040", 8 | "AppImageConfigName": "snowpark-config" 9 | } 10 | ] 11 | } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /dags/snowpark_connection.py: -------------------------------------------------------------------------------- 1 | def snowpark_connect(state_file='./include/state.json'): 2 | import snowflake.snowpark as snp 3 | import json 4 | 5 | with open(state_file) as sdf: 6 | state_dict = json.load(sdf) 7 | 8 | session=None 9 | session = snp.Session.builder.configs(state_dict["connection_parameters"]).create() 10 | session.use_warehouse(state_dict['compute_parameters']['default_warehouse']) 11 | return session, state_dict 12 | -------------------------------------------------------------------------------- /setup/app-image-config-input.json: -------------------------------------------------------------------------------- 1 | { 2 | "AppImageConfigName": "snowpark-config", 3 | "KernelGatewayImageConfig": { 4 | "KernelSpecs": [ 5 | { 6 | "Name": "conda-env-snowpark-py", 7 | "DisplayName": "Python [conda env: snowpark]" 8 | } 9 | ], 10 | "FileSystemConfig": { 11 | "MountPath": "/root", 12 | "DefaultUid": 0, 13 | "DefaultGid": 0 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /setup/create-domain-input.json: -------------------------------------------------------------------------------- 1 | { 2 | "DomainName": "domain-with-custom-conda-env", 3 | "VpcId": "", 4 | "SubnetIds": [ 5 | "" 6 | ], 7 | "DefaultUserSettings": { 8 | "ExecutionRole": "", 9 | "KernelGatewayAppSettings": { 10 | "CustomImages": [ 11 | { 12 | "ImageName": "conda-env-kernel", 13 | "AppImageConfigName": "conda-env-kernel-config" 14 | } 15 | ] 16 | } 17 | }, 18 | "AuthMode": "IAM" 19 | } -------------------------------------------------------------------------------- /dags/model_eval.py: -------------------------------------------------------------------------------- 1 | def eval_model_func(input_data: str, 2 | y_true_name: str, 3 | y_score_name: str): 4 | import pandas as pd 5 | from rexmex import RatingMetricSet, ScoreCard 6 | 7 | metric_set = RatingMetricSet() 8 | score_card = ScoreCard(metric_set) 9 | 10 | df = pd.read_json(input_data) 11 | df.rename(columns={y_true_name: 'y_true', y_score_name:'y_score'}, inplace=True) 12 | 13 | df = score_card.generate_report(df).reset_index() 14 | 15 | return df.to_json(orient='records', lines=False) 16 | -------------------------------------------------------------------------------- /include/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "connection_parameters": {"user": "", 3 | "account": ".", 4 | "role": "ACCOUNTADMIN" 5 | }, 6 | "compute_parameters" : {"default_warehouse": "XSMALL_WH", 7 | "task_warehouse": "XSMALL_WH", 8 | "load_warehouse": "LARGE_WH", 9 | "fe_warehouse": "XXLARGE_WH", 10 | "train_warehouse": "XXLARGE_WH" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: snowpark_0110 2 | channels: 3 | - https://repo.anaconda.com/pkgs/snowflake/ 4 | dependencies: 5 | - python=3.8 6 | - pip 7 | - pip: 8 | - snowflake-snowpark-python[pandas]==0.11.0 9 | # - './include/snowflake_snowpark_python-0.6.0-py3-none-any.whl[pandas]' 10 | - protobuf==3.20.0 11 | - scikit-learn==1.0.2 12 | - scipy==1.7.1 13 | - cloudpickle==2.0.0 14 | - torch==1.10.2 15 | - statsmodels==0.13.0 16 | - matplotlib==3.5.0 17 | - ipykernel==6.13.0 18 | - pytorch-tabnet==3.1.1 19 | - seaborn==0.11.2 20 | - rexmex==0.1.0 21 | - streamlit==1.8.1 22 | -------------------------------------------------------------------------------- /jupyter_env.yml: -------------------------------------------------------------------------------- 1 | name: snowpark_0110 2 | channels: 3 | - https://repo.anaconda.com/pkgs/snowflake/ 4 | dependencies: 5 | - python=3.8 6 | - pip 7 | - pip: 8 | - snowflake-snowpark-python[pandas]==0.11.0 9 | # - './include/snowflake_snowpark_python-0.6.0-py3-none-any.whl[pandas]' 10 | - protobuf==3.20.0 11 | - scikit-learn==1.0.2 12 | - scipy==1.7.1 13 | - cloudpickle==2.0.0 14 | - torch==1.10.2 15 | - statsmodels==0.13.0 16 | - matplotlib==3.5.0 17 | - jupyter==1.0.0 18 | - pytorch-tabnet==3.1.1 19 | - seaborn==0.11.2 20 | - rexmex==0.1.0 21 | - streamlit==1.8.1 22 | -------------------------------------------------------------------------------- /airflow_settings.yml: -------------------------------------------------------------------------------- 1 | # This file allows you to configure Airflow Connections, Pools, and Variables in a single place for local development only. 2 | # NOTE: If putting a dict in conn_extra, please wrap in single quotes. 3 | 4 | # For more information, refer to our docs: https://docs.astronomer.io/develop-project#configure-airflow_settingsyaml-local-development-only 5 | # For issues or questions, reach out to: https://support.astronomer.io 6 | 7 | airflow: 8 | connections: 9 | - conn_id: 10 | conn_type: 11 | conn_host: 12 | conn_schema: 13 | conn_login: 14 | conn_password: 15 | conn_port: 16 | conn_extra: 17 | pools: 18 | - pool_name: 19 | pool_slot: 20 | pool_description: 21 | variables: 22 | - variable_name: 23 | variable_value: 24 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # If you are using an M1 mac, the docker build process takes ages. There is a 2 | # prebuilt image available to make this process go faster. To use the standard 3 | # build process, comment the line below and uncomment the rest of the document. 4 | 5 | # FROM fletchjeff/vhol-citibike:v0.3 6 | 7 | FROM quay.io/astronomer/astro-runtime:5.0.2 8 | USER root 9 | RUN apt-get -y update \ 10 | && apt-get -y upgrade \ 11 | && apt-get install build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libsqlite3-dev libreadline-dev libffi-dev curl libbz2-dev wget -y \ 12 | && wget https://www.python.org/ftp/python/3.8.12/Python-3.8.12.tar.xz \ 13 | && tar -xf Python-3.8.12.tar.xz \ 14 | && mv Python-3.8.12 /opt/Python3.8.12 15 | WORKDIR /opt/Python3.8.12/ 16 | RUN ./configure \ 17 | #--enable-optimizations --enable-shared 18 | && make \ 19 | && make altinstall \ 20 | && ldconfig /opt/Python3.8.12 \ 21 | && pip3.8 install snowflake-snowpark-python[pandas]==0.11.0 jupyterlab 22 | USER astro 23 | WORKDIR /usr/local/airflow 24 | 25 | -------------------------------------------------------------------------------- /setup/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | REGION= 4 | ACCOUNT_ID= 5 | IMAGE_NAME=conda-env-kernel 6 | ROLE_ARN=arn:aws:iam::${ACCOUNT_ID}:role/RoleName 7 | 8 | aws --region ${REGION} ecr create-repository --repository-name smstudio-custom 9 | 10 | aws --region ${REGION} ecr get-login-password | docker login --username AWS --password-stdin ${ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com/smstudio-custom 11 | 12 | docker build . -t ${IMAGE_NAME} -t ${ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com/smstudio-custom:${IMAGE_NAME} 13 | 14 | docker push ${ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com/smstudio-custom:${IMAGE_NAME} 15 | 16 | aws --region ${REGION} sagemaker create-image --image-name ${IMAGE_NAME} --role-arn ${ROLE_ARN} 17 | 18 | aws --region ${REGION} sagemaker create-image-version --image-name ${IMAGE_NAME} \ 19 | --base-image "${ACCOUNT_ID}.dkr.ecr.${REGION}.amazonaws.com/smstudio-custom:${IMAGE_NAME}" 20 | 21 | aws --region ${REGION} sagemaker describe-image-version --image-name ${IMAGE_NAME} 22 | 23 | aws --region ${REGION} sagemaker create-app-image-config --cli-input-json file://app-image-config-input.json 24 | 25 | #If you don't have a Sagemaker domain already... 26 | #aws --region ${REGION} sagemaker create-domain --cli-input-json file://create-domain-input.json 27 | 28 | #!!!! udpate the file update-domain-input.json with your sagemaker domain id !!!! 29 | 30 | aws --region ${REGION} sagemaker update-domain --cli-input-json file://update-domain-input.json 31 | -------------------------------------------------------------------------------- /tests/dags/test_dag_integrity.py: -------------------------------------------------------------------------------- 1 | """Test the validity of all DAGs. This test ensures that all Dags have tags, retries set to two, and no import errors. Feel free to add and remove tests.""" 2 | 3 | import os 4 | import logging 5 | from contextlib import contextmanager 6 | import pytest 7 | from airflow.models import DagBag 8 | 9 | @contextmanager 10 | def suppress_logging(namespace): 11 | logger = logging.getLogger(namespace) 12 | old_value = logger.disabled 13 | logger.disabled = True 14 | try: 15 | yield 16 | finally: 17 | logger.disabled = old_value 18 | 19 | def get_import_errors(): 20 | """ 21 | Generate a tuple for import errors in the dag bag 22 | """ 23 | with suppress_logging('airflow') : 24 | dag_bag = DagBag(include_examples=False) 25 | 26 | def strip_path_prefix(path): 27 | return os.path.relpath(path ,os.environ.get('AIRFLOW_HOME')) 28 | 29 | # we prepend "(None,None)" to ensure that a test object is always created even if its a no op. 30 | return [(None,None)] +[ ( strip_path_prefix(k) , v.strip() ) for k,v in dag_bag.import_errors.items()] 31 | 32 | def get_dags(): 33 | """ 34 | Generate a tuple of dag_id, in the DagBag 35 | """ 36 | with suppress_logging('airflow') : 37 | dag_bag = DagBag(include_examples=False) 38 | 39 | def strip_path_prefix(path): 40 | return os.path.relpath(path ,os.environ.get('AIRFLOW_HOME')) 41 | return [ (k,v,strip_path_prefix(v.fileloc)) for k,v in dag_bag.dags.items()] 42 | 43 | @pytest.mark.parametrize("rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()]) 44 | def test_file_imports(rel_path,rv): 45 | """ Test for import errors on a file """ 46 | if rel_path and rv : 47 | raise Exception(f"{rel_path} failed to import with message \n {rv}") 48 | 49 | 50 | 51 | APPROVED_TAGS = {} 52 | 53 | @pytest.mark.parametrize("dag_id,dag,fileloc", get_dags(), ids=[x[2] for x in get_dags()]) 54 | def test_dag_tags(dag_id,dag, fileloc): 55 | """ 56 | test if a DAG is tagged and if those TAGs are in the approved list 57 | """ 58 | assert dag.tags, f"{dag_id} in {fileloc} has no tags" 59 | if APPROVED_TAGS: 60 | assert not set(dag.tags) - APPROVED_TAGS 61 | 62 | 63 | 64 | @pytest.mark.parametrize("dag_id,dag, fileloc", get_dags(), ids=[x[2] for x in get_dags()]) 65 | def test_dag_retries(dag_id,dag, fileloc): 66 | """ 67 | test if a DAG has retries set 68 | """ 69 | assert dag.default_args.get('retries', None) > 2 , f"{dag_id} in {fileloc} does not have retries not set to 2." 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | penv/ 132 | 133 | include/state.json 134 | -------------------------------------------------------------------------------- /dags/mlops_monthly_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | from dags.mlops_tasks import incremental_elt_task 3 | from dags.mlops_tasks import generate_feature_table_task 4 | from dags.mlops_tasks import generate_forecast_table_task 5 | from dags.mlops_tasks import bulk_train_predict_task 6 | from dags.mlops_tasks import eval_station_models_task 7 | from dags.mlops_tasks import flatten_tables_task 8 | 9 | def citibikeml_monthly_taskflow(files_to_download:list, run_date:str): 10 | """ 11 | End to end Snowflake ML Demo 12 | """ 13 | import uuid 14 | import json 15 | 16 | with open('./include/state.json') as sdf: 17 | state_dict = json.load(sdf) 18 | 19 | model_id = str(uuid.uuid1()).replace('-', '_') 20 | 21 | state_dict.update({'model_id': model_id}) 22 | state_dict.update({'run_date': run_date}) 23 | state_dict.update({'weather_database_name': 'WEATHER_NYC'}) 24 | state_dict.update({'load_table_name': 'RAW_', 25 | 'trips_table_name': 'TRIPS', 26 | 'load_stage_name': 'LOAD_STAGE', 27 | 'model_stage_name': 'MODEL_STAGE', 28 | 'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY', 29 | 'weather_view_name': 'WEATHER_NYC_VW', 30 | 'holiday_table_name': 'HOLIDAYS', 31 | 'clone_table_name': 'CLONE_'+model_id, 32 | 'feature_table_name' : 'FEATURE_'+model_id, 33 | 'pred_table_name': 'PRED_'+model_id, 34 | 'eval_table_name': 'EVAL_'+model_id, 35 | 'forecast_table_name': 'FORECAST_'+model_id, 36 | 'forecast_steps': 30, 37 | 'train_udf_name': 'station_train_predict_udtf', 38 | 'train_func_name': 'StationTrainPredictFunc', 39 | 'eval_udf_name': 'eval_model_output_udf', 40 | 'eval_func_name': 'eval_model_func' 41 | }) 42 | 43 | #Task order - monthlyl incremental 44 | incr_state_dict = incremental_elt_task(state_dict, files_to_download) 45 | feature_state_dict = generate_feature_table_task(incr_state_dict, incr_state_dict, incr_state_dict) 46 | forecast_state_dict = generate_forecast_table_task(incr_state_dict, incr_state_dict, incr_state_dict) 47 | pred_state_dict = bulk_train_predict_task(feature_state_dict, feature_state_dict, forecast_state_dict) 48 | eval_state_dict = eval_station_models_task(pred_state_dict, pred_state_dict, run_date) 49 | state_dict = flatten_tables_task(pred_state_dict, eval_state_dict) 50 | 51 | return state_dict 52 | -------------------------------------------------------------------------------- /dags/mlops_setup_pipeline.py: -------------------------------------------------------------------------------- 1 | from dags.mlops_tasks import snowpark_database_setup 2 | from dags.mlops_tasks import initial_bulk_load_task 3 | from dags.mlops_tasks import materialize_holiday_task 4 | from dags.mlops_tasks import subscribe_to_weather_data_task 5 | from dags.mlops_tasks import create_weather_view_task 6 | from dags.mlops_tasks import deploy_model_udf_task 7 | from dags.mlops_tasks import deploy_eval_udf_task 8 | from dags.mlops_tasks import generate_feature_table_task 9 | from dags.mlops_tasks import generate_forecast_table_task 10 | from dags.mlops_tasks import bulk_train_predict_task 11 | from dags.mlops_tasks import eval_station_models_task 12 | from dags.mlops_tasks import flatten_tables_task 13 | 14 | def citibikeml_setup_taskflow(run_date:str): 15 | """ 16 | End to end Snowflake ML Demo 17 | """ 18 | import uuid 19 | import json 20 | 21 | with open('./include/state.json') as sdf: 22 | state_dict = json.load(sdf) 23 | 24 | model_id = str(uuid.uuid1()).replace('-', '_') 25 | 26 | state_dict.update({'model_id': model_id}) 27 | state_dict.update({'run_date': run_date}) 28 | state_dict.update({'weather_database_name': 'WEATHER_NYC'}) 29 | state_dict.update({'load_table_name': 'RAW_', 30 | 'trips_table_name': 'TRIPS', 31 | 'load_stage_name': 'LOAD_STAGE', 32 | 'model_stage_name': 'MODEL_STAGE', 33 | 'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY', 34 | 'weather_view_name': 'WEATHER_NYC_VW', 35 | 'holiday_table_name': 'HOLIDAYS', 36 | 'clone_table_name': 'CLONE_'+model_id, 37 | 'feature_table_name' : 'FEATURE_'+model_id, 38 | 'pred_table_name': 'PRED_'+model_id, 39 | 'eval_table_name': 'EVAL_'+model_id, 40 | 'forecast_table_name': 'FORECAST_'+model_id, 41 | 'forecast_steps': 30, 42 | 'train_udf_name': 'station_train_predict_udtf', 43 | 'train_func_name': 'StationTrainPredictFunc', 44 | 'eval_udf_name': 'eval_model_output_udf', 45 | 'eval_func_name': 'eval_model_func' 46 | }) 47 | 48 | #Task order - one-time setup 49 | setup_state_dict = snowpark_database_setup(state_dict) 50 | load_state_dict = initial_bulk_load_task(setup_state_dict) 51 | holiday_state_dict = materialize_holiday_task(setup_state_dict) 52 | subscribe_state_dict = subscribe_to_weather_data_task(setup_state_dict) 53 | weather_state_dict = create_weather_view_task(subscribe_state_dict) 54 | model_udf_state_dict = deploy_model_udf_task(setup_state_dict) 55 | eval_udf_state_dict = deploy_eval_udf_task(setup_state_dict) 56 | feature_state_dict = generate_feature_table_task(load_state_dict, holiday_state_dict, weather_state_dict) 57 | foecast_state_dict = generate_forecast_table_task(load_state_dict, holiday_state_dict, weather_state_dict) 58 | pred_state_dict = bulk_train_predict_task(model_udf_state_dict, feature_state_dict, foecast_state_dict) 59 | eval_state_dict = eval_station_models_task(eval_udf_state_dict, pred_state_dict, run_date) 60 | state_dict = flatten_tables_task(pred_state_dict, eval_state_dict) 61 | 62 | return state_dict 63 | -------------------------------------------------------------------------------- /dags/airflow_incremental_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | from datetime import datetime, timedelta 3 | 4 | from airflow.decorators import dag, task 5 | from dags.airflow_tasks import snowpark_database_setup 6 | from dags.airflow_tasks import incremental_elt_task 7 | from dags.airflow_tasks import initial_bulk_load_task 8 | from dags.airflow_tasks import materialize_holiday_task 9 | from dags.airflow_tasks import deploy_model_udf_task 10 | from dags.airflow_tasks import deploy_eval_udf_task 11 | from dags.airflow_tasks import generate_feature_table_task 12 | from dags.airflow_tasks import generate_forecast_table_task 13 | from dags.airflow_tasks import bulk_train_predict_task 14 | from dags.airflow_tasks import eval_station_models_task 15 | from dags.airflow_tasks import flatten_tables_task 16 | 17 | default_args = { 18 | 'owner': 'airflow', 19 | 'depends_on_past': False, 20 | 'email_on_failure': False, 21 | 'email_on_retry': False, 22 | 'retries': 1, 23 | 'retry_delay': timedelta(minutes=5) 24 | } 25 | 26 | #local_airflow_path = '/usr/local/airflow/' 27 | 28 | @dag(default_args=default_args, schedule_interval=None, start_date=datetime(2020, 4, 1), catchup=False, tags=['monthly']) 29 | def citibikeml_monthly_taskflow(files_to_download:list, run_date:str): 30 | """ 31 | End to end Snowpark / Astronomer ML Demo 32 | """ 33 | import uuid 34 | import json 35 | 36 | with open('./include/state.json') as sdf: 37 | state_dict = json.load(sdf) 38 | 39 | model_id = str(uuid.uuid1()).replace('-', '_') 40 | 41 | state_dict.update({'model_id': model_id}) 42 | state_dict.update({'run_date': run_date}) 43 | state_dict.update({'weather_database_name': 'WEATHER_NYC'}) 44 | state_dict.update({'load_table_name': 'RAW_', 45 | 'trips_table_name': 'TRIPS', 46 | 'load_stage_name': 'LOAD_STAGE', 47 | 'model_stage_name': 'MODEL_STAGE', 48 | 'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY', 49 | 'weather_view_name': 'WEATHER_NYC_VW', 50 | 'holiday_table_name': 'HOLIDAYS', 51 | 'clone_table_name': 'CLONE_'+model_id, 52 | 'feature_table_name' : 'FEATURE_'+model_id, 53 | 'pred_table_name': 'PRED_'+model_id, 54 | 'eval_table_name': 'EVAL_'+model_id, 55 | 'forecast_table_name': 'FORECAST_'+model_id, 56 | 'forecast_steps': 30, 57 | 'train_udf_name': 'station_train_predict_udf', 58 | 'train_func_name': 'station_train_predict_func', 59 | 'eval_udf_name': 'eval_model_output_udf', 60 | 'eval_func_name': 'eval_model_func' 61 | }) 62 | 63 | incr_state_dict = incremental_elt_task(state_dict, files_to_download) 64 | feature_state_dict = generate_feature_table_task(incr_state_dict, incr_state_dict, incr_state_dict) 65 | forecast_state_dict = generate_forecast_table_task(incr_state_dict, incr_state_dict, incr_state_dict) 66 | pred_state_dict = bulk_train_predict_task(feature_state_dict, feature_state_dict, forecast_state_dict) 67 | eval_state_dict = eval_station_models_task(pred_state_dict, pred_state_dict, run_date) 68 | state_dict = flatten_tables_task(pred_state_dict, eval_state_dict) 69 | 70 | return state_dict 71 | 72 | run_date='2020_02_01' 73 | files_to_download = ['202001-citibike-tripdata.csv.zip'] 74 | 75 | state_dict = citibikeml_monthly_taskflow(files_to_download=files_to_download, 76 | run_date=run_date) 77 | -------------------------------------------------------------------------------- /dags/feature_engineering.py: -------------------------------------------------------------------------------- 1 | 2 | def generate_holiday_df(session, holiday_table_name:str): 3 | from snowflake.snowpark import functions as F 4 | import pandas as pd 5 | from pandas.tseries.holiday import USFederalHolidayCalendar 6 | from datetime import timedelta, datetime 7 | 8 | cal = USFederalHolidayCalendar() 9 | 10 | #generate a feature of 20 years worth of US holiday days. 11 | start_date = datetime.strptime('2013-01-01', '%Y-%m-%d') 12 | end_date = start_date+timedelta(days=365*20) 13 | 14 | holiday_df = pd.DataFrame(cal.holidays(start=start_date, end=end_date), columns=['DATE']) 15 | holiday_df['DATE'] = holiday_df['DATE'].dt.strftime('%Y-%m-%d') 16 | 17 | session.create_dataframe(holiday_df) \ 18 | .with_column("HOLIDAY", F.lit(1))\ 19 | .write\ 20 | .save_as_table(holiday_table_name, mode="overwrite", table_type="temporary") 21 | 22 | return session.table(holiday_table_name) 23 | 24 | def generate_weather_df(session, weather_table_name): 25 | from snowflake.snowpark import functions as F 26 | return session.table(weather_table_name)\ 27 | .filter(F.col('POSTAL_CODE') == '10007')\ 28 | .select(F.col('DATE_VALID_STD').alias('DATE'), 29 | F.col('TOT_PRECIPITATION_MM').alias('PRECIP'), 30 | F.round(F.col('AVG_TEMPERATURE_FEELSLIKE_2M_C'), 2).alias('TEMP'))\ 31 | .sort('DATE', ascending=True) 32 | 33 | def generate_features(session, input_df, holiday_table_name, weather_table_name): 34 | import snowflake.snowpark as snp 35 | from snowflake.snowpark import functions as F 36 | 37 | #start_date, end_date = input_df.select(F.min('STARTTIME'), F.max('STARTTIME')).collect()[0][0:2] 38 | 39 | #check if features are already materialized (or in a temp table) 40 | holiday_df = session.table(holiday_table_name) 41 | try: 42 | _ = holiday_df.columns 43 | except: 44 | holiday_df = generate_holiday_df(session, holiday_table_name) 45 | 46 | weather_df = session.table(weather_table_name)[['DATE','TEMP']] 47 | try: 48 | _ = weather_df.columns 49 | except: 50 | weather_df = generate_weather_df(session, weather_table_name)[['DATE','PRECIP','TEMP']] 51 | 52 | feature_df = input_df.select(F.to_date(F.col('STARTTIME')).alias('DATE'), 53 | F.col('START_STATION_ID').alias('STATION_ID'))\ 54 | .replace({'NULL': None}, subset=['STATION_ID'])\ 55 | .group_by(F.col('STATION_ID'), F.col('DATE'))\ 56 | .count() 57 | 58 | #Impute missing values for lag columns using mean of the previous period. 59 | mean_1 = round(feature_df.sort('DATE').limit(1).select(F.mean('COUNT')).collect()[0][0]) 60 | mean_7 = round(feature_df.sort('DATE').limit(7).select(F.mean('COUNT')).collect()[0][0]) 61 | mean_90 = round(feature_df.sort('DATE').limit(90).select(F.mean('COUNT')).collect()[0][0]) 62 | mean_365 = round(feature_df.sort('DATE').limit(365).select(F.mean('COUNT')).collect()[0][0]) 63 | 64 | date_win = snp.Window.order_by('DATE') 65 | 66 | feature_df = feature_df.with_column('LAG_1', F.lag('COUNT', offset=1, default_value=mean_1) \ 67 | .over(date_win)) \ 68 | .with_column('LAG_7', F.lag('COUNT', offset=7, default_value=mean_7) \ 69 | .over(date_win)) \ 70 | .with_column('LAG_90', F.lag('COUNT', offset=90, default_value=mean_90) \ 71 | .over(date_win)) \ 72 | .with_column('LAG_365', F.lag('COUNT', offset=365, default_value=mean_365) \ 73 | .over(date_win)) \ 74 | .join(holiday_df, 'DATE', join_type='left').na.fill({'HOLIDAY':0}) \ 75 | .join(weather_df, 'DATE', 'inner') \ 76 | .na.drop() 77 | 78 | return feature_df 79 | -------------------------------------------------------------------------------- /dags/airflow_setup_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | from datetime import datetime, timedelta 3 | 4 | from airflow.decorators import dag, task 5 | from dags.airflow_tasks import snowpark_database_setup 6 | from dags.airflow_tasks import incremental_elt_task 7 | from dags.airflow_tasks import initial_bulk_load_task 8 | from dags.airflow_tasks import materialize_holiday_task 9 | from dags.airflow_tasks import subscribe_to_weather_data_task 10 | from dags.airflow_tasks import create_weather_view_task 11 | from dags.airflow_tasks import deploy_model_udf_task 12 | from dags.airflow_tasks import deploy_eval_udf_task 13 | from dags.airflow_tasks import generate_feature_table_task 14 | from dags.airflow_tasks import generate_forecast_table_task 15 | from dags.airflow_tasks import bulk_train_predict_task 16 | from dags.airflow_tasks import eval_station_models_task 17 | from dags.airflow_tasks import flatten_tables_task 18 | 19 | default_args = { 20 | 'owner': 'airflow', 21 | 'depends_on_past': False, 22 | 'email_on_failure': False, 23 | 'email_on_retry': False, 24 | 'retries': 1, 25 | 'retry_delay': timedelta(minutes=5) 26 | } 27 | 28 | #local_airflow_path = '/usr/local/airflow/' 29 | 30 | @dag(default_args=default_args, schedule_interval=None, start_date=datetime(2020, 3, 1), catchup=False, tags=['setup']) 31 | def citibikeml_setup_taskflow(run_date:str): 32 | """ 33 | Setup initial Snowpark / Astronomer ML Demo 34 | """ 35 | import uuid 36 | import json 37 | 38 | with open('./include/state.json') as sdf: 39 | state_dict = json.load(sdf) 40 | 41 | model_id = str(uuid.uuid1()).replace('-', '_') 42 | 43 | state_dict.update({'model_id': model_id}) 44 | state_dict.update({'run_date': run_date}) 45 | state_dict.update({'weather_database_name': 'WEATHER_NYC'}) 46 | state_dict.update({'load_table_name': 'RAW_', 47 | 'trips_table_name': 'TRIPS', 48 | 'load_stage_name': 'LOAD_STAGE', 49 | 'model_stage_name': 'MODEL_STAGE', 50 | 'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY', 51 | 'weather_view_name': 'WEATHER_NYC_VW', 52 | 'holiday_table_name': 'HOLIDAYS', 53 | 'clone_table_name': 'CLONE_'+model_id, 54 | 'feature_table_name' : 'FEATURE_'+model_id, 55 | 'pred_table_name': 'PRED_'+model_id, 56 | 'eval_table_name': 'EVAL_'+model_id, 57 | 'forecast_table_name': 'FORECAST_'+model_id, 58 | 'forecast_steps': 30, 59 | 'train_udf_name': 'station_train_predict_udf', 60 | 'train_func_name': 'station_train_predict_func', 61 | 'eval_udf_name': 'eval_model_output_udf', 62 | 'eval_func_name': 'eval_model_func' 63 | }) 64 | 65 | #Task order - one-time setup 66 | setup_state_dict = snowpark_database_setup(state_dict) 67 | load_state_dict = initial_bulk_load_task(setup_state_dict) 68 | holiday_state_dict = materialize_holiday_task(setup_state_dict) 69 | subscribe_state_dict = subscribe_to_weather_data_task(setup_state_dict) 70 | weather_state_dict = create_weather_view_task(subscribe_state_dict) 71 | model_udf_state_dict = deploy_model_udf_task(setup_state_dict) 72 | eval_udf_state_dict = deploy_eval_udf_task(setup_state_dict) 73 | feature_state_dict = generate_feature_table_task(load_state_dict, holiday_state_dict, weather_state_dict) 74 | foecast_state_dict = generate_forecast_table_task(load_state_dict, holiday_state_dict, weather_state_dict) 75 | pred_state_dict = bulk_train_predict_task(model_udf_state_dict, feature_state_dict, foecast_state_dict) 76 | eval_state_dict = eval_station_models_task(eval_udf_state_dict, pred_state_dict, run_date) 77 | state_dict = flatten_tables_task(pred_state_dict, eval_state_dict) 78 | 79 | return state_dict 80 | 81 | run_date='2020_01_01' 82 | 83 | state_dict = citibikeml_setup_taskflow(run_date=run_date) 84 | -------------------------------------------------------------------------------- /dags/ingest.py: -------------------------------------------------------------------------------- 1 | def incremental_elt(session, 2 | state_dict:dict, 3 | files_to_ingest:list, 4 | download_base_url, 5 | use_prestaged=False) -> str: 6 | 7 | import dags.elt as ELT 8 | from datetime import datetime 9 | 10 | load_stage_name=state_dict['load_stage_name'] 11 | load_table_name=state_dict['load_table_name'] 12 | trips_table_name=state_dict['trips_table_name'] 13 | 14 | if use_prestaged: 15 | print("Skipping extract. Using provided bucket for pre-staged files.") 16 | 17 | schema1_download_files = list() 18 | schema2_download_files = list() 19 | schema2_start_date = datetime.strptime('202102', "%Y%m") 20 | 21 | for file_name in files_to_ingest: 22 | file_start_date = datetime.strptime(file_name.split("-")[0], "%Y%m") 23 | if file_start_date < schema2_start_date: 24 | schema1_download_files.append(file_name.replace('.zip','.gz')) 25 | else: 26 | schema2_download_files.append(file_name.replace('.zip','.gz')) 27 | 28 | 29 | load_stage_names = {'schema1':load_stage_name+'/schema1/', 'schema2':load_stage_name+'/schema2/'} 30 | files_to_load = {'schema1': schema1_download_files, 'schema2': schema2_download_files} 31 | else: 32 | print("Extracting files from public location.") 33 | load_stage_names, files_to_load = ELT.extract_trips_to_stage(session=session, 34 | files_to_download=files_to_ingest, 35 | download_base_url=download_base_url, 36 | load_stage_name=load_stage_name) 37 | 38 | files_to_load['schema1']=[file+'.gz' for file in files_to_load['schema1']] 39 | files_to_load['schema2']=[file+'.gz' for file in files_to_load['schema2']] 40 | 41 | 42 | print("Loading files to raw.") 43 | stage_table_names = ELT.load_trips_to_raw(session=session, 44 | files_to_load=files_to_load, 45 | load_stage_names=load_stage_names, 46 | load_table_name=load_table_name) 47 | 48 | print("Transforming records to trips table.") 49 | trips_table_name = ELT.transform_trips(session=session, 50 | stage_table_names=stage_table_names, 51 | trips_table_name=trips_table_name) 52 | return trips_table_name 53 | 54 | def bulk_elt(session, 55 | state_dict:dict, 56 | download_base_url, 57 | use_prestaged=False) -> str: 58 | 59 | #import dags.elt as ELT 60 | from dags.ingest import incremental_elt 61 | 62 | import pandas as pd 63 | from datetime import datetime 64 | 65 | #Create a list of filenames to download based on date range 66 | #For files like 201306-citibike-tripdata.zip 67 | date_range1 = pd.period_range(start=datetime.strptime("201306", "%Y%m"), 68 | end=datetime.strptime("201612", "%Y%m"), 69 | freq='M').strftime("%Y%m") 70 | file_name_end1 = '-citibike-tripdata.zip' 71 | files_to_extract = [date+file_name_end1 for date in date_range1.to_list()] 72 | 73 | #For files like 201701-citibike-tripdata.csv.zip 74 | date_range2 = pd.period_range(start=datetime.strptime("201701", "%Y%m"), 75 | end=datetime.strptime("201912", "%Y%m"), 76 | freq='M').strftime("%Y%m") 77 | 78 | file_name_end2 = '-citibike-tripdata.csv.zip' 79 | 80 | files_to_extract = files_to_extract + [date+file_name_end2 for date in date_range2.to_list()] 81 | 82 | trips_table_name = incremental_elt(session=session, 83 | state_dict=state_dict, 84 | files_to_ingest=files_to_extract, 85 | use_prestaged=use_prestaged, 86 | download_base_url=download_base_url) 87 | 88 | return trips_table_name 89 | -------------------------------------------------------------------------------- /02_Data_Marketplace.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "304d0e14", 6 | "metadata": {}, 7 | "source": [ 8 | "## Snowflake Data Marketplace \n", 9 | "\n", 10 | "Later in this hands-on-lab we will be using weather data from the Snowflake Data Marketplace for input to our forecasting models. \n", 11 | "\n", 12 | "Weather Source is a leading provider of global weather and climate data and the OnPoint Product Suite provides businesses with the necessary weather and climate data to quickly generate meaningful and actionable insights for a wide range of use cases across industries.\n", 13 | "\n", 14 | "Weather Source powers a majority of Fortune companies who use their data to quantify the impact of weather and climate on various KPIs including footfall traffic, product sales and demand, supply chain and logistics, advertising and more. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "172b95c0", 20 | "metadata": {}, 21 | "source": [ 22 | "### 1. Browse the Marketplace\n", 23 | "Log in to the Snowflake UI with the `Jack` user and password created in step 00-Setup. \n", 24 | " \n", 25 | "Click on [Marketplace](https://app.snowflake.com/marketplace) on the left side bar.\n", 26 | "\n", 27 | "You will see many different types of data ranging from finance and trading to COVID statistics and geospatial datasets. \n", 28 | " \n", 29 | "Search for `Snowpark` in the search bar at the top. And select the tile named [Snowpark for Python - Hands-on-Lab - Weather Data](https://app.snowflake.com/marketplace/listing/GZSOZ1LLE9)." 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "5c844e2f", 35 | "metadata": {}, 36 | "source": [ 37 | "" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "d03a12ee", 43 | "metadata": {}, 44 | "source": [ 45 | "This dataset is provided by Weather Source for the Snowpark hands-on-lab and provides OnPoint Historical Weather Data in daily format for New York City ZIP Code - 10007. \n", 46 | " \n", 47 | "The sample data is in Celcius and covers the time period from June 1, 2013 to present. The data is updated daily and includes the following supported weather parameters: precipitation, temperature, wind speed & direction and humidity." 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "id": "ee1ee90e", 53 | "metadata": {}, 54 | "source": [ 55 | "- Click on `Get Data` on the right side. \n", 56 | "- Enter `WEATHER_NYC` for the database name. \n", 57 | "- Select `PUBLIC` for the role access if not already selected. \n", 58 | "- Read the terms and conditions and click on `Get Data`." 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "1080a3ec", 64 | "metadata": {}, 65 | "source": [ 66 | "" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "35b231df", 72 | "metadata": {}, 73 | "source": [ 74 | "Click on Data on the left side bar and you will see a new WEATHER database has been created. \n", 75 | "\n", 76 | "Note: The weather data has a different prefix, depending on the the region used. \n", 77 | " \n", 78 | "You can capture the listing ID programatically:" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "d17b63d2", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from dags.snowpark_connection import snowpark_connect\n", 89 | "session, state_dict = snowpark_connect()\n", 90 | "\n", 91 | "weather_listing_id = session.sql(\"SHOW SHARES LIKE '%WEATHERSOURCE_SNOWFLAKE_SNOWPARK_TILE_SNOWFLAKE_SECURE_SHARE%'\").collect()[0][2]\n", 92 | "weather_listing_id" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "3e160891", 98 | "metadata": {}, 99 | "source": [ 100 | "After accepting the subsription terms the shared data can also be programatically subscribed to via the following commands:" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "f32e039e", 106 | "metadata": {}, 107 | "source": [ 108 | "We will save the listing ID in the state.json" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "7c6cd370", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "weather_database_name = 'WEATHER_NYC'\n", 119 | "state_dict['weather_listing_id'] = weather_listing_id\n", 120 | "state_dict['weather_database_name'] = weather_database_name\n", 121 | "\n", 122 | "import json\n", 123 | "with open('./include/state.json', 'w') as sdf:\n", 124 | " json.dump(state_dict, sdf)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "f41e65ba", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "session.close()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "b7f526f4", 140 | "metadata": {}, 141 | "source": [ 142 | "Alternatively you can search the entire marketplace by parsing the json output" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "919d9148", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "#_ = session.sql(\"SHOW AVAILABLE LISTINGS IN DATA EXCHANGE SNOWFLAKE_DATA_MARKETPLACE;\").collect()\n", 153 | "#listing = session.sql(\"SELECT \\\"share_name\\\" FROM TABLE(result_scan()) WHERE PARSE_JSON(\\\"metadata\\\"):title LIKE '%Snowpark%Hands%Weather%'; \").collect()[0]['share_name']" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "22033d60", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "snowpark_0110:Python", 168 | "language": "python", 169 | "name": "conda-env-snowpark_0110-py" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 3 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython3", 181 | "version": "3.8.13" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 5 186 | } 187 | -------------------------------------------------------------------------------- /dags/station_train_predict.py: -------------------------------------------------------------------------------- 1 | class StationTrainPredictFunc: 2 | 3 | def __init__(self): 4 | self.scanned_first_row = False 5 | self.historical_column_names = None 6 | self.target_column = None 7 | self.cutpoint = None 8 | self.max_epochs = None 9 | self.forecast_column_names = None 10 | self.forecast_data = None 11 | self.lag_values = None 12 | self.date = [] 13 | self.count = [] 14 | self.lag_1 = [] 15 | self.lag_7 = [] 16 | self.lag_90 = [] 17 | self.lag_365 = [] 18 | self.holiday = [] 19 | self.precip = [] 20 | self.temp = [] 21 | 22 | def process(self, date, count, lag_1, lag_7, lag_90, lag_365, holiday, precip, temp, historical_column_names, target_column, cutpoint, max_epochs, forecast_column_names, forecast_data, lag_values): 23 | if not self.scanned_first_row: 24 | self.historical_column_names = historical_column_names 25 | self.target_column = target_column 26 | self.cutpoint = int(cutpoint) 27 | self.max_epochs = int(max_epochs) 28 | self.forecast_column_names = forecast_column_names 29 | self.forecast_data = forecast_data 30 | self.lag_values = lag_values 31 | self.scanned_first_row = True 32 | self.date.append(date) 33 | self.count.append(count) 34 | self.lag_1.append(lag_1) 35 | self.lag_7.append(lag_7) 36 | self.lag_90.append(lag_90) 37 | self.lag_365.append(lag_365) 38 | self.holiday.append(holiday) 39 | self.precip.append(precip) 40 | self.temp.append(temp) 41 | yield None 42 | 43 | def end_partition(self): 44 | from torch import tensor 45 | import pandas as pd 46 | from pytorch_tabnet.tab_model import TabNetRegressor 47 | from datetime import timedelta 48 | import numpy as np 49 | 50 | feature_columns = self.historical_column_names.copy() 51 | feature_columns.remove('DATE') 52 | feature_columns.remove(self.target_column) 53 | 54 | df = pd.DataFrame(zip(self.date, self.count, self.lag_1, self.lag_7, self.lag_90, self.lag_365, self.holiday, self.precip, self.temp), columns=[*self.historical_column_names]) 55 | df['DATE'] = pd.to_datetime(df['DATE']) 56 | df = df.sort_values(by='DATE', ascending = True) 57 | 58 | forecast_steps = len(self.forecast_data) 59 | 60 | y_valid = df[self.target_column][-self.cutpoint:].values.reshape(-1, 1) 61 | X_valid = df[feature_columns][-self.cutpoint:].values 62 | y_train = df[self.target_column][:-self.cutpoint].values.reshape(-1, 1) 63 | X_train = df[feature_columns][:-self.cutpoint].values 64 | 65 | model = TabNetRegressor() 66 | 67 | model.fit( 68 | X_train, y_train, 69 | eval_set=[(X_valid, y_valid)], 70 | max_epochs=self.max_epochs, 71 | patience=100, 72 | batch_size=128, 73 | virtual_batch_size=64, 74 | num_workers=0, 75 | drop_last=True) 76 | 77 | df['PRED'] = model.predict(tensor(df[feature_columns].values.astype(np.float32))) 78 | 79 | if len(self.lag_values) > 0: 80 | forecast_df = pd.DataFrame(self.forecast_data, columns = self.forecast_column_names) 81 | 82 | for step in range(forecast_steps): 83 | #station_id = df.iloc[-1]['STATION_ID'] 84 | future_date = df.iloc[-1]['DATE']+timedelta(days=1) 85 | lags=[df.shift(lag-1).iloc[-1]['COUNT'] for lag in self.lag_values] 86 | forecast=forecast_df.loc[forecast_df['DATE']==future_date.strftime('%Y-%m-%d')] 87 | forecast=forecast.drop(labels='DATE', axis=1).values.tolist()[0] 88 | features=[*lags, *forecast] 89 | pred=round(model.predict(np.array([features]))[0][0]) 90 | row=[future_date, pred, *features, pred] 91 | df.loc[len(df)]=row 92 | 93 | explain_df = pd.DataFrame(model.explain(df[feature_columns].astype(float).values)[0], 94 | columns = feature_columns).add_prefix('EXPL_').round(2) 95 | df = pd.concat([df.set_index('DATE').reset_index(), explain_df], axis=1) 96 | df['DATE'] = df['DATE'].dt.strftime('%Y-%m-%d') 97 | 98 | yield ([df[:-forecast_steps].to_json(orient='records', lines=False), 99 | df[-forecast_steps:].to_json(orient='records', lines=False)],) 100 | 101 | 102 | # THIS IS THE OLD CODE FOR DOING VECTORIZED UDFs THAT WE ARE JUST HANGING ONTO 103 | # def station_train_predict_func(historical_data:list, 104 | # historical_column_names:list, 105 | # target_column:str, 106 | # cutpoint: int, 107 | # max_epochs: int, 108 | # forecast_data:list, 109 | # forecast_column_names:list, 110 | # lag_values:list): 111 | 112 | # from torch import tensor 113 | # import pandas as pd 114 | # from pytorch_tabnet.tab_model import TabNetRegressor 115 | # from datetime import timedelta 116 | # import numpy as np 117 | 118 | # feature_columns = historical_column_names.copy() 119 | # feature_columns.remove('DATE') 120 | # feature_columns.remove(target_column) 121 | # forecast_steps = len(forecast_data) 122 | 123 | # df = pd.DataFrame(historical_data, columns = historical_column_names) 124 | 125 | # ##In order to do train/valid split on time-based portion the input data must be sorted by date 126 | # df['DATE'] = pd.to_datetime(df['DATE']) 127 | # df = df.sort_values(by='DATE', ascending=True) 128 | 129 | # y_valid = df[target_column][-cutpoint:].values.reshape(-1, 1) 130 | # X_valid = df[feature_columns][-cutpoint:].values 131 | # y_train = df[target_column][:-cutpoint].values.reshape(-1, 1) 132 | # X_train = df[feature_columns][:-cutpoint].values 133 | 134 | # model = TabNetRegressor() 135 | 136 | # model.fit( 137 | # X_train, y_train, 138 | # eval_set=[(X_valid, y_valid)], 139 | # max_epochs=max_epochs, 140 | # patience=100, 141 | # batch_size=128, 142 | # virtual_batch_size=64, 143 | # num_workers=0, 144 | # drop_last=True) 145 | 146 | # df['PRED'] = model.predict(tensor(df[feature_columns].values)) 147 | 148 | # #Now make the multi-step forecast 149 | # if len(lag_values) > 0: 150 | # forecast_df = pd.DataFrame(forecast_data, columns = forecast_column_names) 151 | 152 | # for step in range(forecast_steps): 153 | # #station_id = df.iloc[-1]['STATION_ID'] 154 | # future_date = df.iloc[-1]['DATE']+timedelta(days=1) 155 | # lags=[df.shift(lag-1).iloc[-1]['COUNT'] for lag in lag_values] 156 | # forecast=forecast_df.loc[forecast_df['DATE']==future_date.strftime('%Y-%m-%d')] 157 | # forecast=forecast.drop(labels='DATE', axis=1).values.tolist()[0] 158 | # features=[*lags, *forecast] 159 | # pred=round(model.predict(np.array([features]))[0][0]) 160 | # row=[future_date, pred, *features, pred] 161 | # df.loc[len(df)]=row 162 | 163 | # explain_df = pd.DataFrame(model.explain(df[feature_columns].astype(float).values)[0], 164 | # columns = feature_columns).add_prefix('EXPL_').round(2) 165 | # df = pd.concat([df.set_index('DATE').reset_index(), explain_df], axis=1) 166 | # df['DATE'] = df['DATE'].dt.strftime('%Y-%m-%d') 167 | 168 | # return [df[:-forecast_steps].to_json(orient='records', lines=False), 169 | # df[-forecast_steps:].to_json(orient='records', lines=False)] 170 | -------------------------------------------------------------------------------- /include/streamlit_app.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.getcwd()+'/dags') 3 | 4 | from snowflake.snowpark import functions as F 5 | from snowpark_connection import snowpark_connect 6 | import streamlit as st 7 | import pandas as pd 8 | from datetime import timedelta, datetime 9 | from dateutil.relativedelta import * 10 | import calendar 11 | import altair as alt 12 | import requests 13 | from requests.auth import HTTPBasicAuth 14 | import time 15 | import json 16 | import logging 17 | 18 | logging.basicConfig(level=logging.WARN) 19 | logging.getLogger().setLevel(logging.WARN) 20 | 21 | def update_forecast_table(forecast_df, stations:list, start_date, end_date): 22 | # explainer_columns = [col for col in forecast_df.schema.names if 'EXP' in col] 23 | explainer_columns=['EXPL_LAG_1', 'EXPL_LAG_7','EXPL_LAG_90','EXPL_LAG_365','EXPL_HOLIDAY','EXPL_PRECIP','EXPL_TEMP'] 24 | explainer_columns_new=['DAY', 'DAY_OF_WEEK', 'QUARTER', 'DAY_OF_YEAR','US_HOLIDAY', 'PRECIPITATION','TEMPERATURE'] 25 | 26 | cond = "F.when" + ".when".join(["(F.col('" + c + "') == F.col('EXPLAIN'), F.lit('" + c + "'))" for c in explainer_columns]) 27 | 28 | df = forecast_df.filter((forecast_df['STATION_ID'].in_(stations)) & 29 | (F.col('DATE') >= start_date) & 30 | (F.col('DATE') <= end_date))\ 31 | .select(['STATION_ID', 32 | F.to_char(F.col('DATE')).alias('DATE'), 33 | 'PRED', 34 | 'HOLIDAY', 35 | *explainer_columns])\ 36 | .with_column('EXPLAIN', F.greatest(*explainer_columns))\ 37 | .with_column('REASON', eval(cond))\ 38 | .select(F.col('STATION_ID'), 39 | F.col('DATE'), 40 | F.col('PRED'), 41 | F.col('REASON'), 42 | F.col('EXPLAIN'), 43 | F.col('EXPL_LAG_1').alias('DAY'), 44 | F.col('EXPL_LAG_7').alias('DAY_OF_WEEK'), 45 | F.col('EXPL_LAG_90').alias('QUARTER'), 46 | F.col('EXPL_LAG_365').alias('DAY_OF_YEAR'), 47 | F.col('EXPL_HOLIDAY').alias('US_HOLIDAY'), 48 | F.col('EXPL_PRECIP').alias('PRECIPITATION'), 49 | F.col('EXPL_TEMP').alias('TEMPERATURE'), 50 | )\ 51 | .to_pandas() 52 | 53 | df['REASON'] = pd.Categorical(df['REASON']) 54 | df['REASON_CODE']=df['REASON'].cat.codes 55 | 56 | rect = alt.Chart(df).mark_rect().encode(alt.X('DATE:N'), 57 | alt.Y('STATION_ID:N'), 58 | alt.Color('REASON'), 59 | tooltip=explainer_columns_new) 60 | text = rect.mark_text(baseline='middle').encode(text='PRED:Q', color=alt.value('white')) 61 | 62 | l = alt.layer( 63 | rect, text 64 | ) 65 | 66 | st.write("### Forecast") 67 | st.altair_chart(l, use_container_width=True) 68 | 69 | return None 70 | 71 | def update_eval_table(eval_df, stations:list): 72 | df = eval_df.select('STATION_ID', F.to_char(F.col('RUN_DATE')).alias('RUN_DATE'), 'RMSE')\ 73 | .filter(eval_df['STATION_ID'].in_(stations))\ 74 | .to_pandas() 75 | 76 | data = df.pivot(index="RUN_DATE", columns="STATION_ID", values="RMSE") 77 | data = data.reset_index().melt('RUN_DATE', var_name='STATION_ID', value_name='RMSE') 78 | 79 | nearest = alt.selection(type='single', nearest=True, on='mouseover', 80 | fields=['RUN_DATE'], empty='none') 81 | 82 | line = alt.Chart(data).mark_line(interpolate='basis').encode( 83 | x='RUN_DATE:N', 84 | y='RMSE:Q', 85 | color='STATION_ID:N' 86 | ) 87 | 88 | selectors = alt.Chart(data).mark_point().encode( 89 | x='RUN_DATE:N', 90 | opacity=alt.value(0) 91 | ).add_selection( 92 | nearest 93 | ) 94 | 95 | points = line.mark_point().encode( 96 | opacity=alt.condition(nearest, alt.value(1), alt.value(0)) 97 | ) 98 | 99 | text = line.mark_text(align='left', dx=5, dy=-5).encode( 100 | text=alt.condition(nearest, 'RMSE:Q', alt.value(' ')) 101 | ) 102 | 103 | rules = alt.Chart(data).mark_rule(color='gray').encode( 104 | x='RUN_DATE:N', 105 | ).transform_filter( 106 | nearest 107 | ) 108 | 109 | l = alt.layer( 110 | line, selectors, points, rules, text 111 | ).properties( 112 | width=600, height=300 113 | ) 114 | st.write("### Model Monitor") 115 | st.altair_chart(l, use_container_width=True) 116 | 117 | return None 118 | 119 | def trigger_ingest(download_file_name, run_date): 120 | dag_url='http://localhost:8080/api/v1/dags/citibikeml_monthly_taskflow/dagRuns' 121 | json_payload = {"conf": {"files_to_download": [download_file_name], "run_date": run_date}} 122 | 123 | response = requests.post(dag_url, 124 | json=json_payload, 125 | auth = HTTPBasicAuth('admin', 'admin')) 126 | 127 | run_id = json.loads(response.text)['dag_run_id'] 128 | #run_id = 'manual__2022-04-07T15:02:29.166108+00:00' 129 | 130 | state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state'] 131 | 132 | st.snow() 133 | 134 | with st.spinner('Ingesting file: '+download_file_name): 135 | while state != 'success': 136 | time.sleep(5) 137 | state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state'] 138 | st.success('Ingested file: '+download_file_name+' State: '+str(state)) 139 | 140 | #Main Body 141 | session, state_dict = snowpark_connect('./include/state.json') 142 | forecast_df = session.table('FLAT_FORECAST') 143 | eval_df = session.table('FLAT_EVAL') 144 | trips_df = session.table('TRIPS') 145 | 146 | st.header('Citibike Forecast Application') 147 | st.write('In this application we leverage deep learning models to predict the number of trips started from '+ 148 | 'a given station each day. After selecting the stations and time range desired the application '+\ 149 | 'displays not only the forecast but also explains which features of the model were most used in making '+\ 150 | 'the prediction. Additionally users can see the historical performance of the deep learning model to '+\ 151 | 'monitor predictive capabilities over time.') 152 | 153 | last_trip_date = trips_df.select(F.to_date(F.max('STARTTIME'))).collect()[0][0] 154 | st.write('Data provided as of '+str(last_trip_date)) 155 | 156 | #Create a sidebar for input 157 | min_date=forecast_df.select(F.min('DATE')).collect()[0][0] 158 | max_date=forecast_df.select(F.max('DATE')).collect()[0][0] 159 | 160 | start_date = st.sidebar.date_input('Start Date', value=min_date, min_value=min_date, max_value=max_date) 161 | show_days = st.sidebar.number_input('Number of days to show', value=7, min_value=1, max_value=30) 162 | end_date = start_date+timedelta(days=show_days) 163 | 164 | stations_df=forecast_df.select(F.col('STATION_ID')).distinct().to_pandas() 165 | 166 | sample_stations = ["519", "497", "435", "402", "426", "285", "293"] 167 | 168 | stations = st.sidebar.multiselect('Choose stations', stations_df['STATION_ID'], sample_stations) 169 | if not stations: 170 | stations = stations_df['STATION_ID'] 171 | 172 | update_forecast_table(forecast_df, stations, start_date, end_date) 173 | 174 | update_eval_table(eval_df, stations) 175 | 176 | 177 | next_ingest = last_trip_date+relativedelta(months=+1) 178 | next_ingest = next_ingest.replace(day=1) 179 | 180 | if next_ingest <= datetime.strptime("2016-12-01", "%Y-%m-%d").date(): 181 | download_file_name=next_ingest.strftime('%Y%m')+'-citibike-tripdata.zip' 182 | else: 183 | download_file_name=next_ingest.strftime('%Y%m')+'-citibike-tripdata.zip' 184 | 185 | run_date = next_ingest+relativedelta(months=+1) 186 | run_date = run_date.strftime('%Y_%m_%d') 187 | 188 | st.write('Next ingest for '+str(next_ingest)) 189 | 190 | st.button('Run Ingest Taskflow', on_click=trigger_ingest, args=(download_file_name, run_date)) 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Citibike Machine Learning Hands-on-Lab with Snowpark Python 2 | 3 | Streamlit App 4 | 5 | 6 | ### Requirements: 7 | - Create a [Snowflake Trial Account](https://signup.snowflake.com/) 8 | - Edition: **Enterprise** 9 | - Cloud Provider: **AWS** <- Must be AWS for this HOL 10 | - Region: **US West (Oregon)**. 11 | Note down the account identifier and region to be used in the hands-on-lab. The account identifier and region can be found in the confirmation email from Snowflake Trial setup. Alternatively, the account identifier can be found in the URL of the Snowflake console when logged-in. (Example account: xxNNNNN, region: us-west-1. 12 | - An [Amazon SageMaker Studio Lab Account](https://studiolab.sagemaker.aws/). Do this ahead of time as sign-up may take up to 24 hours due to backlog. Sagemaker 13 | 14 | - Optional: Docker runtime environment such as [Docker Desktop](https://www.docker.com/products/docker-desktop/) will be used for running and managing Apache Airflow DAGs. Alternatively, if you do not have Docker, you will be able to create and run the ML Ops pipeline from Python, albeit without all the benefits of Airflow. 15 | Apache Airflow DAG 16 | 17 | ### Example Use-Case 18 | In this example we use the [Citibike dataset](https://ride.citibikenyc.com/system-data). Citibike is a bicycle sharing system in New York City. Everyday users choose from 20,000 bicycles at over 1000 stations around New York City. 19 | 20 | To ensure customer satisfaction Citibike needs to predict how many bicycles will be needed at each station. Maintenance teams from Citibike will check each station and repair or replace bicycles. Additionally, the team will relocate bicycles between stations based on predicted demand. The operations team needs an application to show how many bicycles will be needed at a given station on a given day. 21 | 22 | For this demo flow we will assume that the organization has the following **policies and processes** : 23 | - **Dev Tools**: Each user can develop in their tool of choice (ie. VS Code, IntelliJ, Pycharm, Eclipse, etc.). Snowpark Python makes it possible to use any environment where they have a python kernel. 24 | - **Data Governance**: To preserve customer privacy no data can be stored locally. The ingest system may store data temporarily but it must be assumed that, in production, the ingest system will not preserve intermediate data products between runs. Snowpark Python allows the user to push-down all operations to Snowflake and bring the code to the data. 25 | - **Automation**: Although the ML team can use any IDE or notebooks for development purposes the final product must be python code at the end of the work stream. Well-documented, modularized code is necessary for good ML operations and to interface with the company's CI/CD and orchestration tools. 26 | - **Compliance**: Any ML models must be traceable back to the original data set used for training. The business needs to be able to easily remove specific user data from training datasets and retrain models. 27 | 28 | 29 | ### Setup Steps: 30 | 31 | - Login to your [Snowflake Trial account](https://app.snowflake.com/) with the admin credentials that were created with the account in one browser tab (a role with ORGADMIN privileges). Keep this tab open during the hands-on-lab. 32 | - Click on the Billing on the left side panel. 33 | - Click on [Terms and Billing](https://app.snowflake.com/terms-and-billing). 34 | - Read and accept terms to continue with the hands-on-lab. 35 | - Login to [SageMaker Studio Lab](https://studiolab.sagemaker.aws/) in another browser tab. 36 | - Create a Runtime if there isn't one already 37 | - Click on Start Runtime 38 | - Click on Open Project 39 | - Select Git -> Clone Git Repository and enter the following: 40 | Repository URL: https://github.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python. 41 | - Select Yes when prompted to create a conda environment. 42 | - A terminal will open and create the environment. When it is done run `conda activate snowpark_0110` in the terminal window. 43 | - When opening notebooks be sure to select the "snowpark_070" kernel. 44 | 45 | ### Alternative Client 46 | 47 | As an alternative to SageMaker Studio Lab this hands-on-lab can be run in Jupyter or any other notebook from a local system or anywhere a python 3.8 kernel can be installed. 48 | 49 | _**Note:** The `astro` cli setup as part of this repo can also run a version of Jupyter Lab on your local system, which might be quicker if you don't want to go through the conda install listed below. To use this include version of Jupyter, rename the `docker-compose.override.yml.TEMP` file to `docker-compose.override.yml` before running the `astro dev start` command detailed in the Airflow section below_ 50 | 51 | - Install Miniconda 52 | - For MacOS Intel Chip run: 53 | ```bash 54 | curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/Downloads/miniconda.sh 55 | sh ~/Downloads/miniconda.sh -b -p $HOME/miniconda 56 | ~/miniconda/bin/conda init 57 | conda update conda 58 | cat ~/.bash_profile >> ~/.zshrc 59 | . ~/.zshrc 60 | ``` 61 | If another shell besides the Mac default zsh shell is used you must re-source the profile for that shell or open a new terminal window to pickup the miniconda shell/path changes. 62 | 63 | - For MacOS M1 Chip run: 64 | ```bash 65 | curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/Downloads/miniconda.sh 66 | sh ~/Downloads/miniconda.sh -b -p $HOME/miniconda 67 | ~/miniconda/bin/conda init 68 | conda update conda 69 | cat ~/.bash_profile >> ~/.zshrc 70 | . ~/.zshrc 71 | ``` 72 | If another shell besides the Mac default zsh shell is used you must re-source the profile for that shell or open a new terminal window to pickup the miniconda shell/path changes. 73 | 74 | - For Microsoft Windows system download [miniconda for windows](https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Windows-x86_64.exe) and manually install. 75 | 76 | - Install git (if not already installed) on your local system. 77 | ```bash 78 | conda install git 79 | ``` 80 | - Create a python kernel environment. Snowpark for Python is currently supported on **Python 3.8 only**. Clone this repository and create an environment. On Mac OS run: 81 | ``` 82 | mkdir ~/Desktop/snowpark-python 83 | cd ~/Desktop/snowpark-python 84 | git clone https://github.com/Snowflake-Labs/sfguide-citibike-ml-snowpark-python 85 | cd sfguide-citibike-ml-snowpark-python 86 | conda env create -f jupyter_env.yml 87 | conda activate snowpark_0110 88 | jupyter notebook 89 | ``` 90 | 91 | ## Automation with Airflow 92 | 93 | Running this process automatically using Apache Airflow can be done by using the Astronomer command line tools. [Astronomer](https://docs.astronomer.io/astro/cli/get-started#step-1-install-the-astro-cli) provides an easy way to deploy Apache Airflow instances in the cloud and the cli toolsets let you develop and test DAGs locally before deploying into production. 94 | 95 | _**Note:** This will require a docker process running on the local machine, e.g. dockerd, Docker Desktop, Colima etc._ 96 | 97 | * First step is to install [astro CLI](https://docs.astronomer.io/astro/cli/get-started#step-1-install-the-astro-cli). 98 | * Next clone the this repo locally if you have not done so already. 99 | * If you have been working with SageMaker, you will need to copy the final `state.json` file you created while working through the Notebooks locally. Place this file in the `include` directory and overwrite the existing file. 100 | * If you want to use a local instance of Jupyter, simple rename the file `docker-compose.override.yml.TEMP` to `docker-compose.override.yml` before moving to the next step. 101 | * Start up the *astro* instance by running `astro dev start` in the repo directory. 102 | * After a few mins you will have an Airflow instance running at http://localhost:8080. If you renamed the `docker-compose.override.yml.TEMP` to `docker-compose.override.yml` you will also have a version of Jupyter running at http://localhost:8888 103 | 104 | _**Note:** The `Dockerfile` file for this project has been modified to make things run quicker, see [the Dockerfile](Dockerfile) for details._ 105 | 106 | ![This is an image](include/images/apache_airflow.jpg) 107 | -------------------------------------------------------------------------------- /02_Data_Science-ARIMA-Baseline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Citibike ML\n", 8 | "In this example notebook we show a baseline model using ARIMA. This provides a starting point for comparison." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "### 1. Load the Credentials\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from dags.snowpark_connection import snowpark_connect\n", 25 | "session, state_dict = snowpark_connect()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "### 2. Load Data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import snowflake.snowpark as snp\n", 42 | "from snowflake.snowpark import functions as F\n", 43 | "from snowflake.snowpark import types as T\n", 44 | "\n", 45 | "import pandas as pd\n", 46 | "from sklearn.metrics import mean_squared_error\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "import seaborn as sns\n", 49 | "from statsmodels.tsa.stattools import adfuller\n", 50 | "from statsmodels.graphics.tsaplots import plot_acf\n", 51 | "from statsmodels.tsa.arima.model import ARIMA" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "trips_table_name = state_dict['trips_table_name']" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "snowdf = session.table(trips_table_name)\n", 70 | "\n", 71 | "top_stations = snowdf.filter(F.col('START_STATION_ID').is_not_null()) \\\n", 72 | " .groupBy('START_STATION_ID') \\\n", 73 | " .count() \\\n", 74 | " .sort('COUNT', ascending=False) \\\n", 75 | " .toPandas()['START_STATION_ID'].values.tolist()\n", 76 | "\n", 77 | "df = snowdf.filter(F.col('START_STATION_ID') == top_stations[0]) \\\n", 78 | " .withColumn('DATE', \n", 79 | " F.call_builtin('DATE_TRUNC', ('DAY', F.col('STARTTIME')))) \\\n", 80 | " .groupBy('DATE') \\\n", 81 | " .count() \\\n", 82 | " .sort('DATE').toPandas()\n", 83 | "\n", 84 | "plt.figure(figsize=(15, 8))\n", 85 | "ax = sns.lineplot(x='DATE', y='COUNT', data=df)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "%matplotlib inline\n", 95 | "plt.style.use('seaborn-darkgrid')\n", 96 | "plt.rc(\"figure\", figsize=(10, 7))\n", 97 | "plot_acf(df['COUNT'], lags=400)\n", 98 | "plt.xlabel('Lags', fontsize=12)\n", 99 | "plt.ylabel('Autocorrelation', fontsize=12)\n", 100 | "plt.title('Autocorrelation of Trip Count Seasonality', fontsize=14)\n", 101 | "plt.show()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "We can definitely see the strong annual seasonality. Lets look closer at the daily and weekly lag." 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "plt.rc(\"figure\", figsize=(10, 7))\n", 118 | "plot_acf(df['COUNT'], lags=[1, 7, 30, 60, 90, 365])\n", 119 | "plt.xlabel('Lags', fontsize=12)\n", 120 | "plt.ylabel('Autocorrelation', fontsize=12)\n", 121 | "plt.title('Autocorrelation of Trip Count Seasonality', fontsize=14)\n", 122 | "plt.show()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### Baseline Model\n", 130 | "Lets build a baseline with ARIMA since we already have statsmodels imported" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "from statsmodels.tsa.arima.model import ARIMA\n", 140 | "model=ARIMA(df['COUNT'],order=(1,1,1))\n", 141 | "history=model.fit()\n", 142 | "df['HISTORY']=history.predict(start=0, end=len(df))\n", 143 | "plt.figure(figsize=(15, 8))\n", 144 | "df1 = pd.melt(df, id_vars=['DATE'], value_vars=['COUNT', 'HISTORY'])\n", 145 | "ax = sns.lineplot(x='DATE', y='value', hue='variable', data=df1)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "print(\"P-Value = \", adfuller(df['COUNT'].dropna(), autolag = 'AIC')[1])" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "With a p-value greater than .05 we know that the trend is non-trivial." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "rolling_mean = df['COUNT'].rolling(window = 7).mean()\n", 171 | "df['STATIONARY'] = rolling_mean - rolling_mean.shift()\n", 172 | "ax1 = plt.subplot()\n", 173 | "df['STATIONARY'].plot(title='Differenced');\n", 174 | "ax2 = plt.subplot()\n", 175 | "df['COUNT'].plot(title='original')" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "print(\"P-Value = \", adfuller(df['STATIONARY'].dropna(), autolag = 'AIC')[1])" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "By differencing with the rolling mean we can stationarize the series. In order to account for this trend we can create an exogenous signal." 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# import numpy as np\n", 201 | "# exog = pd.DataFrame(pd.PeriodIndex(df['DATE'], freq='D')).set_index('DATE')\n", 202 | "# exog['f1'] = np.sin(2 * np.pi * exog.index.dayofyear / 365.25)\n", 203 | "# exog['f2'] = np.cos(2 * np.pi * exog.index.dayofyear / 365.25)\n", 204 | "# exog['f3'] = np.sin(4 * np.pi * exog.index.dayofyear / 365.25)\n", 205 | "# exog['f4'] = np.cos(4 * np.pi * exog.index.dayofyear / 365.25)\n", 206 | "# exog = exog.reset_index()\n", 207 | "# exog = exog.drop(columns=['DATE'])" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "...or use we can use the `trend=` flag in ARIMA." 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "model=ARIMA(endog=df['COUNT'][:-365], trend='ct', order=(1,0,1))\n", 224 | "history=model.fit()\n", 225 | "df['HISTORY']=history.predict(start=0, end=len(df))\n", 226 | "plt.figure(figsize=(15, 8))\n", 227 | "df1 = pd.melt(df[:-365], id_vars=['DATE'], value_vars=['COUNT', 'HISTORY'])\n", 228 | "ax = sns.lineplot(x='DATE', y='value', hue='variable', data=df1)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "from sklearn.metrics import mean_squared_error\n", 238 | "dftest=pd.DataFrame()\n", 239 | "dftest['DATE']=df['DATE'][-365:]\n", 240 | "dftest['COUNT']=df['COUNT'][-365:]\n", 241 | "dftest['FORECAST']=history.forecast(steps=365)\n", 242 | "plt.figure(figsize=(15, 8))\n", 243 | "df1 = pd.melt(dftest, id_vars=['DATE'], value_vars=['COUNT', 'FORECAST'])\n", 244 | "ax = sns.lineplot(x='DATE', y='value', hue='variable', data=df1)\n", 245 | "\n", 246 | "error = mean_squared_error(dftest['COUNT'], dftest['FORECAST'])\n" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "print(\"ARIMA Error is: \"+str(error))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [] 264 | } 265 | ], 266 | "metadata": { 267 | "authors": [ 268 | { 269 | "name": "cforbe" 270 | } 271 | ], 272 | "kernelspec": { 273 | "display_name": "snowpark_0110:Python", 274 | "language": "python", 275 | "name": "conda-env-snowpark_0110-py" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.8.13" 288 | }, 289 | "msauthor": "trbye" 290 | }, 291 | "nbformat": 4, 292 | "nbformat_minor": 4 293 | } 294 | -------------------------------------------------------------------------------- /dags/elt.py: -------------------------------------------------------------------------------- 1 | def schema1_definition(): 2 | from snowflake.snowpark import types as T 3 | load_schema1 = T.StructType([T.StructField("TRIPDURATION", T.StringType()), 4 | T.StructField("STARTTIME", T.StringType()), 5 | T.StructField("STOPTIME", T.StringType()), 6 | T.StructField("START_STATION_ID", T.StringType()), 7 | T.StructField("START_STATION_NAME", T.StringType()), 8 | T.StructField("START_STATION_LATITUDE", T.StringType()), 9 | T.StructField("START_STATION_LONGITUDE", T.StringType()), 10 | T.StructField("END_STATION_ID", T.StringType()), 11 | T.StructField("END_STATION_NAME", T.StringType()), 12 | T.StructField("END_STATION_LATITUDE", T.StringType()), 13 | T.StructField("END_STATION_LONGITUDE", T.StringType()), 14 | T.StructField("BIKEID", T.StringType()), 15 | T.StructField("USERTYPE", T.StringType()), 16 | T.StructField("BIRTH_YEAR", T.StringType()), 17 | T.StructField("GENDER", T.StringType())]) 18 | return load_schema1 19 | 20 | def schema2_definition(): 21 | from snowflake.snowpark import types as T 22 | load_schema2 = T.StructType([T.StructField("ride_id", T.StringType()), 23 | T.StructField("rideable_type", T.StringType()), 24 | T.StructField("STARTTIME", T.StringType()), 25 | T.StructField("STOPTIME", T.StringType()), 26 | T.StructField("START_STATION_NAME", T.StringType()), 27 | T.StructField("START_STATION_ID", T.StringType()), 28 | T.StructField("END_STATION_NAME", T.StringType()), 29 | T.StructField("END_STATION_ID", T.StringType()), 30 | T.StructField("START_STATION_LATITUDE", T.StringType()), 31 | T.StructField("START_STATION_LONGITUDE", T.StringType()), 32 | T.StructField("END_STATION_LATITUDE", T.StringType()), 33 | T.StructField("END_STATION_LONGITUDE", T.StringType()), 34 | T.StructField("USERTYPE", T.StringType())]) 35 | return load_schema2 36 | 37 | def conformed_schema(): 38 | from snowflake.snowpark import types as T 39 | trips_table_schema = T.StructType([T.StructField("STARTTIME", T.StringType()), 40 | T.StructField("STOPTIME", T.StringType()), 41 | T.StructField("START_STATION_NAME", T.StringType()), 42 | T.StructField("START_STATION_ID", T.StringType()), 43 | T.StructField("END_STATION_NAME", T.StringType()), 44 | T.StructField("END_STATION_ID", T.StringType()), 45 | T.StructField("START_STATION_LATITUDE", T.StringType()), 46 | T.StructField("START_STATION_LONGITUDE", T.StringType()), 47 | T.StructField("END_STATION_LATITUDE", T.StringType()), 48 | T.StructField("END_STATION_LONGITUDE", T.StringType()), 49 | T.StructField("USERTYPE", T.StringType())]) 50 | return trips_table_schema 51 | 52 | def extract_trips_to_stage(session, files_to_download: list, download_base_url: str, load_stage_name:str): 53 | import os 54 | import requests 55 | from zipfile import ZipFile 56 | import gzip 57 | from datetime import datetime 58 | from io import BytesIO 59 | 60 | schema1_download_files = list() 61 | schema2_download_files = list() 62 | schema2_start_date = datetime.strptime('202102', "%Y%m") 63 | 64 | for file_name in files_to_download: 65 | file_start_date = datetime.strptime(file_name.split("-")[0], "%Y%m") 66 | if file_start_date < schema2_start_date: 67 | schema1_download_files.append(file_name) 68 | else: 69 | schema2_download_files.append(file_name) 70 | 71 | 72 | schema1_load_stage = load_stage_name+'/schema1/' 73 | schema1_files_to_load = list() 74 | for zip_file_name in schema1_download_files: 75 | 76 | url = download_base_url+zip_file_name 77 | 78 | print('Downloading and unzipping: '+url) 79 | r = requests.get(url) 80 | file = ZipFile(BytesIO(r.content)) 81 | csv_file_name=file.namelist()[0] 82 | file.extract(csv_file_name) 83 | file.close() 84 | 85 | print('Putting '+csv_file_name+' to stage: '+schema1_load_stage) 86 | session.file.put(local_file_name=csv_file_name, 87 | stage_location=schema1_load_stage, 88 | source_compression='NONE', 89 | overwrite=True) 90 | schema1_files_to_load.append(csv_file_name) 91 | os.remove(csv_file_name) 92 | 93 | 94 | schema2_load_stage = load_stage_name+'/schema2/' 95 | schema2_files_to_load = list() 96 | for zip_file_name in schema2_download_files: 97 | 98 | url = download_base_url+zip_file_name 99 | 100 | print('Downloading and unzipping: '+url) 101 | r = requests.get(url) 102 | file = ZipFile(BytesIO(r.content)) 103 | csv_file_name=file.namelist()[0] 104 | file.extract(csv_file_name) 105 | file.close() 106 | 107 | print('Putting '+csv_file_name+' to stage: '+schema2_load_stage) 108 | session.file.put(local_file_name=csv_file_name, 109 | stage_location=schema2_load_stage, 110 | source_compression='NONE', 111 | overwrite=True) 112 | schema2_files_to_load.append(csv_file_name) 113 | os.remove(csv_file_name) 114 | 115 | load_stage_names = {'schema1' : schema1_load_stage, 'schema2' : schema2_load_stage} 116 | files_to_load = {'schema1': schema1_files_to_load, 'schema2': schema2_files_to_load} 117 | 118 | return load_stage_names, files_to_load 119 | 120 | def load_trips_to_raw(session, files_to_load:dict, load_stage_names:dict, load_table_name:str): 121 | from snowflake.snowpark import functions as F 122 | from snowflake.snowpark import types as T 123 | from datetime import datetime 124 | 125 | csv_file_format_options = {"FIELD_OPTIONALLY_ENCLOSED_BY": "'\"'", "skip_header": 1} 126 | 127 | if len(files_to_load['schema1']) > 0: 128 | load_schema1 = schema1_definition() 129 | loaddf = session.read.option("SKIP_HEADER", 1)\ 130 | .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\ 131 | .option("COMPRESSION", "GZIP")\ 132 | .option("NULL_IF", "\\\\N")\ 133 | .option("NULL_IF", "NULL")\ 134 | .schema(load_schema1)\ 135 | .csv('@'+load_stage_names['schema1'])\ 136 | .copy_into_table(load_table_name+'schema1', 137 | files=files_to_load['schema1'], 138 | format_type_options=csv_file_format_options) 139 | 140 | if len(files_to_load['schema2']) > 0: 141 | load_schema2 = schema2_definition() 142 | loaddf = session.read.option("SKIP_HEADER", 1)\ 143 | .option("FIELD_OPTIONALLY_ENCLOSED_BY", "\042")\ 144 | .option("COMPRESSION", "GZIP")\ 145 | .option("NULL_IF", "\\\\N")\ 146 | .option("NULL_IF", "NULL")\ 147 | .schema(load_schema2)\ 148 | .csv('@'+load_stage_names['schema2'])\ 149 | .copy_into_table(load_table_name+'schema2', 150 | files=files_to_load['schema2'], 151 | format_type_options=csv_file_format_options) 152 | 153 | load_table_names = {'schema1' : load_table_name+str('schema1'), 154 | 'schema2' : load_table_name+str('schema2')} 155 | 156 | return load_table_names 157 | 158 | def transform_trips(session, stage_table_names:dict, trips_table_name:str): 159 | from snowflake.snowpark import functions as F 160 | 161 | #Change all dates to YYYY-MM-DD HH:MI:SS format 162 | date_format_match = "^([0-9]?[0-9])/([0-9]?[0-9])/([0-9][0-9][0-9][0-9]) ([0-9]?[0-9]):([0-9][0-9])(:[0-9][0-9])?.*$" 163 | date_format_repl = "\\3-\\1-\\2 \\4:\\5\\6" 164 | 165 | trips_table_schema = conformed_schema() 166 | 167 | trips_table_schema_names = [field.name for field in trips_table_schema.fields] 168 | 169 | transdf1 = session.table(stage_table_names['schema1'])[trips_table_schema_names] 170 | transdf2 = session.table(stage_table_names['schema2'])[trips_table_schema_names] 171 | 172 | transdf = transdf1.union_by_name(transdf2)\ 173 | .with_column('STARTTIME', F.regexp_replace(F.col('STARTTIME'), 174 | F.lit(date_format_match), 175 | F.lit(date_format_repl)))\ 176 | .with_column('STARTTIME', F.to_timestamp('STARTTIME'))\ 177 | .with_column('STOPTIME', F.regexp_replace(F.col('STOPTIME'), 178 | F.lit(date_format_match), 179 | F.lit(date_format_repl)))\ 180 | .with_column('STOPTIME', F.to_timestamp('STOPTIME'))\ 181 | .write.mode('overwrite').save_as_table(trips_table_name) 182 | 183 | return trips_table_name 184 | 185 | def reset_database(session, state_dict:dict, prestaged=False): 186 | _ = session.sql('CREATE OR REPLACE DATABASE '+state_dict['connection_parameters']['database']).collect() 187 | _ = session.sql('CREATE SCHEMA '+state_dict['connection_parameters']['schema']).collect() 188 | 189 | if prestaged: 190 | sql_cmd = 'CREATE OR REPLACE STAGE '+state_dict['load_stage_name']+\ 191 | ' url='+state_dict['connection_parameters']['download_base_url'] 192 | _ = session.sql(sql_cmd).collect() 193 | else: 194 | _ = session.sql('CREATE STAGE IF NOT EXISTS '+state_dict['load_stage_name']).collect() 195 | 196 | load_schema1=schema1_definition() 197 | session.create_dataframe([[None]*len(load_schema1.names)], schema=load_schema1)\ 198 | .na.drop()\ 199 | .write\ 200 | .save_as_table(state_dict['load_table_name']+'schema1') 201 | 202 | load_schema2=schema2_definition() 203 | session.create_dataframe([[None]*len(load_schema2.names)], schema=load_schema2)\ 204 | .na.drop()\ 205 | .write\ 206 | .save_as_table(state_dict['load_table_name']+'schema2') 207 | 208 | -------------------------------------------------------------------------------- /dags/mlops_tasks.py: -------------------------------------------------------------------------------- 1 | 2 | def snowpark_database_setup(state_dict:dict)-> dict: 3 | import snowflake.snowpark.functions as F 4 | from dags.snowpark_connection import snowpark_connect 5 | from dags.elt import reset_database 6 | 7 | session, _ = snowpark_connect('./include/state.json') 8 | reset_database(session=session, state_dict=state_dict, prestaged=True) 9 | 10 | _ = session.sql('CREATE STAGE '+state_dict['model_stage_name']).collect() 11 | _ = session.sql('CREATE TAG model_id_tag').collect() 12 | 13 | session.close() 14 | 15 | return state_dict 16 | 17 | def incremental_elt_task(state_dict: dict, files_to_download:list)-> dict: 18 | from dags.ingest import incremental_elt 19 | from dags.snowpark_connection import snowpark_connect 20 | 21 | session, _ = snowpark_connect() 22 | 23 | print('Ingesting '+str(files_to_download)) 24 | download_base_url=state_dict['connection_parameters']['download_base_url'] 25 | 26 | _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse']) 27 | 28 | _ = incremental_elt(session=session, 29 | state_dict=state_dict, 30 | files_to_ingest=files_to_download, 31 | download_base_url=download_base_url, 32 | use_prestaged=True) 33 | 34 | #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\ 35 | # ' SUSPEND').collect() 36 | 37 | session.close() 38 | return state_dict 39 | 40 | def initial_bulk_load_task(state_dict:dict)-> dict: 41 | from dags.ingest import bulk_elt 42 | from dags.snowpark_connection import snowpark_connect 43 | 44 | session, _ = snowpark_connect() 45 | 46 | _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse']) 47 | 48 | print('Running initial bulk ingest from '+state_dict['connection_parameters']['download_base_url']) 49 | 50 | _ = bulk_elt(session=session, 51 | state_dict=state_dict, 52 | download_base_url=state_dict['connection_parameters']['download_base_url'], 53 | use_prestaged=True) 54 | 55 | #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\ 56 | # ' SUSPEND').collect() 57 | 58 | session.close() 59 | return state_dict 60 | 61 | def materialize_holiday_task(state_dict: dict)-> dict: 62 | from dags.snowpark_connection import snowpark_connect 63 | from dags.mlops_pipeline import materialize_holiday_table 64 | 65 | print('Materializing holiday table.') 66 | session, _ = snowpark_connect() 67 | 68 | _ = materialize_holiday_table(session=session, 69 | holiday_table_name=state_dict['holiday_table_name']) 70 | 71 | session.close() 72 | return state_dict 73 | 74 | def subscribe_to_weather_data_task(state_dict: dict)-> dict: 75 | from dags.snowpark_connection import snowpark_connect 76 | from dags.mlops_pipeline import subscribe_to_weather_data 77 | 78 | print('Subscribing to weather data') 79 | session, _ = snowpark_connect() 80 | 81 | _ = subscribe_to_weather_data(session=session, 82 | weather_database_name=state_dict['weather_database_name'], 83 | weather_listing_id=state_dict['weather_listing_id']) 84 | session.close() 85 | return state_dict 86 | 87 | def create_weather_view_task(state_dict: dict)-> dict: 88 | from dags.snowpark_connection import snowpark_connect 89 | from dags.mlops_pipeline import create_weather_view 90 | 91 | print('Creating weather view') 92 | session, _ = snowpark_connect() 93 | 94 | _ = create_weather_view(session=session, 95 | weather_table_name=state_dict['weather_table_name'], 96 | weather_view_name=state_dict['weather_view_name']) 97 | session.close() 98 | return state_dict 99 | 100 | def deploy_model_udf_task(state_dict:dict)-> dict: 101 | from dags.snowpark_connection import snowpark_connect 102 | from dags.mlops_pipeline import deploy_pred_train_udf 103 | 104 | print('Deploying station model') 105 | session, _ = snowpark_connect() 106 | 107 | _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect() 108 | 109 | _ = deploy_pred_train_udf(session=session, 110 | udf_name=state_dict['train_udf_name'], 111 | function_name=state_dict['train_func_name'], 112 | model_stage_name=state_dict['model_stage_name']) 113 | session.close() 114 | return state_dict 115 | 116 | def deploy_eval_udf_task(state_dict:dict)-> dict: 117 | from dags.snowpark_connection import snowpark_connect 118 | from dags.mlops_pipeline import deploy_eval_udf 119 | 120 | print('Deploying station model') 121 | session, _ = snowpark_connect() 122 | 123 | _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect() 124 | 125 | _ = deploy_eval_udf(session=session, 126 | udf_name=state_dict['eval_udf_name'], 127 | function_name=state_dict['eval_func_name'], 128 | model_stage_name=state_dict['model_stage_name']) 129 | session.close() 130 | return state_dict 131 | 132 | def generate_feature_table_task(state_dict:dict, 133 | holiday_state_dict:dict, 134 | weather_state_dict:dict)-> dict: 135 | from dags.snowpark_connection import snowpark_connect 136 | from dags.mlops_pipeline import create_feature_table 137 | 138 | print('Generating features for all stations.') 139 | session, _ = snowpark_connect() 140 | 141 | session.use_warehouse(state_dict['compute_parameters']['fe_warehouse']) 142 | 143 | _ = session.sql("CREATE OR REPLACE TABLE "+state_dict['clone_table_name']+\ 144 | " CLONE "+state_dict['trips_table_name']).collect() 145 | _ = session.sql("ALTER TABLE "+state_dict['clone_table_name']+\ 146 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 147 | 148 | _ = create_feature_table(session, 149 | trips_table_name=state_dict['clone_table_name'], 150 | holiday_table_name=state_dict['holiday_table_name'], 151 | weather_view_name=state_dict['weather_view_name'], 152 | feature_table_name=state_dict['feature_table_name']) 153 | 154 | _ = session.sql("ALTER TABLE "+state_dict['feature_table_name']+\ 155 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 156 | 157 | session.close() 158 | return state_dict 159 | 160 | def generate_forecast_table_task(state_dict:dict, 161 | holiday_state_dict:dict, 162 | weather_state_dict:dict)-> dict: 163 | from dags.snowpark_connection import snowpark_connect 164 | from dags.mlops_pipeline import create_forecast_table 165 | 166 | print('Generating forecast features.') 167 | session, _ = snowpark_connect() 168 | 169 | _ = create_forecast_table(session, 170 | trips_table_name=state_dict['trips_table_name'], 171 | holiday_table_name=state_dict['holiday_table_name'], 172 | weather_view_name=state_dict['weather_view_name'], 173 | forecast_table_name=state_dict['forecast_table_name'], 174 | steps=state_dict['forecast_steps']) 175 | 176 | _ = session.sql("ALTER TABLE "+state_dict['forecast_table_name']+\ 177 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 178 | 179 | session.close() 180 | return state_dict 181 | 182 | def bulk_train_predict_task(state_dict:dict, 183 | feature_state_dict:dict, 184 | forecast_state_dict:dict)-> dict: 185 | from dags.snowpark_connection import snowpark_connect 186 | from dags.mlops_pipeline import train_predict 187 | 188 | state_dict = feature_state_dict 189 | 190 | print('Running bulk training and forecast.') 191 | session, _ = snowpark_connect() 192 | 193 | session.use_warehouse(state_dict['compute_parameters']['train_warehouse']) 194 | 195 | pred_table_name = train_predict(session, 196 | station_train_pred_udf_name=state_dict['train_udf_name'], 197 | feature_table_name=state_dict['feature_table_name'], 198 | forecast_table_name=state_dict['forecast_table_name'], 199 | pred_table_name=state_dict['pred_table_name']) 200 | 201 | _ = session.sql("ALTER TABLE "+state_dict['pred_table_name']+\ 202 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 203 | #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['train_warehouse']+\ 204 | # ' SUSPEND').collect() 205 | 206 | session.close() 207 | return state_dict 208 | 209 | def eval_station_models_task(state_dict:dict, 210 | pred_state_dict:dict, 211 | run_date:str)-> dict: 212 | 213 | from dags.snowpark_connection import snowpark_connect 214 | from dags.mlops_pipeline import evaluate_station_model 215 | 216 | print('Running eval UDF for model output') 217 | session, _ = snowpark_connect() 218 | 219 | eval_table_name = evaluate_station_model(session, 220 | run_date=run_date, 221 | eval_model_udf_name=state_dict['eval_udf_name'], 222 | pred_table_name=state_dict['pred_table_name'], 223 | eval_table_name=state_dict['eval_table_name']) 224 | 225 | _ = session.sql("ALTER TABLE "+state_dict['eval_table_name']+\ 226 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 227 | session.close() 228 | return state_dict 229 | 230 | def flatten_tables_task(pred_state_dict:dict, state_dict:dict)-> dict: 231 | from dags.snowpark_connection import snowpark_connect 232 | from dags.mlops_pipeline import flatten_tables 233 | 234 | print('Flattening tables for end-user consumption.') 235 | session, _ = snowpark_connect() 236 | 237 | flat_pred_table, flat_forecast_table, flat_eval_table = flatten_tables(session, 238 | pred_table_name=state_dict['pred_table_name'], 239 | forecast_table_name=state_dict['forecast_table_name'], 240 | eval_table_name=state_dict['eval_table_name']) 241 | state_dict['flat_pred_table'] = flat_pred_table 242 | state_dict['flat_forecast_table'] = flat_forecast_table 243 | state_dict['flat_eval_table'] = flat_eval_table 244 | 245 | _ = session.sql("ALTER TABLE "+flat_pred_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 246 | _ = session.sql("ALTER TABLE "+flat_forecast_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 247 | _ = session.sql("ALTER TABLE "+flat_eval_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 248 | 249 | return state_dict 250 | -------------------------------------------------------------------------------- /dags/airflow_tasks.py: -------------------------------------------------------------------------------- 1 | 2 | from airflow.decorators import task 3 | 4 | @task.virtualenv(python_version=3.8) 5 | def snowpark_database_setup(state_dict:dict)-> dict: 6 | import snowflake.snowpark.functions as F 7 | from dags.snowpark_connection import snowpark_connect 8 | from dags.elt import reset_database 9 | 10 | session, _ = snowpark_connect('./include/state.json') 11 | reset_database(session=session, state_dict=state_dict, prestaged=True) 12 | 13 | _ = session.sql('CREATE STAGE '+state_dict['model_stage_name']).collect() 14 | _ = session.sql('CREATE TAG model_id_tag').collect() 15 | 16 | session.close() 17 | 18 | return state_dict 19 | 20 | @task.virtualenv(python_version=3.8) 21 | def incremental_elt_task(state_dict: dict, files_to_download:list)-> dict: 22 | from dags.ingest import incremental_elt 23 | from dags.snowpark_connection import snowpark_connect 24 | 25 | session, _ = snowpark_connect() 26 | 27 | print('Ingesting '+str(files_to_download)) 28 | download_base_url=state_dict['connection_parameters']['download_base_url'] 29 | 30 | _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse']) 31 | 32 | _ = incremental_elt(session=session, 33 | state_dict=state_dict, 34 | files_to_ingest=files_to_download, 35 | download_base_url=download_base_url, 36 | use_prestaged=True) 37 | 38 | #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\ 39 | # ' SUSPEND').collect() 40 | 41 | session.close() 42 | return state_dict 43 | 44 | @task.virtualenv(python_version=3.8) 45 | def initial_bulk_load_task(state_dict:dict)-> dict: 46 | from dags.ingest import bulk_elt 47 | from dags.snowpark_connection import snowpark_connect 48 | 49 | session, _ = snowpark_connect() 50 | 51 | _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse']) 52 | 53 | print('Running initial bulk ingest from '+state_dict['connection_parameters']['download_base_url']) 54 | 55 | _ = bulk_elt(session=session, 56 | state_dict=state_dict, 57 | download_base_url=state_dict['connection_parameters']['download_base_url'], 58 | use_prestaged=True) 59 | 60 | #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\ 61 | # ' SUSPEND').collect() 62 | 63 | session.close() 64 | return state_dict 65 | 66 | @task.virtualenv(python_version=3.8) 67 | def materialize_holiday_task(state_dict: dict)-> dict: 68 | from dags.snowpark_connection import snowpark_connect 69 | from dags.mlops_pipeline import materialize_holiday_table 70 | 71 | print('Materializing holiday table.') 72 | session, _ = snowpark_connect() 73 | 74 | _ = materialize_holiday_table(session=session, 75 | holiday_table_name=state_dict['holiday_table_name']) 76 | 77 | session.close() 78 | return state_dict 79 | 80 | @task.virtualenv(python_version=3.8) 81 | def subscribe_to_weather_data_task(state_dict: dict)-> dict: 82 | from dags.snowpark_connection import snowpark_connect 83 | from dags.mlops_pipeline import subscribe_to_weather_data 84 | 85 | print('Subscribing to weather data') 86 | session, _ = snowpark_connect() 87 | 88 | _ = subscribe_to_weather_data(session=session, 89 | weather_database_name=state_dict['weather_database_name'], 90 | weather_listing_id=state_dict['weather_listing_id']) 91 | session.close() 92 | return state_dict 93 | 94 | @task.virtualenv(python_version=3.8) 95 | def create_weather_view_task(state_dict: dict)-> dict: 96 | from dags.snowpark_connection import snowpark_connect 97 | from dags.mlops_pipeline import create_weather_view 98 | 99 | print('Creating weather view') 100 | session, _ = snowpark_connect() 101 | 102 | _ = create_weather_view(session=session, 103 | weather_table_name=state_dict['weather_table_name'], 104 | weather_view_name=state_dict['weather_view_name']) 105 | session.close() 106 | return state_dict 107 | 108 | @task.virtualenv(python_version=3.8) 109 | def deploy_model_udf_task(state_dict:dict)-> dict: 110 | from dags.snowpark_connection import snowpark_connect 111 | from dags.mlops_pipeline import deploy_pred_train_udf 112 | 113 | print('Deploying station model') 114 | session, _ = snowpark_connect() 115 | 116 | _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect() 117 | 118 | _ = deploy_pred_train_udf(session=session, 119 | udf_name=state_dict['train_udf_name'], 120 | function_name=state_dict['train_func_name'], 121 | model_stage_name=state_dict['model_stage_name']) 122 | session.close() 123 | return state_dict 124 | 125 | @task.virtualenv(python_version=3.8) 126 | def deploy_eval_udf_task(state_dict:dict)-> dict: 127 | from dags.snowpark_connection import snowpark_connect 128 | from dags.mlops_pipeline import deploy_eval_udf 129 | 130 | print('Deploying station model') 131 | session, _ = snowpark_connect() 132 | 133 | _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect() 134 | 135 | _ = deploy_eval_udf(session=session, 136 | udf_name=state_dict['eval_udf_name'], 137 | function_name=state_dict['eval_func_name'], 138 | model_stage_name=state_dict['model_stage_name']) 139 | session.close() 140 | return state_dict 141 | 142 | @task.virtualenv(python_version=3.8) 143 | def generate_feature_table_task(state_dict:dict, 144 | holiday_state_dict:dict, 145 | weather_state_dict:dict)-> dict: 146 | from dags.snowpark_connection import snowpark_connect 147 | from dags.mlops_pipeline import create_feature_table 148 | 149 | print('Generating features for all stations.') 150 | session, _ = snowpark_connect() 151 | 152 | session.use_warehouse(state_dict['compute_parameters']['fe_warehouse']) 153 | 154 | _ = session.sql("CREATE OR REPLACE TABLE "+state_dict['clone_table_name']+\ 155 | " CLONE "+state_dict['trips_table_name']).collect() 156 | _ = session.sql("ALTER TABLE "+state_dict['clone_table_name']+\ 157 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 158 | 159 | _ = create_feature_table(session, 160 | trips_table_name=state_dict['clone_table_name'], 161 | holiday_table_name=state_dict['holiday_table_name'], 162 | weather_view_name=state_dict['weather_view_name'], 163 | feature_table_name=state_dict['feature_table_name']) 164 | 165 | _ = session.sql("ALTER TABLE "+state_dict['feature_table_name']+\ 166 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 167 | 168 | session.close() 169 | return state_dict 170 | 171 | @task.virtualenv(python_version=3.8) 172 | def generate_forecast_table_task(state_dict:dict, 173 | holiday_state_dict:dict, 174 | weather_state_dict:dict)-> dict: 175 | from dags.snowpark_connection import snowpark_connect 176 | from dags.mlops_pipeline import create_forecast_table 177 | 178 | print('Generating forecast features.') 179 | session, _ = snowpark_connect() 180 | 181 | _ = create_forecast_table(session, 182 | trips_table_name=state_dict['trips_table_name'], 183 | holiday_table_name=state_dict['holiday_table_name'], 184 | weather_view_name=state_dict['weather_view_name'], 185 | forecast_table_name=state_dict['forecast_table_name'], 186 | steps=state_dict['forecast_steps']) 187 | 188 | _ = session.sql("ALTER TABLE "+state_dict['forecast_table_name']+\ 189 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 190 | 191 | session.close() 192 | return state_dict 193 | 194 | @task.virtualenv(python_version=3.8) 195 | def bulk_train_predict_task(state_dict:dict, 196 | feature_state_dict:dict, 197 | forecast_state_dict:dict)-> dict: 198 | from dags.snowpark_connection import snowpark_connect 199 | from dags.mlops_pipeline import train_predict 200 | 201 | state_dict = feature_state_dict 202 | 203 | print('Running bulk training and forecast.') 204 | session, _ = snowpark_connect() 205 | 206 | session.use_warehouse(state_dict['compute_parameters']['train_warehouse']) 207 | 208 | pred_table_name = train_predict(session, 209 | station_train_pred_udf_name=state_dict['train_udf_name'], 210 | feature_table_name=state_dict['feature_table_name'], 211 | forecast_table_name=state_dict['forecast_table_name'], 212 | pred_table_name=state_dict['pred_table_name']) 213 | 214 | _ = session.sql("ALTER TABLE "+state_dict['pred_table_name']+\ 215 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 216 | #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['train_warehouse']+\ 217 | # ' SUSPEND').collect() 218 | 219 | session.close() 220 | return state_dict 221 | 222 | @task.virtualenv(python_version=3.8) 223 | def eval_station_models_task(state_dict:dict, 224 | pred_state_dict:dict, 225 | run_date:str)-> dict: 226 | 227 | from dags.snowpark_connection import snowpark_connect 228 | from dags.mlops_pipeline import evaluate_station_model 229 | 230 | print('Running eval UDF for model output') 231 | session, _ = snowpark_connect() 232 | 233 | eval_table_name = evaluate_station_model(session, 234 | run_date=run_date, 235 | eval_model_udf_name=state_dict['eval_udf_name'], 236 | pred_table_name=state_dict['pred_table_name'], 237 | eval_table_name=state_dict['eval_table_name']) 238 | 239 | _ = session.sql("ALTER TABLE "+state_dict['eval_table_name']+\ 240 | " SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 241 | session.close() 242 | return state_dict 243 | 244 | @task.virtualenv(python_version=3.8) 245 | def flatten_tables_task(pred_state_dict:dict, state_dict:dict)-> dict: 246 | from dags.snowpark_connection import snowpark_connect 247 | from dags.mlops_pipeline import flatten_tables 248 | 249 | print('Flattening tables for end-user consumption.') 250 | session, _ = snowpark_connect() 251 | 252 | flat_pred_table, flat_forecast_table, flat_eval_table = flatten_tables(session, 253 | pred_table_name=state_dict['pred_table_name'], 254 | forecast_table_name=state_dict['forecast_table_name'], 255 | eval_table_name=state_dict['eval_table_name']) 256 | state_dict['flat_pred_table'] = flat_pred_table 257 | state_dict['flat_forecast_table'] = flat_forecast_table 258 | state_dict['flat_eval_table'] = flat_eval_table 259 | 260 | _ = session.sql("ALTER TABLE "+flat_pred_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 261 | _ = session.sql("ALTER TABLE "+flat_forecast_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 262 | _ = session.sql("ALTER TABLE "+flat_eval_table+" SET TAG model_id_tag = '"+state_dict['model_id']+"'").collect() 263 | 264 | return state_dict 265 | -------------------------------------------------------------------------------- /dags/cdc.py: -------------------------------------------------------------------------------- 1 | def schema1_spoc_str(procedure_name:str, interim_target_table_name:str, stream_name:str) -> str: 2 | create_spoc_sql="CREATE OR REPLACE PROCEDURE "+procedure_name+"() " + \ 3 | "RETURNS VARCHAR " + \ 4 | "LANGUAGE SQL " + \ 5 | "AS " + \ 6 | "$$ " + \ 7 | " BEGIN " + \ 8 | " MERGE INTO " + interim_target_table_name + \ 9 | " AS T USING (SELECT * FROM " + stream_name + ") \ 10 | AS S ON concat(T.BIKEID, T.STARTTIME, T.STOPTIME) = concat(S.BIKEID, S.STARTTIME, S.STOPTIME) \ 11 | WHEN MATCHED AND S.metadata$action = 'INSERT' \ 12 | AND S.metadata$isupdate \ 13 | THEN UPDATE SET T.TRIPDURATION = S.TRIPDURATION, \ 14 | T.STARTTIME = S.STARTTIME, \ 15 | T.STOPTIME = S.STOPTIME, \ 16 | T.START_STATION_ID = S.START_STATION_ID, \ 17 | T.START_STATION_NAME = S.START_STATION_NAME, \ 18 | T.START_STATION_LATITUDE = S.START_STATION_LATITUDE, \ 19 | T.START_STATION_LONGITUDE = S.START_STATION_LONGITUDE, \ 20 | T.END_STATION_ID = S.END_STATION_ID, \ 21 | T.END_STATION_NAME = S.END_STATION_NAME, \ 22 | T.END_STATION_LATITUDE = S.END_STATION_LATITUDE, \ 23 | T.END_STATION_LONGITUDE = S.END_STATION_LONGITUDE, \ 24 | T.BIKEID = S.BIKEID, \ 25 | T.USERTYPE = S.USERTYPE, \ 26 | T.BIRTH_YEAR = S.BIRTH_YEAR, \ 27 | T.GENDER = S.GENDER \ 28 | WHEN MATCHED AND S.metadata$action = 'DELETE' \ 29 | THEN DELETE \ 30 | WHEN NOT MATCHED AND S.metadata$action = 'INSERT' \ 31 | THEN INSERT (TRIPDURATION, \ 32 | STARTTIME , \ 33 | STOPTIME , \ 34 | START_STATION_ID , \ 35 | START_STATION_NAME , \ 36 | START_STATION_LATITUDE , \ 37 | START_STATION_LONGITUDE , \ 38 | END_STATION_ID , \ 39 | END_STATION_NAME , \ 40 | END_STATION_LATITUDE , \ 41 | END_STATION_LONGITUDE , \ 42 | BIKEID , \ 43 | USERTYPE , \ 44 | BIRTH_YEAR , \ 45 | GENDER) \ 46 | VALUES (S.TRIPDURATION, \ 47 | S.STARTTIME , \ 48 | S.STOPTIME , \ 49 | S.START_STATION_ID , \ 50 | S.START_STATION_NAME , \ 51 | S.START_STATION_LATITUDE , \ 52 | S.START_STATION_LONGITUDE , \ 53 | S.END_STATION_ID , \ 54 | S.END_STATION_NAME , \ 55 | S.END_STATION_LATITUDE , \ 56 | S.END_STATION_LONGITUDE , \ 57 | S.BIKEID , \ 58 | S.USERTYPE , \ 59 | S.BIRTH_YEAR , \ 60 | S.GENDER); " + \ 61 | " END; " + \ 62 | "$$" 63 | 64 | return create_spoc_sql 65 | 66 | def schema2_spoc_str(procedure_name:str, interim_target_table_name:str, stream_name:str) -> str: 67 | create_spoc_sql="CREATE OR REPLACE PROCEDURE "+procedure_name+"() " + \ 68 | "RETURNS VARCHAR " + \ 69 | "LANGUAGE SQL " + \ 70 | "AS " + \ 71 | "$$ " + \ 72 | " BEGIN " + \ 73 | " MERGE INTO " + interim_target_table_name + \ 74 | " AS T USING (SELECT * FROM " + stream_name + ") \ 75 | AS S ON T.RIDE_ID = S.RIDE_ID \ 76 | WHEN MATCHED AND S.metadata$action = 'INSERT' \ 77 | AND S.metadata$isupdate \ 78 | THEN UPDATE SET T.RIDE_ID = S.RIDE_ID, \ 79 | T.RIDEABLE_TYPE = S.RIDEABLE_TYPE, \ 80 | T.STARTTIME = S.STARTTIME, \ 81 | T.STOPTIME = S.STOPTIME, \ 82 | T.START_STATION_NAME = S.START_STATION_NAME, \ 83 | T.START_STATION_ID = S.START_STATION_ID, \ 84 | T.END_STATION_NAME = S.END_STATION_NAME, \ 85 | T.END_STATION_ID = S.END_STATION_ID, \ 86 | T.START_STATION_LATITUDE = S.START_STATION_LATITUDE, \ 87 | T.START_STATION_LONGITUDE = S.END_STATION_LATITUDE, \ 88 | T.END_STATION_LONGITUDE = S.END_STATION_LONGITUDE, \ 89 | T.USERTYPE = S.USERTYPE \ 90 | WHEN MATCHED AND S.metadata$action = 'DELETE' \ 91 | THEN DELETE \ 92 | WHEN NOT MATCHED AND S.metadata$action = 'INSERT' \ 93 | THEN INSERT (RIDE_ID, \ 94 | RIDEABLE_TYPE, \ 95 | STARTTIME, \ 96 | STOPTIME, \ 97 | START_STATION_NAME, \ 98 | START_STATION_ID, \ 99 | END_STATION_NAME, \ 100 | END_STATION_ID, \ 101 | START_STATION_LATITUDE, \ 102 | END_STATION_LATITUDE, \ 103 | END_STATION_LONGITUDE, \ 104 | USERTYPE) \ 105 | VALUES (S.RIDE_ID, \ 106 | S.RIDEABLE_TYPE, \ 107 | S.STARTTIME, \ 108 | S.STOPTIME, \ 109 | S.START_STATION_NAME, \ 110 | S.START_STATION_ID, \ 111 | S.END_STATION_NAME, \ 112 | S.END_STATION_ID, \ 113 | S.START_STATION_LATITUDE, \ 114 | S.END_STATION_LATITUDE, \ 115 | S.END_STATION_LONGITUDE, \ 116 | S.USERTYPE); " + \ 117 | " END; " + \ 118 | "$$" 119 | return create_spoc_sql 120 | 121 | def load_trips_from_raw_to_interim_target_cdc(session, 122 | stage_table_names:list, 123 | cdc_task_warehouse_name:str): 124 | from datetime import datetime 125 | interim_target_table_names = list() 126 | for stage_table_name in stage_table_names: 127 | schema = stage_table_name.split("_")[1] 128 | if schema == 'schema1': 129 | interim_target_table_name = 'INTERIM_schema1' 130 | stream_name = 'STREAM_schema1' 131 | task_name = 'TRIPSCDCTASK_schema1' 132 | procedure_name = 'TRIPSCDCPROC_schema1' 133 | create_processcdc_procedure_statement = schema1_spoc_str(procedure_name, 134 | interim_target_table_name, 135 | stream_name) 136 | 137 | elif schema == 'schema2': 138 | interim_target_table_name = 'INTERIM_schema2' 139 | stream_name = 'STREAM_schema2' 140 | task_name = 'TRIPSCDCTASK_schema2' 141 | procedure_name = 'TRIPSCDCPROC_schema2' 142 | create_processcdc_procedure_statement = schema2_spoc_str(procedure_name, 143 | interim_target_table_name, 144 | stream_name) 145 | 146 | #outside the if else condition but still inside the for loop 147 | interim_target_table_names.append(interim_target_table_name) 148 | create_stream_sql ='CREATE OR REPLACE STREAM ' + stream_name + \ 149 | ' ON TABLE ' + stage_table_name + \ 150 | ' APPEND_ONLY = FALSE SHOW_INITIAL_ROWS = TRUE' 151 | 152 | create_interim_target_table_sql = 'CREATE OR REPLACE TABLE ' + interim_target_table_name +\ 153 | ' LIKE ' + stage_table_name 154 | create_task_statement = "CREATE OR REPLACE TASK " + task_name + \ 155 | " WAREHOUSE='" + cdc_task_warehouse_name +"'"+ \ 156 | " SCHEDULE = '1 minute'"+ \ 157 | " WHEN SYSTEM$STREAM_HAS_DATA('" + stream_name + "')"+\ 158 | " AS CALL " + procedure_name + "()" 159 | resume_task_statement = "ALTER TASK " + task_name + " RESUME" 160 | 161 | _ = session.sql(create_stream_sql).collect() 162 | _ = session.sql(create_interim_target_table_sql).collect() 163 | _ = session.sql(create_processcdc_procedure_statement).collect() 164 | _ = session.sql(create_task_statement).collect() 165 | _ = session.sql(resume_task_statement).collect() 166 | 167 | return interim_target_table_names 168 | 169 | 170 | 171 | def cdc_elt(session, load_stage_name, files_to_download, download_base_url, load_table_name, trips_table_name) -> str: 172 | from citibike_ml import elt as ELT 173 | 174 | load_stage_name, files_to_load = ELT.extract_trips_to_stage(session=session, 175 | files_to_download=files_to_download, 176 | download_base_url=download_base_url, 177 | load_stage_name=load_stage_name) 178 | stage_table_names = ELT.load_trips_to_raw(session, 179 | files_to_load=files_to_load, 180 | load_stage_name=load_stage_name, 181 | load_table_name=load_table_name) 182 | 183 | interim_target_table_names = load_trips_from_raw_to_target_cdc(session, files_to_load, load_table_name, cdc_target_table_name, stream_name, cdc_task_warehouse_name, procedure_name, full_task_name) 184 | 185 | trips_table_name = ELT.transform_trips(session=session, 186 | stage_table_names=interim_target_table_names, 187 | trips_table_name=trips_table_name) 188 | return trips_table_name 189 | 190 | -------------------------------------------------------------------------------- /06_Streamlit_App.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2e21e38c", 6 | "metadata": {}, 7 | "source": [ 8 | "## Streamlit Application\n", 9 | "In this section of the hands-on-lab, we will utilize Streamlit with Snowpark's Python client-side Dataframe API to create a visual front-end application for the Citibike operations team to consume the insights from the ML forecast." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "a66099ab", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%%writefile include/streamlit_app.py\n", 20 | "import sys, os\n", 21 | "sys.path.append(os.getcwd()+'/dags')\n", 22 | "\n", 23 | "from snowflake.snowpark import functions as F\n", 24 | "from snowpark_connection import snowpark_connect\n", 25 | "import streamlit as st\n", 26 | "import pandas as pd\n", 27 | "from datetime import timedelta, datetime\n", 28 | "from dateutil.relativedelta import *\n", 29 | "import calendar\n", 30 | "import altair as alt\n", 31 | "import requests\n", 32 | "from requests.auth import HTTPBasicAuth\n", 33 | "import time \n", 34 | "import json\n", 35 | "import logging\n", 36 | "\n", 37 | "logging.basicConfig(level=logging.WARN)\n", 38 | "logging.getLogger().setLevel(logging.WARN)\n", 39 | "\n", 40 | "def update_forecast_table(forecast_df, stations:list, start_date, end_date):\n", 41 | "# explainer_columns = [col for col in forecast_df.schema.names if 'EXP' in col]\n", 42 | " explainer_columns=['EXPL_LAG_1', 'EXPL_LAG_7','EXPL_LAG_90','EXPL_LAG_365','EXPL_HOLIDAY','EXPL_PRECIP','EXPL_TEMP']\n", 43 | " explainer_columns_new=['DAY', 'DAY_OF_WEEK', 'QUARTER', 'DAY_OF_YEAR','US_HOLIDAY', 'PRECIPITATION','TEMPERATURE']\n", 44 | "\n", 45 | " cond = \"F.when\" + \".when\".join([\"(F.col('\" + c + \"') == F.col('EXPLAIN'), F.lit('\" + c + \"'))\" for c in explainer_columns])\n", 46 | "\n", 47 | " df = forecast_df.filter((forecast_df['STATION_ID'].in_(stations)) &\n", 48 | " (F.col('DATE') >= start_date) & \n", 49 | " (F.col('DATE') <= end_date))\\\n", 50 | " .select(['STATION_ID', \n", 51 | " F.to_char(F.col('DATE')).alias('DATE'), \n", 52 | " 'PRED', \n", 53 | " 'HOLIDAY',\n", 54 | " *explainer_columns])\\\n", 55 | " .with_column('EXPLAIN', F.greatest(*explainer_columns))\\\n", 56 | " .with_column('REASON', eval(cond))\\\n", 57 | " .select(F.col('STATION_ID'), \n", 58 | " F.col('DATE'), \n", 59 | " F.col('PRED'), \n", 60 | " F.col('REASON'), \n", 61 | " F.col('EXPLAIN'), \n", 62 | " F.col('EXPL_LAG_1').alias('DAY'),\n", 63 | " F.col('EXPL_LAG_7').alias('DAY_OF_WEEK'),\n", 64 | " F.col('EXPL_LAG_90').alias('QUARTER'),\n", 65 | " F.col('EXPL_LAG_365').alias('DAY_OF_YEAR'),\n", 66 | " F.col('EXPL_HOLIDAY').alias('US_HOLIDAY'),\n", 67 | " F.col('EXPL_PRECIP').alias('PRECIPITATION'),\n", 68 | " F.col('EXPL_TEMP').alias('TEMPERATURE'),\n", 69 | " )\\\n", 70 | " .to_pandas()\n", 71 | " \n", 72 | " df['REASON'] = pd.Categorical(df['REASON'])\n", 73 | " df['REASON_CODE']=df['REASON'].cat.codes\n", 74 | " \n", 75 | " rect = alt.Chart(df).mark_rect().encode(alt.X('DATE:N'), \n", 76 | " alt.Y('STATION_ID:N'), \n", 77 | " alt.Color('REASON'),\n", 78 | " tooltip=explainer_columns_new)\n", 79 | " text = rect.mark_text(baseline='middle').encode(text='PRED:Q', color=alt.value('white'))\n", 80 | "\n", 81 | " l = alt.layer(\n", 82 | " rect, text\n", 83 | " )\n", 84 | "\n", 85 | " st.write(\"### Forecast\")\n", 86 | " st.altair_chart(l, use_container_width=True)\n", 87 | " \n", 88 | " return None\n", 89 | "\n", 90 | "def update_eval_table(eval_df, stations:list):\n", 91 | " df = eval_df.select('STATION_ID', F.to_char(F.col('RUN_DATE')).alias('RUN_DATE'), 'RMSE')\\\n", 92 | " .filter(eval_df['STATION_ID'].in_(stations))\\\n", 93 | " .to_pandas()\n", 94 | "\n", 95 | " data = df.pivot(index=\"RUN_DATE\", columns=\"STATION_ID\", values=\"RMSE\")\n", 96 | " data = data.reset_index().melt('RUN_DATE', var_name='STATION_ID', value_name='RMSE')\n", 97 | "\n", 98 | " nearest = alt.selection(type='single', nearest=True, on='mouseover',\n", 99 | " fields=['RUN_DATE'], empty='none')\n", 100 | "\n", 101 | " line = alt.Chart(data).mark_line(interpolate='basis').encode(\n", 102 | " x='RUN_DATE:N',\n", 103 | " y='RMSE:Q',\n", 104 | " color='STATION_ID:N'\n", 105 | " )\n", 106 | "\n", 107 | " selectors = alt.Chart(data).mark_point().encode(\n", 108 | " x='RUN_DATE:N',\n", 109 | " opacity=alt.value(0)\n", 110 | " ).add_selection(\n", 111 | " nearest\n", 112 | " )\n", 113 | "\n", 114 | " points = line.mark_point().encode(\n", 115 | " opacity=alt.condition(nearest, alt.value(1), alt.value(0))\n", 116 | " )\n", 117 | "\n", 118 | " text = line.mark_text(align='left', dx=5, dy=-5).encode(\n", 119 | " text=alt.condition(nearest, 'RMSE:Q', alt.value(' '))\n", 120 | " )\n", 121 | "\n", 122 | " rules = alt.Chart(data).mark_rule(color='gray').encode(\n", 123 | " x='RUN_DATE:N',\n", 124 | " ).transform_filter(\n", 125 | " nearest\n", 126 | " )\n", 127 | "\n", 128 | " l = alt.layer(\n", 129 | " line, selectors, points, rules, text\n", 130 | " ).properties(\n", 131 | " width=600, height=300\n", 132 | " )\n", 133 | " st.write(\"### Model Monitor\")\n", 134 | " st.altair_chart(l, use_container_width=True)\n", 135 | " \n", 136 | " return None\n", 137 | "\n", 138 | "def trigger_ingest(download_file_name, run_date): \n", 139 | " dag_url='http://localhost:8080/api/v1/dags/citibikeml_monthly_taskflow/dagRuns'\n", 140 | " json_payload = {\"conf\": {\"files_to_download\": [download_file_name], \"run_date\": run_date}}\n", 141 | " \n", 142 | " response = requests.post(dag_url, \n", 143 | " json=json_payload,\n", 144 | " auth = HTTPBasicAuth('admin', 'admin'))\n", 145 | "\n", 146 | " run_id = json.loads(response.text)['dag_run_id']\n", 147 | " #run_id = 'manual__2022-04-07T15:02:29.166108+00:00'\n", 148 | "\n", 149 | " state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']\n", 150 | "\n", 151 | " st.snow()\n", 152 | "\n", 153 | " with st.spinner('Ingesting file: '+download_file_name):\n", 154 | " while state != 'success':\n", 155 | " time.sleep(5)\n", 156 | " state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']\n", 157 | " st.success('Ingested file: '+download_file_name+' State: '+str(state))\n", 158 | "\n", 159 | "#Main Body \n", 160 | "session, state_dict = snowpark_connect('./include/state.json')\n", 161 | "forecast_df = session.table('FLAT_FORECAST')\n", 162 | "eval_df = session.table('FLAT_EVAL')\n", 163 | "trips_df = session.table('TRIPS')\n", 164 | "\n", 165 | "st.header('Citibike Forecast Application')\n", 166 | "st.write('In this application we leverage deep learning models to predict the number of trips started from '+\n", 167 | " 'a given station each day. After selecting the stations and time range desired the application '+\\\n", 168 | " 'displays not only the forecast but also explains which features of the model were most used in making '+\\\n", 169 | " 'the prediction. Additionally users can see the historical performance of the deep learning model to '+\\\n", 170 | " 'monitor predictive capabilities over time.')\n", 171 | "\n", 172 | "last_trip_date = trips_df.select(F.to_date(F.max('STARTTIME'))).collect()[0][0]\n", 173 | "st.write('Data provided as of '+str(last_trip_date))\n", 174 | "\n", 175 | "#Create a sidebar for input\n", 176 | "min_date=forecast_df.select(F.min('DATE')).collect()[0][0]\n", 177 | "max_date=forecast_df.select(F.max('DATE')).collect()[0][0]\n", 178 | "\n", 179 | "start_date = st.sidebar.date_input('Start Date', value=min_date, min_value=min_date, max_value=max_date)\n", 180 | "show_days = st.sidebar.number_input('Number of days to show', value=7, min_value=1, max_value=30)\n", 181 | "end_date = start_date+timedelta(days=show_days)\n", 182 | "\n", 183 | "stations_df=forecast_df.select(F.col('STATION_ID')).distinct().to_pandas()\n", 184 | "\n", 185 | "sample_stations = [\"519\", \"497\", \"435\", \"402\", \"426\", \"285\", \"293\"]\n", 186 | "\n", 187 | "stations = st.sidebar.multiselect('Choose stations', stations_df['STATION_ID'], sample_stations)\n", 188 | "if not stations:\n", 189 | " stations = stations_df['STATION_ID']\n", 190 | "\n", 191 | "update_forecast_table(forecast_df, stations, start_date, end_date)\n", 192 | "\n", 193 | "update_eval_table(eval_df, stations)\n", 194 | "\n", 195 | "\n", 196 | "next_ingest = last_trip_date+relativedelta(months=+1)\n", 197 | "next_ingest = next_ingest.replace(day=1) \n", 198 | "\n", 199 | "if next_ingest <= datetime.strptime(\"2016-12-01\", \"%Y-%m-%d\").date():\n", 200 | " download_file_name=next_ingest.strftime('%Y%m')+'-citibike-tripdata.zip'\n", 201 | "else:\n", 202 | " download_file_name=next_ingest.strftime('%Y%m')+'-citibike-tripdata.zip'\n", 203 | " \n", 204 | "run_date = next_ingest+relativedelta(months=+1)\n", 205 | "run_date = run_date.strftime('%Y_%m_%d')\n", 206 | "\n", 207 | "st.write('Next ingest for '+str(next_ingest))\n", 208 | "\n", 209 | "st.button('Run Ingest Taskflow', on_click=trigger_ingest, args=(download_file_name, run_date))\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "id": "579a1566", 215 | "metadata": {}, 216 | "source": [ 217 | "If running in SageMaker Studio Lab update the domain name from the URL in your browser. \n", 218 | "For example if the Studio Lab URL is https://**yyy9xxxxxxxxxxx**.studio.us-east-2.sagemaker.aws/studiolab/default/jupyter/lab \n", 219 | "the domain name is **yyy9xxxxxxxxxxx**. ):" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "fd15d51e", 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "studiolab_domain = ''\n", 230 | "\n", 231 | "# launch\n", 232 | "if studiolab_domain:\n", 233 | " studiolab_region = 'us-east-2'\n", 234 | " url = f'https://{studiolab_domain}.studio.{studiolab_region}.sagemaker.aws/studiolab/default/jupyter/proxy/6006/'\n", 235 | " \n", 236 | "else: \n", 237 | " \n", 238 | " url = f'http://127.0.0.1:6006'\n", 239 | "\n", 240 | "print(f'Wait a few seconds and then click the link below to open your Streamlit application \\n{url}\\n')\n", 241 | "\n", 242 | "!streamlit run --theme.base dark include/streamlit_app.py --server.port 6006 \\\n", 243 | " --server.address 127.0.0.1 \\\n", 244 | " --server.headless true" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "dc188650", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [] 254 | } 255 | ], 256 | "metadata": { 257 | "kernelspec": { 258 | "display_name": "Python 3 (ipykernel)", 259 | "language": "python", 260 | "name": "python3" 261 | }, 262 | "language_info": { 263 | "codemirror_mode": { 264 | "name": "ipython", 265 | "version": 3 266 | }, 267 | "file_extension": ".py", 268 | "mimetype": "text/x-python", 269 | "name": "python", 270 | "nbconvert_exporter": "python", 271 | "pygments_lexer": "ipython3", 272 | "version": "3.8.13" 273 | } 274 | }, 275 | "nbformat": 4, 276 | "nbformat_minor": 5 277 | } 278 | -------------------------------------------------------------------------------- /00_Setup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Initial Setup or Reset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "tags": [] 14 | }, 15 | "source": [ 16 | "## Warning!!!!\n", 17 | "### Running this code will delete an existing database as specified in the state dictionary below.\n", 18 | "\n", 19 | "We need a way to save state throughout the project. We will initially login as the ACCOUNTADMIN role in order to setup some additional users as well as the compute resources we will need. \n", 20 | "\n", 21 | "We will specify a couple of different compute resources which allows us to scale up and down easily. Most of the workflow can use an extra-small warehouse but for certain tasks (ie. feature engineering and model training) we may need larger compute. By specifying them in the state dictionary we can easily select the correct compute for any particular task.\n", 22 | " \n", 23 | "Update the \\, \\, \\ in the state dictionary below with the initial user that was created with your trial account.\n", 24 | "\n", 25 | "Note: If you are running the US West (Oregon) region, you don't need to add the \\." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "state_dict = {\n", 35 | " \"connection_parameters\": {\"user\": \"\",\n", 36 | " \"account\": \".\",\n", 37 | " \"role\": \"ACCOUNTADMIN\"\n", 38 | " },\n", 39 | " \"compute_parameters\" : {\"default_warehouse\": \"XSMALL_WH\", \n", 40 | " \"task_warehouse\": \"XSMALL_WH\", \n", 41 | " \"load_warehouse\": \"LARGE_WH\", \n", 42 | " \"fe_warehouse\": \"XXLARGE_WH\",\n", 43 | " \"train_warehouse\": \"XXLARGE_WH\",\n", 44 | " \"train_warehouse_sow\": \"XXLARGE_SNOWPARKOPT_WH\" \n", 45 | " }\n", 46 | "}" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import json\n", 56 | "with open('./include/state.json', 'w') as sdf:\n", 57 | " json.dump(state_dict, sdf)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "We will connect with username and password. In a non-demo system it is very important to use properly secured passwords with secret managers and/or oauth." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "import snowflake.snowpark as snp\n", 74 | "import json\n", 75 | "import getpass\n", 76 | "\n", 77 | "account_admin_password = getpass.getpass('Enter password for user with ACCOUNTADMIN role access')\n", 78 | "\n", 79 | "with open('./include/state.json') as sdf:\n", 80 | " state_dict = json.load(sdf) \n", 81 | "state_dict['connection_parameters']['password'] = account_admin_password\n", 82 | "\n", 83 | "session = snp.Session.builder.configs(state_dict[\"connection_parameters\"]).create()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "We will also use a specific AWS S3 role for accessing pre-staged files to speed up the hands-on-lab." 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "state_dict['connection_parameters']['download_base_url'] = 's3://sfquickstarts/vhol_citibike_ml_snowpark_python/data'" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "To run this without access to pre-staged files run the following cell instead of the cell above." 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "#state_dict['connection_parameters']['download_base_url'] = 'https://s3.amazonaws.com/tripdata/'" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "Create a sample user which will be used for the hands-on-lab. Normally you will have different roles (and possibly different users) for data scientists, data engineers, ML engineers, etc." 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "session.use_role('securityadmin')\n", 132 | "\n", 133 | "demo_username='jack'\n", 134 | "project_role='PUBLIC'\n", 135 | "\n", 136 | "session.sql(\"CREATE USER IF NOT EXISTS \"+demo_username+\\\n", 137 | " \" LOGIN_NAME = '\"+demo_username+\"'\"+\\\n", 138 | " \" FIRST_NAME = 'SNOWPARK'\"+\\\n", 139 | " \" LAST_NAME = 'HOL'\"+\\\n", 140 | " \" EMAIL = 'jack@hol.snowpark'\"+\\\n", 141 | " \" DEFAULT_ROLE = '\"+project_role+\"'\"+\\\n", 142 | " \" MUST_CHANGE_PASSWORD = FALSE\")\\\n", 143 | " .collect()\n", 144 | "\n", 145 | "session.sql(\"GRANT ROLE \"+project_role+\" TO USER \"+demo_username).collect()\n", 146 | "\n", 147 | "session.use_role('sysadmin')\n", 148 | "session.sql(\"GRANT CREATE DATABASE ON ACCOUNT TO ROLE \"+project_role).collect()" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "tags": [] 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "session.use_role('securityadmin')\n", 160 | "demo_user_password=getpass.getpass('Enter a new password for the demo user '+demo_username)\n", 161 | "session.sql(\"ALTER USER \"+demo_username+\" SET PASSWORD = '\"+demo_user_password+\"'\").collect()" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "Create compute instances as specified in the state dictionary." 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "session.use_role('ACCOUNTADMIN')\n", 178 | "project_role='PUBLIC'\n", 179 | "\n", 180 | "for wh in state_dict['compute_parameters'].values():\n", 181 | " if wh != \"train_warehouse\":\n", 182 | " session.sql(\"CREATE WAREHOUSE IF NOT EXISTS \"+wh+\\\n", 183 | " \" WITH WAREHOUSE_SIZE = '\"+wh.split('_')[0]+\\\n", 184 | " \"' WAREHOUSE_TYPE = 'STANDARD' AUTO_SUSPEND = 60 AUTO_RESUME = TRUE initially_suspended = true;\")\\\n", 185 | " .collect()\n", 186 | " elif wh == \"train_warehouse\":\n", 187 | " session.sql(\"CREATE WAREHOUSE IF NOT EXISTS \"+wh+\\\n", 188 | " \" WITH WAREHOUSE_SIZE = '\"+wh.split('_')[0]+\\\n", 189 | " \"' WAREHOUSE_TYPE = 'HIGH_MEMORY' MAX_CONCURRENCY_LEVEL = 1 AUTO_SUSPEND = 60 AUTO_RESUME = TRUE initially_suspended = true;\")\\\n", 190 | " .collect()\n", 191 | " session.sql(\"GRANT USAGE ON WAREHOUSE \"+wh+\" TO ROLE \"+project_role).collect() \n", 192 | " session.sql(\"GRANT OPERATE ON WAREHOUSE \"+wh+\" TO ROLE \"+project_role).collect() \n", 193 | " \n", 194 | "session.use_role(state_dict['connection_parameters']['role'])" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "Allow users to import data shares." 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "session.use_role('ACCOUNTADMIN')\n", 211 | "session.sql(\"GRANT IMPORT SHARE ON ACCOUNT TO \"+project_role).collect()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "Now update the state dictionary to use the non-admin account." 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "state_dict['connection_parameters']['user'] = demo_username\n", 228 | "state_dict['connection_parameters']['password'] = demo_user_password\n", 229 | "state_dict['connection_parameters']['role'] = project_role\n", 230 | "state_dict['connection_parameters']['database'] = 'CITIBIKEML_'+demo_username\n", 231 | "state_dict['connection_parameters']['schema'] = 'DEMO'" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "Save the updated state dictionary for project team use." 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "import json\n", 248 | "with open('./include/state.json', 'w') as sdf:\n", 249 | " json.dump(state_dict, sdf)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "Create a python function to simplify the users' steps of starting a session." 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "%%writefile dags/snowpark_connection.py\n", 266 | "def snowpark_connect(state_file='./include/state.json'):\n", 267 | " import snowflake.snowpark as snp\n", 268 | " import json\n", 269 | " \n", 270 | " with open(state_file) as sdf:\n", 271 | " state_dict = json.load(sdf) \n", 272 | " \n", 273 | " session=None\n", 274 | " session = snp.Session.builder.configs(state_dict[\"connection_parameters\"]).create()\n", 275 | " session.use_warehouse(state_dict['compute_parameters']['default_warehouse'])\n", 276 | " return session, state_dict" 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "Test the function that users will use." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "from dags.snowpark_connection import snowpark_connect\n", 293 | "session, state_dict = snowpark_connect()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": {}, 299 | "source": [ 300 | "Make sure the user has access to each compute instance." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "session.get_current_warehouse()" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "for wh in state_dict['compute_parameters'].keys():\n", 319 | " session.use_warehouse(state_dict['compute_parameters'][wh])" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "session.get_current_warehouse()" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "Create the database and schema for this project." 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": { 342 | "scrolled": true 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "session.close()" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [] 355 | } 356 | ], 357 | "metadata": { 358 | "authors": [ 359 | { 360 | "name": "cforbe" 361 | } 362 | ], 363 | "kernelspec": { 364 | "display_name": "snowpark_0110:Python", 365 | "language": "python", 366 | "name": "conda-env-snowpark_0110-py" 367 | }, 368 | "language_info": { 369 | "codemirror_mode": { 370 | "name": "ipython", 371 | "version": 3 372 | }, 373 | "file_extension": ".py", 374 | "mimetype": "text/x-python", 375 | "name": "python", 376 | "nbconvert_exporter": "python", 377 | "pygments_lexer": "ipython3", 378 | "version": "3.8.13" 379 | }, 380 | "msauthor": "trbye" 381 | }, 382 | "nbformat": 4, 383 | "nbformat_minor": 4 384 | } 385 | -------------------------------------------------------------------------------- /dags/mlops_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | def materialize_holiday_table(session, holiday_table_name:str) -> str: 3 | from dags.feature_engineering import generate_holiday_df 4 | 5 | holiday_df = generate_holiday_df(session=session, holiday_table_name=holiday_table_name) 6 | holiday_df.write.mode('overwrite').saveAsTable(holiday_table_name) 7 | 8 | return holiday_table_name 9 | 10 | def subscribe_to_weather_data(session, 11 | weather_database_name:str, 12 | weather_listing_id:str) -> str: 13 | 14 | session.sql("CREATE DATABASE IF NOT EXISTS "+weather_database_name+\ 15 | " FROM SHARE "+weather_listing_id).collect() 16 | 17 | return weather_database_name 18 | 19 | def create_weather_view(session, weather_table_name:str, weather_view_name:str) -> str: 20 | from dags.feature_engineering import generate_weather_df 21 | 22 | weather_df = generate_weather_df(session=session, weather_table_name=weather_table_name) 23 | 24 | weather_df.create_or_replace_view(weather_view_name) 25 | 26 | return weather_view_name 27 | 28 | def deploy_pred_train_udf(session, udf_name:str, function_name:str, model_stage_name:str) -> str: 29 | from dags.station_train_predict import StationTrainPredictFunc 30 | from snowflake.snowpark import types as T 31 | from snowflake.snowpark.functions import udtf 32 | 33 | session.clear_packages() 34 | session.clear_imports() 35 | dep_packages=["pandas==1.3.5", "pytorch==1.10.2", "scipy==1.7.1", "scikit-learn==1.0.2", "setuptools==58.0.4", "cloudpickle==2.0.0"] 36 | dep_imports=['./include/pytorch_tabnet.zip', 'dags'] 37 | 38 | station_train_predict_udtf = udtf(StationTrainPredictFunc, 39 | name="station_train_predict_udtf", 40 | session=session, 41 | is_permanent=True, 42 | stage_location='@'+str(model_stage_name), 43 | imports=dep_imports, 44 | packages=dep_packages, 45 | input_types=[T.DateType(), 46 | T.DecimalType(), 47 | T.DecimalType(), 48 | T.DecimalType(), 49 | T.DecimalType(), 50 | T.DecimalType(), 51 | T.DecimalType(), 52 | T.DecimalType(38, 1), 53 | T.DecimalType(38, 1), 54 | T.ArrayType(), 55 | T.StringType(), 56 | T.DecimalType(), 57 | T.DecimalType(), 58 | T.ArrayType(), 59 | T.ArrayType(), 60 | T.ArrayType()], 61 | output_schema=T.StructType([T.StructField("PRED_DATA", T.VariantType())]), 62 | replace=True) 63 | 64 | return station_train_predict_udtf.name 65 | 66 | 67 | def deploy_eval_udf(session, udf_name:str, function_name:str, model_stage_name:str) -> str: 68 | from dags.model_eval import eval_model_func 69 | from snowflake.snowpark import types as T 70 | 71 | session.clear_packages() 72 | session.clear_imports() 73 | dep_packages=['pandas==1.3.5', 'scikit-learn==1.0.2', "cloudpickle==2.0.0"] 74 | dep_imports=['./include/rexmex.zip', 'dags'] 75 | 76 | eval_model_output_udf = session.udf.register(eval_model_func, 77 | session=session, 78 | name=udf_name, 79 | is_permanent=True, 80 | stage_location='@'+str(model_stage_name), 81 | imports=dep_imports, 82 | packages=dep_packages, 83 | input_types=[T.StringType(), 84 | T.StringType(), 85 | T.StringType()], 86 | return_type=T.VariantType(), 87 | replace=True) 88 | return eval_model_output_udf.name 89 | 90 | def create_forecast_table(session, 91 | trips_table_name:str, 92 | holiday_table_name:str, 93 | weather_view_name:str, 94 | forecast_table_name:str, 95 | steps:int): 96 | 97 | from dags.feature_engineering import generate_holiday_df 98 | from datetime import timedelta, datetime 99 | from snowflake.snowpark import functions as F 100 | 101 | start_date = session.table(trips_table_name)\ 102 | .select(F.to_date(F.max('STARTTIME'))).collect()[0][0]+timedelta(days=1) 103 | end_date = start_date+timedelta(days=steps) 104 | 105 | #check if it tables already materialized, otherwise generate DF 106 | holiday_df = session.table(holiday_table_name) 107 | try: 108 | _ = holiday_df.columns 109 | except: 110 | holiday_df = generate_holiday_df(session, holiday_table_name) 111 | 112 | weather_df = session.table(weather_view_name) 113 | 114 | forecast_df = holiday_df.join(weather_df[['DATE','PRECIP','TEMP']], 'DATE', join_type='right')\ 115 | .na.fill({'HOLIDAY':0})\ 116 | .filter((F.col('DATE') >= start_date) &\ 117 | (F.col('DATE') <= end_date))\ 118 | .sort('DATE', ascending=True) 119 | 120 | forecast_df.write.mode('overwrite').save_as_table(forecast_table_name) 121 | 122 | return forecast_table_name 123 | 124 | 125 | def create_feature_table(session, 126 | trips_table_name:str, 127 | holiday_table_name:str, 128 | weather_view_name:str, 129 | feature_table_name:str) -> list: 130 | 131 | import snowflake.snowpark as snp 132 | from snowflake.snowpark import functions as F 133 | from dags.feature_engineering import generate_holiday_df, generate_weather_df 134 | 135 | #check if it tables already materialized, otherwise generate DF 136 | holiday_df = session.table(holiday_table_name) 137 | try: 138 | _ = holiday_df.columns 139 | except: 140 | holiday_df = generate_holiday_df(session, holiday_table_name) 141 | 142 | weather_df = session.table(weather_view_name) 143 | 144 | sid_date_window = snp.Window.partition_by(F.col('STATION_ID')).order_by(F.col('DATE').asc()) 145 | sid_window = snp.Window.partition_by(F.col('STATION_ID')) 146 | latest_date = session.table(trips_table_name).select(F.to_char(F.to_date(F.max('STARTTIME')))).collect()[0][0] 147 | 148 | feature_df = session.table(trips_table_name)\ 149 | .select(F.to_date(F.col('STARTTIME')).alias('DATE'), 150 | F.col('START_STATION_ID').alias('STATION_ID'))\ 151 | .group_by(F.col('STATION_ID'), F.col('DATE'))\ 152 | .count()\ 153 | .with_column('LAG_1', F.lag(F.col('COUNT'), offset=1).over(sid_date_window))\ 154 | .with_column('LAG_7', F.lag(F.col('COUNT'), offset=7).over(sid_date_window))\ 155 | .with_column('LAG_90', F.lag(F.col('COUNT'), offset=90).over(sid_date_window))\ 156 | .with_column('LAG_365', F.lag(F.col('COUNT'), offset=365).over(sid_date_window))\ 157 | .na.drop()\ 158 | .join(holiday_df, 'DATE', join_type='left').na.fill({'HOLIDAY':0})\ 159 | .join(weather_df[['DATE','PRECIP','TEMP']], 'DATE', 'inner')\ 160 | .with_column('DAY_COUNT', F.count(F.col('DATE')).over(sid_window))\ 161 | .filter(F.col('DAY_COUNT') >= 365*2)\ 162 | .with_column('MAX_DATE', F.max('DATE').over(sid_window))\ 163 | .filter(F.col('MAX_DATE') == latest_date)\ 164 | .drop(['DAY_COUNT', 'MAX_DATE']) 165 | 166 | feature_df.write.mode('overwrite').save_as_table(feature_table_name) 167 | 168 | return feature_table_name 169 | 170 | def train_predict(session, 171 | station_train_pred_udf_name:str, 172 | feature_table_name:str, 173 | forecast_table_name:str, 174 | pred_table_name:str) -> list: 175 | 176 | from snowflake.snowpark import functions as F 177 | 178 | cutpoint=365 179 | max_epochs = 10 180 | target_column = 'COUNT' 181 | lag_values=[1,7,90,365] 182 | lag_values_array = F.array_construct(*[F.lit(x) for x in lag_values]) 183 | 184 | historical_df = session.table(feature_table_name) 185 | historical_column_list = historical_df.columns 186 | historical_column_list.remove('STATION_ID') 187 | historical_column_names = F.array_construct(*[F.lit(x) for x in historical_column_list]) 188 | 189 | forecast_df = session.table(forecast_table_name) 190 | forecast_column_list = forecast_df.columns 191 | forecast_column_names = F.array_construct(*[F.lit(x) for x in forecast_column_list]) 192 | forecast_df = forecast_df.select(F.array_agg(F.array_construct(F.col('*'))).alias('FORECAST_DATA')) 193 | 194 | station_train_predict = F.table_function("station_train_predict_udtf") 195 | 196 | train_df = historical_df.join(forecast_df) 197 | udtf_input = train_df.select(train_df['STATION_ID'], station_train_predict(train_df['DATE'], \ 198 | train_df['COUNT'], \ 199 | train_df['LAG_1'], \ 200 | train_df['LAG_7'], \ 201 | train_df['LAG_90'], \ 202 | train_df['LAG_365'], \ 203 | train_df['HOLIDAY'], \ 204 | train_df['PRECIP'], \ 205 | train_df['TEMP'], \ 206 | F.lit(historical_column_names), \ 207 | F.lit(target_column), \ 208 | F.lit(cutpoint), \ 209 | F.lit(max_epochs), \ 210 | F.lit(forecast_column_names), \ 211 | train_df['FORECAST_DATA'], \ 212 | F.lit(lag_values_array)) \ 213 | .over(partition_by = 'STATION_ID')) \ 214 | .write.mode('overwrite') \ 215 | .save_as_table(pred_table_name) 216 | 217 | return pred_table_name 218 | 219 | def evaluate_station_model(session, 220 | run_date:str, 221 | eval_model_udf_name:str, 222 | pred_table_name:str, 223 | eval_table_name:str): 224 | from snowflake.snowpark import functions as F 225 | from datetime import datetime 226 | 227 | y_true_name='COUNT' 228 | y_score_name='PRED' 229 | run_date=datetime.strptime(run_date, '%Y_%m_%d').date() 230 | 231 | session.table(pred_table_name)\ 232 | .select('STATION_ID', 233 | F.call_udf(eval_model_udf_name, 234 | F.parse_json(F.col('PRED_DATA')[0]), 235 | F.lit(y_true_name), 236 | F.lit(y_score_name)).alias('EVAL_DATA'))\ 237 | .with_column('RUN_DATE', F.to_date(F.lit(run_date)))\ 238 | .write.mode('overwrite')\ 239 | .save_as_table(eval_table_name) 240 | 241 | return eval_table_name 242 | 243 | def flatten_tables(session, pred_table_name:str, forecast_table_name:str, eval_table_name:str): 244 | from snowflake.snowpark import functions as F 245 | 246 | session.table(pred_table_name)\ 247 | .select('STATION_ID', F.parse_json(F.col('PRED_DATA')[0]).alias('PRED_DATA'))\ 248 | .flatten('PRED_DATA').select('STATION_ID', F.col('VALUE').alias('PRED_DATA'))\ 249 | .select('STATION_ID', 250 | F.to_date(F.col('PRED_DATA')['DATE']).alias('DATE'), 251 | F.as_integer(F.col('PRED_DATA')['COUNT']).alias('COUNT'), 252 | F.as_integer(F.col('PRED_DATA')['LAG_1']).alias('LAG_1'), 253 | F.as_integer(F.col('PRED_DATA')['LAG_7']).alias('LAG_7'), 254 | F.as_integer(F.col('PRED_DATA')['LAG_90']).alias('LAG_90'), 255 | F.as_integer(F.col('PRED_DATA')['LAG_365']).alias('LAG_365'), 256 | F.as_integer(F.col('PRED_DATA')['HOLIDAY']).alias('HOLIDAY'), 257 | F.as_decimal(F.col('PRED_DATA')['PRECIP']).alias('PRECIP'), 258 | F.as_decimal(F.col('PRED_DATA')['TEMP']).alias('TEMP'), 259 | F.as_decimal(F.col('PRED_DATA')['PRED']).alias('PRED'), 260 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_1']).alias('EXPL_LAG_1'), 261 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_7']).alias('EXPL_LAG_7'), 262 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_90']).alias('EXPL_LAG_90'), 263 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_365']).alias('EXPL_LAG_365'), 264 | F.as_decimal(F.col('PRED_DATA')['EXPL_HOLIDAY']).alias('EXPL_HOLIDAY'), 265 | F.as_decimal(F.col('PRED_DATA')['EXPL_PRECIP']).alias('EXPL_PRECIP'), 266 | F.as_decimal(F.col('PRED_DATA')['EXPL_TEMP']).alias('EXPL_TEMP'))\ 267 | .write.mode('overwrite').save_as_table('flat_PRED') 268 | 269 | #forecast are in position 2 of the pred_table 270 | session.table(pred_table_name)\ 271 | .select('STATION_ID', F.parse_json(F.col('PRED_DATA')[1]).alias('PRED_DATA'))\ 272 | .flatten('PRED_DATA').select('STATION_ID', F.col('VALUE').alias('PRED_DATA'))\ 273 | .select('STATION_ID', 274 | F.to_date(F.col('PRED_DATA')['DATE']).alias('DATE'), 275 | F.as_integer(F.col('PRED_DATA')['COUNT']).alias('COUNT'), 276 | F.as_integer(F.col('PRED_DATA')['LAG_1']).alias('LAG_1'), 277 | F.as_integer(F.col('PRED_DATA')['LAG_7']).alias('LAG_7'), 278 | F.as_integer(F.col('PRED_DATA')['LAG_90']).alias('LAG_90'), 279 | F.as_integer(F.col('PRED_DATA')['LAG_365']).alias('LAG_365'), 280 | F.as_integer(F.col('PRED_DATA')['HOLIDAY']).alias('HOLIDAY'), 281 | F.as_decimal(F.col('PRED_DATA')['PRECIP']).alias('PRECIP'), 282 | F.as_decimal(F.col('PRED_DATA')['TEMP']).alias('TEMP'), 283 | F.as_decimal(F.col('PRED_DATA')['PRED']).alias('PRED'), 284 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_1']).alias('EXPL_LAG_1'), 285 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_7']).alias('EXPL_LAG_7'), 286 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_90']).alias('EXPL_LAG_90'), 287 | F.as_decimal(F.col('PRED_DATA')['EXPL_LAG_365']).alias('EXPL_LAG_365'), 288 | F.as_decimal(F.col('PRED_DATA')['EXPL_HOLIDAY']).alias('EXPL_HOLIDAY'), 289 | F.as_decimal(F.col('PRED_DATA')['EXPL_PRECIP']).alias('EXPL_PRECIP'), 290 | F.as_decimal(F.col('PRED_DATA')['EXPL_TEMP']).alias('EXPL_TEMP'))\ 291 | .write.mode('overwrite').save_as_table('flat_FORECAST') 292 | 293 | session.table(eval_table_name)\ 294 | .select('RUN_DATE', 'STATION_ID', F.parse_json(F.col('EVAL_DATA')).alias('EVAL_DATA'))\ 295 | .flatten('EVAL_DATA').select('RUN_DATE', 'STATION_ID', F.col('VALUE').alias('EVAL_DATA'))\ 296 | .select('RUN_DATE', 'STATION_ID', 297 | F.as_decimal(F.col('EVAL_DATA')['mae'], 10, 2).alias('mae'), 298 | F.as_decimal(F.col('EVAL_DATA')['mape'], 10, 2).alias('mape'), 299 | F.as_decimal(F.col('EVAL_DATA')['mse'], 10, 2).alias('mse'), 300 | F.as_decimal(F.col('EVAL_DATA')['r_squared'], 10, 2).alias('r_squared'), 301 | F.as_decimal(F.col('EVAL_DATA')['rmse'], 10, 2).alias('rmse'), 302 | F.as_decimal(F.col('EVAL_DATA')['smape'], 10, 2).alias('smape'),)\ 303 | .write.mode('append').save_as_table('flat_EVAL') 304 | 305 | return 'flat_PRED', 'flat_FORECAST', 'flat_EVAL' 306 | 307 | -------------------------------------------------------------------------------- /01_Ingest.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Create Ingest Logic" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Incremental and Bulk Extract, Load and Transform\n", 15 | "We expect to get new data every month which we will incrementally load. Here we will create some functions to wrap the ELT functions from the Data Engineer." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from dags.snowpark_connection import snowpark_connect\n", 25 | "session, state_dict = snowpark_connect()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "from dags import elt as ELT\n", 35 | "\n", 36 | "import snowflake.snowpark as snp\n", 37 | "import uuid \n", 38 | "\n", 39 | "state_dict.update({'download_base_url': 'https://s3.amazonaws.com/tripdata/',\n", 40 | " 'load_table_name': 'RAW_',\n", 41 | " 'trips_table_name': 'TRIPS',\n", 42 | " 'load_stage_name': 'LOAD_STAGE'\n", 43 | " })\n", 44 | "\n", 45 | "import json\n", 46 | "with open('./include/state.json', 'w') as sdf:\n", 47 | " json.dump(state_dict, sdf)\n", 48 | " \n", 49 | "ELT.reset_database(session=session, state_dict=state_dict, prestaged=False)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "First we will test the ELT functions. We pick a couple of files representing the various schema and file names." 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "file_name_end2 = '202102-citibike-tripdata.csv.zip'\n", 66 | "file_name_end1 = '201402-citibike-tripdata.zip'\n", 67 | "file_name_end3 = '202003-citibike-tripdata.csv.zip'\n", 68 | "\n", 69 | "files_to_download = [file_name_end1, file_name_end2, file_name_end3]" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "%%time\n", 79 | "load_stage_names, files_to_load = ELT.extract_trips_to_stage(session=session, \n", 80 | " files_to_download=files_to_download, \n", 81 | " download_base_url=state_dict['download_base_url'], \n", 82 | " load_stage_name=state_dict['load_stage_name'])" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "This ELT logic requires downloading data to the local system in order to unzip as well as upload the file to a stage. This can be really slow depending on network speed. Later we will provide a __bulk-load option that uses data already in gzip format in order to speed up the hands-on-lab__." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "%%time\n", 99 | "\n", 100 | "files_to_load['schema1']=[file+'.gz' for file in files_to_load['schema1']]\n", 101 | "files_to_load['schema2']=[file+'.gz' for file in files_to_load['schema2']]\n", 102 | "\n", 103 | "stage_table_names = ELT.load_trips_to_raw(session=session, \n", 104 | " files_to_load=files_to_load, \n", 105 | " load_stage_names=load_stage_names, \n", 106 | " load_table_name=state_dict['load_table_name'])" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "%%time\n", 116 | "trips_table_name = ELT.transform_trips(session=session, \n", 117 | " stage_table_names=stage_table_names, \n", 118 | " trips_table_name=state_dict['trips_table_name'])" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "Since there are two separate schemas we will create two separate ingest paths. For that we will want to separate the files into two groups like the following." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "from datetime import datetime\n", 135 | "\n", 136 | "files_to_ingest=['202004-citibike-tripdata.csv.zip', '202102-citibike-tripdata.csv.zip']\n", 137 | "schema1_download_files = list()\n", 138 | "schema2_download_files = list()\n", 139 | "schema2_start_date = datetime.strptime('202102', \"%Y%m\")\n", 140 | "\n", 141 | "for file_name in files_to_ingest:\n", 142 | " file_start_date = datetime.strptime(file_name.split(\"-\")[0], \"%Y%m\")\n", 143 | " if file_start_date < schema2_start_date:\n", 144 | " schema1_download_files.append(file_name.replace('.zip','.gz'))\n", 145 | " else:\n", 146 | " schema2_download_files.append(file_name.replace('.zip','.gz'))\n", 147 | " \n", 148 | "files_to_load = {'schema1': schema1_download_files, 'schema2': schema2_download_files}\n", 149 | "files_to_load" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "Here we create the incremental ELT function as well as a bulk load function. The bulk ingest function wraps the incremental ingest with a full set of data to bootstrap the project." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "%%writefile dags/ingest.py\n", 166 | "def incremental_elt(session, \n", 167 | " state_dict:dict, \n", 168 | " files_to_ingest:list, \n", 169 | " download_base_url,\n", 170 | " use_prestaged=False) -> str:\n", 171 | " \n", 172 | " import dags.elt as ELT\n", 173 | " from datetime import datetime\n", 174 | "\n", 175 | " load_stage_name=state_dict['load_stage_name']\n", 176 | " load_table_name=state_dict['load_table_name']\n", 177 | " trips_table_name=state_dict['trips_table_name']\n", 178 | " \n", 179 | " if use_prestaged:\n", 180 | " print(\"Skipping extract. Using provided bucket for pre-staged files.\")\n", 181 | " \n", 182 | " schema1_download_files = list()\n", 183 | " schema2_download_files = list()\n", 184 | " schema2_start_date = datetime.strptime('202102', \"%Y%m\")\n", 185 | "\n", 186 | " for file_name in files_to_ingest:\n", 187 | " file_start_date = datetime.strptime(file_name.split(\"-\")[0], \"%Y%m\")\n", 188 | " if file_start_date < schema2_start_date:\n", 189 | " schema1_download_files.append(file_name.replace('.zip','.gz'))\n", 190 | " else:\n", 191 | " schema2_download_files.append(file_name.replace('.zip','.gz'))\n", 192 | " \n", 193 | " \n", 194 | " load_stage_names = {'schema1':load_stage_name+'/schema1/', 'schema2':load_stage_name+'/schema2/'}\n", 195 | " files_to_load = {'schema1': schema1_download_files, 'schema2': schema2_download_files}\n", 196 | " else:\n", 197 | " print(\"Extracting files from public location.\")\n", 198 | " load_stage_names, files_to_load = ELT.extract_trips_to_stage(session=session, \n", 199 | " files_to_download=files_to_ingest, \n", 200 | " download_base_url=download_base_url, \n", 201 | " load_stage_name=load_stage_name)\n", 202 | " \n", 203 | " files_to_load['schema1']=[file+'.gz' for file in files_to_load['schema1']]\n", 204 | " files_to_load['schema2']=[file+'.gz' for file in files_to_load['schema2']]\n", 205 | "\n", 206 | "\n", 207 | " print(\"Loading files to raw.\")\n", 208 | " stage_table_names = ELT.load_trips_to_raw(session=session, \n", 209 | " files_to_load=files_to_load, \n", 210 | " load_stage_names=load_stage_names, \n", 211 | " load_table_name=load_table_name) \n", 212 | " \n", 213 | " print(\"Transforming records to trips table.\")\n", 214 | " trips_table_name = ELT.transform_trips(session=session, \n", 215 | " stage_table_names=stage_table_names, \n", 216 | " trips_table_name=trips_table_name)\n", 217 | " return trips_table_name\n", 218 | "\n", 219 | "def bulk_elt(session, \n", 220 | " state_dict:dict,\n", 221 | " download_base_url, \n", 222 | " use_prestaged=False) -> str:\n", 223 | " \n", 224 | " #import dags.elt as ELT\n", 225 | " from dags.ingest import incremental_elt\n", 226 | " \n", 227 | " import pandas as pd\n", 228 | " from datetime import datetime\n", 229 | "\n", 230 | " #Create a list of filenames to download based on date range\n", 231 | " #For files like 201306-citibike-tripdata.zip\n", 232 | " date_range1 = pd.period_range(start=datetime.strptime(\"201306\", \"%Y%m\"), \n", 233 | " end=datetime.strptime(\"201612\", \"%Y%m\"), \n", 234 | " freq='M').strftime(\"%Y%m\")\n", 235 | " file_name_end1 = '-citibike-tripdata.zip'\n", 236 | " files_to_extract = [date+file_name_end1 for date in date_range1.to_list()]\n", 237 | "\n", 238 | " #For files like 201701-citibike-tripdata.csv.zip\n", 239 | " date_range2 = pd.period_range(start=datetime.strptime(\"201701\", \"%Y%m\"), \n", 240 | " end=datetime.strptime(\"201912\", \"%Y%m\"), \n", 241 | " freq='M').strftime(\"%Y%m\")\n", 242 | " \n", 243 | " file_name_end2 = '-citibike-tripdata.csv.zip'\n", 244 | " \n", 245 | " files_to_extract = files_to_extract + [date+file_name_end2 for date in date_range2.to_list()] \n", 246 | "\n", 247 | " trips_table_name = incremental_elt(session=session, \n", 248 | " state_dict=state_dict, \n", 249 | " files_to_ingest=files_to_extract, \n", 250 | " use_prestaged=use_prestaged,\n", 251 | " download_base_url=download_base_url)\n", 252 | " \n", 253 | " return trips_table_name\n" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "The incremental ELT function allows us to specify one or more files to extract, load and transform. Lets try it with a couple of examples. Start with a single file." 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "%%time\n", 270 | "from dags.ingest import incremental_elt\n", 271 | "from dags.elt import reset_database\n", 272 | "from dags.snowpark_connection import snowpark_connect\n", 273 | "\n", 274 | "session, state_dict = snowpark_connect('./include/state.json')\n", 275 | "\n", 276 | "session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])\n", 277 | "\n", 278 | "reset_database(session=session, state_dict=state_dict, prestaged=False)\n", 279 | "\n", 280 | "incremental_elt(session=session, \n", 281 | " state_dict=state_dict, \n", 282 | " files_to_ingest=['202001-citibike-tripdata.csv.zip'], \n", 283 | " download_base_url=state_dict['download_base_url'],\n", 284 | " use_prestaged=False)\n", 285 | "session.close()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "We may need to ingest a list of multiple files." 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "%%time\n", 302 | "from dags.ingest import incremental_elt\n", 303 | "from dags.elt import reset_database\n", 304 | "from dags.snowpark_connection import snowpark_connect\n", 305 | "\n", 306 | "session, state_dict = snowpark_connect('./include/state.json')\n", 307 | "\n", 308 | "session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])\n", 309 | "\n", 310 | "reset_database(session=session, state_dict=state_dict, prestaged=False)\n", 311 | "\n", 312 | "incremental_elt(session=session, \n", 313 | " state_dict=state_dict, \n", 314 | " files_to_ingest=['202002-citibike-tripdata.csv.zip', '202102-citibike-tripdata.csv.zip'], \n", 315 | " download_base_url=state_dict['download_base_url'],\n", 316 | " use_prestaged=False)\n", 317 | "\n", 318 | "session.close()" 319 | ] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": {}, 324 | "source": [ 325 | "These load functions will default to loading from the public citibike data set. However, we may want to be able to specify files already pre-downloaded into a different S3 bucket. The functions assume the files are in gzip format in that bucket." 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "%%time\n", 335 | "from dags.ingest import incremental_elt\n", 336 | "from dags.elt import reset_database\n", 337 | "from dags.snowpark_connection import snowpark_connect\n", 338 | "\n", 339 | "session, state_dict = snowpark_connect('./include/state.json')\n", 340 | "\n", 341 | "session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])\n", 342 | "\n", 343 | "reset_database(session=session, state_dict=state_dict, prestaged=True)\n", 344 | "\n", 345 | "incremental_elt(session=session, \n", 346 | " state_dict=state_dict, \n", 347 | " files_to_ingest=['202001-citibike-tripdata.csv.zip', '202102-citibike-tripdata.csv.zip'],\n", 348 | " download_base_url=state_dict['connection_parameters']['download_base_url'],\n", 349 | " use_prestaged=True)\n", 350 | "session.close()" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "We could also bulk load the entire historical dataset using the following. This takes at least 30min depending on network speed to your local system. See below for an alternative." 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "# %%time\n", 367 | "# from dags.ingest import bulk_elt\n", 368 | "# from dags.elt import reset_database\n", 369 | "# from dags.snowpark_connection import snowpark_connect\n", 370 | "\n", 371 | "# session, state_dict = snowpark_connect('./include/state.json')\n", 372 | "\n", 373 | "# session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])\n", 374 | "\n", 375 | "# reset_database(session=session, state_dict=state_dict, prestaged=False)\n", 376 | "\n", 377 | "# bulk_elt(session=session, \n", 378 | "# state_dict=state_dict, \n", 379 | "# use_prestaged=False, \n", 380 | "# download_base_url='https://s3.amazonaws.com/tripdata/')\n", 381 | "# session.close()" 382 | ] 383 | }, 384 | { 385 | "cell_type": "markdown", 386 | "metadata": {}, 387 | "source": [ 388 | "For the hands-on-lab we will bulk load from a different S3 bucket where the files are already in gzip format (see below). \n", 389 | "\n", 390 | "For this project we are going back in time and pretending it is January 2020 (so that we can experience the effect of data drift during COVID lockdown). So this bulk load ingests from an existing bucket with data from June 2013 to January 2020." 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "%%time\n", 400 | "from dags.ingest import bulk_elt\n", 401 | "from dags.elt import reset_database\n", 402 | "from dags.snowpark_connection import snowpark_connect\n", 403 | "\n", 404 | "session, state_dict = snowpark_connect()\n", 405 | "\n", 406 | "state_dict.update({'load_table_name': 'RAW_',\n", 407 | " 'trips_table_name': 'TRIPS',\n", 408 | " 'load_stage_name': 'LOAD_STAGE'\n", 409 | " })\n", 410 | "import json\n", 411 | "with open('./include/state.json', 'w') as sdf:\n", 412 | " json.dump(state_dict, sdf)\n", 413 | "\n", 414 | "reset_database(session=session, state_dict=state_dict, prestaged=True)\n", 415 | "\n", 416 | "session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])\n", 417 | "\n", 418 | "bulk_elt(session=session, \n", 419 | " state_dict=state_dict, \n", 420 | " download_base_url=state_dict['connection_parameters']['download_base_url'],\n", 421 | " use_prestaged=True)" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "session.table(state_dict['trips_table_name']).count()" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "session.close()" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "Without the need to download locally we ingested ~90 million records in about 30 seconds." 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [] 455 | } 456 | ], 457 | "metadata": { 458 | "authors": [ 459 | { 460 | "name": "cforbe" 461 | } 462 | ], 463 | "kernelspec": { 464 | "display_name": "snowpark_0110:Python", 465 | "language": "python", 466 | "name": "conda-env-snowpark_0110-py" 467 | }, 468 | "language_info": { 469 | "codemirror_mode": { 470 | "name": "ipython", 471 | "version": 3 472 | }, 473 | "file_extension": ".py", 474 | "mimetype": "text/x-python", 475 | "name": "python", 476 | "nbconvert_exporter": "python", 477 | "pygments_lexer": "ipython3", 478 | "version": "3.8.13" 479 | }, 480 | "msauthor": "trbye" 481 | }, 482 | "nbformat": 4, 483 | "nbformat_minor": 4 484 | } 485 | -------------------------------------------------------------------------------- /05_Airflow_Pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2e21e38c", 6 | "metadata": {}, 7 | "source": [ 8 | "## Apache Airflow (OPTIONAL) \n", 9 | "\n", 10 | "In this section of the hands-on-lab, we will utilize Snowpark's Python client-side Dataframe API as well as the Snowpark server-side runtime and Apache Airflow to create an operational pipeline. We will take the functions created by the ML Ops team and create a directed acyclic graph (DAG) of operations to run each month when new data is available. \n", 11 | "\n", 12 | "Note: This code requires the ability to run docker containers locally. If you do not have Docker Desktop you can run the same pipeline from a python kernel via the 04_ML_Ops.ipynb notebook." 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "0c266ce3", 18 | "metadata": {}, 19 | "source": [ 20 | "We will use the dev CLI from Astronomer. https://docs.astronomer.io/astro/cli/get-started#step-1-install-the-astro-cli\n", 21 | "\n", 22 | "Follow the instructions to install the `astro` CLI for your particular local setup." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "ad3fbdcb", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "%%writefile dags/airflow_tasks.py\n", 33 | "\n", 34 | "from airflow.decorators import task\n", 35 | "\n", 36 | "@task.virtualenv(python_version=3.8)\n", 37 | "def snowpark_database_setup(state_dict:dict)-> dict: \n", 38 | " import snowflake.snowpark.functions as F\n", 39 | " from dags.snowpark_connection import snowpark_connect\n", 40 | " from dags.elt import reset_database\n", 41 | "\n", 42 | " session, _ = snowpark_connect('./include/state.json')\n", 43 | " reset_database(session=session, state_dict=state_dict, prestaged=True)\n", 44 | "\n", 45 | " _ = session.sql('CREATE STAGE '+state_dict['model_stage_name']).collect()\n", 46 | " _ = session.sql('CREATE TAG model_id_tag').collect()\n", 47 | "\n", 48 | " session.close()\n", 49 | "\n", 50 | " return state_dict\n", 51 | "\n", 52 | "@task.virtualenv(python_version=3.8)\n", 53 | "def incremental_elt_task(state_dict: dict, files_to_download:list)-> dict:\n", 54 | " from dags.ingest import incremental_elt\n", 55 | " from dags.snowpark_connection import snowpark_connect\n", 56 | "\n", 57 | " session, _ = snowpark_connect()\n", 58 | "\n", 59 | " print('Ingesting '+str(files_to_download))\n", 60 | " download_base_url=state_dict['connection_parameters']['download_base_url']\n", 61 | "\n", 62 | " _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse'])\n", 63 | "\n", 64 | " _ = incremental_elt(session=session, \n", 65 | " state_dict=state_dict, \n", 66 | " files_to_ingest=files_to_download,\n", 67 | " download_base_url=download_base_url,\n", 68 | " use_prestaged=True)\n", 69 | "\n", 70 | " #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\\\n", 71 | " # ' SUSPEND').collect()\n", 72 | "\n", 73 | " session.close()\n", 74 | " return state_dict\n", 75 | "\n", 76 | "@task.virtualenv(python_version=3.8)\n", 77 | "def initial_bulk_load_task(state_dict:dict)-> dict:\n", 78 | " from dags.ingest import bulk_elt\n", 79 | " from dags.snowpark_connection import snowpark_connect\n", 80 | "\n", 81 | " session, _ = snowpark_connect()\n", 82 | "\n", 83 | " _ = session.use_warehouse(state_dict['compute_parameters']['load_warehouse'])\n", 84 | "\n", 85 | " print('Running initial bulk ingest from '+state_dict['connection_parameters']['download_base_url'])\n", 86 | " \n", 87 | " _ = bulk_elt(session=session, \n", 88 | " state_dict=state_dict, \n", 89 | " download_base_url=state_dict['connection_parameters']['download_base_url'],\n", 90 | " use_prestaged=True)\n", 91 | "\n", 92 | " #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['load_warehouse']+\\\n", 93 | " # ' SUSPEND').collect()\n", 94 | "\n", 95 | " session.close()\n", 96 | " return state_dict\n", 97 | "\n", 98 | "@task.virtualenv(python_version=3.8)\n", 99 | "def materialize_holiday_task(state_dict: dict)-> dict:\n", 100 | " from dags.snowpark_connection import snowpark_connect\n", 101 | " from dags.mlops_pipeline import materialize_holiday_table\n", 102 | "\n", 103 | " print('Materializing holiday table.')\n", 104 | " session, _ = snowpark_connect()\n", 105 | "\n", 106 | " _ = materialize_holiday_table(session=session, \n", 107 | " holiday_table_name=state_dict['holiday_table_name'])\n", 108 | "\n", 109 | " session.close()\n", 110 | " return state_dict\n", 111 | "\n", 112 | "@task.virtualenv(python_version=3.8)\n", 113 | "def subscribe_to_weather_data_task(state_dict: dict)-> dict:\n", 114 | " from dags.snowpark_connection import snowpark_connect\n", 115 | " from dags.mlops_pipeline import subscribe_to_weather_data\n", 116 | "\n", 117 | " print('Subscribing to weather data')\n", 118 | " session, _ = snowpark_connect()\n", 119 | "\n", 120 | " _ = subscribe_to_weather_data(session=session, \n", 121 | " weather_database_name=state_dict['weather_database_name'], \n", 122 | " weather_listing_id=state_dict['weather_listing_id'])\n", 123 | " session.close()\n", 124 | " return state_dict\n", 125 | "\n", 126 | "@task.virtualenv(python_version=3.8)\n", 127 | "def create_weather_view_task(state_dict: dict)-> dict:\n", 128 | " from dags.snowpark_connection import snowpark_connect\n", 129 | " from dags.mlops_pipeline import create_weather_view\n", 130 | "\n", 131 | " print('Creating weather view')\n", 132 | " session, _ = snowpark_connect()\n", 133 | "\n", 134 | " _ = create_weather_view(session=session,\n", 135 | " weather_table_name=state_dict['weather_table_name'],\n", 136 | " weather_view_name=state_dict['weather_view_name'])\n", 137 | " session.close()\n", 138 | " return state_dict\n", 139 | " \n", 140 | "@task.virtualenv(python_version=3.8)\n", 141 | "def deploy_model_udf_task(state_dict:dict)-> dict:\n", 142 | " from dags.snowpark_connection import snowpark_connect\n", 143 | " from dags.mlops_pipeline import deploy_pred_train_udf\n", 144 | "\n", 145 | " print('Deploying station model')\n", 146 | " session, _ = snowpark_connect()\n", 147 | "\n", 148 | " _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect()\n", 149 | "\n", 150 | " _ = deploy_pred_train_udf(session=session, \n", 151 | " udf_name=state_dict['train_udf_name'],\n", 152 | " function_name=state_dict['train_func_name'],\n", 153 | " model_stage_name=state_dict['model_stage_name'])\n", 154 | " session.close()\n", 155 | " return state_dict\n", 156 | "\n", 157 | "@task.virtualenv(python_version=3.8)\n", 158 | "def deploy_eval_udf_task(state_dict:dict)-> dict:\n", 159 | " from dags.snowpark_connection import snowpark_connect\n", 160 | " from dags.mlops_pipeline import deploy_eval_udf\n", 161 | "\n", 162 | " print('Deploying station model')\n", 163 | " session, _ = snowpark_connect()\n", 164 | "\n", 165 | " _ = session.sql('CREATE STAGE IF NOT EXISTS ' + state_dict['model_stage_name']).collect()\n", 166 | "\n", 167 | " _ = deploy_eval_udf(session=session, \n", 168 | " udf_name=state_dict['eval_udf_name'],\n", 169 | " function_name=state_dict['eval_func_name'],\n", 170 | " model_stage_name=state_dict['model_stage_name'])\n", 171 | " session.close()\n", 172 | " return state_dict\n", 173 | "\n", 174 | "@task.virtualenv(python_version=3.8)\n", 175 | "def generate_feature_table_task(state_dict:dict, \n", 176 | " holiday_state_dict:dict, \n", 177 | " weather_state_dict:dict)-> dict:\n", 178 | " from dags.snowpark_connection import snowpark_connect\n", 179 | " from dags.mlops_pipeline import create_feature_table\n", 180 | "\n", 181 | " print('Generating features for all stations.')\n", 182 | " session, _ = snowpark_connect()\n", 183 | "\n", 184 | " session.use_warehouse(state_dict['compute_parameters']['fe_warehouse'])\n", 185 | "\n", 186 | " _ = session.sql(\"CREATE OR REPLACE TABLE \"+state_dict['clone_table_name']+\\\n", 187 | " \" CLONE \"+state_dict['trips_table_name']).collect()\n", 188 | " _ = session.sql(\"ALTER TABLE \"+state_dict['clone_table_name']+\\\n", 189 | " \" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 190 | "\n", 191 | " _ = create_feature_table(session, \n", 192 | " trips_table_name=state_dict['clone_table_name'], \n", 193 | " holiday_table_name=state_dict['holiday_table_name'], \n", 194 | " weather_view_name=state_dict['weather_view_name'],\n", 195 | " feature_table_name=state_dict['feature_table_name'])\n", 196 | "\n", 197 | " _ = session.sql(\"ALTER TABLE \"+state_dict['feature_table_name']+\\\n", 198 | " \" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 199 | "\n", 200 | " session.close()\n", 201 | " return state_dict\n", 202 | "\n", 203 | "@task.virtualenv(python_version=3.8)\n", 204 | "def generate_forecast_table_task(state_dict:dict, \n", 205 | " holiday_state_dict:dict, \n", 206 | " weather_state_dict:dict)-> dict: \n", 207 | " from dags.snowpark_connection import snowpark_connect\n", 208 | " from dags.mlops_pipeline import create_forecast_table\n", 209 | "\n", 210 | " print('Generating forecast features.')\n", 211 | " session, _ = snowpark_connect()\n", 212 | "\n", 213 | " _ = create_forecast_table(session, \n", 214 | " trips_table_name=state_dict['trips_table_name'],\n", 215 | " holiday_table_name=state_dict['holiday_table_name'], \n", 216 | " weather_view_name=state_dict['weather_view_name'], \n", 217 | " forecast_table_name=state_dict['forecast_table_name'],\n", 218 | " steps=state_dict['forecast_steps'])\n", 219 | "\n", 220 | " _ = session.sql(\"ALTER TABLE \"+state_dict['forecast_table_name']+\\\n", 221 | " \" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 222 | "\n", 223 | " session.close()\n", 224 | " return state_dict\n", 225 | "\n", 226 | "@task.virtualenv(python_version=3.8)\n", 227 | "def bulk_train_predict_task(state_dict:dict, \n", 228 | " feature_state_dict:dict, \n", 229 | " forecast_state_dict:dict)-> dict: \n", 230 | " from dags.snowpark_connection import snowpark_connect\n", 231 | " from dags.mlops_pipeline import train_predict\n", 232 | "\n", 233 | " state_dict = feature_state_dict\n", 234 | "\n", 235 | " print('Running bulk training and forecast.')\n", 236 | " session, _ = snowpark_connect()\n", 237 | "\n", 238 | " session.use_warehouse(state_dict['compute_parameters']['train_warehouse'])\n", 239 | "\n", 240 | " pred_table_name = train_predict(session, \n", 241 | " station_train_pred_udf_name=state_dict['train_udf_name'], \n", 242 | " feature_table_name=state_dict['feature_table_name'], \n", 243 | " forecast_table_name=state_dict['forecast_table_name'],\n", 244 | " pred_table_name=state_dict['pred_table_name'])\n", 245 | "\n", 246 | " _ = session.sql(\"ALTER TABLE \"+state_dict['pred_table_name']+\\\n", 247 | " \" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 248 | " #_ = session.sql('ALTER WAREHOUSE IF EXISTS '+state_dict['compute_parameters']['train_warehouse']+\\\n", 249 | " # ' SUSPEND').collect()\n", 250 | "\n", 251 | " session.close()\n", 252 | " return state_dict\n", 253 | "\n", 254 | "@task.virtualenv(python_version=3.8)\n", 255 | "def eval_station_models_task(state_dict:dict, \n", 256 | " pred_state_dict:dict,\n", 257 | " run_date:str)-> dict:\n", 258 | "\n", 259 | " from dags.snowpark_connection import snowpark_connect\n", 260 | " from dags.mlops_pipeline import evaluate_station_model\n", 261 | "\n", 262 | " print('Running eval UDF for model output')\n", 263 | " session, _ = snowpark_connect()\n", 264 | "\n", 265 | " eval_table_name = evaluate_station_model(session, \n", 266 | " run_date=run_date, \n", 267 | " eval_model_udf_name=state_dict['eval_udf_name'], \n", 268 | " pred_table_name=state_dict['pred_table_name'], \n", 269 | " eval_table_name=state_dict['eval_table_name'])\n", 270 | "\n", 271 | " _ = session.sql(\"ALTER TABLE \"+state_dict['eval_table_name']+\\\n", 272 | " \" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 273 | " session.close()\n", 274 | " return state_dict \n", 275 | "\n", 276 | "@task.virtualenv(python_version=3.8)\n", 277 | "def flatten_tables_task(pred_state_dict:dict, state_dict:dict)-> dict:\n", 278 | " from dags.snowpark_connection import snowpark_connect\n", 279 | " from dags.mlops_pipeline import flatten_tables\n", 280 | "\n", 281 | " print('Flattening tables for end-user consumption.')\n", 282 | " session, _ = snowpark_connect()\n", 283 | "\n", 284 | " flat_pred_table, flat_forecast_table, flat_eval_table = flatten_tables(session,\n", 285 | " pred_table_name=state_dict['pred_table_name'], \n", 286 | " forecast_table_name=state_dict['forecast_table_name'], \n", 287 | " eval_table_name=state_dict['eval_table_name'])\n", 288 | " state_dict['flat_pred_table'] = flat_pred_table\n", 289 | " state_dict['flat_forecast_table'] = flat_forecast_table\n", 290 | " state_dict['flat_eval_table'] = flat_eval_table\n", 291 | "\n", 292 | " _ = session.sql(\"ALTER TABLE \"+flat_pred_table+\" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 293 | " _ = session.sql(\"ALTER TABLE \"+flat_forecast_table+\" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 294 | " _ = session.sql(\"ALTER TABLE \"+flat_eval_table+\" SET TAG model_id_tag = '\"+state_dict['model_id']+\"'\").collect()\n", 295 | "\n", 296 | " return state_dict\n" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "5a42735e", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "%%writefile dags/airflow_setup_pipeline.py\n", 307 | "\n", 308 | "from datetime import datetime, timedelta\n", 309 | "\n", 310 | "from airflow.decorators import dag, task\n", 311 | "from dags.airflow_tasks import snowpark_database_setup\n", 312 | "from dags.airflow_tasks import incremental_elt_task\n", 313 | "from dags.airflow_tasks import initial_bulk_load_task\n", 314 | "from dags.airflow_tasks import materialize_holiday_task\n", 315 | "from dags.airflow_tasks import subscribe_to_weather_data_task\n", 316 | "from dags.airflow_tasks import create_weather_view_task\n", 317 | "from dags.airflow_tasks import deploy_model_udf_task\n", 318 | "from dags.airflow_tasks import deploy_eval_udf_task\n", 319 | "from dags.airflow_tasks import generate_feature_table_task\n", 320 | "from dags.airflow_tasks import generate_forecast_table_task\n", 321 | "from dags.airflow_tasks import bulk_train_predict_task\n", 322 | "from dags.airflow_tasks import eval_station_models_task \n", 323 | "from dags.airflow_tasks import flatten_tables_task\n", 324 | "\n", 325 | "default_args = {\n", 326 | " 'owner': 'airflow',\n", 327 | " 'depends_on_past': False,\n", 328 | " 'email_on_failure': False,\n", 329 | " 'email_on_retry': False,\n", 330 | " 'retries': 1,\n", 331 | " 'retry_delay': timedelta(minutes=5)\n", 332 | "}\n", 333 | "\n", 334 | "#local_airflow_path = '/usr/local/airflow/'\n", 335 | "\n", 336 | "@dag(default_args=default_args, schedule_interval=None, start_date=datetime(2020, 3, 1), catchup=False, tags=['setup'])\n", 337 | "def citibikeml_setup_taskflow(run_date:str):\n", 338 | " \"\"\"\n", 339 | " Setup initial Snowpark / Astronomer ML Demo\n", 340 | " \"\"\"\n", 341 | " import uuid\n", 342 | " import json\n", 343 | " \n", 344 | " with open('./include/state.json') as sdf:\n", 345 | " state_dict = json.load(sdf)\n", 346 | " \n", 347 | " model_id = str(uuid.uuid1()).replace('-', '_')\n", 348 | "\n", 349 | " state_dict.update({'model_id': model_id})\n", 350 | " state_dict.update({'run_date': run_date})\n", 351 | " state_dict.update({'weather_database_name': 'WEATHER_NYC'})\n", 352 | " state_dict.update({'load_table_name': 'RAW_',\n", 353 | " 'trips_table_name': 'TRIPS',\n", 354 | " 'load_stage_name': 'LOAD_STAGE',\n", 355 | " 'model_stage_name': 'MODEL_STAGE',\n", 356 | " 'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY',\n", 357 | " 'weather_view_name': 'WEATHER_NYC_VW',\n", 358 | " 'holiday_table_name': 'HOLIDAYS',\n", 359 | " 'clone_table_name': 'CLONE_'+model_id,\n", 360 | " 'feature_table_name' : 'FEATURE_'+model_id,\n", 361 | " 'pred_table_name': 'PRED_'+model_id,\n", 362 | " 'eval_table_name': 'EVAL_'+model_id,\n", 363 | " 'forecast_table_name': 'FORECAST_'+model_id,\n", 364 | " 'forecast_steps': 30,\n", 365 | " 'train_udf_name': 'station_train_predict_udf',\n", 366 | " 'train_func_name': 'station_train_predict_func',\n", 367 | " 'eval_udf_name': 'eval_model_output_udf',\n", 368 | " 'eval_func_name': 'eval_model_func'\n", 369 | " })\n", 370 | " \n", 371 | " #Task order - one-time setup\n", 372 | " setup_state_dict = snowpark_database_setup(state_dict)\n", 373 | " load_state_dict = initial_bulk_load_task(setup_state_dict)\n", 374 | " holiday_state_dict = materialize_holiday_task(setup_state_dict)\n", 375 | " subscribe_state_dict = subscribe_to_weather_data_task(setup_state_dict)\n", 376 | " weather_state_dict = create_weather_view_task(subscribe_state_dict)\n", 377 | " model_udf_state_dict = deploy_model_udf_task(setup_state_dict)\n", 378 | " eval_udf_state_dict = deploy_eval_udf_task(setup_state_dict)\n", 379 | " feature_state_dict = generate_feature_table_task(load_state_dict, holiday_state_dict, weather_state_dict) \n", 380 | " foecast_state_dict = generate_forecast_table_task(load_state_dict, holiday_state_dict, weather_state_dict)\n", 381 | " pred_state_dict = bulk_train_predict_task(model_udf_state_dict, feature_state_dict, foecast_state_dict)\n", 382 | " eval_state_dict = eval_station_models_task(eval_udf_state_dict, pred_state_dict, run_date) \n", 383 | " state_dict = flatten_tables_task(pred_state_dict, eval_state_dict)\n", 384 | "\n", 385 | " return state_dict\n", 386 | "\n", 387 | "run_date='2020_01_01'\n", 388 | "\n", 389 | "state_dict = citibikeml_setup_taskflow(run_date=run_date)\n" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "id": "29fdd741", 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "%%writefile dags/airflow_incremental_pipeline.py\n", 400 | "\n", 401 | "from datetime import datetime, timedelta\n", 402 | "\n", 403 | "from airflow.decorators import dag, task\n", 404 | "from dags.airflow_tasks import snowpark_database_setup\n", 405 | "from dags.airflow_tasks import incremental_elt_task\n", 406 | "from dags.airflow_tasks import initial_bulk_load_task\n", 407 | "from dags.airflow_tasks import materialize_holiday_task\n", 408 | "from dags.airflow_tasks import deploy_model_udf_task\n", 409 | "from dags.airflow_tasks import deploy_eval_udf_task\n", 410 | "from dags.airflow_tasks import generate_feature_table_task\n", 411 | "from dags.airflow_tasks import generate_forecast_table_task\n", 412 | "from dags.airflow_tasks import bulk_train_predict_task\n", 413 | "from dags.airflow_tasks import eval_station_models_task \n", 414 | "from dags.airflow_tasks import flatten_tables_task\n", 415 | "\n", 416 | "default_args = {\n", 417 | " 'owner': 'airflow',\n", 418 | " 'depends_on_past': False,\n", 419 | " 'email_on_failure': False,\n", 420 | " 'email_on_retry': False,\n", 421 | " 'retries': 1,\n", 422 | " 'retry_delay': timedelta(minutes=5)\n", 423 | "}\n", 424 | "\n", 425 | "#local_airflow_path = '/usr/local/airflow/'\n", 426 | "\n", 427 | "@dag(default_args=default_args, schedule_interval=None, start_date=datetime(2020, 4, 1), catchup=False, tags=['monthly'])\n", 428 | "def citibikeml_monthly_taskflow(files_to_download:list, run_date:str):\n", 429 | " \"\"\"\n", 430 | " End to end Snowpark / Astronomer ML Demo\n", 431 | " \"\"\"\n", 432 | " import uuid\n", 433 | " import json\n", 434 | " \n", 435 | " with open('./include/state.json') as sdf:\n", 436 | " state_dict = json.load(sdf)\n", 437 | " \n", 438 | " model_id = str(uuid.uuid1()).replace('-', '_')\n", 439 | "\n", 440 | " state_dict.update({'model_id': model_id})\n", 441 | " state_dict.update({'run_date': run_date})\n", 442 | " state_dict.update({'weather_database_name': 'WEATHER_NYC'})\n", 443 | " state_dict.update({'load_table_name': 'RAW_',\n", 444 | " 'trips_table_name': 'TRIPS',\n", 445 | " 'load_stage_name': 'LOAD_STAGE',\n", 446 | " 'model_stage_name': 'MODEL_STAGE',\n", 447 | " 'weather_table_name': state_dict['weather_database_name']+'.ONPOINT_ID.HISTORY_DAY',\n", 448 | " 'weather_view_name': 'WEATHER_NYC_VW',\n", 449 | " 'holiday_table_name': 'HOLIDAYS',\n", 450 | " 'clone_table_name': 'CLONE_'+model_id,\n", 451 | " 'feature_table_name' : 'FEATURE_'+model_id,\n", 452 | " 'pred_table_name': 'PRED_'+model_id,\n", 453 | " 'eval_table_name': 'EVAL_'+model_id,\n", 454 | " 'forecast_table_name': 'FORECAST_'+model_id,\n", 455 | " 'forecast_steps': 30,\n", 456 | " 'train_udf_name': 'station_train_predict_udf',\n", 457 | " 'train_func_name': 'station_train_predict_func',\n", 458 | " 'eval_udf_name': 'eval_model_output_udf',\n", 459 | " 'eval_func_name': 'eval_model_func'\n", 460 | " })\n", 461 | "\n", 462 | " incr_state_dict = incremental_elt_task(state_dict, files_to_download)\n", 463 | " feature_state_dict = generate_feature_table_task(incr_state_dict, incr_state_dict, incr_state_dict) \n", 464 | " forecast_state_dict = generate_forecast_table_task(incr_state_dict, incr_state_dict, incr_state_dict)\n", 465 | " pred_state_dict = bulk_train_predict_task(feature_state_dict, feature_state_dict, forecast_state_dict)\n", 466 | " eval_state_dict = eval_station_models_task(pred_state_dict, pred_state_dict, run_date)\n", 467 | " state_dict = flatten_tables_task(pred_state_dict, eval_state_dict)\n", 468 | "\n", 469 | " return state_dict\n", 470 | "\n", 471 | "run_date='2020_02_01'\n", 472 | "files_to_download = ['202001-citibike-tripdata.csv.zip']\n", 473 | "\n", 474 | "state_dict = citibikeml_monthly_taskflow(files_to_download=files_to_download, \n", 475 | " run_date=run_date)\n" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "id": "b183b8fa", 481 | "metadata": {}, 482 | "source": [ 483 | "Now open a new browser tab to localhost:8080" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "id": "5961a58b", 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "import webbrowser\n", 494 | "\n", 495 | "# generate an URL\n", 496 | "url = 'https://localhost:8080'\n", 497 | "webbrowser.open(url)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "id": "dcd1ebc5", 503 | "metadata": {}, 504 | "source": [ 505 | "Lets run the initial setup, ingest and forecast DAG." 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "id": "4f1e1041", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "# #This sample code can be used to trigger the Airflow pipeline from a command-line shell.\n", 516 | "# !curl -X POST 'http://localhost:8080/api/v1/dags/citibikeml_monthly_taskflow/dagRuns' \\\n", 517 | "# -H 'Content-Type: application/json' \\\n", 518 | "# --user \"admin:admin\" \\\n", 519 | "# -d '{\"conf\": {\"files_to_download\": [\"202003-citibike-tripdata.csv.zip\"], \"run_date\": \"2020_04_01\"}}'" 520 | ] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "id": "fb0d7c97", 525 | "metadata": {}, 526 | "source": [ 527 | "Alternatively we can use a REST API" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "id": "b84288d7", 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "import requests\n", 538 | "from requests.auth import HTTPBasicAuth\n", 539 | "import time \n", 540 | "import json\n", 541 | "\n", 542 | "dag_url='http://localhost:8080/api/v1/dags/citibikeml_setup_taskflow/dagRuns'\n", 543 | "json_payload = {\"conf\": {\"run_date\": \"2020_01_01\"}}\n", 544 | "\n", 545 | "response = requests.post(dag_url, \n", 546 | " json=json_payload,\n", 547 | " auth = HTTPBasicAuth('admin', 'admin'))\n", 548 | "\n", 549 | "run_id = json.loads(response.text)['dag_run_id']\n", 550 | "\n", 551 | "state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']\n", 552 | "\n", 553 | "while state != 'success':\n", 554 | " print('DAG running...'+state)\n", 555 | " time.sleep(10)\n", 556 | " state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": null, 562 | "id": "d0f3dbec", 563 | "metadata": { 564 | "pycharm": { 565 | "name": "#%%\n" 566 | } 567 | }, 568 | "outputs": [], 569 | "source": [ 570 | "import requests\n", 571 | "from requests.auth import HTTPBasicAuth\n", 572 | "import time \n", 573 | "import json\n", 574 | "\n", 575 | "dag_url='http://localhost:8080/api/v1/dags/citibikeml_monthly_taskflow/dagRuns'\n", 576 | "json_payload = {\"conf\": {\"files_to_download\": [\"202001-citibike-tripdata.csv.zip\"], \"run_date\": \"2020_02_01\"}}\n", 577 | "\n", 578 | "response = requests.post(dag_url, \n", 579 | " json=json_payload,\n", 580 | " auth = HTTPBasicAuth('admin', 'admin'))\n", 581 | "\n", 582 | "run_id = json.loads(response.text)['dag_run_id']\n", 583 | "\n", 584 | "state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']\n", 585 | "\n", 586 | "while state != 'success':\n", 587 | " print('DAG running...'+state)\n", 588 | " time.sleep(10)\n", 589 | " state=json.loads(requests.get(dag_url+'/'+run_id, auth=HTTPBasicAuth('admin', 'admin')).text)['state']" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "id": "ad097d86", 596 | "metadata": {}, 597 | "outputs": [], 598 | "source": [] 599 | } 600 | ], 601 | "metadata": { 602 | "kernelspec": { 603 | "display_name": "Python 3 (ipykernel)", 604 | "language": "python", 605 | "name": "python3" 606 | }, 607 | "language_info": { 608 | "codemirror_mode": { 609 | "name": "ipython", 610 | "version": 3 611 | }, 612 | "file_extension": ".py", 613 | "mimetype": "text/x-python", 614 | "name": "python", 615 | "nbconvert_exporter": "python", 616 | "pygments_lexer": "ipython3", 617 | "version": "3.8.12" 618 | } 619 | }, 620 | "nbformat": 4, 621 | "nbformat_minor": 5 622 | } 623 | --------------------------------------------------------------------------------