├── README.md
├── dbt_duckdb
├── .gitignore
├── Makefile
├── README.md
├── dutch_railway_network
│ ├── .sqlfluff
│ ├── analyses
│ │ ├── build_charts.py
│ │ └── time_based_functions
│ │ │ ├── hopping_window.py
│ │ │ ├── session_window.py
│ │ │ ├── sliding_window.py
│ │ │ ├── tumbling_window.py
│ │ │ └── utils.py
│ ├── data
│ │ ├── .gitkeep
│ │ └── exports
│ │ │ └── .gitkeep
│ ├── dbt_project.yml
│ ├── macros
│ │ └── common_columns.sql
│ ├── models
│ │ ├── exports
│ │ │ ├── export_province_geojson.sql
│ │ │ ├── export_train_services_agg.sql
│ │ │ └── schema.yml
│ │ ├── reverse_etl
│ │ │ ├── rep_dim_nl_municipalities.sql
│ │ │ ├── rep_dim_nl_provinces.sql
│ │ │ ├── rep_dim_nl_train_stations.sql
│ │ │ ├── rep_fact_train_services_daily_agg.sql
│ │ │ └── schema.yml
│ │ ├── sources.yml
│ │ └── transformation
│ │ │ ├── ams_traffic_v.sql
│ │ │ ├── dim_nl_municipalities.sql
│ │ │ ├── dim_nl_provinces.sql
│ │ │ ├── dim_nl_train_stations.sql
│ │ │ ├── fact_services.sql
│ │ │ └── schema.yml
│ ├── package-lock.yml
│ ├── packages.yml
│ ├── profiles.yml
│ ├── seeds
│ │ └── gemeente_2025.geojson
│ ├── snapshots
│ │ └── .gitkeep
│ └── tests
│ │ ├── test_province_municipality_relation.sql
│ │ └── test_rep_fact_services.sql
└── requirements.txt
├── duckdb_streamlit
├── .gitignore
├── Makefile
├── README.md
├── app.py
├── constants.py
├── pages
│ ├── closest_train_stations.py
│ └── railway_network_utilization.py
├── requirements.txt
└── utils.py
├── guides
└── DuckDB_in_Jupyter_notebooks.ipynb
└── scikit_learn_duckdb
├── .gitignore
├── Makefile
├── README.md
├── model
└── .gitkeep
├── predict_penguin_species.py
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 |
3 | In this repository we store the code we show-case in blog posts.
4 | This code is intended for learning purposes, and might not be maintained nor updated with new releases of the libraries used.
5 |
6 | # duckdb_streamlit
7 |
8 | The application behind [Using DuckDB in Streamlit blog post](https://duckdb.org/2025/03/28/using-duckdb-in-streamlit.html).
9 |
10 | # dbt_duckdb
11 |
12 | The application behind [Fully Local Data Transformation with dbt and DuckDB](https://duckdb.org/2025/04/04/dbt-duckdb.html).
13 |
--------------------------------------------------------------------------------
/dbt_duckdb/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | dutch_railway_network/target/
3 | dutch_railway_network/dbt_packages/
4 | dutch_railway_network/logs/
5 | dutch_railway_network/data/dutch_railway_network.duckdb
6 | venv_dbt_duckdb/
7 | dutch_railway_network/.user.yml
8 | dutch_railway_network/data/exports/nl_train_services_aggregate/service_year=2024/
9 | dutch_railway_network/data/exports/*.json
10 | dutch_railway_network/analyses/*.html
--------------------------------------------------------------------------------
/dbt_duckdb/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: setup-postgres, force-clean-pg-container, setup-env, run-dbt, serve-chart, clean-up
2 |
3 | setup-postgres:
4 | docker run --name postgres-dbt-duckdb -e POSTGRES_PASSWORD=mysecretpassword -p 5466:5432 -d postgres:17.4
5 | docker exec -it postgres-dbt-duckdb apt-get update
6 | docker exec -it postgres-dbt-duckdb apt-get install -y postgresql-17-postgis-3 postgis
7 | docker exec -it postgres-dbt-duckdb psql -U postgres -c 'create extension postgis'
8 |
9 | setup-python:
10 | python -m venv venv_dbt_duckdb && \
11 | source venv_dbt_duckdb/bin/activate && \
12 | pip install -r requirements.txt
13 |
14 | run-dbt:
15 | export DBT_DUCKDB_PG_PWD=mysecretpassword && \
16 | source venv_dbt_duckdb/bin/activate && \
17 | cd dutch_railway_network && \
18 | dbt deps && \
19 | dbt build
20 |
21 | serve-chart:
22 | source venv_dbt_duckdb/bin/activate && \
23 | cd dutch_railway_network && \
24 | python analyses/build_charts.py && \
25 | python analyses/time_based_functions/tumbling_window.py && \
26 | python analyses/time_based_functions/sliding_window.py && \
27 | python analyses/time_based_functions/session_window.py && \
28 | echo 'go to http://localhost:8888' && \
29 | python -m http.server 8888 -d analyses
30 |
31 |
32 | force-clean-pg-container:
33 | docker rm -f postgres-dbt-duckdb
34 |
35 | clean-up: force-clean-pg-container
36 | rm -rf venv_dbt_duckdb
37 |
--------------------------------------------------------------------------------
/dbt_duckdb/README.md:
--------------------------------------------------------------------------------
1 | Pre-requisites: docker, make, Python >= 3.12.
2 |
3 | ## Local execution
4 |
5 | 1. Startup PostgreSQL with `make setup-postgres`
6 | 2. Create Python virtual env with requirements: `make setup-python`
7 | 3. Build dbt models: `make run-dbt`
8 | 4. Serve chart: `make serve-chart`, the chart is served at: http://localhost:8888/charts.html
9 |
10 | ## Cleanup
11 |
12 | 1. Run `make clean-up`
13 |
14 | ## Misc
15 |
16 | - Format SQL with `sqlfluff fix models`
17 | - Generate ERD `dbt docs generate && dbterd run -t mermaid -s schema:main_public`
18 | - Connect from DuckDB to PostgreSQL
19 | ```sql
20 | CREATE secret pg(
21 | type postgres,
22 | host '127.0.0.1',
23 | port '5466',
24 | database 'postgres',
25 | user 'postgres',
26 | password 'mysecretpassword'
27 | );
28 | ATTACH '' AS postgres_db (type postgres, schema 'main_public', secret pg);
29 | ```
30 | - Generate schema files
31 | ```bash
32 | dbt run-operation generate_model_yaml --args '{"model_names": [], "upstream_descriptions":true}'
33 | ```
34 |
35 | - If there is an issue with spatial, make sure to force [update the version](https://github.com/duckdb/duckdb-spatial/issues/508).
36 |
37 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/.sqlfluff:
--------------------------------------------------------------------------------
1 | [sqlfluff]
2 | templater = jinja
3 | dialect = duckdb
4 | exclude_rules = AL04, AL05, AL06, AL07, AM01, AM02, AM03, AM04, AM07, CV01,
5 | CV02, CV03, CV04, CV05, CV06, CV07, CV08, CV09, CV10, CV11,
6 | LT03, LT05, LT07, LT08, LT12, LT13, RF01, RF02, RF03, RF04,
7 | RF05, RF06, ST01, ST02, ST03, ST06, ST07, ST08, TQ01, CP05
8 | ignore = templating
9 | large_file_skip_byte_limit = 0
10 | max_line_length = 0
11 |
12 |
13 | [sqlfluff:layout:type:alias_expression]
14 | spacing_before = align
15 | align_within = select_clause
16 | spacing_after = touch
17 |
18 | [sqlfluff:indentation]
19 | tab_space_size = 4
20 | indent_unit = space
21 | indented_joins = false
22 | indented_using_on = true
23 | allow_implicit_indents = true
24 | indented_on_contents = false
25 | indented_ctes = false
26 |
27 | [sqlfluff:rules:aliasing.table]
28 | aliasing.table = explicit
29 |
30 | [sqlfluff:rules:aliasing.column]
31 | aliasing.column = explicit
32 |
33 | [sqlfluff:rules:aliasing.expression]
34 | allow_scalar = True
35 |
36 | [sqlfluff:rules:ambiguous.join]
37 | fully_qualify_join_types = inner
38 |
39 | [sqlfluff:rules:ambiguous.column_references]
40 | group_by_and_order_by_style = consistent
41 |
42 | [sqlfluff:rules:capitalisation.keywords]
43 | capitalisation_policy = upper
44 |
45 | [sqlfluff:rules:capitalisation.identifiers]
46 | extended_capitalisation_policy = lower
47 | unquoted_identifiers_policy = all
48 |
49 | [sqlfluff:rules:capitalisation.functions]
50 | extended_capitalisation_policy = lower
51 |
52 | [sqlfluff:rules:capitalisation.literals]
53 | capitalisation_policy = upper
54 |
55 | [sqlfluff:rules:capitalisation.types]
56 | extended_capitalisation_policy = upper
57 |
58 | [sqlfluff:rules:jinja.padding]
59 | single_space = true
60 |
61 | [sqlfluff:rules:layout.spacing]
62 | no_trailing_whitespace = true
63 | extra_whitespace = false
64 |
65 | [sqlfluff:rules:layout.commas]
66 | line_position = trailing
67 |
68 | [sqlfluff:rules:layout.functions]
69 | no_space_after_function_name = true
70 |
71 | [sqlfluff:rules:layout.select_targets]
72 | wildcard_policy = single
73 |
74 | [sqlfluff:rules:layout.set_operators]
75 | set_operator_on_new_line = ['UNION', 'UNION ALL']
76 |
77 | [sqlfluff:rules:structure.nested_case]
78 |
79 | [sqlfluff:rules:structure.subquery]
80 | forbid_subquery_in = join
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/analyses/build_charts.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import plotly.express as px
5 | import duckdb
6 |
7 | dir_path = Path(__file__).parent.parent.absolute()
8 |
9 | import logging
10 |
11 | logging.basicConfig(
12 | level=logging.DEBUG,
13 | format="{asctime} - {message}",
14 | style="{",
15 | datefmt="%Y-%m-%d %H:%M:%S.%s",
16 | )
17 |
18 |
19 | def main():
20 | with duckdb.connect() as con:
21 | con.sql("load spatial;")
22 |
23 | logging.info("Reading parquet")
24 | services = (
25 | con.read_parquet(
26 | f"{dir_path}/data/exports/nl_train_services_aggregate/*/*/*.parquet",
27 | hive_partitioning=True,
28 | )
29 | .filter("service_year=2024")
30 | .filter("province_sk != 'unknown'")
31 | )
32 |
33 | logging.info("Getting number of rides")
34 |
35 | province_summary_df = services.aggregate(
36 | """
37 | province_sk,
38 | province_name,
39 | sum(number_of_rides) as number_of_rides
40 | """
41 | ).df()
42 |
43 | with open(f"{dir_path}/data/exports/provinces.json", "r") as f:
44 | province_geojson = json.load(f)
45 |
46 | logging.info("Generating map")
47 |
48 | fig = px.choropleth_map(
49 | province_summary_df,
50 | geojson=province_geojson,
51 | locations="province_sk",
52 | featureidkey="properties.province_sk",
53 | color="number_of_rides",
54 | color_continuous_scale="peach",
55 | center=dict(lat=52.20528, lon=5.5),
56 | zoom=6.5,
57 | height=800,
58 | width=800,
59 | title="Train Rides, Dutch Provinces, 2024",
60 | labels={"number_of_rides": "Number of Rides"},
61 | template="plotly_dark",
62 | hover_name="province_name",
63 | )
64 |
65 | logging.info("Saving map")
66 | fig.write_html(f"{dir_path}/analyses/charts.html")
67 |
68 | logging.info("Done")
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/analyses/time_based_functions/hopping_window.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from funcy import log_durations
4 |
5 | from utils import get_duckdb_conn
6 |
7 |
8 | @log_durations(logging.info)
9 | def get_hopping_window(duckdb_conn):
10 | return duckdb_conn.sql(
11 | """
12 | WITH time_range AS (
13 | SELECT
14 | range AS window_start,
15 | window_start + INTERVAL '15' MINUTE AS window_end -- window size of 15 minutes
16 | FROM range(
17 | '2024-01-01 00:00:00'::TIMESTAMP,
18 | '2025-01-01 00:00:00'::TIMESTAMP,
19 | INTERVAL '5' MINUTE -- hopping size of 5 minute
20 | )
21 | )
22 | SELECT
23 | window_start,
24 | window_end,
25 | count(service_sk) AS number_of_services
26 | FROM ams_traffic_v
27 | INNER JOIN time_range AS ts
28 | ON station_service_time >= ts.window_start AND station_service_time < ts.window_end
29 | GROUP BY ALL
30 | ORDER BY 3 DESC, 1 ASC
31 | LIMIT 5
32 | """
33 | )
34 |
35 |
36 | @log_durations(logging.info)
37 | def main():
38 | duckdb_conn = get_duckdb_conn()
39 | hopping_window = get_hopping_window(duckdb_conn)
40 |
41 | logging.info(hopping_window.show())
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/analyses/time_based_functions/session_window.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pandas as pd
4 | import plotly.express as px
5 | from funcy import log_durations
6 |
7 | from utils import PLOTS_DIR_PATH, get_duckdb_conn
8 |
9 |
10 | @log_durations(logging.info)
11 | def get_session_window(duckdb_conn):
12 | return duckdb_conn.sql(
13 | """
14 | WITH ams_daily_traffic AS (
15 | SELECT
16 | service_sk,
17 | station_service_time,
18 | lag(station_service_time) OVER (
19 | PARTITION BY station_service_time::DATE
20 | ORDER BY station_service_time
21 | ) AS previous_service_time,
22 | date_diff('minute', previous_service_time, station_service_time) AS gap_minutes
23 | FROM ams_traffic_v
24 | WHERE hour(station_service_time) BETWEEN 6 AND 23
25 | ), window_calculation AS (
26 | SELECT
27 | service_sk,
28 | station_service_time,
29 | station_service_time::DATE AS station_service_date,
30 | gap_minutes,
31 | IF(gap_minutes >= 10 OR gap_minutes IS NULL, 1, 0) new_session,
32 | sum(new_session) OVER (
33 | PARTITION BY station_service_date
34 | ORDER BY station_service_time ROWS UNBOUNDED PRECEDING
35 | ) AS session_id_in_day
36 | FROM ams_daily_traffic
37 | ), session_window AS (
38 | SELECT
39 | station_service_date,
40 | session_id_in_day,
41 | max(gap_minutes) AS gap_minutes,
42 | min(station_service_time) AS window_start,
43 | max(station_service_time) AS window_end,
44 | count(service_sk) AS number_of_arrivals
45 | FROM window_calculation
46 | GROUP BY ALL
47 | )
48 | SELECT
49 | station_service_date,
50 | session_id_in_day,
51 | max(gap_minutes) AS gap_minutes,
52 | min(station_service_time) AS window_start,
53 | max(station_service_time) AS window_end,
54 | count(service_sk) AS number_of_services
55 | FROM window_calculation
56 | GROUP BY ALL
57 | """
58 | )
59 |
60 |
61 | @log_durations(logging.info)
62 | def get_top5_day_with_most_gaps(duckdb_conn):
63 | session_window = get_session_window(duckdb_conn)
64 |
65 | return (
66 | session_window.aggregate(
67 | """
68 | station_service_date,
69 | max(ceil(date_diff('minute', window_start, window_end) / 60)) AS number_of_hours_without_gap,
70 | count(*) AS number_of_sessions
71 | """
72 | )
73 | .filter("number_of_hours_without_gap")
74 | .order("number_of_sessions desc, station_service_date")
75 | .limit(5)
76 | )
77 |
78 |
79 | @log_durations(logging.info)
80 | def save_session_window_px(duckdb_conn, day_with_most_gaps):
81 | df = (
82 | get_session_window(duckdb_conn)
83 | .filter(f"station_service_date = '{day_with_most_gaps}'")
84 | .df()
85 | )
86 | unique_gap = df["gap_minutes"].sort_values().unique()
87 | fig = px.timeline(
88 | df,
89 | x_start="window_start",
90 | x_end="window_end",
91 | y="gap_minutes",
92 | title=f"Session Windows on {day_with_most_gaps}",
93 | category_orders={"gap_minutes": unique_gap},
94 | )
95 |
96 | fig.update_yaxes(autorange=True)
97 |
98 | all_ticks = pd.concat([df["window_start"], df["window_end"]]).sort_values().unique()
99 | fig.update_layout(
100 | xaxis=dict(
101 | tickmode="array",
102 | tickvals=all_ticks,
103 | tickformat="%H:%M",
104 | tickangle=90,
105 | tickfont=dict(size=8, family="Arial Bold"),
106 | ),
107 | xaxis_range=[
108 | df["window_start"].min() - pd.Timedelta(minutes=5),
109 | df["window_end"].max() + pd.Timedelta(minutes=5),
110 | ],
111 | xaxis_title="Time",
112 | yaxis_title="Duration of Service Inactivity, in minutes",
113 | )
114 |
115 | for t in all_ticks:
116 | fig.add_vline(x=t, line_width=1, line_dash="dot", line_color="gray")
117 |
118 | fig.update_layout(yaxis={"tickvals": unique_gap, "type": "category"})
119 |
120 | fig.update_yaxes(categoryorder="array", categoryarray=[str(v) for v in unique_gap])
121 |
122 | fig.write_html(f"{PLOTS_DIR_PATH}/session_window.html")
123 |
124 |
125 | @log_durations(logging.info)
126 | def main():
127 | duckdb_conn = get_duckdb_conn()
128 | most_detected_gaps = get_top5_day_with_most_gaps(duckdb_conn)
129 | logging.info(most_detected_gaps.show())
130 | save_session_window_px(
131 | duckdb_conn, day_with_most_gaps=most_detected_gaps.fetchone()[0]
132 | )
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
137 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/analyses/time_based_functions/sliding_window.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from funcy import log_durations
4 |
5 | from utils import get_duckdb_conn
6 |
7 |
8 | @log_durations(logging.info)
9 | def get_sliding_window(duckdb_conn):
10 | return duckdb_conn.sql(
11 | f"""
12 | SELECT
13 | station_service_time - INTERVAL '15' MINUTE AS window_start,
14 | station_service_time AS window_end,
15 | count(service_sk) OVER (
16 | ORDER BY station_service_time
17 | RANGE
18 | BETWEEN INTERVAL '15' MINUTE PRECEDING
19 | AND CURRENT ROW
20 | ) AS number_of_services
21 | FROM ams_traffic_v
22 | ORDER BY 3 DESC, 1
23 | LIMIT 5
24 | """
25 | )
26 |
27 |
28 | @log_durations(logging.info)
29 | def main():
30 | duckdb_conn = get_duckdb_conn()
31 | sliding_window = get_sliding_window(duckdb_conn)
32 |
33 | logging.info(sliding_window.show())
34 |
35 |
36 | if __name__ == "__main__":
37 | main()
38 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/analyses/time_based_functions/tumbling_window.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import plotly.express as px
4 | from funcy import log_durations
5 |
6 | from utils import PLOTS_DIR_PATH, get_duckdb_conn
7 |
8 |
9 | @log_durations(logging.info)
10 | def get_hour_tumbling_window_df(duckdb_conn):
11 | return duckdb_conn.sql(
12 | """
13 | SELECT
14 | date_trunc('hour', station_service_time) station_service_time_hour,
15 | count(*) AS number_of_services
16 | FROM ams_traffic_v
17 | WHERE year(station_service_time) = 2024
18 | GROUP BY ALL
19 | ORDER BY 1
20 | """
21 | ).df()
22 |
23 |
24 | @log_durations(logging.info)
25 | def save_hour_tumbling_px(duckdb_conn):
26 | fig = px.line(
27 | get_hour_tumbling_window_df(duckdb_conn),
28 | x="station_service_time_hour",
29 | y="number_of_services",
30 | title="Hourly Train Services, 2024",
31 | )
32 |
33 | fig.write_html(f"{PLOTS_DIR_PATH}/hour_tumbling_window.html")
34 |
35 |
36 | @log_durations(logging.info)
37 | def get_quarter_tumbling_window_df(duckdb_conn):
38 | return duckdb_conn.sql(
39 | """
40 | SELECT
41 | strftime('%H:%M', time_bucket(
42 | INTERVAL '15' MINUTE, -- bucket width
43 | station_service_time,
44 | INTERVAL '0' MINUTE -- offset
45 | )) AS station_service_time_hour_quarter,
46 | count(*) / 366 AS number_of_services
47 | FROM ams_traffic_v
48 | WHERE year(station_service_time) = 2024
49 | GROUP BY ALL
50 | ORDER BY 1
51 | """
52 | ).df()
53 |
54 |
55 | @log_durations(logging.info)
56 | def save_quarter_tumbling_px(duckdb_conn):
57 | fig = px.line(
58 | get_quarter_tumbling_window_df(duckdb_conn),
59 | x="station_service_time_hour_quarter",
60 | y="number_of_services",
61 | title="Average Number of Train Services, per 15 minutes, 2024",
62 | )
63 |
64 | fig.update_layout(xaxis={"dtick": 1})
65 | fig.write_html(f"{PLOTS_DIR_PATH}/hour_quarter_tumbling_window.html")
66 |
67 |
68 | @log_durations(logging.info)
69 | def main():
70 | duckdb_conn = get_duckdb_conn()
71 | save_hour_tumbling_px(duckdb_conn)
72 | save_quarter_tumbling_px(duckdb_conn)
73 |
74 |
75 | if __name__ == "__main__":
76 | main()
77 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/analyses/time_based_functions/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | import duckdb
5 | from funcy import log_durations
6 |
7 | logging.basicConfig(level=logging.INFO)
8 |
9 | DB_DIR_PATH = f"{Path(__file__).parent.parent.parent.absolute()}/data"
10 |
11 | PLOTS_DIR_PATH = f"{Path(__file__).parent.parent.absolute()}"
12 |
13 |
14 | @log_durations(logging.info)
15 | def get_duckdb_conn():
16 | duckdb_conn = duckdb.connect(
17 | f"{DB_DIR_PATH}/dutch_railway_network.duckdb",
18 | read_only=True,
19 | )
20 | duckdb_conn.sql("use main_main")
21 | return duckdb_conn
22 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duckdb/duckdb-blog-examples/f8b3dece7acf4d48c86a8bc55199bc96d4f66bf4/dbt_duckdb/dutch_railway_network/data/.gitkeep
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/data/exports/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duckdb/duckdb-blog-examples/f8b3dece7acf4d48c86a8bc55199bc96d4f66bf4/dbt_duckdb/dutch_railway_network/data/exports/.gitkeep
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/dbt_project.yml:
--------------------------------------------------------------------------------
1 |
2 | # Name your project! Project names should contain only lowercase characters
3 | # and underscores. A good package name should reflect your organization's
4 | # name or the intended use of these models
5 | name: 'dutch_railway_network'
6 | version: '1.0.0'
7 |
8 | # This setting configures which "profile" dbt uses for this project.
9 | profile: 'dutch_railway_network'
10 |
11 | flags:
12 | send_anonymous_usage_stats: false
13 |
14 | # These configurations specify where dbt should look for different types of files.
15 | # The `model-paths` config, for example, states that models in this project can be
16 | # found in the "models/" directory. You probably won't need to change these!
17 | model-paths: ["models"]
18 | analysis-paths: ["analyses"]
19 | test-paths: ["tests"]
20 | seed-paths: ["seeds"]
21 | macro-paths: ["macros"]
22 | snapshot-paths: ["snapshots"]
23 |
24 | clean-targets: # directories to be removed by `dbt clean`
25 | - "target"
26 | - "dbt_packages"
27 |
28 |
29 | # Configuring models
30 | # Full documentation: https://docs.getdbt.com/docs/configuring-models
31 |
32 | # In this example config, we tell dbt to build all models in the example/
33 | # directory as views. These settings can be overridden in the individual model
34 | # files using the `{{ config(...) }}` macro.
35 | models:
36 | dutch_railway_network:
37 | transformation:
38 | schema: main
39 | +docs:
40 | node_color: 'silver'
41 | reverse_etl:
42 | database: postgres_db
43 | schema: public
44 | +docs:
45 | node_color: '#d5b85a'
46 | exports:
47 | +docs:
48 | node_color: 'green'
49 |
50 | vars:
51 | execution_year: 2024
52 | execution_month: '202408'
53 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/macros/common_columns.sql:
--------------------------------------------------------------------------------
1 | {% macro common_columns() %}
2 | last_updated_dt: get_current_timestamp(),
3 | invocation_id: '{{ invocation_id }}'
4 | {% endmacro %}
5 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/exports/export_province_geojson.sql:
--------------------------------------------------------------------------------
1 | {{ config(
2 | materialized='external',
3 | location="data/exports/provinces.json"
4 | )
5 | }}
6 |
7 | WITH province_agg AS (
8 | SELECT
9 | json_group_array(
10 | json_object(
11 | 'type', 'Feature',
12 | 'properties', json_object('province_sk', province_sk),
13 | 'geometry', st_asgeojson(province_geometry)
14 | )
15 | ) AS features
16 | FROM {{ ref("dim_nl_provinces") }}
17 | )
18 | SELECT
19 | 'FeatureCollection' AS type,
20 | features
21 | FROM province_agg
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/exports/export_train_services_agg.sql:
--------------------------------------------------------------------------------
1 | {{ config(
2 | materialized='external',
3 | location="data/exports/nl_train_services_aggregate",
4 | options={"partition_by": "service_year, service_month", "overwrite": True}
5 | )
6 | }}
7 |
8 | SELECT
9 | year(service_date) AS service_year,
10 | month(service_date) AS service_month,
11 | service_type,
12 | service_company,
13 | tr_st.station_sk,
14 | tr_st.station_name,
15 | m.municipality_sk,
16 | m.municipality_name,
17 | p.province_sk,
18 | p.province_name,
19 | count(*) AS number_of_rides
20 | FROM {{ ref ("fact_services") }} AS srv
21 | INNER JOIN {{ ref("dim_nl_train_stations") }} AS tr_st
22 | ON srv.station_sk = tr_st.station_sk
23 | INNER JOIN {{ ref("dim_nl_municipalities") }} AS m
24 | ON tr_st.municipality_sk = m.municipality_sk
25 | INNER JOIN {{ ref("dim_nl_provinces") }} AS p
26 | ON m.province_sk = p.province_sk
27 | WHERE service_year = {{ var('execution_year') }}
28 | GROUP BY ALL
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/exports/schema.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | models:
4 | - name: export_train_services_agg
5 | description: ""
6 | columns:
7 | - name: service_type
8 | data_type: varchar
9 | description: "The service type"
10 |
11 | - name: service_company
12 | data_type: varchar
13 | description: "The service company"
14 |
15 | - name: station_sk
16 | data_type: varchar
17 | description: "The station surrogate key"
18 |
19 | - name: station_name
20 | data_type: varchar
21 | description: "The station name"
22 |
23 | - name: station_geo_location
24 | data_type: geometry
25 | description: "The station geo location"
26 |
27 | - name: municipality_sk
28 | data_type: varchar
29 | description: "The municipality surrogate key"
30 |
31 | - name: municipality_name
32 | data_type: varchar
33 | description: "The municipality name"
34 |
35 | - name: municipality_geometry
36 | data_type: varchar
37 | description: "The municipality geometry"
38 |
39 | - name: province_sk
40 | data_type: varchar
41 | description: "The province surrogate key"
42 |
43 | - name: province_name
44 | data_type: varchar
45 | description: "The province name"
46 |
47 | - name: province_geometry
48 | data_type: varchar
49 | description: ""
50 |
51 | - name: number_of_rides
52 | data_type: double
53 | description: "The number of rides on the service date"
54 |
55 | - name: service_month
56 | data_type: bigint
57 | description: ""
58 |
59 | - name: service_year
60 | data_type: bigint
61 | description: ""
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/reverse_etl/rep_dim_nl_municipalities.sql:
--------------------------------------------------------------------------------
1 | {{ config(
2 | materialized='table',
3 | post_hook = """
4 | call postgres_execute(
5 | '{{ this.database }}',
6 | '
7 | alter table {{ this.schema }}.rep_dim_nl_municipalities
8 | alter column municipality_geometry type geometry
9 | using ST_GeomFromWKB(decode(municipality_geometry, ''hex''))
10 | '
11 | )
12 | """
13 | )
14 | }}
15 |
16 | SELECT
17 | municipality_sk,
18 | municipality_name,
19 | st_ashexwkb(municipality_geometry) AS municipality_geometry,
20 | province_sk,
21 | {{ common_columns() }}
22 | FROM {{ ref("dim_nl_municipalities") }}
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/reverse_etl/rep_dim_nl_provinces.sql:
--------------------------------------------------------------------------------
1 | {{ config(
2 | materialized='table',
3 | post_hook = """
4 | call postgres_execute(
5 | '{{ this.database }}',
6 | '
7 | alter table {{ this.schema }}.rep_dim_nl_provinces
8 | alter column province_geometry type geometry
9 | using ST_GeomFromWKB(decode(province_geometry, ''hex''))
10 | '
11 | )
12 | """)
13 | }}
14 |
15 | SELECT
16 | province_sk,
17 | province_name,
18 | st_ashexwkb(province_geometry) AS province_geometry,
19 | {{ common_columns() }}
20 | FROM {{ ref("dim_nl_provinces") }}
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/reverse_etl/rep_dim_nl_train_stations.sql:
--------------------------------------------------------------------------------
1 | {{ config(
2 | materialized='table',
3 | post_hook = """
4 | call postgres_execute(
5 | '{{ this.database }}',
6 | '
7 | alter table {{ this.schema }}.rep_dim_nl_train_stations
8 | alter column station_geo_location type geometry
9 | using ST_GeomFromWKB(decode(station_geo_location, ''hex''))
10 | '
11 | )
12 | """) }}
13 |
14 | SELECT
15 | station_sk,
16 | station_code,
17 | station_name,
18 | station_type,
19 | st_ashexwkb(station_geo_location) AS station_geo_location,
20 | municipality_sk,
21 | {{ common_columns() }}
22 | FROM {{ ref("dim_nl_train_stations") }}
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/reverse_etl/rep_fact_train_services_daily_agg.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized='incremental',
4 | incremental_strategy='delete+insert',
5 | unique_key='service_date, service_type, service_company, station_sk'
6 | )
7 | }}
8 |
9 | SELECT
10 | service_date,
11 | service_type,
12 | service_company,
13 | srv.station_sk,
14 | mn.municipality_sk,
15 | province_sk,
16 | count(*) AS number_of_rides,
17 | {{ common_columns() }}
18 | FROM {{ ref ("fact_services") }} AS srv
19 | INNER JOIN {{ ref("rep_dim_nl_train_stations") }} AS tr_st
20 | ON srv.station_sk = tr_st.station_sk
21 | INNER JOIN {{ ref("rep_dim_nl_municipalities") }} AS mn
22 | ON tr_st.municipality_sk = mn.municipality_sk
23 | WHERE service_arrival_cancelled IS FALSE
24 |
25 | {% if is_incremental() %}
26 | AND srv.invocation_id = (
27 | SELECT invocation_id FROM {{ ref("fact_services") }}
28 | ORDER BY last_updated_dt DESC LIMIT 1
29 | )
30 | {% endif %}
31 | GROUP BY ALL
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/reverse_etl/schema.yml:
--------------------------------------------------------------------------------
1 |
2 | version: 2
3 |
4 | models:
5 | - name: rep_dim_nl_provinces
6 | description: "Dim table for NL provinces"
7 | columns:
8 | - name: province_sk
9 | data_type: varchar
10 | description: "The surrogate key"
11 | - name: province_id
12 | data_type: integer
13 | description: "The primary key in the source system"
14 | - name: province_name
15 | data_type: varchar
16 | description: "The province name"
17 | - name: province_geom
18 | data_type: geometry
19 | description: "The province geometry"
20 | - name: last_updated_at
21 | data_type: timestamp
22 | description: "Timestamp when the record was last updated"
23 | - name: invocation_id
24 | data_type: varchar
25 | description: "The dbt invocation id"
26 |
27 | - name: rep_dim_nl_municipalities
28 | description: "Dim table for NL municipalities"
29 | columns:
30 | - name: municipality_sk
31 | data_type: varchar
32 | description: "The surrogate key"
33 | - name: municipality_id
34 | data_type: integer
35 | description: "The primary key in the source data"
36 | - name: province_sk
37 | data_type: varchar
38 | description: "The province in which the municipality is located"
39 | tests:
40 | - relationships:
41 | to: ref('rep_dim_nl_provinces')
42 | field: province_sk
43 | - name: municipality_name
44 | data_type: varchar
45 | description: "The municipality name"
46 | - name: municipality_geometry
47 | data_type: geometry
48 | description: "The municipality geometry"
49 | - name: last_updated_at
50 | data_type: timestamp
51 | description: "Timestamp when the record was last updated"
52 | - name: invocation_id
53 | data_type: varchar
54 | description: "The dbt invocation id"
55 |
56 | - name: rep_dim_nl_train_stations
57 | description: "Dim table for NL train stations"
58 | columns:
59 | - name: station_sk
60 | data_type: varchar
61 | description: "The surrogate key"
62 | - name: station_id
63 | data_type: varchar
64 | description: "The primary key of this table in the source data"
65 | - name: municipality_sk
66 | data_type: varchar
67 | description: "The municipality in which the station is located"
68 | tests:
69 | - relationships:
70 | to: ref('rep_dim_nl_municipalities')
71 | field: municipality_sk
72 | - name: station_code
73 | data_type: varchar
74 | description: "The code of the station"
75 | - name: station_name
76 | data_type: varchar
77 | description: "The station name"
78 | - name: station_type
79 | data_type: varchar
80 | description: "The station type"
81 | - name: station_geo_location
82 | data_type: geometry
83 | description: "The station geo location"
84 | - name: last_updated_at
85 | data_type: timestamp
86 | description: "Timestamp when the record was last updated"
87 | - name: invocation_id
88 | data_type: varchar
89 | description: "The dbt invocation id"
90 |
91 | - name: rep_fact_train_services_daily_agg
92 | columns:
93 | - name: service_date
94 | data_type: date
95 | description: "The service date"
96 | - name: service_type
97 | data_type: varchar
98 | description: "The service type"
99 | - name: service_company
100 | data_type: varchar
101 | description: "The service company"
102 | - name: station_sk
103 | data_type: varchar
104 | description: "The station sk"
105 | tests:
106 | - relationships:
107 | to: ref('rep_dim_nl_train_stations')
108 | field: station_sk
109 | - name: municipality_sk
110 | data_type: varchar
111 | description: "The municipality sk"
112 | tests:
113 | - relationships:
114 | to: ref('rep_dim_nl_municipalities')
115 | field: municipality_sk
116 | - name: province_sk
117 | data_type: varchar
118 | description: "The province sk"
119 | tests:
120 | - relationships:
121 | to: ref('rep_dim_nl_provinces')
122 | field: province_sk
123 | - name: number_of_rides
124 | data_type: integer
125 | description: "The number of rides on the service date"
126 | - name: last_updated_at
127 | data_type: timestamp
128 | description: "Timestamp when the record was last updated"
129 | - name: invocation_id
130 | data_type: varchar
131 | description: "The dbt invocation id"
132 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/sources.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | sources:
3 | - name: geojson_external
4 | tables:
5 | - name: nl_provinces
6 | config:
7 | external_location: "https://cartomap.github.io/nl/wgs84/provincie_2025.geojson"
8 | - name: nl_municipalities
9 | config:
10 | external_location: "seeds/gemeente_2025.geojson"
11 | - name: external_db
12 | database: external_db
13 | schema: main
14 | tables:
15 | - name: stations
16 | - name: services
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/transformation/ams_traffic_v.sql:
--------------------------------------------------------------------------------
1 | {{ config(materialized='view') }}
2 |
3 | SELECT
4 | service_sk,
5 | if(station_arrival_time IS NULL, station_departure_time, station_arrival_time) AS station_service_time
6 | FROM {{ ref("fact_services") }} AS srv
7 | INNER JOIN {{ ref("dim_nl_train_stations") }} AS st
8 | ON srv.station_sk = st.station_sk
9 | WHERE station_name = 'Amsterdam Centraal'
10 | AND (service_arrival_cancelled = FALSE OR service_departure_cancelled = FALSE)
11 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/transformation/dim_nl_municipalities.sql:
--------------------------------------------------------------------------------
1 | {{ config(materialized='table') }}
2 |
3 | WITH covered_by_selection AS (
4 | SELECT
5 | id AS municipality_id,
6 | statnaam AS municipality_name,
7 | geom AS municipality_geometry,
8 | dim_prov.province_sk
9 | FROM st_read({{ source("geojson_external", "nl_municipalities") }}) AS dim_mun
10 | INNER JOIN {{ ref ("dim_nl_provinces") }} AS dim_prov
11 | ON st_covers(dim_prov.province_geometry, dim_mun.geom)
12 | ),
13 | ordered_by_difference_area AS (
14 | SELECT
15 | id AS municipality_id,
16 | statnaam AS municipality_name,
17 | geom AS municipality_geometry,
18 | dim_prov.province_sk
19 | FROM st_read({{ source("geojson_external", "nl_municipalities") }}) AS dim_mun,
20 | {{ ref ("dim_nl_provinces") }} AS dim_prov
21 | WHERE NOT EXISTS (
22 | SELECT 1 FROM covered_by_selection
23 | WHERE dim_mun.id = covered_by_selection.municipality_id
24 | )
25 | QUALIFY row_number() OVER (
26 | PARTITION BY municipality_id
27 | ORDER BY st_area(st_difference(dim_mun.geom, province_geometry))
28 | ) = 1
29 | )
30 | SELECT
31 | {{ dbt_utils.generate_surrogate_key(['municipality_id']) }} AS municipality_sk,
32 | src.*,
33 | {{ common_columns() }}
34 | FROM covered_by_selection AS src
35 | UNION
36 | SELECT
37 | {{ dbt_utils.generate_surrogate_key(['municipality_id']) }} AS municipality_sk,
38 | src.*,
39 | {{ common_columns() }}
40 | FROM ordered_by_difference_area AS src
41 | UNION
42 | SELECT
43 | 'unknown' AS municipality_sk,
44 | -1 AS municipality_id,
45 | 'unknown' AS municipality_name,
46 | NULL AS municipality_geometry,
47 | 'unknown' AS province_sk,
48 | {{ common_columns() }}
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/transformation/dim_nl_provinces.sql:
--------------------------------------------------------------------------------
1 | {{ config(materialized='table') }}
2 |
3 | SELECT
4 | {{ dbt_utils.generate_surrogate_key(['id']) }} AS province_sk,
5 | id AS province_id,
6 | statnaam AS province_name,
7 | geom AS province_geometry,
8 | {{ common_columns() }}
9 | FROM st_read({{ source("geojson_external", "nl_provinces") }}) AS src
10 | UNION ALL
11 | SELECT
12 | 'unknown' AS province_sk,
13 | -1 AS province_id,
14 | 'unknown' AS province_name,
15 | NULL AS province_geometry,
16 | {{ common_columns() }}
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/transformation/dim_nl_train_stations.sql:
--------------------------------------------------------------------------------
1 | {{ config(materialized='table') }}
2 |
3 | SELECT
4 | {{ dbt_utils.generate_surrogate_key(['tr_st.code']) }} AS station_sk,
5 | tr_st.id AS station_id,
6 | tr_st.code AS station_code,
7 | tr_st.name_long AS station_name,
8 | tr_st.type AS station_type,
9 | st_point(tr_st.geo_lng, tr_st.geo_lat) AS station_geo_location,
10 | coalesce(dim_mun.municipality_sk, 'unknown') AS municipality_sk,
11 | {{ common_columns() }}
12 | FROM {{ source("external_db", "stations") }} AS tr_st
13 | LEFT JOIN {{ ref ("dim_nl_municipalities") }} AS dim_mun
14 | ON st_contains(
15 | dim_mun.municipality_geometry,
16 | st_point(tr_st.geo_lng, tr_st.geo_lat)
17 | )
18 | WHERE tr_st.country = 'NL'
19 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/transformation/fact_services.sql:
--------------------------------------------------------------------------------
1 | {{ config(materialized='table') }}
2 |
3 | SELECT
4 | {{ dbt_utils.generate_surrogate_key(['"Service:RDT-ID"', 'station_sk']) }} AS service_sk,
5 | "Service:Date" AS service_date,
6 | "Service:Type" AS service_type,
7 | "Service:Company" AS service_company,
8 | station_sk,
9 | "Stop:Arrival time" AS station_arrival_time,
10 | "Stop:Departure time" AS station_departure_time,
11 | if("Stop:Arrival cancelled" IS NULL, FALSE, "Stop:Arrival cancelled") AS service_arrival_cancelled,
12 | "Service:Train number" AS service_train_number,
13 | if("Stop:Departure cancelled" IS NULL, FALSE, "Stop:Departure cancelled") AS service_departure_cancelled,
14 | {{ common_columns() }}
15 | FROM {{ source("external_db", "services") }} AS srv
16 | INNER JOIN {{ ref("dim_nl_train_stations") }} AS tr_st
17 | ON srv."Stop:Station Code" = tr_st.station_code
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/models/transformation/schema.yml:
--------------------------------------------------------------------------------
1 |
2 | version: 2
3 |
4 | models:
5 | - name: dim_nl_provinces
6 | description: "Dim table for NL provinces, example of SCD2"
7 | columns:
8 | - name: province_sk
9 | description: "The surrogate key"
10 | - name: province_id
11 | description: "The primary key in the source system"
12 | tests:
13 | - unique
14 | - not_null
15 | - name: province_name
16 | description: "The province name"
17 | - name: province_geom
18 | description: "The province geometry"
19 | - name: last_updated_dt
20 | description: "Timestamp when the record was last updated"
21 | - name: invocation_id
22 | description: "The dbt invocation id"
23 |
24 | - name: dim_nl_municipalities
25 | description: "Dim table for NL municipalities"
26 | columns:
27 | - name: municipality_sk
28 | description: "The surrogate key"
29 | - name: municipality_id
30 | description: "The primary key in the source data"
31 | tests:
32 | - unique
33 | - not_null
34 | - name: province_sk
35 | description: "The province in which the municipality is located"
36 | tests:
37 | - relationships:
38 | to: ref('dim_nl_provinces')
39 | field: province_sk
40 | - name: municipality_name
41 | description: "The municipality name"
42 | - name: municipality_geometry
43 | description: "The municipality geometry"
44 | - name: last_updated_dt
45 | description: "Timestamp when the record was last updated"
46 | - name: invocation_id
47 | description: "The dbt invocation id"
48 |
49 | - name: dim_nl_train_stations
50 | description: "Dim table for NL train stations"
51 | columns:
52 | - name: station_sk
53 | description: "The surrogate key"
54 | - name: station_id
55 | description: "The primary key of this table in the source data"
56 | tests:
57 | - unique
58 | - not_null
59 | - name: municipality_sk
60 | description: "The municipality id in which the station is located"
61 | tests:
62 | - relationships:
63 | to: ref('dim_nl_municipalities')
64 | field: municipality_sk
65 | - name: station_code
66 | description: "The code of the station"
67 | tests:
68 | - unique
69 | - not_null
70 | - name: station_name
71 | description: "The station name"
72 | - name: station_type
73 | description: "The station type"
74 | - name: station_geo_location
75 | description: "The station geo location"
76 | - name: last_updated_dt
77 | description: "Timestamp when the record was last updated"
78 | - name: invocation_id
79 | description: "The dbt invocation id"
80 |
81 | - name: fact_services
82 | columns:
83 | - name: service_sk
84 | description: "The surrogate key"
85 | tests:
86 | - unique
87 | - not_null
88 | - name: service_date
89 | description: "The service date"
90 | - name: service_type
91 | description: "The service type"
92 | - name: service_company
93 | description: "The service company"
94 | - name: service_date
95 | description: "The service date"
96 | - name: station_sk
97 | description: "The station surrogate key"
98 | tests:
99 | - relationships:
100 | to: ref('dim_nl_train_stations')
101 | field: station_sk
102 | - name: station_arrival_time
103 | description: "The arrival time in the station"
104 | - name: station_departure_time
105 | description: "The departure time from the station"
106 | - name: service_arrival_cancelled
107 | description: "Flag if the arrival was cancelled in the station"
108 | - name: last_updated_dt
109 | description: "Timestamp when the record was last updated"
110 | - name: invocation_id
111 | description: "The dbt invocation id"
112 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/package-lock.yml:
--------------------------------------------------------------------------------
1 | packages:
2 | - package: dbt-labs/dbt_utils
3 | version: 1.3.0
4 | - package: dbt-labs/codegen
5 | version: 0.13.1
6 | sha1_hash: 9c459bb513316be11ab55af0e5113f17444d082e
7 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/packages.yml:
--------------------------------------------------------------------------------
1 | packages:
2 | # run dbt deps to install
3 | - package: dbt-labs/dbt_utils
4 | version: 1.3.0
5 | - package: dbt-labs/codegen
6 | version: 0.13.1
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/profiles.yml:
--------------------------------------------------------------------------------
1 | dutch_railway_network:
2 |
3 | outputs:
4 | dev:
5 | type: duckdb
6 | path: data/dutch_railway_network.duckdb
7 | extensions:
8 | - spatial
9 | - httpfs
10 | - postgres
11 | threads: 1
12 | attach:
13 | - path: 'https://blobs.duckdb.org/nl-railway/train_stations_and_services.duckdb'
14 | type: duckdb
15 | alias: external_db
16 | - path: "postgresql://postgres:{{ env_var('DBT_DUCKDB_PG_PWD') }}@localhost:5466/postgres"
17 | type: postgres
18 | alias: postgres_db
19 | target: dev
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/snapshots/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duckdb/duckdb-blog-examples/f8b3dece7acf4d48c86a8bc55199bc96d4f66bf4/dbt_duckdb/dutch_railway_network/snapshots/.gitkeep
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/tests/test_province_municipality_relation.sql:
--------------------------------------------------------------------------------
1 | -- check that the province which contains the point is the same as the province associated through municipality
2 | SELECT
3 | station_sk,
4 | province_sk
5 | FROM {{ ref("dim_nl_train_stations") }} AS ts
6 | INNER JOIN {{ ref("dim_nl_municipalities") }} AS mn
7 | ON ts.municipality_sk = mn.municipality_sk
8 | WHERE ts.station_code NOT IN ('HLGH', 'EEM') -- issue with geo match
9 | EXCEPT
10 | SELECT
11 | station_sk,
12 | province_sk
13 | FROM {{ ref("dim_nl_train_stations") }} AS ts
14 | INNER JOIN {{ ref("dim_nl_provinces") }} AS p
15 | ON ST_CONTAINS(p.province_geometry, station_geo_location)
16 |
--------------------------------------------------------------------------------
/dbt_duckdb/dutch_railway_network/tests/test_rep_fact_services.sql:
--------------------------------------------------------------------------------
1 | SELECT COUNT(*)
2 | FROM {{ source("external_db", "services") }}
3 | WHERE IF("Stop:Arrival cancelled" is null, false, "Stop:Arrival cancelled") IS FALSE
4 | AND "Stop:Station Code" IN (
5 | SELECT code
6 | FROM {{ source("external_db", "stations") }}
7 | WHERE country = 'NL'
8 | )
9 | EXCEPT
10 | SELECT SUM(number_of_rides)
11 | FROM {{ ref('rep_fact_train_services_daily_agg') }}
--------------------------------------------------------------------------------
/dbt_duckdb/requirements.txt:
--------------------------------------------------------------------------------
1 | agate==1.9.1
2 | annotated-types==0.7.0
3 | attrs==25.3.0
4 | babel==2.17.0
5 | black==25.1.0
6 | certifi==2025.1.31
7 | chardet==5.2.0
8 | charset-normalizer==3.4.1
9 | click==8.1.8
10 | colorama==0.4.6
11 | daff==1.3.46
12 | dbt-adapters==1.14.3
13 | dbt-artifacts-parser==0.8.2
14 | dbt-common==1.16.0
15 | dbt-core==1.9.3
16 | dbt-duckdb==1.9.2
17 | dbt-extractor==0.5.1
18 | dbt-semantic-interfaces==0.7.4
19 | dbterd==1.18.0
20 | deepdiff==7.0.1
21 | diff_cover==9.2.4
22 | duckdb==1.2.1
23 | funcy==2.0
24 | idna==3.10
25 | importlib-metadata==6.11.0
26 | iniconfig==2.1.0
27 | isodate==0.6.1
28 | Jinja2==3.1.6
29 | jsonschema==4.23.0
30 | jsonschema-specifications==2024.10.1
31 | leather==0.4.0
32 | MarkupSafe==3.0.2
33 | mashumaro==3.14
34 | more-itertools==10.6.0
35 | msgpack==1.1.0
36 | mypy-extensions==1.0.0
37 | narwhals==1.32.0
38 | networkx==3.4.2
39 | numpy==2.2.4
40 | ordered-set==4.1.0
41 | packaging==24.2
42 | pandas==2.2.3
43 | parsedatetime==2.6
44 | pathspec==0.12.1
45 | platformdirs==4.3.7
46 | plotly==6.0.1
47 | pluggy==1.5.0
48 | protobuf==5.29.4
49 | pydantic==2.10.6
50 | pydantic_core==2.27.2
51 | Pygments==2.19.1
52 | pytest==8.3.5
53 | python-dateutil==2.9.0.post0
54 | python-slugify==8.0.4
55 | pytimeparse==1.1.8
56 | pytz==2025.2
57 | PyYAML==6.0.2
58 | referencing==0.36.2
59 | regex==2024.11.6
60 | requests==2.32.3
61 | rpds-py==0.23.1
62 | six==1.17.0
63 | snowplow-tracker==1.1.0
64 | sqlfluff==3.3.1
65 | sqlparse==0.5.3
66 | tblib==3.0.0
67 | text-unidecode==1.3
68 | tqdm==4.67.1
69 | typing_extensions==4.12.2
70 | tzdata==2025.2
71 | urllib3==2.3.0
72 | zipp==3.21.0
73 |
--------------------------------------------------------------------------------
/duckdb_streamlit/.gitignore:
--------------------------------------------------------------------------------
1 | venv_duckdb_streamlit
--------------------------------------------------------------------------------
/duckdb_streamlit/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: start-app, clean-up
2 |
3 | start-app:
4 | python -m venv venv_duckdb_streamlit && \
5 | source ./venv_duckdb_streamlit/bin/activate && \
6 | pip install -r requirements.txt && \
7 | streamlit run app.py
8 |
9 | clean-up:
10 | rm -rf venv_duckdb_streamlit
11 |
--------------------------------------------------------------------------------
/duckdb_streamlit/README.md:
--------------------------------------------------------------------------------
1 | Pre-requisites: make and Python >= 3.12.
2 |
3 | ## Local execution
4 | 1. Run `make start-app`
5 | 2. Go to the address specified in the log (it will take around 1 minute to spin up)
6 |
7 | ## Cleanup
8 |
9 | 1. Run `make clean-up`
10 |
--------------------------------------------------------------------------------
/duckdb_streamlit/app.py:
--------------------------------------------------------------------------------
1 | import plotly.express as px
2 | import streamlit as st
3 |
4 | from utils import get_duckdb_conn, get_stations_services_query
5 |
6 |
7 | def main():
8 |
9 | st.title("Analyzing Dutch Railway Data")
10 |
11 | duckdb_conn = get_duckdb_conn()
12 |
13 | # using Streamlit charts
14 | st.subheader("Number of train services in 2024")
15 | st.line_chart(
16 | duckdb_conn.sql("from services")
17 | .aggregate(
18 | """
19 | service_date: "Service:Date",
20 | service_month: monthname(service_date),
21 | service_month_id: month(service_date),
22 | num_services: count(distinct "Service:RDT-ID")
23 | """
24 | )
25 | .order("service_month_id")
26 | .df(),
27 | x="service_date",
28 | y="num_services",
29 | color="service_month",
30 | )
31 |
32 | # using Plotly charts
33 | st.plotly_chart(
34 | px.bar(
35 | get_top_5_stations_data(),
36 | x="service_month",
37 | y="num_services",
38 | color="station_name",
39 | barmode="group",
40 | title="Top 5 Busiest Train Stations 2024",
41 | labels={
42 | "service_month": "Month",
43 | "num_services": "Number Train Trips",
44 | "station_name": "Station Name",
45 | },
46 | )
47 | )
48 |
49 |
50 | def get_top_5_stations_data():
51 | stations_query, _ = get_stations_services_query(get_duckdb_conn())
52 |
53 | return (
54 | stations_query.aggregate(
55 | """
56 | station_name,
57 | service_month: monthname(service_date),
58 | service_month_id: month(service_date),
59 | num_services: sum(num_services)
60 | """
61 | )
62 | .select(
63 | """
64 | station_name,
65 | service_month,
66 | service_month_id,
67 | num_services,
68 | rn: row_number() over (partition by service_month order by num_services desc)
69 | """
70 | )
71 | .filter("rn <= 5")
72 | .order("service_month_id, station_name")
73 | .df()
74 | )
75 |
76 |
77 | if __name__ == "__main__":
78 | main()
79 |
--------------------------------------------------------------------------------
/duckdb_streamlit/constants.py:
--------------------------------------------------------------------------------
1 | DEFAULT_LAT = 52.20528
2 | DEFAULT_LNG = 6.000556
3 |
--------------------------------------------------------------------------------
/duckdb_streamlit/pages/closest_train_stations.py:
--------------------------------------------------------------------------------
1 | import duckdb
2 | import folium
3 | import streamlit as st
4 | from streamlit_folium import st_folium
5 |
6 | from utils import get_duckdb_conn
7 |
8 | from constants import DEFAULT_LAT, DEFAULT_LNG
9 |
10 |
11 | def main():
12 | if "clicked_map" not in st.session_state:
13 | st.session_state.clicked_map = True
14 | st.session_state.clicked_location_lat = DEFAULT_LAT
15 | st.session_state.clicked_location_lng = DEFAULT_LNG
16 | st.session_state.clicked_location_zoom = 8
17 |
18 | st.subheader(
19 | f"Closest 5 train stations to [{st.session_state.clicked_location_lat:.2f}, {st.session_state.clicked_location_lng:.2f}]",
20 | anchor=False,
21 | )
22 |
23 | # display the map together with the markers
24 | user_map = st_folium(
25 | get_map(
26 | lat=st.session_state.clicked_location_lat,
27 | lng=st.session_state.clicked_location_lng,
28 | zoom=st.session_state.clicked_location_zoom,
29 | ),
30 | key="user-map",
31 | height=600,
32 | width=800
33 | )
34 |
35 | # rerun the application on click
36 | if user_map.get("last_clicked"):
37 | st.session_state.clicked_location_lat = user_map["last_clicked"]["lat"]
38 | st.session_state.clicked_location_lng = user_map["last_clicked"]["lng"]
39 | st.session_state.clicked_location_zoom = user_map["zoom"]
40 | st.rerun()
41 |
42 |
43 | def get_map(lat, lng, zoom):
44 | # create the folium map, with the center at the latitude and longitude provided as input
45 | folium_map = folium.Map(
46 | location=[
47 | lat,
48 | lng,
49 | ],
50 | zoom_start=zoom,
51 | height=600,
52 | width=800
53 | )
54 |
55 | # add a marker of blue color and user icon at the location provided
56 | folium.Marker(
57 | location=[
58 | lat,
59 | lng,
60 | ],
61 | icon=folium.Icon(icon="user", prefix="fa", color="blue"),
62 | draggable=False,
63 | ).add_to(folium_map)
64 |
65 | # get the closest train stations to the location provided
66 |
67 | duckdb_conn = get_duckdb_conn()
68 | closest_stations_detailed_query,_ = get_closest_stations_detailed_query(duckdb_conn, lat, lng)
69 |
70 | # iterate over the list of records
71 | for x in closest_stations_detailed_query.fetchall():
72 | # for each train station add a marker to the map at the location of the train station
73 | # and add to the popup the information about the train station
74 | folium.Marker(
75 | location=[x[1], x[2]],
76 | draggable=False,
77 | icon=folium.Icon(color=x[6]),
78 | popup=folium.Popup(
79 | f"""
80 | Station: {x[0]}
81 | Location: [{x[1]},{x[2]}]
82 | Distance: {x[3]} km
83 | Number of Services: {x[4]:,}
84 | Number of Cancellations: {x[5]:,}
85 | """,
86 | max_width=200,
87 | ),
88 | ).add_to(folium_map)
89 |
90 | return folium_map
91 |
92 |
93 |
94 | def get_closest_stations_query(duckdb_conn, lat, lng):
95 | stations_selection = duckdb_conn.sql("""
96 | select name_long as station_name, geo_lat, geo_lng, code
97 | from stations st
98 | where exists (
99 | select count(*)
100 | from services sv
101 | where st.code = sv."Stop:Station code"
102 | having count(*)>100
103 | )
104 | """)
105 |
106 | return (
107 | stations_selection.project(f"""
108 | code as station_code,
109 | station_name,
110 | geo_lat,
111 | geo_lng,
112 | station_geo_point: st_point(geo_lng, geo_lat),
113 | clicked_geo_point: st_point({lng}, {lat}),
114 | distance_in_m: st_distance_sphere(st_point(geo_lng, geo_lat), clicked_geo_point),
115 | distance_in_km: round(distance_in_m/1000,2)
116 | """)
117 | .order("distance_in_km")
118 | .limit(5)
119 | )
120 |
121 |
122 | def get_closest_stations_detailed_query(duckdb_conn, lat, lng):
123 | services = duckdb_conn.sql("from services").set_alias("services")
124 | closest_stations = get_closest_stations_query(duckdb_conn, lat, lng).set_alias("closest_stations")
125 |
126 | return (
127 | services.join(
128 | closest_stations,
129 | 'services."Stop:Station code" = closest_stations.station_code',
130 | )
131 | .aggregate("""
132 | station_name,
133 | geo_lat,
134 | geo_lng,
135 | distance_in_km,
136 | num_cancelled_at_departure: sum(coalesce("Stop:Departure cancelled", false)),
137 | num_cancelled_at_arrival: sum(coalesce("Stop:Arrival cancelled", false)),
138 | num_services: count(*)
139 | """)
140 | .select("""
141 | station_name,
142 | geo_lat,
143 | geo_lng,
144 | distance_in_km,
145 | num_services,
146 | num_cancellations: num_cancelled_at_arrival + num_cancelled_at_departure,
147 | color: case row_number() over (order by num_services desc)
148 | when 1 then 'darkred'
149 | when 2 then 'red'
150 | when 3 then 'orange'
151 | when 4 then 'darkgreen'
152 | when 5 then 'green'
153 | else 'green' end
154 | """)
155 | ), duckdb_conn
156 |
157 |
158 |
159 | if __name__ == "__main__":
160 | main()
161 |
--------------------------------------------------------------------------------
/duckdb_streamlit/pages/railway_network_utilization.py:
--------------------------------------------------------------------------------
1 | import plotly.express as px
2 | import streamlit as st
3 |
4 | from constants import DEFAULT_LAT, DEFAULT_LNG
5 | from utils import (
6 | get_duckdb_conn,
7 | get_stations_services_query,
8 | )
9 |
10 |
11 | def main():
12 |
13 | duckdb_conn = get_duckdb_conn()
14 |
15 | with st.expander("Show railway network utilization during the year"):
16 |
17 | st.plotly_chart(get_utilization_during_year(duckdb_conn))
18 |
19 | with st.expander("Show overall railway network utilization across the country"):
20 | st.plotly_chart(get_utilization_across_country(duckdb_conn))
21 |
22 | with st.expander(
23 | "Show animation of railway network utilization across the country"
24 | ):
25 | st.plotly_chart(get_animated_utilization_across_country(duckdb_conn))
26 |
27 |
28 | def get_utilization_during_year(duckdb_conn):
29 | heatmap_df = get_stations_services_data(duckdb_conn)
30 | heatmap_df.set_index("service_day", inplace=True)
31 |
32 | fig = px.imshow(
33 | heatmap_df.to_numpy(),
34 | x=list(heatmap_df.columns),
35 | y=list(heatmap_df.index),
36 | color_continuous_scale="viridis",
37 | text_auto=".2s",
38 | aspect="auto",
39 | )
40 | fig.update_xaxes(side="top", title="Number of train rides in 2024")
41 |
42 | return fig
43 |
44 |
45 | @st.cache_data
46 | def get_stations_services_data(_duckdb_conn):
47 | query = _duckdb_conn.sql("from services").aggregate("""
48 | service_day: dayname("Service:Date"),
49 | service_day_isodow: isodow("Service:Date"),
50 | service_month: monthname("Service:Date"),
51 | num_services: count(distinct "Service:RDT-ID")
52 | """)
53 |
54 | return (
55 | _duckdb_conn.sql(f"""
56 | pivot ({query.sql_query()})
57 | on service_month
58 | using sum(num_services)
59 | group by service_day, service_day_isodow
60 | order by service_day_isodow
61 | """)
62 | .select(
63 | "January",
64 | "February",
65 | "March",
66 | "April",
67 | "May",
68 | "June",
69 | "July",
70 | "August",
71 | "September",
72 | "October",
73 | "November",
74 | "December",
75 | "service_day",
76 | )
77 | .df()
78 | )
79 |
80 | def get_utilization_across_country(duckdb_conn):
81 | stations_query, _ = get_stations_services_query(duckdb_conn)
82 | stations_agg_df = stations_query.aggregate(
83 | "geo_lat, geo_lng, num_services: sum(num_services)"
84 | ).df()
85 |
86 | return px.density_map(
87 | stations_agg_df,
88 | lat="geo_lat",
89 | lon="geo_lng",
90 | z="num_services",
91 | radius=5,
92 | center=dict(lat=DEFAULT_LAT, lon=DEFAULT_LNG),
93 | zoom=6.5,
94 | map_style="open-street-map",
95 | color_continuous_scale="viridis",
96 | range_color=[0, 100000],
97 | width=1000,
98 | height=600,
99 | title="Railway Network Utilization 2024",
100 | )
101 |
102 |
103 | def get_animated_utilization_across_country(duckdb_conn):
104 | stations_query, _ = get_stations_services_query(duckdb_conn)
105 |
106 | stations_df = stations_query.filter("month(service_date) = 7").order("service_date").df()
107 |
108 | fig = px.density_map(
109 | stations_df,
110 | lat="geo_lat",
111 | lon="geo_lng",
112 | z="num_services",
113 | radius=7,
114 | center=dict(lat=DEFAULT_LAT, lon=DEFAULT_LNG),
115 | zoom=5,
116 | map_style="open-street-map",
117 | color_continuous_scale="viridis",
118 | range_color=[0, 700],
119 | animation_frame="service_date_format",
120 | title="Railway Network Utilization, July 2024",
121 | )
122 |
123 | fig.update_layout(
124 | width=1000,
125 | height=600,
126 | sliders=[{"currentvalue": {"prefix": None, "font": {"size": 16}}}],
127 | updatemenus=[
128 | {
129 | "buttons": [
130 | {
131 | "args": [
132 | None,
133 | {
134 | "frame": {"duration": 300, "redraw": True},
135 | "fromcurrent": True,
136 | },
137 | ],
138 | "label": "Play",
139 | "method": "animate",
140 | },
141 | {
142 | "args": [
143 | [None],
144 | {
145 | "frame": {"duration": 0, "redraw": True},
146 | "mode": "immediate",
147 | "transition": {"duration": 0},
148 | },
149 | ],
150 | "label": "Stop",
151 | "method": "animate",
152 | },
153 | {
154 | "args": [
155 | None,
156 | {
157 | "frame": {"duration": 100, "redraw": True},
158 | "fromcurrent": True,
159 | },
160 | ],
161 | "label": "Speed x 3",
162 | "method": "animate",
163 | },
164 | ],
165 | }
166 | ],
167 | )
168 |
169 | return fig
170 |
171 |
172 |
173 | if __name__ == "__main__":
174 | main()
175 |
--------------------------------------------------------------------------------
/duckdb_streamlit/requirements.txt:
--------------------------------------------------------------------------------
1 | altair==5.5.0
2 | attrs==25.3.0
3 | blinker==1.9.0
4 | branca==0.8.1
5 | cachetools==5.5.2
6 | certifi==2025.1.31
7 | charset-normalizer==3.4.1
8 | click==8.1.8
9 | duckdb==1.2.1
10 | folium==0.19.5
11 | gitdb==4.0.12
12 | GitPython==3.1.44
13 | idna==3.10
14 | Jinja2==3.1.6
15 | jsonschema==4.23.0
16 | jsonschema-specifications==2024.10.1
17 | MarkupSafe==3.0.2
18 | narwhals==1.31.0
19 | numpy==2.2.4
20 | packaging==24.2
21 | pandas==2.2.3
22 | pillow==11.1.0
23 | plotly==6.0.1
24 | protobuf==5.29.4
25 | pyarrow==19.0.1
26 | pydeck==0.9.1
27 | python-dateutil==2.9.0.post0
28 | pytz==2025.1
29 | referencing==0.36.2
30 | requests==2.32.3
31 | rpds-py==0.23.1
32 | six==1.17.0
33 | smmap==5.0.2
34 | streamlit==1.43.2
35 | streamlit_folium==0.24.0
36 | tenacity==9.0.0
37 | toml==0.10.2
38 | tornado==6.4.2
39 | typing_extensions==4.12.2
40 | tzdata==2025.2
41 | urllib3==2.3.0
42 | xyzservices==2025.1.0
43 |
--------------------------------------------------------------------------------
/duckdb_streamlit/utils.py:
--------------------------------------------------------------------------------
1 | import duckdb
2 | import streamlit as st
3 |
4 |
5 | @st.cache_resource
6 | def get_duckdb_conn():
7 | duckdb_conn = duckdb.connect()
8 | duckdb_conn.sql(
9 | "attach 'https://blobs.duckdb.org/nl-railway/train_stations_and_services.duckdb' as external_db"
10 | )
11 | duckdb_conn.sql("use external_db")
12 | duckdb_conn.sql("install spatial")
13 | duckdb_conn.sql("load spatial")
14 |
15 | return duckdb_conn
16 |
17 |
18 | def get_stations_services_query(duckdb_conn):
19 | # create a relation for the station selection
20 | stations_selection = duckdb_conn.sql(
21 | "select name_long as station_name, geo_lat, geo_lng, code from stations"
22 | ).set_alias("stations_selection")
23 |
24 | # create a relation for the services selection
25 | services_selection = (
26 | duckdb_conn.sql("from services")
27 | .aggregate(
28 | """
29 | station_code: "Stop:Station code",
30 | service_date: "Service:Date",
31 | service_date_format: strftime(service_date, '%d-%b (%A)'),
32 | num_services: count(*)
33 | """
34 | )
35 | .set_alias("services")
36 | )
37 |
38 | # return the query with joining stations and services and the duckdb_conn
39 | return (
40 | (
41 | stations_selection.join(
42 | services_selection, "services.station_code = stations_selection.code"
43 | ).select(
44 | """
45 | service_date,
46 | service_date_format,
47 | station_name,
48 | geo_lat,
49 | geo_lng,
50 | num_services
51 | """
52 | )
53 | ),
54 | duckdb_conn,
55 | )
56 |
--------------------------------------------------------------------------------
/guides/DuckDB_in_Jupyter_notebooks.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": []
7 | },
8 | "kernelspec": {
9 | "name": "python3",
10 | "display_name": "Python 3"
11 | },
12 | "language_info": {
13 | "name": "python"
14 | },
15 | "widgets": {
16 | "application/vnd.jupyter.widget-state+json": {
17 | "0382da2cd6204467a2469a6369fb4e94": {
18 | "model_module": "@jupyter-widgets/controls",
19 | "model_name": "FloatProgressModel",
20 | "model_module_version": "1.5.0",
21 | "state": {
22 | "_dom_classes": [],
23 | "_model_module": "@jupyter-widgets/controls",
24 | "_model_module_version": "1.5.0",
25 | "_model_name": "FloatProgressModel",
26 | "_view_count": null,
27 | "_view_module": "@jupyter-widgets/controls",
28 | "_view_module_version": "1.5.0",
29 | "_view_name": "ProgressView",
30 | "bar_style": "",
31 | "description": "",
32 | "description_tooltip": null,
33 | "layout": "IPY_MODEL_c94155f26b66480680c24d57719ab398",
34 | "max": 100,
35 | "min": 0,
36 | "orientation": "horizontal",
37 | "style": "IPY_MODEL_5ff54cd86d714982936073808df23f3c",
38 | "value": 100
39 | }
40 | },
41 | "c94155f26b66480680c24d57719ab398": {
42 | "model_module": "@jupyter-widgets/base",
43 | "model_name": "LayoutModel",
44 | "model_module_version": "1.2.0",
45 | "state": {
46 | "_model_module": "@jupyter-widgets/base",
47 | "_model_module_version": "1.2.0",
48 | "_model_name": "LayoutModel",
49 | "_view_count": null,
50 | "_view_module": "@jupyter-widgets/base",
51 | "_view_module_version": "1.2.0",
52 | "_view_name": "LayoutView",
53 | "align_content": null,
54 | "align_items": null,
55 | "align_self": null,
56 | "border": null,
57 | "bottom": null,
58 | "display": null,
59 | "flex": null,
60 | "flex_flow": null,
61 | "grid_area": null,
62 | "grid_auto_columns": null,
63 | "grid_auto_flow": null,
64 | "grid_auto_rows": null,
65 | "grid_column": null,
66 | "grid_gap": null,
67 | "grid_row": null,
68 | "grid_template_areas": null,
69 | "grid_template_columns": null,
70 | "grid_template_rows": null,
71 | "height": null,
72 | "justify_content": null,
73 | "justify_items": null,
74 | "left": null,
75 | "margin": null,
76 | "max_height": null,
77 | "max_width": null,
78 | "min_height": null,
79 | "min_width": null,
80 | "object_fit": null,
81 | "object_position": null,
82 | "order": null,
83 | "overflow": null,
84 | "overflow_x": null,
85 | "overflow_y": null,
86 | "padding": null,
87 | "right": null,
88 | "top": null,
89 | "visibility": null,
90 | "width": "auto"
91 | }
92 | },
93 | "5ff54cd86d714982936073808df23f3c": {
94 | "model_module": "@jupyter-widgets/controls",
95 | "model_name": "ProgressStyleModel",
96 | "model_module_version": "1.5.0",
97 | "state": {
98 | "_model_module": "@jupyter-widgets/controls",
99 | "_model_module_version": "1.5.0",
100 | "_model_name": "ProgressStyleModel",
101 | "_view_count": null,
102 | "_view_module": "@jupyter-widgets/base",
103 | "_view_module_version": "1.2.0",
104 | "_view_name": "StyleView",
105 | "bar_color": "black",
106 | "description_width": ""
107 | }
108 | }
109 | }
110 | }
111 | },
112 | "cells": [
113 | {
114 | "cell_type": "markdown",
115 | "source": [
116 | "# DuckDB in Jupyter Notebooks\n",
117 | "A streamlined workflow for SQL analysis with DuckDB and Jupyter"
118 | ],
119 | "metadata": {
120 | "id": "vQivFMys2vtz"
121 | }
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "source": [
126 | "## Library Import and Configuration"
127 | ],
128 | "metadata": {
129 | "id": "TxtOOY905TG5"
130 | }
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {
136 | "id": "cf49_HQa2o8h",
137 | "colab": {
138 | "base_uri": "https://localhost:8080/"
139 | },
140 | "outputId": "f797204e-7fe1-4560-de7b-b8aa8d7616dd"
141 | },
142 | "outputs": [
143 | {
144 | "output_type": "stream",
145 | "name": "stdout",
146 | "text": [
147 | "\u001b[33mWARNING: Skipping malloy as it is not installed.\u001b[0m\u001b[33m\n",
148 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.1/95.1 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
149 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m192.8/192.8 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
150 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.0/92.0 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
151 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.7/49.7 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
152 | "\u001b[?25h"
153 | ]
154 | }
155 | ],
156 | "source": [
157 | "!pip uninstall --quiet --yes malloy\n",
158 | "!pip install --quiet --upgrade duckdb\n",
159 | "!pip install --quiet jupysql==0.11.1\n",
160 | "!pip install --quiet duckdb-engine\n",
161 | "!pip install --quiet pandas\n",
162 | "!pip install --quiet matplotlib\n",
163 | "\n"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "source": [
169 | "import duckdb\n",
170 | "import pandas as pd\n",
171 | "# No need to import sqlalchemy or duckdb_engine\n",
172 | "# JupySQL will use SQLAlchemy to auto-detect the driver needed based on your connection string!\n",
173 | "\n",
174 | "# Import jupysql Jupyter extension to create SQL cells\n",
175 | "%load_ext sql"
176 | ],
177 | "metadata": {
178 | "id": "MJHaFyq_3I_5"
179 | },
180 | "execution_count": null,
181 | "outputs": []
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "source": [
186 | "We configure jupysql to return data as a Pandas dataframe and have less verbose output"
187 | ],
188 | "metadata": {
189 | "id": "javkpysP6I0W"
190 | }
191 | },
192 | {
193 | "cell_type": "code",
194 | "source": [
195 | "%config SqlMagic.autopandas = True\n",
196 | "%config SqlMagic.feedback = False\n",
197 | "%config SqlMagic.displaycon = False"
198 | ],
199 | "metadata": {
200 | "id": "CvY8OgfV3ckB"
201 | },
202 | "execution_count": null,
203 | "outputs": []
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "source": [
208 | "## Connecting to DuckDB\n",
209 | "Connect jupysql to DuckDB. You may either connect to an in memory DuckDB, or a file backed db."
210 | ],
211 | "metadata": {
212 | "id": "9Xq5eXmM5bUA"
213 | }
214 | },
215 | {
216 | "cell_type": "code",
217 | "source": [
218 | "# conn = duckdb.connect()\n",
219 | "# # conn = duckdb.connect(\"file.db\")\n",
220 | "\n",
221 | "# # use the DuckDB connection\n",
222 | "# %sql conn\n",
223 | "\n",
224 | "%sql duckdb:///:memory:/\n",
225 | "%sql SET python_scan_all_frames=true"
226 | ],
227 | "metadata": {
228 | "id": "8NW45gk13eoY",
229 | "colab": {
230 | "base_uri": "https://localhost:8080/",
231 | "height": 53
232 | },
233 | "outputId": "13c6c310-7340-474b-abf2-1fbe9b7c8a42"
234 | },
235 | "execution_count": null,
236 | "outputs": [
237 | {
238 | "output_type": "execute_result",
239 | "data": {
240 | "text/plain": [
241 | "Empty DataFrame\n",
242 | "Columns: [Success]\n",
243 | "Index: []"
244 | ],
245 | "text/html": [
246 | "\n",
247 | "
\n",
248 | "
\n",
249 | "\n",
262 | "
\n",
263 | " \n",
264 | " \n",
265 | " | \n",
266 | " Success | \n",
267 | "
\n",
268 | " \n",
269 | " \n",
270 | " \n",
271 | "
\n",
272 | "
\n",
273 | "
\n",
354 | "
\n"
355 | ],
356 | "application/vnd.google.colaboratory.intrinsic+json": {
357 | "type": "dataframe",
358 | "summary": "{\n \"name\": \"get_ipython()\",\n \"rows\": 0,\n \"fields\": [\n {\n \"column\": \"Success\",\n \"properties\": {\n \"dtype\": \"boolean\",\n \"num_unique_values\": 0,\n \"samples\": [],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
359 | }
360 | },
361 | "metadata": {},
362 | "execution_count": 4
363 | }
364 | ]
365 | },
366 | {
367 | "cell_type": "markdown",
368 | "source": [
369 | "## Querying DuckDB\n",
370 | "Single line SQL queries can be run using `%sql` at the start of a line. Query results will be displayed as a Pandas DF. Note the SQL syntax highlighting!"
371 | ],
372 | "metadata": {
373 | "id": "xUPJhKPH5N6D"
374 | }
375 | },
376 | {
377 | "cell_type": "code",
378 | "source": [
379 | "%sql SELECT 'Off and flying!' as a_duckdb_column"
380 | ],
381 | "metadata": {
382 | "colab": {
383 | "base_uri": "https://localhost:8080/",
384 | "height": 81
385 | },
386 | "id": "JboVd92U43VV",
387 | "outputId": "2c9e8830-f9be-4bcd-bd98-0aceb6461f15"
388 | },
389 | "execution_count": null,
390 | "outputs": [
391 | {
392 | "output_type": "execute_result",
393 | "data": {
394 | "text/plain": [
395 | " a_duckdb_column\n",
396 | "0 Off and flying!"
397 | ],
398 | "text/html": [
399 | "\n",
400 | " \n",
401 | "
\n",
402 | "\n",
415 | "
\n",
416 | " \n",
417 | " \n",
418 | " | \n",
419 | " a_duckdb_column | \n",
420 | "
\n",
421 | " \n",
422 | " \n",
423 | " \n",
424 | " 0 | \n",
425 | " Off and flying! | \n",
426 | "
\n",
427 | " \n",
428 | "
\n",
429 | "
\n",
430 | "
\n",
511 | "
\n"
512 | ],
513 | "application/vnd.google.colaboratory.intrinsic+json": {
514 | "type": "dataframe",
515 | "summary": "{\n \"name\": \"get_ipython()\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"a_duckdb_column\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"Off and flying!\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
516 | }
517 | },
518 | "metadata": {},
519 | "execution_count": 5
520 | }
521 | ]
522 | },
523 | {
524 | "cell_type": "markdown",
525 | "source": [
526 | "An entire Jupyter cell can be used as a SQL cell by placing `%%sql` at the start of the cell. Query results will be displayed as a Pandas DF."
527 | ],
528 | "metadata": {
529 | "id": "UbWLn9rD579W"
530 | }
531 | },
532 | {
533 | "cell_type": "code",
534 | "source": [
535 | "%%sql\n",
536 | "SELECT\n",
537 | " schema_name,\n",
538 | " function_name\n",
539 | "FROM duckdb_functions()\n",
540 | "ORDER BY ALL DESC\n",
541 | "LIMIT 5"
542 | ],
543 | "metadata": {
544 | "colab": {
545 | "base_uri": "https://localhost:8080/",
546 | "height": 206
547 | },
548 | "id": "ZEOoRI-u569E",
549 | "outputId": "64f52ece-7b92-40bc-8aa3-38b3d5b8aabb"
550 | },
551 | "execution_count": null,
552 | "outputs": [
553 | {
554 | "output_type": "execute_result",
555 | "data": {
556 | "text/plain": [
557 | " schema_name function_name\n",
558 | "0 pg_catalog shobj_description\n",
559 | "1 pg_catalog pg_typeof\n",
560 | "2 pg_catalog pg_type_is_visible\n",
561 | "3 pg_catalog pg_ts_template_is_visible\n",
562 | "4 pg_catalog pg_ts_parser_is_visible"
563 | ],
564 | "text/html": [
565 | "\n",
566 | " \n",
567 | "
\n",
568 | "\n",
581 | "
\n",
582 | " \n",
583 | " \n",
584 | " | \n",
585 | " schema_name | \n",
586 | " function_name | \n",
587 | "
\n",
588 | " \n",
589 | " \n",
590 | " \n",
591 | " 0 | \n",
592 | " pg_catalog | \n",
593 | " shobj_description | \n",
594 | "
\n",
595 | " \n",
596 | " 1 | \n",
597 | " pg_catalog | \n",
598 | " pg_typeof | \n",
599 | "
\n",
600 | " \n",
601 | " 2 | \n",
602 | " pg_catalog | \n",
603 | " pg_type_is_visible | \n",
604 | "
\n",
605 | " \n",
606 | " 3 | \n",
607 | " pg_catalog | \n",
608 | " pg_ts_template_is_visible | \n",
609 | "
\n",
610 | " \n",
611 | " 4 | \n",
612 | " pg_catalog | \n",
613 | " pg_ts_parser_is_visible | \n",
614 | "
\n",
615 | " \n",
616 | "
\n",
617 | "
\n",
618 | "
\n",
826 | "
\n"
827 | ],
828 | "application/vnd.google.colaboratory.intrinsic+json": {
829 | "type": "dataframe",
830 | "summary": "{\n \"name\": \"get_ipython()\",\n \"rows\": 5,\n \"fields\": [\n {\n \"column\": \"schema_name\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"pg_catalog\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"function_name\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"pg_typeof\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
831 | }
832 | },
833 | "metadata": {},
834 | "execution_count": 6
835 | }
836 | ]
837 | },
838 | {
839 | "cell_type": "markdown",
840 | "source": [
841 | "To return query results into a Pandas dataframe for future usage, use `<<` as an assignment operator. This can be used with both the `%sql` and `%%sql` Jupyter magics."
842 | ],
843 | "metadata": {
844 | "id": "8HtKdEcs6mvC"
845 | }
846 | },
847 | {
848 | "cell_type": "code",
849 | "source": [
850 | "%sql my_df << SELECT 'Off and flying!' as a_duckdb_column\n",
851 | "my_df"
852 | ],
853 | "metadata": {
854 | "colab": {
855 | "base_uri": "https://localhost:8080/",
856 | "height": 89
857 | },
858 | "id": "GQpzinPH5GvF",
859 | "outputId": "00217e80-af08-4a92-ae8a-000edae8cdc4"
860 | },
861 | "execution_count": null,
862 | "outputs": [
863 | {
864 | "output_type": "execute_result",
865 | "data": {
866 | "text/plain": [
867 | " a_duckdb_column\n",
868 | "0 Off and flying!"
869 | ],
870 | "text/html": [
871 | "\n",
872 | " \n",
873 | "
\n",
874 | "\n",
887 | "
\n",
888 | " \n",
889 | " \n",
890 | " | \n",
891 | " a_duckdb_column | \n",
892 | "
\n",
893 | " \n",
894 | " \n",
895 | " \n",
896 | " 0 | \n",
897 | " Off and flying! | \n",
898 | "
\n",
899 | " \n",
900 | "
\n",
901 | "
\n",
902 | "
\n",
1038 | "
\n"
1039 | ],
1040 | "application/vnd.google.colaboratory.intrinsic+json": {
1041 | "type": "dataframe",
1042 | "variable_name": "my_df",
1043 | "summary": "{\n \"name\": \"my_df\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"a_duckdb_column\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"Off and flying!\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
1044 | }
1045 | },
1046 | "metadata": {},
1047 | "execution_count": 7
1048 | }
1049 | ]
1050 | },
1051 | {
1052 | "cell_type": "markdown",
1053 | "source": [
1054 | "## Querying Pandas Dataframes\n",
1055 | "DuckDB is able to find and query any dataframe stored as a variable in the Jupyter notebook."
1056 | ],
1057 | "metadata": {
1058 | "id": "ZHe_uG2666Zv"
1059 | }
1060 | },
1061 | {
1062 | "cell_type": "code",
1063 | "source": [
1064 | "input_df = pd.DataFrame.from_dict({\"i\":[1, 2, 3],\n",
1065 | " \"j\":[\"one\", \"two\", \"three\"]})"
1066 | ],
1067 | "metadata": {
1068 | "id": "4qgw6C644LaB"
1069 | },
1070 | "execution_count": null,
1071 | "outputs": []
1072 | },
1073 | {
1074 | "cell_type": "code",
1075 | "source": [
1076 | "duckdb.execute('''SELECT sum(i) as total_i FROM input_df''').df()"
1077 | ],
1078 | "metadata": {
1079 | "colab": {
1080 | "base_uri": "https://localhost:8080/",
1081 | "height": 81
1082 | },
1083 | "id": "ZW8NTp8VSrq3",
1084 | "outputId": "ed0e855c-2a90-4fe0-fbdb-4599b37c46bd"
1085 | },
1086 | "execution_count": null,
1087 | "outputs": [
1088 | {
1089 | "output_type": "execute_result",
1090 | "data": {
1091 | "text/plain": [
1092 | " total_i\n",
1093 | "0 6.0"
1094 | ],
1095 | "text/html": [
1096 | "\n",
1097 | " \n",
1098 | "
\n",
1099 | "\n",
1112 | "
\n",
1113 | " \n",
1114 | " \n",
1115 | " | \n",
1116 | " total_i | \n",
1117 | "
\n",
1118 | " \n",
1119 | " \n",
1120 | " \n",
1121 | " 0 | \n",
1122 | " 6.0 | \n",
1123 | "
\n",
1124 | " \n",
1125 | "
\n",
1126 | "
\n",
1127 | "
\n",
1208 | "
\n"
1209 | ],
1210 | "application/vnd.google.colaboratory.intrinsic+json": {
1211 | "type": "dataframe",
1212 | "summary": "{\n \"name\": \"duckdb\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"total_i\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": null,\n \"min\": 6.0,\n \"max\": 6.0,\n \"num_unique_values\": 1,\n \"samples\": [\n 6.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
1213 | }
1214 | },
1215 | "metadata": {},
1216 | "execution_count": 9
1217 | }
1218 | ]
1219 | },
1220 | {
1221 | "cell_type": "code",
1222 | "source": [
1223 | "%%sql\n",
1224 | "from duckdb_settings()\n",
1225 | "where\n",
1226 | " name = 'enable_external_access'"
1227 | ],
1228 | "metadata": {
1229 | "colab": {
1230 | "base_uri": "https://localhost:8080/",
1231 | "height": 81
1232 | },
1233 | "id": "3Asn-9I1RZRE",
1234 | "outputId": "72f5754c-6919-403d-ebc0-265a094b6a23"
1235 | },
1236 | "execution_count": null,
1237 | "outputs": [
1238 | {
1239 | "output_type": "execute_result",
1240 | "data": {
1241 | "text/plain": [
1242 | " name value \\\n",
1243 | "0 enable_external_access true \n",
1244 | "\n",
1245 | " description input_type scope \n",
1246 | "0 Allow the database to access external state (t... BOOLEAN GLOBAL "
1247 | ],
1248 | "text/html": [
1249 | "\n",
1250 | " \n",
1251 | "
\n",
1252 | "\n",
1265 | "
\n",
1266 | " \n",
1267 | " \n",
1268 | " | \n",
1269 | " name | \n",
1270 | " value | \n",
1271 | " description | \n",
1272 | " input_type | \n",
1273 | " scope | \n",
1274 | "
\n",
1275 | " \n",
1276 | " \n",
1277 | " \n",
1278 | " 0 | \n",
1279 | " enable_external_access | \n",
1280 | " true | \n",
1281 | " Allow the database to access external state (t... | \n",
1282 | " BOOLEAN | \n",
1283 | " GLOBAL | \n",
1284 | "
\n",
1285 | " \n",
1286 | "
\n",
1287 | "
\n",
1288 | "
\n",
1369 | "
\n"
1370 | ],
1371 | "application/vnd.google.colaboratory.intrinsic+json": {
1372 | "type": "dataframe",
1373 | "summary": "{\n \"name\": \"get_ipython()\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"name\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"enable_external_access\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"value\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"true\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"description\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"Allow the database to access external state (through e.g. loading/installing modules, COPY TO/FROM, CSV readers, pandas replacement scans, etc)\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"input_type\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"BOOLEAN\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"scope\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 1,\n \"samples\": [\n \"GLOBAL\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
1374 | }
1375 | },
1376 | "metadata": {},
1377 | "execution_count": 10
1378 | }
1379 | ]
1380 | },
1381 | {
1382 | "cell_type": "markdown",
1383 | "source": [
1384 | "The dataframe being queried can be specified just like any other table in the `FROM` clause."
1385 | ],
1386 | "metadata": {
1387 | "id": "siVry2OI7HwF"
1388 | }
1389 | },
1390 | {
1391 | "cell_type": "code",
1392 | "source": [
1393 | "%sql output_df << SELECT sum(i) as total_i FROM input_df\n",
1394 | "output_df"
1395 | ],
1396 | "metadata": {
1397 | "colab": {
1398 | "base_uri": "https://localhost:8080/",
1399 | "height": 89
1400 | },
1401 | "id": "uNxSRUVu4YvY",
1402 | "outputId": "dc7a6424-88e1-415a-8895-5a5792c0d36a"
1403 | },
1404 | "execution_count": null,
1405 | "outputs": [
1406 | {
1407 | "output_type": "execute_result",
1408 | "data": {
1409 | "text/plain": [
1410 | " total_i\n",
1411 | "0 6.0"
1412 | ],
1413 | "text/html": [
1414 | "\n",
1415 | " \n",
1416 | "
\n",
1417 | "\n",
1430 | "
\n",
1431 | " \n",
1432 | " \n",
1433 | " | \n",
1434 | " total_i | \n",
1435 | "
\n",
1436 | " \n",
1437 | " \n",
1438 | " \n",
1439 | " 0 | \n",
1440 | " 6.0 | \n",
1441 | "
\n",
1442 | " \n",
1443 | "
\n",
1444 | "
\n",
1445 | "
\n",
1581 | "
\n"
1582 | ],
1583 | "application/vnd.google.colaboratory.intrinsic+json": {
1584 | "type": "dataframe",
1585 | "variable_name": "output_df",
1586 | "summary": "{\n \"name\": \"output_df\",\n \"rows\": 1,\n \"fields\": [\n {\n \"column\": \"total_i\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": null,\n \"min\": 6.0,\n \"max\": 6.0,\n \"num_unique_values\": 1,\n \"samples\": [\n 6.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
1587 | }
1588 | },
1589 | "metadata": {},
1590 | "execution_count": 11
1591 | }
1592 | ]
1593 | },
1594 | {
1595 | "cell_type": "markdown",
1596 | "source": [
1597 | "## Visualizing DuckDB Data\n",
1598 | "The most common way to plot datasets in Python is to load them using pandas and then use matplotlib or seaborn for plotting.\n",
1599 | "This approach requires loading all data into memory which is highly inefficient.\n",
1600 | "The plotting module in JupySQL runs computations in the SQL engine.\n",
1601 | "This delegates memory management to the engine and ensures that intermediate computations do not keep eating up memory, efficiently plotting massive datasets."
1602 | ],
1603 | "metadata": {
1604 | "id": "3_yCxgNIedj5"
1605 | }
1606 | },
1607 | {
1608 | "cell_type": "markdown",
1609 | "source": [
1610 | "### Install and Load DuckDB httpfs extension\n",
1611 | "DuckDB's [httpfs extension](https://duckdb.org/docs/extensions/httpfs) allows parquet and csv files to be queried remotely over http.\n",
1612 | "These examples query a parquet file that contains historical taxi data from NYC.\n",
1613 | "Using the parquet format allows DuckDB to only pull the rows and columns into memory that are needed rather than download the entire file.\n",
1614 | "DuckDB can be used to process [local parquet files as well](https://duckdb.org/docs/data/parquet), which may be desirable if querying the entire parquet file, or running multiple queries that require large subsets of the file.\n"
1615 | ],
1616 | "metadata": {
1617 | "id": "Pr_tB-phf3U3"
1618 | }
1619 | },
1620 | {
1621 | "cell_type": "code",
1622 | "source": [
1623 | "%%sql\n",
1624 | "INSTALL httpfs;\n",
1625 | "LOAD httpfs;"
1626 | ],
1627 | "metadata": {
1628 | "id": "nfjBv8ADZOwO",
1629 | "colab": {
1630 | "base_uri": "https://localhost:8080/",
1631 | "height": 85,
1632 | "referenced_widgets": [
1633 | "0382da2cd6204467a2469a6369fb4e94",
1634 | "c94155f26b66480680c24d57719ab398",
1635 | "5ff54cd86d714982936073808df23f3c"
1636 | ]
1637 | },
1638 | "outputId": "75291df2-581b-4c0e-a0d8-875180d68141"
1639 | },
1640 | "execution_count": null,
1641 | "outputs": [
1642 | {
1643 | "output_type": "display_data",
1644 | "data": {
1645 | "text/plain": [
1646 | "FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))"
1647 | ],
1648 | "application/vnd.jupyter.widget-view+json": {
1649 | "version_major": 2,
1650 | "version_minor": 0,
1651 | "model_id": "0382da2cd6204467a2469a6369fb4e94"
1652 | }
1653 | },
1654 | "metadata": {}
1655 | },
1656 | {
1657 | "output_type": "execute_result",
1658 | "data": {
1659 | "text/plain": [
1660 | "Empty DataFrame\n",
1661 | "Columns: [Success]\n",
1662 | "Index: []"
1663 | ],
1664 | "text/html": [
1665 | "\n",
1666 | " \n",
1667 | "
\n",
1668 | "\n",
1681 | "
\n",
1682 | " \n",
1683 | " \n",
1684 | " | \n",
1685 | " Success | \n",
1686 | "
\n",
1687 | " \n",
1688 | " \n",
1689 | " \n",
1690 | "
\n",
1691 | "
\n",
1692 | "
\n",
1773 | "
\n"
1774 | ],
1775 | "application/vnd.google.colaboratory.intrinsic+json": {
1776 | "type": "dataframe",
1777 | "summary": "{\n \"name\": \"get_ipython()\",\n \"rows\": 0,\n \"fields\": [\n {\n \"column\": \"Success\",\n \"properties\": {\n \"dtype\": \"boolean\",\n \"num_unique_values\": 0,\n \"samples\": [],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
1778 | }
1779 | },
1780 | "metadata": {},
1781 | "execution_count": 12
1782 | }
1783 | ]
1784 | },
1785 | {
1786 | "cell_type": "markdown",
1787 | "source": [
1788 | "### Boxplot & Histogram\n",
1789 | "To create a boxplot, call `%sqlplot boxplot`, passing the name of the table and the column to plot.\n",
1790 | "In this case, the name of the table is the URL of the locally stored parquet file.\n",
1791 | "\n",
1792 | "**Warning** Remote locations are not supported in the `table` argument."
1793 | ],
1794 | "metadata": {
1795 | "id": "RNceigBAgF5V"
1796 | }
1797 | },
1798 | {
1799 | "cell_type": "code",
1800 | "source": [
1801 | "from urllib.request import urlretrieve\n",
1802 | "\n",
1803 | "_ = urlretrieve(\n",
1804 | " \"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet\",\n",
1805 | " \"yellow_tripdata_2021-01.parquet\",\n",
1806 | ")\n",
1807 | "\n",
1808 | "%sqlplot boxplot --table yellow_tripdata_2021-01.parquet --column trip_distance\n",
1809 | "\n",
1810 | "# delete file\n",
1811 | "!rm yellow_tripdata_2021-01.parquet"
1812 | ],
1813 | "metadata": {
1814 | "id": "DBXg6u2hYZAt",
1815 | "colab": {
1816 | "base_uri": "https://localhost:8080/",
1817 | "height": 452
1818 | },
1819 | "outputId": "7a0f2199-8f2a-4a07-ccc5-fe0a418cf2ef"
1820 | },
1821 | "execution_count": null,
1822 | "outputs": [
1823 | {
1824 | "output_type": "display_data",
1825 | "data": {
1826 | "text/plain": [
1827 | ""
1828 | ],
1829 | "image/png": "\n"
1830 | },
1831 | "metadata": {}
1832 | }
1833 | ]
1834 | },
1835 | {
1836 | "cell_type": "markdown",
1837 | "source": [
1838 | "Now, create a query that filters by the 90th percentile.\n",
1839 | "Note the use of the `--save`, and `--no-execute` functions.\n",
1840 | "This tells JupySQL to store the query, but skips execution. It will be referenced in the next plotting call."
1841 | ],
1842 | "metadata": {
1843 | "id": "CzZe8VgygJot"
1844 | }
1845 | },
1846 | {
1847 | "cell_type": "code",
1848 | "source": [
1849 | "%%sql --save short_trips --no-execute\n",
1850 | "SELECT *\n",
1851 | "FROM 'https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet'\n",
1852 | "WHERE trip_distance < 6.3"
1853 | ],
1854 | "metadata": {
1855 | "id": "j9QKSvUWYn0K",
1856 | "colab": {
1857 | "base_uri": "https://localhost:8080/",
1858 | "height": 34
1859 | },
1860 | "outputId": "48dc5446-5229-4818-89df-3bae1da81bf3"
1861 | },
1862 | "execution_count": null,
1863 | "outputs": [
1864 | {
1865 | "output_type": "display_data",
1866 | "data": {
1867 | "text/plain": [
1868 | "Skipping execution..."
1869 | ],
1870 | "text/html": [
1871 | "Skipping execution..."
1872 | ]
1873 | },
1874 | "metadata": {}
1875 | }
1876 | ]
1877 | },
1878 | {
1879 | "cell_type": "markdown",
1880 | "source": [
1881 | "To create a histogram, call `%sqlplot histogram` and pass the name of the table, the column to plot, and the number of bins.\n",
1882 | "This uses `--with short_trips` so JupySQL uses the query defined previously and therefore only plots a subset of the data."
1883 | ],
1884 | "metadata": {
1885 | "id": "Dbw06QARgNiX"
1886 | }
1887 | },
1888 | {
1889 | "cell_type": "code",
1890 | "source": [
1891 | "%sqlplot histogram --table short_trips --column trip_distance --bins 10 --with short_trips"
1892 | ],
1893 | "metadata": {
1894 | "id": "6rgP4x-NYpcE",
1895 | "colab": {
1896 | "base_uri": "https://localhost:8080/",
1897 | "height": 490
1898 | },
1899 | "outputId": "4943cbb0-e3ea-4960-f5ac-3ffa76574144"
1900 | },
1901 | "execution_count": null,
1902 | "outputs": [
1903 | {
1904 | "output_type": "execute_result",
1905 | "data": {
1906 | "text/plain": [
1907 | ""
1908 | ]
1909 | },
1910 | "metadata": {},
1911 | "execution_count": 15
1912 | },
1913 | {
1914 | "output_type": "display_data",
1915 | "data": {
1916 | "text/plain": [
1917 | ""
1918 | ],
1919 | "image/png": "\n"
1920 | },
1921 | "metadata": {}
1922 | }
1923 | ]
1924 | },
1925 | {
1926 | "cell_type": "markdown",
1927 | "source": [
1928 | "## Summary\n",
1929 | "You now have the ability to alternate between SQL and Pandas in a simple and highly performant way! You can plot massive datasets directly through the engine (avoiding both the download of the entire file and loading all of it into Pandas in memory). Dataframes can be read as tables in SQL, and SQL results can be output into Dataframes. Happy analyzing!"
1930 | ],
1931 | "metadata": {
1932 | "id": "exzkl7g47jja"
1933 | }
1934 | }
1935 | ]
1936 | }
--------------------------------------------------------------------------------
/scikit_learn_duckdb/.gitignore:
--------------------------------------------------------------------------------
1 | venv_scikit_learn_duckdb/
2 | .tmp/
3 | model/*.dot
4 | model/*.sav
5 | __marimo__/
--------------------------------------------------------------------------------
/scikit_learn_duckdb/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: setup-python, run-marimo
2 |
3 | setup-python:
4 | python -m venv venv_scikit_learn_duckdb && \
5 | source venv_scikit_learn_duckdb/bin/activate && \
6 | pip install -r requirements.txt
7 |
8 | run-marimo:
9 | source venv_scikit_learn_duckdb/bin/activate && \
10 | marimo edit predict_penguin_species.py
--------------------------------------------------------------------------------
/scikit_learn_duckdb/README.md:
--------------------------------------------------------------------------------
1 | Pre-requisites: make, Python >= 3.12.
2 |
3 | # Setup Python env
4 | 1. Execute `make setup-python`
5 |
6 | # Run marimo notebook
7 | **Warning** The first run takes ~30 seconds to import `scikit-learn`
8 | 1. Execute `make run-marimo`
9 |
--------------------------------------------------------------------------------
/scikit_learn_duckdb/model/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/duckdb/duckdb-blog-examples/f8b3dece7acf4d48c86a8bc55199bc96d4f66bf4/scikit_learn_duckdb/model/.gitkeep
--------------------------------------------------------------------------------
/scikit_learn_duckdb/predict_penguin_species.py:
--------------------------------------------------------------------------------
1 | import marimo
2 |
3 | __generated_with = "0.13.6"
4 | app = marimo.App(width="medium")
5 |
6 |
7 | @app.cell
8 | def _():
9 | import marimo as mo
10 | return (mo,)
11 |
12 |
13 | @app.cell
14 | def _():
15 | import logging
16 | import pickle
17 |
18 | from datetime import datetime
19 | from decimal import Decimal
20 |
21 | import duckdb
22 | import numpy as np
23 | import orjson
24 | import plotly.express as px
25 | from sklearn.ensemble import RandomForestClassifier
26 | from sklearn.model_selection import train_test_split
27 | return (
28 | Decimal,
29 | RandomForestClassifier,
30 | datetime,
31 | duckdb,
32 | np,
33 | orjson,
34 | pickle,
35 | px,
36 | train_test_split,
37 | )
38 |
39 |
40 | @app.cell
41 | def _(duckdb):
42 | # read the csv data from external location and exclude records with null values and alter column type
43 | def process_palmerpenguins_data(duckdb_conn):
44 | duckdb_conn.read_csv(
45 | "http://blobs.duckdb.org/data/penguins.csv"
46 | ).filter("columns(*)::text != 'NA'").filter("columns(*) is not null").select(
47 | "*, row_number() over () as observation_id"
48 | ).to_table(
49 | "penguins_data"
50 | )
51 |
52 | duckdb_conn.sql(
53 | "alter table penguins_data alter bill_length_mm set data type decimal(5, 2)"
54 | )
55 | duckdb_conn.sql(
56 | "alter table penguins_data alter bill_depth_mm set data type decimal(5, 2)"
57 | )
58 | duckdb_conn.sql("alter table penguins_data alter body_mass_g set data type integer")
59 | duckdb_conn.sql(
60 | "alter table penguins_data alter flipper_length_mm set data type integer"
61 | )
62 |
63 | duckdb_conn = duckdb.connect()
64 |
65 | process_palmerpenguins_data(duckdb_conn=duckdb_conn)
66 |
67 | return (duckdb_conn,)
68 |
69 |
70 | @app.cell
71 | def _(duckdb_conn, px):
72 | # plot species and island
73 | px.bar(
74 | duckdb_conn.table("penguins_data").aggregate(
75 | "species, island, count(*) as number_of_observations").order("island, species").df(),
76 | x="island",
77 | y="number_of_observations",
78 | color="species",
79 | title="Palmer Penguins Observations",
80 | barmode="group",
81 | labels={
82 | "number_of_observations": "Number of Observations",
83 | "island": "Island"
84 | }
85 | )
86 | return
87 |
88 |
89 | @app.cell
90 | def _(duckdb_conn, px):
91 | # plot features per species
92 | px.scatter(
93 | duckdb_conn.table("penguins_data").df(),
94 | x="bill_length_mm",
95 | y="bill_depth_mm",
96 | size="body_mass_g",
97 | color="species",
98 | title="Penguins Observations, bill length and depth, per species",
99 | labels={
100 | "bill_length_mm": "Bill Length in mm",
101 | "bill_depth_mm": "Bill Depth in mm"
102 | }
103 | )
104 | return
105 |
106 |
107 | @app.cell
108 | def _(duckdb_conn):
109 | # analyze the data
110 | duckdb_conn.table("penguins_data").describe().df()
111 | return
112 |
113 |
114 | @app.cell
115 | def _(duckdb_conn):
116 | # instead of label encoding, we create reference tables
117 | def process_reference_data(duckdb_conn):
118 | for feature in ["species", "island"]:
119 | duckdb_conn.sql(f"drop table if exists {feature}_ref")
120 | (
121 | duckdb_conn.table("penguins_data")
122 | .select(feature)
123 | .unique(feature)
124 | .row_number(
125 | window_spec=f"over (order by {feature})", projected_columns=feature
126 | )
127 | .select(f"{feature}, #2 - 1 as {feature}_id")
128 | .to_table(f"{feature}_ref")
129 | )
130 | duckdb_conn.table(f"{feature}_ref").show()
131 |
132 | process_reference_data(duckdb_conn)
133 |
134 | return
135 |
136 |
137 | @app.cell
138 | def _(train_test_split):
139 | def train_split_data(selection_query):
140 | X_df = selection_query.select(
141 | "bill_length_mm, bill_depth_mm, flipper_length_mm, body_mass_g, island_id, observation_id, species_id"
142 | ).order("observation_id").df()
143 | y_df = [
144 | x[0]
145 | for x in selection_query.order("observation_id").select("species_id").fetchall()
146 | ]
147 |
148 | num_test = 0.30
149 | return train_test_split(X_df, y_df, test_size=num_test)
150 | return (train_split_data,)
151 |
152 |
153 | @app.cell
154 | def _(RandomForestClassifier, pickle, train_split_data):
155 | def get_model(selection_query):
156 | X_train, X_test, y_train, y_test = train_split_data(selection_query)
157 |
158 | model = RandomForestClassifier(n_estimators=1, max_depth=2, random_state=5)
159 |
160 | model.fit(X_train.drop(["observation_id", "species_id"], axis=1).values, y_train)
161 |
162 | pickle.dump(model, open("./model/penguin_model.sav", "wb"))
163 |
164 | print(f" Accuracy score is: {model.score(
165 | X_test.drop(["observation_id", "species_id"], axis=1).values, y_test
166 | )}")
167 | return (get_model,)
168 |
169 |
170 | @app.cell
171 | def _(duckdb_conn, get_model, pickle):
172 | selection_query = (
173 | duckdb_conn.table("penguins_data")
174 | .join(duckdb_conn.table("island_ref"), condition="island")
175 | .join(duckdb_conn.table("species_ref"), condition="species")
176 | )
177 |
178 | get_model(selection_query)
179 |
180 | model = pickle.load(open("./model/penguin_model.sav", "rb"))
181 | return model, selection_query
182 |
183 |
184 | @app.cell
185 | def _(duckdb_conn, model, selection_query):
186 | # get predictions with pandas and duckdb in python
187 |
188 | predicted_df = selection_query.select(
189 | "bill_length_mm, bill_depth_mm, flipper_length_mm, body_mass_g, island_id, observation_id, species_id"
190 | ).df()
191 |
192 | predicted_df["predicted_species_id"] = model.predict(
193 | predicted_df.drop(["observation_id", "species_id"], axis=1).values
194 | )
195 |
196 | (
197 | duckdb_conn.table("predicted_df")
198 | .select("observation_id", "species_id", "predicted_species_id")
199 | .filter("species_id != predicted_species_id")
200 | )
201 |
202 | return (predicted_df,)
203 |
204 |
205 | @app.cell
206 | def _(duckdb_conn, mo, predicted_df):
207 | _df = mo.sql(
208 | f"""
209 | -- directly with SQL
210 |
211 | select observation_id, species_id, predicted_species_id
212 | from predicted_df
213 | where species_id != predicted_species_id
214 | """,
215 | engine=duckdb_conn
216 | )
217 | return
218 |
219 |
220 | @app.cell
221 | def _(Decimal, duckdb_conn, pickle, selection_query):
222 | # get predictions with duckdb udf, row by row
223 |
224 | def get_prediction_per_row(
225 | bill_length_mm: Decimal, bill_depth_mm: Decimal, flipper_length_mm: int, body_mass_g: int, island_id: int
226 | ) -> int:
227 | model = pickle.load(open("./model/penguin_model.sav", "rb"))
228 | return int(
229 | model.predict(
230 | [
231 | [
232 | bill_length_mm,
233 | bill_depth_mm,
234 | flipper_length_mm,
235 | body_mass_g,
236 | island_id,
237 | ]
238 | ]
239 | )[0]
240 | )
241 |
242 | try:
243 | duckdb_conn.remove_function("predict_species_per_row")
244 | except Exception:
245 | pass
246 | finally:
247 | duckdb_conn.create_function(
248 | "predict_species_per_row", get_prediction_per_row, return_type=int
249 | )
250 |
251 | selection_query.select(
252 | """
253 | observation_id,
254 | species_id,
255 | predict_species_per_row(
256 | bill_length_mm,
257 | bill_depth_mm,
258 | flipper_length_mm,
259 | body_mass_g,
260 | island_id
261 | ) as predicted_species_id
262 | """
263 | ).filter("species_id != predicted_species_id")
264 | return
265 |
266 |
267 | @app.cell
268 | def _(Decimal, datetime, duckdb, duckdb_conn, np, orjson, pickle):
269 | # get predictions with duckdb udf, full / batch style
270 |
271 | def get_prediction_per_batch(input_data: dict[str, list[Decimal | int ]]) -> np.ndarray:
272 | """
273 | input_data example:
274 | {
275 | "bill_length_mm": [40.5],
276 | "bill_depth_mm": [41.5],
277 | "flipper_length_mm: [250],
278 | "body_mass_g": [3000],
279 | "island_id": [1]
280 | }
281 | """
282 | model = pickle.load(open("./model/penguin_model.sav", "rb"))
283 |
284 | st_dt = datetime.now()
285 |
286 | input_data_parsed = orjson.loads(input_data)
287 |
288 | print(f"JSON parsing took {(datetime.now() - st_dt).total_seconds()} seconds")
289 |
290 | st_dt = datetime.now()
291 |
292 | input_data_converted_to_numpy = np.stack(tuple(input_data_parsed.values()), axis=1)
293 |
294 | print(f"Converting to numpy took {(datetime.now() - st_dt).total_seconds()} seconds")
295 |
296 | return model.predict(input_data_converted_to_numpy)
297 |
298 | try:
299 | duckdb_conn.remove_function("predict_species_per_batch")
300 | except Exception:
301 | pass
302 | finally:
303 | duckdb_conn.create_function(
304 | "predict_species_per_batch",
305 | get_prediction_per_batch,
306 | return_type=duckdb.typing.DuckDBPyType(list[int]),
307 | )
308 |
309 |
310 | def get_selection_query_for_batch(selection_query):
311 | return (
312 | selection_query
313 | .aggregate("""
314 | json_object(
315 | 'bill_length_mm', array_agg(bill_length_mm),
316 | 'bill_depth_mm', array_agg(bill_depth_mm),
317 | 'flipper_length_mm', array_agg(flipper_length_mm),
318 | 'body_mass_g', array_agg(body_mass_g),
319 | 'island_id', array_agg(island_id)
320 | ) as input_data,
321 | struct_pack(
322 | observation_id := array_agg(observation_id),
323 | species_id := array_agg(species_id),
324 | predicted_species_id := predict_species_per_batch(input_data)
325 | ) as output_data
326 | """)
327 | .select("""
328 | unnest(output_data.observation_id) as observation_id,
329 | unnest(output_data.species_id) as species_id,
330 | unnest(output_data.predicted_species_id) as predicted_species_id
331 | """)
332 | )
333 |
334 | return (get_selection_query_for_batch,)
335 |
336 |
337 | @app.cell
338 | def _(get_selection_query_for_batch, selection_query):
339 | # mass retrieval
340 | get_selection_query_for_batch(selection_query).filter("species_id != predicted_species_id").show()
341 |
342 | return
343 |
344 |
345 | @app.cell
346 | def _(get_selection_query_for_batch, selection_query):
347 | # batch style
348 | for i in range(4):
349 | (
350 | get_selection_query_for_batch(
351 | selection_query
352 | .order("observation_id")
353 | .limit(100, offset=100*i)
354 | .select("*")
355 | )
356 | .filter("species_id != predicted_species_id")
357 | ).show()
358 | return
359 |
360 |
361 | @app.cell
362 | def _(duckdb_conn, selection_query):
363 | def generate_dummy_data(duckdb_conn, selection_query):
364 | duckdb_conn.sql("drop table if exists dummy_generated_data")
365 | selection_query.filter("1 = 0").select(
366 | "bill_length_mm, bill_depth_mm, flipper_length_mm, body_mass_g, island_id, observation_id, species_id, species"
367 | ).to_table("dummy_generated_data")
368 |
369 | for idx, rec in enumerate(
370 | selection_query.aggregate("""
371 | island_id,
372 | species_id,
373 | min(bill_length_mm)::int as min_bill_length_mm,
374 | max(bill_length_mm)::int as max_bill_length_mm,
375 | min(bill_depth_mm)::int as min_bill_depth_mm,
376 | max(bill_depth_mm)::int as max_bill_depth_mm,
377 | min(flipper_length_mm) as min_flipper_length_mm,
378 | max(flipper_length_mm) as max_flipper_length_mm,
379 | min(body_mass_g) as min_body_mass_g,
380 | max(body_mass_g) as max_body_mass_g
381 | """).fetchall()
382 | ):
383 | bill_length_range = duckdb_conn.sql(f"from range({rec[2]}, {rec[3]})").select(
384 | "range as bill_length"
385 | )
386 |
387 | bill_depth_range = duckdb_conn.sql(f"from range({rec[4]}, {rec[5]})").select(
388 | "range as bill_depth"
389 | )
390 |
391 | flipper_length_range = duckdb_conn.sql(
392 | f"from range({rec[6]}, {rec[7]})"
393 | ).select("range as flipper_length")
394 |
395 | body_mass_range = duckdb_conn.sql(f"from range({rec[8]}, {rec[9]})").select(
396 | "range as body_mass"
397 | )
398 |
399 | dummy_range = duckdb_conn.sql("from range(1,10)").set_alias(
400 | "dummy_range"
401 | )
402 |
403 | sql_query = (
404 | dummy_range.join(bill_length_range, condition="1 = 1")
405 | .join(bill_depth_range, condition="1 = 1")
406 | .join(flipper_length_range, condition="1 = 1")
407 | .join(body_mass_range, condition="1 = 1")
408 | .join(duckdb_conn.table("species_ref"), condition=f"species_id = {rec[1]}")
409 | .select(
410 | f"""
411 | bill_length + 10 ** 1/range as bill_length_mm,
412 | bill_depth + 10 ** 1/range as bill_depth_mm,
413 | flipper_length + (10 ** 1/range)::int as flipper_length_mm,
414 | body_mass + (10 ** 1/range)::int as body_mass_g,
415 | {rec[0]} as island_id,
416 | null as observation_id,
417 | species_id,
418 | species
419 | """
420 | )
421 | ).sql_query()
422 |
423 | duckdb_conn.sql(f"select * from ({sql_query}) using sample 30%").insert_into("dummy_generated_data")
424 |
425 | generate_dummy_data(duckdb_conn, selection_query)
426 |
427 | duckdb_conn.table("dummy_generated_data").count("*")
428 |
429 | return
430 |
431 |
432 | @app.cell
433 | def _(duckdb_conn):
434 | duckdb_conn.sql("select * from dummy_generated_data using sample 10%").to_table("sample_dummy_data")
435 | return
436 |
437 |
438 | @app.cell
439 | def _(duckdb_conn, get_selection_query_for_batch):
440 | (
441 | get_selection_query_for_batch(duckdb_conn.table("sample_dummy_data"))
442 | .aggregate("""
443 | sum(if(species_id = predicted_species_id, 1, 0)) number_of_correct_predictions,
444 | sum(if(species_id = predicted_species_id, 0, 1)) number_of_incorrect_predictions
445 | """)
446 | )
447 | return
448 |
449 |
450 | @app.cell
451 | def _(duckdb_conn, model):
452 | predicted_dummy_df = duckdb_conn.table("sample_dummy_data").select(
453 | "bill_length_mm, bill_depth_mm, flipper_length_mm, body_mass_g, island_id, observation_id, species_id"
454 | ).df()
455 |
456 | predicted_dummy_df["predicted_species_id"] = model.predict(
457 | predicted_dummy_df.drop(["observation_id", "species_id"], axis=1).values
458 | )
459 |
460 | (
461 | duckdb_conn.table("predicted_dummy_df")
462 | .select("observation_id", "species_id", "predicted_species_id")
463 | .aggregate("""
464 | sum(if(species_id = predicted_species_id, 1, 0)) number_of_correct_predictions,
465 | sum(if(species_id = predicted_species_id, 0, 1)) number_of_incorrect_predictions
466 | """)
467 | )
468 | return
469 |
470 |
471 | if __name__ == "__main__":
472 | app.run()
473 |
--------------------------------------------------------------------------------
/scikit_learn_duckdb/requirements.txt:
--------------------------------------------------------------------------------
1 | anyio==4.9.0
2 | asttokens==3.0.0
3 | click==8.1.8
4 | decorator==5.2.1
5 | docutils==0.21.2
6 | duckdb==1.2.2
7 | executing==2.2.0
8 | h11==0.16.0
9 | idna==3.10
10 | ipython==9.2.0
11 | ipython_pygments_lexers==1.1.1
12 | itsdangerous==2.2.0
13 | jedi==0.19.2
14 | joblib==1.4.2
15 | marimo==0.13.6
16 | Markdown==3.8
17 | matplotlib-inline==0.1.7
18 | narwhals==1.37.1
19 | numpy==2.2.5
20 | orjson==3.10.18
21 | packaging==25.0
22 | pandas==2.2.3
23 | parso==0.8.4
24 | pexpect==4.9.0
25 | plotly==6.0.1
26 | polars==1.29.0
27 | prompt_toolkit==3.0.51
28 | psutil==7.0.0
29 | ptyprocess==0.7.0
30 | pure_eval==0.2.3
31 | pyarrow==20.0.0
32 | pycrdt==0.11.1
33 | Pygments==2.19.1
34 | pymdown-extensions==10.15
35 | python-dateutil==2.9.0.post0
36 | pytz==2025.2
37 | PyYAML==6.0.2
38 | scikit-learn==1.6.1
39 | scipy==1.15.2
40 | six==1.17.0
41 | sniffio==1.3.1
42 | sqlglot==26.16.4
43 | stack-data==0.6.3
44 | starlette==0.46.2
45 | threadpoolctl==3.6.0
46 | tomlkit==0.13.2
47 | traitlets==5.14.3
48 | typing_extensions==4.13.2
49 | tzdata==2025.2
50 | uvicorn==0.34.2
51 | wcwidth==0.2.13
52 | websockets==15.0.1
53 |
--------------------------------------------------------------------------------