├── .idea ├── dbt_linreg.iml └── .gitignore ├── .python-version ├── integration_tests ├── seeds │ └── .gitkeep ├── packages.yml ├── tests │ ├── test_perfectly_multicollinear_model.sql │ ├── test_simple_0var_regression_wide.sql │ ├── test_wide_format_options.sql │ ├── test_simple_1var_regression_wide.sql │ ├── test_simple_0var_regression_long_fwl.sql │ ├── test_simple_0var_regression_long_chol.sql │ ├── test_simple_1var_regression_long_chol.sql │ ├── test_simple_1var_regression_long_fwl.sql │ ├── test_simple_2var_regression_wide.sql │ ├── test_simple_2var_regression_long.sql │ ├── test_simple_3var_regression_long.sql │ ├── test_collinear_matrix_1var_without_const_ridge.sql │ ├── test_simple_3var_regression_wide.sql │ ├── test_simple_4var_regression_long.sql │ ├── test_simple_5var_regression_long.sql │ ├── test_collinear_matrix_1var_without_const.sql │ ├── test_simple_4var_regression_wide.sql │ ├── test_collinear_matrix_2var_without_const.sql │ ├── test_collinear_matrix_ridge_regression_fwl.sql │ ├── test_simple_5var_regression_wide.sql │ ├── test_collinear_matrix_5var_without_const_ridge.sql │ ├── test_collinear_matrix_ridge_regression_chol.sql │ ├── test_collinear_matrix_regression_fwl.sql │ ├── test_collinear_matrix_ridge_regression_chol_unoptimized.sql │ ├── test_collinear_matrix_3var_without_const.sql │ ├── test_long_format_options.sql │ ├── test_collinear_matrix_4var_without_const.sql │ ├── test_collinear_matrix_5var_without_const.sql │ ├── test_groups_matrix_regression_fwl.sql │ ├── test_collinear_matrix_regression_chol.sql │ ├── test_collinear_matrix_regression_chol_unoptimized.sql │ ├── test_groups_matrix_regression_chol_optimized.sql │ └── test_groups_matrix_regression_chol_unoptimized.sql ├── models │ ├── simple_0var_regression_wide.sql │ ├── collinear_matrix_1var_without_const.sql │ ├── simple_0var_regression_long_chol.sql │ ├── simple_0var_regression_long_fwl.sql │ ├── simple_1var_regression_long_fwl.sql │ ├── simple_1var_regression_wide.sql │ ├── collinear_matrix_2var_without_const.sql │ ├── simple_1var_regression_long_chol.sql │ ├── simple_2var_regression_long.sql │ ├── simple_2var_regression_wide.sql │ ├── collinear_matrix_3var_without_const.sql │ ├── simple_3var_regression_long.sql │ ├── collinear_matrix_regression_fwl.sql │ ├── simple_4var_regression_long.sql │ ├── collinear_matrix_4var_without_const.sql │ ├── collinear_matrix_1var_without_const_ridge.sql │ ├── collinear_matrix_5var_without_const.sql │ ├── collinear_matrix_ridge_regression_fwl.sql │ ├── simple_3var_regression_wide.sql │ ├── collinear_matrix_regression_chol.sql │ ├── simple_4var_regression_wide.sql │ ├── simple_5var_regression_long.sql │ ├── simple_5var_regression_wide.sql │ ├── groups_matrix_regression_fwl.sql │ ├── perfectly_multicollinear_model.sql │ ├── simple_8var_regression_wide.sql │ ├── collinear_matrix_ridge_regression_chol.sql │ ├── collinear_matrix_5var_without_const_ridge.sql │ ├── simple_10var_regression_long.sql │ ├── collinear_matrix_regression_chol_unoptimized.sql │ ├── collinear_matrix_ridge_regression_chol_unoptimized.sql │ ├── groups_matrix_regression_chol_optimized.sql │ ├── wide_format_options.sql │ ├── groups_matrix_regression_chol_unoptimized.sql │ └── long_format_options.sql ├── selectors.yml ├── dbt_project.yml └── profiles │ └── profiles.yml ├── .dbtignore ├── docs ├── src │ ├── img │ │ ├── dbt-linreg-logo.png │ │ ├── dbt-linreg-favicon.png │ │ ├── dbt-linreg-banner-dark.png │ │ └── dbt-linreg-banner-light.png │ ├── index.md │ └── css │ │ └── extra.css ├── requirements.txt └── mkdocs.yml ├── .editorconfig ├── dbt_project.yml ├── .github └── workflows │ ├── docs.yml │ └── tests.yml ├── .pre-commit-config.yaml ├── pyproject.toml ├── macros ├── linear_regression │ ├── ols_impl_special │ │ ├── _ols_0var.sql │ │ └── _ols_1var.sql │ ├── ols.sql │ ├── ols_impl_fwl │ │ └── _ols_impl_fwl.sql │ ├── utils │ │ └── utils.sql │ └── ols_impl_chol │ │ └── _ols_impl_chol.sql └── schema.yml ├── LICENSE ├── run ├── CHANGELOG.md ├── .gitignore ├── scripts.py └── README.md /.idea/dbt_linreg.iml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /integration_tests/seeds/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !dbt_linreg.iml 4 | -------------------------------------------------------------------------------- /integration_tests/packages.yml: -------------------------------------------------------------------------------- 1 | packages: 2 | - local: ../ 3 | -------------------------------------------------------------------------------- /.dbtignore: -------------------------------------------------------------------------------- 1 | * 2 | !macros/ 3 | !dbt_project.yml 4 | !README.md 5 | -------------------------------------------------------------------------------- /docs/src/img/dbt-linreg-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwreeves/dbt_linreg/HEAD/docs/src/img/dbt-linreg-logo.png -------------------------------------------------------------------------------- /docs/src/img/dbt-linreg-favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwreeves/dbt_linreg/HEAD/docs/src/img/dbt-linreg-favicon.png -------------------------------------------------------------------------------- /docs/src/img/dbt-linreg-banner-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwreeves/dbt_linreg/HEAD/docs/src/img/dbt-linreg-banner-dark.png -------------------------------------------------------------------------------- /docs/src/img/dbt-linreg-banner-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dwreeves/dbt_linreg/HEAD/docs/src/img/dbt-linreg-banner-light.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | pymdown-extensions 3 | mkdocs-material 4 | mkdocs-macros-plugin 5 | pygments 6 | markdown-include 7 | pathspec 8 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | 7 | [*.{sql,yml,yaml}] 8 | indent_style = space 9 | indent_size = 2 10 | -------------------------------------------------------------------------------- /integration_tests/tests/test_perfectly_multicollinear_model.sql: -------------------------------------------------------------------------------- 1 | select * 2 | from {{ ref('perfectly_multicollinear_model') }} 3 | where 4 | const is not null 5 | or xa is not null 6 | or xb is not null 7 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_0var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 6 | 10.0 as const 7 | ) 8 | 9 | select base.* 10 | from {{ ref('simple_0var_regression_wide') }} as base, expected 11 | where not ( 12 | base.const = expected.const 13 | ) 14 | -------------------------------------------------------------------------------- /docs/src/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - title 4 | - toc 5 | - navigation 6 | --- 7 | 8 |

9 | 10 | 11 |

12 | 13 | {!../README.md!} 14 | -------------------------------------------------------------------------------- /integration_tests/models/simple_0var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=[], 11 | output='wide', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_1var_without_const.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1'], 11 | output='long', 12 | add_constant=false 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_0var_regression_long_chol.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=[], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_0var_regression_long_fwl.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=[], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_1var_regression_long_fwl.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa'], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_1var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa'], 11 | output='wide', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/tests/test_wide_format_options.sql: -------------------------------------------------------------------------------- 1 | with base as ( 2 | 3 | select 4 | fooconstant_term_bar, 5 | "fooxa_bar", 6 | fooxb_bar 7 | from 8 | {{ ref("wide_format_options") }} 9 | 10 | ) 11 | 12 | /* If this SQL query doesn't throw an error, it's all set. */ 13 | select * from base where false 14 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_2var_without_const.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2'], 11 | output='long', 12 | add_constant=false 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_1var_regression_long_chol.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa'], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_2var_regression_long.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa', 'xb'], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_2var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa', 'xb'], 11 | output='wide', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_3var_without_const.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2', 'x3'], 11 | output='long', 12 | add_constant=false 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_3var_regression_long.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa', 'xb', 'xc'], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_regression_fwl.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 11 | output='long', 12 | method='fwl' 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/simple_4var_regression_long.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('simple_matrix'), 9 | endog='y', 10 | exog=['xa', 'xb', 'xc', 'xd'], 11 | output='long', 12 | output_options={'round': 5} 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_4var_without_const.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2', 'x3', 'x4'], 11 | output='long', 12 | add_constant=false 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_1var_without_const_ridge.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1'], 11 | alpha=2.0, 12 | output='long', 13 | add_constant=false 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_5var_without_const.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 11 | output='long', 12 | add_constant=false 13 | ) 14 | }} as linreg 15 | -------------------------------------------------------------------------------- /dbt_project.yml: -------------------------------------------------------------------------------- 1 | name: "dbt_linreg" 2 | version: "0.3.1" 3 | 4 | # 1.2 is required because of modules.itertools. 5 | require-dbt-version: [">=1.2.0", "<2.0.0"] 6 | 7 | config-version: 2 8 | 9 | target-path: "target" 10 | clean-targets: ["target", "dbt_modules", "dbt_packages"] 11 | macro-paths: ["macros"] 12 | log-path: "logs" 13 | profile: "dbt_linreg_profile" 14 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_ridge_regression_fwl.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('collinear_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 11 | output='long', 12 | alpha=0.01, 13 | method='fwl' 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/simple_3var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('simple_matrix'), 10 | endog='y', 11 | exog=['xa', 'xb', 'xc'], 12 | output='wide', 13 | output_options={'round': 5} 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_regression_chol.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('collinear_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 12 | output='long', 13 | method='chol' 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/simple_4var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('simple_matrix'), 10 | endog='y', 11 | exog=['xa', 'xb', 'xc', 'xd'], 12 | output='wide', 13 | output_options={'round': 5} 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/simple_5var_regression_long.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('simple_matrix'), 10 | endog='y', 11 | exog=['xa', 'xb', 'xc', 'xd', 'xe'], 12 | output='long', 13 | output_options={'round': 5} 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/simple_5var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('simple_matrix'), 10 | endog='y', 11 | exog=['xa', 'xb', 'xc', 'xd', 'xe'], 12 | output='wide', 13 | output_options={'round': 5} 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/groups_matrix_regression_fwl.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select * from {{ 7 | dbt_linreg.ols( 8 | table=ref('groups_matrix'), 9 | endog='y', 10 | exog=['x1', 'x2', 'x3'], 11 | group_by=['gb_var'], 12 | output='long', 13 | method='fwl' 14 | ) 15 | }} as linreg 16 | order by gb_var, variable_name 17 | -------------------------------------------------------------------------------- /integration_tests/models/perfectly_multicollinear_model.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | with base as ( 7 | select 8 | y, 9 | xa, 10 | xa as xb 11 | from {{ ref('simple_matrix') }} 12 | ) 13 | 14 | select * from {{ 15 | dbt_linreg.ols( 16 | table='base', 17 | endog='y', 18 | exog=['xa', 'xb'] 19 | ) 20 | }} as linreg 21 | -------------------------------------------------------------------------------- /integration_tests/models/simple_8var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="view", 4 | tags=["perftest", "skip-postgres"], 5 | enabled=False, 6 | ) 7 | }} 8 | select * from {{ 9 | dbt_linreg.ols( 10 | table=ref('simple_matrix'), 11 | endog='y', 12 | exog=['xa', 'xb', 'xc', 'xd', 'xe', 'xf', 'xg', 'xh'], 13 | output='wide', 14 | ) 15 | }} as linreg 16 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_ridge_regression_chol.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('collinear_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 12 | output='long', 13 | alpha=0.01, 14 | method='chol' 15 | ) 16 | }} as linreg 17 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_5var_without_const_ridge.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('collinear_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 12 | alpha=1.0, 13 | output='long', 14 | add_constant=false 15 | ) 16 | }} as linreg 17 | -------------------------------------------------------------------------------- /integration_tests/models/simple_10var_regression_long.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="view", 4 | enabled=False, 5 | tags=["skip-postgres"] 6 | ) 7 | }} 8 | select * from {{ 9 | dbt_linreg.ols( 10 | table=ref('simple_matrix'), 11 | endog='y', 12 | exog=['xa', 'xb', 'xc', 'xd', 'xe', 'xf', 'xg', 'xh', 'xi', 'xj'], 13 | output='long', 14 | output_options={'round': 5} 15 | ) 16 | }} as linreg 17 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres", "skip-clickhouse"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('collinear_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 12 | output='long', 13 | method='chol', 14 | method_options={'subquery_optimization': False} 15 | ) 16 | }} as linreg 17 | -------------------------------------------------------------------------------- /docs/src/css/extra.css: -------------------------------------------------------------------------------- 1 | [data-md-color-scheme="light"] img[src$="#only-dark"], 2 | [data-md-color-scheme="light"] img[src$="#gh-dark-mode-only"] { 3 | display: none; /* Hide dark images in light mode */ 4 | } 5 | 6 | [data-md-color-scheme="dark"] img[src$="#only-light"], 7 | [data-md-color-scheme="dark"] img[src$="#gh-light-mode-only"] { 8 | display: none; /* Hide light images in dark mode */ 9 | } 10 | 11 | img[src$="#readme-logo"] { 12 | display: none; 13 | } 14 | -------------------------------------------------------------------------------- /integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres", "skip-clickhouse"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('collinear_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3', 'x4', 'x5'], 12 | output='long', 13 | alpha=0.01, 14 | method='chol', 15 | method_options={'subquery_optimization': False} 16 | ) 17 | }} as linreg 18 | -------------------------------------------------------------------------------- /integration_tests/models/groups_matrix_regression_chol_optimized.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('groups_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3'], 12 | group_by=['gb_var'], 13 | output='long', 14 | method='chol', 15 | method_options={'subquery_optimization': True} 16 | ) 17 | }} as linreg 18 | order by gb_var, variable_name 19 | -------------------------------------------------------------------------------- /integration_tests/models/wide_format_options.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select 7 | * 8 | from {{ 9 | dbt_linreg.ols( 10 | table=ref('simple_matrix'), 11 | endog='y', 12 | exog=['"xa"', 'xb'], 13 | output='wide', 14 | output_options={ 15 | 'variable_column_prefix': 'foo', 16 | 'variable_column_suffix': '_bar', 17 | 'constant_name': 'constant_term' 18 | } 19 | ) 20 | }} as linreg 21 | -------------------------------------------------------------------------------- /integration_tests/models/groups_matrix_regression_chol_unoptimized.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table", 4 | tags=["skip-postgres", "skip-clickhouse"] 5 | ) 6 | }} 7 | select * from {{ 8 | dbt_linreg.ols( 9 | table=ref('groups_matrix'), 10 | endog='y', 11 | exog=['x1', 'x2', 'x3'], 12 | group_by=['gb_var'], 13 | output='long', 14 | method='chol', 15 | method_options={'subquery_optimization': False} 16 | ) 17 | }} as linreg 18 | order by gb_var, variable_name 19 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_1var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 6 | 10.0 as const, 7 | 5.0 as xa 8 | ) 9 | 10 | select 11 | expected.const as expected_const, 12 | base.const as actual_const, 13 | expected.xa as expected_xa, 14 | base.xa as actual_xa 15 | from {{ ref('simple_1var_regression_wide') }} as base, expected 16 | where not ( 17 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }} 18 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }} 19 | ) 20 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | jobs: 7 | build: 8 | name: Deploy docs 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout main 12 | uses: actions/checkout@v1 13 | - name: Install dependencies 14 | run: sudo apt-get update 15 | - name: Deploy docs 16 | uses: mhausenblas/mkdocs-deploy-gh-pages@1.26 17 | env: 18 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 19 | CONFIG_FILE: docs/mkdocs.yml 20 | REQUIREMENTS: docs/requirements.txt 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | 3 | - repo: https://github.com/pre-commit/pre-commit-hooks 4 | rev: v2.3.0 5 | hooks: 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | 10 | - repo: https://github.com/charliermarsh/ruff-pre-commit 11 | rev: v0.6.2 12 | hooks: 13 | - id: ruff 14 | 15 | - repo: https://github.com/koalaman/shellcheck-precommit 16 | rev: v0.8.0 17 | hooks: 18 | - id: shellcheck 19 | args: [-x, run] 20 | 21 | - repo: https://github.com/rhysd/actionlint 22 | rev: v1.6.26 23 | hooks: 24 | - id: actionlint 25 | -------------------------------------------------------------------------------- /integration_tests/selectors.yml: -------------------------------------------------------------------------------- 1 | selectors: 2 | - name: duckdb-selector 3 | definition: 'fqn:*' 4 | - name: postgres-selector 5 | # Postgres runs into memory / performance issues for some of these queries. 6 | # Resolving this and making Postgres more performant is a TODO. 7 | definition: 8 | union: 9 | - 'fqn:*' 10 | - exclude: 11 | - '@tag:skip-postgres' 12 | - name: clickhouse-selector 13 | # Clickhouse struggles with the unoptimized chol method. 14 | definition: 15 | union: 16 | - 'fqn:*' 17 | - exclude: 18 | - '@tag:skip-clickhouse' 19 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_0var_regression_long_fwl.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | 7 | ) 8 | 9 | select 10 | coalesce(base.variable_name, expected.variable_name) as variable_name, 11 | expected.coefficient as expected_coefficient, 12 | base.coefficient as actual_coefficient 13 | from {{ ref('simple_0var_regression_long_fwl') }} as base 14 | full outer join expected 15 | on base.variable_name = expected.variable_name 16 | where 17 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 18 | or base.coefficient is null 19 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_0var_regression_long_chol.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | 7 | ) 8 | 9 | select 10 | coalesce(base.variable_name, expected.variable_name) as variable_name, 11 | expected.coefficient as expected_coefficient, 12 | base.coefficient as actual_coefficient 13 | from {{ ref('simple_0var_regression_long_chol') }} as base 14 | full outer join expected 15 | on base.variable_name = expected.variable_name 16 | where 17 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 18 | or base.coefficient is null 19 | -------------------------------------------------------------------------------- /integration_tests/dbt_project.yml: -------------------------------------------------------------------------------- 1 | name: "dbt_linreg_tests" 2 | version: "0.3.1" 3 | 4 | require-dbt-version: [">=1.0.0", "<2.0.0"] 5 | 6 | config-version: 2 7 | 8 | target-path: "target" 9 | clean-targets: ["target", "dbt_modules", "dbt_packages"] 10 | macro-paths: ["macros"] 11 | log-path: "logs" 12 | 13 | vars: 14 | _test_precision_simple_matrix: '{{ "10e-8" if target.name == "clickhouse" else 0.0 }}' 15 | _test_precision_collinear_matrix: '{{ "10e-6" if target.name == "clickhouse" else "10e-7" }}' 16 | 17 | models: 18 | +materialized: table 19 | 20 | tests: 21 | +store_failures: true 22 | 23 | # During dev only! 24 | profile: "dbt_linreg_profile" 25 | -------------------------------------------------------------------------------- /integration_tests/profiles/profiles.yml: -------------------------------------------------------------------------------- 1 | dbt_linreg_profile: 2 | target: duckdb 3 | outputs: 4 | duckdb: 5 | type: duckdb 6 | path: dbt.duckdb 7 | postgres: 8 | type: postgres 9 | user: '{{ env_var("POSTGRES_USER") }}' 10 | password: '{{ env_var("POSTGRES_PASSWORD") }}' 11 | host: '{{ env_var("POSTGRES_HOST", "localhost") }}' 12 | port: '{{ env_var("POSTGRES_PORT", "5432") | as_number }}' 13 | dbname: '{{ env_var("POSTGRES_DB", "dbt_linreg") }}' 14 | schema: '{{ env_var("POSTGRES_SCHEMA", "public") }}' 15 | clickhouse: 16 | type: clickhouse 17 | port: 8123 18 | schema: dbt_linreg 19 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_1var_regression_long_chol.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | union all 7 | select 'xa' as variable_name, 5.0 as coefficient 8 | 9 | ) 10 | 11 | select 12 | coalesce(base.variable_name, expected.variable_name) as variable_name, 13 | expected.coefficient as expected_coefficient, 14 | base.coefficient as actual_coefficient 15 | from {{ ref('simple_1var_regression_long_chol') }} as base 16 | full outer join expected 17 | on base.variable_name = expected.variable_name 18 | where 19 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 20 | or base.coefficient is null 21 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_1var_regression_long_fwl.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | union all 7 | select 'xa' as variable_name, 5.0 as coefficient 8 | 9 | ) 10 | 11 | select 12 | coalesce(base.variable_name, expected.variable_name) as variable_name, 13 | expected.coefficient as expected_coefficient, 14 | base.coefficient as actual_coefficient 15 | from {{ ref('simple_1var_regression_long_fwl') }} as base 16 | full outer join expected 17 | on base.variable_name = expected.variable_name 18 | where 19 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 20 | or base.coefficient is null 21 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_2var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 6 | 10.0 as const, 7 | 5.0 as xa, 8 | 7.0 as xb 9 | ) 10 | 11 | select 12 | expected.const as expected_const, 13 | base.const as actual_const, 14 | expected.xa as expected_xa, 15 | base.xa as actual_xa, 16 | expected.xb as expected_xb, 17 | base.xb as actual_xb 18 | from {{ ref('simple_2var_regression_wide') }} as base, expected 19 | where not ( 20 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }} 21 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }} 22 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }} 23 | ) 24 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_2var_regression_long.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | union all 7 | select 'xa' as variable_name, 5.0 as coefficient 8 | union all 9 | select 'xb' as variable_name, 7.0 as coefficient 10 | 11 | ) 12 | 13 | select 14 | coalesce(base.variable_name, expected.variable_name) as variable_name, 15 | expected.coefficient as expected_coefficient, 16 | base.coefficient as actual_coefficient 17 | from {{ ref('simple_2var_regression_long') }} as base 18 | full outer join expected 19 | on base.variable_name = expected.variable_name 20 | where 21 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 22 | or base.coefficient is null 23 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_3var_regression_long.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | union all 7 | select 'xa' as variable_name, 5.0 as coefficient 8 | union all 9 | select 'xb' as variable_name, 7.0 as coefficient 10 | union all 11 | select 'xc' as variable_name, 9.0 as coefficient 12 | 13 | ) 14 | 15 | select 16 | coalesce(base.variable_name, expected.variable_name) as variable_name, 17 | expected.coefficient as expected_coefficient, 18 | base.coefficient as actual_coefficient 19 | from {{ ref('simple_3var_regression_long') }} as base 20 | full outer join expected 21 | on base.variable_name = expected.variable_name 22 | where 23 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 24 | or base.coefficient is null 25 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_1var_without_const_ridge.sql: -------------------------------------------------------------------------------- 1 | /* Ridge regression coefficients do not match exactly. 2 | Instead, a threshold of no more than 0.01% deviation is enforced. */ 3 | {% set THRESHOLD = 0.0001 %} 4 | with 5 | 6 | expected as ( 7 | 8 | select 'x1' as variable_name, 21.78558328301129 as coefficient 9 | 10 | ) 11 | 12 | select 13 | coalesce(base.variable_name, expected.variable_name) as variable_name, 14 | expected.coefficient as expected_coefficient, 15 | base.coefficient as actual_coefficient 16 | from {{ ref('collinear_matrix_1var_without_const_ridge') }} as base 17 | full outer join expected 18 | on base.variable_name = expected.variable_name 19 | where 20 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }} 21 | or sign(base.coefficient) != sign(expected.coefficient) 22 | or base.coefficient is null 23 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_3var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 6 | 10.0 as const, 7 | 5.0 as xa, 8 | 7.0 as xb, 9 | 9.0 as xc 10 | ) 11 | 12 | select 13 | expected.const as expected_const, 14 | base.const as actual_const, 15 | expected.xa as expected_xa, 16 | base.xa as actual_xa, 17 | expected.xb as expected_xb, 18 | base.xb as actual_xb, 19 | expected.xc as expected_xc, 20 | base.xc as actual_xc 21 | from {{ ref('simple_3var_regression_wide') }} as base, expected 22 | where not ( 23 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }} 24 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }} 25 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }} 26 | and abs(base.xc - expected.xc) <= {{ var("_test_precision_simple_matrix") }} 27 | ) 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dbt_linreg" 3 | requires-python = "~=3.11" 4 | description = "dbt_linreg dbt package" 5 | version = "0.3.1" 6 | readme = "README.md" 7 | authors = ["Daniel Reeves"] 8 | 9 | [project.optional-dependencies] 10 | python-dev = [ 11 | "pandas>=2.2.3", 12 | "pre-commit>=4.0.1", 13 | "pyyaml>=6.0.2", 14 | "rich-click>=1.8.5", 15 | "ruff>=0.8.4", 16 | "statsmodels>=0.14.4", 17 | "tabulate>=0.9.0", 18 | ] 19 | clickhouse = [ 20 | "dbt-core<1.9.0", 21 | "dbt-clickhouse", 22 | ] 23 | duckdb = [ 24 | "dbt-core<1.9.0", 25 | "dbt-duckdb", 26 | "duckdb>=1.1.3", 27 | ] 28 | postgres = [ 29 | "dbt-core<1.9.0", 30 | "dbt-postgres", 31 | ] 32 | 33 | [tool.ruff] 34 | line-length = 120 35 | 36 | [tool.ruff.lint] 37 | select = ["F", "E", "W", "I001"] 38 | 39 | [tool.ruff.lint.isort] 40 | lines-after-imports = 2 41 | force-single-line = true 42 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_4var_regression_long.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | union all 7 | select 'xa' as variable_name, 5.0 as coefficient 8 | union all 9 | select 'xb' as variable_name, 7.0 as coefficient 10 | union all 11 | select 'xc' as variable_name, 9.0 as coefficient 12 | union all 13 | select 'xd' as variable_name, 11.0 as coefficient 14 | 15 | ) 16 | 17 | select 18 | coalesce(base.variable_name, expected.variable_name) as variable_name, 19 | expected.coefficient as expected_coefficient, 20 | base.coefficient as actual_coefficient 21 | from {{ ref('simple_4var_regression_long') }} as base 22 | full outer join expected 23 | on base.variable_name = expected.variable_name 24 | where 25 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 26 | or base.coefficient is null 27 | -------------------------------------------------------------------------------- /macros/linear_regression/ols_impl_special/_ols_0var.sql: -------------------------------------------------------------------------------- 1 | {% macro _ols_0var(table, 2 | endog, 3 | exog, 4 | add_constant=True, 5 | output=None, 6 | output_options=None, 7 | group_by=None, 8 | alpha=None) -%} 9 | (with _dbt_linreg_final_coefs as ( 10 | select 11 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) }} 12 | avg({{ endog }}) as x0_coef 13 | from {{ table }} 14 | {%- if group_by %} 15 | group by 16 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }} 17 | {%- endif %} 18 | ) 19 | {{ 20 | dbt_linreg.final_select( 21 | exog=[], 22 | exog_aliased=[], 23 | add_constant=add_constant, 24 | group_by=group_by, 25 | output=output, 26 | output_options=output_options, 27 | calculate_standard_error=False 28 | ) 29 | }} 30 | ) 31 | {%- endmacro %} 32 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_5var_regression_long.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 10.0 as coefficient 6 | union all 7 | select 'xa' as variable_name, 5.0 as coefficient 8 | union all 9 | select 'xb' as variable_name, 7.0 as coefficient 10 | union all 11 | select 'xc' as variable_name, 9.0 as coefficient 12 | union all 13 | select 'xd' as variable_name, 11.0 as coefficient 14 | union all 15 | select 'xe' as variable_name, 13.0 as coefficient 16 | 17 | ) 18 | 19 | select 20 | coalesce(base.variable_name, expected.variable_name) as variable_name, 21 | expected.coefficient as expected_coefficient, 22 | base.coefficient as actual_coefficient 23 | from {{ ref('simple_5var_regression_long') }} as base 24 | full outer join expected 25 | on base.variable_name = expected.variable_name 26 | where 27 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_simple_matrix") }} 28 | or base.coefficient is null 29 | -------------------------------------------------------------------------------- /integration_tests/models/long_format_options.sql: -------------------------------------------------------------------------------- 1 | {{ 2 | config( 3 | materialized="table" 4 | ) 5 | }} 6 | select 7 | true as strip_quotes, * 8 | from {{ 9 | dbt_linreg.ols( 10 | table=ref('simple_matrix'), 11 | endog='y', 12 | exog=['"xa"', 'xb'], 13 | output='long', 14 | output_options={ 15 | 'constant_name': 'constant_term', 16 | 'variable_column_name': 'vname', 17 | 'coefficient_column_name': 'co', 18 | 'strip_quotes': True 19 | } 20 | ) 21 | }} as linreg1 22 | 23 | union all 24 | 25 | select 26 | false as strip_quotes, * 27 | from {{ 28 | dbt_linreg.ols( 29 | table=ref('simple_matrix'), 30 | endog='y', 31 | exog=['"xa"', 'xb'], 32 | output='long', 33 | output_options={ 34 | 'constant_name': 'constant_term', 35 | 'variable_column_name': 'vname', 36 | 'coefficient_column_name': 'co', 37 | 'strip_quotes': False 38 | } 39 | ) 40 | }} as linreg2 41 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_1var_without_const.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'x1' as variable_name, 30.500076644845674 as coefficient, 0.8396121329329627 as standard_error, 36.326388636502585 as t_statistic 6 | 7 | ) 8 | 9 | select 10 | coalesce(base.variable_name, expected.variable_name) as variable_name, 11 | expected.coefficient as expected_coefficient, 12 | base.coefficient as actual_coefficient 13 | from {{ ref('collinear_matrix_1var_without_const') }} as base 14 | full outer join expected 15 | on base.variable_name = expected.variable_name 16 | where 17 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 18 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 19 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 20 | or base.coefficient is null 21 | or base.standard_error is null 22 | or base.t_statistic is null 23 | or expected.coefficient is null 24 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_4var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 6 | 10.0 as const, 7 | 5.0 as xa, 8 | 7.0 as xb, 9 | 9.0 as xc, 10 | 11.0 as xd 11 | ) 12 | 13 | select 14 | expected.const as expected_const, 15 | base.const as actual_const, 16 | expected.xa as expected_xa, 17 | base.xa as actual_xa, 18 | expected.xb as expected_xb, 19 | base.xb as actual_xb, 20 | expected.xc as expected_xc, 21 | base.xc as actual_xc, 22 | expected.xd as expected_xd, 23 | base.xd as actual_xd 24 | from {{ ref('simple_4var_regression_wide') }} as base, expected 25 | where not ( 26 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }} 27 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }} 28 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }} 29 | and abs(base.xc - expected.xc) <= {{ var("_test_precision_simple_matrix") }} 30 | and abs(base.xd - expected.xd) <= {{ var("_test_precision_simple_matrix") }} 31 | ) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2023 Daniel Reeves 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | 'Software'), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 22 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_2var_without_const.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'x1' as variable_name, 63.18154691334764 as coefficient, 0.4056389914380657 as standard_error, 155.75807120848344 as t_statistic 6 | union all 7 | select 'x2' as variable_name, 55.39820150046505 as coefficient, 0.2738669097295638 as standard_error, 202.2814715190283 as t_statistic 8 | 9 | ) 10 | 11 | select 12 | coalesce(base.variable_name, expected.variable_name) as variable_name, 13 | expected.coefficient as expected_coefficient, 14 | base.coefficient as actual_coefficient 15 | from {{ ref('collinear_matrix_2var_without_const') }} as base 16 | full outer join expected 17 | on base.variable_name = expected.variable_name 18 | where 19 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 20 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 21 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 22 | or base.coefficient is null 23 | or base.standard_error is null 24 | or base.t_statistic is null 25 | or expected.coefficient is null 26 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_ridge_regression_fwl.sql: -------------------------------------------------------------------------------- 1 | /* Ridge regression coefficients do not match exactly. 2 | Instead, a threshold of no more than 0.01% deviation is enforced. */ 3 | {% set THRESHOLD = 0.0001 %} 4 | with 5 | 6 | expected as ( 7 | 8 | select 'const' as variable_name, 20.7548151107157 as coefficient 9 | union all 10 | select 'x1' as variable_name, 9.784064449021356 as coefficient 11 | union all 12 | select 'x2' as variable_name, 6.315640539781496 as coefficient 13 | union all 14 | select 'x3' as variable_name, 19.578696589513562 as coefficient 15 | union all 16 | select 'x4' as variable_name, 3.736823845978248 as coefficient 17 | union all 18 | select 'x5' as variable_name, 13.323547772767592 as coefficient 19 | 20 | ) 21 | 22 | select 23 | coalesce(base.variable_name, expected.variable_name) as variable_name 24 | from {{ ref('collinear_matrix_ridge_regression_fwl') }} as base 25 | full outer join expected 26 | on base.variable_name = expected.variable_name 27 | where 28 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }} 29 | or sign(base.coefficient) != sign(expected.coefficient) 30 | or base.coefficient is null 31 | -------------------------------------------------------------------------------- /integration_tests/tests/test_simple_5var_regression_wide.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 6 | 10.0 as const, 7 | 5.0 as xa, 8 | 7.0 as xb, 9 | 9.0 as xc, 10 | 11.0 as xd, 11 | 13.0 as xe 12 | ) 13 | 14 | select 15 | expected.const as expected_const, 16 | base.const as actual_const, 17 | expected.xa as expected_xa, 18 | base.xa as actual_xa, 19 | expected.xb as expected_xb, 20 | base.xb as actual_xb, 21 | expected.xc as expected_xc, 22 | base.xc as actual_xc, 23 | expected.xd as expected_xd, 24 | base.xd as actual_xd, 25 | expected.xe as expected_xe, 26 | base.xe as actual_xe 27 | from {{ ref('simple_5var_regression_wide') }} as base, expected 28 | where not ( 29 | abs(base.const - expected.const) <= {{ var("_test_precision_simple_matrix") }} 30 | and abs(base.xa - expected.xa) <= {{ var("_test_precision_simple_matrix") }} 31 | and abs(base.xb - expected.xb) <= {{ var("_test_precision_simple_matrix") }} 32 | and abs(base.xc - expected.xc) <= {{ var("_test_precision_simple_matrix") }} 33 | and abs(base.xd - expected.xd) <= {{ var("_test_precision_simple_matrix") }} 34 | and abs(base.xe - expected.xe) <= {{ var("_test_precision_simple_matrix") }} 35 | ) 36 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_5var_without_const_ridge.sql: -------------------------------------------------------------------------------- 1 | /* Ridge regression coefficients do not match exactly. 2 | Instead, a threshold of no more than 0.01% deviation is enforced. */ 3 | {% set THRESHOLD = 0.0001 %} 4 | with 5 | 6 | expected as ( 7 | 8 | select 'x1' as variable_name, 9.44760329057758 as coefficient 9 | union all 10 | select 'x2' as variable_name, 3.5049555562844787 as coefficient 11 | union all 12 | select 'x3' as variable_name, 20.753357497835637 as coefficient 13 | union all 14 | select 'x4' as variable_name, 3.522584853991104 as coefficient 15 | union all 16 | select 'x5' as variable_name, 16.31725550368597 as coefficient 17 | 18 | ) 19 | 20 | select 21 | coalesce(base.variable_name, expected.variable_name) as variable_name, 22 | expected.coefficient as expected_coefficient, 23 | base.coefficient as actual_coefficient 24 | from {{ ref('collinear_matrix_5var_without_const_ridge') }} as base 25 | full outer join expected 26 | on base.variable_name = expected.variable_name 27 | where 28 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }} 29 | or sign(base.coefficient) != sign(expected.coefficient) 30 | or base.coefficient is null 31 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_ridge_regression_chol.sql: -------------------------------------------------------------------------------- 1 | /* Ridge regression coefficients do not match exactly. 2 | Instead, a threshold of no more than 0.01% deviation is enforced. */ 3 | {% set THRESHOLD = 0.0001 %} 4 | with 5 | 6 | expected as ( 7 | 8 | select 'const' as variable_name, 20.7548151107157 as coefficient 9 | union all 10 | select 'x1' as variable_name, 9.784064449021356 as coefficient 11 | union all 12 | select 'x2' as variable_name, 6.315640539781496 as coefficient 13 | union all 14 | select 'x3' as variable_name, 19.578696589513562 as coefficient 15 | union all 16 | select 'x4' as variable_name, 3.736823845978248 as coefficient 17 | union all 18 | select 'x5' as variable_name, 13.323547772767592 as coefficient 19 | 20 | ) 21 | 22 | select 23 | coalesce(base.variable_name, expected.variable_name) as variable_name, 24 | expected.coefficient as expected_coefficient, 25 | base.coefficient as actual_coefficient 26 | from {{ ref('collinear_matrix_ridge_regression_chol') }} as base 27 | full outer join expected 28 | on base.variable_name = expected.variable_name 29 | where 30 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }} 31 | or sign(base.coefficient) != sign(expected.coefficient) 32 | or base.coefficient is null 33 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_regression_fwl.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 19.757104885315176 as coefficient 6 | union all 7 | select 'x1' as variable_name, 9.90708767581426 as coefficient 8 | union all 9 | select 'x2' as variable_name, 6.187473206056227 as coefficient 10 | union all 11 | select 'x3' as variable_name, 19.66874583168642 as coefficient 12 | union all 13 | select 'x4' as variable_name, 3.7192417102253468 as coefficient 14 | union all 15 | select 'x5' as variable_name, 13.444273483323244 as coefficient 16 | 17 | ) 18 | 19 | select 20 | coalesce(base.variable_name, expected.variable_name) as variable_name, 21 | expected.coefficient as expected_coefficient, 22 | base.coefficient as actual_coefficient 23 | from {{ ref('collinear_matrix_regression_fwl') }} as base 24 | full outer join expected 25 | on base.variable_name = expected.variable_name 26 | where 27 | {% if target.name == "clickhouse" %} 28 | {# This has poor precision for Clickhouse; not much I can do about it. #} 29 | abs(base.coefficient - expected.coefficient) > 0.1 30 | {% else %} 31 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 32 | {% endif %} 33 | or base.coefficient is null 34 | or expected.coefficient is null 35 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_ridge_regression_chol_unoptimized.sql: -------------------------------------------------------------------------------- 1 | /* Ridge regression coefficients do not match exactly. 2 | Instead, a threshold of no more than 0.01% deviation is enforced. */ 3 | {% set THRESHOLD = 0.0001 %} 4 | with 5 | 6 | expected as ( 7 | 8 | select 'const' as variable_name, 20.7548151107157 as coefficient 9 | union all 10 | select 'x1' as variable_name, 9.784064449021356 as coefficient 11 | union all 12 | select 'x2' as variable_name, 6.315640539781496 as coefficient 13 | union all 14 | select 'x3' as variable_name, 19.578696589513562 as coefficient 15 | union all 16 | select 'x4' as variable_name, 3.736823845978248 as coefficient 17 | union all 18 | select 'x5' as variable_name, 13.323547772767592 as coefficient 19 | 20 | ) 21 | 22 | select 23 | coalesce(base.variable_name, expected.variable_name) as variable_name, 24 | expected.coefficient as expected_coefficient, 25 | base.coefficient as actual_coefficient 26 | from {{ ref('collinear_matrix_ridge_regression_chol_unoptimized') }} as base 27 | full outer join expected 28 | on base.variable_name = expected.variable_name 29 | where 30 | abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }} 31 | or sign(base.coefficient) != sign(expected.coefficient) 32 | or base.coefficient is null 33 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_3var_without_const.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'x1' as variable_name, 20.090207982897063 as coefficient, 0.5196176972417176 as standard_error, 38.6634406209445 as t_statistic 6 | union all 7 | select 'x2' as variable_name, -16.533211090826203 as coefficient, 0.7481701784700665 as standard_error, -22.098195793682894 as t_statistic 8 | union all 9 | select 'x3' as variable_name, 35.00389104686492 as coefficient, 0.351617515124373 as standard_error, 99.55104493154575 as t_statistic 10 | 11 | ) 12 | 13 | select 14 | coalesce(base.variable_name, expected.variable_name) as variable_name, 15 | expected.coefficient as expected_coefficient, 16 | base.coefficient as actual_coefficient 17 | from {{ ref('collinear_matrix_3var_without_const') }} as base 18 | full outer join expected 19 | on base.variable_name = expected.variable_name 20 | where 21 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 22 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 23 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 24 | or base.coefficient is null 25 | or base.standard_error is null 26 | or base.t_statistic is null 27 | or expected.coefficient is null 28 | -------------------------------------------------------------------------------- /integration_tests/tests/test_long_format_options.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | base as ( 4 | 5 | select strip_quotes, vname, co 6 | from {{ ref("long_format_options") }} 7 | 8 | ), 9 | 10 | find_unstripped_quotes as ( 11 | 12 | select 13 | cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_true, 14 | cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_false 15 | from base 16 | where not strip_quotes 17 | 18 | ), 19 | 20 | dodge_unstripped_quotes as ( 21 | 22 | select 23 | cast(max(cast(vname = 'xa' as integer)) as boolean) as should_be_true, 24 | cast(max(cast(vname = '"xa"' as integer)) as boolean) as should_be_false 25 | from base 26 | where strip_quotes 27 | 28 | ), 29 | 30 | coef_col_name as ( 31 | 32 | select 33 | cast(max(cast(vname = 'constant_term' as integer)) as boolean) as should_be_true, 34 | cast(max(cast(vname = 'const' as integer)) as boolean) as should_be_false 35 | from base 36 | 37 | ) 38 | 39 | select 'find_unstripped_quotes' as test_case 40 | from find_unstripped_quotes 41 | where should_be_false or not should_be_true 42 | 43 | union all 44 | 45 | select 'dodge_unstripped_quotes' as test_case 46 | from dodge_unstripped_quotes 47 | where should_be_false or not should_be_true 48 | 49 | union all 50 | 51 | select 'coef_col_name' as test_case 52 | from coef_col_name 53 | where should_be_false or not should_be_true 54 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_4var_without_const.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'x1' as variable_name, 20.587532776354163 as coefficient, 0.5176259827853541 as standard_error, 39.772989496339235 as t_statistic 6 | union all 7 | select 'x2' as variable_name, -20.41001520357013 as coefficient, 0.8103907603637923 as standard_error, -25.185399688426696 as t_statistic 8 | union all 9 | select 'x3' as variable_name, 35.084935774341524 as coefficient, 0.34920588221192245 as standard_error, 100.4706322588505 as t_statistic 10 | union all 11 | select 'x4' as variable_name, 1.8960558858899716 as coefficient, 0.1583538085466205 as standard_error, 11.973541421529871 as t_statistic 12 | 13 | ) 14 | 15 | select 16 | coalesce(base.variable_name, expected.variable_name) as variable_name, 17 | expected.coefficient as expected_coefficient, 18 | base.coefficient as actual_coefficient 19 | from {{ ref('collinear_matrix_4var_without_const') }} as base 20 | full outer join expected 21 | on base.variable_name = expected.variable_name 22 | where 23 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 24 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 25 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 26 | or base.coefficient is null 27 | or base.standard_error is null 28 | or base.t_statistic is null 29 | or expected.coefficient is null 30 | -------------------------------------------------------------------------------- /docs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: dbt_linreg 2 | site_url: https://dwreeves.github.io/dbt_linreg/ 3 | site_description: Docs for dbt_linreg 4 | site_author: Daniel Reeves 5 | repo_url: https://github.com/dwreeves/dbt_linreg/ 6 | repo_name: dbt_linreg 7 | docs_dir: src 8 | nav: 9 | - Home: index.md 10 | theme: 11 | name: material 12 | palette: 13 | - media: "(prefers-color-scheme: dark)" 14 | scheme: slate 15 | primary: black 16 | accent: light blue 17 | toggle: 18 | icon: material/lightbulb-outline 19 | name: Switch to light mode 20 | - media: "(prefers-color-scheme: light)" 21 | scheme: default 22 | primary: white 23 | accent: light blue 24 | toggle: 25 | icon: material/lightbulb 26 | name: Switch to dark mode 27 | logo: img/dbt-linreg-logo.png 28 | favicon: img/dbt-linreg-favicon.png 29 | font: 30 | text: Opens Sans 31 | code: Roboto Mono 32 | plugins: 33 | - macros 34 | markdown_extensions: 35 | - admonition 36 | - pymdownx.tabbed 37 | - pymdownx.keys 38 | - pymdownx.details 39 | - pymdownx.inlinehilite 40 | - pymdownx.superfences 41 | - markdown_include.include: 42 | base_path: docs 43 | - sane_lists 44 | extra_css: 45 | - css/extra.css 46 | extra: 47 | social: 48 | - icon: fontawesome/brands/github 49 | link: https://github.com/dwreeves/dbt_linreg 50 | - icon: fontawesome/brands/linkedin 51 | link: https://www.linkedin.com/in/daniel-reeves-27700545/ 52 | - icon: fontawesome/brands/twitter 53 | link: https://twitter.com/mueblesfeos 54 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_5var_without_const.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'x1' as variable_name, 11.392300499659957 as coefficient, 0.5240533254061608 as standard_error, 21.73881921430515 as t_statistic 6 | union all 7 | select 'x2' as variable_name, 2.333060182571783 as coefficient, 0.9201150492406911 as standard_error, 2.5356178931070636 as t_statistic 8 | union all 9 | select 'x3' as variable_name, 21.895814737788875 as coefficient, 0.44810399169425286 as standard_error, 48.8632441210849 as t_statistic 10 | union all 11 | select 'x4' as variable_name, 3.4480236159406785 as coefficient, 0.1504072830205524 as standard_error, 22.92457882820424 as t_statistic 12 | union all 13 | select 'x5' as variable_name, 15.766951731565559 as coefficient, 0.37297028350495787 as standard_error, 42.274015997727524 as t_statistic 14 | 15 | ) 16 | 17 | select 18 | coalesce(base.variable_name, expected.variable_name) as variable_name, 19 | expected.coefficient as expected_coefficient, 20 | base.coefficient as actual_coefficient 21 | from {{ ref('collinear_matrix_5var_without_const') }} as base 22 | full outer join expected 23 | on base.variable_name = expected.variable_name 24 | where 25 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 26 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 27 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 28 | or base.coefficient is null 29 | or base.standard_error is null 30 | or base.t_statistic is null 31 | or expected.coefficient is null 32 | -------------------------------------------------------------------------------- /integration_tests/tests/test_groups_matrix_regression_fwl.sql: -------------------------------------------------------------------------------- 1 | /* Ridge regression coefficients do not match exactly. 2 | Instead, a threshold of no more than 0.01% deviation is enforced. */ 3 | {% set THRESHOLD = 0.0001 %} 4 | with 5 | 6 | expected as ( 7 | 8 | select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient 9 | union all 10 | select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient 11 | union all 12 | select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient 13 | union all 14 | select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient 15 | union all 16 | select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient 17 | union all 18 | select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient 19 | union all 20 | select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient 21 | union all 22 | select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient 23 | 24 | ) 25 | 26 | select 27 | coalesce(base.variable_name, expected.variable_name) as variable_name, 28 | expected.coefficient as expected_coefficient, 29 | base.coefficient as actual_coefficient 30 | from {{ ref('groups_matrix_regression_fwl') }} as base 31 | full outer join expected 32 | on 33 | base.gb_var = expected.gb_var 34 | and base.variable_name = expected.variable_name 35 | where 36 | {% if target.name == "clickhouse" %} 37 | {# This has poor precision for Clickhouse; not much I can do about it. #} 38 | abs(base.coefficient - expected.coefficient) > 0.1 39 | {% else %} 40 | abs(base.coefficient - expected.coefficient) > {{ THRESHOLD }} 41 | {% endif %} 42 | or base.coefficient is null 43 | or expected.coefficient is null 44 | -------------------------------------------------------------------------------- /run: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eo pipefail 4 | 5 | export DBT_PROFILES_DIR=./integration_tests/profiles 6 | export DBT_PROJECT_DIR=./integration_tests 7 | export DO_NOT_TRACK=1 8 | 9 | if [ -f .env ]; then 10 | # shellcheck disable=SC2002,SC2046 11 | export $(cat .env | xargs) 12 | fi 13 | 14 | function dbt { 15 | uv run dbt "${@}" 16 | } 17 | 18 | function setup { 19 | uv sync --all-extras 20 | uvx pre-commit install 21 | } 22 | 23 | function docker-run-clickhouse { 24 | docker run \ 25 | -d \ 26 | -p 8123:8123 \ 27 | --name dbt-linreg-clickhouse \ 28 | --ulimit nofile=262144:262144 \ 29 | clickhouse/clickhouse-server 30 | } 31 | 32 | function test { 33 | local target="${1-"duckdb"}" 34 | 35 | if [ -z "${GITHUB_ACTIONS}" ] && [ "${target}" = "postgres" ]; 36 | then 37 | createdb "${POSTGRES_DB-"dbt_linreg"}" || true 38 | fi 39 | 40 | if [ "${target}" = "clickhouse" ]; 41 | then 42 | docker-run-clickhouse || true 43 | fi 44 | 45 | if [ -z "${GITHUB_ACTIONS}" ] && [ "${target}" = "duckdb" ]; 46 | then 47 | rm -f dbt.duckdb 48 | fi 49 | 50 | uv run scripts.py gen-test-cases 51 | dbt deps --target "${target}" 52 | dbt seed --target "${target}" 53 | dbt run --target "${target}" --selector "${target}-selector" 54 | dbt test --target "${target}" --selector "${target}-selector" --store-failures 55 | } 56 | 57 | function lint { 58 | uv run pre-commit run -a 59 | } 60 | 61 | function docs:deploy { 62 | # shellcheck disable=SC2046 63 | uv run $(xargs -I{} echo --with {} < docs/requirements.txt) mkdocs gh-deploy -f docs/mkdocs.yml 64 | } 65 | 66 | function help { 67 | echo "$0 " 68 | echo "Tasks:" 69 | compgen -A function | cat -n 70 | } 71 | 72 | TIMEFORMAT=$'\nTask completed in %3lR' 73 | time "${@:-help}" 74 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | pre-commit: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v4 15 | - uses: pre-commit/action@v3.0.0 16 | integration-tests: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | dbt_core: [1.5.*, 1.8.*] 21 | db_target: [duckdb, postgres, clickhouse] 22 | services: 23 | postgres: 24 | image: ${{ (matrix.db_target == 'postgres') && 'postgres' || '' }} 25 | env: 26 | POSTGRES_USER: postgres 27 | POSTGRES_PASSWORD: postgres 28 | POSTGRES_DB: dbt_linreg 29 | ports: 30 | - 5432:5432 31 | options: >- 32 | --health-cmd pg_isready 33 | --health-interval 10s 34 | --health-timeout 5s 35 | --health-retries 5 36 | steps: 37 | - uses: actions/checkout@v4 38 | - name: Install uv 39 | uses: astral-sh/setup-uv@v5 40 | - name: Setup 41 | run: | 42 | sudo apt-get update 43 | sudo apt-get install 44 | chmod +x ./run 45 | uv venv 46 | uv sync --extra python-dev 47 | uv pip install -U "dbt-core==$DBT_CORE_VERSION" "dbt-${DBT_TARGET}==$DBT_CORE_VERSION" 48 | env: 49 | UV_NO_SYNC: true 50 | DO_NOT_TRACK: 1 51 | DBT_CORE_VERSION: ${{ matrix.dbt_core }} 52 | DBT_TARGET: ${{ matrix.db_target }} 53 | - name: Test 54 | run: ./run test "${DBT_TARGET}" 55 | env: 56 | UV_NO_SYNC: true 57 | DO_NOT_TRACK: 1 58 | DBT_TARGET: ${{ matrix.db_target }} 59 | POSTGRES_HOST: localhost 60 | POSTGRES_USER: postgres 61 | POSTGRES_PASSWORD: postgres 62 | POSTGRES_DB: dbt_linreg 63 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_regression_chol.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 19.757104885315176 as coefficient, 2.992803142237603 as standard_error, 6.601538406078909 as t_statistic 6 | union all 7 | select 'x1' as variable_name, 9.90708767581426 as coefficient, 0.5692826957191374 as standard_error, 17.402755696445837 as t_statistic 8 | union all 9 | select 'x2' as variable_name, 6.187473206056227 as coefficient, 1.0880807259333622 as standard_error, 5.686593888287631 as t_statistic 10 | union all 11 | select 'x3' as variable_name, 19.66874583168642 as coefficient, 0.5601379212447676 as standard_error, 35.11411223146169 as t_statistic 12 | union all 13 | select 'x4' as variable_name, 3.7192417102253468 as coefficient, 0.15560940177101745 as standard_error, 23.901137514160553 as t_statistic 14 | union all 15 | select 'x5' as variable_name, 13.444273483323244 as coefficient, 0.5121595119107619 as standard_error, 26.250168493728488 as t_statistic 16 | 17 | ) 18 | 19 | select 20 | coalesce(base.variable_name, expected.variable_name) as variable_name, 21 | expected.coefficient as expected_coefficient, 22 | base.coefficient as actual_coefficient, 23 | expected.standard_error as expected_standard_error, 24 | base.standard_error as actual_standard_error, 25 | expected.t_statistic as expected_t_statistic, 26 | base.t_statistic as actual_t_statistic 27 | from {{ ref('collinear_matrix_regression_chol') }} as base 28 | full outer join expected 29 | on base.variable_name = expected.variable_name 30 | where 31 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 32 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 33 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 34 | or base.coefficient is null 35 | or base.standard_error is null 36 | or base.t_statistic is null 37 | or expected.coefficient is null 38 | -------------------------------------------------------------------------------- /integration_tests/tests/test_collinear_matrix_regression_chol_unoptimized.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'const' as variable_name, 19.757104885315176 as coefficient, 2.992803142237603 as standard_error, 6.601538406078909 as t_statistic 6 | union all 7 | select 'x1' as variable_name, 9.90708767581426 as coefficient, 0.5692826957191374 as standard_error, 17.402755696445837 as t_statistic 8 | union all 9 | select 'x2' as variable_name, 6.187473206056227 as coefficient, 1.0880807259333622 as standard_error, 5.686593888287631 as t_statistic 10 | union all 11 | select 'x3' as variable_name, 19.66874583168642 as coefficient, 0.5601379212447676 as standard_error, 35.11411223146169 as t_statistic 12 | union all 13 | select 'x4' as variable_name, 3.7192417102253468 as coefficient, 0.15560940177101745 as standard_error, 23.901137514160553 as t_statistic 14 | union all 15 | select 'x5' as variable_name, 13.444273483323244 as coefficient, 0.5121595119107619 as standard_error, 26.250168493728488 as t_statistic 16 | 17 | ) 18 | 19 | select 20 | coalesce(base.variable_name, expected.variable_name) as variable_name, 21 | expected.coefficient as expected_coefficient, 22 | base.coefficient as actual_coefficient, 23 | expected.standard_error as expected_standard_error, 24 | base.standard_error as actual_standard_error, 25 | expected.t_statistic as expected_t_statistic, 26 | base.t_statistic as actual_t_statistic 27 | from {{ ref('collinear_matrix_regression_chol_unoptimized') }} as base 28 | full outer join expected 29 | on base.variable_name = expected.variable_name 30 | where 31 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 32 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 33 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 34 | or base.coefficient is null 35 | or base.standard_error is null 36 | or base.t_statistic is null 37 | or expected.coefficient is null 38 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ### `0.3.1` 4 | 5 | - Fix bug in `vars:` implementation of method options. 6 | 7 | ### `0.3.0` 8 | 9 | - Official support for Clickhouse! 10 | - Rename `format=` and `format_options=` to `output=` and `output_options=` to make the API consistent with **dbt_pca**. 11 | - Allow for setting method and output options globally with `vars:` 12 | 13 | ### `0.2.6` 14 | 15 | - Fix bug with `group_by` on multiple variables; contributed by [@svkohler](https://github.com/dwreeves/dbt_linreg/issues/21). 16 | 17 | ### `0.2.5` 18 | 19 | - Fix bug where `exog` and `group_by` did not handle `str` inputs e.g. `exog="x"`. 20 | - Fix bug where `group_by` for `method='fwl'` with exactly 1 exog variable did not work. (Explanation: `method='fwl'` dispatches to a different macro for the special case of 1 exog variable, and `group_by` was not implemented correctly here.) 21 | - Fix bug where `safe` mode did not work for `method='chol'`. 22 | - Improved docs by hiding everything except `ols()`, improved description of `ols()` macro, and added missing arg. 23 | 24 | ### `0.2.4` 25 | 26 | - Fix minor incompatibility with Redshift; contributed by [@steelcd](https://github.com/steelcd). 27 | 28 | ### `0.2.3` 29 | 30 | - Added Postgres support in integration tests + fixed bugs that prevented Postgres from working. 31 | 32 | ### `0.2.2` 33 | 34 | - Added dbt documentation of the `ols()` macro. 35 | 36 | ### `0.2.1` 37 | 38 | - Added `.dbtignore`. 39 | 40 | ### `0.2.0` 41 | 42 | - Add `chol` method to `dbt_linreg.ols()`, and also set as the default method. (This method is significantly faster than `fwl`, and has a few other benefits.) 43 | - Add standard error column in `long` format for `chol` method. 44 | 45 | ### `0.1.2` 46 | 47 | - Added the ability to turn off/on the constant term with `add_constant: bool = True` kwarg. 48 | - Fixed error that occurred when rendering a 1-variable ridge regression. 49 | 50 | ### `0.1.1` 51 | 52 | - Fixed namespacing issue with CTEs-- all CTEs generated by `dbt_linreg` now start with `_dbt_linreg_`, to reduce namespace conflicts with user generated SQL. 53 | - Locked the dbt-core version requirement to `>=1.2.0` (for now) because one of this package's dependencies (`modules.itertools.combinations`) is not available prior to `1.2.0`. 54 | 55 | ### `0.1.0` 56 | 57 | - Initial release 58 | -------------------------------------------------------------------------------- /integration_tests/tests/test_groups_matrix_regression_chol_optimized.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient, 0.053945103940799474 as standard_error, -1.2166194078844779 as t_statistic 6 | union all 7 | select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient, 0.015209571618398615 as standard_error, 65.12622136954383 as t_statistic 8 | union all 9 | select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient, 0.02906881854690807 as standard_error, 170.2243829590593 as t_statistic 10 | union all 11 | select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient, 0.014337008978330493 as standard_error, 2.178559705108859 as t_statistic 12 | union all 13 | select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient, 0.035587045398501334 as standard_error, 56.529364150464545 as t_statistic 14 | union all 15 | select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient, 0.006731681784764358 as standard_error, 445.1088462064698 as t_statistic 16 | union all 17 | select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient, 0.008744674914389008 as standard_error, 1031.4486907791759 as t_statistic 18 | union all 19 | select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient, 0.0072206704541224525 as standard_error, 2.2368166875178472 as t_statistic 20 | 21 | ) 22 | 23 | select 24 | coalesce(base.variable_name, expected.variable_name) as variable_name, 25 | expected.coefficient as expected_coefficient, 26 | base.coefficient as actual_coefficient, 27 | expected.standard_error as expected_standard_error, 28 | base.standard_error as actual_standard_error, 29 | expected.t_statistic as expected_t_statistic, 30 | base.t_statistic as actual_t_statistic 31 | from {{ ref('groups_matrix_regression_chol_optimized') }} as base 32 | full outer join expected 33 | on 34 | base.gb_var = expected.gb_var 35 | and base.variable_name = expected.variable_name 36 | where 37 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 38 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 39 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 40 | or base.coefficient is null 41 | or base.standard_error is null 42 | or base.t_statistic is null 43 | or expected.coefficient is null 44 | -------------------------------------------------------------------------------- /integration_tests/tests/test_groups_matrix_regression_chol_unoptimized.sql: -------------------------------------------------------------------------------- 1 | with 2 | 3 | expected as ( 4 | 5 | select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient, 0.053945103940799474 as standard_error, -1.2166194078844779 as t_statistic 6 | union all 7 | select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient, 0.015209571618398615 as standard_error, 65.12622136954383 as t_statistic 8 | union all 9 | select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient, 0.02906881854690807 as standard_error, 170.2243829590593 as t_statistic 10 | union all 11 | select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient, 0.014337008978330493 as standard_error, 2.178559705108859 as t_statistic 12 | union all 13 | select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient, 0.035587045398501334 as standard_error, 56.529364150464545 as t_statistic 14 | union all 15 | select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient, 0.006731681784764358 as standard_error, 445.1088462064698 as t_statistic 16 | union all 17 | select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient, 0.008744674914389008 as standard_error, 1031.4486907791759 as t_statistic 18 | union all 19 | select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient, 0.0072206704541224525 as standard_error, 2.2368166875178472 as t_statistic 20 | 21 | ) 22 | 23 | select 24 | coalesce(base.variable_name, expected.variable_name) as variable_name, 25 | expected.coefficient as expected_coefficient, 26 | base.coefficient as actual_coefficient, 27 | expected.standard_error as expected_standard_error, 28 | base.standard_error as actual_standard_error, 29 | expected.t_statistic as expected_t_statistic, 30 | base.t_statistic as actual_t_statistic 31 | from {{ ref('groups_matrix_regression_chol_unoptimized') }} as base 32 | full outer join expected 33 | on 34 | base.gb_var = expected.gb_var 35 | and base.variable_name = expected.variable_name 36 | where 37 | abs(base.coefficient - expected.coefficient) > {{ var("_test_precision_collinear_matrix") }} 38 | or abs(base.standard_error - expected.standard_error) > {{ var("_test_precision_collinear_matrix") }} 39 | or abs(base.t_statistic - expected.t_statistic) > {{ var("_test_precision_collinear_matrix") }} 40 | or base.coefficient is null 41 | or base.standard_error is null 42 | or base.t_statistic is null 43 | or expected.coefficient is null 44 | -------------------------------------------------------------------------------- /macros/linear_regression/ols_impl_special/_ols_1var.sql: -------------------------------------------------------------------------------- 1 | {% macro _ols_1var(table, 2 | endog, 3 | exog, 4 | add_constant=True, 5 | output=None, 6 | output_options=None, 7 | group_by=None, 8 | alpha=None) -%} 9 | {%- set exog_aliased = ['x1'] %} 10 | (with 11 | {%- if alpha %} 12 | _dbt_linreg_cmeans as ( 13 | select 14 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }} 15 | avg({{ endog }}) as y, 16 | avg({{ exog[0] }}) as x1, 17 | count(*) as ct 18 | from 19 | {{ table }} 20 | {%- if group_by %} 21 | group by 22 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }} 23 | {%- endif %} 24 | ), 25 | {%- endif %} 26 | _dbt_linreg_base as ( 27 | select 28 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }} 29 | {%- if alpha and add_constant %} 30 | b.{{ endog }} - _dbt_linreg_cmeans.y as y, 31 | b.{{ exog[0] }} - _dbt_linreg_cmeans.x1 as x1, 32 | {%- else %} 33 | {{ endog }} as y, 34 | b.{{ exog[0] }} as x1, 35 | {%- endif %} 36 | false as fake 37 | from 38 | {{ table }} as b 39 | {%- if alpha %} 40 | {%- if add_constant %} 41 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }} 42 | {%- endif %} 43 | union all 44 | select 45 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 46 | 0 as y, 47 | pow(x1 * ct, 0.5) as x1, 48 | true as fake 49 | from _dbt_linreg_cmeans 50 | {%- endif %} 51 | ), 52 | _dbt_linreg_final_coefs as ( 53 | select 54 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 55 | {%- if add_constant %} 56 | avg({{ dbt_linreg._filter_and_center_if_alpha('b.y', alpha) }}) 57 | - avg({{ dbt_linreg._filter_and_center_if_alpha('b.x1', alpha) }}) * {{ dbt_linreg.regress('b.y', 'b.x1') }} 58 | as x0_coef, 59 | {%- endif %} 60 | {{ dbt_linreg.regress('b.y', 'b.x1', add_constant=add_constant) }} as x1_coef 61 | from _dbt_linreg_base as b 62 | {%- if alpha and add_constant %} 63 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }} 64 | {%- endif %} 65 | {%- if group_by %} 66 | group by 67 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }} 68 | {%- endif %} 69 | ) 70 | {{ 71 | dbt_linreg.final_select( 72 | exog=exog, 73 | exog_aliased=['x1'], 74 | add_constant=add_constant, 75 | group_by=group_by, 76 | output=format, 77 | output_options=output_options, 78 | calculate_standard_error=False 79 | ) 80 | }} 81 | ) 82 | {%- endmacro %} 83 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | integration_tests/package-lock.yml 2 | dbt.duckdb 3 | dbt.duckdb.wal 4 | .user.yml 5 | docs/site/ 6 | integration_tests/seeds/*.csv 7 | integration_tests/package-lock.yml 8 | 9 | # Created by https://www.toptal.com/developers/gitignore/api/dbt,osx,windows,visualstudiocode,python 10 | # Edit at https://www.toptal.com/developers/gitignore?templates=dbt,osx,windows,visualstudiocode,python 11 | 12 | ### dbt ### 13 | target/ 14 | dbt_modules/ 15 | dbt_packages/ 16 | logs/ 17 | 18 | ### OSX ### 19 | # General 20 | .DS_Store 21 | .AppleDouble 22 | .LSOverride 23 | 24 | # Icon must end with two \r 25 | Icon 26 | 27 | 28 | # Thumbnails 29 | ._* 30 | 31 | # Files that might appear in the root of a volume 32 | .DocumentRevisions-V100 33 | .fseventsd 34 | .Spotlight-V100 35 | .TemporaryItems 36 | .Trashes 37 | .VolumeIcon.icns 38 | .com.apple.timemachine.donotpresent 39 | 40 | # Directories potentially created on remote AFP share 41 | .AppleDB 42 | .AppleDesktop 43 | Network Trash Folder 44 | Temporary Items 45 | .apdisk 46 | 47 | ### Python ### 48 | # Byte-compiled / optimized / DLL files 49 | __pycache__/ 50 | *.py[cod] 51 | *$py.class 52 | 53 | # C extensions 54 | *.so 55 | 56 | # Distribution / packaging 57 | .Python 58 | build/ 59 | develop-eggs/ 60 | dist/ 61 | downloads/ 62 | eggs/ 63 | .eggs/ 64 | lib/ 65 | lib64/ 66 | parts/ 67 | sdist/ 68 | var/ 69 | wheels/ 70 | share/python-wheels/ 71 | *.egg-info/ 72 | .installed.cfg 73 | *.egg 74 | MANIFEST 75 | 76 | # PyInstaller 77 | # Usually these files are written by a python script from a template 78 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 79 | *.manifest 80 | *.spec 81 | 82 | # Installer logs 83 | pip-log.txt 84 | pip-delete-this-directory.txt 85 | 86 | # Unit test / coverage reports 87 | htmlcov/ 88 | .tox/ 89 | .nox/ 90 | .coverage 91 | .coverage.* 92 | .cache 93 | nosetests.xml 94 | coverage.xml 95 | *.cover 96 | *.py,cover 97 | .hypothesis/ 98 | .pytest_cache/ 99 | cover/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | db.sqlite3 109 | db.sqlite3-journal 110 | 111 | # Flask stuff: 112 | instance/ 113 | .webassets-cache 114 | 115 | # Scrapy stuff: 116 | .scrapy 117 | 118 | # Sphinx documentation 119 | docs/_build/ 120 | 121 | # PyBuilder 122 | .pybuilder/ 123 | 124 | # Jupyter Notebook 125 | .ipynb_checkpoints 126 | 127 | # IPython 128 | profile_default/ 129 | ipython_config.py 130 | 131 | # pyenv 132 | # For a library or package, you might want to ignore these files since the code is 133 | # intended to run in multiple environments; otherwise, check them in: 134 | # .python-version 135 | 136 | # pipenv 137 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 138 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 139 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 140 | # install all needed dependencies. 141 | #Pipfile.lock 142 | 143 | # poetry 144 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 145 | # This is especially recommended for binary packages to ensure reproducibility, and is more 146 | # commonly ignored for libraries. 147 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 148 | #poetry.lock 149 | 150 | # pdm 151 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 152 | #pdm.lock 153 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 154 | # in version control. 155 | # https://pdm.fming.dev/#use-with-ide 156 | .pdm.toml 157 | 158 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 159 | __pypackages__/ 160 | 161 | # Celery stuff 162 | celerybeat-schedule 163 | celerybeat.pid 164 | 165 | # SageMath parsed files 166 | *.sage.py 167 | 168 | # Environments 169 | .env 170 | .venv 171 | env/ 172 | venv/ 173 | ENV/ 174 | env.bak/ 175 | venv.bak/ 176 | 177 | # Spyder project settings 178 | .spyderproject 179 | .spyproject 180 | 181 | # Rope project settings 182 | .ropeproject 183 | 184 | # mkdocs documentation 185 | /site 186 | 187 | # mypy 188 | .mypy_cache/ 189 | .dmypy.json 190 | dmypy.json 191 | 192 | # Pyre type checker 193 | .pyre/ 194 | 195 | # pytype static type analyzer 196 | .pytype/ 197 | 198 | # Cython debug symbols 199 | cython_debug/ 200 | 201 | # PyCharm 202 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 203 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 204 | # and can be added to the global gitignore or merged into this file. For a more nuclear 205 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 206 | #.idea/ 207 | 208 | ### Python Patch ### 209 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 210 | poetry.toml 211 | 212 | # ruff 213 | .ruff_cache/ 214 | 215 | ### VisualStudioCode ### 216 | .vscode/* 217 | !.vscode/settings.json 218 | !.vscode/tasks.json 219 | !.vscode/launch.json 220 | !.vscode/extensions.json 221 | !.vscode/*.code-snippets 222 | 223 | # Local History for Visual Studio Code 224 | .history/ 225 | 226 | # Built Visual Studio Code Extensions 227 | *.vsix 228 | 229 | ### VisualStudioCode Patch ### 230 | # Ignore all local history of files 231 | .history 232 | .ionide 233 | 234 | ### Windows ### 235 | # Windows thumbnail cache files 236 | Thumbs.db 237 | Thumbs.db:encryptable 238 | ehthumbs.db 239 | ehthumbs_vista.db 240 | 241 | # Dump file 242 | *.stackdump 243 | 244 | # Folder config file 245 | [Dd]esktop.ini 246 | 247 | # Recycle Bin used on file shares 248 | $RECYCLE.BIN/ 249 | 250 | # Windows Installer files 251 | *.cab 252 | *.msi 253 | *.msix 254 | *.msm 255 | *.msp 256 | 257 | # Windows shortcuts 258 | *.lnk 259 | 260 | # End of https://www.toptal.com/developers/gitignore/api/dbt,osx,windows,visualstudiocode,python 261 | -------------------------------------------------------------------------------- /macros/linear_regression/ols.sql: -------------------------------------------------------------------------------- 1 | {% macro ols(table, 2 | endog=none, 3 | exog=none, 4 | x=none, 5 | y=none, 6 | add_constant=true, 7 | output='wide', 8 | output_options=none, 9 | format=none, 10 | format_options=none, 11 | group_by=none, 12 | alpha=none, 13 | method=none, 14 | method_options=none) -%} 15 | 16 | {############################################################################# 17 | 18 | This function does 3 things: 19 | 20 | 1. Resolves and casts polymorphic inputs. 21 | 1. Validates inputs. 22 | 2. Dispatches the appropriate call. 23 | 24 | The actual calculations occur elsewhere in the code, depending on the 25 | implementation chosen. 26 | 27 | #############################################################################} 28 | 29 | {# Format the variables, and cast strings to lists #} 30 | {# ----------------------------------------------- #} 31 | 32 | {% if x is not none and exog is none %} 33 | {% set exog = x %} 34 | {% elif x is not none and exog is not none %} 35 | {{ exceptions.raise_compiler_error( 36 | "Please specify either `exog` (preferred) or `x`, not both." 37 | " `x` is just an alias for `exog`." 38 | ) }} 39 | {% endif %} 40 | 41 | {% if format_options is not none and output_options is not none %} 42 | {{ exceptions.raise_compiler_error( 43 | "`format_options` is deprecated and is another name for `output_options`." 44 | " Please only set the `output_options`." 45 | ) }} 46 | {% endif %} 47 | 48 | 49 | {% if format is not none and output is not none %} 50 | {{ exceptions.raise_compiler_error( 51 | "`format` is deprecated and is another name for `output`." 52 | " Please only set the `output`." 53 | ) }} 54 | {% elif format is not none and output is none %} 55 | {% set output = format %} 56 | {% endif %} 57 | 58 | {% if output_options is none %} 59 | {% if format_options is not none %} 60 | {% set output_options = format_options %} 61 | {% else %} 62 | {% set output_options = {} %} 63 | {% endif %} 64 | {% endif %} 65 | 66 | {% if method_options is none %} 67 | {% set method_options = {} %} 68 | {% endif %} 69 | 70 | {% if y is not none and endog is none %} 71 | {% set endog = y %} 72 | {% elif y is not none and endog is not none %} 73 | {{ exceptions.raise_compiler_error( 74 | "Please specify either `endog` (preferred) or `y`, not both." 75 | " `y` is just an alias for `endog`." 76 | ) }} 77 | {% endif %} 78 | 79 | {% if exog is not iterable %} 80 | {% if exog is none %} 81 | {% set exog = [] %} 82 | {% else %} 83 | {% set exog = [exog] %} 84 | {% endif %} 85 | {% elif exog is string %} 86 | {% set exog = [exog] %} 87 | {% endif %} 88 | 89 | {% if group_by is not iterable %} 90 | {% if group_by is none %} 91 | {% set group_by = [] %} 92 | {% else %} 93 | {% set group_by = [group_by] %} 94 | {% endif %} 95 | {% elif group_by is string %} 96 | {% set group_by = [group_by] %} 97 | {% endif %} 98 | 99 | {% if alpha is not iterable and alpha is not none %} 100 | {% set alpha = [alpha] * (exog | length) %} 101 | {% endif %} 102 | 103 | {% if method is none %} 104 | {% set method = 'chol' %} 105 | {% endif %} 106 | 107 | {# Check for user input errors #} 108 | {# --------------------------- #} 109 | 110 | {% if endog is none %} 111 | {{ exceptions.raise_compiler_error( 112 | "`endog` is not allowed to be None." 113 | " Please specify a target variable / y-variable / endogenous variable for" 114 | " the linear regression." 115 | ) }} 116 | {% endif %} 117 | 118 | {% if not exog and not add_constant %} 119 | {{ exceptions.raise_compiler_error( 120 | "Cannot run dbt_linreg.ols() because there are no exogenous variables" 121 | " / features to regress!" 122 | ) }} 123 | {% endif %} 124 | 125 | {% for i in range((exog | length)) %} 126 | {% for j in range(i, (exog | length)) %} 127 | {% if i != j %} 128 | {% if exog[i] == exog[j] %} 129 | {% if not alpha %} 130 | {{ exceptions.raise_compiler_error( 131 | "Duplicate variables are not allowed without regularization, as" 132 | " that will lead to multicollinearity. Duplicate variable is: " 133 | ~ exog[i] ~ ", which occurs at positions " ~ i ~ " and " ~ j ~ "." 134 | ) }} 135 | {% else %} 136 | {% do log( 137 | "Note: exog variable " ~ exog[i] ~ " is duplicated at positions " 138 | ~ i ~ " and " ~ j ~ "." 139 | ) %} 140 | {% endif %} 141 | {% endif %} 142 | {% endif %} 143 | {% endfor %} 144 | {% endfor %} 145 | 146 | {% if format is not none and format not in ['wide', 'long'] %} 147 | {{ exceptions.raise_compiler_error( 148 | "Format must be either 'wide' or 'long'; received " ~ output ~ "." 149 | " Also, `format=` is deprecated; it is suggested you use `output=` instead." 150 | ) }} 151 | {% endif %} 152 | 153 | {% if output not in ['wide', 'long'] %} 154 | {{ exceptions.raise_compiler_error( 155 | "Output must be either 'wide' or 'long'; received " ~ output ~ "." 156 | ) }} 157 | {% endif %} 158 | 159 | {% if alpha is not none and (alpha | length) != (exog | length) %} 160 | {{ exceptions.raise_compiler_error( 161 | "The number of values passed in `alpha` must be equivalent to" 162 | " the number of columns in `exog`." 163 | " Received " ~ (exog | length) ~ " exog variables and " ~ (alpha | length) ~ 164 | " alpha parameters." 165 | " Note that the constant term cannot be penalized." 166 | ) }} 167 | {% endif %} 168 | 169 | {% if method == 'chol' %} 170 | {{ return( 171 | dbt_linreg._ols_chol( 172 | table=table, 173 | endog=endog, 174 | exog=exog, 175 | add_constant=add_constant, 176 | output=output, 177 | output_options=output_options, 178 | group_by=group_by, 179 | alpha=alpha, 180 | method_options=method_options 181 | ) 182 | ) }} 183 | {% elif method == 'fwl' %} 184 | {{ return( 185 | dbt_linreg._ols_fwl( 186 | table=table, 187 | endog=endog, 188 | exog=exog, 189 | add_constant=add_constant, 190 | output=output, 191 | output_options=output_options, 192 | group_by=group_by, 193 | alpha=alpha, 194 | method_options=method_options 195 | ) 196 | ) }} 197 | {% else %} 198 | {{ exceptions.raise_compiler_error( 199 | "Invalid method specified. The only valid methods are 'chol' and 'fwl'" 200 | ) }} 201 | {% endif %} 202 | 203 | {% endmacro %} 204 | -------------------------------------------------------------------------------- /macros/schema.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | macros: 4 | - name: ols 5 | description: |- 6 | **dbt_linreg** is an easy way to perform linear regression and ridge regression in SQL with OLS. 7 | 8 | The `dbt_linreg.ols()` macro is the core, high-level API for the **dbt_linreg** package. This macro will calculate and output the coefficients of a linear regression specified by the user. The regression can also be L2 regularized using the `alpha` argument, i.e. ridge regression is also supported. 9 | 10 | Here is an example of a dbt model that selects from a dbt model called `simple_matrix`, and runs a regression on `y` using feature columns `xa`, `xb`, and `xc`: 11 | 12 | {% raw %} 13 | ```sql 14 | {{ 15 | config( 16 | materialized="table" 17 | ) 18 | }} 19 | select * from {{ 20 | dbt_linreg.ols( 21 | table=ref('simple_matrix'), 22 | endog='y', 23 | exog=['xa', 'xb', 'xc'], 24 | output='long', 25 | output_options={'round': 5} 26 | ) 27 | }} 28 | ``` 29 | {% endraw %} 30 | 31 | You may also select from a CTE; in this case, just pass a string referencing the CTE: 32 | 33 | {% raw %} 34 | ```sql 35 | {{ 36 | config( 37 | materialized="table" 38 | ) 39 | }} 40 | with my_data as ( 41 | select * from {{ ref('simple_matrix') }} 42 | ) 43 | select * from {{ 44 | dbt_linreg.ols( 45 | table='my_data', 46 | endog='y', 47 | exog=['xa', 'xb', 'xc'], 48 | output='long', 49 | output_options={'round': 5} 50 | ) 51 | }} 52 | ``` 53 | {% endraw %} 54 | 55 | The macro renders a subquery, inclusive of parentheses. 56 | 57 | Please see the README / full documentation for more information: [https://dwreeves.github.io/dbt_linreg/](https://dwreeves.github.io/dbt_linreg/) 58 | arguments: 59 | - name: table 60 | type: string 61 | description: Name of table or CTE to pull the data from. You can use `ref()` or `source()` here if you'd like. 62 | - name: endog 63 | type: string 64 | description: The endogenous variable / y variable / target variable of the regression. (You can also specify `y=...` instead of `endog=...` if you prefer.) 65 | - name: exog 66 | type: string or list of strings 67 | description: The exogenous variables / X variables / features of the regression. (You can also specify `x=...` instead of `exog=...` if you prefer.) 68 | - name: add_constant 69 | type: boolean 70 | description: 'If true, a constant term is added automatically to the regression. (Default: `true`)' 71 | - name: output 72 | type: string 73 | description: |- 74 | Either "wide" or "long" format for coefficients. See **Formats and format options** in the README for more. 75 | - If `wide`, the variables span the columns with their original variable names, and the coefficients fill a single row. 76 | - If `long`, the coefficients are in a single column called `coefficient`, and the variable names are in a single column called `variable_name`. 77 | - name: output_options 78 | type: dict 79 | description: See **Outputs and output options** section in the README for more. 80 | - name: group_by 81 | type: string or list of numbers 82 | description: If specified, the regression will be grouped by these variables, and individual regressions will run on each group. 83 | - name: alpha 84 | type: number or list of numbers 85 | description: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section in the README for more information. 86 | - name: method 87 | type: string 88 | description: The method used to calculate the regression. Only `chol` and `fwl` are valid inputs for now. See **Methods and method options** in the README for more. 89 | - name: method_options 90 | type: dict 91 | description: Options specific to the estimation method. See **Methods and method options** in the README for more. 92 | # Everything down here is just for intermediary calculations or helper functions. 93 | # There is no point to showing these in the docs. 94 | # The truly curious can just look at the source code. 95 | # 96 | # Please generate the below with the following command: 97 | # >>> python scripts.py gen-hide-macros-yaml 98 | - name: _alias_exog 99 | docs: 100 | show: false 101 | - name: _alias_gb_cols 102 | docs: 103 | show: false 104 | - name: _cell_or_alias 105 | docs: 106 | show: false 107 | - name: _cholesky_decomposition 108 | docs: 109 | show: false 110 | - name: _filter_and_center_if_alpha 111 | docs: 112 | show: false 113 | - name: _filter_if_alpha 114 | docs: 115 | show: false 116 | - name: _format_wide_variable_column 117 | docs: 118 | show: false 119 | - name: _forward_substitution 120 | docs: 121 | show: false 122 | - name: _gb_cols 123 | docs: 124 | show: false 125 | - name: _get_method_option 126 | docs: 127 | show: false 128 | - name: _get_output_option 129 | docs: 130 | show: false 131 | - name: _join_on_groups 132 | docs: 133 | show: false 134 | - name: _maybe_round 135 | docs: 136 | show: false 137 | - name: _ols_0var 138 | docs: 139 | show: false 140 | - name: _ols_1var 141 | docs: 142 | show: false 143 | - name: _ols_chol 144 | docs: 145 | show: false 146 | - name: _ols_fwl 147 | docs: 148 | show: false 149 | - name: _orth_x_intercept 150 | docs: 151 | show: false 152 | - name: _orth_x_slope 153 | docs: 154 | show: false 155 | - name: _regress_or_alias 156 | docs: 157 | show: false 158 | - name: _safe_sqrt 159 | docs: 160 | show: false 161 | - name: _strip_quotes 162 | docs: 163 | show: false 164 | - name: _traverse_intercepts 165 | docs: 166 | show: false 167 | - name: _traverse_slopes 168 | docs: 169 | show: false 170 | - name: _unalias_gb_cols 171 | docs: 172 | show: false 173 | - name: bigquery___safe_sqrt 174 | docs: 175 | show: false 176 | - name: clickhouse___cell_or_alias 177 | docs: 178 | show: false 179 | - name: default___cell_or_alias 180 | docs: 181 | show: false 182 | - name: default___maybe_round 183 | docs: 184 | show: false 185 | - name: default___regress_or_alias 186 | docs: 187 | show: false 188 | - name: default___safe_sqrt 189 | docs: 190 | show: false 191 | - name: default__regress 192 | docs: 193 | show: false 194 | - name: duckdb___cell_or_alias 195 | docs: 196 | show: false 197 | - name: duckdb___regress_or_alias 198 | docs: 199 | show: false 200 | - name: final_select 201 | docs: 202 | show: false 203 | - name: postgres___maybe_round 204 | docs: 205 | show: false 206 | - name: redshift___maybe_round 207 | docs: 208 | show: false 209 | - name: regress 210 | docs: 211 | show: false 212 | - name: snowflake___cell_or_alias 213 | docs: 214 | show: false 215 | - name: snowflake___regress_or_alias 216 | docs: 217 | show: false 218 | - name: snowflake__regress 219 | docs: 220 | show: false 221 | -------------------------------------------------------------------------------- /macros/linear_regression/ols_impl_fwl/_ols_impl_fwl.sql: -------------------------------------------------------------------------------- 1 | {% macro _traverse_slopes(step, x) -%} 2 | {%- set li = [] %} 3 | {%- for i in x %} 4 | {%- for j in x %} 5 | {%- set remaining = x.copy() %} 6 | {%- if i != j %} 7 | {%- do remaining.remove(i) %} 8 | {%- do remaining.remove(j) %} 9 | {%- set comb = modules.itertools.combinations(remaining, step - 1) %} 10 | {%- for c in comb %} 11 | {%- do li.append((i, j, c)) %} 12 | {%- endfor %} 13 | {%- endif %} 14 | {%- endfor %} 15 | {%- endfor %} 16 | {{ return(li) }} 17 | {% endmacro %} 18 | 19 | {% macro _traverse_intercepts(step, x) -%} 20 | {%- set li = [] %} 21 | {%- for i in x %} 22 | {%- set remaining = x.copy() %} 23 | {%- do remaining.remove(i) %} 24 | {%- set comb = modules.itertools.combinations(remaining, step) %} 25 | {%- for c in comb %} 26 | {%- set ortho = [] %} 27 | {%- if c %} 28 | {%- for b in c %} 29 | {%- set _c = (c | list) %} 30 | {%- do _c.remove(b) %} 31 | {%- do ortho.append([b] + (modules.itertools.combinations(_c, step - 1) | list)) %} 32 | {%- endfor %} 33 | {%- endif %} 34 | {%- do li.append((i, ortho)) %} 35 | {%- endfor %} 36 | {%- endfor %} 37 | {{ return(li) }} 38 | {% endmacro %} 39 | 40 | {% macro _filter_if_alpha(i, alpha) %} 41 | {% if alpha %} 42 | {{ return('case when not fake then ' ~ i ~ ' end') }} 43 | {% else %} 44 | {{ return(i) }} 45 | {% endif %} 46 | {% endmacro %} 47 | 48 | {% macro _filter_and_center_if_alpha(i, alpha, base_prefix='') %} 49 | {% if alpha %} 50 | {{ return('case when not fake then ' ~ base_prefix ~ i ~ ' + _dbt_linreg_cmeans.' ~ i ~ ' end') }} 51 | {% else %} 52 | {{ return(i) }} 53 | {% endif %} 54 | {% endmacro %} 55 | 56 | {% macro _orth_x_slope(x, o) -%} 57 | {%- if o %} 58 | {{ return(x ~ '_' ~ (o | join(''))) }} 59 | {%- else %} 60 | {{ return(x) }} 61 | {%- endif %} 62 | {% endmacro %} 63 | 64 | {% macro _orth_x_intercept(x, o) -%} 65 | {%- set li = [] %} 66 | {%- for c in o %} 67 | {%- do li.append(c[0]) %} 68 | {%- endfor %} 69 | {{ return(x ~ '_' ~ (li | join(''))) }} 70 | {% endmacro %} 71 | 72 | {% macro _regress_or_alias(y, x, add_constant=True) %} 73 | {{ return( 74 | adapter.dispatch('_regress_or_alias', 'dbt_linreg') 75 | (y, x, add_constant=add_constant) 76 | ) }} 77 | {% endmacro %} 78 | 79 | {# In some but not all query engines, you can select from other columns. 80 | Doing this keeps the compiled SQL cleaner, and for large regressions can 81 | slightly improve the query planner speed (albeit not the execution). #} 82 | {% macro default___regress_or_alias(y, x, add_constant=True) %} 83 | {{ return(dbt_linreg.regress(y, x, add_constant=add_constant)) }} 84 | {% endmacro %} 85 | 86 | {% macro snowflake___regress_or_alias(y, x, add_constant=True) %} 87 | {{ return(y ~ '_' ~ x ~ '_coef') }} 88 | {% endmacro %} 89 | 90 | {% macro duckdb___regress_or_alias(y, x, add_constant=True) %} 91 | {{ return(y ~ '_' ~ x ~ '_coef') }} 92 | {% endmacro %} 93 | 94 | 95 | {% macro _ols_fwl(table, 96 | endog, 97 | exog, 98 | add_constant=True, 99 | output=None, 100 | output_options=None, 101 | group_by=None, 102 | alpha=None, 103 | method_options=None) -%} 104 | {%- if (exog | length) == 0 %} 105 | {% do log('Note: exog was empty; running regression on constant term only.') %} 106 | {{ return(dbt_linreg._ols_0var( 107 | table=table, 108 | endog=endog, 109 | exog=exog, 110 | add_constant=add_constant, 111 | output=output, 112 | output_options=output_options, 113 | group_by=group_by, 114 | alpha=alpha 115 | )) }} 116 | {%- elif (exog | length) == 1 %} 117 | {{ return(dbt_linreg._ols_1var( 118 | table=table, 119 | endog=endog, 120 | exog=exog, 121 | add_constant=add_constant, 122 | output=output, 123 | output_options=output_options, 124 | group_by=group_by, 125 | alpha=alpha 126 | )) }} 127 | {%- endif %} 128 | {%- set exog_aliased = dbt_linreg._alias_exog(exog) %} 129 | (with 130 | {%- if alpha %} 131 | _dbt_linreg_cmeans as ( 132 | select 133 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }} 134 | avg({{ endog }}) as y, 135 | {%- for i in exog_aliased %} 136 | avg({{ exog[loop.index0] }}) as {{ i }}, 137 | {%- endfor %} 138 | count(*) as ct 139 | from 140 | {{ table }} 141 | {%- if group_by %} 142 | group by 143 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }} 144 | {%- endif %} 145 | ), 146 | {%- endif %} 147 | _dbt_linreg_step0 as ( 148 | select 149 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }} 150 | {%- if alpha and add_constant %} 151 | b.{{ endog }} - _dbt_linreg_cmeans.y as y, 152 | {%- for i in exog_aliased %} 153 | b.{{ exog[loop.index0] }} - _dbt_linreg_cmeans.{{ i }} as {{ i }}, 154 | {%- endfor %} 155 | {%- else %} 156 | {{ endog }} as y, 157 | {%- for i in exog_aliased %} 158 | b.{{ exog[loop.index0] }} as {{ i }}, 159 | {%- endfor %} 160 | {%- endif %} 161 | false as fake 162 | from 163 | {{ table }} as b 164 | {%- if alpha %} 165 | {%- if add_constant %} 166 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }} 167 | {%- endif %} 168 | {%- for i in exog_aliased %} 169 | {%- set i_idx = loop.index0 %} 170 | union all 171 | select 172 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 173 | 0 as y, 174 | {%- for j in exog_aliased %} 175 | {%- if i == j %} 176 | pow({{ alpha[i_idx] }} * ct, 0.5) as {{ j }}, 177 | {%- else %} 178 | 0 as {{ j }}, 179 | {%- endif %} 180 | {%- endfor %} 181 | true as fake 182 | from _dbt_linreg_cmeans as cmeans 183 | {%- endfor %} 184 | {%- endif %} 185 | ), 186 | {% for step in range(1, (exog | length)) %} 187 | _dbt_linreg_step{{ step }} as ( 188 | with 189 | __dbt_linreg_coefs{{ step }} as ( 190 | select 191 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(6) }} 192 | {#- Slope terms #} 193 | {%- for _y, _x, _o in dbt_linreg._traverse_slopes(step, exog_aliased) %} 194 | {%- set _c = dbt_linreg._orth_x_slope(_x, _o) %} 195 | {{ dbt_linreg.regress(_y, _c, add_constant=add_constant) }} as {{ _y }}_{{ _c }}_coef, 196 | {%- endfor %} 197 | {#- Constant terms #} 198 | {%- if add_constant %} 199 | {%- for _y, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} 200 | avg({{ dbt_linreg._filter_if_alpha(_y, alpha) }}) 201 | {%- for _yi, _xi in _o %} 202 | {%- set _ci = dbt_linreg._orth_x_slope(_yi, _xi) %} 203 | - avg({{ dbt_linreg._filter_if_alpha(_yi, alpha) }}) * {{ dbt_linreg._regress_or_alias(_y, _ci) }} 204 | {%- endfor %} 205 | as {{ dbt_linreg._orth_x_intercept(_y, _o) }}_const 206 | {%- if not loop.last -%} 207 | , 208 | {%- endif -%} 209 | {%- endfor %} 210 | {%- endif %} 211 | from _dbt_linreg_step{{ step - 1 }} 212 | {%- if group_by %} 213 | group by 214 | {{ dbt_linreg._gb_cols(group_by) | indent(6) }} 215 | {%- endif %} 216 | ) 217 | select 218 | {{ dbt_linreg._gb_cols(group_by, prefix='b', trailing_comma=True) | indent(4) }} 219 | y, 220 | {%- for i in exog_aliased %} 221 | {{ i }}, 222 | {%- endfor %} 223 | fake, 224 | {%- for _y, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} 225 | {{ _y }} 226 | {%- for _yi, _xi in _o %} 227 | {%- set _ci = dbt_linreg._orth_x_slope(_yi, _xi) %} 228 | - {{ _y }}_{{ _ci }}_coef * {{ _yi }} 229 | {%- endfor %} 230 | {%- set _c = dbt_linreg._orth_x_intercept(_y, _o) %} 231 | {%- if add_constant %} 232 | - {{ _c }}_const 233 | {%- endif %} 234 | as {{ _c }} 235 | {%- if not loop.last -%} 236 | , 237 | {%- endif %} 238 | {%- endfor %} 239 | from _dbt_linreg_step0 as b 240 | {{ dbt_linreg._join_on_groups(group_by, 'b', '__dbt_linreg_coefs'~step) | indent(2) }} 241 | ), 242 | {%- if loop.last %} 243 | _dbt_linreg_final_coefs as ( 244 | select 245 | {%- if add_constant %} 246 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 247 | avg({{ dbt_linreg._filter_and_center_if_alpha('y', alpha, base_prefix='b.') }}) 248 | {%- for _x, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} 249 | - avg({{ dbt_linreg._filter_and_center_if_alpha(_x, alpha, base_prefix='b.') }}) * {{ dbt_linreg.regress('b.y', dbt_linreg._orth_x_intercept('b.' ~ _x, _o)) }} 250 | {%- endfor %} 251 | as x0_coef, 252 | {%- endif %} 253 | {%- for _x, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} 254 | {{ dbt_linreg.regress('b.y', dbt_linreg._orth_x_intercept(_x, _o), add_constant=add_constant) }} as {{ _x }}_coef 255 | {%- if not loop.last -%} 256 | , 257 | {%- endif %} 258 | {%- endfor %} 259 | from _dbt_linreg_step{{ step }} as b 260 | {%- if alpha and add_constant %} 261 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_cmeans') | indent(2) }} 262 | {%- endif %} 263 | {%- if group_by %} 264 | group by 265 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }} 266 | {%- endif %} 267 | ) 268 | {%- endif %} 269 | {%- endfor %} 270 | {{ 271 | dbt_linreg.final_select( 272 | exog=exog, 273 | exog_aliased=exog_aliased, 274 | add_constant=add_constant, 275 | group_by=group_by, 276 | output=output, 277 | output_options=output_options, 278 | calculate_standard_error=False 279 | ) 280 | }} 281 | ) 282 | {%- endmacro %} 283 | -------------------------------------------------------------------------------- /macros/linear_regression/utils/utils.sql: -------------------------------------------------------------------------------- 1 | {############################################################################### 2 | ## Simple univariate regression. 3 | ###############################################################################} 4 | 5 | {% macro regress(y, x, add_constant=True) %} 6 | {{ return( 7 | adapter.dispatch('regress', 'dbt_linreg') 8 | (y, x, add_constant=add_constant) 9 | ) }} 10 | {% endmacro %} 11 | 12 | {% macro default__regress(y, x, add_constant=True) -%} 13 | {%- if add_constant -%} 14 | covar_pop({{ x }}, {{ y }}) / var_pop({{ x }}) 15 | {%- else -%} 16 | sum({{ x }} * {{ y }}) / sum({{ x }} * {{ x }}) 17 | {%- endif -%} 18 | {%- endmacro %} 19 | 20 | {% macro snowflake__regress(y, x, add_constant=True) -%} 21 | {%- if add_constant -%} 22 | regr_slope({{ x }}, {{ y }}) 23 | {%- else -%} 24 | sum({{ x }} * {{ y }}) / sum({{ x }} * {{ x }}) 25 | {%- endif -%} 26 | {%- endmacro %} 27 | 28 | {############################################################################### 29 | ## Final select 30 | ###############################################################################} 31 | 32 | {# Every OLS method ends with a "_dbt_linreg_final_coefs" CTE with a common 33 | interface. This interface can then be transformed in a standard way using the 34 | final_select() macro, which formats the output for the user. #} 35 | {% macro final_select(exog=none, 36 | exog_aliased=none, 37 | group_by=none, 38 | add_constant=true, 39 | output=none, 40 | output_options=none, 41 | calculate_standard_error=false) -%} 42 | {%- if output == 'long' %} 43 | {%- if add_constant %} 44 | select 45 | {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }} 46 | {{ dbt.string_literal(dbt_linreg._get_output_option('constant_name', output_options, 'const')) }} as {{ dbt_linreg._get_output_option('variable_column_name', output_options, 'variable_name') }}, 47 | {{ dbt_linreg._maybe_round('x0_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('coefficient_column_name', output_options, 'coefficient') }}{% if calculate_standard_error %}, 48 | {{ dbt_linreg._maybe_round('x0_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('standard_error_column_name', output_options, 'standard_error') }}, 49 | {{ dbt_linreg._maybe_round('x0_coef/x0_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('t_statistic_column_name', output_options, 't_statistic') }} 50 | {%- endif %} 51 | from _dbt_linreg_final_coefs as b 52 | {%- if calculate_standard_error %} 53 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_stderrs') }} 54 | {%- endif %} 55 | {%- if exog_aliased %} 56 | union all 57 | {%- endif %} 58 | {%- endif %} 59 | {%- for i in exog_aliased %} 60 | select 61 | {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }} 62 | {{ dbt.string_literal(dbt_linreg._strip_quotes(exog[loop.index0], output_options)) }} as {{ dbt_linreg._get_output_option('variable_column_name', output_options, 'variable_name') }}, 63 | {{ dbt_linreg._maybe_round(i~'_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('coefficient_column_name', output_options, 'coefficient') }}{% if calculate_standard_error %}, 64 | {{ dbt_linreg._maybe_round(i~'_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('standard_error_column_name', output_options, 'standard_error') }}, 65 | {{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._get_output_option('t_statistic_column_name', output_options, 't_statistic') }} 66 | {%- endif %} 67 | from _dbt_linreg_final_coefs as b 68 | {%- if calculate_standard_error %} 69 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_stderrs') }} 70 | {%- endif %} 71 | {%- if not loop.last %} 72 | union all 73 | {%- endif %} 74 | {%- endfor %} 75 | {%- elif output == 'wide' %} 76 | select 77 | {%- if add_constant -%} 78 | {{ dbt_linreg._unalias_gb_cols(group_by) | indent(2) }} 79 | {{ dbt_linreg._maybe_round('x0_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._format_wide_variable_column(dbt_linreg._get_output_option('constant_name', output_options, 'const'), output_options) }} 80 | {%- if exog_aliased -%} 81 | , 82 | {%- endif -%} 83 | {%- endif -%} 84 | {%- for i in exog_aliased %} 85 | {{ dbt_linreg._maybe_round(i~'_coef', dbt_linreg._get_output_option('round', output_options)) }} as {{ dbt_linreg._format_wide_variable_column(exog[loop.index0], output_options) }} 86 | {%- if not loop.last -%} 87 | , 88 | {%- endif %} 89 | {%- endfor %} 90 | from _dbt_linreg_final_coefs 91 | {%- else %} 92 | {#- Fallback option (which should never happen!) is to just select star. #} 93 | select * from _dbt_linreg_final_coefs 94 | {%- endif %} 95 | {%- endmacro %} 96 | 97 | {############################################################################### 98 | ## Misc. 99 | ###############################################################################} 100 | 101 | {# Users can pass columns such as '"foo"', with the double quotes included. 102 | In this situation, we want to strip the double quotes when presenting 103 | outputs in a long format. #} 104 | {% macro _strip_quotes(x, output_options) -%} 105 | {% if dbt_linreg._get_output_option('strip_quotes', output_options) | default(True) %} 106 | {% if x[0] == '"' and x[-1] == '"' and (x | length) > 1 %} 107 | {{ return(x[1:-1]) }} 108 | {% endif %} 109 | {% endif %} 110 | {{ return(x)}} 111 | {%- endmacro %} 112 | 113 | {% macro _format_wide_variable_column(x, output_options) -%} 114 | {% if x[0] == '"' and x[-1] == '"' and (x | length) > 1 %} 115 | {% set _add_quotes = True %} 116 | {% set x = x[1:-1] %} 117 | {% else %} 118 | {% set _add_quotes = False %} 119 | {% endif %} 120 | {% if dbt_linreg._get_output_option('variable_column_prefix', output_options) %} 121 | {% set x = dbt_linreg._get_output_option('variable_column_prefix', output_options) ~ x %} 122 | {% endif %} 123 | {% if dbt_linreg._get_output_option('variable_column_suffix', output_options) %} 124 | {% set x = x ~ dbt_linreg._get_output_option('variable_column_suffix', output_options) %} 125 | {% endif %} 126 | {% if _add_quotes %} 127 | {% set x = '"' ~ x ~ '"' %} 128 | {% endif %} 129 | {{ return(x)}} 130 | {%- endmacro %} 131 | 132 | {# To ensure no namespace conflicts, f"gb{index}" is used in group by 133 | statements instead of the actual column names. This macro adds aliases. #} 134 | {% macro _alias_gb_cols(group_by) -%} 135 | {%- if group_by %} 136 | {%- for gb in group_by %} 137 | {{ gb }} as gb{{ loop.index }}, 138 | {%- endfor %} 139 | {%- endif %} 140 | {%- endmacro %} 141 | 142 | {# This macros reverses gb column aliases at the end of an OLS query. #} 143 | {% macro _unalias_gb_cols(group_by, prefix=None) -%} 144 | {%- if group_by %} 145 | {%- for gb in group_by %} 146 | {%- if prefix %} 147 | {{ prefix }}.gb{{ loop.index }} as {{ gb }}, 148 | {%- else %} 149 | gb{{ loop.index }} as {{ gb }}, 150 | {%- endif %} 151 | {%- endfor %} 152 | {%- endif %} 153 | {%- endmacro %} 154 | 155 | {# Round the final coefficient if the user specifies the `round` format 156 | option. Otherwise, keep as is. #} 157 | 158 | {% macro _maybe_round(x, round_) %} 159 | {{ return( 160 | adapter.dispatch('_maybe_round', 'dbt_linreg')(x, round_) 161 | ) }} 162 | {% endmacro %} 163 | 164 | {% macro default___maybe_round(x, round_) %} 165 | {% if round_ is not none %} 166 | {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }} 167 | {% else %} 168 | {{ return(x) }} 169 | {% endif %} 170 | {% endmacro %} 171 | 172 | {% macro postgres___maybe_round(x, round_) %} 173 | {% if round_ is not none %} 174 | {{ return('round((' ~ x ~ ')::numeric, ' ~ round_ ~ ')') }} 175 | {% else %} 176 | {{ return('(' ~ x ~ ')::numeric') }} 177 | {% endif %} 178 | {% endmacro %} 179 | 180 | {% macro redshift___maybe_round(x, round_) %} 181 | {% if round_ is not none %} 182 | {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }} 183 | {% else %} 184 | {{ return(x) }} 185 | {% endif %} 186 | {% endmacro %} 187 | 188 | {# Alias and write group by columns in a standard way. #} 189 | {% macro _gb_cols(group_by, trailing_comma=False, prefix=None) -%} 190 | {%- if group_by %} 191 | {%- for gb in group_by %} 192 | {%- if prefix %} 193 | {{ prefix }}.gb{{ loop.index }} 194 | {%- else %} 195 | gb{{ loop.index }} 196 | {%- endif %} 197 | {%- if (not loop.last) or trailing_comma -%} 198 | , 199 | {%- endif %} 200 | {%- endfor %} 201 | {%- endif %} 202 | {%- endmacro %} 203 | 204 | {# Take exog and gen a list containing 'x1', 'x2', etc. #} 205 | {% macro _alias_exog(x) -%} 206 | {% set li = [] %} 207 | {% for i in x %} 208 | {% do li.append('x' ~ loop.index) %} 209 | {% endfor %} 210 | {{ return(li) }} 211 | {%- endmacro %} 212 | 213 | {# Join on gb1, gb2 etc. from a table to another table. 214 | If there is no group by column, assume `join_to` is just 1 row. 215 | And in that case, just do a cross join. #} 216 | {% macro _join_on_groups(group_by, join_from, join_to) -%} 217 | {%- if not group_by %} 218 | cross join {{ join_to }} 219 | {%- else %} 220 | inner join {{ join_to }} 221 | on 222 | {%- for _ in group_by %} 223 | {{ join_from }}.gb{{ loop.index }} = {{ join_to }}.gb{{ loop.index }} 224 | {% if not loop.last -%} 225 | and 226 | {%- endif %} 227 | {%- endfor %} 228 | {%- endif %} 229 | {%- endmacro %} 230 | 231 | {% macro _get_output_option(field, output_options, default=none) %} 232 | {{ return(output_options.get(field, var("dbt_linreg", {}).get("output_options", {}).get(field, default))) }} 233 | {% endmacro %} 234 | 235 | {% macro _get_method_option(method, field, method_options, default=none) %} 236 | {{ return(method_options.get(field, var("dbt_linreg", {}).get("method_options", {}).get(method, {}).get(field, default))) }} 237 | {% endmacro %} 238 | -------------------------------------------------------------------------------- /macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql: -------------------------------------------------------------------------------- 1 | {# In some warehouses, you can reference newly created column aliases 2 | in the query you wrote. 3 | If that's not available, the previous calc will be in the dict. #} 4 | 5 | {% macro _cell_or_alias(i, j, d, prefix=none, isa=none) %} 6 | {% if isa is not none %} 7 | {% if isa %} 8 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }} 9 | {% else %} 10 | {{ return(d[(i, j)]) }} 11 | {% endif %} 12 | {% endif %} 13 | {{ return( 14 | adapter.dispatch('_cell_or_alias', 'dbt_linreg') 15 | (i, j, d, prefix, isa) 16 | ) }} 17 | {% endmacro %} 18 | 19 | {% macro default___cell_or_alias(i, j, d, prefix=none, isa=none) %} 20 | {{ return(d[(i, j)]) }} 21 | {% endmacro %} 22 | 23 | {% macro snowflake___cell_or_alias(i, j, d, prefix=none, isa=none) %} 24 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }} 25 | {% endmacro %} 26 | 27 | {% macro duckdb___cell_or_alias(i, j, d, prefix=none, isa=none) %} 28 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }} 29 | {% endmacro %} 30 | 31 | {% macro clickhouse___cell_or_alias(i, j, d, prefix=none, isa=none) %} 32 | {{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }} 33 | {% endmacro %} 34 | 35 | {% macro _safe_sqrt(x, safe=True) %} 36 | {{ return( 37 | adapter.dispatch('_safe_sqrt', 'dbt_linreg') 38 | (x, safe) 39 | ) }} 40 | {% endmacro %} 41 | 42 | {% macro default___safe_sqrt(x, safe=True) %} 43 | {% if safe %} 44 | {{ return('case when ('~x~') >= 0 then sqrt('~x~') end') }} 45 | {% endif %} 46 | {{ return('sqrt('~x~')') }} 47 | {% endmacro %} 48 | 49 | {% macro bigquery___safe_sqrt(x, safe=True) %} 50 | {% if safe %} 51 | {{ return('safe.sqrt('~x~')') }} 52 | {% endif %} 53 | {{ return('sqrt('~x~')') }} 54 | {% endmacro %} 55 | 56 | {% macro _cholesky_decomposition(li, subquery_optimization=true, safe=true, isa=none) %} 57 | {% set d = {} %} 58 | {% for i in li %} 59 | {% for j in range(li[0], i + 1) %} 60 | {% if i == li[0] and j == li[0] %} 61 | {% do d.update({(i, j): dbt_linreg._safe_sqrt(x='x'~i~'x'~j, safe=safe)}) %} 62 | {% else %} 63 | {% set ns = namespace() %} 64 | {% set ns.s = 'x'~j~'x'~i %} 65 | {% for k in range(li[0], j) %} 66 | {% if subquery_optimization and i != j %} 67 | {% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*i'~j~'j'~k %} 68 | {% else %} 69 | {% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d, isa=isa) %} 70 | {% endif %} 71 | {% endfor %} 72 | {% if i == j %} 73 | {% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %} 74 | {% else %} 75 | {% if safe %} 76 | {% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa) ~ ', 0)'}) %} 77 | {% else %} 78 | {% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa)}) %} 79 | {% endif %} 80 | {% endif %} 81 | {% endif %} 82 | {% endfor %} 83 | {% endfor %} 84 | {{ return(d) }} 85 | {% endmacro %} 86 | 87 | {% macro _forward_substitution(li, safe=true, isa=none) %} 88 | {% set d = {} %} 89 | {% for i, j in modules.itertools.combinations_with_replacement(li, 2) %} 90 | {% set ns = namespace() %} 91 | {% if i == j %} 92 | {% set ns.numerator = '1' %} 93 | {% else %} 94 | {% set ns.numerator = '(' %} 95 | {% for k in range(i, j) %} 96 | {% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_", isa=isa) %} 97 | {% endfor %} 98 | {% set ns.numerator = ns.numerator~')' %} 99 | {% endif %} 100 | {% if safe %} 101 | {% do d.update({(i, j): '('~ns.numerator~'/nullif(i'~j~'j'~j~', 0))'}) %} 102 | {% else %} 103 | {% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %} 104 | {% endif %} 105 | {% endfor %} 106 | {{ return(d) }} 107 | {% endmacro %} 108 | 109 | {% macro _ols_chol(table, 110 | endog, 111 | exog, 112 | add_constant=True, 113 | output=None, 114 | output_options=None, 115 | group_by=None, 116 | alpha=None, 117 | method_options=None) -%} 118 | {%- if (exog | length) == 0 %} 119 | {% do log('Note: exog was empty; running regression on constant term only.') %} 120 | {{ return(dbt_linreg._ols_0var( 121 | table=table, 122 | endog=endog, 123 | exog=exog, 124 | add_constant=add_constant, 125 | output=output, 126 | output_options=output_options, 127 | group_by=group_by, 128 | alpha=alpha 129 | )) }} 130 | {%- endif %} 131 | {%- set subquery_optimization = dbt_linreg._get_method_option('chol', 'subquery_optimization', method_options, true) %} 132 | {%- set safe_mode = dbt_linreg._get_method_option('chol', 'safe', method_options, true) %} 133 | {% set isa = dbt_linreg._get_method_option('chol', 'intra_select_aliasing', method_options) %} 134 | {%- set calculate_standard_error = dbt_linreg._get_output_option('calculate_standard_error', output_options, (not alpha) and output == 'long') %} 135 | {%- if alpha and calculate_standard_error %} 136 | {% do log( 137 | 'Warning: Standard errors are NOT designed to take into account ridge regression regularization.' 138 | ) %} 139 | {%- endif %} 140 | {%- if add_constant %} 141 | {% set xmin = 0 %} 142 | {%- else %} 143 | {% set xmin = 1 %} 144 | {%- endif %} 145 | {%- set xcols = (range(xmin, (exog | length) + 1) | list) %} 146 | {%- set upto = (xcols | length) %} 147 | {%- set exog_aliased = dbt_linreg._alias_exog(exog) %} 148 | (with 149 | _dbt_linreg_base as ( 150 | select 151 | {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }} 152 | {{ endog }} as y, 153 | {%- if add_constant %} 154 | 1 as x0, 155 | {%- endif %} 156 | {%- for i in range(1, (exog | length) + 1) %} 157 | b.{{ exog[loop.index0] }} as x{{ i }} 158 | {%- if not loop.last -%} 159 | , 160 | {%- endif %} 161 | {%- endfor %} 162 | from 163 | {{ table }} as b 164 | ), 165 | _dbt_linreg_xtx as ( 166 | select 167 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 168 | {%- for i, j in modules.itertools.combinations_with_replacement(xcols, 2) %} 169 | {%- if alpha and i == j and i > 0 %} 170 | sum(b.x{{ i }} * b.x{{ j }} + {{ alpha[i-1] }}) as x{{ i }}x{{ j }} 171 | {%- else %} 172 | sum(b.x{{ i }} * b.x{{ j }}) as x{{ i }}x{{ j }} 173 | {%- endif %} 174 | {%- if not loop.last -%} 175 | , 176 | {%- endif %} 177 | {%- endfor %} 178 | from _dbt_linreg_base as b 179 | {%- if group_by %} 180 | group by 181 | {{ dbt_linreg._gb_cols(group_by) | indent(4) }} 182 | {%- endif %} 183 | ), 184 | _dbt_linreg_chol as ( 185 | 186 | {%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode, isa=isa) %} 187 | {%- if subquery_optimization %} 188 | {%- for i in (xcols | reverse) %} 189 | select 190 | *, 191 | {%- for j in range(xmin, i + 1) %} 192 | {{ d[(i, j)] }} as i{{ i }}j{{ j }} 193 | {%- if not loop.last -%} 194 | , 195 | {%- endif %} 196 | {%- endfor %} 197 | {%- if not loop.last %} 198 | from ( 199 | {%- else %} 200 | from _dbt_linreg_xtx{% for close_ct in range(upto - 1) %}) as ic{{ close_ct }}{% endfor %} 201 | {%- endif %} 202 | {%- endfor %} 203 | {%- else %} 204 | select 205 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 206 | {%- for k, v in d.items() %} 207 | {{ v }} as {{ 'i'~k[0]~'j'~k[1] }} 208 | {%- if not loop.last -%} 209 | , 210 | {%- endif %} 211 | {%- endfor %} 212 | from _dbt_linreg_xtx 213 | {%- endif %} 214 | ), 215 | _dbt_linreg_inverse_chol as ( 216 | {#- The optimal way to calculate is to do each diagonal at a time. #} 217 | {%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode, isa=isa) %} 218 | {%- if subquery_optimization %} 219 | {%- for gap in (range(0, upto) | reverse) %} 220 | select *, 221 | {%- for j in range(gap + xmin, upto + xmin) %} 222 | {%- set i = j - gap %} 223 | {{ d[(i, j)] }} as inv_i{{ i }}j{{ j }} 224 | {%- if not loop.last -%} 225 | , 226 | {%- endif %} 227 | {%- endfor %} 228 | {%- if not loop.last %} 229 | from ( 230 | {%- else %} 231 | from _dbt_linreg_chol{% for close_ct in range(upto - 1) %}) as ic{{ close_ct }}{% endfor %} 232 | {%- endif %} 233 | {%- endfor %} 234 | {%- else %} 235 | select 236 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 237 | {%- for k, v in d.items() %} 238 | {{ v }} as inv_{{ 'i'~k[0]~'j'~k[1] }} 239 | {%- if not loop.last -%} 240 | , 241 | {%- endif %} 242 | {%- endfor %} 243 | from _dbt_linreg_chol 244 | {%- endif %} 245 | ), 246 | _dbt_linreg_inverse_xtx as ( 247 | select 248 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} 249 | {%- for i, j in modules.itertools.combinations_with_replacement(xcols, 2) %} 250 | {%- if not add_constant %} 251 | {%- set upto = upto + 1 %} 252 | {%- endif %} 253 | {%- for k in range(j, upto) %} 254 | inv_i{{ i }}j{{ k }} * inv_i{{ j }}j{{ k }}{%- if not loop.last %} + {% endif -%} 255 | {%- endfor %} 256 | as inv_x{{ i }}x{{ j }} 257 | {%- if not loop.last -%} 258 | , 259 | {%- endif %} 260 | {%- endfor %} 261 | from _dbt_linreg_inverse_chol 262 | ), 263 | _dbt_linreg_final_coefs as ( 264 | select 265 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }} 266 | {%- for x1 in xcols %} 267 | sum(( 268 | {%- for x2 in xcols %} 269 | {%- if x2 > x1 %} 270 | b.x{{ x2 }} * inv_x{{ x1 }}x{{ x2 }} 271 | {%- else %} 272 | b.x{{ x2 }} * inv_x{{ x2 }}x{{ x1 }} 273 | {%- endif %} 274 | {%- if not loop.last %} + {% endif -%} 275 | {%- endfor %} 276 | ) * b.y) as x{{ x1 }}_coef 277 | {%- if not loop.last -%} 278 | , 279 | {%- endif %} 280 | {%- endfor %} 281 | from 282 | _dbt_linreg_base as b 283 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_inverse_xtx') | indent(2) }} 284 | {%- if group_by %} 285 | group by 286 | {{ dbt_linreg._gb_cols(group_by, prefix='b') | indent(4) }} 287 | {%- endif %} 288 | ){%- if calculate_standard_error %}, 289 | _dbt_linreg_resid as ( 290 | select 291 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }} 292 | avg(pow(y 293 | {%- for x in xcols %} 294 | - x{{ x }} * x{{ x }}_coef 295 | {%- endfor %} 296 | , 2)) as resid_square_mean, 297 | count(*) as n 298 | from 299 | _dbt_linreg_base as b 300 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_final_coefs') | indent(2) }} 301 | {%- if group_by %} 302 | group by 303 | {{ dbt_linreg._gb_cols(group_by, prefix='b') | indent(2) }} 304 | {%- endif %} 305 | ), 306 | _dbt_linreg_stderrs as ( 307 | select 308 | {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }} 309 | {%- for x in xcols %} 310 | sqrt(inv_x{{ x }}x{{ x }} * resid_square_mean * n / (n - {{ upto }})) as x{{ x }}_stderr 311 | {%- if not loop.last -%} 312 | , 313 | {%- endif %} 314 | {%- endfor %} 315 | from 316 | _dbt_linreg_resid as b 317 | {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_inverse_xtx') | indent(2) }} 318 | ) 319 | {%- endif %} 320 | {{ 321 | dbt_linreg.final_select( 322 | exog=exog, 323 | exog_aliased=exog_aliased, 324 | add_constant=add_constant, 325 | group_by=group_by, 326 | output=output, 327 | output_options=output_options, 328 | calculate_standard_error=calculate_standard_error 329 | ) 330 | }}) 331 | {% endmacro %} 332 | -------------------------------------------------------------------------------- /scripts.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is used for generation of CSV files for integration test cases, 3 | and also for manual verification + generation of test case values. 4 | """ 5 | import json 6 | import os 7 | import os.path as op 8 | import warnings 9 | from typing import NamedTuple 10 | from typing import Optional 11 | from typing import Protocol 12 | 13 | import numpy as np 14 | import pandas as pd 15 | import rich_click as click 16 | import statsmodels.api as sm 17 | import yaml 18 | from tabulate import tabulate 19 | 20 | 21 | # Suppress iteritems warning 22 | warnings.simplefilter("ignore", category=FutureWarning) 23 | 24 | # No scientific notation 25 | np.set_printoptions(suppress=True) 26 | 27 | 28 | DIR = op.dirname(__file__) 29 | 30 | DEFAULT_SIZE = 10_000 31 | DEFAULT_SEED = 382479347 32 | 33 | 34 | class TestCase(NamedTuple): 35 | df: pd.DataFrame 36 | x_cols: list[str] 37 | y_col: str 38 | group: Optional[str] = None 39 | 40 | 41 | class TestCaseCallable(Protocol): 42 | def __call__(self, size: int, seed: int) -> TestCase: 43 | pass 44 | 45 | 46 | def gram_schmidt(df: pd.DataFrame): 47 | q = pd.DataFrame(index=df.index) 48 | for c, v in df.items(): 49 | v_new = v.copy() 50 | for _, u in q.items(): 51 | v_new -= u * u.dot(v) / u.dot(u) 52 | q[c] = v_new / np.linalg.norm(v_new) 53 | return q 54 | 55 | 56 | def simple_matrix(size: int = DEFAULT_SIZE, seed: int = DEFAULT_SEED) -> TestCase: 57 | # Gram Schmidt makes any matrix into the simplest test case because 58 | # orthogonalization guarantees round and predictable coefficients. 59 | # 60 | # That said, we also want to cover test cases unorthogonalized. 61 | # Otherwise, it kind of beats the point of writing all that multiple 62 | # regression logic using the FWL theorem. 63 | # 64 | # So although it is a good and clean test case, it can't be the only one. 65 | rs = np.random.RandomState(seed=seed) 66 | df = pd.DataFrame(index=range(size)) 67 | df["const"] = 1 68 | 69 | coefficients = pd.Series({ 70 | "const": 10, 71 | "xa": 5, 72 | "xb": 7, 73 | "xc": 9, 74 | "xd": 11, 75 | "xe": 13, 76 | "xf": 15, 77 | "xg": 17, 78 | "xh": 19, 79 | "xi": 21, 80 | "xj": 23 81 | }) 82 | 83 | for c in coefficients.index: 84 | if c == "const": 85 | continue 86 | df[c] = rs.normal(0, 1, size=size) 87 | df["epsilon"] = rs.normal(0, 10, size=size) 88 | 89 | feature_cols = list(coefficients.index) 90 | x_cols = [i for i in feature_cols if i != "const"] 91 | non_const_cols = x_cols + ["epsilon"] 92 | 93 | # Center the non-constant columns 94 | # This is kinda like orthogonalizing w/r/t constant term. 95 | # So doing this here means we don't need to Gram Schmidt the constant. 96 | for c in non_const_cols: 97 | df[c] -= df[c].mean() 98 | 99 | df[non_const_cols] = gram_schmidt(df[non_const_cols]) 100 | 101 | df["y"] = df[feature_cols].dot(coefficients) + df["epsilon"] 102 | 103 | return TestCase(df=df, y_col="y", x_cols=x_cols) 104 | 105 | 106 | def collinear_matrix(size: int = DEFAULT_SIZE, seed: int = DEFAULT_SEED) -> TestCase: 107 | rs = np.random.RandomState(seed=seed) 108 | df = pd.DataFrame(index=range(size)) 109 | df["const"] = 1 110 | df["x1"] = 2 + rs.normal(0, 1, size=size) 111 | df["x2"] = 1 - df["x1"] + rs.normal(0, 3, size=size) 112 | df["x3"] = 3 + 2 * df["x2"] + rs.normal(0, 1, size=size) 113 | df["x4"] = -3 + 0.5 * (df["x1"] * df["x3"]) + rs.normal(0, 1, size=size) 114 | df["x5"] = 4 + 0.5 * np.sin(3 * df["x2"]) + rs.normal(0, 1, size=size) 115 | df["epsilon"] = rs.normal(0, 4, size=size) 116 | 117 | coefficients = pd.Series({ 118 | "const": 20, 119 | "x1": 5, 120 | "x2": 7, 121 | "x3": 9, 122 | "x4": 11, 123 | "x5": 13 124 | }) 125 | 126 | x_cols = list(coefficients.index) 127 | 128 | # coefficients will not exactly match due to OVB 129 | df["y"] = ( 130 | df[coefficients.index].dot(coefficients) 131 | + (df["x3"] + np.sin(df["x1"])) ** 2 132 | + df["epsilon"] 133 | ) 134 | 135 | return TestCase(df=df, y_col="y", x_cols=x_cols) 136 | 137 | 138 | def groups_matrix(size: int = DEFAULT_SIZE, seed: int = DEFAULT_SEED) -> TestCase: 139 | rs = np.random.RandomState(seed=seed) 140 | size1 = size // 2 141 | size2 = size - size1 142 | 143 | df1 = pd.DataFrame(index=range(size1)) 144 | df1["gb_var"] = "a" 145 | df1["const"] = 1 146 | df1["x1"] = 2 + rs.normal(0, 1, size=size1) 147 | df1["x2"] = 1 - df1["x1"] + rs.normal(0, 3, size=size1) 148 | df1["x3"] = 3 + 2 * df1["x2"] + rs.normal(0, 1, size=size1) 149 | df1["y"] = 1 * df1["x1"] + 2 * df1["x2"] + 3 * df1["x2"] + rs.normal(0, 1, size=size1) 150 | 151 | df2 = pd.DataFrame(index=range(size2)) 152 | df2["gb_var"] = "b" 153 | df2["const"] = 1 154 | df2["x1"] = 6 + rs.normal(0, 3, size=size2) 155 | df2["x2"] = 3 + df2["x1"] + rs.normal(0, 3, size=size2) 156 | df2["x3"] = -1 - df2["x2"] + rs.normal(0, 2, size=size2) 157 | df2["y"] = 2 + 3 * df2["x1"] + 4 * df2["x2"] + 5 * df2["x2"] + rs.normal(0, 1, size=size1) 158 | 159 | df = pd.concat([df1, df2], axis=0).reset_index() 160 | 161 | return TestCase( 162 | df=df, 163 | y_col="y", 164 | x_cols=["const", "x1", "x2", "x3"], 165 | group="gb_var" 166 | ) 167 | 168 | 169 | ALL_TEST_CASES: dict[str, TestCaseCallable] = { 170 | "simple_matrix": simple_matrix, 171 | "collinear_matrix": collinear_matrix, 172 | "groups_matrix": groups_matrix 173 | } 174 | 175 | 176 | def click_option_seed(**kwargs): 177 | return click.option( 178 | "--seed", "-s", 179 | default=DEFAULT_SEED, 180 | show_default=True, 181 | help="Seed used to generate data.", 182 | **kwargs 183 | ) 184 | 185 | 186 | def click_option_size(**kwargs): 187 | return click.option( 188 | "--size", "-n", 189 | default=DEFAULT_SIZE, 190 | show_default=True, 191 | help="Number of rows to generate.", 192 | **kwargs 193 | ) 194 | 195 | 196 | @click.group("main", context_settings=dict(help_option_names=["-h", "--help"])) 197 | def cli(): 198 | """CLI for manually testing the code base.""" 199 | 200 | 201 | @cli.command("regress") 202 | @click.option("--table", "-t", 203 | required=True, 204 | type=click.Choice(ALL_TEST_CASES.keys()), 205 | help="Table to regress against.") 206 | @click.option("--const/--no-const", 207 | default=True, 208 | type=click.BOOL, 209 | show_default=True, 210 | help="If true, add the constant term.") 211 | @click.option("--columns", "-c", 212 | default=None, 213 | type=click.INT, 214 | show_default=True, 215 | help="Number of columns to regress.") 216 | @click.option("--alpha", "-a", 217 | default=None, 218 | type=click.FLOAT, 219 | show_default=True, 220 | help="Alpha for the regression.") 221 | @click_option_size() 222 | @click_option_seed() 223 | def regress(table: str, const: bool, columns: int, alpha: float, size: int, seed: int): 224 | """ 225 | Run regression on integration test cases. 226 | 227 | Use me for either manual verification of test cases, or for generating new 228 | test cases. (All numeric values for test cases were generated using this 229 | CLI.) 230 | """ 231 | callback = ALL_TEST_CASES[table] 232 | 233 | click.echo(click.style("=" * 80, fg="blue")) 234 | click.echo( 235 | click.style("Test case: ", fg="blue", bold=True) 236 | + 237 | click.style(table, fg="blue") 238 | ) 239 | click.echo(click.style("=" * 80, fg="blue")) 240 | 241 | test_case = callback(size, seed) 242 | 243 | if columns is None: 244 | x_cols = test_case.x_cols 245 | else: 246 | # K plus Constant (1) 247 | x_cols = test_case.x_cols[:columns+1] 248 | 249 | if not const: 250 | x_cols = [i for i in x_cols if i != "const"] 251 | 252 | def _run_model(cond=None): 253 | if cond is None: 254 | cond = slice(None) 255 | y = test_case.df.loc[cond, test_case.y_col] 256 | x_mat = test_case.df.loc[cond, x_cols] 257 | if alpha: 258 | if const: 259 | alpha_arr = [0, *([alpha] * (len(x_mat.columns) - 1))] 260 | else: 261 | alpha_arr = [alpha] * len(x_mat.columns) 262 | model = sm.OLS( 263 | y, 264 | x_mat 265 | ).fit_regularized(L1_wt=0, alpha=alpha_arr) 266 | else: 267 | model = sm.OLS(y, x_mat).fit() 268 | res_df = pd.DataFrame(index=x_mat.columns) 269 | res_df["coef"] = model.params 270 | res_df["stderr"] = model.bse 271 | res_df["tstat"] = res_df["coef"] / res_df["stderr"] 272 | click.echo( 273 | tabulate( 274 | res_df, 275 | headers=["column name", "coef", "stderr", "tstat"], 276 | disable_numparse=True, 277 | tablefmt="psql", 278 | ) 279 | ) 280 | 281 | if test_case.group: 282 | for c in test_case.df[test_case.group].unique(): 283 | click.echo(click.style(f"{test_case.group} - {c}", fg="green")) 284 | _run_model(cond=(test_case.df[test_case.group] == c)) 285 | else: 286 | _run_model() 287 | 288 | 289 | def echo_table_name(s: str): 290 | click.echo(click.style("=" * 80, fg="green")) 291 | click.echo( 292 | click.style("Table: ", fg="green", bold=True) 293 | + 294 | click.style(s, fg="green") 295 | ) 296 | click.echo(click.style("=" * 80, fg="green")) 297 | 298 | 299 | @cli.command("gen-test-cases") 300 | @click.option("--table", "-t", "tables", 301 | multiple=True, 302 | default=None, 303 | show_default=True, 304 | help="Generate a specific table. If None, generate all tables.") 305 | @click_option_size() 306 | @click_option_seed() 307 | @click.option("--skip-if-exists", is_flag=True, 308 | help="Skip if the file exists. Otherwise, overwrite.") 309 | def gen_test_cases(tables: list[str], size: int, seed: int, skip_if_exists: bool): 310 | """Generate integration test cases (CSV files).""" 311 | if not tables: 312 | tables = ALL_TEST_CASES 313 | for table_name in tables: 314 | file_name = f"{DIR}/integration_tests/seeds/{table_name}.csv" 315 | if skip_if_exists and op.exists(file_name): 316 | click.echo("File " + click.style(file_name, fg="blue") + " already exists; skipping.") 317 | continue 318 | 319 | callback = ALL_TEST_CASES[table_name] 320 | 321 | echo_table_name(table_name) 322 | 323 | test_case = callback(size, seed) 324 | y = test_case.df[test_case.y_col] 325 | x_mat = test_case.df[test_case.x_cols] 326 | 327 | click.echo() 328 | li = [] 329 | for i in range(1, len(x_mat.columns) + 1): 330 | model = sm.OLS( 331 | y, 332 | sm.add_constant(x_mat.iloc[:, :i]) 333 | ).fit() 334 | params = model.params.rename(f"{i}-var").reindex(x_mat.columns) 335 | params = params.apply( 336 | lambda s: "{:.5f}".format(s) 337 | if pd.notna(s) 338 | else None 339 | ) 340 | expand_by = params.apply(lambda s: len(s) if s is not None else 0).max() 341 | params = params.where( 342 | pd.notna(params), 343 | click.style("-" * expand_by, fg="bright_black") 344 | ) 345 | 346 | params.apply(lambda s: len(s)).max() 347 | li.append(params) 348 | coefs = pd.concat(li, axis=1) 349 | click.echo( 350 | tabulate( 351 | coefs, 352 | headers=coefs.columns, 353 | disable_numparse=True, 354 | tablefmt="psql", 355 | ) 356 | ) 357 | 358 | 359 | all_cols = [test_case.y_col, *test_case.x_cols] 360 | if test_case.group: 361 | all_cols.append(test_case.group) 362 | 363 | test_case.df[all_cols].to_csv(file_name, index=False) 364 | click.echo( 365 | click.style(f"Wrote DataFrame to file {file_name!r}", fg="yellow") 366 | ) 367 | click.echo("") 368 | click.echo(click.style("Done!", fg="green")) 369 | 370 | 371 | @cli.command("gen-hide-macros-yaml") 372 | @click.option("--parse/--no-parse", is_flag=True, default=True) 373 | def gen_hide_args_yaml(parse: bool) -> None: 374 | """Generates the YAML that hides the macros from the docs. 375 | 376 | Requires the `manifest.json` to be available. 377 | (`dbt parse --profiles-dir ./integration_tests/profiles`) 378 | 379 | Recommended to `| pbcopy` this command, then paste in `macros/schema.yml`. 380 | 381 | This is not enforced during CICD, beware! 382 | """ 383 | 384 | if parse: 385 | from dbt.cli.main import dbtRunner 386 | os.environ["DO_NOT_TRACK"] = "1" 387 | dbtRunner().invoke( 388 | [ 389 | "parse", 390 | "--profiles-dir", op.join(op.dirname(__file__), "integration_tests", "profiles"), 391 | "--project-dir", op.dirname(__file__) 392 | ] 393 | ) 394 | 395 | exclude_from_hiding = ["ols"] 396 | with open("target/manifest.json") as f: 397 | manifest = json.load(f) 398 | 399 | macros = [ 400 | data["name"] for fqn, data 401 | in manifest["macros"].items() 402 | if data.get("package_name", "") == "dbt_linreg" 403 | and data.get("name") not in exclude_from_hiding 404 | ] 405 | 406 | out = [ 407 | {"name": macro, "docs": {"show": False}} 408 | for macro in sorted(macros) 409 | ] 410 | 411 | print(" " + yaml.safe_dump(out, sort_keys=False).replace("\n", "\n ")) 412 | 413 | 414 | if __name__ == "__main__": 415 | cli() 416 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 | dbt_linreg logo 5 | 6 |

7 |

8 | Linear regression in any SQL dialect, powered by dbt. 9 |

10 |

11 | Tests badge 12 | Docs badge 13 |

14 | 15 | # Overview 16 | 17 | **dbt_linreg** is an easy way to perform linear regression and ridge regression in SQL (Snowflake, DuckDB, Clickhouse, and more) with OLS using dbt's Jinja2 templating. 18 | 19 | Reasons to use **dbt_linreg**: 20 | 21 | - 📈 **Linear regression in pure SQL:** With the power of Jinja2 templating and some behind-the-scenes math tricks, it is possible to implement multiple and multivariate regression in pure SQL. Most SQL engines (even OLAP engines) do not have a multiple regression implementation of OLS, so this fills a valuable niche. **`dbt_linreg` implements true OLS, not an approximation!** 22 | - 📱 **Simple interface:** Just define a `table=` (which works with `ref()`, `source()`, and CTEs), a y-variable with `endog=`, your x-variables in a list with `exog=...`, and you're all set! Note that the API is loosely based on Statsmodels's naming conventions. 23 | - 🤖 **Support for ridge regression:** Just pass in `alpha=scalar` or `alpha=[scalar1, scalar2, ...]` to regularize your regressions. (Note: regressors are not automatically standardized.) 24 | - 🤸‍ **Flexibility:** Tons of formatting options available to return coefficients the way you want. 25 | - 🤗 **User friendly:** The API provides comprehensive feedback on input errors. 26 | - 💪 **Durable and tested:** Everything in this code base is tested against equivalent regressions performed in Statsmodels with high precision assertions (between 10e-6 to 10e-7, depending on the database engine). 27 | 28 | _Note: If you enjoy this project, you may also enjoy my other dbt machine learning project, [**dbt_pca**](https://github.com/dwreeves/dbt_pca)._ 😊 29 | 30 | # Installation 31 | 32 | dbt-core `>=1.2.0` is required to install `dbt_linreg`. 33 | 34 | Add this the `packages:` list your dbt project's `packages.yml`: 35 | 36 | ```yaml 37 | - package: "dwreeves/dbt_linreg" 38 | version: "0.3.1" 39 | ``` 40 | 41 | The full file will look something like this: 42 | 43 | ```yaml 44 | packages: 45 | # ... 46 | # Other packages here 47 | # ... 48 | - package: "dwreeves/dbt_linreg" 49 | version: "0.3.1" 50 | ``` 51 | 52 | # Examples 53 | 54 | ### Simple example 55 | 56 | The following example runs a linear regression of 3 columns `xa + xb + xc` on `y`, using data in the dbt model named `simple_matrix`. It outputs the data in "long" format, and rounds the coefficients to 5 decimal points: 57 | 58 | ```sql 59 | {{ 60 | config( 61 | materialized="table" 62 | ) 63 | }} 64 | select * from {{ 65 | dbt_linreg.ols( 66 | table=ref('simple_matrix'), 67 | endog='y', 68 | exog=['xa', 'xb', 'xc'], 69 | output='long', 70 | output_options={'round': 5} 71 | ) 72 | }} as linreg 73 | ``` 74 | 75 | Output: 76 | 77 | |variable_name|coefficient|standard_error|t_statistic| 78 | |---|---|---|---| 79 | |const|10.0|0.00462|2163.27883| 80 | |xa|5.0|0.46226|10.81639| 81 | |xb|7.0|0.46226|15.14295| 82 | |xc|9.0|0.46226|19.46951| 83 | 84 | Note: `simple_matrix` is one of the test cases, so you can try this yourself! Standard errors are constant across `xa`, `xb`, `xc`, because `simple_matrix` is orthonormal. 85 | 86 | ### Complex example 87 | 88 | The following hypothetical example shows multiple ridge regressions (one per `product_id`) on a table that is preprocessed substantially. After the fact, predictions are run, and the R-squared of each regression is calculated at the end. 89 | 90 | This example shows that, although `dbt_linreg` does not implement everything for you, the OLS implementation does most of the hard work. This gives you the freedom to do things you've never been able to do before in SQL! 91 | 92 | ```sql 93 | {{ 94 | config( 95 | materialized="table" 96 | ) 97 | }} 98 | with 99 | 100 | preprocessed_data as ( 101 | 102 | select 103 | product_id, 104 | price, 105 | log(price) as log_price, 106 | epoch(time) as t, 107 | sin(epoch(time)*pi()*2 / (60*60*24*365)) as sin_t, 108 | cos(epoch(time)*pi()*2 / (60*60*24*365)) as cos_t 109 | from 110 | {{ ref('prices') }} 111 | 112 | ), 113 | 114 | preprocessed_and_normalized_data as ( 115 | 116 | select 117 | product_id, 118 | price, 119 | log(price) as log_price, 120 | (time - avg(time) over ()) / (stddev(time) over ()) as t_norm, 121 | (sin_t - avg(sin_t) over ()) / (stddev(sin_t) over ()) as sin_t_norm, 122 | (cos_t - avg(cos_t) over ()) / (stddev(cos_t) over ()) as cos_t_norm 123 | from 124 | preprocessed_data 125 | 126 | ), 127 | 128 | coefficients as ( 129 | 130 | select * from {{ 131 | dbt_linreg.ols( 132 | table='preprocessed_and_normalized_data', 133 | endog='log_price', 134 | exog=['t_norm', 'sin_t_norm', 'cos_t_norm'], 135 | group_by=['product_id'], 136 | alpha=0.0001 137 | ) 138 | }} 139 | 140 | ), 141 | 142 | predict as ( 143 | 144 | select 145 | d.product_id, 146 | d.time, 147 | d.price, 148 | exp( 149 | c.const 150 | + d.t_norm * c.t_norm 151 | + d.sin_t_norm * c.sin_t_norm 152 | + d.cos_t_norm * sin_t_norm) as predicted_price 153 | from 154 | preprocessed_and_normalized_data as d 155 | join 156 | coefficients as c 157 | on 158 | d.product_id = c.product_id 159 | 160 | ) 161 | 162 | select 163 | product_id, 164 | pow(corr(predicted_price, price), 2) as r_squared 165 | from 166 | predict 167 | group by 168 | product_id 169 | ``` 170 | 171 | # Supported Databases 172 | 173 | **dbt_linreg** should work with most SQL databases, but so far, testing has been done for the following database tools: 174 | 175 | | Database | Supported | Precision asserted in CI\* | Supported since version | 176 | |----------------|-----------|----------------------------|-------------------------| 177 | | **Snowflake** | ✅ | _n/a_ | 0.1.0 | 178 | | **DuckDB** | ✅ | 10e-7 | 0.1.0 | 179 | | **Postgres**† | ✅ | 10e-7 | 0.2.3 | 180 | | **Redshift** | ✅ | _n/a_ | 0.2.4 | 181 | | **Clickhouse** | ✅ | 10e-6 | 0.3.0 | 182 | 183 | If **dbt_linreg** does not work in your database tool, please let me know in a bug report. 184 | 185 | > _\* Precision is for test cases using the **collinear_matrix** for unregularized regressions, in comparison to the output of the same regression in the Python package Statsmodels using `sm.OLS().fit(method="pinv")`. For example, coefficients for unregularized regressions performed in DuckDB are asserted to be within 10e-7 of Statsmodels._ 186 | 187 | > _† Minimal support for Postgres. Postgres is syntactically supported, but is not performant under certain circumstances._ 188 | 189 | # API 190 | 191 | The only function available in the public API is the `dbt_linreg.ols()` macro. 192 | 193 | Using Python typing notation, the full API for `dbt_linreg.ols()` looks like this: 194 | 195 | ```python 196 | def ols( 197 | table: str, 198 | endog: str, 199 | exog: Union[str, list[str]], 200 | add_constant: bool = True, 201 | output: Literal['wide', 'long'] = 'wide', 202 | output_options: Optional[dict[str, Any]] = None, 203 | group_by: Optional[Union[str, list[str]]] = None, 204 | alpha: Optional[Union[float, list[float]]] = None, 205 | method: Literal['chol', 'fwl'] = 'chol', 206 | method_options: Optional[dict[str, Any]] = None 207 | ): 208 | ... 209 | ``` 210 | 211 | Where: 212 | 213 | - **table**: Name of table or CTE to pull the data from. You can use `ref()` or `source()` here if you'd like. 214 | - **endog**: The endogenous variable / y variable / target variable of the regression. (You can also specify `y=...` instead of `endog=...` if you prefer.) 215 | - **exog**: The exogenous variables / X variables / features of the regression. (You can also specify `x=...` instead of `exog=...` if you prefer.) 216 | - **add_constant**: If true, a constant term is added automatically to the regression. 217 | - **output**: Either "wide" or "long" output format for coefficients. See **Outputs and output options** for more. 218 | - If `wide`, the variables span the columns with their original variable names, and the coefficients fill a single row. 219 | - If `long`, the coefficients are in a single column called `coefficient`, and the variable names are in a single column called `variable_name`. 220 | - **output_options**: See **Formats and format options** section for more. 221 | - **group_by**: If specified, the regression will be grouped by these variables, and individual regressions will run on each group. 222 | - **alpha**: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section for more information. 223 | - **method**: The method used to calculate the regression. See **Methods and method options** for more. 224 | - **method_options**: Options specific to the estimation method. See **Methods and method options** for more. 225 | 226 | # Outputs and output options 227 | 228 | Outputs can be returned either in `output='long'` or `output='wide'`. 229 | 230 | All outputs have their own output options, which can be passed into the `output_options=` arg as a dict, e.g. `output_options={'foo': 'bar'}`. 231 | 232 | `output=` and `output_options=` were formerly named `format=` and `format_options=` respectively. 233 | This has been deprecated to make **dbt_linreg**'s API more consistent with **dbt_pca**'s API. 234 | 235 | ### Options for `output='long'` 236 | 237 | - **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits. 238 | - **constant_name** (`string`; default = `'const'`): String name that refers to constant term. 239 | - **variable_column_name** (`string`; default = `'variable_name'`): Column name storing strings of variable names. 240 | - **coefficient_column_name** (`string`; default = `'coefficient'`): Column name storing model coefficients. 241 | - **strip_quotes** (`bool`; default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals. 242 | 243 | These options are available for `output='long'` only when `method='chol'`: 244 | 245 | - **calculate_standard_error** (`bool`; default = `True if not alpha else False`): If true, provide the standard error in the output. 246 | - **standard_error_column_name** (`string`; default = `'standard_error'`): Column name storing the standard error for the parameter. 247 | - **t_statistic_column_name** (`string`; default = `'t_statistic'`): Column name storing the t-statistic for the parameter. 248 | 249 | ### Options for `output='wide'` 250 | 251 | - **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits. 252 | - **constant_name** (`string`; default = `'const'`): String name that refers to constant term. 253 | - **variable_column_prefix** (`string`; default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.) 254 | - **variable_column_suffix** (`string`; default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.) 255 | 256 | ## Setting output options globally 257 | 258 | Output options can be set globally via `vars`, e.g. in your `dbt_project.yml`: 259 | 260 | ```yaml 261 | # dbt_project.yml 262 | vars: 263 | dbt_linreg: 264 | output_options: 265 | round: 5 266 | ``` 267 | 268 | Output options passed via `ols()` always take precedence over globally set output options. 269 | 270 | # Methods and method options 271 | 272 | There are currently two valid methods for calculating regression coefficients: 273 | 274 | - `chol`: Uses Cholesky decomposition to calculate the pseudo-inverse. 275 | - `fwl`: Uses a "Frisch-Waugh-Lovell" approach, which consists of calculating univariate regressions to get multiple regression coefficients. 276 | 277 | ## `chol` method 278 | 279 | **👍 This is the suggested method (and the default) for calculating regressions!** 280 | 281 | This method calculates regression coefficients using the Moore-Penrose pseudo-inverse, and the inverse of **X'X** is calculated using Cholesky decomposition, hence it is referred to as `chol`. 282 | 283 | ### Options for `method='chol'` 284 | 285 | Specify these in a dict using the `method_options=` kwarg: 286 | 287 | - **safe** (`bool`; default: `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value may be passed into a SQRT() or a divide by zero may occur, and most SQL engines will raise an error when this happens. 288 | - **subquery_optimization** (`bool`; default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened. 289 | - **intra_select_aliasing** (`bool`; default = `[depends on db]`): If True, within a single select statement, column aliases are used to refer to other columns created during that select. This can significantly reduce the text of a SQL query, but not all SQL engines support this. By default, for all databases officially supported by **dbt_linreg**, the best option is already selected. For unsupported databases, the default is `False` for broad compatibility, so if you are running **dbt_linreg** in an officially unsupported database engine which supports this feature, you may want to modify this option globally in your `vars` to be `true`. 290 | 291 | ## `fwl` method 292 | 293 | **This method is generally not recommended.** 294 | 295 | Simple univariate regression coefficients are simply `covar_pop(y, x) / var_pop(x)`. 296 | 297 | The multiple regression implementation uses a technique described in section `3.2.3 Multiple Regression from Simple Univariate Regression` of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=71)). Econometricians know this as the Frisch-Waugh-Lovell theorem, hence the method is referred to as `fwl` internally in the code base. 298 | 299 | Ridge regression is implemented using the augmentation technique described in Exercise 12 of Chapter 3 of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=115)). 300 | 301 | There are a few reasons why this method is discouraged over the `chol` method: 302 | 303 | - 🐌 It tends to be much slower in OLAP systems, and struggles to efficiently calculate large number of columns. 304 | - 📊 It does not calculate standard errors. 305 | - 😕 For ridge regression, coefficients are not accurate; they tend to be off by a magnitude of ~0.01%. 306 | - ⚠️ It does not work in all databases because it relies on `COVAR_POP`. 307 | 308 | So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgres) for unregularized coefficient estimation. Long story short, the `chol` method relies on subquery optimization to be more performant than `fwl`; however, OLTP systems do not benefit at all from subquery optimization. This means that `fwl` is slightly more performant in this context. 309 | 310 | ## Setting method options globally 311 | 312 | Method options can be set globally via `vars`, e.g. in your `dbt_project.yml`. Each `method` gets its own config, e.g. the `dbt_linreg: method_options: chol: ...` namespace only applies to the `chol` method. Here is an example: 313 | 314 | ```yaml 315 | # dbt_project.yml 316 | vars: 317 | dbt_linreg: 318 | method_options: 319 | chol: 320 | intra_select_aliasing: true 321 | ``` 322 | 323 | Method options passed via `ols()` always take precedence over globally set method options. 324 | 325 | # Notes 326 | 327 | - ⚠️ **If your coefficients are null, it does not mean dbt_linreg is broken, it most likely means your feature columns are perfectly multicollinear.** If you are 100% sure that is not the issue, please file a bug report with a minimally reproducible example. 328 | 329 | - Regularization is implemented using nearly the same approach as Statsmodels; the only difference is that the constant term can never be regularized. This means: 330 | - A scalar input (e.g. `alpha=0.01`) will apply an alpha of `0.01` to all features. 331 | - An array input (e.g. `alpha=[0.01, 0.02, 0.03, 0.04, 0.05]`) will apply an alpha of `0.01` to the first column, `0.02` to the second column, etc. 332 | - `alpha` is equivalent to what TEoSL refers to as "lambda," times the sample size N. That is to say: `α ≡ λ * N`. 333 | - (Of course, you can regularize the constant term by DIYing your own constant term and doing `add_constant=false`.) 334 | 335 | - Regularization as currently implemented for the `chol` method tends to be very slow in OLTP systems (e.g. Postgres), but is very performant in OLAP systems (e.g. Snowflake, DuckDB, BigQuery, Redshift). As dbt is more commonly used in OLAP contexts, the code base is optimized for the OLAP use case. 336 | - That said, it may be possible to make regularization in OLTP more performant (e.g. with augmentation of the design matrix), so PRs are welcome. 337 | 338 | - Regression coefficients in Postgres are always `numeric` types. 339 | 340 | ## Possible future features 341 | 342 | Some things that could happen in the future: 343 | 344 | - Weighted least squares (WLS) 345 | - Efficient multivariate regression (i.e. multiple endogenous vectors sharing a single design matrix) 346 | - P-values 347 | - Heteroskedasticity robust standard errors 348 | - Recursive CTE implementations + long formatted inputs 349 | 350 | Note that although I maintain this library (as of writing in 2025), I do not actively update it much with new features, so this wish list is unlikely unless I personally need it or unless someone else contributes these features. 351 | 352 | # FAQ 353 | 354 | ### How does this work? 355 | 356 | See **Methods and method options** section for a full breakdown of each linear regression implementation. 357 | 358 | All approaches were validated using Statsmodels `sm.OLS()`. 359 | 360 | ### BigQuery (or other database) has linear regression implemented natively. Why should I use `dbt_linreg` over that? 361 | 362 | You don't have to use this. Most warehouses don't support multiple regression out of the box, so this satisfies a niche for those database tools which don't. 363 | 364 | That said, even in BigQuery, it may be very useful to extract coefficients within a query instead of generating a separate `MODEL` object through a DDL statement, for a few reasons. Even in more black box predictive contexts, being able to predict in the same `SELECT` statement as training can be convenient. Additionally, BigQuery does not expose model coefficients to users, and this can be a dealbreaker in many contexts where you care about your coefficients as measurements, not as predictive model parameters. Lastly, `group_by` is akin to estimating parameters for multiple linear regressions at once. 365 | 366 | Overall, I would say this is pretty different from what BigQuery's `CREATE MODEL` is doing; use whatever makes sense for your use case! But keep in mind that for large numbers of variables, a native implementation of linear regression will be noticeably more efficient than this implementation. 367 | 368 | ### Why is L2 regularization / ridge regression supported, but not L1 regularization / LASSO supported? 369 | 370 | There is no closed-form solution to L1 regularization, which makes it very very hard to add through raw SQL. L2 regularization has a closed-form solution and can be implemented using a pre-processing trick. 371 | 372 | ### Is the `group_by=[...]` argument like categorical variables / one-hot encodings? 373 | 374 | No. The `group_by` runs a linear regressions within each group, and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables. 375 | 376 | ### Why aren't categorical variables / one-hot encodings supported? 377 | 378 | I opt to leave out dummy variable support because it's tricky, and I want to keep the API clean and mull on how to best implement that at the highest level. 379 | 380 | Note that you couldn't simply add categorical variables in the same list as numeric variables because Jinja2 templating is not natively aware of the types you're feeding through it, nor does Jinja2 know the values that a string variable can take on. The way you would actually implement categorical variables is with `group by` trickery (i.e. center both y and X by categorical variable group means), although I am not sure how to do that efficiently for more than one categorical variable column. 381 | 382 | If you'd like to regress on a categorical variable, for now you'll need to do your own feature engineering, e.g. `(foo = 'bar')::int as foo_bar, (foo = 'baz')::int as foo_baz`. 383 | 384 | ### Why are there no p-values? 385 | 386 | This is something that might happen in the future. P-values would require a lookup on a dimension table, which is a significant amount of work to manage nicely. 387 | 388 | In the meanwhile, you can implement this yourself-- just create a dimension table that left joins a t-statistic on a half-open interval to lookup a p-value. 389 | 390 | # Trademark & Copyright 391 | 392 | dbt is a trademark of dbt Labs. 393 | 394 | This package is **unaffiliated** with dbt Labs. 395 | --------------------------------------------------------------------------------