├── .gitignore ├── README.md ├── pytest_dbt_adapter ├── __init__.py ├── builtin.py ├── exceptions.py ├── projects │ ├── __init__.py │ ├── base.yml │ ├── data_test_ephemerals.yml │ ├── data_tests.yml │ ├── empty.yml │ ├── ephemeral.yml │ ├── incremental.yml │ ├── schema_tests.yml │ ├── snapshot_cc.yml │ └── snapshot_ts.yml ├── sequences │ ├── __init__.py │ ├── base.yml │ ├── data_test.yml │ ├── data_test_ephemeral_models.yml │ ├── empty.yml │ ├── ephemeral.yml │ ├── incremental.yml │ ├── schema_test.yml │ ├── snapshot_strategy_check_cols.yml │ └── snapshot_strategy_timestamp.yml └── spec_file.py ├── requirements.txt ├── setup.py └── specs ├── postgres.dbtspec ├── presto.dbtspec ├── spark-databricks.dbtspec └── spark.dbtspec /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.pyc 3 | __pycache__ 4 | .tox/ 5 | .idea/ 6 | build/ 7 | dist/ 8 | .DS_Store 9 | .pytest_cache 10 | env 11 | 12 | .vscode 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > :warning: **THIS SUITE IS NO LONGER THE RECOMMENDED WAY TO TEST ADAPTERS.** 2 | If you are building and testing a new dbt adapter, please read instead: ["Testing a new adapter"](https://docs.getdbt.com/guides/advanced/adapter-development/4-testing-a-new-adapter) 3 | 4 | ## Installation and use 5 | 6 | `pip install pytest-dbt-adapter` 7 | 8 | 9 | You'll need to install this package with `pip install pytest-dbt-adapter` and write a specfile, which is a yaml file ending in `.dbtspec`. See the included spark/postgres examples in `specs`. You can also write custom test sequences and override existing default projects. 10 | 11 | After installing this package, you should be able to run your spec with `pytest path/to/mytest.dbspec`. You'll need dbt-core and your adapter plugin installed in the environment as well. 12 | 13 | This package also includes a module named `dbt_adapter_tests` that includes helpers for writing integration tests with Python if necessary. For maintainability purposes, this should only be used as a last resort for scenarios that are impossible to capture with a specfile. 14 | 15 | 16 | ## Specs 17 | 18 | A spec is composed of a minimum of two things: 19 | - a `target` block 20 | - a `sequences` block 21 | - The keys are test names. You can select from these names with pytest's `-k` flag. 22 | - The values are test sequence definitions. 23 | 24 | Optionally, there is also: 25 | - a `projects` block 26 | 27 | ### Targets 28 | 29 | A target block is just like a target block you'd use in dbt core. However, there is one special change: the `schema` field should include a `{{ var('_dbt_random_suffix') }}` somewhere that the test suite will insert. 30 | 31 | 32 | ### Sequences 33 | 34 | A sequence has a `name` (the sequence name), a `project` (the project name to use), and `sequence` (a collection of test steps). You can declare new sequences inline, or use the name of a builtin sequence. A sequence itself is just a list of steps. You can find examples in the form of the builtin sequences in the `sequences/` folder. 35 | 36 | You are encouraged to use as many sequences as you can from the built-in list without modification. 37 | 38 | 39 | ### Projects 40 | 41 | The minimum project contains only a `name` field. The value is the name of the project - sequences include a project name. 42 | 43 | A project also has an optional `paths` block, where the keys are relative file paths (to a `dbt_project.yml` that will be written), and the values are the contents of those files. 44 | 45 | There is a `dbt_project_yml` block, which should be a dictionary that will be updated into the default dbt_project.yml (which sets name, version, and config-version). 46 | 47 | 48 | Instead of declaring a `name` field, a project definition may have an `overrides` field that names a builtin project. The test suite will update the named builtin project with those overrides, instead of overwriting the full project with a new one. 49 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .spec_file import DbtSpecFile 2 | 3 | 4 | # custom test loader 5 | def pytest_collect_file(parent, path): 6 | if path.ext == ".dbtspec": # and path.basename.startswith("test"): 7 | return DbtSpecFile.from_parent(parent, fspath=path) 8 | 9 | 10 | def pytest_addoption(parser): 11 | group = parser.getgroup('dbtadapter') 12 | group.addoption( 13 | '--no-drop-schema', 14 | action='store_false', 15 | dest='drop_schema', 16 | 17 | ) 18 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/builtin.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Dict, Any 4 | 5 | import yaml 6 | 7 | from .exceptions import TestProcessingException 8 | 9 | 10 | DEFAULT_DBT_PROJECT = { 11 | 'name': 'dbt_test_project', 12 | 'config-version': 2, 13 | 'version': '1.0.0', 14 | } 15 | 16 | 17 | NAMES_BASE = """ 18 | id,name,some_date 19 | 1,Easton,1981-05-20T06:46:51 20 | 2,Lillian,1978-09-03T18:10:33 21 | 3,Jeremiah,1982-03-11T03:59:51 22 | 4,Nolan,1976-05-06T20:21:35 23 | 5,Hannah,1982-06-23T05:41:26 24 | 6,Eleanor,1991-08-10T23:12:21 25 | 7,Lily,1971-03-29T14:58:02 26 | 8,Jonathan,1988-02-26T02:55:24 27 | 9,Adrian,1994-02-09T13:14:23 28 | 10,Nora,1976-03-01T16:51:39 29 | """.lstrip() 30 | 31 | 32 | NAMES_EXTENDED = NAMES_BASE + """ 33 | 11,Mateo,2014-09-07T17:04:27 34 | 12,Julian,2000-02-04T11:48:30 35 | 13,Gabriel,2001-07-10T07:32:52 36 | 14,Isaac,2002-11-24T03:22:28 37 | 15,Levi,2009-11-15T11:57:15 38 | 16,Elizabeth,2005-04-09T03:50:11 39 | 17,Grayson,2019-08-06T19:28:17 40 | 18,Dylan,2014-03-01T11:50:41 41 | 19,Jayden,2009-06-06T07:12:49 42 | 20,Luke,2003-12-05T21:42:18 43 | """.lstrip() 44 | 45 | 46 | NAMES_ADD_COLUMN = """ 47 | id,name,some_date,last_initial 48 | 1,Easton,1981-05-20T06:46:51,A 49 | 2,Lillian,1978-09-03T18:10:33,B 50 | 3,Jeremiah,1982-03-11T03:59:51,C 51 | 4,Nolan,1976-05-06T20:21:35,D 52 | 5,Hannah,1982-06-23T05:41:26,E 53 | 6,Eleanor,1991-08-10T23:12:21,F 54 | 7,Lily,1971-03-29T14:58:02,G 55 | 8,Jonathan,1988-02-26T02:55:24,H 56 | 9,Adrian,1994-02-09T13:14:23,I 57 | 10,Nora,1976-03-01T16:51:39,J 58 | """.lstrip() 59 | 60 | 61 | class Model: 62 | def __init__(self, config, body): 63 | self.config = config 64 | self.body = body 65 | 66 | @classmethod 67 | def from_dict(cls, dct): 68 | try: 69 | config = dct.get('config', {}) 70 | if 'materialized' in dct: 71 | config['materialized'] = dct['materialized'] 72 | return cls(config=config, body=dct['body']) 73 | except KeyError as exc: 74 | raise TestProcessingException( 75 | f'Invalid test, model is missing key {exc}' 76 | ) 77 | 78 | def config_params(self): 79 | if not self.config: 80 | return '' 81 | else: 82 | pairs = ', '.join( 83 | '{!s}={!r}'.format(key, value) 84 | for key, value in self.config.items() 85 | ) 86 | return '{{ config(' + pairs + ') }}' 87 | 88 | def render(self): 89 | return '\n'.join([self.config_params(), self.body]) 90 | 91 | 92 | class DbtProject: 93 | def __init__( 94 | self, 95 | name: str, 96 | dbt_project_yml: str, 97 | paths: Dict[str, str], 98 | facts: Dict[str, Any] 99 | ): 100 | self.name = name 101 | self.dbt_project_yml = dbt_project_yml 102 | self.paths = paths 103 | self.facts = facts 104 | 105 | def write(self, path: str): 106 | project_path = os.path.join(path, 'project') 107 | os.makedirs(project_path) 108 | with open(os.path.join(project_path, 'dbt_project.yml'), 'w') as fp: 109 | fp.write(yaml.safe_dump(self.dbt_project_yml)) 110 | 111 | for relpath, contents in self.paths.items(): 112 | fullpath = os.path.join(project_path, relpath) 113 | os.makedirs(os.path.dirname(fullpath), exist_ok=True) 114 | 115 | if isinstance(contents, str) and contents.startswith('files.'): 116 | contents_path = contents.split('.')[1:] 117 | cur = KNOWN_FILES 118 | for part in contents_path: 119 | if part not in cur: 120 | raise TestProcessingException( 121 | f'at known file lookup {contents}, could not find ' 122 | f'part {part} in known files path' 123 | ) 124 | cur = cur[part] 125 | contents = cur 126 | 127 | if relpath.startswith('models/') and relpath.endswith('.sql'): 128 | if isinstance(contents, dict): 129 | model = Model.from_dict(contents) 130 | contents = model.render() 131 | if not isinstance(contents, str): 132 | raise TestProcessingException(f'{contents} is not a string') 133 | 134 | with open(fullpath, 'w') as fp: 135 | fp.write(contents) 136 | 137 | @classmethod 138 | def from_dict(cls, dct, overriding=None): 139 | if overriding is None: 140 | overriding = {} 141 | 142 | paths: Dict[str, Any] 143 | facts: Dict[str, Any] 144 | dbt_project_yml: Dict[str, Any] 145 | 146 | if 'overrides' in dct: 147 | name = dct['overrides'] 148 | if name not in overriding: 149 | raise TestProcessingException( 150 | f'Invalid project definition, override name {name} not ' 151 | 'known' 152 | ) from None 153 | dbt_project_yml = overriding[name].dbt_project_yml.copy() 154 | paths = overriding[name].paths.copy() 155 | facts = overriding[name].facts.copy() 156 | else: 157 | try: 158 | name = dct['name'] 159 | except KeyError: 160 | raise TestProcessingException( 161 | f'Invalid project definition, no name in {dct}' 162 | ) from None 163 | 164 | dbt_project_yml = DEFAULT_DBT_PROJECT.copy() 165 | paths = {} 166 | facts = {} 167 | 168 | dbt_project_yml.update(dct.get('dbt_project_yml', {})) 169 | 170 | paths.update(dct.get('paths', {})) 171 | facts.update(dct.get('facts', {})) 172 | return cls( 173 | name=name, 174 | dbt_project_yml=dbt_project_yml, 175 | paths=paths, 176 | facts=facts, 177 | ) 178 | 179 | 180 | SEED_SOURCE_YML = """ 181 | version: 2 182 | sources: 183 | - name: raw 184 | schema: "{{ target.schema }}" 185 | tables: 186 | - name: seed 187 | identifier: "{{ var('seed_name', 'base') }}" 188 | """ 189 | 190 | TEST_SEEDS_SCHEMA_YML_TEST_BASE = """ 191 | version: 2 192 | models: 193 | - name: base 194 | columns: 195 | - name: id 196 | tests: 197 | - not_null 198 | """ 199 | 200 | TEST_MODELS_SCHEMA_YML_TEST_VIEW = """ 201 | version: 2 202 | models: 203 | - name: view 204 | columns: 205 | - name: id 206 | tests: 207 | - not_null 208 | """ 209 | 210 | TEST_MODELS_SCHEMA_YML_TEST_TABLE = """ 211 | version: 2 212 | models: 213 | - name: table 214 | columns: 215 | - name: id 216 | tests: 217 | - not_null 218 | """ 219 | 220 | 221 | TEST_PASSING_DATA_TEST = """ 222 | select * from ( 223 | select 1 as id 224 | ) as my_subquery 225 | where id = 2 226 | """ 227 | 228 | TEST_FAILING_DATA_TEST = """ 229 | select * from ( 230 | select 1 as id 231 | ) as my_subquery 232 | where id = 1 233 | """ 234 | 235 | 236 | TEST_EPHEMERAL_DATA_TEST_PASSING = ''' 237 | with my_other_cool_cte as ( 238 | select id, name from {{ ref('ephemeral') }} 239 | where id > 1000 240 | ) 241 | select name, id from my_other_cool_cte 242 | ''' 243 | 244 | 245 | TEST_EPHEMERAL_DATA_TEST_FAILING = ''' 246 | with my_other_cool_cte as ( 247 | select id, name from {{ ref('ephemeral') }} 248 | where id < 1000 249 | ) 250 | select name, id from my_other_cool_cte 251 | ''' 252 | 253 | INCREMENTAL_MODEL = """ 254 | select * from {{ source('raw', 'seed') }} 255 | {% if is_incremental() %} 256 | where id > (select max(id) from {{ this }}) 257 | {% endif %} 258 | """.strip() 259 | 260 | 261 | CC_ALL_SNAPSHOT_SQL = ''' 262 | {% snapshot cc_all_snapshot %} 263 | {{ config( 264 | check_cols='all', unique_key='id', strategy='check', 265 | target_database=database, target_schema=schema 266 | ) }} 267 | select * from {{ ref(var('seed_name', 'base')) }} 268 | {% endsnapshot %} 269 | '''.strip() 270 | 271 | 272 | CC_NAME_SNAPSHOT_SQL = ''' 273 | {% snapshot cc_name_snapshot %} 274 | {{ config( 275 | check_cols=['name'], unique_key='id', strategy='check', 276 | target_database=database, target_schema=schema 277 | ) }} 278 | select * from {{ ref(var('seed_name', 'base')) }} 279 | {% endsnapshot %} 280 | '''.strip() 281 | 282 | 283 | CC_DATE_SNAPSHOT_SQL = ''' 284 | {% snapshot cc_date_snapshot %} 285 | {{ config( 286 | check_cols=['some_date'], unique_key='id', strategy='check', 287 | target_database=database, target_schema=schema 288 | ) }} 289 | select * from {{ ref(var('seed_name', 'base')) }} 290 | {% endsnapshot %} 291 | '''.strip() 292 | 293 | 294 | TS_SNAPSHOT_SQL = ''' 295 | {% snapshot ts_snapshot %} 296 | {{ config( 297 | strategy='timestamp', 298 | unique_key='id', 299 | updated_at='some_date', 300 | target_database=database, 301 | target_schema=schema, 302 | )}} 303 | select * from {{ ref(var('seed_name', 'base')) }} 304 | {% endsnapshot %} 305 | '''.strip() 306 | 307 | 308 | EPHEMERAL_WITH_CTE = """ 309 | with my_cool_cte as ( 310 | select name, id from {{ ref('base') }} 311 | ) 312 | select id, name from my_cool_cte where id is not null 313 | """ 314 | 315 | 316 | KNOWN_FILES = { 317 | 'seeds': { 318 | 'base': NAMES_BASE, 319 | 'newcolumns': NAMES_ADD_COLUMN, 320 | 'added': NAMES_EXTENDED, 321 | }, 322 | 'models': { 323 | 'base_materialized_var': """ 324 | {{ config(materialized=var("materialized_var", "table"))}} 325 | select * from {{ source('raw', 'seed') }} 326 | """, 327 | 'base_table': { 328 | 'materialized': 'table', 329 | 'body': "select * from {{ source('raw', 'seed') }}", 330 | }, 331 | 'base_view': { 332 | 'materialized': 'view', 333 | 'body': "select * from {{ source('raw', 'seed') }}", 334 | }, 335 | 'ephemeral': { 336 | 'materialized': 'ephemeral', 337 | 'body': "select * from {{ source('raw', 'seed') }}", 338 | }, 339 | 'ephemeral_with_cte': { 340 | 'materialized': 'ephemeral', 341 | 'body': EPHEMERAL_WITH_CTE, 342 | }, 343 | 'ephemeral_view': { 344 | 'materialized': 'view', 345 | 'body': "select * from {{ ref('ephemeral') }}", 346 | }, 347 | 'ephemeral_table': { 348 | 'materialized': 'table', 349 | 'body': "select * from {{ ref('ephemeral') }}", 350 | }, 351 | 'incremental': { 352 | 'materialized': 'incremental', 353 | 'body': INCREMENTAL_MODEL, 354 | } 355 | }, 356 | 'snapshots': { 357 | 'check_cols_all': CC_ALL_SNAPSHOT_SQL, 358 | 'check_cols_name': CC_NAME_SNAPSHOT_SQL, 359 | 'check_cols_date': CC_DATE_SNAPSHOT_SQL, 360 | 'timestamp': TS_SNAPSHOT_SQL, 361 | }, 362 | 'tests': { 363 | 'passing': TEST_PASSING_DATA_TEST, 364 | 'failing': TEST_FAILING_DATA_TEST, 365 | 'ephemeral': { 366 | 'passing': TEST_EPHEMERAL_DATA_TEST_PASSING, 367 | 'failing': TEST_EPHEMERAL_DATA_TEST_FAILING, 368 | } 369 | }, 370 | 'schemas': { 371 | 'base': SEED_SOURCE_YML, 372 | 'test_seed': TEST_SEEDS_SCHEMA_YML_TEST_BASE, 373 | 'test_view': TEST_MODELS_SCHEMA_YML_TEST_VIEW, 374 | 'test_table': TEST_MODELS_SCHEMA_YML_TEST_TABLE, 375 | }, 376 | } 377 | 378 | 379 | THIS_DIR = Path(__file__).parent 380 | PROJECT_DIR = THIS_DIR / 'projects' 381 | SEQUENCE_DIR = THIS_DIR / 'sequences' 382 | 383 | 384 | def _get_named_yaml_dicts(path: Path) -> Dict[str, Dict[str, Any]]: 385 | result = {} 386 | for project_path in path.glob('**/*.yml'): 387 | try: 388 | data = yaml.safe_load(project_path.read_text()) 389 | name = data['name'] 390 | except KeyError as exc: 391 | raise ImportError( 392 | f'Invalid project file at {project_path}: no name' 393 | ) from exc 394 | except Exception as exc: 395 | raise ImportError( 396 | f'Could not read project file at {project_path}: {exc}' 397 | ) from exc 398 | result[name] = data 399 | return result 400 | 401 | 402 | DEFAULT_PROJECTS = _get_named_yaml_dicts(PROJECT_DIR) 403 | 404 | BUILTIN_TEST_SEQUENCES = _get_named_yaml_dicts(SEQUENCE_DIR) 405 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/exceptions.py: -------------------------------------------------------------------------------- 1 | class DBTException(Exception): 2 | """ custom exception for error reporting. """ 3 | 4 | 5 | class TestProcessingException(Exception): 6 | pass 7 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbt-labs/dbt-adapter-tests/c447c95ec6420d738729eee34dfc379a49ba64dc/pytest_dbt_adapter/projects/__init__.py -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/base.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | models/view_model.sql: files.models.base_view 5 | models/table_model.sql: files.models.base_table 6 | models/swappable.sql: files.models.base_materialized_var 7 | models/schema.yml: files.schemas.base 8 | dbt_project_yml: 9 | models: 10 | dbt_test_project: 11 | 12 | facts: 13 | seed: 14 | length: 1 15 | names: 16 | - base 17 | run: 18 | length: 3 19 | names: 20 | - view_model 21 | - table_model 22 | - swappable 23 | catalog: 24 | nodes: 25 | length: 4 26 | sources: 27 | length: 1 28 | persisted_relations: 29 | - base 30 | - view_model 31 | - table_model 32 | - swappable 33 | base: 34 | rowcount: 10 35 | expected_types_view: 36 | base: table 37 | view_model: view 38 | table_model: table 39 | swappable: view 40 | expected_types_table: 41 | base: table 42 | view_model: view 43 | table_model: table 44 | swappable: table 45 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/data_test_ephemerals.yml: -------------------------------------------------------------------------------- 1 | name: data_test_ephemeral_models 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | models/ephemeral.sql: files.models.ephemeral_with_cte 5 | models/passing_model.sql: files.tests.ephemeral.passing 6 | models/failing_model.sql: files.tests.ephemeral.failing 7 | models/schema.yml: files.schemas.base 8 | tests/passing.sql: files.tests.ephemeral.passing 9 | tests/failing.sql: files.tests.ephemeral.failing 10 | dbt_project_yml: 11 | test-paths: 12 | - tests 13 | 14 | facts: 15 | seed: 16 | length: 1 17 | names: 18 | - base 19 | run: 20 | length: 2 21 | names: 22 | - passing_model 23 | - failing_model 24 | test: 25 | length: 2 26 | names: 27 | - passing 28 | - failing 29 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/data_tests.yml: -------------------------------------------------------------------------------- 1 | name: data_tests 2 | paths: 3 | tests/passing.sql: files.tests.passing 4 | tests/failing.sql: files.tests.failing 5 | dbt_project_yml: 6 | test-paths: 7 | - tests 8 | 9 | facts: 10 | test: 11 | length: 2 12 | names: 13 | - passing 14 | - failing 15 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/empty.yml: -------------------------------------------------------------------------------- 1 | name: empty 2 | facts: 3 | seed: 4 | length: 0 5 | run: 6 | length: 0 7 | catalog: 8 | nodes: 9 | length: 0 10 | sources: 11 | length: 0 12 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/ephemeral.yml: -------------------------------------------------------------------------------- 1 | name: ephemeral 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | models/ephemeral.sql: files.models.ephemeral 5 | models/view_model.sql: files.models.ephemeral_view 6 | models/table_model.sql: files.models.ephemeral_table 7 | models/schema.yml: files.schemas.base 8 | 9 | facts: 10 | seed: 11 | length: 1 12 | names: 13 | - base 14 | run: 15 | length: 2 16 | names: 17 | - view_model 18 | - table_model 19 | catalog: 20 | nodes: 21 | length: 3 22 | sources: 23 | length: 1 24 | persisted_relations: 25 | - base 26 | - view_model 27 | - table_model 28 | base: 29 | rowcount: 10 30 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/incremental.yml: -------------------------------------------------------------------------------- 1 | name: incremental 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | seeds/added.csv: files.seeds.added 5 | models/incremental.sql: files.models.incremental 6 | models/schema.yml: files.schemas.base 7 | 8 | facts: 9 | seed: 10 | length: 2 11 | names: 12 | - base 13 | - added 14 | run: 15 | length: 1 16 | names: 17 | - incremental 18 | catalog: 19 | nodes: 20 | length: 3 21 | sources: 22 | length: 1 23 | persisted_relations: 24 | - base 25 | - added 26 | - incremental 27 | base: 28 | rowcount: 10 29 | added: 30 | rowcount: 20 31 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/schema_tests.yml: -------------------------------------------------------------------------------- 1 | name: schema_tests 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | seeds/schema.yml: files.schemas.test_seed 5 | models/view_model.sql: files.models.base_view 6 | models/table_model.sql: files.models.base_table 7 | models/schema.yml: files.schemas.base 8 | models/schema_view.yml: files.schemas.test_view 9 | models/schema_table.yml: files.schemas.test_table 10 | 11 | facts: 12 | seed: 13 | length: 1 14 | names: 15 | - base 16 | run: 17 | length: 2 18 | names: 19 | - view_model 20 | - table_model 21 | test: 22 | length: 3 23 | catalog: 24 | nodes: 25 | length: 3 26 | sources: 27 | length: 1 28 | persisted_relations: 29 | - base 30 | - view_model 31 | - table_model 32 | base: 33 | rowcount: 10 34 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/snapshot_cc.yml: -------------------------------------------------------------------------------- 1 | name: snapshot_strategy_check_cols 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | seeds/newcolumns.csv: files.seeds.newcolumns 5 | seeds/added.csv: files.seeds.added 6 | snapshots/cc_all_snapshot.sql: files.snapshots.check_cols_all 7 | snapshots/cc_date_snapshot.sql: files.snapshots.check_cols_date 8 | snapshots/cc_name_snapshot.sql: files.snapshots.check_cols_name 9 | 10 | facts: 11 | seed: 12 | length: 3 13 | names: 14 | - base 15 | - newcolumns 16 | - added 17 | snapshot: 18 | length: 3 19 | names: 20 | - cc_all_snapshot 21 | - cc_name_snapshot 22 | - cc_date_snapshot 23 | base: 24 | rowcount: 10 25 | added: 26 | rowcount: 20 27 | newcolumns: 28 | rowcount: 10 29 | added_plus_ten: 30 | rowcount: 30 31 | added_plus_twenty: 32 | rowcount: 40 33 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/projects/snapshot_ts.yml: -------------------------------------------------------------------------------- 1 | name: snapshot_strategy_timestamp 2 | paths: 3 | seeds/base.csv: files.seeds.base 4 | seeds/newcolumns.csv: files.seeds.newcolumns 5 | seeds/added.csv: files.seeds.added 6 | snapshots/ts_snapshot.sql: files.snapshots.timestamp 7 | 8 | facts: 9 | seed: 10 | length: 3 11 | names: 12 | - base 13 | - newcolumns 14 | - added 15 | snapshot: 16 | length: 1 17 | names: 18 | - ts_snapshot 19 | base: 20 | rowcount: 10 21 | added: 22 | rowcount: 20 23 | newcolumns: 24 | rowcount: 10 25 | added_plus_ten: 26 | rowcount: 30 27 | added_plus_twenty: 28 | rowcount: 40 29 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbt-labs/dbt-adapter-tests/c447c95ec6420d738729eee34dfc379a49ba64dc/pytest_dbt_adapter/sequences/__init__.py -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/base.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | project: base 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | length: fact.seed.length 8 | - type: dbt 9 | cmd: run 10 | - type: run_results 11 | length: fact.run.length 12 | - type: relation_types 13 | expect: fact.expected_types_table 14 | - type: relation_rows 15 | name: base 16 | length: fact.base.rowcount 17 | - type: relations_equal 18 | relations: fact.persisted_relations 19 | - type: dbt 20 | cmd: docs generate 21 | - type: catalog 22 | exists: True 23 | nodes: 24 | length: fact.catalog.nodes.length 25 | sources: 26 | length: fact.catalog.sources.length 27 | # now swap 28 | - type: dbt 29 | cmd: run -m swappable 30 | vars: 31 | materialized_var: view 32 | - type: run_results 33 | length: 1 34 | - type: relation_types 35 | expect: fact.expected_types_view 36 | # now incremental 37 | - type: dbt 38 | cmd: run -m swappable 39 | vars: 40 | materialized_var: incremental 41 | - type: run_results 42 | length: 1 43 | - type: relation_types 44 | expect: fact.expected_types_table 45 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/data_test.yml: -------------------------------------------------------------------------------- 1 | name: data_test 2 | project: data_tests 3 | sequence: 4 | - type: dbt 5 | cmd: test 6 | check: false 7 | - type: run_results 8 | length: fact.test.length 9 | names: fact.test.names 10 | attributes: 11 | passing.status: pass 12 | failing.status: fail 13 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/data_test_ephemeral_models.yml: -------------------------------------------------------------------------------- 1 | name: data_test_ephemeral_models 2 | project: data_test_ephemeral_models 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | length: fact.seed.length 8 | - type: dbt 9 | cmd: test 10 | check: false 11 | - type: run_results 12 | length: fact.test.length 13 | names: fact.test.names 14 | attributes: 15 | passing.status: pass 16 | failing.status: fail 17 | - type: dbt 18 | cmd: run 19 | - type: run_results 20 | length: fact.run.length 21 | names: fact.run.names 22 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/empty.yml: -------------------------------------------------------------------------------- 1 | name: empty 2 | project: empty 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | exists: True 8 | - type: dbt 9 | cmd: run 10 | - type: run_results 11 | exists: True 12 | - type: catalog 13 | exists: False 14 | - type: dbt 15 | cmd: docs generate 16 | - type: run_results 17 | exists: True 18 | - type: catalog 19 | exists: True 20 | nodes: 21 | length: fact.catalog.nodes.length 22 | sources: 23 | length: fact.catalog.sources.length 24 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/ephemeral.yml: -------------------------------------------------------------------------------- 1 | name: ephemeral 2 | project: ephemeral 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | length: fact.seed.length 8 | - type: dbt 9 | cmd: run 10 | - type: run_results 11 | length: fact.run.length 12 | - type: relation_rows 13 | name: base 14 | length: fact.base.rowcount 15 | - type: relations_equal 16 | relations: fact.persisted_relations 17 | - type: dbt 18 | cmd: docs generate 19 | - type: catalog 20 | exists: True 21 | nodes: 22 | length: fact.catalog.nodes.length 23 | sources: 24 | length: fact.catalog.sources.length 25 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/incremental.yml: -------------------------------------------------------------------------------- 1 | name: incremental 2 | project: incremental 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | length: fact.seed.length 8 | - type: dbt 9 | cmd: run 10 | vars: 11 | seed_name: base 12 | - type: relation_rows 13 | name: base 14 | length: fact.base.rowcount 15 | - type: run_results 16 | length: fact.run.length 17 | - type: relations_equal 18 | relations: 19 | - base 20 | - incremental 21 | - type: dbt 22 | cmd: run 23 | vars: 24 | seed_name: added 25 | - type: relation_rows 26 | name: added 27 | length: fact.added.rowcount 28 | - type: run_results 29 | length: fact.run.length 30 | - type: relations_equal 31 | relations: 32 | - added 33 | - incremental 34 | - type: dbt 35 | cmd: docs generate 36 | - type: catalog 37 | exists: True 38 | nodes: 39 | length: fact.catalog.nodes.length 40 | sources: 41 | length: fact.catalog.sources.length 42 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/schema_test.yml: -------------------------------------------------------------------------------- 1 | name: schema_test 2 | project: schema_tests 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: dbt 7 | cmd: test -m base 8 | - type: run_results 9 | length: fact.seed.length 10 | - type: dbt 11 | cmd: run 12 | - type: dbt 13 | cmd: test 14 | length: fact.test.length 15 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/snapshot_strategy_check_cols.yml: -------------------------------------------------------------------------------- 1 | name: snapshot_strategy_check_cols 2 | project: snapshot_strategy_check_cols 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | length: fact.seed.length 8 | - type: dbt 9 | cmd: snapshot 10 | - type: relation_rows 11 | name: cc_all_snapshot 12 | length: fact.base.rowcount 13 | - type: relation_rows 14 | name: cc_name_snapshot 15 | length: fact.base.rowcount 16 | - type: relation_rows 17 | name: cc_date_snapshot 18 | length: fact.base.rowcount 19 | # point at the "added" seed so the snapshot sees 10 new rows 20 | - type: dbt 21 | cmd: snapshot 22 | vars: 23 | seed_name: added 24 | - type: relation_rows 25 | name: cc_all_snapshot 26 | length: fact.added.rowcount 27 | - type: relation_rows 28 | name: cc_name_snapshot 29 | length: fact.added.rowcount 30 | - type: relation_rows 31 | name: cc_date_snapshot 32 | length: fact.added.rowcount 33 | # update some timestamps in the "added" seed so the snapshot sees 10 more new rows 34 | - type: update_rows 35 | name: added 36 | dst_col: some_date 37 | clause: 38 | src_col: some_date 39 | type: add_timestamp 40 | where: id > 10 and id < 21 41 | - type: dbt 42 | cmd: snapshot 43 | vars: 44 | seed_name: added 45 | - type: relation_rows 46 | name: cc_all_snapshot 47 | length: fact.added_plus_ten.rowcount 48 | - type: relation_rows 49 | # unchanged: only the timestamp changed 50 | name: cc_name_snapshot 51 | length: fact.added.rowcount 52 | - type: relation_rows 53 | name: cc_date_snapshot 54 | length: fact.added_plus_ten.rowcount 55 | - type: update_rows 56 | name: added 57 | dst_col: name 58 | clause: 59 | src_col: name 60 | type: add_string 61 | value: _updated 62 | where: id < 11 63 | - type: dbt 64 | cmd: snapshot 65 | vars: 66 | seed_name: added 67 | - type: relation_rows 68 | name: cc_all_snapshot 69 | length: fact.added_plus_twenty.rowcount 70 | - type: relation_rows 71 | name: cc_name_snapshot 72 | length: fact.added_plus_ten.rowcount 73 | # does not see name updates 74 | - type: relation_rows 75 | name: cc_date_snapshot 76 | length: fact.added_plus_ten.rowcount 77 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/sequences/snapshot_strategy_timestamp.yml: -------------------------------------------------------------------------------- 1 | name: snapshot_strategy_timestamp 2 | project: snapshot_strategy_timestamp 3 | sequence: 4 | - type: dbt 5 | cmd: seed 6 | - type: run_results 7 | length: fact.seed.length 8 | - type: dbt 9 | cmd: snapshot 10 | - type: relation_rows 11 | name: ts_snapshot 12 | length: fact.base.rowcount 13 | # point at the "added" seed so the snapshot sees 10 new rows 14 | - type: dbt 15 | cmd: snapshot 16 | vars: 17 | seed_name: added 18 | - type: relation_rows 19 | name: ts_snapshot 20 | length: fact.added.rowcount 21 | # update some timestamps in the "added" seed so the snapshot sees 10 more new rows 22 | - type: update_rows 23 | name: added 24 | dst_col: some_date 25 | clause: 26 | src_col: some_date 27 | type: add_timestamp 28 | where: id > 10 and id < 21 29 | - type: dbt 30 | cmd: snapshot 31 | vars: 32 | seed_name: added 33 | - type: relation_rows 34 | name: ts_snapshot 35 | length: fact.added_plus_ten.rowcount 36 | - type: update_rows 37 | name: added 38 | dst_col: name 39 | clause: 40 | src_col: name 41 | type: add_string 42 | value: _updated 43 | where: id < 11 44 | - type: dbt 45 | cmd: snapshot 46 | vars: 47 | seed_name: added 48 | # no change in row count, because the timestamp was not updated 49 | - type: relation_rows 50 | name: ts_snapshot 51 | length: fact.added_plus_ten.rowcount 52 | -------------------------------------------------------------------------------- /pytest_dbt_adapter/spec_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import shlex 5 | import tempfile 6 | from datetime import datetime 7 | from itertools import chain, repeat 8 | from subprocess import run, CalledProcessError, PIPE 9 | from typing import Dict, Any, Iterable 10 | 11 | import pytest 12 | import yaml 13 | 14 | from .exceptions import TestProcessingException, DBTException 15 | from .builtin import BUILTIN_TEST_SEQUENCES, DEFAULT_PROJECTS, DbtProject 16 | 17 | 18 | from dbt.adapters.factory import FACTORY 19 | from dbt.config import RuntimeConfig 20 | from dbt.main import parse_args 21 | 22 | 23 | class DbtSpecFile(pytest.File): 24 | def collect(self): 25 | with self.fspath.open() as fp: 26 | raw = yaml.safe_load(fp) 27 | if not raw: 28 | return 29 | try: 30 | raw_target = raw['target'] 31 | except KeyError: 32 | raise TestProcessingException( 33 | 'Invalid dbtspec: target not found' 34 | ) from None 35 | 36 | projects = { 37 | k: DbtProject.from_dict(v) 38 | for k, v in DEFAULT_PROJECTS.items() 39 | } 40 | 41 | for project in raw.get('projects', []): 42 | parsed = DbtProject.from_dict(project, projects) 43 | projects[parsed.name] = parsed 44 | 45 | try: 46 | sequences = raw['sequences'] 47 | except KeyError: 48 | raise TestProcessingException( 49 | 'Invalid dbtspec: sequences not found' 50 | ) from None 51 | 52 | for name, testdef in sequences.items(): 53 | if isinstance(testdef, str): 54 | try: 55 | testdef = BUILTIN_TEST_SEQUENCES[testdef] 56 | except KeyError: 57 | raise TestProcessingException( 58 | f'Unknown builtin test name {testdef}' 59 | ) 60 | try: 61 | project_name = testdef['project'] 62 | except KeyError: 63 | raise TestProcessingException( 64 | f'Invalid dbtspec: no project in sequence {testdef}' 65 | ) from None 66 | 67 | try: 68 | project = projects[project_name] 69 | except KeyError: 70 | raise TestProcessingException( 71 | f'Invalid dbtspec: project {project_name} unknown' 72 | ) from None 73 | 74 | try: 75 | sequence = testdef['sequence'] 76 | except KeyError: 77 | raise TestProcessingException( 78 | f'Invalid dbtspec: no sequence in sequence {testdef}' 79 | ) from None 80 | 81 | yield DbtItem.from_parent( 82 | self, 83 | name=name, 84 | target=raw_target, 85 | sequence=sequence, 86 | project=project, 87 | ) 88 | 89 | 90 | class DbtItem(pytest.Item): 91 | def __init__(self, name, parent, target, sequence, project): 92 | super().__init__(name, parent) 93 | self.target = target 94 | self.sequence = sequence 95 | self.project = project 96 | self.adapter = None 97 | self.schema_relation = None 98 | start = datetime.utcnow().strftime('%y%m%d%H%M%S%f') 99 | randval = random.SystemRandom().randint(0, 999999) 100 | self.random_suffix = f'{start}{randval:06}' 101 | 102 | def _base_vars(self): 103 | return {'_dbt_random_suffix': self.random_suffix} 104 | 105 | def _get_adapter(self, tmpdir): 106 | project_path = os.path.join(tmpdir, 'project') 107 | args = parse_args([ 108 | 'compile', '--profile', 'dbt-pytest', '--target', 'default', 109 | '--project-dir', project_path, '--profiles-dir', tmpdir, 110 | '--vars', yaml.safe_dump(self._base_vars()), 111 | ]) 112 | with open(os.path.join(args.profiles_dir, 'profiles.yml')) as fp: 113 | data = yaml.safe_load(fp) 114 | try: 115 | profile = data[args.profile] 116 | except KeyError: 117 | raise ValueError(f'profile {args.profile} not found') 118 | try: 119 | outputs = profile['outputs'] 120 | except KeyError: 121 | raise ValueError(f'malformed profile {args.profile}') 122 | try: 123 | target = outputs[args.target] 124 | except KeyError: 125 | raise ValueError( 126 | f'target {args.target} not found in {args.profile}' 127 | ) 128 | try: 129 | adapter_type = target['type'] 130 | except KeyError: 131 | raise ValueError( 132 | f'target {args.target} in {args.profile} has no type') 133 | _ = FACTORY.load_plugin(adapter_type) 134 | config = RuntimeConfig.from_args(args) 135 | 136 | FACTORY.register_adapter(config) 137 | adapter = FACTORY.lookup_adapter(config.credentials.type) 138 | return adapter 139 | 140 | @staticmethod 141 | def _get_from_dict(dct: Dict[str, Any], keypath: Iterable[str]): 142 | value = dct 143 | for key in keypath: 144 | value = value[key] 145 | return value 146 | 147 | def _update_nested_dict( 148 | dct: Dict[str, Any], keypath: Iterable[str], value: Any 149 | ): 150 | next_key, keypath = keypath[0], keypath[1:] 151 | for cur_key in keypath: 152 | if next_key not in dct: 153 | dct[next_key] = {} 154 | dct = dct[next_key] 155 | next_key = cur_key 156 | dct[next_key] = value 157 | 158 | def get_fact(self, key): 159 | if isinstance(key, str) and key.startswith('fact.'): 160 | parts = key.split('.')[1:] 161 | try: 162 | return self._get_from_dict(self.project.facts, parts) 163 | except KeyError: 164 | pass 165 | return key 166 | 167 | def _relation_from_name(self, name: str): 168 | """reverse-engineer a relation (including quoting) from a given name and 169 | the adapter. 170 | 171 | This does assume that relations are split by the `.` character. 172 | 173 | Note that this doesn't really have to be correct, it only has to 174 | round-trip properly. Still, do our best to get this right. 175 | """ 176 | cls = self.adapter.Relation 177 | credentials = self.adapter.config.credentials 178 | quote_policy = cls.get_default_quote_policy().to_dict() 179 | include_policy = cls.get_default_include_policy().to_dict() 180 | kwargs = {} 181 | 182 | parts = name.split('.') 183 | if len(parts) == 0: # I think this is literally impossible! 184 | raise TestProcessingException(f'Invalid test name {name}') 185 | 186 | names = ['database', 'schema', 'identifier'] 187 | defaults = [credentials.database, credentials.schema, None] 188 | values = chain(repeat(None, 3 - len(parts)), parts) 189 | for name, value, default in zip(names, values, defaults): 190 | # no quote policy -> use the default 191 | if value is None: 192 | if default is None: 193 | include_policy[name] = False 194 | value = default 195 | else: 196 | include_policy[name] = True 197 | # if we have a value, we can figure out the quote policy. 198 | trimmed = value[1:-1] 199 | if self.adapter.quote(trimmed) == value: 200 | quote_policy[name] = True 201 | value = trimmed 202 | else: 203 | quote_policy[name] = False 204 | kwargs[name] = value 205 | 206 | return cls.create( 207 | include_policy=include_policy, 208 | quote_policy=quote_policy, 209 | **kwargs 210 | ) 211 | 212 | def step_dbt(self, sequence_item, tmpdir): 213 | if 'cmd' not in sequence_item: 214 | raise TestProcessingException( 215 | f'Got item type cmd, but no cmd in {sequence_item}' 216 | ) 217 | cmd = shlex.split(sequence_item['cmd']) 218 | partial_parse = sequence_item.get('partial_parse', False) 219 | extra = [ 220 | '--target', 'default', 221 | '--profile', 'dbt-pytest', 222 | '--profiles-dir', tmpdir, 223 | '--project-dir', os.path.join(tmpdir, 'project') 224 | ] 225 | base_cmd = ['dbt', '--debug'] 226 | 227 | if partial_parse: 228 | base_cmd.append('--partial-parse') 229 | else: 230 | base_cmd.append('--no-partial-parse') 231 | 232 | full_cmd = base_cmd + cmd + extra 233 | cli_vars = sequence_item.get('vars', {}).copy() 234 | cli_vars.update(self._base_vars()) 235 | if cli_vars: 236 | full_cmd.extend(('--vars', yaml.safe_dump(cli_vars))) 237 | expect_passes = sequence_item.get('check', True) 238 | result = run(full_cmd, check=False, stdout=PIPE, stderr=PIPE) 239 | print(result.stdout.decode('utf-8')) 240 | if expect_passes: 241 | if result.returncode != 0: 242 | raise TestProcessingException( 243 | f'Command {full_cmd} failed, expected pass! Got ' 244 | f'rc={result.returncode}' 245 | ) 246 | else: 247 | if result.returncode == 0: 248 | raise TestProcessingException( 249 | f'Command {full_cmd} passed, expected failure! Got ' 250 | f'rc={result.returncode}' 251 | ) 252 | return result 253 | 254 | @staticmethod 255 | def _build_expected_attributes_dict( 256 | values: Dict[str, Any] 257 | ) -> Dict[str, Any]: 258 | # turn keys into nested dicts 259 | attributes = {} 260 | for key, value in values.items(): 261 | parts = key.split('.', 1) 262 | if len(parts) != 2: 263 | raise TestProcessingException( 264 | f'Expected a longer keypath, only got "{key}" ' 265 | '(no attributes?)' 266 | ) 267 | name, keypath = parts 268 | 269 | if name not in attributes: 270 | attributes[name] = {} 271 | attributes[name][keypath] = value 272 | return attributes 273 | 274 | @staticmethod 275 | def _get_name( 276 | result: Dict[str, Any], 277 | nodes: Dict[str, Any] 278 | ) -> str: 279 | """Given a run result get the unique_id and lookup the name from 280 | a dict of nodes mapped to their unique_id. 281 | """ 282 | try: 283 | unique_id = result['unique_id'] 284 | except KeyError as exc: 285 | raise DBTException( 286 | f'Invalid result, missing required key {exc}' 287 | ) from None 288 | try: 289 | return nodes[unique_id]['name'] 290 | except KeyError as exc: 291 | raise DBTException( 292 | f'Invalid node, missing required key {exc}' 293 | ) from None 294 | 295 | def step_run_results(self, sequence_item, tmpdir): 296 | run_results_path = os.path.join( 297 | tmpdir, 'project', 'target', 'run_results.json') 298 | manifest_path = os.path.join( 299 | tmpdir, 'project', 'target', 'manifest.json') 300 | 301 | expect_exists = sequence_item.get('exists', True) 302 | 303 | assert expect_exists == os.path.exists(run_results_path) 304 | if not expect_exists: 305 | return None 306 | 307 | try: 308 | with open( 309 | run_results_path 310 | ) as results_fp, open( 311 | manifest_path 312 | ) as manifest_fp: 313 | run_results_data = json.load(results_fp) 314 | manifest_data = json.load(manifest_fp) 315 | except Exception as exc: 316 | raise DBTException( 317 | f'could not load run_results.json: {exc}' 318 | ) from exc 319 | try: 320 | results = run_results_data['results'] 321 | except KeyError: 322 | raise DBTException( 323 | 'Invalid run_results.json - no results' 324 | ) from None 325 | try: 326 | nodes = manifest_data['nodes'] 327 | except KeyError: 328 | raise DBTException( 329 | 'Invalid manifest.json - no nodes' 330 | ) from None 331 | 332 | if 'length' in sequence_item: 333 | expected = self.get_fact(sequence_item['length']) 334 | assert expected == len(results) 335 | if 'names' in sequence_item: 336 | expected_names = set(self.get_fact(sequence_item['names'])) 337 | extra_results_ok = sequence_item.get('extra_results_ok', False) 338 | 339 | for result in results: 340 | name = self._get_name(result, nodes) 341 | if (not extra_results_ok) and (name not in expected_names): 342 | raise DBTException( 343 | f'Got unexpected name {name} in results' 344 | ) 345 | expected_names.discard(name) 346 | if expected_names: 347 | raise DBTException( 348 | f'Nodes missing from run_results: {list(expected_names)}' 349 | ) 350 | if 'attributes' in sequence_item: 351 | values = self.get_fact(sequence_item['attributes']) 352 | 353 | attributes = self._build_expected_attributes_dict(values) 354 | 355 | for result in results: 356 | name = self._get_name(result, nodes) 357 | if name in attributes: 358 | for key, value in attributes[name].items(): 359 | try: 360 | self._get_from_dict(result, key.split('.')) 361 | except KeyError as exc: 362 | raise DBTException( 363 | f'Invalid result, missing required key {exc}' 364 | ) from None 365 | 366 | def _expected_catalog_member(self, sequence_item, catalog, member_name): 367 | if member_name not in catalog: 368 | raise DBTException( 369 | f'invalid catalog.json: no {member_name}!' 370 | ) 371 | 372 | actual = catalog[member_name] 373 | expected = sequence_item.get(member_name, {}) 374 | if 'length' in expected: 375 | expected_length = self.get_fact(expected['length']) 376 | assert len(actual) == expected_length 377 | 378 | if 'names' in expected: 379 | extra_nodes_ok = expected.get('extra_nodes_ok', False) 380 | expected_names = set(self.get_fact(expected['names'])) 381 | for node in actual.values(): 382 | try: 383 | name = node['metadata']['name'] 384 | except KeyError as exc: 385 | singular = member_name[:-1] 386 | raise TestProcessingException( 387 | f'Invalid catalog {singular}: missing key {exc}' 388 | ) from None 389 | if (not extra_nodes_ok) and (name not in expected_names): 390 | raise DBTException( 391 | f'Got unexpected name {name} in catalog' 392 | ) 393 | expected_names.discard(name) 394 | if expected_names: 395 | raise DBTException( 396 | f'{member_name.title()} missing from run_results: ' 397 | f'{list(expected_names)}' 398 | ) 399 | 400 | def step_catalog(self, sequence_item, tmpdir): 401 | path = os.path.join(tmpdir, 'project', 'target', 'catalog.json') 402 | expect_exists = sequence_item.get('exists', True) 403 | 404 | assert expect_exists == os.path.exists(path) 405 | if not expect_exists: 406 | return None 407 | 408 | try: 409 | with open(path) as fp: 410 | catalog = json.load(fp) 411 | except Exception as exc: 412 | raise DBTException( 413 | f'could not load catalog.json: {exc}' 414 | ) from exc 415 | 416 | self._expected_catalog_member(sequence_item, catalog, 'nodes') 417 | self._expected_catalog_member(sequence_item, catalog, 'sources') 418 | 419 | def step_relations_equal(self, sequence_item): 420 | if 'relations' not in sequence_item: 421 | raise TestProcessingException( 422 | 'Invalid relations_equal: no relations' 423 | ) 424 | relation_names = self.get_fact(sequence_item['relations']) 425 | assert isinstance(relation_names, list) 426 | if len(relation_names) < 2: 427 | raise TestProcessingException( 428 | 'Not enough relations to compare', 429 | ) 430 | relations = [ 431 | self._relation_from_name(name) for name in relation_names 432 | ] 433 | with self.adapter.connection_named('_test'): 434 | basis, compares = relations[0], relations[1:] 435 | columns = [ 436 | c.name for c in self.adapter.get_columns_in_relation(basis) 437 | ] 438 | 439 | for relation in compares: 440 | sql = self.adapter.get_rows_different_sql( 441 | basis, relation, column_names=columns 442 | ) 443 | _, tbl = self.adapter.execute(sql, fetch=True) 444 | num_rows = len(tbl) 445 | assert num_rows == 1, f'Invalid sql query from get_rows_different_sql: incorrect number of rows ({num_rows})' 446 | num_cols = len(tbl[0]) 447 | assert num_cols == 2, f'Invalid sql query from get_rows_different_sql: incorrect number of cols ({num_cols})' 448 | row_count_difference = tbl[0][0] 449 | assert row_count_difference == 0, f'Got {row_count_difference} difference in row count betwen {basis} and {relation}' 450 | rows_mismatched = tbl[0][1] 451 | assert rows_mismatched == 0, f'Got {rows_mismatched} different rows between {basis} and {relation}' 452 | 453 | def step_relation_rows(self, sequence_item): 454 | if 'name' not in sequence_item: 455 | raise TestProcessingException('Invalid relation_rows: no name') 456 | if 'length' not in sequence_item: 457 | raise TestProcessingException('Invalid relation_rows: no length') 458 | name = self.get_fact(sequence_item['name']) 459 | length = self.get_fact(sequence_item['length']) 460 | relation = self._relation_from_name(name) 461 | with self.adapter.connection_named('_test'): 462 | _, tbl = self.adapter.execute( 463 | f'select count(*) as num_rows from {relation}', 464 | fetch=True 465 | ) 466 | 467 | assert len(tbl) == 1 and len(tbl[0]) == 1, \ 468 | 'count did not return 1 row with 1 column' 469 | assert tbl[0][0] == length, \ 470 | f'expected {name} to have {length} rows, but it has {tbl[0][0]}' 471 | 472 | def _generate_update_clause(self, clause) -> str: 473 | if 'type' not in clause: 474 | raise TestProcessingException( 475 | 'invalid update_rows clause: no type' 476 | ) 477 | clause_type = clause['type'] 478 | 479 | if clause_type == 'add_timestamp': 480 | if 'src_col' not in clause: 481 | raise TestProcessingException( 482 | 'Invalid update_rows clause: no src_col' 483 | ) 484 | add_to = self.get_fact(clause['src_col']) 485 | kwargs = { 486 | k: self.get_fact(v) for k, v in clause.items() 487 | if k in ('interval', 'number') 488 | } 489 | with self.adapter.connection_named('_test'): 490 | return self.adapter.timestamp_add_sql( 491 | add_to=add_to, 492 | **kwargs 493 | ) 494 | elif clause_type == 'add_string': 495 | if 'src_col' not in clause: 496 | raise TestProcessingException( 497 | 'Invalid update_rows clause: no src_col' 498 | ) 499 | if 'value' not in clause: 500 | raise TestProcessingException( 501 | 'Invalid update_rows clause: no value' 502 | ) 503 | src_col = self.get_fact(clause['src_col']) 504 | value = self.get_fact(clause['value']) 505 | location = clause.get('location', 'append') 506 | with self.adapter.connection_named('_test'): 507 | return self.adapter.string_add_sql( 508 | src_col, value, location 509 | ) 510 | else: 511 | raise TestProcessingException( 512 | f'Unknown clause type in update_rows: {clause_type}' 513 | ) 514 | 515 | def step_relation_types(self, sequence_item): 516 | """ 517 | type: relation_types 518 | expect: 519 | foo: view 520 | bar: table 521 | """ 522 | if 'expect' not in sequence_item: 523 | raise TestProcessingException('Invalid relation_types: no expect') 524 | expected = self.get_fact(sequence_item['expect']) 525 | 526 | expected_relation_values = {} 527 | found_relations = [] 528 | schemas = set() 529 | 530 | for key, value in expected.items(): 531 | relation = self._relation_from_name(key) 532 | expected_relation_values[relation] = value 533 | schemas.add(relation.without_identifier()) 534 | with self.adapter.connection_named('__test'): 535 | for schema in schemas: 536 | found_relations.extend( 537 | self.adapter.list_relations_without_caching(schema)) 538 | 539 | for key, value in expected.items(): 540 | for relation in found_relations: 541 | # this might be too broad 542 | if relation.identifier == key: 543 | assert relation.type == value, ( 544 | f'Got an unexpected relation type of {relation.type} ' 545 | f'for relation {key}, expected {value}' 546 | ) 547 | 548 | def step_update_rows(self, sequence_item): 549 | """ 550 | type: update_rows 551 | name: base 552 | dst_col: some_date 553 | clause: 554 | type: add_timestamp 555 | src_col: some_date 556 | where: id > 10 557 | """ 558 | if 'name' not in sequence_item: 559 | raise TestProcessingException('Invalid update_rows: no name') 560 | if 'dst_col' not in sequence_item: 561 | raise TestProcessingException('Invalid update_rows: no dst_col') 562 | 563 | if 'clause' not in sequence_item: 564 | raise TestProcessingException('Invalid update_rows: no clause') 565 | 566 | clause = self.get_fact(sequence_item['clause']) 567 | if isinstance(clause, dict): 568 | clause = self._generate_update_clause(clause) 569 | 570 | where = None 571 | if 'where' in sequence_item: 572 | where = self.get_fact(sequence_item['where']) 573 | 574 | name = self.get_fact(sequence_item['name']) 575 | dst_col = self.get_fact(sequence_item['dst_col']) 576 | relation = self._relation_from_name(name) 577 | 578 | with self.adapter.connection_named('_test'): 579 | sql = self.adapter.update_column_sql( 580 | dst_name=str(relation), 581 | dst_column=dst_col, 582 | clause=clause, 583 | where_clause=where, 584 | ) 585 | self.adapter.execute(sql, auto_begin=True) 586 | self.adapter.commit_if_has_connection() 587 | 588 | def _write_profile(self, tmpdir): 589 | profile_data = { 590 | 'config': { 591 | 'send_anonymous_usage_stats': False, 592 | }, 593 | 'dbt-pytest': { 594 | 'target': 'default', 595 | 'outputs': { 596 | 'default': self.target, 597 | }, 598 | }, 599 | } 600 | with open(os.path.join(tmpdir, 'profiles.yml'), 'w') as fp: 601 | fp.write(yaml.safe_dump(profile_data)) 602 | 603 | def _add_context(self, error_str, idx, test_item): 604 | item_type = test_item['type'] 605 | return f'{error_str} in test index {idx} (item_type={item_type})' 606 | 607 | def run_test_item(self, idx, test_item, tmpdir): 608 | try: 609 | item_type = test_item['type'] 610 | except KeyError: 611 | raise TestProcessingException( 612 | f'Could not find type in {test_item}' 613 | ) from None 614 | print(f'Executing step {idx+1}/{len(self.sequence)}') 615 | try: 616 | if item_type == 'dbt': 617 | assert os.path.exists(tmpdir) 618 | self.step_dbt(test_item, tmpdir) 619 | elif item_type == 'run_results': 620 | self.step_run_results(test_item, tmpdir) 621 | elif item_type == 'catalog': 622 | self.step_catalog(test_item, tmpdir) 623 | elif item_type == 'relations_equal': 624 | self.step_relations_equal(test_item) 625 | elif item_type == 'relation_rows': 626 | self.step_relation_rows(test_item) 627 | elif item_type == 'update_rows': 628 | self.step_update_rows(test_item) 629 | elif item_type == 'relation_types': 630 | self.step_relation_types(test_item) 631 | else: 632 | raise TestProcessingException( 633 | f'Unknown item type {item_type}' 634 | ) 635 | except AssertionError as exc: 636 | if len(exc.args) == 1: 637 | arg = self._add_context(exc.args[0], idx, test_item) 638 | exc.args = (arg,) 639 | else: # uhhhhhhh 640 | exc.args = exc.args + (self._add_context('', idx, test_item),) 641 | raise 642 | 643 | def runtest(self): 644 | FACTORY.reset_adapters() 645 | with tempfile.TemporaryDirectory() as tmpdir: 646 | self._write_profile(tmpdir) 647 | self.project.write(tmpdir) 648 | self.adapter = self._get_adapter(tmpdir) 649 | 650 | self.schema_relation = self.adapter.Relation.create( 651 | database=self.adapter.config.credentials.database, 652 | schema=self.adapter.config.credentials.schema, 653 | quote_policy=self.adapter.config.quoting, 654 | ) 655 | 656 | try: 657 | for idx, test_item in enumerate(self.sequence): 658 | self.run_test_item(idx, test_item, tmpdir) 659 | finally: 660 | with self.adapter.connection_named('__test'): 661 | if self.config.getoption('drop_schema'): 662 | self.adapter.drop_schema(self.schema_relation) 663 | 664 | return True 665 | 666 | def repr_failure(self, excinfo): 667 | """ called when self.runtest() raises an exception. """ 668 | if isinstance(excinfo.value, DBTException): 669 | return "\n".join([ 670 | "usecase execution failed", 671 | " spec failed: {!r}".format(excinfo.value.args), 672 | " no further details known at this point.", 673 | ]) 674 | elif isinstance(excinfo.value, CalledProcessError): 675 | failed = str(excinfo.value.cmd) 676 | stdout = excinfo.value.stdout.decode('utf-8') 677 | stderr = excinfo.value.stderr.decode('utf-8') 678 | return '\n'.join([ 679 | f'failed to execute "{failed}:', 680 | f' output: {stdout}', 681 | f' error: {stderr}', 682 | f' rc: {excinfo.value.returncode}', 683 | ]) 684 | elif isinstance(excinfo.value, TestProcessingException): 685 | return str(excinfo.value) 686 | else: 687 | return f'Unknown error: {excinfo.value}' 688 | 689 | def reportinfo(self): 690 | return self.fspath, 0, "usecase: {}".format(self.name) 691 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dbt-core>=1.0.0b1 2 | pytest 3 | typing_extensions>=3.7.4,<3.8 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | def read(path): 5 | with open(path, encoding='utf-8') as fp: 6 | return fp.read() 7 | 8 | 9 | setup( 10 | name='pytest-dbt-adapter', 11 | packages=['pytest_dbt_adapter'], 12 | author="dbt Labs", 13 | author_email="info@dbtlabs.com", 14 | url="https://github.com/dbt-labs/dbt-adapter-tests", 15 | version='0.6.0', 16 | package_data={ 17 | 'pytest_dbt_adapter': [ 18 | 'projects/*.yml', 19 | 'sequences/*.yml', 20 | ] 21 | }, 22 | entry_points={ 23 | 'pytest11': [ 24 | 'pytest_dbt_adapter = pytest_dbt_adapter', 25 | ] 26 | }, 27 | install_requires=['py>=1.3.0', 'pytest>=6', 'pyyaml>=3.11'], 28 | description="A pytest plugin for testing dbt adapter plugins", 29 | long_description=read('README.md'), 30 | long_description_content_type='text/markdown', 31 | python_requires=">=3.6.2", 32 | ) 33 | -------------------------------------------------------------------------------- /specs/postgres.dbtspec: -------------------------------------------------------------------------------- 1 | target: 2 | type: postgres 3 | host: localhost 4 | user: root 5 | pass: password 6 | database: dbt 7 | schema: "dbt_test_{{ var('_dbt_random_suffix') }}" 8 | port: 5432 9 | threads: 8 10 | sequences: 11 | test_dbt_empty: empty 12 | test_dbt_base: base 13 | test_dbt_ephemeral: ephemeral 14 | test_dbt_incremental: incremental 15 | test_dbt_snapshot_strategy_timestamp: snapshot_strategy_timestamp 16 | test_dbt_snapshot_strategy_check_cols: snapshot_strategy_check_cols 17 | test_dbt_data_test: data_test 18 | test_dbt_schema_test: schema_test 19 | test_dbt_ephemeral_data_tests: data_test_ephemeral_models 20 | -------------------------------------------------------------------------------- /specs/presto.dbtspec: -------------------------------------------------------------------------------- 1 | target: 2 | type: presto 3 | threads: 4 4 | host: localhost 5 | port: 8080 6 | user: presto 7 | database: memory 8 | schema: default 9 | projects: 10 | # incremental models aren't allowed 11 | - overrides: base 12 | paths: 13 | models/swappable.sql: | 14 | {% set materialized_var = "table" %} 15 | {% if var("materialized_var", "table") == "view" %} 16 | {% set materialized_var = "view" %} 17 | {% endif %} 18 | {{ config(materialized=materialized_var) }} 19 | select * from {{ source('raw', 'seed') }} 20 | sequences: 21 | test_dbt_empty: empty 22 | test_dbt_base: base 23 | test_dbt_ephemeral: ephemeral 24 | # no incrementals, no snapshots 25 | # test_dbt_incremental: incremental 26 | # test_dbt_snapshot_strategy_timestamp: snapshot_strategy_timestamp 27 | # test_dbt_snapshot_strategy_check_cols: snapshot_strategy_check_cols 28 | test_dbt_data_test: data_test 29 | test_dbt_schema_test: schema_test 30 | test_dbt_ephemeral_data_tests: data_test_ephemeral_models 31 | -------------------------------------------------------------------------------- /specs/spark-databricks.dbtspec: -------------------------------------------------------------------------------- 1 | target: 2 | type: spark 3 | host: "{{ env_var('DBT_DATABRICKS_HOST_NAME') }}" 4 | cluster: "{{ env_var('DBT_DATABRICKS_CLUSTER_NAME') }}" 5 | token: "{{ env_var('DBT_DATABRICKS_TOKEN') }}" 6 | method: http 7 | port: 443 8 | schema: "analytics_{{ var('_dbt_random_suffix') }}" 9 | connect_retries: 5 10 | connect_timeout: 60 11 | projects: 12 | - overrides: incremental 13 | paths: 14 | "models/incremental.sql": 15 | materialized: incremental 16 | body: "select * from {{ source('raw', 'seed') }}" 17 | facts: 18 | base: 19 | rowcount: 10 20 | added: 21 | rowcount: 20 22 | - overrides: snapshot_strategy_check_cols 23 | dbt_project_yml: &file_format_delta 24 | # we're going to UPDATE the seed tables as part of testing, so we must make them delta format 25 | seeds: 26 | dbt_test_project: 27 | file_format: delta 28 | snapshots: 29 | dbt_test_project: 30 | file_format: delta 31 | - overrides: snapshot_strategy_timestamp 32 | dbt_project_yml: *file_format_delta 33 | sequences: 34 | test_dbt_empty: empty 35 | test_dbt_base: base 36 | test_dbt_ephemeral: ephemeral 37 | test_dbt_incremental: incremental 38 | test_dbt_snapshot_strategy_timestamp: snapshot_strategy_timestamp 39 | test_dbt_snapshot_strategy_check_cols: snapshot_strategy_check_cols 40 | test_dbt_data_test: data_test 41 | test_dbt_ephemeral_data_tests: data_test_ephemeral_models 42 | test_dbt_schema_test: schema_test 43 | 44 | -------------------------------------------------------------------------------- /specs/spark.dbtspec: -------------------------------------------------------------------------------- 1 | target: 2 | type: spark 3 | host: localhost 4 | user: dbt 5 | method: thrift 6 | port: 10000 7 | connect_retries: 5 8 | connect_timeout: 60 9 | schema: "analytics_{{ var('_dbt_random_suffix') }}" 10 | projects: 11 | - overrides: incremental 12 | paths: 13 | "models/incremental.sql": 14 | materialized: incremental 15 | body: "select * from {{ source('raw', 'seed') }}" 16 | facts: 17 | base: 18 | rowcount: 10 19 | added: 20 | rowcount: 20 21 | sequences: 22 | test_dbt_empty: empty 23 | test_dbt_base: base 24 | test_dbt_ephemeral: ephemeral 25 | test_dbt_incremental: incremental 26 | # snapshots require delta format 27 | # test_dbt_snapshot_strategy_timestamp: snapshot_strategy_timestamp 28 | # test_dbt_snapshot_strategy_check_cols: snapshot_strategy_check_cols 29 | test_dbt_data_test: data_test 30 | test_dbt_schema_test: schema_test 31 | # the local cluster currently tests on spark 2.x, which does not support this 32 | # if we upgrade it to 3.x, we can enable this test 33 | # test_dbt_ephemeral_data_tests: data_test_ephemeral_models 34 | --------------------------------------------------------------------------------