12 |
13 | {!../README.md!}
14 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_0var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=[],
11 | output='wide',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_1var_without_const.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1'],
11 | output='long',
12 | add_constant=false
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_0var_regression_long_chol.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=[],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_0var_regression_long_fwl.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=[],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_1var_regression_long_fwl.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa'],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_1var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa'],
11 | output='wide',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_wide_format_options.sql:
--------------------------------------------------------------------------------
1 | with base as (
2 |
3 | select
4 | fooconstant_term_bar,
5 | "fooxa_bar",
6 | fooxb_bar
7 | from
8 | {{ ref("wide_format_options") }}
9 |
10 | )
11 |
12 | /* If this SQL query doesn't throw an error, it's all set. */
13 | select * from base where false
14 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_2var_without_const.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2'],
11 | output='long',
12 | add_constant=false
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_1var_regression_long_chol.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa'],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_2var_regression_long.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa', 'xb'],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_2var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa', 'xb'],
11 | output='wide',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_3var_without_const.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2', 'x3'],
11 | output='long',
12 | add_constant=false
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_3var_regression_long.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa', 'xb', 'xc'],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_regression_fwl.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
11 | output='long',
12 | method='fwl'
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_4var_regression_long.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('simple_matrix'),
9 | endog='y',
10 | exog=['xa', 'xb', 'xc', 'xd'],
11 | output='long',
12 | output_options={'round': 5}
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_4var_without_const.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2', 'x3', 'x4'],
11 | output='long',
12 | add_constant=false
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_1var_without_const_ridge.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1'],
11 | alpha=2.0,
12 | output='long',
13 | add_constant=false
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_5var_without_const.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
11 | output='long',
12 | add_constant=false
13 | )
14 | }} as linreg
15 |
--------------------------------------------------------------------------------
/dbt_project.yml:
--------------------------------------------------------------------------------
1 | name: "dbt_linreg"
2 | version: "0.3.1"
3 |
4 | # 1.2 is required because of modules.itertools.
5 | require-dbt-version: [">=1.2.0", "<2.0.0"]
6 |
7 | config-version: 2
8 |
9 | target-path: "target"
10 | clean-targets: ["target", "dbt_modules", "dbt_packages"]
11 | macro-paths: ["macros"]
12 | log-path: "logs"
13 | profile: "dbt_linreg_profile"
14 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_ridge_regression_fwl.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('collinear_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
11 | output='long',
12 | alpha=0.01,
13 | method='fwl'
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_3var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('simple_matrix'),
10 | endog='y',
11 | exog=['xa', 'xb', 'xc'],
12 | output='wide',
13 | output_options={'round': 5}
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_regression_chol.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('collinear_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
12 | output='long',
13 | method='chol'
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_4var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('simple_matrix'),
10 | endog='y',
11 | exog=['xa', 'xb', 'xc', 'xd'],
12 | output='wide',
13 | output_options={'round': 5}
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_5var_regression_long.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('simple_matrix'),
10 | endog='y',
11 | exog=['xa', 'xb', 'xc', 'xd', 'xe'],
12 | output='long',
13 | output_options={'round': 5}
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_5var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('simple_matrix'),
10 | endog='y',
11 | exog=['xa', 'xb', 'xc', 'xd', 'xe'],
12 | output='wide',
13 | output_options={'round': 5}
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/groups_matrix_regression_fwl.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select * from {{
7 | dbt_linreg.ols(
8 | table=ref('groups_matrix'),
9 | endog='y',
10 | exog=['x1', 'x2', 'x3'],
11 | group_by=['gb_var'],
12 | output='long',
13 | method='fwl'
14 | )
15 | }} as linreg
16 | order by gb_var, variable_name
17 |
--------------------------------------------------------------------------------
/integration_tests/models/perfectly_multicollinear_model.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | with base as (
7 | select
8 | y,
9 | xa,
10 | xa as xb
11 | from {{ ref('simple_matrix') }}
12 | )
13 |
14 | select * from {{
15 | dbt_linreg.ols(
16 | table='base',
17 | endog='y',
18 | exog=['xa', 'xb']
19 | )
20 | }} as linreg
21 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_8var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="view",
4 | tags=["perftest", "skip-postgres"],
5 | enabled=False,
6 | )
7 | }}
8 | select * from {{
9 | dbt_linreg.ols(
10 | table=ref('simple_matrix'),
11 | endog='y',
12 | exog=['xa', 'xb', 'xc', 'xd', 'xe', 'xf', 'xg', 'xh'],
13 | output='wide',
14 | )
15 | }} as linreg
16 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_ridge_regression_chol.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('collinear_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
12 | output='long',
13 | alpha=0.01,
14 | method='chol'
15 | )
16 | }} as linreg
17 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_5var_without_const_ridge.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('collinear_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
12 | alpha=1.0,
13 | output='long',
14 | add_constant=false
15 | )
16 | }} as linreg
17 |
--------------------------------------------------------------------------------
/integration_tests/models/simple_10var_regression_long.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="view",
4 | enabled=False,
5 | tags=["skip-postgres"]
6 | )
7 | }}
8 | select * from {{
9 | dbt_linreg.ols(
10 | table=ref('simple_matrix'),
11 | endog='y',
12 | exog=['xa', 'xb', 'xc', 'xd', 'xe', 'xf', 'xg', 'xh', 'xi', 'xj'],
13 | output='long',
14 | output_options={'round': 5}
15 | )
16 | }} as linreg
17 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres", "skip-clickhouse"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('collinear_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
12 | output='long',
13 | method='chol',
14 | method_options={'subquery_optimization': False}
15 | )
16 | }} as linreg
17 |
--------------------------------------------------------------------------------
/docs/src/css/extra.css:
--------------------------------------------------------------------------------
1 | [data-md-color-scheme="light"] img[src$="#only-dark"],
2 | [data-md-color-scheme="light"] img[src$="#gh-dark-mode-only"] {
3 | display: none; /* Hide dark images in light mode */
4 | }
5 |
6 | [data-md-color-scheme="dark"] img[src$="#only-light"],
7 | [data-md-color-scheme="dark"] img[src$="#gh-light-mode-only"] {
8 | display: none; /* Hide light images in dark mode */
9 | }
10 |
11 | img[src$="#readme-logo"] {
12 | display: none;
13 | }
14 |
--------------------------------------------------------------------------------
/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres", "skip-clickhouse"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('collinear_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'],
12 | output='long',
13 | alpha=0.01,
14 | method='chol',
15 | method_options={'subquery_optimization': False}
16 | )
17 | }} as linreg
18 |
--------------------------------------------------------------------------------
/integration_tests/models/groups_matrix_regression_chol_optimized.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('groups_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3'],
12 | group_by=['gb_var'],
13 | output='long',
14 | method='chol',
15 | method_options={'subquery_optimization': True}
16 | )
17 | }} as linreg
18 | order by gb_var, variable_name
19 |
--------------------------------------------------------------------------------
/integration_tests/models/wide_format_options.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select
7 | *
8 | from {{
9 | dbt_linreg.ols(
10 | table=ref('simple_matrix'),
11 | endog='y',
12 | exog=['"xa"', 'xb'],
13 | output='wide',
14 | output_options={
15 | 'variable_column_prefix': 'foo',
16 | 'variable_column_suffix': '_bar',
17 | 'constant_name': 'constant_term'
18 | }
19 | )
20 | }} as linreg
21 |
--------------------------------------------------------------------------------
/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table",
4 | tags=["skip-postgres", "skip-clickhouse"]
5 | )
6 | }}
7 | select * from {{
8 | dbt_linreg.ols(
9 | table=ref('groups_matrix'),
10 | endog='y',
11 | exog=['x1', 'x2', 'x3'],
12 | group_by=['gb_var'],
13 | output='long',
14 | method='chol',
15 | method_options={'subquery_optimization': False}
16 | )
17 | }} as linreg
18 | order by gb_var, variable_name
19 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_1var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select
6 | 10.0 as const,
7 | 5.0 as xa
8 | )
9 |
10 | select
11 | expected.const as expected_const,
12 | base.const as actual_const,
13 | expected.xa as expected_xa,
14 | base.xa as actual_xa
15 | from {{ ref('simple_1var_regression_wide') }} as base, expected
16 | where not (
17 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }}
18 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }}
19 | )
20 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | jobs:
7 | build:
8 | name: Deploy docs
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Checkout main
12 | uses: actions/checkout@v1
13 | - name: Install dependencies
14 | run: sudo apt-get update
15 | - name: Deploy docs
16 | uses: mhausenblas/mkdocs-deploy-gh-pages@1.26
17 | env:
18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
19 | CONFIG_FILE: docs/mkdocs.yml
20 | REQUIREMENTS: docs/requirements.txt
21 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 |
3 | - repo: https://github.com/pre-commit/pre-commit-hooks
4 | rev: v2.3.0
5 | hooks:
6 | - id: check-yaml
7 | - id: end-of-file-fixer
8 | - id: trailing-whitespace
9 |
10 | - repo: https://github.com/charliermarsh/ruff-pre-commit
11 | rev: v0.6.2
12 | hooks:
13 | - id: ruff
14 |
15 | - repo: https://github.com/koalaman/shellcheck-precommit
16 | rev: v0.8.0
17 | hooks:
18 | - id: shellcheck
19 | args: [-x, run]
20 |
21 | - repo: https://github.com/rhysd/actionlint
22 | rev: v1.6.26
23 | hooks:
24 | - id: actionlint
25 |
--------------------------------------------------------------------------------
/integration_tests/selectors.yml:
--------------------------------------------------------------------------------
1 | selectors:
2 | - name: duckdb-selector
3 | definition: 'fqn:*'
4 | - name: postgres-selector
5 | # Postgres runs into memory / performance issues for some of these queries.
6 | # Resolving this and making Postgres more performant is a TODO.
7 | definition:
8 | union:
9 | - 'fqn:*'
10 | - exclude:
11 | - '@tag:skip-postgres'
12 | - name: clickhouse-selector
13 | # Clickhouse struggles with the unoptimized chol method.
14 | definition:
15 | union:
16 | - 'fqn:*'
17 | - exclude:
18 | - '@tag:skip-clickhouse'
19 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_0var_regression_long_fwl.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 |
7 | )
8 |
9 | select
10 | coalesce(base.variable_name, expected.variable_name) as variable_name,
11 | expected.coefficient as expected_coefficient,
12 | base.coefficient as actual_coefficient
13 | from {{ ref('simple_0var_regression_long_fwl') }} as base
14 | full outer join expected
15 | on base.variable_name = expected.variable_name
16 | where
17 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
18 | or base.coefficient is null
19 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_0var_regression_long_chol.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 |
7 | )
8 |
9 | select
10 | coalesce(base.variable_name, expected.variable_name) as variable_name,
11 | expected.coefficient as expected_coefficient,
12 | base.coefficient as actual_coefficient
13 | from {{ ref('simple_0var_regression_long_chol') }} as base
14 | full outer join expected
15 | on base.variable_name = expected.variable_name
16 | where
17 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
18 | or base.coefficient is null
19 |
--------------------------------------------------------------------------------
/integration_tests/dbt_project.yml:
--------------------------------------------------------------------------------
1 | name: "dbt_linreg_tests"
2 | version: "0.3.1"
3 |
4 | require-dbt-version: [">=1.0.0", "<2.0.0"]
5 |
6 | config-version: 2
7 |
8 | target-path: "target"
9 | clean-targets: ["target", "dbt_modules", "dbt_packages"]
10 | macro-paths: ["macros"]
11 | log-path: "logs"
12 |
13 | vars:
14 | _test_precision_simple_matrix: '{{ "10e-8" if target.name == "clickhouse" else 0.0 }}'
15 | _test_precision_collinear_matrix: '{{ "10e-6" if target.name == "clickhouse" else "10e-7" }}'
16 |
17 | models:
18 | +materialized: table
19 |
20 | tests:
21 | +store_failures: true
22 |
23 | # During dev only!
24 | profile: "dbt_linreg_profile"
25 |
--------------------------------------------------------------------------------
/integration_tests/profiles/profiles.yml:
--------------------------------------------------------------------------------
1 | dbt_linreg_profile:
2 | target: duckdb
3 | outputs:
4 | duckdb:
5 | type: duckdb
6 | path: dbt.duckdb
7 | postgres:
8 | type: postgres
9 | user: '{{ env_var("POSTGRES_USER") }}'
10 | password: '{{ env_var("POSTGRES_PASSWORD") }}'
11 | host: '{{ env_var("POSTGRES_HOST", "localhost") }}'
12 | port: '{{ env_var("POSTGRES_PORT", "5432") | as_number }}'
13 | dbname: '{{ env_var("POSTGRES_DB", "dbt_linreg") }}'
14 | schema: '{{ env_var("POSTGRES_SCHEMA", "public") }}'
15 | clickhouse:
16 | type: clickhouse
17 | port: 8123
18 | schema: dbt_linreg
19 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_1var_regression_long_chol.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 | union all
7 | select 'xa' as variable_name, 5.0 as coefficient
8 |
9 | )
10 |
11 | select
12 | coalesce(base.variable_name, expected.variable_name) as variable_name,
13 | expected.coefficient as expected_coefficient,
14 | base.coefficient as actual_coefficient
15 | from {{ ref('simple_1var_regression_long_chol') }} as base
16 | full outer join expected
17 | on base.variable_name = expected.variable_name
18 | where
19 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
20 | or base.coefficient is null
21 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_1var_regression_long_fwl.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 | union all
7 | select 'xa' as variable_name, 5.0 as coefficient
8 |
9 | )
10 |
11 | select
12 | coalesce(base.variable_name, expected.variable_name) as variable_name,
13 | expected.coefficient as expected_coefficient,
14 | base.coefficient as actual_coefficient
15 | from {{ ref('simple_1var_regression_long_fwl') }} as base
16 | full outer join expected
17 | on base.variable_name = expected.variable_name
18 | where
19 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
20 | or base.coefficient is null
21 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_2var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select
6 | 10.0 as const,
7 | 5.0 as xa,
8 | 7.0 as xb
9 | )
10 |
11 | select
12 | expected.const as expected_const,
13 | base.const as actual_const,
14 | expected.xa as expected_xa,
15 | base.xa as actual_xa,
16 | expected.xb as expected_xb,
17 | base.xb as actual_xb
18 | from {{ ref('simple_2var_regression_wide') }} as base, expected
19 | where not (
20 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }}
21 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }}
22 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }}
23 | )
24 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_2var_regression_long.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 | union all
7 | select 'xa' as variable_name, 5.0 as coefficient
8 | union all
9 | select 'xb' as variable_name, 7.0 as coefficient
10 |
11 | )
12 |
13 | select
14 | coalesce(base.variable_name, expected.variable_name) as variable_name,
15 | expected.coefficient as expected_coefficient,
16 | base.coefficient as actual_coefficient
17 | from {{ ref('simple_2var_regression_long') }} as base
18 | full outer join expected
19 | on base.variable_name = expected.variable_name
20 | where
21 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
22 | or base.coefficient is null
23 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_3var_regression_long.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 | union all
7 | select 'xa' as variable_name, 5.0 as coefficient
8 | union all
9 | select 'xb' as variable_name, 7.0 as coefficient
10 | union all
11 | select 'xc' as variable_name, 9.0 as coefficient
12 |
13 | )
14 |
15 | select
16 | coalesce(base.variable_name, expected.variable_name) as variable_name,
17 | expected.coefficient as expected_coefficient,
18 | base.coefficient as actual_coefficient
19 | from {{ ref('simple_3var_regression_long') }} as base
20 | full outer join expected
21 | on base.variable_name = expected.variable_name
22 | where
23 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
24 | or base.coefficient is null
25 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_1var_without_const_ridge.sql:
--------------------------------------------------------------------------------
1 | /* Ridge regression coefficients do not match exactly.
2 | Instead, a threshold of no more than 0.01% deviation is enforced. */
3 | {% set THRESHOLD = 0.0001 %}
4 | with
5 |
6 | expected as (
7 |
8 | select 'x1' as variable_name, 21.78558328301129 as coefficient
9 |
10 | )
11 |
12 | select
13 | coalesce(base.variable_name, expected.variable_name) as variable_name,
14 | expected.coefficient as expected_coefficient,
15 | base.coefficient as actual_coefficient
16 | from {{ ref('collinear_matrix_1var_without_const_ridge') }} as base
17 | full outer join expected
18 | on base.variable_name = expected.variable_name
19 | where
20 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }}
21 | or sign(base.coefficient) != sign(expected.coefficient)
22 | or base.coefficient is null
23 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_3var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select
6 | 10.0 as const,
7 | 5.0 as xa,
8 | 7.0 as xb,
9 | 9.0 as xc
10 | )
11 |
12 | select
13 | expected.const as expected_const,
14 | base.const as actual_const,
15 | expected.xa as expected_xa,
16 | base.xa as actual_xa,
17 | expected.xb as expected_xb,
18 | base.xb as actual_xb,
19 | expected.xc as expected_xc,
20 | base.xc as actual_xc
21 | from {{ ref('simple_3var_regression_wide') }} as base, expected
22 | where not (
23 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }}
24 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }}
25 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }}
26 | and abs(base.xc - expected.xc) <= {{ var("_test_precision_simple_matrix") }}
27 | )
28 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "dbt_linreg"
3 | requires-python = "~=3.11"
4 | description = "dbt_linreg dbt package"
5 | version = "0.3.1"
6 | readme = "README.md"
7 | authors = ["Daniel Reeves"]
8 |
9 | [project.optional-dependencies]
10 | python-dev = [
11 | "pandas>=2.2.3",
12 | "pre-commit>=4.0.1",
13 | "pyyaml>=6.0.2",
14 | "rich-click>=1.8.5",
15 | "ruff>=0.8.4",
16 | "statsmodels>=0.14.4",
17 | "tabulate>=0.9.0",
18 | ]
19 | clickhouse = [
20 | "dbt-core<1.9.0",
21 | "dbt-clickhouse",
22 | ]
23 | duckdb = [
24 | "dbt-core<1.9.0",
25 | "dbt-duckdb",
26 | "duckdb>=1.1.3",
27 | ]
28 | postgres = [
29 | "dbt-core<1.9.0",
30 | "dbt-postgres",
31 | ]
32 |
33 | [tool.ruff]
34 | line-length = 120
35 |
36 | [tool.ruff.lint]
37 | select = ["F", "E", "W", "I001"]
38 |
39 | [tool.ruff.lint.isort]
40 | lines-after-imports = 2
41 | force-single-line = true
42 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_4var_regression_long.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 | union all
7 | select 'xa' as variable_name, 5.0 as coefficient
8 | union all
9 | select 'xb' as variable_name, 7.0 as coefficient
10 | union all
11 | select 'xc' as variable_name, 9.0 as coefficient
12 | union all
13 | select 'xd' as variable_name, 11.0 as coefficient
14 |
15 | )
16 |
17 | select
18 | coalesce(base.variable_name, expected.variable_name) as variable_name,
19 | expected.coefficient as expected_coefficient,
20 | base.coefficient as actual_coefficient
21 | from {{ ref('simple_4var_regression_long') }} as base
22 | full outer join expected
23 | on base.variable_name = expected.variable_name
24 | where
25 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
26 | or base.coefficient is null
27 |
--------------------------------------------------------------------------------
/macros/linear_regression/ols_impl_special/_ols_0var.sql:
--------------------------------------------------------------------------------
1 | {% macro _ols_0var(table,
2 | endog,
3 | exog,
4 | add_constant=True,
5 | output=None,
6 | output_options=None,
7 | group_by=None,
8 | alpha=None) -%}
9 | (with _dbt_linreg_final_coefs as (
10 | select
11 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) }}
12 | avg({{ endog }}) as x0_coef
13 | from {{ table }}
14 | {%- if group_by %}
15 | group by
16 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }}
17 | {%- endif %}
18 | )
19 | {{
20 | dbt_linreg.final_select(
21 | exog=[],
22 | exog_aliased=[],
23 | add_constant=add_constant,
24 | group_by=group_by,
25 | output=output,
26 | output_options=output_options,
27 | calculate_standard_error=False
28 | )
29 | }}
30 | )
31 | {%- endmacro %}
32 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_5var_regression_long.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 10.0 as coefficient
6 | union all
7 | select 'xa' as variable_name, 5.0 as coefficient
8 | union all
9 | select 'xb' as variable_name, 7.0 as coefficient
10 | union all
11 | select 'xc' as variable_name, 9.0 as coefficient
12 | union all
13 | select 'xd' as variable_name, 11.0 as coefficient
14 | union all
15 | select 'xe' as variable_name, 13.0 as coefficient
16 |
17 | )
18 |
19 | select
20 | coalesce(base.variable_name, expected.variable_name) as variable_name,
21 | expected.coefficient as expected_coefficient,
22 | base.coefficient as actual_coefficient
23 | from {{ ref('simple_5var_regression_long') }} as base
24 | full outer join expected
25 | on base.variable_name = expected.variable_name
26 | where
27 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }}
28 | or base.coefficient is null
29 |
--------------------------------------------------------------------------------
/integration_tests/models/long_format_options.sql:
--------------------------------------------------------------------------------
1 | {{
2 | config(
3 | materialized="table"
4 | )
5 | }}
6 | select
7 | true as strip_quotes, *
8 | from {{
9 | dbt_linreg.ols(
10 | table=ref('simple_matrix'),
11 | endog='y',
12 | exog=['"xa"', 'xb'],
13 | output='long',
14 | output_options={
15 | 'constant_name': 'constant_term',
16 | 'variable_column_name': 'vname',
17 | 'coefficient_column_name': 'co',
18 | 'strip_quotes': True
19 | }
20 | )
21 | }} as linreg1
22 |
23 | union all
24 |
25 | select
26 | false as strip_quotes, *
27 | from {{
28 | dbt_linreg.ols(
29 | table=ref('simple_matrix'),
30 | endog='y',
31 | exog=['"xa"', 'xb'],
32 | output='long',
33 | output_options={
34 | 'constant_name': 'constant_term',
35 | 'variable_column_name': 'vname',
36 | 'coefficient_column_name': 'co',
37 | 'strip_quotes': False
38 | }
39 | )
40 | }} as linreg2
41 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_1var_without_const.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'x1' as variable_name, 30.500076644845674 as coefficient, 0.8396121329329627 as standard_error, 36.326388636502585 as t_statistic
6 |
7 | )
8 |
9 | select
10 | coalesce(base.variable_name, expected.variable_name) as variable_name,
11 | expected.coefficient as expected_coefficient,
12 | base.coefficient as actual_coefficient
13 | from {{ ref('collinear_matrix_1var_without_const') }} as base
14 | full outer join expected
15 | on base.variable_name = expected.variable_name
16 | where
17 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
18 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
19 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
20 | or base.coefficient is null
21 | or base.standard_error is null
22 | or base.t_statistic is null
23 | or expected.coefficient is null
24 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_4var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select
6 | 10.0 as const,
7 | 5.0 as xa,
8 | 7.0 as xb,
9 | 9.0 as xc,
10 | 11.0 as xd
11 | )
12 |
13 | select
14 | expected.const as expected_const,
15 | base.const as actual_const,
16 | expected.xa as expected_xa,
17 | base.xa as actual_xa,
18 | expected.xb as expected_xb,
19 | base.xb as actual_xb,
20 | expected.xc as expected_xc,
21 | base.xc as actual_xc,
22 | expected.xd as expected_xd,
23 | base.xd as actual_xd
24 | from {{ ref('simple_4var_regression_wide') }} as base, expected
25 | where not (
26 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }}
27 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }}
28 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }}
29 | and abs(base.xc - expected.xc) <= {{ var("_test_precision_simple_matrix") }}
30 | and abs(base.xd - expected.xd) <= {{ var("_test_precision_simple_matrix") }}
31 | )
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2023 Daniel Reeves
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining
6 | a copy of this software and associated documentation files (the
7 | 'Software'), to deal in the Software without restriction, including
8 | without limitation the rights to use, copy, modify, merge, publish,
9 | distribute, sublicense, and/or sell copies of the Software, and to
10 | permit persons to whom the Software is furnished to do so, subject to
11 | the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be
14 | included in all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_2var_without_const.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'x1' as variable_name, 63.18154691334764 as coefficient, 0.4056389914380657 as standard_error, 155.75807120848344 as t_statistic
6 | union all
7 | select 'x2' as variable_name, 55.39820150046505 as coefficient, 0.2738669097295638 as standard_error, 202.2814715190283 as t_statistic
8 |
9 | )
10 |
11 | select
12 | coalesce(base.variable_name, expected.variable_name) as variable_name,
13 | expected.coefficient as expected_coefficient,
14 | base.coefficient as actual_coefficient
15 | from {{ ref('collinear_matrix_2var_without_const') }} as base
16 | full outer join expected
17 | on base.variable_name = expected.variable_name
18 | where
19 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
20 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
21 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
22 | or base.coefficient is null
23 | or base.standard_error is null
24 | or base.t_statistic is null
25 | or expected.coefficient is null
26 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_ridge_regression_fwl.sql:
--------------------------------------------------------------------------------
1 | /* Ridge regression coefficients do not match exactly.
2 | Instead, a threshold of no more than 0.01% deviation is enforced. */
3 | {% set THRESHOLD = 0.0001 %}
4 | with
5 |
6 | expected as (
7 |
8 | select 'const' as variable_name, 20.7548151107157 as coefficient
9 | union all
10 | select 'x1' as variable_name, 9.784064449021356 as coefficient
11 | union all
12 | select 'x2' as variable_name, 6.315640539781496 as coefficient
13 | union all
14 | select 'x3' as variable_name, 19.578696589513562 as coefficient
15 | union all
16 | select 'x4' as variable_name, 3.736823845978248 as coefficient
17 | union all
18 | select 'x5' as variable_name, 13.323547772767592 as coefficient
19 |
20 | )
21 |
22 | select
23 | coalesce(base.variable_name, expected.variable_name) as variable_name
24 | from {{ ref('collinear_matrix_ridge_regression_fwl') }} as base
25 | full outer join expected
26 | on base.variable_name = expected.variable_name
27 | where
28 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }}
29 | or sign(base.coefficient) != sign(expected.coefficient)
30 | or base.coefficient is null
31 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_simple_5var_regression_wide.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select
6 | 10.0 as const,
7 | 5.0 as xa,
8 | 7.0 as xb,
9 | 9.0 as xc,
10 | 11.0 as xd,
11 | 13.0 as xe
12 | )
13 |
14 | select
15 | expected.const as expected_const,
16 | base.const as actual_const,
17 | expected.xa as expected_xa,
18 | base.xa as actual_xa,
19 | expected.xb as expected_xb,
20 | base.xb as actual_xb,
21 | expected.xc as expected_xc,
22 | base.xc as actual_xc,
23 | expected.xd as expected_xd,
24 | base.xd as actual_xd,
25 | expected.xe as expected_xe,
26 | base.xe as actual_xe
27 | from {{ ref('simple_5var_regression_wide') }} as base, expected
28 | where not (
29 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }}
30 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }}
31 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }}
32 | and abs(base.xc - expected.xc) <= {{ var("_test_precision_simple_matrix") }}
33 | and abs(base.xd - expected.xd) <= {{ var("_test_precision_simple_matrix") }}
34 | and abs(base.xe - expected.xe) <= {{ var("_test_precision_simple_matrix") }}
35 | )
36 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_5var_without_const_ridge.sql:
--------------------------------------------------------------------------------
1 | /* Ridge regression coefficients do not match exactly.
2 | Instead, a threshold of no more than 0.01% deviation is enforced. */
3 | {% set THRESHOLD = 0.0001 %}
4 | with
5 |
6 | expected as (
7 |
8 | select 'x1' as variable_name, 9.44760329057758 as coefficient
9 | union all
10 | select 'x2' as variable_name, 3.5049555562844787 as coefficient
11 | union all
12 | select 'x3' as variable_name, 20.753357497835637 as coefficient
13 | union all
14 | select 'x4' as variable_name, 3.522584853991104 as coefficient
15 | union all
16 | select 'x5' as variable_name, 16.31725550368597 as coefficient
17 |
18 | )
19 |
20 | select
21 | coalesce(base.variable_name, expected.variable_name) as variable_name,
22 | expected.coefficient as expected_coefficient,
23 | base.coefficient as actual_coefficient
24 | from {{ ref('collinear_matrix_5var_without_const_ridge') }} as base
25 | full outer join expected
26 | on base.variable_name = expected.variable_name
27 | where
28 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }}
29 | or sign(base.coefficient) != sign(expected.coefficient)
30 | or base.coefficient is null
31 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_ridge_regression_chol.sql:
--------------------------------------------------------------------------------
1 | /* Ridge regression coefficients do not match exactly.
2 | Instead, a threshold of no more than 0.01% deviation is enforced. */
3 | {% set THRESHOLD = 0.0001 %}
4 | with
5 |
6 | expected as (
7 |
8 | select 'const' as variable_name, 20.7548151107157 as coefficient
9 | union all
10 | select 'x1' as variable_name, 9.784064449021356 as coefficient
11 | union all
12 | select 'x2' as variable_name, 6.315640539781496 as coefficient
13 | union all
14 | select 'x3' as variable_name, 19.578696589513562 as coefficient
15 | union all
16 | select 'x4' as variable_name, 3.736823845978248 as coefficient
17 | union all
18 | select 'x5' as variable_name, 13.323547772767592 as coefficient
19 |
20 | )
21 |
22 | select
23 | coalesce(base.variable_name, expected.variable_name) as variable_name,
24 | expected.coefficient as expected_coefficient,
25 | base.coefficient as actual_coefficient
26 | from {{ ref('collinear_matrix_ridge_regression_chol') }} as base
27 | full outer join expected
28 | on base.variable_name = expected.variable_name
29 | where
30 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }}
31 | or sign(base.coefficient) != sign(expected.coefficient)
32 | or base.coefficient is null
33 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_regression_fwl.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 19.757104885315176 as coefficient
6 | union all
7 | select 'x1' as variable_name, 9.90708767581426 as coefficient
8 | union all
9 | select 'x2' as variable_name, 6.187473206056227 as coefficient
10 | union all
11 | select 'x3' as variable_name, 19.66874583168642 as coefficient
12 | union all
13 | select 'x4' as variable_name, 3.7192417102253468 as coefficient
14 | union all
15 | select 'x5' as variable_name, 13.444273483323244 as coefficient
16 |
17 | )
18 |
19 | select
20 | coalesce(base.variable_name, expected.variable_name) as variable_name,
21 | expected.coefficient as expected_coefficient,
22 | base.coefficient as actual_coefficient
23 | from {{ ref('collinear_matrix_regression_fwl') }} as base
24 | full outer join expected
25 | on base.variable_name = expected.variable_name
26 | where
27 | {% if target.name == "clickhouse" %}
28 | {# This has poor precision for Clickhouse; not much I can do about it. #}
29 | abs(base.coefficient - expected.coefficient) > 0.1
30 | {% else %}
31 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
32 | {% endif %}
33 | or base.coefficient is null
34 | or expected.coefficient is null
35 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_ridge_regression_chol_unoptimized.sql:
--------------------------------------------------------------------------------
1 | /* Ridge regression coefficients do not match exactly.
2 | Instead, a threshold of no more than 0.01% deviation is enforced. */
3 | {% set THRESHOLD = 0.0001 %}
4 | with
5 |
6 | expected as (
7 |
8 | select 'const' as variable_name, 20.7548151107157 as coefficient
9 | union all
10 | select 'x1' as variable_name, 9.784064449021356 as coefficient
11 | union all
12 | select 'x2' as variable_name, 6.315640539781496 as coefficient
13 | union all
14 | select 'x3' as variable_name, 19.578696589513562 as coefficient
15 | union all
16 | select 'x4' as variable_name, 3.736823845978248 as coefficient
17 | union all
18 | select 'x5' as variable_name, 13.323547772767592 as coefficient
19 |
20 | )
21 |
22 | select
23 | coalesce(base.variable_name, expected.variable_name) as variable_name,
24 | expected.coefficient as expected_coefficient,
25 | base.coefficient as actual_coefficient
26 | from {{ ref('collinear_matrix_ridge_regression_chol_unoptimized') }} as base
27 | full outer join expected
28 | on base.variable_name = expected.variable_name
29 | where
30 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }}
31 | or sign(base.coefficient) != sign(expected.coefficient)
32 | or base.coefficient is null
33 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_3var_without_const.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'x1' as variable_name, 20.090207982897063 as coefficient, 0.5196176972417176 as standard_error, 38.6634406209445 as t_statistic
6 | union all
7 | select 'x2' as variable_name, -16.533211090826203 as coefficient, 0.7481701784700665 as standard_error, -22.098195793682894 as t_statistic
8 | union all
9 | select 'x3' as variable_name, 35.00389104686492 as coefficient, 0.351617515124373 as standard_error, 99.55104493154575 as t_statistic
10 |
11 | )
12 |
13 | select
14 | coalesce(base.variable_name, expected.variable_name) as variable_name,
15 | expected.coefficient as expected_coefficient,
16 | base.coefficient as actual_coefficient
17 | from {{ ref('collinear_matrix_3var_without_const') }} as base
18 | full outer join expected
19 | on base.variable_name = expected.variable_name
20 | where
21 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
22 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
23 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
24 | or base.coefficient is null
25 | or base.standard_error is null
26 | or base.t_statistic is null
27 | or expected.coefficient is null
28 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_long_format_options.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | base as (
4 |
5 | select strip_quotes, vname, co
6 | from {{ ref("long_format_options") }}
7 |
8 | ),
9 |
10 | find_unstripped_quotes as (
11 |
12 | select
13 | cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_true,
14 | cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_false
15 | from base
16 | where not strip_quotes
17 |
18 | ),
19 |
20 | dodge_unstripped_quotes as (
21 |
22 | select
23 | cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_true,
24 | cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_false
25 | from base
26 | where strip_quotes
27 |
28 | ),
29 |
30 | coef_col_name as (
31 |
32 | select
33 | cast(max(cast(vname = 'constant_term' as integer)) as boolean) as should_be_true,
34 | cast(max(cast(vname = 'const' as integer)) as boolean) as should_be_false
35 | from base
36 |
37 | )
38 |
39 | select 'find_unstripped_quotes' as test_case
40 | from find_unstripped_quotes
41 | where should_be_false or not should_be_true
42 |
43 | union all
44 |
45 | select 'dodge_unstripped_quotes' as test_case
46 | from dodge_unstripped_quotes
47 | where should_be_false or not should_be_true
48 |
49 | union all
50 |
51 | select 'coef_col_name' as test_case
52 | from coef_col_name
53 | where should_be_false or not should_be_true
54 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_4var_without_const.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'x1' as variable_name, 20.587532776354163 as coefficient, 0.5176259827853541 as standard_error, 39.772989496339235 as t_statistic
6 | union all
7 | select 'x2' as variable_name, -20.41001520357013 as coefficient, 0.8103907603637923 as standard_error, -25.185399688426696 as t_statistic
8 | union all
9 | select 'x3' as variable_name, 35.084935774341524 as coefficient, 0.34920588221192245 as standard_error, 100.4706322588505 as t_statistic
10 | union all
11 | select 'x4' as variable_name, 1.8960558858899716 as coefficient, 0.1583538085466205 as standard_error, 11.973541421529871 as t_statistic
12 |
13 | )
14 |
15 | select
16 | coalesce(base.variable_name, expected.variable_name) as variable_name,
17 | expected.coefficient as expected_coefficient,
18 | base.coefficient as actual_coefficient
19 | from {{ ref('collinear_matrix_4var_without_const') }} as base
20 | full outer join expected
21 | on base.variable_name = expected.variable_name
22 | where
23 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
24 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
25 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
26 | or base.coefficient is null
27 | or base.standard_error is null
28 | or base.t_statistic is null
29 | or expected.coefficient is null
30 |
--------------------------------------------------------------------------------
/docs/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: dbt_linreg
2 | site_url: https://dwreeves.github.io/dbt_linreg/
3 | site_description: Docs for dbt_linreg
4 | site_author: Daniel Reeves
5 | repo_url: https://github.com/dwreeves/dbt_linreg/
6 | repo_name: dbt_linreg
7 | docs_dir: src
8 | nav:
9 | - Home: index.md
10 | theme:
11 | name: material
12 | palette:
13 | - media: "(prefers-color-scheme: dark)"
14 | scheme: slate
15 | primary: black
16 | accent: light blue
17 | toggle:
18 | icon: material/lightbulb-outline
19 | name: Switch to light mode
20 | - media: "(prefers-color-scheme: light)"
21 | scheme: default
22 | primary: white
23 | accent: light blue
24 | toggle:
25 | icon: material/lightbulb
26 | name: Switch to dark mode
27 | logo: img/dbt-linreg-logo.png
28 | favicon: img/dbt-linreg-favicon.png
29 | font:
30 | text: Opens Sans
31 | code: Roboto Mono
32 | plugins:
33 | - macros
34 | markdown_extensions:
35 | - admonition
36 | - pymdownx.tabbed
37 | - pymdownx.keys
38 | - pymdownx.details
39 | - pymdownx.inlinehilite
40 | - pymdownx.superfences
41 | - markdown_include.include:
42 | base_path: docs
43 | - sane_lists
44 | extra_css:
45 | - css/extra.css
46 | extra:
47 | social:
48 | - icon: fontawesome/brands/github
49 | link: https://github.com/dwreeves/dbt_linreg
50 | - icon: fontawesome/brands/linkedin
51 | link: https://www.linkedin.com/in/daniel-reeves-27700545/
52 | - icon: fontawesome/brands/twitter
53 | link: https://twitter.com/mueblesfeos
54 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_5var_without_const.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'x1' as variable_name, 11.392300499659957 as coefficient, 0.5240533254061608 as standard_error, 21.73881921430515 as t_statistic
6 | union all
7 | select 'x2' as variable_name, 2.333060182571783 as coefficient, 0.9201150492406911 as standard_error, 2.5356178931070636 as t_statistic
8 | union all
9 | select 'x3' as variable_name, 21.895814737788875 as coefficient, 0.44810399169425286 as standard_error, 48.8632441210849 as t_statistic
10 | union all
11 | select 'x4' as variable_name, 3.4480236159406785 as coefficient, 0.1504072830205524 as standard_error, 22.92457882820424 as t_statistic
12 | union all
13 | select 'x5' as variable_name, 15.766951731565559 as coefficient, 0.37297028350495787 as standard_error, 42.274015997727524 as t_statistic
14 |
15 | )
16 |
17 | select
18 | coalesce(base.variable_name, expected.variable_name) as variable_name,
19 | expected.coefficient as expected_coefficient,
20 | base.coefficient as actual_coefficient
21 | from {{ ref('collinear_matrix_5var_without_const') }} as base
22 | full outer join expected
23 | on base.variable_name = expected.variable_name
24 | where
25 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
26 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
27 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
28 | or base.coefficient is null
29 | or base.standard_error is null
30 | or base.t_statistic is null
31 | or expected.coefficient is null
32 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_groups_matrix_regression_fwl.sql:
--------------------------------------------------------------------------------
1 | /* Ridge regression coefficients do not match exactly.
2 | Instead, a threshold of no more than 0.01% deviation is enforced. */
3 | {% set THRESHOLD = 0.0001 %}
4 | with
5 |
6 | expected as (
7 |
8 | select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient
9 | union all
10 | select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient
11 | union all
12 | select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient
13 | union all
14 | select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient
15 | union all
16 | select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient
17 | union all
18 | select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient
19 | union all
20 | select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient
21 | union all
22 | select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient
23 |
24 | )
25 |
26 | select
27 | coalesce(base.variable_name, expected.variable_name) as variable_name,
28 | expected.coefficient as expected_coefficient,
29 | base.coefficient as actual_coefficient
30 | from {{ ref('groups_matrix_regression_fwl') }} as base
31 | full outer join expected
32 | on
33 | base.gb_var = expected.gb_var
34 | and base.variable_name = expected.variable_name
35 | where
36 | {% if target.name == "clickhouse" %}
37 | {# This has poor precision for Clickhouse; not much I can do about it. #}
38 | abs(base.coefficient - expected.coefficient) > 0.1
39 | {% else %}
40 | abs(base.coefficient - expected.coefficient) > {{ THRESHOLD }}
41 | {% endif %}
42 | or base.coefficient is null
43 | or expected.coefficient is null
44 |
--------------------------------------------------------------------------------
/run:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -eo pipefail
4 |
5 | export DBT_PROFILES_DIR=./integration_tests/profiles
6 | export DBT_PROJECT_DIR=./integration_tests
7 | export DO_NOT_TRACK=1
8 |
9 | if [ -f .env ]; then
10 | # shellcheck disable=SC2002,SC2046
11 | export $(cat .env | xargs)
12 | fi
13 |
14 | function dbt {
15 | uv run dbt "${@}"
16 | }
17 |
18 | function setup {
19 | uv sync --all-extras
20 | uvx pre-commit install
21 | }
22 |
23 | function docker-run-clickhouse {
24 | docker run \
25 | -d \
26 | -p 8123:8123 \
27 | --name dbt-linreg-clickhouse \
28 | --ulimit nofile=262144:262144 \
29 | clickhouse/clickhouse-server
30 | }
31 |
32 | function test {
33 | local target="${1-"duckdb"}"
34 |
35 | if [ -z "${GITHUB_ACTIONS}" ] && [ "${target}" = "postgres" ];
36 | then
37 | createdb "${POSTGRES_DB-"dbt_linreg"}" || true
38 | fi
39 |
40 | if [ "${target}" = "clickhouse" ];
41 | then
42 | docker-run-clickhouse || true
43 | fi
44 |
45 | if [ -z "${GITHUB_ACTIONS}" ] && [ "${target}" = "duckdb" ];
46 | then
47 | rm -f dbt.duckdb
48 | fi
49 |
50 | uv run scripts.py gen-test-cases
51 | dbt deps --target "${target}"
52 | dbt seed --target "${target}"
53 | dbt run --target "${target}" --selector "${target}-selector"
54 | dbt test --target "${target}" --selector "${target}-selector" --store-failures
55 | }
56 |
57 | function lint {
58 | uv run pre-commit run -a
59 | }
60 |
61 | function docs:deploy {
62 | # shellcheck disable=SC2046
63 | uv run $(xargs -I{} echo --with {} < docs/requirements.txt) mkdocs gh-deploy -f docs/mkdocs.yml
64 | }
65 |
66 | function help {
67 | echo "$0 "
68 | echo "Tasks:"
69 | compgen -A function | cat -n
70 | }
71 |
72 | TIMEFORMAT=$'\nTask completed in %3lR'
73 | time "${@:-help}"
74 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | jobs:
10 | pre-commit:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 | - uses: actions/setup-python@v4
15 | - uses: pre-commit/action@v3.0.0
16 | integration-tests:
17 | runs-on: ubuntu-latest
18 | strategy:
19 | matrix:
20 | dbt_core: [1.5.*, 1.8.*]
21 | db_target: [duckdb, postgres, clickhouse]
22 | services:
23 | postgres:
24 | image: ${{ (matrix.db_target == 'postgres') && 'postgres' || '' }}
25 | env:
26 | POSTGRES_USER: postgres
27 | POSTGRES_PASSWORD: postgres
28 | POSTGRES_DB: dbt_linreg
29 | ports:
30 | - 5432:5432
31 | options: >-
32 | --health-cmd pg_isready
33 | --health-interval 10s
34 | --health-timeout 5s
35 | --health-retries 5
36 | steps:
37 | - uses: actions/checkout@v4
38 | - name: Install uv
39 | uses: astral-sh/setup-uv@v5
40 | - name: Setup
41 | run: |
42 | sudo apt-get update
43 | sudo apt-get install
44 | chmod +x ./run
45 | uv venv
46 | uv sync --extra python-dev
47 | uv pip install -U "dbt-core==$DBT_CORE_VERSION" "dbt-${DBT_TARGET}==$DBT_CORE_VERSION"
48 | env:
49 | UV_NO_SYNC: true
50 | DO_NOT_TRACK: 1
51 | DBT_CORE_VERSION: ${{ matrix.dbt_core }}
52 | DBT_TARGET: ${{ matrix.db_target }}
53 | - name: Test
54 | run: ./run test "${DBT_TARGET}"
55 | env:
56 | UV_NO_SYNC: true
57 | DO_NOT_TRACK: 1
58 | DBT_TARGET: ${{ matrix.db_target }}
59 | POSTGRES_HOST: localhost
60 | POSTGRES_USER: postgres
61 | POSTGRES_PASSWORD: postgres
62 | POSTGRES_DB: dbt_linreg
63 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_regression_chol.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 19.757104885315176 as coefficient, 2.992803142237603 as standard_error, 6.601538406078909 as t_statistic
6 | union all
7 | select 'x1' as variable_name, 9.90708767581426 as coefficient, 0.5692826957191374 as standard_error, 17.402755696445837 as t_statistic
8 | union all
9 | select 'x2' as variable_name, 6.187473206056227 as coefficient, 1.0880807259333622 as standard_error, 5.686593888287631 as t_statistic
10 | union all
11 | select 'x3' as variable_name, 19.66874583168642 as coefficient, 0.5601379212447676 as standard_error, 35.11411223146169 as t_statistic
12 | union all
13 | select 'x4' as variable_name, 3.7192417102253468 as coefficient, 0.15560940177101745 as standard_error, 23.901137514160553 as t_statistic
14 | union all
15 | select 'x5' as variable_name, 13.444273483323244 as coefficient, 0.5121595119107619 as standard_error, 26.250168493728488 as t_statistic
16 |
17 | )
18 |
19 | select
20 | coalesce(base.variable_name, expected.variable_name) as variable_name,
21 | expected.coefficient as expected_coefficient,
22 | base.coefficient as actual_coefficient,
23 | expected.standard_error as expected_standard_error,
24 | base.standard_error as actual_standard_error,
25 | expected.t_statistic as expected_t_statistic,
26 | base.t_statistic as actual_t_statistic
27 | from {{ ref('collinear_matrix_regression_chol') }} as base
28 | full outer join expected
29 | on base.variable_name = expected.variable_name
30 | where
31 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
32 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
33 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
34 | or base.coefficient is null
35 | or base.standard_error is null
36 | or base.t_statistic is null
37 | or expected.coefficient is null
38 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_collinear_matrix_regression_chol_unoptimized.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'const' as variable_name, 19.757104885315176 as coefficient, 2.992803142237603 as standard_error, 6.601538406078909 as t_statistic
6 | union all
7 | select 'x1' as variable_name, 9.90708767581426 as coefficient, 0.5692826957191374 as standard_error, 17.402755696445837 as t_statistic
8 | union all
9 | select 'x2' as variable_name, 6.187473206056227 as coefficient, 1.0880807259333622 as standard_error, 5.686593888287631 as t_statistic
10 | union all
11 | select 'x3' as variable_name, 19.66874583168642 as coefficient, 0.5601379212447676 as standard_error, 35.11411223146169 as t_statistic
12 | union all
13 | select 'x4' as variable_name, 3.7192417102253468 as coefficient, 0.15560940177101745 as standard_error, 23.901137514160553 as t_statistic
14 | union all
15 | select 'x5' as variable_name, 13.444273483323244 as coefficient, 0.5121595119107619 as standard_error, 26.250168493728488 as t_statistic
16 |
17 | )
18 |
19 | select
20 | coalesce(base.variable_name, expected.variable_name) as variable_name,
21 | expected.coefficient as expected_coefficient,
22 | base.coefficient as actual_coefficient,
23 | expected.standard_error as expected_standard_error,
24 | base.standard_error as actual_standard_error,
25 | expected.t_statistic as expected_t_statistic,
26 | base.t_statistic as actual_t_statistic
27 | from {{ ref('collinear_matrix_regression_chol_unoptimized') }} as base
28 | full outer join expected
29 | on base.variable_name = expected.variable_name
30 | where
31 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
32 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
33 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
34 | or base.coefficient is null
35 | or base.standard_error is null
36 | or base.t_statistic is null
37 | or expected.coefficient is null
38 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | ### `0.3.1`
4 |
5 | - Fix bug in `vars:` implementation of method options.
6 |
7 | ### `0.3.0`
8 |
9 | - Official support for Clickhouse!
10 | - Rename `format=` and `format_options=` to `output=` and `output_options=` to make the API consistent with **dbt_pca**.
11 | - Allow for setting method and output options globally with `vars:`
12 |
13 | ### `0.2.6`
14 |
15 | - Fix bug with `group_by` on multiple variables; contributed by [@svkohler](https://github.com/dwreeves/dbt_linreg/issues/21).
16 |
17 | ### `0.2.5`
18 |
19 | - Fix bug where `exog` and `group_by` did not handle `str` inputs e.g. `exog="x"`.
20 | - Fix bug where `group_by` for `method='fwl'` with exactly 1 exog variable did not work. (Explanation: `method='fwl'` dispatches to a different macro for the special case of 1 exog variable, and `group_by` was not implemented correctly here.)
21 | - Fix bug where `safe` mode did not work for `method='chol'`.
22 | - Improved docs by hiding everything except `ols()`, improved description of `ols()` macro, and added missing arg.
23 |
24 | ### `0.2.4`
25 |
26 | - Fix minor incompatibility with Redshift; contributed by [@steelcd](https://github.com/steelcd).
27 |
28 | ### `0.2.3`
29 |
30 | - Added Postgres support in integration tests + fixed bugs that prevented Postgres from working.
31 |
32 | ### `0.2.2`
33 |
34 | - Added dbt documentation of the `ols()` macro.
35 |
36 | ### `0.2.1`
37 |
38 | - Added `.dbtignore`.
39 |
40 | ### `0.2.0`
41 |
42 | - Add `chol` method to `dbt_linreg.ols()`, and also set as the default method. (This method is significantly faster than `fwl`, and has a few other benefits.)
43 | - Add standard error column in `long` format for `chol` method.
44 |
45 | ### `0.1.2`
46 |
47 | - Added the ability to turn off/on the constant term with `add_constant: bool = True` kwarg.
48 | - Fixed error that occurred when rendering a 1-variable ridge regression.
49 |
50 | ### `0.1.1`
51 |
52 | - Fixed namespacing issue with CTEs-- all CTEs generated by `dbt_linreg` now start with `_dbt_linreg_`, to reduce namespace conflicts with user generated SQL.
53 | - Locked the dbt-core version requirement to `>=1.2.0` (for now) because one of this package's dependencies (`modules.itertools.combinations`) is not available prior to `1.2.0`.
54 |
55 | ### `0.1.0`
56 |
57 | - Initial release
58 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_groups_matrix_regression_chol_optimized.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient, 0.053945103940799474 as standard_error, -1.2166194078844779 as t_statistic
6 | union all
7 | select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient, 0.015209571618398615 as standard_error, 65.12622136954383 as t_statistic
8 | union all
9 | select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient, 0.02906881854690807 as standard_error, 170.2243829590593 as t_statistic
10 | union all
11 | select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient, 0.014337008978330493 as standard_error, 2.178559705108859 as t_statistic
12 | union all
13 | select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient, 0.035587045398501334 as standard_error, 56.529364150464545 as t_statistic
14 | union all
15 | select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient, 0.006731681784764358 as standard_error, 445.1088462064698 as t_statistic
16 | union all
17 | select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient, 0.008744674914389008 as standard_error, 1031.4486907791759 as t_statistic
18 | union all
19 | select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient, 0.0072206704541224525 as standard_error, 2.2368166875178472 as t_statistic
20 |
21 | )
22 |
23 | select
24 | coalesce(base.variable_name, expected.variable_name) as variable_name,
25 | expected.coefficient as expected_coefficient,
26 | base.coefficient as actual_coefficient,
27 | expected.standard_error as expected_standard_error,
28 | base.standard_error as actual_standard_error,
29 | expected.t_statistic as expected_t_statistic,
30 | base.t_statistic as actual_t_statistic
31 | from {{ ref('groups_matrix_regression_chol_optimized') }} as base
32 | full outer join expected
33 | on
34 | base.gb_var = expected.gb_var
35 | and base.variable_name = expected.variable_name
36 | where
37 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
38 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
39 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
40 | or base.coefficient is null
41 | or base.standard_error is null
42 | or base.t_statistic is null
43 | or expected.coefficient is null
44 |
--------------------------------------------------------------------------------
/integration_tests/tests/test_groups_matrix_regression_chol_unoptimized.sql:
--------------------------------------------------------------------------------
1 | with
2 |
3 | expected as (
4 |
5 | select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient, 0.053945103940799474 as standard_error, -1.2166194078844779 as t_statistic
6 | union all
7 | select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient, 0.015209571618398615 as standard_error, 65.12622136954383 as t_statistic
8 | union all
9 | select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient, 0.02906881854690807 as standard_error, 170.2243829590593 as t_statistic
10 | union all
11 | select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient, 0.014337008978330493 as standard_error, 2.178559705108859 as t_statistic
12 | union all
13 | select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient, 0.035587045398501334 as standard_error, 56.529364150464545 as t_statistic
14 | union all
15 | select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient, 0.006731681784764358 as standard_error, 445.1088462064698 as t_statistic
16 | union all
17 | select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient, 0.008744674914389008 as standard_error, 1031.4486907791759 as t_statistic
18 | union all
19 | select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient, 0.0072206704541224525 as standard_error, 2.2368166875178472 as t_statistic
20 |
21 | )
22 |
23 | select
24 | coalesce(base.variable_name, expected.variable_name) as variable_name,
25 | expected.coefficient as expected_coefficient,
26 | base.coefficient as actual_coefficient,
27 | expected.standard_error as expected_standard_error,
28 | base.standard_error as actual_standard_error,
29 | expected.t_statistic as expected_t_statistic,
30 | base.t_statistic as actual_t_statistic
31 | from {{ ref('groups_matrix_regression_chol_unoptimized') }} as base
32 | full outer join expected
33 | on
34 | base.gb_var = expected.gb_var
35 | and base.variable_name = expected.variable_name
36 | where
37 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }}
38 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }}
39 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }}
40 | or base.coefficient is null
41 | or base.standard_error is null
42 | or base.t_statistic is null
43 | or expected.coefficient is null
44 |
--------------------------------------------------------------------------------
/macros/linear_regression/ols_impl_special/_ols_1var.sql:
--------------------------------------------------------------------------------
1 | {% macro _ols_1var(table,
2 | endog,
3 | exog,
4 | add_constant=True,
5 | output=None,
6 | output_options=None,
7 | group_by=None,
8 | alpha=None) -%}
9 | {%- set exog_aliased = ['x1'] %}
10 | (with
11 | {%- if alpha %}
12 | _dbt_linreg_cmeans as (
13 | select
14 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
15 | avg({{ endog }}) as y,
16 | avg({{ exog[0] }}) as x1,
17 | count(*) as ct
18 | from
19 | {{ table }}
20 | {%- if group_by %}
21 | group by
22 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }}
23 | {%- endif %}
24 | ),
25 | {%- endif %}
26 | _dbt_linreg_base as (
27 | select
28 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
29 | {%- if alpha and add_constant %}
30 | b.{{ endog }} - _dbt_linreg_cmeans.y as y,
31 | b.{{ exog[0] }} - _dbt_linreg_cmeans.x1 as x1,
32 | {%- else %}
33 | {{ endog }} as y,
34 | b.{{ exog[0] }} as x1,
35 | {%- endif %}
36 | false as fake
37 | from
38 | {{ table }} as b
39 | {%- if alpha %}
40 | {%- if add_constant %}
41 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }}
42 | {%- endif %}
43 | union all
44 | select
45 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
46 | 0 as y,
47 | pow(x1 * ct, 0.5) as x1,
48 | true as fake
49 | from _dbt_linreg_cmeans
50 | {%- endif %}
51 | ),
52 | _dbt_linreg_final_coefs as (
53 | select
54 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
55 | {%- if add_constant %}
56 | avg({{ dbt_linreg._filter_and_center_if_alpha('b.y', alpha) }})
57 | - avg({{ dbt_linreg._filter_and_center_if_alpha('b.x1', alpha) }}) * {{ dbt_linreg.regress('b.y', 'b.x1') }}
58 | as x0_coef,
59 | {%- endif %}
60 | {{ dbt_linreg.regress('b.y', 'b.x1', add_constant=add_constant) }} as x1_coef
61 | from _dbt_linreg_base as b
62 | {%- if alpha and add_constant %}
63 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }}
64 | {%- endif %}
65 | {%- if group_by %}
66 | group by
67 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }}
68 | {%- endif %}
69 | )
70 | {{
71 | dbt_linreg.final_select(
72 | exog=exog,
73 | exog_aliased=['x1'],
74 | add_constant=add_constant,
75 | group_by=group_by,
76 | output=format,
77 | output_options=output_options,
78 | calculate_standard_error=False
79 | )
80 | }}
81 | )
82 | {%- endmacro %}
83 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | integration_tests/package-lock.yml
2 | dbt.duckdb
3 | dbt.duckdb.wal
4 | .user.yml
5 | docs/site/
6 | integration_tests/seeds/*.csv
7 | integration_tests/package-lock.yml
8 |
9 | # Created by https://www.toptal.com/developers/gitignore/api/dbt,osx,windows,visualstudiocode,python
10 | # Edit at https://www.toptal.com/developers/gitignore?templates=dbt,osx,windows,visualstudiocode,python
11 |
12 | ### dbt ###
13 | target/
14 | dbt_modules/
15 | dbt_packages/
16 | logs/
17 |
18 | ### OSX ###
19 | # General
20 | .DS_Store
21 | .AppleDouble
22 | .LSOverride
23 |
24 | # Icon must end with two \r
25 | Icon
26 |
27 |
28 | # Thumbnails
29 | ._*
30 |
31 | # Files that might appear in the root of a volume
32 | .DocumentRevisions-V100
33 | .fseventsd
34 | .Spotlight-V100
35 | .TemporaryItems
36 | .Trashes
37 | .VolumeIcon.icns
38 | .com.apple.timemachine.donotpresent
39 |
40 | # Directories potentially created on remote AFP share
41 | .AppleDB
42 | .AppleDesktop
43 | Network Trash Folder
44 | Temporary Items
45 | .apdisk
46 |
47 | ### Python ###
48 | # Byte-compiled / optimized / DLL files
49 | __pycache__/
50 | *.py[cod]
51 | *$py.class
52 |
53 | # C extensions
54 | *.so
55 |
56 | # Distribution / packaging
57 | .Python
58 | build/
59 | develop-eggs/
60 | dist/
61 | downloads/
62 | eggs/
63 | .eggs/
64 | lib/
65 | lib64/
66 | parts/
67 | sdist/
68 | var/
69 | wheels/
70 | share/python-wheels/
71 | *.egg-info/
72 | .installed.cfg
73 | *.egg
74 | MANIFEST
75 |
76 | # PyInstaller
77 | # Usually these files are written by a python script from a template
78 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
79 | *.manifest
80 | *.spec
81 |
82 | # Installer logs
83 | pip-log.txt
84 | pip-delete-this-directory.txt
85 |
86 | # Unit test / coverage reports
87 | htmlcov/
88 | .tox/
89 | .nox/
90 | .coverage
91 | .coverage.*
92 | .cache
93 | nosetests.xml
94 | coverage.xml
95 | *.cover
96 | *.py,cover
97 | .hypothesis/
98 | .pytest_cache/
99 | cover/
100 |
101 | # Translations
102 | *.mo
103 | *.pot
104 |
105 | # Django stuff:
106 | *.log
107 | local_settings.py
108 | db.sqlite3
109 | db.sqlite3-journal
110 |
111 | # Flask stuff:
112 | instance/
113 | .webassets-cache
114 |
115 | # Scrapy stuff:
116 | .scrapy
117 |
118 | # Sphinx documentation
119 | docs/_build/
120 |
121 | # PyBuilder
122 | .pybuilder/
123 |
124 | # Jupyter Notebook
125 | .ipynb_checkpoints
126 |
127 | # IPython
128 | profile_default/
129 | ipython_config.py
130 |
131 | # pyenv
132 | # For a library or package, you might want to ignore these files since the code is
133 | # intended to run in multiple environments; otherwise, check them in:
134 | # .python-version
135 |
136 | # pipenv
137 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
138 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
139 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
140 | # install all needed dependencies.
141 | #Pipfile.lock
142 |
143 | # poetry
144 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
145 | # This is especially recommended for binary packages to ensure reproducibility, and is more
146 | # commonly ignored for libraries.
147 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
148 | #poetry.lock
149 |
150 | # pdm
151 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
152 | #pdm.lock
153 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
154 | # in version control.
155 | # https://pdm.fming.dev/#use-with-ide
156 | .pdm.toml
157 |
158 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
159 | __pypackages__/
160 |
161 | # Celery stuff
162 | celerybeat-schedule
163 | celerybeat.pid
164 |
165 | # SageMath parsed files
166 | *.sage.py
167 |
168 | # Environments
169 | .env
170 | .venv
171 | env/
172 | venv/
173 | ENV/
174 | env.bak/
175 | venv.bak/
176 |
177 | # Spyder project settings
178 | .spyderproject
179 | .spyproject
180 |
181 | # Rope project settings
182 | .ropeproject
183 |
184 | # mkdocs documentation
185 | /site
186 |
187 | # mypy
188 | .mypy_cache/
189 | .dmypy.json
190 | dmypy.json
191 |
192 | # Pyre type checker
193 | .pyre/
194 |
195 | # pytype static type analyzer
196 | .pytype/
197 |
198 | # Cython debug symbols
199 | cython_debug/
200 |
201 | # PyCharm
202 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
203 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
204 | # and can be added to the global gitignore or merged into this file. For a more nuclear
205 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
206 | #.idea/
207 |
208 | ### Python Patch ###
209 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
210 | poetry.toml
211 |
212 | # ruff
213 | .ruff_cache/
214 |
215 | ### VisualStudioCode ###
216 | .vscode/*
217 | !.vscode/settings.json
218 | !.vscode/tasks.json
219 | !.vscode/launch.json
220 | !.vscode/extensions.json
221 | !.vscode/*.code-snippets
222 |
223 | # Local History for Visual Studio Code
224 | .history/
225 |
226 | # Built Visual Studio Code Extensions
227 | *.vsix
228 |
229 | ### VisualStudioCode Patch ###
230 | # Ignore all local history of files
231 | .history
232 | .ionide
233 |
234 | ### Windows ###
235 | # Windows thumbnail cache files
236 | Thumbs.db
237 | Thumbs.db:encryptable
238 | ehthumbs.db
239 | ehthumbs_vista.db
240 |
241 | # Dump file
242 | *.stackdump
243 |
244 | # Folder config file
245 | [Dd]esktop.ini
246 |
247 | # Recycle Bin used on file shares
248 | $RECYCLE.BIN/
249 |
250 | # Windows Installer files
251 | *.cab
252 | *.msi
253 | *.msix
254 | *.msm
255 | *.msp
256 |
257 | # Windows shortcuts
258 | *.lnk
259 |
260 | # End of https://www.toptal.com/developers/gitignore/api/dbt,osx,windows,visualstudiocode,python
261 |
--------------------------------------------------------------------------------
/macros/linear_regression/ols.sql:
--------------------------------------------------------------------------------
1 | {% macro ols(table,
2 | endog=none,
3 | exog=none,
4 | x=none,
5 | y=none,
6 | add_constant=true,
7 | output='wide',
8 | output_options=none,
9 | format=none,
10 | format_options=none,
11 | group_by=none,
12 | alpha=none,
13 | method=none,
14 | method_options=none) -%}
15 |
16 | {#############################################################################
17 |
18 | This function does 3 things:
19 |
20 | 1. Resolves and casts polymorphic inputs.
21 | 1. Validates inputs.
22 | 2. Dispatches the appropriate call.
23 |
24 | The actual calculations occur elsewhere in the code, depending on the
25 | implementation chosen.
26 |
27 | #############################################################################}
28 |
29 | {# Format the variables, and cast strings to lists #}
30 | {# ----------------------------------------------- #}
31 |
32 | {% if x is not none and exog is none %}
33 | {% set exog = x %}
34 | {% elif x is not none and exog is not none %}
35 | {{ exceptions.raise_compiler_error(
36 | "Please specify either `exog` (preferred) or `x`, not both."
37 | " `x` is just an alias for `exog`."
38 | ) }}
39 | {% endif %}
40 |
41 | {% if format_options is not none and output_options is not none %}
42 | {{ exceptions.raise_compiler_error(
43 | "`format_options` is deprecated and is another name for `output_options`."
44 | " Please only set the `output_options`."
45 | ) }}
46 | {% endif %}
47 |
48 |
49 | {% if format is not none and output is not none %}
50 | {{ exceptions.raise_compiler_error(
51 | "`format` is deprecated and is another name for `output`."
52 | " Please only set the `output`."
53 | ) }}
54 | {% elif format is not none and output is none %}
55 | {% set output = format %}
56 | {% endif %}
57 |
58 | {% if output_options is none %}
59 | {% if format_options is not none %}
60 | {% set output_options = format_options %}
61 | {% else %}
62 | {% set output_options = {} %}
63 | {% endif %}
64 | {% endif %}
65 |
66 | {% if method_options is none %}
67 | {% set method_options = {} %}
68 | {% endif %}
69 |
70 | {% if y is not none and endog is none %}
71 | {% set endog = y %}
72 | {% elif y is not none and endog is not none %}
73 | {{ exceptions.raise_compiler_error(
74 | "Please specify either `endog` (preferred) or `y`, not both."
75 | " `y` is just an alias for `endog`."
76 | ) }}
77 | {% endif %}
78 |
79 | {% if exog is not iterable %}
80 | {% if exog is none %}
81 | {% set exog = [] %}
82 | {% else %}
83 | {% set exog = [exog] %}
84 | {% endif %}
85 | {% elif exog is string %}
86 | {% set exog = [exog] %}
87 | {% endif %}
88 |
89 | {% if group_by is not iterable %}
90 | {% if group_by is none %}
91 | {% set group_by = [] %}
92 | {% else %}
93 | {% set group_by = [group_by] %}
94 | {% endif %}
95 | {% elif group_by is string %}
96 | {% set group_by = [group_by] %}
97 | {% endif %}
98 |
99 | {% if alpha is not iterable and alpha is not none %}
100 | {% set alpha = [alpha] * (exog | length) %}
101 | {% endif %}
102 |
103 | {% if method is none %}
104 | {% set method = 'chol' %}
105 | {% endif %}
106 |
107 | {# Check for user input errors #}
108 | {# --------------------------- #}
109 |
110 | {% if endog is none %}
111 | {{ exceptions.raise_compiler_error(
112 | "`endog` is not allowed to be None."
113 | " Please specify a target variable / y-variable / endogenous variable for"
114 | " the linear regression."
115 | ) }}
116 | {% endif %}
117 |
118 | {% if not exog and not add_constant %}
119 | {{ exceptions.raise_compiler_error(
120 | "Cannot run dbt_linreg.ols() because there are no exogenous variables"
121 | " / features to regress!"
122 | ) }}
123 | {% endif %}
124 |
125 | {% for i in range((exog | length)) %}
126 | {% for j in range(i, (exog | length)) %}
127 | {% if i != j %}
128 | {% if exog[i] == exog[j] %}
129 | {% if not alpha %}
130 | {{ exceptions.raise_compiler_error(
131 | "Duplicate variables are not allowed without regularization, as"
132 | " that will lead to multicollinearity. Duplicate variable is: "
133 | ~ exog[i] ~ ", which occurs at positions " ~ i ~ " and " ~ j ~ "."
134 | ) }}
135 | {% else %}
136 | {% do log(
137 | "Note: exog variable " ~ exog[i] ~ " is duplicated at positions "
138 | ~ i ~ " and " ~ j ~ "."
139 | ) %}
140 | {% endif %}
141 | {% endif %}
142 | {% endif %}
143 | {% endfor %}
144 | {% endfor %}
145 |
146 | {% if format is not none and format not in ['wide', 'long'] %}
147 | {{ exceptions.raise_compiler_error(
148 | "Format must be either 'wide' or 'long'; received " ~ output ~ "."
149 | " Also, `format=` is deprecated; it is suggested you use `output=` instead."
150 | ) }}
151 | {% endif %}
152 |
153 | {% if output not in ['wide', 'long'] %}
154 | {{ exceptions.raise_compiler_error(
155 | "Output must be either 'wide' or 'long'; received " ~ output ~ "."
156 | ) }}
157 | {% endif %}
158 |
159 | {% if alpha is not none and (alpha | length) != (exog | length) %}
160 | {{ exceptions.raise_compiler_error(
161 | "The number of values passed in `alpha` must be equivalent to"
162 | " the number of columns in `exog`."
163 | " Received " ~ (exog | length) ~ " exog variables and " ~ (alpha | length) ~
164 | " alpha parameters."
165 | " Note that the constant term cannot be penalized."
166 | ) }}
167 | {% endif %}
168 |
169 | {% if method == 'chol' %}
170 | {{ return(
171 | dbt_linreg._ols_chol(
172 | table=table,
173 | endog=endog,
174 | exog=exog,
175 | add_constant=add_constant,
176 | output=output,
177 | output_options=output_options,
178 | group_by=group_by,
179 | alpha=alpha,
180 | method_options=method_options
181 | )
182 | ) }}
183 | {% elif method == 'fwl' %}
184 | {{ return(
185 | dbt_linreg._ols_fwl(
186 | table=table,
187 | endog=endog,
188 | exog=exog,
189 | add_constant=add_constant,
190 | output=output,
191 | output_options=output_options,
192 | group_by=group_by,
193 | alpha=alpha,
194 | method_options=method_options
195 | )
196 | ) }}
197 | {% else %}
198 | {{ exceptions.raise_compiler_error(
199 | "Invalid method specified. The only valid methods are 'chol' and 'fwl'"
200 | ) }}
201 | {% endif %}
202 |
203 | {% endmacro %}
204 |
--------------------------------------------------------------------------------
/macros/schema.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | macros:
4 | - name: ols
5 | description: |-
6 | **dbt_linreg** is an easy way to perform linear regression and ridge regression in SQL with OLS.
7 |
8 | The `dbt_linreg.ols()` macro is the core, high-level API for the **dbt_linreg** package. This macro will calculate and output the coefficients of a linear regression specified by the user. The regression can also be L2 regularized using the `alpha` argument, i.e. ridge regression is also supported.
9 |
10 | Here is an example of a dbt model that selects from a dbt model called `simple_matrix`, and runs a regression on `y` using feature columns `xa`, `xb`, and `xc`:
11 |
12 | {% raw %}
13 | ```sql
14 | {{
15 | config(
16 | materialized="table"
17 | )
18 | }}
19 | select * from {{
20 | dbt_linreg.ols(
21 | table=ref('simple_matrix'),
22 | endog='y',
23 | exog=['xa', 'xb', 'xc'],
24 | output='long',
25 | output_options={'round': 5}
26 | )
27 | }}
28 | ```
29 | {% endraw %}
30 |
31 | You may also select from a CTE; in this case, just pass a string referencing the CTE:
32 |
33 | {% raw %}
34 | ```sql
35 | {{
36 | config(
37 | materialized="table"
38 | )
39 | }}
40 | with my_data as (
41 | select * from {{ ref('simple_matrix') }}
42 | )
43 | select * from {{
44 | dbt_linreg.ols(
45 | table='my_data',
46 | endog='y',
47 | exog=['xa', 'xb', 'xc'],
48 | output='long',
49 | output_options={'round': 5}
50 | )
51 | }}
52 | ```
53 | {% endraw %}
54 |
55 | The macro renders a subquery, inclusive of parentheses.
56 |
57 | Please see the README / full documentation for more information: [https://dwreeves.github.io/dbt_linreg/](https://dwreeves.github.io/dbt_linreg/)
58 | arguments:
59 | - name: table
60 | type: string
61 | description: Name of table or CTE to pull the data from. You can use `ref()` or `source()` here if you'd like.
62 | - name: endog
63 | type: string
64 | description: The endogenous variable / y variable / target variable of the regression. (You can also specify `y=...` instead of `endog=...` if you prefer.)
65 | - name: exog
66 | type: string or list of strings
67 | description: The exogenous variables / X variables / features of the regression. (You can also specify `x=...` instead of `exog=...` if you prefer.)
68 | - name: add_constant
69 | type: boolean
70 | description: 'If true, a constant term is added automatically to the regression. (Default: `true`)'
71 | - name: output
72 | type: string
73 | description: |-
74 | Either "wide" or "long" format for coefficients. See **Formats and format options** in the README for more.
75 | - If `wide`, the variables span the columns with their original variable names, and the coefficients fill a single row.
76 | - If `long`, the coefficients are in a single column called `coefficient`, and the variable names are in a single column called `variable_name`.
77 | - name: output_options
78 | type: dict
79 | description: See **Outputs and output options** section in the README for more.
80 | - name: group_by
81 | type: string or list of numbers
82 | description: If specified, the regression will be grouped by these variables, and individual regressions will run on each group.
83 | - name: alpha
84 | type: number or list of numbers
85 | description: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section in the README for more information.
86 | - name: method
87 | type: string
88 | description: The method used to calculate the regression. Only `chol` and `fwl` are valid inputs for now. See **Methods and method options** in the README for more.
89 | - name: method_options
90 | type: dict
91 | description: Options specific to the estimation method. See **Methods and method options** in the README for more.
92 | # Everything down here is just for intermediary calculations or helper functions.
93 | # There is no point to showing these in the docs.
94 | # The truly curious can just look at the source code.
95 | #
96 | # Please generate the below with the following command:
97 | # >>> python scripts.py gen-hide-macros-yaml
98 | - name: _alias_exog
99 | docs:
100 | show: false
101 | - name: _alias_gb_cols
102 | docs:
103 | show: false
104 | - name: _cell_or_alias
105 | docs:
106 | show: false
107 | - name: _cholesky_decomposition
108 | docs:
109 | show: false
110 | - name: _filter_and_center_if_alpha
111 | docs:
112 | show: false
113 | - name: _filter_if_alpha
114 | docs:
115 | show: false
116 | - name: _format_wide_variable_column
117 | docs:
118 | show: false
119 | - name: _forward_substitution
120 | docs:
121 | show: false
122 | - name: _gb_cols
123 | docs:
124 | show: false
125 | - name: _get_method_option
126 | docs:
127 | show: false
128 | - name: _get_output_option
129 | docs:
130 | show: false
131 | - name: _join_on_groups
132 | docs:
133 | show: false
134 | - name: _maybe_round
135 | docs:
136 | show: false
137 | - name: _ols_0var
138 | docs:
139 | show: false
140 | - name: _ols_1var
141 | docs:
142 | show: false
143 | - name: _ols_chol
144 | docs:
145 | show: false
146 | - name: _ols_fwl
147 | docs:
148 | show: false
149 | - name: _orth_x_intercept
150 | docs:
151 | show: false
152 | - name: _orth_x_slope
153 | docs:
154 | show: false
155 | - name: _regress_or_alias
156 | docs:
157 | show: false
158 | - name: _safe_sqrt
159 | docs:
160 | show: false
161 | - name: _strip_quotes
162 | docs:
163 | show: false
164 | - name: _traverse_intercepts
165 | docs:
166 | show: false
167 | - name: _traverse_slopes
168 | docs:
169 | show: false
170 | - name: _unalias_gb_cols
171 | docs:
172 | show: false
173 | - name: bigquery___safe_sqrt
174 | docs:
175 | show: false
176 | - name: clickhouse___cell_or_alias
177 | docs:
178 | show: false
179 | - name: default___cell_or_alias
180 | docs:
181 | show: false
182 | - name: default___maybe_round
183 | docs:
184 | show: false
185 | - name: default___regress_or_alias
186 | docs:
187 | show: false
188 | - name: default___safe_sqrt
189 | docs:
190 | show: false
191 | - name: default__regress
192 | docs:
193 | show: false
194 | - name: duckdb___cell_or_alias
195 | docs:
196 | show: false
197 | - name: duckdb___regress_or_alias
198 | docs:
199 | show: false
200 | - name: final_select
201 | docs:
202 | show: false
203 | - name: postgres___maybe_round
204 | docs:
205 | show: false
206 | - name: redshift___maybe_round
207 | docs:
208 | show: false
209 | - name: regress
210 | docs:
211 | show: false
212 | - name: snowflake___cell_or_alias
213 | docs:
214 | show: false
215 | - name: snowflake___regress_or_alias
216 | docs:
217 | show: false
218 | - name: snowflake__regress
219 | docs:
220 | show: false
221 |
--------------------------------------------------------------------------------
/macros/linear_regression/ols_impl_fwl/_ols_impl_fwl.sql:
--------------------------------------------------------------------------------
1 | {% macro _traverse_slopes(step, x) -%}
2 | {%- set li = [] %}
3 | {%- for i in x %}
4 | {%- for j in x %}
5 | {%- set remaining = x.copy() %}
6 | {%- if i != j %}
7 | {%- do remaining.remove(i) %}
8 | {%- do remaining.remove(j) %}
9 | {%- set comb = modules.itertools.combinations(remaining, step - 1) %}
10 | {%- for c in comb %}
11 | {%- do li.append((i, j, c)) %}
12 | {%- endfor %}
13 | {%- endif %}
14 | {%- endfor %}
15 | {%- endfor %}
16 | {{ return(li) }}
17 | {% endmacro %}
18 |
19 | {% macro _traverse_intercepts(step, x) -%}
20 | {%- set li = [] %}
21 | {%- for i in x %}
22 | {%- set remaining = x.copy() %}
23 | {%- do remaining.remove(i) %}
24 | {%- set comb = modules.itertools.combinations(remaining, step) %}
25 | {%- for c in comb %}
26 | {%- set ortho = [] %}
27 | {%- if c %}
28 | {%- for b in c %}
29 | {%- set _c = (c | list) %}
30 | {%- do _c.remove(b) %}
31 | {%- do ortho.append([b] + (modules.itertools.combinations(_c, step - 1) | list)) %}
32 | {%- endfor %}
33 | {%- endif %}
34 | {%- do li.append((i, ortho)) %}
35 | {%- endfor %}
36 | {%- endfor %}
37 | {{ return(li) }}
38 | {% endmacro %}
39 |
40 | {% macro _filter_if_alpha(i, alpha) %}
41 | {% if alpha %}
42 | {{ return('case when not fake then ' ~ i ~ ' end') }}
43 | {% else %}
44 | {{ return(i) }}
45 | {% endif %}
46 | {% endmacro %}
47 |
48 | {% macro _filter_and_center_if_alpha(i, alpha, base_prefix='') %}
49 | {% if alpha %}
50 | {{ return('case when not fake then ' ~ base_prefix ~ i ~ ' + _dbt_linreg_cmeans.' ~ i ~ ' end') }}
51 | {% else %}
52 | {{ return(i) }}
53 | {% endif %}
54 | {% endmacro %}
55 |
56 | {% macro _orth_x_slope(x, o) -%}
57 | {%- if o %}
58 | {{ return(x ~ '_' ~ (o | join(''))) }}
59 | {%- else %}
60 | {{ return(x) }}
61 | {%- endif %}
62 | {% endmacro %}
63 |
64 | {% macro _orth_x_intercept(x, o) -%}
65 | {%- set li = [] %}
66 | {%- for c in o %}
67 | {%- do li.append(c[0]) %}
68 | {%- endfor %}
69 | {{ return(x ~ '_' ~ (li | join(''))) }}
70 | {% endmacro %}
71 |
72 | {% macro _regress_or_alias(y, x, add_constant=True) %}
73 | {{ return(
74 | adapter.dispatch('_regress_or_alias', 'dbt_linreg')
75 | (y, x, add_constant=add_constant)
76 | ) }}
77 | {% endmacro %}
78 |
79 | {# In some but not all query engines, you can select from other columns.
80 | Doing this keeps the compiled SQL cleaner, and for large regressions can
81 | slightly improve the query planner speed (albeit not the execution). #}
82 | {% macro default___regress_or_alias(y, x, add_constant=True) %}
83 | {{ return(dbt_linreg.regress(y, x, add_constant=add_constant)) }}
84 | {% endmacro %}
85 |
86 | {% macro snowflake___regress_or_alias(y, x, add_constant=True) %}
87 | {{ return(y ~ '_' ~ x ~ '_coef') }}
88 | {% endmacro %}
89 |
90 | {% macro duckdb___regress_or_alias(y, x, add_constant=True) %}
91 | {{ return(y ~ '_' ~ x ~ '_coef') }}
92 | {% endmacro %}
93 |
94 |
95 | {% macro _ols_fwl(table,
96 | endog,
97 | exog,
98 | add_constant=True,
99 | output=None,
100 | output_options=None,
101 | group_by=None,
102 | alpha=None,
103 | method_options=None) -%}
104 | {%- if (exog | length) == 0 %}
105 | {% do log('Note: exog was empty; running regression on constant term only.') %}
106 | {{ return(dbt_linreg._ols_0var(
107 | table=table,
108 | endog=endog,
109 | exog=exog,
110 | add_constant=add_constant,
111 | output=output,
112 | output_options=output_options,
113 | group_by=group_by,
114 | alpha=alpha
115 | )) }}
116 | {%- elif (exog | length) == 1 %}
117 | {{ return(dbt_linreg._ols_1var(
118 | table=table,
119 | endog=endog,
120 | exog=exog,
121 | add_constant=add_constant,
122 | output=output,
123 | output_options=output_options,
124 | group_by=group_by,
125 | alpha=alpha
126 | )) }}
127 | {%- endif %}
128 | {%- set exog_aliased = dbt_linreg._alias_exog(exog) %}
129 | (with
130 | {%- if alpha %}
131 | _dbt_linreg_cmeans as (
132 | select
133 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
134 | avg({{ endog }}) as y,
135 | {%- for i in exog_aliased %}
136 | avg({{ exog[loop.index0] }}) as {{ i }},
137 | {%- endfor %}
138 | count(*) as ct
139 | from
140 | {{ table }}
141 | {%- if group_by %}
142 | group by
143 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }}
144 | {%- endif %}
145 | ),
146 | {%- endif %}
147 | _dbt_linreg_step0 as (
148 | select
149 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
150 | {%- if alpha and add_constant %}
151 | b.{{ endog }} - _dbt_linreg_cmeans.y as y,
152 | {%- for i in exog_aliased %}
153 | b.{{ exog[loop.index0] }} - _dbt_linreg_cmeans.{{ i }} as {{ i }},
154 | {%- endfor %}
155 | {%- else %}
156 | {{ endog }} as y,
157 | {%- for i in exog_aliased %}
158 | b.{{ exog[loop.index0] }} as {{ i }},
159 | {%- endfor %}
160 | {%- endif %}
161 | false as fake
162 | from
163 | {{ table }} as b
164 | {%- if alpha %}
165 | {%- if add_constant %}
166 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }}
167 | {%- endif %}
168 | {%- for i in exog_aliased %}
169 | {%- set i_idx = loop.index0 %}
170 | union all
171 | select
172 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
173 | 0 as y,
174 | {%- for j in exog_aliased %}
175 | {%- if i == j %}
176 | pow({{ alpha[i_idx] }} * ct, 0.5) as {{ j }},
177 | {%- else %}
178 | 0 as {{ j }},
179 | {%- endif %}
180 | {%- endfor %}
181 | true as fake
182 | from _dbt_linreg_cmeans as cmeans
183 | {%- endfor %}
184 | {%- endif %}
185 | ),
186 | {% for step in range(1, (exog | length)) %}
187 | _dbt_linreg_step{{ step }} as (
188 | with
189 | __dbt_linreg_coefs{{ step }} as (
190 | select
191 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(6) }}
192 | {#- Slope terms #}
193 | {%- for _y, _x, _o in dbt_linreg._traverse_slopes(step, exog_aliased) %}
194 | {%- set _c = dbt_linreg._orth_x_slope(_x, _o) %}
195 | {{ dbt_linreg.regress(_y, _c, add_constant=add_constant) }} as {{ _y }}_{{ _c }}_coef,
196 | {%- endfor %}
197 | {#- Constant terms #}
198 | {%- if add_constant %}
199 | {%- for _y, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %}
200 | avg({{ dbt_linreg._filter_if_alpha(_y, alpha) }})
201 | {%- for _yi, _xi in _o %}
202 | {%- set _ci = dbt_linreg._orth_x_slope(_yi, _xi) %}
203 | - avg({{ dbt_linreg._filter_if_alpha(_yi, alpha) }}) * {{ dbt_linreg._regress_or_alias(_y, _ci) }}
204 | {%- endfor %}
205 | as {{ dbt_linreg._orth_x_intercept(_y, _o) }}_const
206 | {%- if not loop.last -%}
207 | ,
208 | {%- endif -%}
209 | {%- endfor %}
210 | {%- endif %}
211 | from _dbt_linreg_step{{ step - 1 }}
212 | {%- if group_by %}
213 | group by
214 | {{ dbt_linreg._gb_cols(group_by) | indent(6) }}
215 | {%- endif %}
216 | )
217 | select
218 | {{ dbt_linreg._gb_cols(group_by, prefix='b', trailing_comma=True) | indent(4) }}
219 | y,
220 | {%- for i in exog_aliased %}
221 | {{ i }},
222 | {%- endfor %}
223 | fake,
224 | {%- for _y, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %}
225 | {{ _y }}
226 | {%- for _yi, _xi in _o %}
227 | {%- set _ci = dbt_linreg._orth_x_slope(_yi, _xi) %}
228 | - {{ _y }}_{{ _ci }}_coef * {{ _yi }}
229 | {%- endfor %}
230 | {%- set _c = dbt_linreg._orth_x_intercept(_y, _o) %}
231 | {%- if add_constant %}
232 | - {{ _c }}_const
233 | {%- endif %}
234 | as {{ _c }}
235 | {%- if not loop.last -%}
236 | ,
237 | {%- endif %}
238 | {%- endfor %}
239 | from _dbt_linreg_step0 as b
240 | {{ dbt_linreg._join_on_groups(group_by, 'b', '__dbt_linreg_coefs'~step) | indent(2) }}
241 | ),
242 | {%- if loop.last %}
243 | _dbt_linreg_final_coefs as (
244 | select
245 | {%- if add_constant %}
246 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
247 | avg({{ dbt_linreg._filter_and_center_if_alpha('y', alpha, base_prefix='b.') }})
248 | {%- for _x, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %}
249 | - avg({{ dbt_linreg._filter_and_center_if_alpha(_x, alpha, base_prefix='b.') }}) * {{ dbt_linreg.regress('b.y', dbt_linreg._orth_x_intercept('b.' ~ _x, _o)) }}
250 | {%- endfor %}
251 | as x0_coef,
252 | {%- endif %}
253 | {%- for _x, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %}
254 | {{ dbt_linreg.regress('b.y', dbt_linreg._orth_x_intercept(_x, _o), add_constant=add_constant) }} as {{ _x }}_coef
255 | {%- if not loop.last -%}
256 | ,
257 | {%- endif %}
258 | {%- endfor %}
259 | from _dbt_linreg_step{{ step }} as b
260 | {%- if alpha and add_constant %}
261 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }}
262 | {%- endif %}
263 | {%- if group_by %}
264 | group by
265 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }}
266 | {%- endif %}
267 | )
268 | {%- endif %}
269 | {%- endfor %}
270 | {{
271 | dbt_linreg.final_select(
272 | exog=exog,
273 | exog_aliased=exog_aliased,
274 | add_constant=add_constant,
275 | group_by=group_by,
276 | output=output,
277 | output_options=output_options,
278 | calculate_standard_error=False
279 | )
280 | }}
281 | )
282 | {%- endmacro %}
283 |
--------------------------------------------------------------------------------
/macros/linear_regression/utils/utils.sql:
--------------------------------------------------------------------------------
1 | {###############################################################################
2 | ## Simple univariate regression.
3 | ###############################################################################}
4 |
5 | {% macro regress(y, x, add_constant=True) %}
6 | {{ return(
7 | adapter.dispatch('regress', 'dbt_linreg')
8 | (y, x, add_constant=add_constant)
9 | ) }}
10 | {% endmacro %}
11 |
12 | {% macro default__regress(y, x, add_constant=True) -%}
13 | {%- if add_constant -%}
14 | covar_pop({{ x }}, {{ y }}) / var_pop({{ x }})
15 | {%- else -%}
16 | sum({{ x }} * {{ y }}) / sum({{ x }} * {{ x }})
17 | {%- endif -%}
18 | {%- endmacro %}
19 |
20 | {% macro snowflake__regress(y, x, add_constant=True) -%}
21 | {%- if add_constant -%}
22 | regr_slope({{ x }}, {{ y }})
23 | {%- else -%}
24 | sum({{ x }} * {{ y }}) / sum({{ x }} * {{ x }})
25 | {%- endif -%}
26 | {%- endmacro %}
27 |
28 | {###############################################################################
29 | ## Final select
30 | ###############################################################################}
31 |
32 | {# Every OLS method ends with a "_dbt_linreg_final_coefs" CTE with a common
33 | interface. This interface can then be transformed in a standard way using the
34 | final_select() macro, which formats the output for the user. #}
35 | {% macro final_select(exog=none,
36 | exog_aliased=none,
37 | group_by=none,
38 | add_constant=true,
39 | output=none,
40 | output_options=none,
41 | calculate_standard_error=false) -%}
42 | {%- if output == 'long' %}
43 | {%- if add_constant %}
44 | select
45 | {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }}
46 | {{ dbt.string_literal(dbt_linreg._get_output_option('constant_name', output_options, 'const')) }} as {{ dbt_linreg._get_output_option('variable_column_name', output_options, 'variable_name') }},
47 | {{ dbt_linreg._maybe_round('x0_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('coefficient_column_name', output_options, 'coefficient') }}{% if calculate_standard_error %},
48 | {{ dbt_linreg._maybe_round('x0_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('standard_error_column_name', output_options, 'standard_error') }},
49 | {{ dbt_linreg._maybe_round('x0_coef/x0_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('t_statistic_column_name', output_options, 't_statistic') }}
50 | {%- endif %}
51 | from _dbt_linreg_final_coefs as b
52 | {%- if calculate_standard_error %}
53 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_stderrs') }}
54 | {%- endif %}
55 | {%- if exog_aliased %}
56 | union all
57 | {%- endif %}
58 | {%- endif %}
59 | {%- for i in exog_aliased %}
60 | select
61 | {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }}
62 | {{ dbt.string_literal(dbt_linreg._strip_quotes(exog[loop.index0], output_options)) }} as {{ dbt_linreg._get_output_option('variable_column_name', output_options, 'variable_name') }},
63 | {{ dbt_linreg._maybe_round(i~'_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('coefficient_column_name', output_options, 'coefficient') }}{% if calculate_standard_error %},
64 | {{ dbt_linreg._maybe_round(i~'_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('standard_error_column_name', output_options, 'standard_error') }},
65 | {{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('t_statistic_column_name', output_options, 't_statistic') }}
66 | {%- endif %}
67 | from _dbt_linreg_final_coefs as b
68 | {%- if calculate_standard_error %}
69 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_stderrs') }}
70 | {%- endif %}
71 | {%- if not loop.last %}
72 | union all
73 | {%- endif %}
74 | {%- endfor %}
75 | {%- elif output == 'wide' %}
76 | select
77 | {%- if add_constant -%}
78 | {{ dbt_linreg._unalias_gb_cols(group_by) | indent(2) }}
79 | {{ dbt_linreg._maybe_round('x0_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._format_wide_variable_column(dbt_linreg._get_output_option('constant_name', output_options, 'const'), output_options) }}
80 | {%- if exog_aliased -%}
81 | ,
82 | {%- endif -%}
83 | {%- endif -%}
84 | {%- for i in exog_aliased %}
85 | {{ dbt_linreg._maybe_round(i~'_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._format_wide_variable_column(exog[loop.index0], output_options) }}
86 | {%- if not loop.last -%}
87 | ,
88 | {%- endif %}
89 | {%- endfor %}
90 | from _dbt_linreg_final_coefs
91 | {%- else %}
92 | {#- Fallback option (which should never happen!) is to just select star. #}
93 | select * from _dbt_linreg_final_coefs
94 | {%- endif %}
95 | {%- endmacro %}
96 |
97 | {###############################################################################
98 | ## Misc.
99 | ###############################################################################}
100 |
101 | {# Users can pass columns such as '"foo"', with the double quotes included.
102 | In this situation, we want to strip the double quotes when presenting
103 | outputs in a long format. #}
104 | {% macro _strip_quotes(x, output_options) -%}
105 | {% if dbt_linreg._get_output_option('strip_quotes', output_options) | default(True) %}
106 | {% if x[0] == '"' and x[-1] == '"' and (x | length) > 1 %}
107 | {{ return(x[1:-1]) }}
108 | {% endif %}
109 | {% endif %}
110 | {{ return(x)}}
111 | {%- endmacro %}
112 |
113 | {% macro _format_wide_variable_column(x, output_options) -%}
114 | {% if x[0] == '"' and x[-1] == '"' and (x | length) > 1 %}
115 | {% set _add_quotes = True %}
116 | {% set x = x[1:-1] %}
117 | {% else %}
118 | {% set _add_quotes = False %}
119 | {% endif %}
120 | {% if dbt_linreg._get_output_option('variable_column_prefix', output_options) %}
121 | {% set x = dbt_linreg._get_output_option('variable_column_prefix', output_options) ~ x %}
122 | {% endif %}
123 | {% if dbt_linreg._get_output_option('variable_column_suffix', output_options) %}
124 | {% set x = x ~ dbt_linreg._get_output_option('variable_column_suffix', output_options) %}
125 | {% endif %}
126 | {% if _add_quotes %}
127 | {% set x = '"' ~ x ~ '"' %}
128 | {% endif %}
129 | {{ return(x)}}
130 | {%- endmacro %}
131 |
132 | {# To ensure no namespace conflicts, f"gb{index}" is used in group by
133 | statements instead of the actual column names. This macro adds aliases. #}
134 | {% macro _alias_gb_cols(group_by) -%}
135 | {%- if group_by %}
136 | {%- for gb in group_by %}
137 | {{ gb }} as gb{{ loop.index }},
138 | {%- endfor %}
139 | {%- endif %}
140 | {%- endmacro %}
141 |
142 | {# This macros reverses gb column aliases at the end of an OLS query. #}
143 | {% macro _unalias_gb_cols(group_by, prefix=None) -%}
144 | {%- if group_by %}
145 | {%- for gb in group_by %}
146 | {%- if prefix %}
147 | {{ prefix }}.gb{{ loop.index }} as {{ gb }},
148 | {%- else %}
149 | gb{{ loop.index }} as {{ gb }},
150 | {%- endif %}
151 | {%- endfor %}
152 | {%- endif %}
153 | {%- endmacro %}
154 |
155 | {# Round the final coefficient if the user specifies the `round` format
156 | option. Otherwise, keep as is. #}
157 |
158 | {% macro _maybe_round(x, round_) %}
159 | {{ return(
160 | adapter.dispatch('_maybe_round', 'dbt_linreg')(x, round_)
161 | ) }}
162 | {% endmacro %}
163 |
164 | {% macro default___maybe_round(x, round_) %}
165 | {% if round_ is not none %}
166 | {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }}
167 | {% else %}
168 | {{ return(x) }}
169 | {% endif %}
170 | {% endmacro %}
171 |
172 | {% macro postgres___maybe_round(x, round_) %}
173 | {% if round_ is not none %}
174 | {{ return('round((' ~ x ~ ')::numeric, ' ~ round_ ~ ')') }}
175 | {% else %}
176 | {{ return('(' ~ x ~ ')::numeric') }}
177 | {% endif %}
178 | {% endmacro %}
179 |
180 | {% macro redshift___maybe_round(x, round_) %}
181 | {% if round_ is not none %}
182 | {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }}
183 | {% else %}
184 | {{ return(x) }}
185 | {% endif %}
186 | {% endmacro %}
187 |
188 | {# Alias and write group by columns in a standard way. #}
189 | {% macro _gb_cols(group_by, trailing_comma=False, prefix=None) -%}
190 | {%- if group_by %}
191 | {%- for gb in group_by %}
192 | {%- if prefix %}
193 | {{ prefix }}.gb{{ loop.index }}
194 | {%- else %}
195 | gb{{ loop.index }}
196 | {%- endif %}
197 | {%- if (not loop.last) or trailing_comma -%}
198 | ,
199 | {%- endif %}
200 | {%- endfor %}
201 | {%- endif %}
202 | {%- endmacro %}
203 |
204 | {# Take exog and gen a list containing 'x1', 'x2', etc. #}
205 | {% macro _alias_exog(x) -%}
206 | {% set li = [] %}
207 | {% for i in x %}
208 | {% do li.append('x' ~ loop.index) %}
209 | {% endfor %}
210 | {{ return(li) }}
211 | {%- endmacro %}
212 |
213 | {# Join on gb1, gb2 etc. from a table to another table.
214 | If there is no group by column, assume `join_to` is just 1 row.
215 | And in that case, just do a cross join. #}
216 | {% macro _join_on_groups(group_by, join_from, join_to) -%}
217 | {%- if not group_by %}
218 | cross join {{ join_to }}
219 | {%- else %}
220 | inner join {{ join_to }}
221 | on
222 | {%- for _ in group_by %}
223 | {{ join_from }}.gb{{ loop.index }} = {{ join_to }}.gb{{ loop.index }}
224 | {% if not loop.last -%}
225 | and
226 | {%- endif %}
227 | {%- endfor %}
228 | {%- endif %}
229 | {%- endmacro %}
230 |
231 | {% macro _get_output_option(field, output_options, default=none) %}
232 | {{ return(output_options.get(field, var("dbt_linreg", {}).get("output_options", {}).get(field, default))) }}
233 | {% endmacro %}
234 |
235 | {% macro _get_method_option(method, field, method_options, default=none) %}
236 | {{ return(method_options.get(field, var("dbt_linreg", {}).get("method_options", {}).get(method, {}).get(field, default))) }}
237 | {% endmacro %}
238 |
--------------------------------------------------------------------------------
/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql:
--------------------------------------------------------------------------------
1 | {# In some warehouses, you can reference newly created column aliases
2 | in the query you wrote.
3 | If that's not available, the previous calc will be in the dict. #}
4 |
5 | {% macro _cell_or_alias(i, j, d, prefix=none, isa=none) %}
6 | {% if isa is not none %}
7 | {% if isa %}
8 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
9 | {% else %}
10 | {{ return(d[(i, j)]) }}
11 | {% endif %}
12 | {% endif %}
13 | {{ return(
14 | adapter.dispatch('_cell_or_alias', 'dbt_linreg')
15 | (i, j, d, prefix, isa)
16 | ) }}
17 | {% endmacro %}
18 |
19 | {% macro default___cell_or_alias(i, j, d, prefix=none, isa=none) %}
20 | {{ return(d[(i, j)]) }}
21 | {% endmacro %}
22 |
23 | {% macro snowflake___cell_or_alias(i, j, d, prefix=none, isa=none) %}
24 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
25 | {% endmacro %}
26 |
27 | {% macro duckdb___cell_or_alias(i, j, d, prefix=none, isa=none) %}
28 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
29 | {% endmacro %}
30 |
31 | {% macro clickhouse___cell_or_alias(i, j, d, prefix=none, isa=none) %}
32 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
33 | {% endmacro %}
34 |
35 | {% macro _safe_sqrt(x, safe=True) %}
36 | {{ return(
37 | adapter.dispatch('_safe_sqrt', 'dbt_linreg')
38 | (x, safe)
39 | ) }}
40 | {% endmacro %}
41 |
42 | {% macro default___safe_sqrt(x, safe=True) %}
43 | {% if safe %}
44 | {{ return('case when ('~x~') >= 0 then sqrt('~x~') end') }}
45 | {% endif %}
46 | {{ return('sqrt('~x~')') }}
47 | {% endmacro %}
48 |
49 | {% macro bigquery___safe_sqrt(x, safe=True) %}
50 | {% if safe %}
51 | {{ return('safe.sqrt('~x~')') }}
52 | {% endif %}
53 | {{ return('sqrt('~x~')') }}
54 | {% endmacro %}
55 |
56 | {% macro _cholesky_decomposition(li, subquery_optimization=true, safe=true, isa=none) %}
57 | {% set d = {} %}
58 | {% for i in li %}
59 | {% for j in range(li[0], i + 1) %}
60 | {% if i == li[0] and j == li[0] %}
61 | {% do d.update({(i, j): dbt_linreg._safe_sqrt(x='x'~i~'x'~j, safe=safe)}) %}
62 | {% else %}
63 | {% set ns = namespace() %}
64 | {% set ns.s = 'x'~j~'x'~i %}
65 | {% for k in range(li[0], j) %}
66 | {% if subquery_optimization and i != j %}
67 | {% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*i'~j~'j'~k %}
68 | {% else %}
69 | {% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d, isa=isa) %}
70 | {% endif %}
71 | {% endfor %}
72 | {% if i == j %}
73 | {% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %}
74 | {% else %}
75 | {% if safe %}
76 | {% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa) ~ ', 0)'}) %}
77 | {% else %}
78 | {% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa)}) %}
79 | {% endif %}
80 | {% endif %}
81 | {% endif %}
82 | {% endfor %}
83 | {% endfor %}
84 | {{ return(d) }}
85 | {% endmacro %}
86 |
87 | {% macro _forward_substitution(li, safe=true, isa=none) %}
88 | {% set d = {} %}
89 | {% for i, j in modules.itertools.combinations_with_replacement(li, 2) %}
90 | {% set ns = namespace() %}
91 | {% if i == j %}
92 | {% set ns.numerator = '1' %}
93 | {% else %}
94 | {% set ns.numerator = '(' %}
95 | {% for k in range(i, j) %}
96 | {% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_", isa=isa) %}
97 | {% endfor %}
98 | {% set ns.numerator = ns.numerator~')' %}
99 | {% endif %}
100 | {% if safe %}
101 | {% do d.update({(i, j): '('~ns.numerator~'/nullif(i'~j~'j'~j~', 0))'}) %}
102 | {% else %}
103 | {% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %}
104 | {% endif %}
105 | {% endfor %}
106 | {{ return(d) }}
107 | {% endmacro %}
108 |
109 | {% macro _ols_chol(table,
110 | endog,
111 | exog,
112 | add_constant=True,
113 | output=None,
114 | output_options=None,
115 | group_by=None,
116 | alpha=None,
117 | method_options=None) -%}
118 | {%- if (exog | length) == 0 %}
119 | {% do log('Note: exog was empty; running regression on constant term only.') %}
120 | {{ return(dbt_linreg._ols_0var(
121 | table=table,
122 | endog=endog,
123 | exog=exog,
124 | add_constant=add_constant,
125 | output=output,
126 | output_options=output_options,
127 | group_by=group_by,
128 | alpha=alpha
129 | )) }}
130 | {%- endif %}
131 | {%- set subquery_optimization = dbt_linreg._get_method_option('chol', 'subquery_optimization', method_options, true) %}
132 | {%- set safe_mode = dbt_linreg._get_method_option('chol', 'safe', method_options, true) %}
133 | {% set isa = dbt_linreg._get_method_option('chol', 'intra_select_aliasing', method_options) %}
134 | {%- set calculate_standard_error = dbt_linreg._get_output_option('calculate_standard_error', output_options, (not alpha) and output == 'long') %}
135 | {%- if alpha and calculate_standard_error %}
136 | {% do log(
137 | 'Warning: Standard errors are NOT designed to take into account ridge regression regularization.'
138 | ) %}
139 | {%- endif %}
140 | {%- if add_constant %}
141 | {% set xmin = 0 %}
142 | {%- else %}
143 | {% set xmin = 1 %}
144 | {%- endif %}
145 | {%- set xcols = (range(xmin, (exog | length) + 1) | list) %}
146 | {%- set upto = (xcols | length) %}
147 | {%- set exog_aliased = dbt_linreg._alias_exog(exog) %}
148 | (with
149 | _dbt_linreg_base as (
150 | select
151 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }}
152 | {{ endog }} as y,
153 | {%- if add_constant %}
154 | 1 as x0,
155 | {%- endif %}
156 | {%- for i in range(1, (exog | length) + 1) %}
157 | b.{{ exog[loop.index0] }} as x{{ i }}
158 | {%- if not loop.last -%}
159 | ,
160 | {%- endif %}
161 | {%- endfor %}
162 | from
163 | {{ table }} as b
164 | ),
165 | _dbt_linreg_xtx as (
166 | select
167 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
168 | {%- for i, j in modules.itertools.combinations_with_replacement(xcols, 2) %}
169 | {%- if alpha and i == j and i > 0 %}
170 | sum(b.x{{ i }} * b.x{{ j }} + {{ alpha[i-1] }}) as x{{ i }}x{{ j }}
171 | {%- else %}
172 | sum(b.x{{ i }} * b.x{{ j }}) as x{{ i }}x{{ j }}
173 | {%- endif %}
174 | {%- if not loop.last -%}
175 | ,
176 | {%- endif %}
177 | {%- endfor %}
178 | from _dbt_linreg_base as b
179 | {%- if group_by %}
180 | group by
181 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }}
182 | {%- endif %}
183 | ),
184 | _dbt_linreg_chol as (
185 |
186 | {%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode, isa=isa) %}
187 | {%- if subquery_optimization %}
188 | {%- for i in (xcols | reverse) %}
189 | select
190 | *,
191 | {%- for j in range(xmin, i + 1) %}
192 | {{ d[(i, j)] }} as i{{ i }}j{{ j }}
193 | {%- if not loop.last -%}
194 | ,
195 | {%- endif %}
196 | {%- endfor %}
197 | {%- if not loop.last %}
198 | from (
199 | {%- else %}
200 | from _dbt_linreg_xtx{% for close_ct in range(upto - 1) %}) as ic{{ close_ct }}{% endfor %}
201 | {%- endif %}
202 | {%- endfor %}
203 | {%- else %}
204 | select
205 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
206 | {%- for k, v in d.items() %}
207 | {{ v }} as {{ 'i'~k[0]~'j'~k[1] }}
208 | {%- if not loop.last -%}
209 | ,
210 | {%- endif %}
211 | {%- endfor %}
212 | from _dbt_linreg_xtx
213 | {%- endif %}
214 | ),
215 | _dbt_linreg_inverse_chol as (
216 | {#- The optimal way to calculate is to do each diagonal at a time. #}
217 | {%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode, isa=isa) %}
218 | {%- if subquery_optimization %}
219 | {%- for gap in (range(0, upto) | reverse) %}
220 | select *,
221 | {%- for j in range(gap + xmin, upto + xmin) %}
222 | {%- set i = j - gap %}
223 | {{ d[(i, j)] }} as inv_i{{ i }}j{{ j }}
224 | {%- if not loop.last -%}
225 | ,
226 | {%- endif %}
227 | {%- endfor %}
228 | {%- if not loop.last %}
229 | from (
230 | {%- else %}
231 | from _dbt_linreg_chol{% for close_ct in range(upto - 1) %}) as ic{{ close_ct }}{% endfor %}
232 | {%- endif %}
233 | {%- endfor %}
234 | {%- else %}
235 | select
236 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
237 | {%- for k, v in d.items() %}
238 | {{ v }} as inv_{{ 'i'~k[0]~'j'~k[1] }}
239 | {%- if not loop.last -%}
240 | ,
241 | {%- endif %}
242 | {%- endfor %}
243 | from _dbt_linreg_chol
244 | {%- endif %}
245 | ),
246 | _dbt_linreg_inverse_xtx as (
247 | select
248 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }}
249 | {%- for i, j in modules.itertools.combinations_with_replacement(xcols, 2) %}
250 | {%- if not add_constant %}
251 | {%- set upto = upto + 1 %}
252 | {%- endif %}
253 | {%- for k in range(j, upto) %}
254 | inv_i{{ i }}j{{ k }} * inv_i{{ j }}j{{ k }}{%- if not loop.last %} + {% endif -%}
255 | {%- endfor %}
256 | as inv_x{{ i }}x{{ j }}
257 | {%- if not loop.last -%}
258 | ,
259 | {%- endif %}
260 | {%- endfor %}
261 | from _dbt_linreg_inverse_chol
262 | ),
263 | _dbt_linreg_final_coefs as (
264 | select
265 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }}
266 | {%- for x1 in xcols %}
267 | sum((
268 | {%- for x2 in xcols %}
269 | {%- if x2 > x1 %}
270 | b.x{{ x2 }} * inv_x{{ x1 }}x{{ x2 }}
271 | {%- else %}
272 | b.x{{ x2 }} * inv_x{{ x2 }}x{{ x1 }}
273 | {%- endif %}
274 | {%- if not loop.last %} + {% endif -%}
275 | {%- endfor %}
276 | ) * b.y) as x{{ x1 }}_coef
277 | {%- if not loop.last -%}
278 | ,
279 | {%- endif %}
280 | {%- endfor %}
281 | from
282 | _dbt_linreg_base as b
283 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_inverse_xtx') | indent(2) }}
284 | {%- if group_by %}
285 | group by
286 | {{ dbt_linreg._gb_cols(group_by, prefix='b') | indent(4) }}
287 | {%- endif %}
288 | ){%- if calculate_standard_error %},
289 | _dbt_linreg_resid as (
290 | select
291 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }}
292 | avg(pow(y
293 | {%- for x in xcols %}
294 | - x{{ x }} * x{{ x }}_coef
295 | {%- endfor %}
296 | , 2)) as resid_square_mean,
297 | count(*) as n
298 | from
299 | _dbt_linreg_base as b
300 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_final_coefs') | indent(2) }}
301 | {%- if group_by %}
302 | group by
303 | {{ dbt_linreg._gb_cols(group_by, prefix='b') | indent(2) }}
304 | {%- endif %}
305 | ),
306 | _dbt_linreg_stderrs as (
307 | select
308 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }}
309 | {%- for x in xcols %}
310 | sqrt(inv_x{{ x }}x{{ x }} * resid_square_mean * n / (n - {{ upto }})) as x{{ x }}_stderr
311 | {%- if not loop.last -%}
312 | ,
313 | {%- endif %}
314 | {%- endfor %}
315 | from
316 | _dbt_linreg_resid as b
317 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_inverse_xtx') | indent(2) }}
318 | )
319 | {%- endif %}
320 | {{
321 | dbt_linreg.final_select(
322 | exog=exog,
323 | exog_aliased=exog_aliased,
324 | add_constant=add_constant,
325 | group_by=group_by,
326 | output=output,
327 | output_options=output_options,
328 | calculate_standard_error=calculate_standard_error
329 | )
330 | }})
331 | {% endmacro %}
332 |
--------------------------------------------------------------------------------
/scripts.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is used for generation of CSV files for integration test cases,
3 | and also for manual verification + generation of test case values.
4 | """
5 | import json
6 | import os
7 | import os.path as op
8 | import warnings
9 | from typing import NamedTuple
10 | from typing import Optional
11 | from typing import Protocol
12 |
13 | import numpy as np
14 | import pandas as pd
15 | import rich_click as click
16 | import statsmodels.api as sm
17 | import yaml
18 | from tabulate import tabulate
19 |
20 |
21 | # Suppress iteritems warning
22 | warnings.simplefilter("ignore", category=FutureWarning)
23 |
24 | # No scientific notation
25 | np.set_printoptions(suppress=True)
26 |
27 |
28 | DIR = op.dirname(__file__)
29 |
30 | DEFAULT_SIZE = 10_000
31 | DEFAULT_SEED = 382479347
32 |
33 |
34 | class TestCase(NamedTuple):
35 | df: pd.DataFrame
36 | x_cols: list[str]
37 | y_col: str
38 | group: Optional[str] = None
39 |
40 |
41 | class TestCaseCallable(Protocol):
42 | def __call__(self, size: int, seed: int) -> TestCase:
43 | pass
44 |
45 |
46 | def gram_schmidt(df: pd.DataFrame):
47 | q = pd.DataFrame(index=df.index)
48 | for c, v in df.items():
49 | v_new = v.copy()
50 | for _, u in q.items():
51 | v_new -= u * u.dot(v) / u.dot(u)
52 | q[c] = v_new / np.linalg.norm(v_new)
53 | return q
54 |
55 |
56 | def simple_matrix(size: int = DEFAULT_SIZE, seed: int = DEFAULT_SEED) -> TestCase:
57 | # Gram Schmidt makes any matrix into the simplest test case because
58 | # orthogonalization guarantees round and predictable coefficients.
59 | #
60 | # That said, we also want to cover test cases unorthogonalized.
61 | # Otherwise, it kind of beats the point of writing all that multiple
62 | # regression logic using the FWL theorem.
63 | #
64 | # So although it is a good and clean test case, it can't be the only one.
65 | rs = np.random.RandomState(seed=seed)
66 | df = pd.DataFrame(index=range(size))
67 | df["const"] = 1
68 |
69 | coefficients = pd.Series({
70 | "const": 10,
71 | "xa": 5,
72 | "xb": 7,
73 | "xc": 9,
74 | "xd": 11,
75 | "xe": 13,
76 | "xf": 15,
77 | "xg": 17,
78 | "xh": 19,
79 | "xi": 21,
80 | "xj": 23
81 | })
82 |
83 | for c in coefficients.index:
84 | if c == "const":
85 | continue
86 | df[c] = rs.normal(0, 1, size=size)
87 | df["epsilon"] = rs.normal(0, 10, size=size)
88 |
89 | feature_cols = list(coefficients.index)
90 | x_cols = [i for i in feature_cols if i != "const"]
91 | non_const_cols = x_cols + ["epsilon"]
92 |
93 | # Center the non-constant columns
94 | # This is kinda like orthogonalizing w/r/t constant term.
95 | # So doing this here means we don't need to Gram Schmidt the constant.
96 | for c in non_const_cols:
97 | df[c] -= df[c].mean()
98 |
99 | df[non_const_cols] = gram_schmidt(df[non_const_cols])
100 |
101 | df["y"] = df[feature_cols].dot(coefficients) + df["epsilon"]
102 |
103 | return TestCase(df=df, y_col="y", x_cols=x_cols)
104 |
105 |
106 | def collinear_matrix(size: int = DEFAULT_SIZE, seed: int = DEFAULT_SEED) -> TestCase:
107 | rs = np.random.RandomState(seed=seed)
108 | df = pd.DataFrame(index=range(size))
109 | df["const"] = 1
110 | df["x1"] = 2 + rs.normal(0, 1, size=size)
111 | df["x2"] = 1 - df["x1"] + rs.normal(0, 3, size=size)
112 | df["x3"] = 3 + 2 * df["x2"] + rs.normal(0, 1, size=size)
113 | df["x4"] = -3 + 0.5 * (df["x1"] * df["x3"]) + rs.normal(0, 1, size=size)
114 | df["x5"] = 4 + 0.5 * np.sin(3 * df["x2"]) + rs.normal(0, 1, size=size)
115 | df["epsilon"] = rs.normal(0, 4, size=size)
116 |
117 | coefficients = pd.Series({
118 | "const": 20,
119 | "x1": 5,
120 | "x2": 7,
121 | "x3": 9,
122 | "x4": 11,
123 | "x5": 13
124 | })
125 |
126 | x_cols = list(coefficients.index)
127 |
128 | # coefficients will not exactly match due to OVB
129 | df["y"] = (
130 | df[coefficients.index].dot(coefficients)
131 | + (df["x3"] + np.sin(df["x1"])) ** 2
132 | + df["epsilon"]
133 | )
134 |
135 | return TestCase(df=df, y_col="y", x_cols=x_cols)
136 |
137 |
138 | def groups_matrix(size: int = DEFAULT_SIZE, seed: int = DEFAULT_SEED) -> TestCase:
139 | rs = np.random.RandomState(seed=seed)
140 | size1 = size // 2
141 | size2 = size - size1
142 |
143 | df1 = pd.DataFrame(index=range(size1))
144 | df1["gb_var"] = "a"
145 | df1["const"] = 1
146 | df1["x1"] = 2 + rs.normal(0, 1, size=size1)
147 | df1["x2"] = 1 - df1["x1"] + rs.normal(0, 3, size=size1)
148 | df1["x3"] = 3 + 2 * df1["x2"] + rs.normal(0, 1, size=size1)
149 | df1["y"] = 1 * df1["x1"] + 2 * df1["x2"] + 3 * df1["x2"] + rs.normal(0, 1, size=size1)
150 |
151 | df2 = pd.DataFrame(index=range(size2))
152 | df2["gb_var"] = "b"
153 | df2["const"] = 1
154 | df2["x1"] = 6 + rs.normal(0, 3, size=size2)
155 | df2["x2"] = 3 + df2["x1"] + rs.normal(0, 3, size=size2)
156 | df2["x3"] = -1 - df2["x2"] + rs.normal(0, 2, size=size2)
157 | df2["y"] = 2 + 3 * df2["x1"] + 4 * df2["x2"] + 5 * df2["x2"] + rs.normal(0, 1, size=size1)
158 |
159 | df = pd.concat([df1, df2], axis=0).reset_index()
160 |
161 | return TestCase(
162 | df=df,
163 | y_col="y",
164 | x_cols=["const", "x1", "x2", "x3"],
165 | group="gb_var"
166 | )
167 |
168 |
169 | ALL_TEST_CASES: dict[str, TestCaseCallable] = {
170 | "simple_matrix": simple_matrix,
171 | "collinear_matrix": collinear_matrix,
172 | "groups_matrix": groups_matrix
173 | }
174 |
175 |
176 | def click_option_seed(**kwargs):
177 | return click.option(
178 | "--seed", "-s",
179 | default=DEFAULT_SEED,
180 | show_default=True,
181 | help="Seed used to generate data.",
182 | **kwargs
183 | )
184 |
185 |
186 | def click_option_size(**kwargs):
187 | return click.option(
188 | "--size", "-n",
189 | default=DEFAULT_SIZE,
190 | show_default=True,
191 | help="Number of rows to generate.",
192 | **kwargs
193 | )
194 |
195 |
196 | @click.group("main", context_settings=dict(help_option_names=["-h", "--help"]))
197 | def cli():
198 | """CLI for manually testing the code base."""
199 |
200 |
201 | @cli.command("regress")
202 | @click.option("--table", "-t",
203 | required=True,
204 | type=click.Choice(ALL_TEST_CASES.keys()),
205 | help="Table to regress against.")
206 | @click.option("--const/--no-const",
207 | default=True,
208 | type=click.BOOL,
209 | show_default=True,
210 | help="If true, add the constant term.")
211 | @click.option("--columns", "-c",
212 | default=None,
213 | type=click.INT,
214 | show_default=True,
215 | help="Number of columns to regress.")
216 | @click.option("--alpha", "-a",
217 | default=None,
218 | type=click.FLOAT,
219 | show_default=True,
220 | help="Alpha for the regression.")
221 | @click_option_size()
222 | @click_option_seed()
223 | def regress(table: str, const: bool, columns: int, alpha: float, size: int, seed: int):
224 | """
225 | Run regression on integration test cases.
226 |
227 | Use me for either manual verification of test cases, or for generating new
228 | test cases. (All numeric values for test cases were generated using this
229 | CLI.)
230 | """
231 | callback = ALL_TEST_CASES[table]
232 |
233 | click.echo(click.style("=" * 80, fg="blue"))
234 | click.echo(
235 | click.style("Test case: ", fg="blue", bold=True)
236 | +
237 | click.style(table, fg="blue")
238 | )
239 | click.echo(click.style("=" * 80, fg="blue"))
240 |
241 | test_case = callback(size, seed)
242 |
243 | if columns is None:
244 | x_cols = test_case.x_cols
245 | else:
246 | # K plus Constant (1)
247 | x_cols = test_case.x_cols[:columns+1]
248 |
249 | if not const:
250 | x_cols = [i for i in x_cols if i != "const"]
251 |
252 | def _run_model(cond=None):
253 | if cond is None:
254 | cond = slice(None)
255 | y = test_case.df.loc[cond, test_case.y_col]
256 | x_mat = test_case.df.loc[cond, x_cols]
257 | if alpha:
258 | if const:
259 | alpha_arr = [0, *([alpha] * (len(x_mat.columns) - 1))]
260 | else:
261 | alpha_arr = [alpha] * len(x_mat.columns)
262 | model = sm.OLS(
263 | y,
264 | x_mat
265 | ).fit_regularized(L1_wt=0, alpha=alpha_arr)
266 | else:
267 | model = sm.OLS(y, x_mat).fit()
268 | res_df = pd.DataFrame(index=x_mat.columns)
269 | res_df["coef"] = model.params
270 | res_df["stderr"] = model.bse
271 | res_df["tstat"] = res_df["coef"] / res_df["stderr"]
272 | click.echo(
273 | tabulate(
274 | res_df,
275 | headers=["column name", "coef", "stderr", "tstat"],
276 | disable_numparse=True,
277 | tablefmt="psql",
278 | )
279 | )
280 |
281 | if test_case.group:
282 | for c in test_case.df[test_case.group].unique():
283 | click.echo(click.style(f"{test_case.group} - {c}", fg="green"))
284 | _run_model(cond=(test_case.df[test_case.group] == c))
285 | else:
286 | _run_model()
287 |
288 |
289 | def echo_table_name(s: str):
290 | click.echo(click.style("=" * 80, fg="green"))
291 | click.echo(
292 | click.style("Table: ", fg="green", bold=True)
293 | +
294 | click.style(s, fg="green")
295 | )
296 | click.echo(click.style("=" * 80, fg="green"))
297 |
298 |
299 | @cli.command("gen-test-cases")
300 | @click.option("--table", "-t", "tables",
301 | multiple=True,
302 | default=None,
303 | show_default=True,
304 | help="Generate a specific table. If None, generate all tables.")
305 | @click_option_size()
306 | @click_option_seed()
307 | @click.option("--skip-if-exists", is_flag=True,
308 | help="Skip if the file exists. Otherwise, overwrite.")
309 | def gen_test_cases(tables: list[str], size: int, seed: int, skip_if_exists: bool):
310 | """Generate integration test cases (CSV files)."""
311 | if not tables:
312 | tables = ALL_TEST_CASES
313 | for table_name in tables:
314 | file_name = f"{DIR}/integration_tests/seeds/{table_name}.csv"
315 | if skip_if_exists and op.exists(file_name):
316 | click.echo("File " + click.style(file_name, fg="blue") + " already exists; skipping.")
317 | continue
318 |
319 | callback = ALL_TEST_CASES[table_name]
320 |
321 | echo_table_name(table_name)
322 |
323 | test_case = callback(size, seed)
324 | y = test_case.df[test_case.y_col]
325 | x_mat = test_case.df[test_case.x_cols]
326 |
327 | click.echo()
328 | li = []
329 | for i in range(1, len(x_mat.columns) + 1):
330 | model = sm.OLS(
331 | y,
332 | sm.add_constant(x_mat.iloc[:, :i])
333 | ).fit()
334 | params = model.params.rename(f"{i}-var").reindex(x_mat.columns)
335 | params = params.apply(
336 | lambda s: "{:.5f}".format(s)
337 | if pd.notna(s)
338 | else None
339 | )
340 | expand_by = params.apply(lambda s: len(s) if s is not None else 0).max()
341 | params = params.where(
342 | pd.notna(params),
343 | click.style("-" * expand_by, fg="bright_black")
344 | )
345 |
346 | params.apply(lambda s: len(s)).max()
347 | li.append(params)
348 | coefs = pd.concat(li, axis=1)
349 | click.echo(
350 | tabulate(
351 | coefs,
352 | headers=coefs.columns,
353 | disable_numparse=True,
354 | tablefmt="psql",
355 | )
356 | )
357 |
358 |
359 | all_cols = [test_case.y_col, *test_case.x_cols]
360 | if test_case.group:
361 | all_cols.append(test_case.group)
362 |
363 | test_case.df[all_cols].to_csv(file_name, index=False)
364 | click.echo(
365 | click.style(f"Wrote DataFrame to file {file_name!r}", fg="yellow")
366 | )
367 | click.echo("")
368 | click.echo(click.style("Done!", fg="green"))
369 |
370 |
371 | @cli.command("gen-hide-macros-yaml")
372 | @click.option("--parse/--no-parse", is_flag=True, default=True)
373 | def gen_hide_args_yaml(parse: bool) -> None:
374 | """Generates the YAML that hides the macros from the docs.
375 |
376 | Requires the `manifest.json` to be available.
377 | (`dbt parse --profiles-dir ./integration_tests/profiles`)
378 |
379 | Recommended to `| pbcopy` this command, then paste in `macros/schema.yml`.
380 |
381 | This is not enforced during CICD, beware!
382 | """
383 |
384 | if parse:
385 | from dbt.cli.main import dbtRunner
386 | os.environ["DO_NOT_TRACK"] = "1"
387 | dbtRunner().invoke(
388 | [
389 | "parse",
390 | "--profiles-dir", op.join(op.dirname(__file__), "integration_tests", "profiles"),
391 | "--project-dir", op.dirname(__file__)
392 | ]
393 | )
394 |
395 | exclude_from_hiding = ["ols"]
396 | with open("target/manifest.json") as f:
397 | manifest = json.load(f)
398 |
399 | macros = [
400 | data["name"] for fqn, data
401 | in manifest["macros"].items()
402 | if data.get("package_name", "") == "dbt_linreg"
403 | and data.get("name") not in exclude_from_hiding
404 | ]
405 |
406 | out = [
407 | {"name": macro, "docs": {"show": False}}
408 | for macro in sorted(macros)
409 | ]
410 |
411 | print(" " + yaml.safe_dump(out, sort_keys=False).replace("\n", "\n "))
412 |
413 |
414 | if __name__ == "__main__":
415 | cli()
416 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | Linear regression in any SQL dialect, powered by dbt.
9 |
10 |
11 |
12 |
13 |
14 |
15 | # Overview
16 |
17 | **dbt_linreg** is an easy way to perform linear regression and ridge regression in SQL (Snowflake, DuckDB, Clickhouse, and more) with OLS using dbt's Jinja2 templating.
18 |
19 | Reasons to use **dbt_linreg**:
20 |
21 | - 📈 **Linear regression in pure SQL:** With the power of Jinja2 templating and some behind-the-scenes math tricks, it is possible to implement multiple and multivariate regression in pure SQL. Most SQL engines (even OLAP engines) do not have a multiple regression implementation of OLS, so this fills a valuable niche. **`dbt_linreg` implements true OLS, not an approximation!**
22 | - 📱 **Simple interface:** Just define a `table=` (which works with `ref()`, `source()`, and CTEs), a y-variable with `endog=`, your x-variables in a list with `exog=...`, and you're all set! Note that the API is loosely based on Statsmodels's naming conventions.
23 | - 🤖 **Support for ridge regression:** Just pass in `alpha=scalar` or `alpha=[scalar1, scalar2, ...]` to regularize your regressions. (Note: regressors are not automatically standardized.)
24 | - 🤸 **Flexibility:** Tons of formatting options available to return coefficients the way you want.
25 | - 🤗 **User friendly:** The API provides comprehensive feedback on input errors.
26 | - 💪 **Durable and tested:** Everything in this code base is tested against equivalent regressions performed in Statsmodels with high precision assertions (between 10e-6 to 10e-7, depending on the database engine).
27 |
28 | _Note: If you enjoy this project, you may also enjoy my other dbt machine learning project, [**dbt_pca**](https://github.com/dwreeves/dbt_pca)._ 😊
29 |
30 | # Installation
31 |
32 | dbt-core `>=1.2.0` is required to install `dbt_linreg`.
33 |
34 | Add this the `packages:` list your dbt project's `packages.yml`:
35 |
36 | ```yaml
37 | - package: "dwreeves/dbt_linreg"
38 | version: "0.3.1"
39 | ```
40 |
41 | The full file will look something like this:
42 |
43 | ```yaml
44 | packages:
45 | # ...
46 | # Other packages here
47 | # ...
48 | - package: "dwreeves/dbt_linreg"
49 | version: "0.3.1"
50 | ```
51 |
52 | # Examples
53 |
54 | ### Simple example
55 |
56 | The following example runs a linear regression of 3 columns `xa + xb + xc` on `y`, using data in the dbt model named `simple_matrix`. It outputs the data in "long" format, and rounds the coefficients to 5 decimal points:
57 |
58 | ```sql
59 | {{
60 | config(
61 | materialized="table"
62 | )
63 | }}
64 | select * from {{
65 | dbt_linreg.ols(
66 | table=ref('simple_matrix'),
67 | endog='y',
68 | exog=['xa', 'xb', 'xc'],
69 | output='long',
70 | output_options={'round': 5}
71 | )
72 | }} as linreg
73 | ```
74 |
75 | Output:
76 |
77 | |variable_name|coefficient|standard_error|t_statistic|
78 | |---|---|---|---|
79 | |const|10.0|0.00462|2163.27883|
80 | |xa|5.0|0.46226|10.81639|
81 | |xb|7.0|0.46226|15.14295|
82 | |xc|9.0|0.46226|19.46951|
83 |
84 | Note: `simple_matrix` is one of the test cases, so you can try this yourself! Standard errors are constant across `xa`, `xb`, `xc`, because `simple_matrix` is orthonormal.
85 |
86 | ### Complex example
87 |
88 | The following hypothetical example shows multiple ridge regressions (one per `product_id`) on a table that is preprocessed substantially. After the fact, predictions are run, and the R-squared of each regression is calculated at the end.
89 |
90 | This example shows that, although `dbt_linreg` does not implement everything for you, the OLS implementation does most of the hard work. This gives you the freedom to do things you've never been able to do before in SQL!
91 |
92 | ```sql
93 | {{
94 | config(
95 | materialized="table"
96 | )
97 | }}
98 | with
99 |
100 | preprocessed_data as (
101 |
102 | select
103 | product_id,
104 | price,
105 | log(price) as log_price,
106 | epoch(time) as t,
107 | sin(epoch(time)*pi()*2 / (60*60*24*365)) as sin_t,
108 | cos(epoch(time)*pi()*2 / (60*60*24*365)) as cos_t
109 | from
110 | {{ ref('prices') }}
111 |
112 | ),
113 |
114 | preprocessed_and_normalized_data as (
115 |
116 | select
117 | product_id,
118 | price,
119 | log(price) as log_price,
120 | (time - avg(time) over ()) / (stddev(time) over ()) as t_norm,
121 | (sin_t - avg(sin_t) over ()) / (stddev(sin_t) over ()) as sin_t_norm,
122 | (cos_t - avg(cos_t) over ()) / (stddev(cos_t) over ()) as cos_t_norm
123 | from
124 | preprocessed_data
125 |
126 | ),
127 |
128 | coefficients as (
129 |
130 | select * from {{
131 | dbt_linreg.ols(
132 | table='preprocessed_and_normalized_data',
133 | endog='log_price',
134 | exog=['t_norm', 'sin_t_norm', 'cos_t_norm'],
135 | group_by=['product_id'],
136 | alpha=0.0001
137 | )
138 | }}
139 |
140 | ),
141 |
142 | predict as (
143 |
144 | select
145 | d.product_id,
146 | d.time,
147 | d.price,
148 | exp(
149 | c.const
150 | + d.t_norm * c.t_norm
151 | + d.sin_t_norm * c.sin_t_norm
152 | + d.cos_t_norm * sin_t_norm) as predicted_price
153 | from
154 | preprocessed_and_normalized_data as d
155 | join
156 | coefficients as c
157 | on
158 | d.product_id = c.product_id
159 |
160 | )
161 |
162 | select
163 | product_id,
164 | pow(corr(predicted_price, price), 2) as r_squared
165 | from
166 | predict
167 | group by
168 | product_id
169 | ```
170 |
171 | # Supported Databases
172 |
173 | **dbt_linreg** should work with most SQL databases, but so far, testing has been done for the following database tools:
174 |
175 | | Database | Supported | Precision asserted in CI\* | Supported since version |
176 | |----------------|-----------|----------------------------|-------------------------|
177 | | **Snowflake** | ✅ | _n/a_ | 0.1.0 |
178 | | **DuckDB** | ✅ | 10e-7 | 0.1.0 |
179 | | **Postgres**† | ✅ | 10e-7 | 0.2.3 |
180 | | **Redshift** | ✅ | _n/a_ | 0.2.4 |
181 | | **Clickhouse** | ✅ | 10e-6 | 0.3.0 |
182 |
183 | If **dbt_linreg** does not work in your database tool, please let me know in a bug report.
184 |
185 | > _\* Precision is for test cases using the **collinear_matrix** for unregularized regressions, in comparison to the output of the same regression in the Python package Statsmodels using `sm.OLS().fit(method="pinv")`. For example, coefficients for unregularized regressions performed in DuckDB are asserted to be within 10e-7 of Statsmodels._
186 |
187 | > _† Minimal support for Postgres. Postgres is syntactically supported, but is not performant under certain circumstances._
188 |
189 | # API
190 |
191 | The only function available in the public API is the `dbt_linreg.ols()` macro.
192 |
193 | Using Python typing notation, the full API for `dbt_linreg.ols()` looks like this:
194 |
195 | ```python
196 | def ols(
197 | table: str,
198 | endog: str,
199 | exog: Union[str, list[str]],
200 | add_constant: bool = True,
201 | output: Literal['wide', 'long'] = 'wide',
202 | output_options: Optional[dict[str, Any]] = None,
203 | group_by: Optional[Union[str, list[str]]] = None,
204 | alpha: Optional[Union[float, list[float]]] = None,
205 | method: Literal['chol', 'fwl'] = 'chol',
206 | method_options: Optional[dict[str, Any]] = None
207 | ):
208 | ...
209 | ```
210 |
211 | Where:
212 |
213 | - **table**: Name of table or CTE to pull the data from. You can use `ref()` or `source()` here if you'd like.
214 | - **endog**: The endogenous variable / y variable / target variable of the regression. (You can also specify `y=...` instead of `endog=...` if you prefer.)
215 | - **exog**: The exogenous variables / X variables / features of the regression. (You can also specify `x=...` instead of `exog=...` if you prefer.)
216 | - **add_constant**: If true, a constant term is added automatically to the regression.
217 | - **output**: Either "wide" or "long" output format for coefficients. See **Outputs and output options** for more.
218 | - If `wide`, the variables span the columns with their original variable names, and the coefficients fill a single row.
219 | - If `long`, the coefficients are in a single column called `coefficient`, and the variable names are in a single column called `variable_name`.
220 | - **output_options**: See **Formats and format options** section for more.
221 | - **group_by**: If specified, the regression will be grouped by these variables, and individual regressions will run on each group.
222 | - **alpha**: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section for more information.
223 | - **method**: The method used to calculate the regression. See **Methods and method options** for more.
224 | - **method_options**: Options specific to the estimation method. See **Methods and method options** for more.
225 |
226 | # Outputs and output options
227 |
228 | Outputs can be returned either in `output='long'` or `output='wide'`.
229 |
230 | All outputs have their own output options, which can be passed into the `output_options=` arg as a dict, e.g. `output_options={'foo': 'bar'}`.
231 |
232 | `output=` and `output_options=` were formerly named `format=` and `format_options=` respectively.
233 | This has been deprecated to make **dbt_linreg**'s API more consistent with **dbt_pca**'s API.
234 |
235 | ### Options for `output='long'`
236 |
237 | - **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
238 | - **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
239 | - **variable_column_name** (`string`; default = `'variable_name'`): Column name storing strings of variable names.
240 | - **coefficient_column_name** (`string`; default = `'coefficient'`): Column name storing model coefficients.
241 | - **strip_quotes** (`bool`; default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.
242 |
243 | These options are available for `output='long'` only when `method='chol'`:
244 |
245 | - **calculate_standard_error** (`bool`; default = `True if not alpha else False`): If true, provide the standard error in the output.
246 | - **standard_error_column_name** (`string`; default = `'standard_error'`): Column name storing the standard error for the parameter.
247 | - **t_statistic_column_name** (`string`; default = `'t_statistic'`): Column name storing the t-statistic for the parameter.
248 |
249 | ### Options for `output='wide'`
250 |
251 | - **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
252 | - **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
253 | - **variable_column_prefix** (`string`; default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
254 | - **variable_column_suffix** (`string`; default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
255 |
256 | ## Setting output options globally
257 |
258 | Output options can be set globally via `vars`, e.g. in your `dbt_project.yml`:
259 |
260 | ```yaml
261 | # dbt_project.yml
262 | vars:
263 | dbt_linreg:
264 | output_options:
265 | round: 5
266 | ```
267 |
268 | Output options passed via `ols()` always take precedence over globally set output options.
269 |
270 | # Methods and method options
271 |
272 | There are currently two valid methods for calculating regression coefficients:
273 |
274 | - `chol`: Uses Cholesky decomposition to calculate the pseudo-inverse.
275 | - `fwl`: Uses a "Frisch-Waugh-Lovell" approach, which consists of calculating univariate regressions to get multiple regression coefficients.
276 |
277 | ## `chol` method
278 |
279 | **👍 This is the suggested method (and the default) for calculating regressions!**
280 |
281 | This method calculates regression coefficients using the Moore-Penrose pseudo-inverse, and the inverse of **X'X** is calculated using Cholesky decomposition, hence it is referred to as `chol`.
282 |
283 | ### Options for `method='chol'`
284 |
285 | Specify these in a dict using the `method_options=` kwarg:
286 |
287 | - **safe** (`bool`; default: `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value may be passed into a SQRT() or a divide by zero may occur, and most SQL engines will raise an error when this happens.
288 | - **subquery_optimization** (`bool`; default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
289 | - **intra_select_aliasing** (`bool`; default = `[depends on db]`): If True, within a single select statement, column aliases are used to refer to other columns created during that select. This can significantly reduce the text of a SQL query, but not all SQL engines support this. By default, for all databases officially supported by **dbt_linreg**, the best option is already selected. For unsupported databases, the default is `False` for broad compatibility, so if you are running **dbt_linreg** in an officially unsupported database engine which supports this feature, you may want to modify this option globally in your `vars` to be `true`.
290 |
291 | ## `fwl` method
292 |
293 | **This method is generally not recommended.**
294 |
295 | Simple univariate regression coefficients are simply `covar_pop(y, x) / var_pop(x)`.
296 |
297 | The multiple regression implementation uses a technique described in section `3.2.3 Multiple Regression from Simple Univariate Regression` of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=71)). Econometricians know this as the Frisch-Waugh-Lovell theorem, hence the method is referred to as `fwl` internally in the code base.
298 |
299 | Ridge regression is implemented using the augmentation technique described in Exercise 12 of Chapter 3 of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=115)).
300 |
301 | There are a few reasons why this method is discouraged over the `chol` method:
302 |
303 | - 🐌 It tends to be much slower in OLAP systems, and struggles to efficiently calculate large number of columns.
304 | - 📊 It does not calculate standard errors.
305 | - 😕 For ridge regression, coefficients are not accurate; they tend to be off by a magnitude of ~0.01%.
306 | - ⚠️ It does not work in all databases because it relies on `COVAR_POP`.
307 |
308 | So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgres) for unregularized coefficient estimation. Long story short, the `chol` method relies on subquery optimization to be more performant than `fwl`; however, OLTP systems do not benefit at all from subquery optimization. This means that `fwl` is slightly more performant in this context.
309 |
310 | ## Setting method options globally
311 |
312 | Method options can be set globally via `vars`, e.g. in your `dbt_project.yml`. Each `method` gets its own config, e.g. the `dbt_linreg: method_options: chol: ...` namespace only applies to the `chol` method. Here is an example:
313 |
314 | ```yaml
315 | # dbt_project.yml
316 | vars:
317 | dbt_linreg:
318 | method_options:
319 | chol:
320 | intra_select_aliasing: true
321 | ```
322 |
323 | Method options passed via `ols()` always take precedence over globally set method options.
324 |
325 | # Notes
326 |
327 | - ⚠️ **If your coefficients are null, it does not mean dbt_linreg is broken, it most likely means your feature columns are perfectly multicollinear.** If you are 100% sure that is not the issue, please file a bug report with a minimally reproducible example.
328 |
329 | - Regularization is implemented using nearly the same approach as Statsmodels; the only difference is that the constant term can never be regularized. This means:
330 | - A scalar input (e.g. `alpha=0.01`) will apply an alpha of `0.01` to all features.
331 | - An array input (e.g. `alpha=[0.01, 0.02, 0.03, 0.04, 0.05]`) will apply an alpha of `0.01` to the first column, `0.02` to the second column, etc.
332 | - `alpha` is equivalent to what TEoSL refers to as "lambda," times the sample size N. That is to say: `α ≡ λ * N`.
333 | - (Of course, you can regularize the constant term by DIYing your own constant term and doing `add_constant=false`.)
334 |
335 | - Regularization as currently implemented for the `chol` method tends to be very slow in OLTP systems (e.g. Postgres), but is very performant in OLAP systems (e.g. Snowflake, DuckDB, BigQuery, Redshift). As dbt is more commonly used in OLAP contexts, the code base is optimized for the OLAP use case.
336 | - That said, it may be possible to make regularization in OLTP more performant (e.g. with augmentation of the design matrix), so PRs are welcome.
337 |
338 | - Regression coefficients in Postgres are always `numeric` types.
339 |
340 | ## Possible future features
341 |
342 | Some things that could happen in the future:
343 |
344 | - Weighted least squares (WLS)
345 | - Efficient multivariate regression (i.e. multiple endogenous vectors sharing a single design matrix)
346 | - P-values
347 | - Heteroskedasticity robust standard errors
348 | - Recursive CTE implementations + long formatted inputs
349 |
350 | Note that although I maintain this library (as of writing in 2025), I do not actively update it much with new features, so this wish list is unlikely unless I personally need it or unless someone else contributes these features.
351 |
352 | # FAQ
353 |
354 | ### How does this work?
355 |
356 | See **Methods and method options** section for a full breakdown of each linear regression implementation.
357 |
358 | All approaches were validated using Statsmodels `sm.OLS()`.
359 |
360 | ### BigQuery (or other database) has linear regression implemented natively. Why should I use `dbt_linreg` over that?
361 |
362 | You don't have to use this. Most warehouses don't support multiple regression out of the box, so this satisfies a niche for those database tools which don't.
363 |
364 | That said, even in BigQuery, it may be very useful to extract coefficients within a query instead of generating a separate `MODEL` object through a DDL statement, for a few reasons. Even in more black box predictive contexts, being able to predict in the same `SELECT` statement as training can be convenient. Additionally, BigQuery does not expose model coefficients to users, and this can be a dealbreaker in many contexts where you care about your coefficients as measurements, not as predictive model parameters. Lastly, `group_by` is akin to estimating parameters for multiple linear regressions at once.
365 |
366 | Overall, I would say this is pretty different from what BigQuery's `CREATE MODEL` is doing; use whatever makes sense for your use case! But keep in mind that for large numbers of variables, a native implementation of linear regression will be noticeably more efficient than this implementation.
367 |
368 | ### Why is L2 regularization / ridge regression supported, but not L1 regularization / LASSO supported?
369 |
370 | There is no closed-form solution to L1 regularization, which makes it very very hard to add through raw SQL. L2 regularization has a closed-form solution and can be implemented using a pre-processing trick.
371 |
372 | ### Is the `group_by=[...]` argument like categorical variables / one-hot encodings?
373 |
374 | No. The `group_by` runs a linear regressions within each group, and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.
375 |
376 | ### Why aren't categorical variables / one-hot encodings supported?
377 |
378 | I opt to leave out dummy variable support because it's tricky, and I want to keep the API clean and mull on how to best implement that at the highest level.
379 |
380 | Note that you couldn't simply add categorical variables in the same list as numeric variables because Jinja2 templating is not natively aware of the types you're feeding through it, nor does Jinja2 know the values that a string variable can take on. The way you would actually implement categorical variables is with `group by` trickery (i.e. center both y and X by categorical variable group means), although I am not sure how to do that efficiently for more than one categorical variable column.
381 |
382 | If you'd like to regress on a categorical variable, for now you'll need to do your own feature engineering, e.g. `(foo = 'bar')::int as foo_bar, (foo = 'baz')::int as foo_baz`.
383 |
384 | ### Why are there no p-values?
385 |
386 | This is something that might happen in the future. P-values would require a lookup on a dimension table, which is a significant amount of work to manage nicely.
387 |
388 | In the meanwhile, you can implement this yourself-- just create a dimension table that left joins a t-statistic on a half-open interval to lookup a p-value.
389 |
390 | # Trademark & Copyright
391 |
392 | dbt is a trademark of dbt Labs.
393 |
394 | This package is **unaffiliated** with dbt Labs.
395 |
--------------------------------------------------------------------------------