├── BASE_VERSION ├── .python-version ├── .gitignore ├── .pre-commit-config.yaml ├── dysql ├── exceptions.py ├── test │ ├── conftest.py │ ├── test_mariadbmap_tojson.py │ ├── __init__.py │ ├── test_annotations.py │ ├── test_sql_insert_templates.py │ ├── test_sql_exists_decorator.py │ ├── test_mappers.py │ ├── test_template_generators.py │ ├── test_database_initialization.py │ ├── test_pydantic_mappers.py │ ├── test_sql_in_list_templates.py │ └── test_sql_decorator.py ├── annotations.py ├── __init__.py ├── multitenancy.py ├── pydantic_mappers.py ├── databases.py ├── test_managers.py ├── mappers.py ├── connections.py └── query_utils.py ├── Dockerfile ├── .github ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md ├── CONTRIBUTING.md └── workflows │ └── build.yaml ├── LICENSE ├── pyproject.toml ├── CODE_OF_CONDUCT.md └── README.rst /BASE_VERSION: -------------------------------------------------------------------------------- 1 | 3.1 2 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .idea 3 | .tox 4 | venv/ 5 | build/ 6 | dist/ 7 | *.egg-info/ 8 | 9 | *.pyc 10 | *.swp 11 | pytest_cache 12 | 13 | docker-compose.env 14 | docker-compose.yml 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.12.7 5 | hooks: 6 | # Run the linter. 7 | - id: ruff-check 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format 11 | -------------------------------------------------------------------------------- /dysql/exceptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | 10 | class DBNotPreparedError(Exception): 11 | """ 12 | DB Not Prepared Error 13 | """ 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG python_version 2 | 3 | FROM python:${python_version}-alpine 4 | 5 | RUN apk add gcc musl-dev libffi-dev openssl-dev git 6 | 7 | WORKDIR /app 8 | 9 | # Install uv 10 | COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/ 11 | 12 | # Install dependencies first to leverage Docker cache 13 | COPY pyproject.toml uv.lock README.rst /app/ 14 | RUN uv sync --locked --no-install-project --no-dev 15 | 16 | # Install the project separately for optimal layer caching 17 | COPY dysql /app/dysql 18 | RUN uv sync --locked --no-dev -------------------------------------------------------------------------------- /dysql/test/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from unittest.mock import patch 10 | import pytest 11 | 12 | 13 | @pytest.fixture(name="mock_create_engine") 14 | def mock_create_engine_fixture(): 15 | create_mock = patch("dysql.databases.sqlalchemy.create_engine") 16 | try: 17 | yield create_mock.start() 18 | finally: 19 | create_mock.stop() 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ### Expected Behaviour 5 | 6 | ### Actual Behaviour 7 | 8 | ### Reproduce Scenario (including but not limited to) 9 | 10 | #### Steps to Reproduce 11 | 12 | #### Platform and Version 13 | 14 | #### Sample Code that illustrates the problem 15 | 16 | #### Logs taken while reproducing problem -------------------------------------------------------------------------------- /dysql/annotations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2023 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from typing import TypeVar, Annotated 10 | 11 | from pydantic import BeforeValidator 12 | 13 | # pylint: disable=invalid-name 14 | T = TypeVar("T") 15 | 16 | 17 | def _transform_csv(value: str) -> T: 18 | if not value: 19 | return None 20 | 21 | if isinstance(value, str): 22 | return list(map(str.strip, value.split(","))) 23 | 24 | if isinstance(value, list): 25 | return value 26 | # if we don't have a string or type T we aren't going to be able to transform it 27 | return [value] 28 | 29 | 30 | # Annotation that helps transform a CSV string into a list of type T 31 | FromCSVToList = Annotated[T, BeforeValidator(_transform_csv)] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | © Copyright 2021 Adobe. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["uv_build"] 3 | build-backend = "uv_build" 4 | 5 | [project] 6 | name = "dy-sql" 7 | description = "Dynamically runs SQL queries and executions." 8 | readme = "README.rst" 9 | requires-python = ">=3.9" 10 | license = {text = "MIT"} 11 | authors = [ 12 | {name = "Adobe", email = "noreply@adobe.com"} 13 | ] 14 | urls = { "Homepage" = "https://github.com/adobe/dy-sql" } 15 | version = "3.2" 16 | dependencies = [ 17 | # SQLAlchemy 2+ is not yet supported 18 | "sqlalchemy<2", 19 | # now using features only found in pydantic 2+ 20 | "pydantic>=2", 21 | ] 22 | 23 | [dependency-groups] 24 | dev = [ 25 | "pytest>=6.2.4", 26 | "pytest-randomly>=3.10.1", 27 | "pytest-cov>=2.12.1", 28 | "ruff>=0.12.7", 29 | "ruff-lsp>=0.0.45", 30 | "pre-commit>=3.6", 31 | "docker>=7", 32 | ] 33 | 34 | [tool.uv.build-backend] 35 | module-name = "dysql" 36 | module-root = "" 37 | source-exclude = ["dysql/test"] 38 | wheel-exclude = ["dysql/test"] 39 | 40 | [tool.ruff.lint] 41 | # Enable the copyright rule 42 | preview = true 43 | extend-select = ["CPY"] 44 | 45 | [tool.ruff.lint.flake8-copyright] 46 | min-file-size = 200 47 | notice-rgx = '''""" 48 | Copyright [(2)\d{3}]* Adobe 49 | All Rights Reserved. 50 | 51 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 52 | with the terms of the Adobe license agreement accompanying it. 53 | """''' 54 | -------------------------------------------------------------------------------- /dysql/test/test_mariadbmap_tojson.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import json 10 | import pytest 11 | 12 | from dysql import DbMapResult 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "name, expected_json, mariadb_map", 17 | [ 18 | ( 19 | "basic", 20 | '{"id": 1, "name": "test"}', 21 | DbMapResult(id=1, name="test"), 22 | ), 23 | ( 24 | "basic_with_list", 25 | '{"id": 1, "name": "test", "my_list": ["a", "b", "c"]}', 26 | DbMapResult(id=1, name="test", my_list=["a", "b", "c"]), 27 | ), 28 | ( 29 | "inner_map", 30 | '{"id": 1, "inner_map": {"id": 2, "name": "inner_test"}}', 31 | DbMapResult(id=1, inner_map=DbMapResult(id=2, name="inner_test")), 32 | ), 33 | ( 34 | "inner_list_of_maps", 35 | '{"id": 1, "inner_map_list": [{"id": 2, "name": "inner_test_2"}, {"id": 3, "name": "inner_test_3"}]}', 36 | DbMapResult( 37 | id=1, 38 | inner_map_list=[ 39 | DbMapResult(id=2, name="inner_test_2"), 40 | DbMapResult(id=3, name="inner_test_3"), 41 | ], 42 | ), 43 | ), 44 | ], 45 | ) 46 | def test_raw_json_format(name, expected_json, mariadb_map): 47 | assert json.dumps(mariadb_map.raw()) == expected_json, "error with " + name 48 | -------------------------------------------------------------------------------- /dysql/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | # Public imports 10 | from .mappers import ( 11 | BaseMapper, 12 | DbMapResult, 13 | RecordCombiningMapper, 14 | SingleRowMapper, 15 | SingleColumnMapper, 16 | SingleRowAndColumnMapper, 17 | CountMapper, 18 | KeyValueMapper, 19 | ) 20 | from .query_utils import QueryData, QueryDataError, TemplateGenerators 21 | from .connections import ( 22 | sqlexists, 23 | sqlquery, 24 | sqlupdate, 25 | ) 26 | from .databases import ( 27 | is_set_current_database_supported, 28 | reset_current_database, 29 | set_current_database, 30 | set_database_init_hook, 31 | set_default_connection_parameters, 32 | ) 33 | from .multitenancy import use_database_tenant, tenant_database_manager 34 | from .exceptions import DBNotPreparedError 35 | 36 | 37 | __all__ = [ 38 | "BaseMapper", 39 | "DbMapResult", 40 | "RecordCombiningMapper", 41 | "SingleRowMapper", 42 | "SingleColumnMapper", 43 | "SingleRowAndColumnMapper", 44 | "CountMapper", 45 | "KeyValueMapper", 46 | "QueryData", 47 | "QueryDataError", 48 | "TemplateGenerators", 49 | "sqlexists", 50 | "sqlquery", 51 | "sqlupdate", 52 | "is_set_current_database_supported", 53 | "reset_current_database", 54 | "set_current_database", 55 | "set_database_init_hook", 56 | "set_default_connection_parameters", 57 | "use_database_tenant", 58 | "tenant_database_manager", 59 | "DBNotPreparedError", 60 | ] 61 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | 7 | ## Related Issue 8 | 9 | 10 | 11 | 12 | 13 | 14 | ## Motivation and Context 15 | 16 | 17 | 18 | ## How Has This Been Tested? 19 | 20 | 21 | 22 | 23 | 24 | ## Screenshots (if appropriate): 25 | 26 | ## Types of changes 27 | 28 | 29 | 30 | - [ ] Bug fix (non-breaking change which fixes an issue) 31 | - [ ] New feature (non-breaking change which adds functionality) 32 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 33 | 34 | ## Checklist: 35 | 36 | 37 | 38 | 39 | - [ ] I have signed the [Adobe Open Source CLA](https://opensource.adobe.com/cla.html). 40 | - [ ] My code follows the code style of this project. 41 | - [ ] My change requires a change to the documentation. 42 | - [ ] I have updated the documentation accordingly. 43 | - [ ] I have read the **CONTRIBUTING** document. 44 | - [ ] I have added tests to cover my changes. 45 | - [ ] All new and existing tests passed. -------------------------------------------------------------------------------- /dysql/test/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from unittest.mock import Mock 10 | 11 | from dysql import set_default_connection_parameters, databases 12 | 13 | 14 | def setup_mock_engine(mock_create_engine): 15 | """ 16 | build up the basics of a mock engine for the database 17 | :return: mocked engine for use and manipulation in testing 18 | """ 19 | mock_engine = Mock() 20 | mock_engine.connect().execution_options().__enter__ = Mock() 21 | mock_engine.connect().execution_options().__exit__ = Mock() 22 | set_default_connection_parameters("fake", "user", "password", "test") 23 | 24 | # Clear out the databases before attempting to mock anything 25 | databases.DatabaseContainerSingleton().clear() 26 | mock_create_engine.return_value = mock_engine 27 | return mock_engine 28 | 29 | 30 | def verify_query_params(mock_engine, expected_query, expected_args): 31 | verify_query(mock_engine, expected_query) 32 | verify_query_args(mock_engine, expected_args) 33 | 34 | 35 | def verify_query(mock_engine, expected_query): 36 | execute_call = ( 37 | mock_engine.connect.return_value.execution_options.return_value.execute 38 | ) 39 | execute_call.assert_called() 40 | 41 | query = execute_call.call_args[0][0].text 42 | assert query == expected_query 43 | 44 | 45 | def verify_query_args(mock_engine, expected_args): 46 | execute_call = ( 47 | mock_engine.connect.return_value.execution_options.return_value.execute 48 | ) 49 | query_args = execute_call.call_args[0][1] 50 | 51 | assert query_args 52 | for expected_key in expected_args: 53 | expected_value = expected_args[expected_key] 54 | assert query_args.get(expected_key) 55 | assert expected_value == query_args[expected_key] 56 | -------------------------------------------------------------------------------- /dysql/test/test_annotations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2023 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from typing import List, Union 10 | 11 | import pytest 12 | from pydantic import BaseModel, ValidationError 13 | 14 | from dysql.annotations import FromCSVToList 15 | 16 | 17 | class StrCSVModel(BaseModel): 18 | values: FromCSVToList[List[str]] 19 | 20 | 21 | class IntCSVModel(BaseModel): 22 | values: FromCSVToList[List[int]] 23 | 24 | 25 | class NullableStrCSVModel(BaseModel): 26 | values: FromCSVToList[Union[List[str], None]] 27 | 28 | 29 | class NullableIntCSVModel(BaseModel): 30 | values: FromCSVToList[Union[List[int], None]] 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "cls, values, expected", 35 | [ 36 | (StrCSVModel, "1,2,3", ["1", "2", "3"]), 37 | (StrCSVModel, "a,b", ["a", "b"]), 38 | (StrCSVModel, "a", ["a"]), 39 | (NullableStrCSVModel, "", None), 40 | (NullableStrCSVModel, None, None), 41 | (StrCSVModel, ["a", "b"], ["a", "b"]), 42 | (StrCSVModel, ["a", "b", "c"], ["a", "b", "c"]), 43 | (IntCSVModel, "1,2,3", [1, 2, 3]), 44 | (IntCSVModel, "1", [1]), 45 | (NullableIntCSVModel, "", None), 46 | (NullableIntCSVModel, None, None), 47 | (IntCSVModel, ["1", "2", "3"], [1, 2, 3]), 48 | (IntCSVModel, ["1", "2", "3", 4, 5], [1, 2, 3, 4, 5]), 49 | ], 50 | ) 51 | def test_from_csv_to_list(cls, values, expected): 52 | assert expected == cls(values=values).values 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "cls, values", 57 | [ 58 | (StrCSVModel, ""), 59 | (StrCSVModel, None), 60 | (IntCSVModel, "a,b,c"), 61 | (IntCSVModel, ""), 62 | (IntCSVModel, None), 63 | (IntCSVModel, ["a", "b", "c"]), 64 | ], 65 | ) 66 | def test_from_csv_to_list_invalid(cls, values): 67 | with pytest.raises(ValidationError): 68 | cls(values=values) 69 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for choosing to contribute! 4 | 5 | The following are a set of guidelines to follow when contributing to this project. 6 | 7 | ## Code Of Conduct 8 | 9 | This project adheres to the Adobe [code of conduct](../CODE_OF_CONDUCT.md). By participating, 10 | you are expected to uphold this code. Please report unacceptable behavior to 11 | [Grp-opensourceoffice@adobe.com](mailto:Grp-opensourceoffice@adobe.com). 12 | 13 | ## Have A Question? 14 | 15 | Start by filing an issue. The existing committers on this project work to reach 16 | consensus around project direction and issue solutions within issue threads 17 | (when appropriate). 18 | 19 | ## Contributor License Agreement 20 | 21 | All third-party contributions to this project must be accompanied by a signed contributor 22 | license agreement. This gives Adobe permission to redistribute your contributions 23 | as part of the project. [Sign our CLA](https://opensource.adobe.com/cla.html). You 24 | only need to submit an Adobe CLA one time, so if you have submitted one previously, 25 | you are good to go! 26 | 27 | ## Code Reviews 28 | 29 | All submissions should come in the form of pull requests and need to be reviewed 30 | by project committers. Read [GitHub's pull request documentation](https://help.github.com/articles/about-pull-requests/) 31 | for more information on sending pull requests. 32 | 33 | Lastly, please follow the [pull request template](PULL_REQUEST_TEMPLATE.md) when 34 | submitting a pull request! 35 | 36 | ## From Contributor To Committer 37 | 38 | We love contributions from our community! If you'd like to go a step beyond contributor 39 | and become a committer with full write access and a say in the project, you must 40 | be invited to the project. The existing committers employ an internal nomination 41 | process that must reach lazy consensus (silence is approval) before invitations 42 | are issued. If you feel you are qualified and want to get more deeply involved, 43 | feel free to reach out to existing committers to have a conversation about that. 44 | 45 | ## Security Issues 46 | 47 | Security issues shouldn't be reported on this issue tracker. Instead, [file an issue to our security experts](https://helpx.adobe.com/security/alertus.html). -------------------------------------------------------------------------------- /dysql/multitenancy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2025 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import functools 10 | import logging 11 | from typing import Any, Callable, TypeVar 12 | from contextlib import contextmanager 13 | 14 | from dysql.exceptions import DBNotPreparedError 15 | from dysql import ( 16 | set_current_database, 17 | reset_current_database, 18 | ) 19 | 20 | LOGGER = logging.getLogger(__name__) 21 | 22 | F = TypeVar("F", bound=Callable[..., Any]) 23 | 24 | 25 | @contextmanager 26 | def tenant_database_manager(database_key: str): 27 | """ 28 | Context manager for temporarily switching to a different database. 29 | 30 | :param database_key: the database key to switch to 31 | :raises DBNotPreparedError: if the database key is not set 32 | """ 33 | if not database_key: 34 | raise DBNotPreparedError( 35 | "Cannot switch to database tenant with empty database key" 36 | ) 37 | 38 | try: 39 | LOGGER.debug(f"Switching to database {database_key}") 40 | set_current_database(database_key) 41 | yield 42 | except Exception as e: 43 | LOGGER.error(f"Error while using database {database_key}: {e}") 44 | raise 45 | finally: 46 | try: 47 | reset_current_database() 48 | LOGGER.debug(f"Reset database context from: {database_key}") 49 | except Exception as e: 50 | LOGGER.error(f"Error resetting database context: {e}") 51 | # Don't re-raise here to avoid masking the original exception 52 | 53 | 54 | def use_database_tenant(database_key: str): 55 | """ 56 | Decorator that switches to a specific database for the duration of the function call. 57 | :param database_key: the database key to use 58 | :return: the decorator function 59 | """ 60 | 61 | def decorator(func: F) -> F: 62 | @functools.wraps(func) 63 | def wrapper(*args, **kwargs): 64 | with tenant_database_manager(database_key): 65 | return func(*args, **kwargs) 66 | 67 | return wrapper 68 | 69 | return decorator 70 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Adobe Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of experience, 9 | nationality, personal appearance, race, religion, or sexual identity and 10 | orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language. 18 | * Being respectful of differing viewpoints and experiences. 19 | * Gracefully accepting constructive criticism. 20 | * Focusing on what is best for the community. 21 | * Showing empathy towards other community members. 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances. 27 | * Trolling, insulting/derogatory comments, and personal or political attacks. 28 | * Public or private harassment. 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission. 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting. 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at Grp-opensourceoffice@adobe.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at [https://contributor-covenant.org/version/1/4][version]. 72 | 73 | [homepage]: https://contributor-covenant.org 74 | [version]: https://contributor-covenant.org/version/1/4/ 75 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: 3 | - push 4 | - pull_request 5 | jobs: 6 | test: 7 | if: github.repository == 'adobe/dy-sql' 8 | runs-on: ubuntu-latest 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | python-version: 13 | - '3.9' 14 | - '3.10' 15 | - '3.11' 16 | - '3.12' 17 | - '3.13' 18 | steps: 19 | - uses: actions/checkout@v2 20 | with: 21 | # Fetch all history instead of the latest commit 22 | fetch-depth: 0 23 | - name: Install uv 24 | uses: astral-sh/setup-uv@v6 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | # Install dev dependencies 28 | - name: Install dependencies 29 | run: uv sync --locked 30 | - name: Run pre-commit checks 31 | run: uv run pre-commit run --all-files 32 | - name: Test with pytest 33 | run: uv run pytest --junitxml=test-reports/test-results.xml 34 | - name: Publish test results 35 | uses: EnricoMi/publish-unit-test-result-action/composite@v1 36 | if: github.event_name == 'push' && always() 37 | with: 38 | files: test-reports/test-results.xml 39 | check_name: "Test Results ${{ matrix.python-version }}" 40 | get-version: 41 | if: github.repository == 'adobe/dy-sql' 42 | runs-on: ubuntu-latest 43 | needs: test 44 | outputs: 45 | current-version: ${{ steps.version-number.outputs.CURRENT_VERSION }} 46 | steps: 47 | - uses: actions/checkout@v2 48 | with: 49 | # Fetch all history instead of the latest commit 50 | fetch-depth: 0 51 | - name: Install uv 52 | uses: astral-sh/setup-uv@v6 53 | with: 54 | python-version: 3.9 55 | - name: Install dependencies 56 | run: uv sync --locked 57 | - name: Get current version 58 | id: version-number 59 | run: echo "CURRENT_VERSION=$(uv version --short).$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT 60 | - name: Print current version 61 | run: echo CURRENT_VERSION ${{ steps.version-number.outputs.CURRENT_VERSION }} 62 | tag-commit: 63 | if: github.repository == 'adobe/dy-sql' && github.event_name == 'push' && github.ref == 'refs/heads/main' 64 | runs-on: ubuntu-latest 65 | needs: [test, get-version] 66 | steps: 67 | - uses: actions/checkout@v2 68 | with: 69 | # Fetch all history instead of the latest commit 70 | fetch-depth: 0 71 | - name: Tag commit 72 | run: git tag ${{ needs.get-version.outputs.current-version }} && git push --tags 73 | publish-pypi: 74 | if: github.repository == 'adobe/dy-sql' 75 | runs-on: ubuntu-latest 76 | needs: test 77 | steps: 78 | - uses: actions/checkout@v2 79 | with: 80 | # Fetch all history instead of the latest commit 81 | fetch-depth: 0 82 | - name: Install uv 83 | uses: astral-sh/setup-uv@v6 84 | with: 85 | python-version: 3.9 86 | - name: Install dependencies 87 | run: uv sync --locked 88 | - name: Set version 89 | run: uv version --no-sync "$(uv version --short).$(git rev-list --count HEAD)" 90 | - name: Build 91 | run: uv build 92 | - name: Check upload 93 | run: uv run --with twine twine check dist/* 94 | - name: Publish to PyPi 95 | uses: pypa/gh-action-pypi-publish@release/v1 96 | # Only publish on pushes to main 97 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 98 | with: 99 | user: __token__ 100 | password: ${{ secrets.ADOBE_BOT_PYPI_TOKEN }} 101 | 102 | -------------------------------------------------------------------------------- /dysql/test/test_sql_insert_templates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import pytest 10 | 11 | import dysql 12 | from dysql import QueryData, sqlupdate, QueryDataError 13 | from dysql.test import ( 14 | verify_query, 15 | verify_query_args, 16 | setup_mock_engine, 17 | ) 18 | 19 | 20 | @pytest.fixture(name="mock_engine", autouse=True) 21 | def mock_engine_fixture(mock_create_engine): 22 | initial_id = 0 23 | 24 | def handle_execute(query=None, args=None): 25 | nonlocal initial_id 26 | if "INSERT INTO get_last(name)" in query.text: 27 | initial_id += 1 28 | if "SELECT LAST_INSERT_ID()" == query.text: 29 | return type("Result", (), {"scalar": lambda: initial_id}) 30 | return [] 31 | 32 | mock_engine = setup_mock_engine(mock_create_engine) 33 | execute_mock = mock_engine.connect().execution_options().execute 34 | execute_mock.side_effect = handle_execute 35 | return mock_engine 36 | 37 | 38 | def test_insert_non_query_data_fails(): 39 | @sqlupdate() 40 | def select_with_string(): 41 | return "INSERT" 42 | 43 | with pytest.raises(QueryDataError): 44 | select_with_string() 45 | 46 | 47 | def test_insert_single_column(mock_engine): 48 | insert_into_single_value(["Tom", "Jerry"]) 49 | verify_query( 50 | mock_engine, 51 | "INSERT INTO table(name) VALUES ( :values__name_col_0 ), ( :values__name_col_1 ) ", 52 | ) 53 | 54 | 55 | def test_insert_single_column_single_value(mock_engine): 56 | insert_into_single_value("Tom") 57 | verify_query(mock_engine, "INSERT INTO table(name) VALUES ( :values__name_col_0 ) ") 58 | 59 | 60 | def test_insert_single_value_empty(): 61 | with pytest.raises( 62 | dysql.query_utils.ListTemplateException, match="['values_name_col']" 63 | ): 64 | insert_into_single_value([]) 65 | 66 | 67 | def test_insert_single_value_no_key(): 68 | with pytest.raises( 69 | dysql.query_utils.ListTemplateException, match="['values_name_col']" 70 | ): 71 | insert_into_single_value(None) 72 | 73 | 74 | def test_insert_multiple_values(mock_engine): 75 | insert_into_multiple_values( 76 | [ 77 | {"name": "Tom", "email": "tom@adobe.com"}, 78 | {"name": "Jerry", "email": "jerry@adobe.com"}, 79 | ] 80 | ) 81 | verify_query( 82 | mock_engine, 83 | "INSERT INTO table(name, email) VALUES ( :values__users_0_0, :values__users_0_1 ), " 84 | "( :values__users_1_0, :values__users_1_1 ) ", 85 | ) 86 | verify_query_args( 87 | mock_engine, 88 | { 89 | "values__users_0_0": "Tom", 90 | "values__users_0_1": "tom@adobe.com", 91 | "values__users_1_0": "Jerry", 92 | "values__users_1_1": "jerry@adobe.com", 93 | }, 94 | ) 95 | 96 | 97 | @pytest.mark.parametrize( 98 | "args", 99 | [ 100 | ([("bob", "bob@email.com")]), 101 | ([("bob", "bob@email.com"), ("tom", "tom@email.com")]), 102 | None, 103 | (), 104 | ], 105 | ) 106 | def test_insert_with_callback(args): 107 | def callback(items): 108 | assert items == args 109 | 110 | @sqlupdate(on_success=callback) 111 | def insert(items): 112 | yield QueryData(f"INSERT INTO table(name, email) {items}") 113 | 114 | insert(args) 115 | 116 | 117 | def test_insert_with_callack_not_called(mock_engine): 118 | def callback(): 119 | assert False 120 | 121 | @sqlupdate(on_success=callback) 122 | def insert(): 123 | yield QueryData("INSERT INTO table(name, email)") 124 | 125 | mock_engine.connect().execution_options.return_value.execute.side_effect = ( 126 | Exception() 127 | ) 128 | with pytest.raises(Exception): 129 | insert() 130 | 131 | 132 | @sqlupdate() 133 | def insert_into_multiple_values(users): 134 | yield QueryData( 135 | "INSERT INTO table(name, email) {values__users}", 136 | template_params={"values__users": [(d["name"], d["email"]) for d in users]}, 137 | ) 138 | 139 | 140 | @sqlupdate() 141 | def insert_into_single_value(names): 142 | template_params = {} 143 | if names is not None: 144 | template_params = {"values__name_col": names} 145 | 146 | return QueryData( 147 | "INSERT INTO table(name) {values__name_col}", template_params=template_params 148 | ) 149 | 150 | 151 | def test_last_insert_id(): 152 | @sqlupdate() 153 | def insert(get_last_insert_id=None): 154 | yield QueryData("INSERT INTO get_last(name) VALUES ('Tom')") 155 | assert get_last_insert_id 156 | assert get_last_insert_id() == 1 157 | yield QueryData("INSERT INTO get_last(name) VALUES ('Jerry')") 158 | assert get_last_insert_id() == 2 159 | 160 | insert() 161 | 162 | 163 | def test_last_insert_id_removed_before_callback(): 164 | def callback(**kwargs): 165 | assert "get_last_insert_id" not in kwargs 166 | 167 | @sqlupdate() 168 | def insert(get_last_insert_id=None): 169 | assert get_last_insert_id 170 | yield QueryData("INSERT INTO get_last(name) VALUES ('Tom')") 171 | yield QueryData("INSERT INTO get_last(name) VALUES ('Jerry')") 172 | assert get_last_insert_id() == 2 173 | 174 | insert() 175 | -------------------------------------------------------------------------------- /dysql/test/test_sql_exists_decorator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from unittest.mock import Mock 10 | 11 | import pytest 12 | 13 | from dysql import sqlexists, QueryData 14 | from dysql.test import setup_mock_engine 15 | 16 | 17 | TRUE_QUERY = "SELECT 1 from table" 18 | TRUE_QUERY_PARAMS = "SELECT 1 from table where key=:key" 19 | FALSE_QUERY = "SELECT 1 from false_table " 20 | FALSE_QUERY_PARAMS = "SELECT 1 from table where key=:key" 21 | TRUE_PARAMS = {"key": 123} 22 | FALSE_PARAMS = {"key": 456} 23 | SELECT_EXISTS_QUERY = "SELECT 1 WHERE EXISTS ( {} )" 24 | SELECT_EXISTS_NO_WHERE_QUERY = "SELECT EXISTS ( {} )" 25 | EXISTS_QUERIES = { 26 | "true": SELECT_EXISTS_QUERY.format(TRUE_QUERY), 27 | "false": SELECT_EXISTS_QUERY.format(FALSE_QUERY), 28 | "true_params": SELECT_EXISTS_QUERY.format(TRUE_QUERY_PARAMS), 29 | "false_params": SELECT_EXISTS_QUERY.format(FALSE_QUERY_PARAMS), 30 | "true_no_where": SELECT_EXISTS_NO_WHERE_QUERY.format(TRUE_QUERY), 31 | "false_no_where": SELECT_EXISTS_NO_WHERE_QUERY.format(FALSE_QUERY), 32 | } 33 | 34 | 35 | @pytest.fixture(autouse=True) 36 | def mock_engine_fixture(mock_create_engine): 37 | mock_engine = setup_mock_engine(mock_create_engine) 38 | mock_engine.connect.return_value.execution_options.return_value.execute.side_effect = _check_query_and_return_result 39 | mock_engine.connect().execution_options().__enter__ = Mock() 40 | mock_engine.connect().execution_options().__exit__ = Mock() 41 | 42 | 43 | def test_exists_true(): 44 | assert _exists_true() 45 | 46 | 47 | def test_exists_false(): 48 | assert not _exists_false() 49 | 50 | 51 | def test_exists_true_params(): 52 | assert _exists_true_params() 53 | 54 | 55 | def test_exists_false_params(): 56 | assert not _exists_false_params() 57 | 58 | 59 | def test_exists_query_contains_with_exists_true(): 60 | """ 61 | this helps to test that if someone adds an exists statement, we don't add the exists 62 | exists(exists()) will always give a 'true' result 63 | """ 64 | # should match against the same query, should still match 65 | assert _exists_specified("true") 66 | 67 | 68 | def test_exists_query_contains_with_exists_false(): 69 | """ 70 | this helps to test that if someone adds an exists statement, we don't add the exists 71 | exists(exists()) will always give a 'true' result 72 | """ 73 | # should match against the same query, should still match 74 | assert not _exists_specified("false") 75 | 76 | 77 | def test_exists_without_where_true(): 78 | """ 79 | this helps avoid an issue discovered when using the 'utf8' charset, some queries with the exists 80 | query would fail to decode the '0' returned with a query as just `SELECT EXISTS` but work as expected 81 | when using `SELECT 1 WHERE EXISTS` 82 | """ 83 | assert _select_exists_no_where_true() 84 | 85 | 86 | def test_exists_witohut_where_false(): 87 | """ 88 | this helps avoid an issue discovered when using the 'utf8' charset, some queries with the exists 89 | query would fail to decode the '0' returned with a query as just `SELECT EXISTS` but work as expected 90 | when using `SELECT 1 WHERE EXISTS` 91 | """ 92 | assert not _select_exists_no_where_false() 93 | 94 | 95 | def test_exists_query_starts_with_exists_handles_whitespace(): 96 | """ 97 | this helps to test that if someone adds an exists statement, we don't add the exists 98 | exists(exists()) will always give a 'true' result 99 | """ 100 | # should trim the whitespace and match appropiately 101 | assert _exists_whitespace() 102 | 103 | 104 | @sqlexists() 105 | def _select_exists_no_where_false(): 106 | return QueryData(EXISTS_QUERIES["false_no_where"]) 107 | 108 | 109 | @sqlexists() 110 | def _select_exists_no_where_true(): 111 | return QueryData(EXISTS_QUERIES["true_no_where"]) 112 | 113 | 114 | @sqlexists() 115 | def _exists_specified(key): 116 | return QueryData(EXISTS_QUERIES[key]) 117 | 118 | 119 | @sqlexists() 120 | def _exists_true(): 121 | return QueryData(TRUE_QUERY) 122 | 123 | 124 | @sqlexists() 125 | def _exists_false(): 126 | return QueryData(FALSE_QUERY) 127 | 128 | 129 | @sqlexists() 130 | def _exists_true_params(): 131 | return QueryData(TRUE_QUERY, TRUE_PARAMS) 132 | 133 | 134 | @sqlexists() 135 | def _exists_false_params(): 136 | return QueryData(FALSE_QUERY, FALSE_PARAMS) 137 | 138 | 139 | @sqlexists() 140 | def _exists_whitespace(): 141 | return QueryData( 142 | """ 143 | """ 144 | + TRUE_QUERY 145 | ) 146 | 147 | 148 | def _check_query_and_return_result(query, params): 149 | """ 150 | check that the query matches our expectations 151 | and then set the return value for the scalar call. 152 | scalar() returns either 1 or 0, 1 for true, 0 for false 153 | """ 154 | assert query.text in list(EXISTS_QUERIES.values()) 155 | scalar_mock = Mock() 156 | # default mock responses to true, then we only handle setting false responses 157 | scalar_mock.scalar.return_value = 1 158 | if query.text == EXISTS_QUERIES["true_params"]: 159 | assert params.get("key") == 123 160 | if query.text == EXISTS_QUERIES["false_params"]: 161 | assert params.get("key") == 456 162 | scalar_mock.scalar.return_value = 0 163 | elif query.text == EXISTS_QUERIES["false"]: 164 | scalar_mock.scalar.return_value = 0 165 | return scalar_mock 166 | -------------------------------------------------------------------------------- /dysql/pydantic_mappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from typing import Any, Dict, Set 10 | 11 | import sqlalchemy 12 | from pydantic import BaseModel, TypeAdapter 13 | 14 | from .mappers import DbMapResultBase 15 | 16 | 17 | class DbMapResultModel(BaseModel, DbMapResultBase): 18 | """ 19 | Pydantic model that allows for results to be mapped into annotated fields without explicitly 20 | mapping each field. The metadata fields starting with "_" are special fields that allow lists, sets, 21 | and dictionaries to be mapped correctly if more than one record is mapped to an instance. 22 | 23 | Note that validation *does* occur the very first time map_record is called, but not on subsequent runs. Therefore 24 | if you desire better validation for list, set, or dict fields, that must most likely be done outside of pydantic. 25 | 26 | Additionally, lists, sets, and dicts will ignore null values from the database. Therefore you must provide default 27 | values for these fields when used or else validation will fail. 28 | """ 29 | 30 | # List fields that are aggregated into a string of comma seperated values with basic string splitting on commas 31 | _csv_list_fields: Set[str] = set() 32 | # List field that are json objects 33 | _json_fields: Set[str] = set() 34 | # List fields (type does not matter) 35 | _list_fields: Set[str] = set() 36 | # Set fields (type does not matter) 37 | _set_fields: Set[str] = set() 38 | # Dictionary key fields as DB field name => model field name 39 | _dict_key_fields: Dict[str, str] = {} 40 | # Dictionary value fields as model field name => DB field name (this is reversed from _dict_key_fields!) 41 | _dict_value_mappings: Dict[str, str] = {} 42 | 43 | @classmethod 44 | def create_instance(cls, *args, **kwargs) -> "DbMapResultModel": 45 | # Uses the construct method to prevent validation when mapping results 46 | return cls.model_construct(*args, **kwargs) 47 | 48 | def _map_json(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str): 49 | model_field = self.model_fields[field] 50 | value = record[field] 51 | if not value: 52 | return 53 | if not self._has_been_mapped(): 54 | current_dict[field] = TypeAdapter(model_field.annotation).validate_json( 55 | value 56 | ) 57 | 58 | def _map_list(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str): 59 | if record[field] is None: 60 | # Do not map null entries into lists, this may cause problems in the future but it works 61 | # around some other issues when fields are nullable 62 | return 63 | if self._has_been_mapped(): 64 | current_dict[field].append(record[field]) 65 | else: 66 | current_dict[field] = [record[field]] 67 | 68 | def _map_set(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str): 69 | if record[field] is None: 70 | # See note above for lists 71 | return 72 | if self._has_been_mapped(): 73 | current_dict[field].add(record[field]) 74 | else: 75 | current_dict[field] = {record[field]} 76 | 77 | def _map_dict(self, current_dict: dict, record: sqlalchemy.engine.Row, field: str): 78 | model_field_name = self._dict_key_fields[field] 79 | value_field = self._dict_value_mappings[model_field_name] 80 | if record[value_field] is None: 81 | # See note above for lists 82 | return 83 | if self._has_been_mapped(): 84 | current_dict[model_field_name][record[field]] = record[value_field] 85 | else: 86 | current_dict[model_field_name] = {record[field]: record[value_field]} 87 | 88 | def _map_list_from_string( 89 | self, current_dict: dict, record: sqlalchemy.engine.Row, field: str 90 | ): 91 | list_string = record[field] 92 | if not list_string: 93 | # See note above for lists 94 | return 95 | 96 | # force it to be a string 97 | list_string = str(list_string) 98 | values_from_string = list(map(str.strip, list_string.split(","))) 99 | 100 | model_field = self.model_fields[field] 101 | # pre-validates the list we are expecting because we want to ensure all records are validated 102 | values = TypeAdapter(model_field.annotation).validate_python(values_from_string) 103 | 104 | if self._has_been_mapped() and current_dict[field]: 105 | current_dict[field].extend(values) 106 | else: 107 | current_dict[field] = values 108 | 109 | def map_record(self, record: sqlalchemy.engine.Row) -> None: 110 | """ 111 | Implementation of map_record that handles the special "_" prefixed fields listed at the top of this class. 112 | The following rules are used: 113 | - If it is in the _csv_list_fields and is not none, extend the existing list 114 | - If it is in the _list_fields and is not none, append it to the field specified 115 | - If it is in the _set_fields and is not none, add it to the field specified 116 | - If it is in the _dict_key_fields and is not none, add it to the field specified along with the value 117 | from the _dict_value_mappings 118 | - Else add it to the fields to bind to this model 119 | - Remove all DB fields that are present in _dict_value_mappings since they were likely added above 120 | :param record: the DB record 121 | """ 122 | current_dict: dict = self.__dict__ 123 | for field in record.keys(): 124 | if field in self._list_fields: 125 | self._map_list(current_dict, record, field) 126 | elif field in self._csv_list_fields: 127 | self._map_list_from_string(current_dict, record, field) 128 | elif field in self._json_fields: 129 | self._map_json(current_dict, record, field) 130 | elif field in self._set_fields: 131 | self._map_set(current_dict, record, field) 132 | elif field in self._dict_key_fields: 133 | self._map_dict(current_dict, record, field) 134 | else: 135 | # Ignore fields that should have already been set 136 | if not self._has_been_mapped(): 137 | current_dict[field] = record[field] 138 | 139 | # Remove all dict value fields (if present) 140 | for db_field in self._dict_value_mappings.values(): 141 | current_dict.pop(db_field, None) 142 | if self._has_been_mapped(): 143 | # At this point, just update the previous record 144 | self.__dict__.update(current_dict) 145 | else: 146 | # Init takes care of validation and assigning values to each field with conversions in place, etc 147 | self.__init__(**current_dict) # pylint: disable=unnecessary-dunder-call 148 | 149 | def raw(self) -> dict: 150 | return self.model_dump() 151 | 152 | def has(self, field: str) -> bool: 153 | return field in self.__dict__ 154 | 155 | def _has_been_mapped(self): 156 | """ 157 | Tells if a record has already been mapped onto this class or not. 158 | :return: True if map_record has already been called, False otherwise 159 | """ 160 | return bool(self.model_fields_set) 161 | 162 | def get(self, field: str, default: Any = None) -> Any: 163 | return self.__dict__.get(field, default) 164 | -------------------------------------------------------------------------------- /dysql/test/test_mappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import pytest 10 | 11 | from dysql import ( 12 | RecordCombiningMapper, 13 | SingleRowMapper, 14 | SingleColumnMapper, 15 | SingleRowAndColumnMapper, 16 | CountMapper, 17 | KeyValueMapper, 18 | ) 19 | from dysql.mappers import MapperError 20 | 21 | 22 | class TestMappers: 23 | """ 24 | Test Mappers 25 | """ 26 | 27 | @staticmethod 28 | def _unwrap_results(results): 29 | return [r.raw() for r in results] 30 | 31 | def test_record_combining(self): 32 | mapper = RecordCombiningMapper() 33 | assert len(mapper.map_records([])) == 0 34 | assert self._unwrap_results(mapper.map_records([{"a": 1, "b": 2}])) == [ 35 | {"a": 1, "b": 2} 36 | ] 37 | assert self._unwrap_results( 38 | mapper.map_records( 39 | [ 40 | {"id": 1, "a": 1, "b": 2}, 41 | {"id": 2, "a": 1, "b": 2}, 42 | {"id": 1, "c": 3}, 43 | ] 44 | ) 45 | ) == [ 46 | {"id": 1, "a": 1, "b": 2, "c": 3}, 47 | {"id": 2, "a": 1, "b": 2}, 48 | ] 49 | 50 | @staticmethod 51 | def test_record_combining_no_record_mapper(): 52 | mapper = RecordCombiningMapper(record_mapper=None) 53 | assert len(mapper.map_records([])) == 0 54 | assert mapper.map_records([{"a": 1, "b": 2}]) == [{"a": 1, "b": 2}] 55 | assert mapper.map_records( 56 | [ 57 | {"id": 1, "a": 1, "b": 2}, 58 | {"id": 2, "a": 1, "b": 2}, 59 | {"id": 1, "c": 3}, 60 | ] 61 | ) == [ 62 | {"id": 1, "a": 1, "b": 2}, 63 | {"id": 2, "a": 1, "b": 2}, 64 | {"id": 1, "c": 3}, 65 | ] 66 | 67 | @staticmethod 68 | def test_single_row(): 69 | mapper = SingleRowMapper() 70 | assert mapper.map_records([]) is None 71 | assert mapper.map_records([{"a": 1, "b": 2}]).raw() == {"a": 1, "b": 2} 72 | assert mapper.map_records( 73 | [ 74 | {"id": 1, "a": 1, "b": 2}, 75 | {"id": 2, "a": 1, "b": 2}, 76 | {"id": 1, "c": 3}, 77 | ] 78 | ).raw() == {"id": 1, "a": 1, "b": 2} 79 | 80 | @staticmethod 81 | def test_single_row_no_record_mapper(): 82 | mapper = SingleRowMapper(record_mapper=None) 83 | assert mapper.map_records([]) is None 84 | assert mapper.map_records([{"a": 1, "b": 2}]) == {"a": 1, "b": 2} 85 | assert mapper.map_records( 86 | [ 87 | {"id": 1, "a": 1, "b": 2}, 88 | {"id": 2, "a": 1, "b": 2}, 89 | {"id": 1, "c": 3}, 90 | ] 91 | ) == {"id": 1, "a": 1, "b": 2} 92 | 93 | @staticmethod 94 | def test_single_column(): 95 | mapper = SingleColumnMapper() 96 | assert len(mapper.map_records([])) == 0 97 | assert mapper.map_records([{"a": 1, "b": 2}]) == [1] 98 | assert mapper.map_records( 99 | [ 100 | {"id": "myid1", "a": 1, "b": 2}, 101 | {"id": "myid2", "a": 1, "b": 2}, 102 | {"id": "myid3", "c": 3}, 103 | ] 104 | ) == ["myid1", "myid2", "myid3"] 105 | 106 | @staticmethod 107 | @pytest.mark.parametrize( 108 | "mapper", 109 | [ 110 | SingleRowAndColumnMapper(), 111 | # Alias for the other one 112 | CountMapper(), 113 | ], 114 | ) 115 | def test_single_column_and_row(mapper): 116 | assert mapper.map_records([]) is None 117 | assert mapper.map_records([{"a": 1, "b": 2}]) == 1 118 | assert ( 119 | mapper.map_records( 120 | [ 121 | {"id": "myid", "a": 1, "b": 2}, 122 | {"id": 2, "a": 1, "b": 2}, 123 | {"id": 1, "c": 3}, 124 | ] 125 | ) 126 | == "myid" 127 | ) 128 | 129 | @staticmethod 130 | @pytest.mark.parametrize( 131 | "mapper, expected", 132 | [ 133 | (KeyValueMapper(), {"a": 4, "b": 7}), 134 | (KeyValueMapper(key_column="column_named_something"), {"a": 4, "b": 7}), 135 | (KeyValueMapper(value_column="column_with_some_value"), {"a": 4, "b": 7}), 136 | ( 137 | KeyValueMapper(has_multiple_values_per_key=True), 138 | {"a": [1, 2, 3, 4], "b": [3, 4, 5, 6, 7]}, 139 | ), 140 | ( 141 | KeyValueMapper( 142 | key_column="column_named_something", 143 | has_multiple_values_per_key=True, 144 | ), 145 | {"a": [1, 2, 3, 4], "b": [3, 4, 5, 6, 7]}, 146 | ), 147 | ( 148 | KeyValueMapper( 149 | key_column="column_with_some_value", 150 | value_column="column_named_something", 151 | has_multiple_values_per_key=True, 152 | ), 153 | { 154 | 1: ["a"], 155 | 2: ["a"], 156 | 3: ["a", "b"], 157 | 4: ["a", "b"], 158 | 5: ["b"], 159 | 6: ["b"], 160 | 7: ["b"], 161 | }, 162 | ), 163 | ( 164 | KeyValueMapper( 165 | key_column="column_with_some_value", 166 | value_column="column_named_something", 167 | ), 168 | {1: "a", 2: "a", 3: "b", 4: "b", 5: "b", 6: "b", 7: "b"}, 169 | ), 170 | ], 171 | ) 172 | def test_key_mapper_key_has_multiple(mapper, expected): 173 | result = mapper.map_records( 174 | [ 175 | HelperRow( 176 | ("column_named_something", "column_with_some_value"), ["a", 1] 177 | ), 178 | HelperRow( 179 | ("column_named_something", "column_with_some_value"), ["a", 2] 180 | ), 181 | HelperRow( 182 | ("column_named_something", "column_with_some_value"), ["a", 3] 183 | ), 184 | HelperRow( 185 | ("column_named_something", "column_with_some_value"), ["a", 4] 186 | ), 187 | HelperRow( 188 | ("column_named_something", "column_with_some_value"), ["b", 3] 189 | ), 190 | HelperRow( 191 | ("column_named_something", "column_with_some_value"), ["b", 4] 192 | ), 193 | HelperRow( 194 | ("column_named_something", "column_with_some_value"), ["b", 5] 195 | ), 196 | HelperRow( 197 | ("column_named_something", "column_with_some_value"), ["b", 6] 198 | ), 199 | HelperRow( 200 | ("column_named_something", "column_with_some_value"), ["b", 7] 201 | ), 202 | ] 203 | ) 204 | assert len(result) == len(expected) 205 | assert result == expected 206 | 207 | @staticmethod 208 | def test_key_mapper_key_value_same(): 209 | with pytest.raises( 210 | MapperError, match="key and value columns cannot be the same" 211 | ): 212 | KeyValueMapper(key_column="same", value_column="same") 213 | 214 | 215 | class HelperRow: # pylint: disable=too-few-public-methods 216 | """ 217 | Helper class does the most basic functionality we see when accessing records passed in 218 | """ 219 | 220 | def __init__(self, fields, values): 221 | self.fields = fields 222 | self.values = values 223 | 224 | def __getitem__(self, key): 225 | if isinstance(key, int): 226 | return self.values[key] 227 | return self.values[self.fields.index(key)] 228 | -------------------------------------------------------------------------------- /dysql/test/test_template_generators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import pytest 10 | 11 | from dysql.query_utils import TemplateGenerators, ListTemplateException 12 | 13 | 14 | class TestTemplatesGenerators: 15 | """ 16 | Test we get templates back from Templates 17 | """ 18 | 19 | number_values = [1, 2, 3, 4] 20 | string_values = ["1", "2", "3", "4"] 21 | set_values = {1, 2, 3, 4} 22 | insert_values = [("ironman", 1), ("batman", 2)] 23 | tuple_values = [(1, 2), (3, 4)] 24 | query = "column_a IN ( :column_a_0, :column_a_1, :column_a_2, :column_a_3 )" 25 | 26 | query_with_list_of_tuples = ( 27 | "(column_a, column_b) IN (( :column_acolumn_b_0_0, :column_acolumn_b_0_1 ), " 28 | "( :column_acolumn_b_1_0, :column_acolumn_b_1_1 ))" 29 | ) 30 | not_query_with_list_of_tuples = ( 31 | "(column_a, column_b) NOT IN (( :column_acolumn_b_0_0, :column_acolumn_b_0_1 ), " 32 | "( :column_acolumn_b_1_0, :column_acolumn_b_1_1 ))" 33 | ) 34 | query_with_table = "table.column_a IN ( :table_column_a_0, :table_column_a_1, :table_column_a_2, :table_column_a_3 )" 35 | not_query = "column_a NOT IN ( :column_a_0, :column_a_1, :column_a_2, :column_a_3 )" 36 | not_query_with_table = "table.column_a NOT IN ( :table_column_a_0, :table_column_a_1, :table_column_a_2, :table_column_a_3 )" 37 | 38 | test_query_data = ( 39 | "template_function, column_name,column_values,expected_query", 40 | [ 41 | (TemplateGenerators.in_column, "column_a", number_values, query), 42 | ( 43 | TemplateGenerators.in_column, 44 | "table.column_a", 45 | number_values, 46 | query_with_table, 47 | ), 48 | ( 49 | TemplateGenerators.in_column, 50 | "table.column_a", 51 | set_values, 52 | query_with_table, 53 | ), 54 | (TemplateGenerators.in_column, "column_a", string_values, query), 55 | ( 56 | TemplateGenerators.in_column, 57 | "table.column_a", 58 | string_values, 59 | query_with_table, 60 | ), 61 | (TemplateGenerators.in_column, "column_a", [], "1 <> 1"), 62 | ( 63 | TemplateGenerators.in_multi_column, 64 | "(column_a, column_b)", 65 | tuple_values, 66 | query_with_list_of_tuples, 67 | ), 68 | ( 69 | TemplateGenerators.not_in_multi_column, 70 | "(column_a, column_b)", 71 | tuple_values, 72 | not_query_with_list_of_tuples, 73 | ), 74 | (TemplateGenerators.not_in_column, "column_a", number_values, not_query), 75 | ( 76 | TemplateGenerators.not_in_column, 77 | "table.column_a", 78 | number_values, 79 | not_query_with_table, 80 | ), 81 | (TemplateGenerators.not_in_column, "column_a", string_values, not_query), 82 | ( 83 | TemplateGenerators.not_in_column, 84 | "table.column_a", 85 | string_values, 86 | not_query_with_table, 87 | ), 88 | ( 89 | TemplateGenerators.not_in_column, 90 | "table.column_a", 91 | set_values, 92 | not_query_with_table, 93 | ), 94 | (TemplateGenerators.not_in_column, "column_a", [], "1 = 1"), 95 | ( 96 | TemplateGenerators.values, 97 | "someid", 98 | insert_values, 99 | "VALUES ( :someid_0_0, :someid_0_1 ), ( :someid_1_0, :someid_1_1 )", 100 | ), 101 | ], 102 | ) 103 | 104 | parameter_numbers = { 105 | "column_a_0": number_values[0], 106 | "column_a_1": number_values[1], 107 | "column_a_2": number_values[2], 108 | "column_a_3": number_values[3], 109 | } 110 | with_table_parameter_numbers = { 111 | "table_column_a_0": number_values[0], 112 | "table_column_a_1": number_values[1], 113 | "table_column_a_2": number_values[2], 114 | "table_column_a_3": number_values[3], 115 | } 116 | parameter_strings = { 117 | "column_a_0": string_values[0], 118 | "column_a_1": string_values[1], 119 | "column_a_2": string_values[2], 120 | "column_a_3": string_values[3], 121 | } 122 | with_table_parameter_strings = { 123 | "table_column_a_0": string_values[0], 124 | "table_column_a_1": string_values[1], 125 | "table_column_a_2": string_values[2], 126 | "table_column_a_3": string_values[3], 127 | } 128 | test_params_data = ( 129 | "template_function, column_name,column_values,expected_params", 130 | [ 131 | ( 132 | TemplateGenerators.in_column, 133 | "column_a", 134 | number_values, 135 | parameter_numbers, 136 | ), 137 | ( 138 | TemplateGenerators.in_column, 139 | "column_a", 140 | set_values, 141 | parameter_numbers, 142 | ), 143 | ( 144 | TemplateGenerators.in_column, 145 | "table.column_a", 146 | number_values, 147 | with_table_parameter_numbers, 148 | ), 149 | ( 150 | TemplateGenerators.in_column, 151 | "column_a", 152 | string_values, 153 | parameter_strings, 154 | ), 155 | ( 156 | TemplateGenerators.in_column, 157 | "table.column_a", 158 | string_values, 159 | with_table_parameter_strings, 160 | ), 161 | (TemplateGenerators.in_column, "column_a", [], None), 162 | ( 163 | TemplateGenerators.not_in_column, 164 | "column_a", 165 | number_values, 166 | parameter_numbers, 167 | ), 168 | ( 169 | TemplateGenerators.not_in_column, 170 | "table.column_a", 171 | number_values, 172 | with_table_parameter_numbers, 173 | ), 174 | ( 175 | TemplateGenerators.not_in_column, 176 | "column_a", 177 | string_values, 178 | parameter_strings, 179 | ), 180 | ( 181 | TemplateGenerators.not_in_column, 182 | "column_a", 183 | set_values, 184 | parameter_numbers, 185 | ), 186 | ( 187 | TemplateGenerators.not_in_column, 188 | "table.column_a", 189 | string_values, 190 | with_table_parameter_strings, 191 | ), 192 | (TemplateGenerators.not_in_column, "column_a", [], None), 193 | ( 194 | TemplateGenerators.values, 195 | "someid", 196 | insert_values, 197 | { 198 | "someid_0_0": insert_values[0][0], 199 | "someid_0_1": insert_values[0][1], 200 | "someid_1_0": insert_values[1][0], 201 | "someid_1_1": insert_values[1][1], 202 | }, 203 | ), 204 | ], 205 | ) 206 | 207 | @staticmethod 208 | @pytest.mark.parametrize(*test_query_data) 209 | def test_query(template_function, column_name, column_values, expected_query): 210 | query, _ = template_function(column_name, column_values) 211 | assert query == expected_query 212 | 213 | @staticmethod 214 | @pytest.mark.parametrize(*test_params_data) 215 | def test_params(template_function, column_name, column_values, expected_params): 216 | _, params = template_function(column_name, column_values) 217 | assert params == expected_params 218 | 219 | @staticmethod 220 | def test_insert_none(): 221 | with pytest.raises(ListTemplateException): 222 | TemplateGenerators.values("someid", None) 223 | -------------------------------------------------------------------------------- /dysql/test/test_database_initialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | from unittest import mock 10 | 11 | import pytest 12 | 13 | import dysql.connections 14 | from dysql import ( 15 | sqlquery, 16 | DBNotPreparedError, 17 | set_default_connection_parameters, 18 | QueryData, 19 | set_database_init_hook, 20 | ) 21 | from dysql.test import setup_mock_engine 22 | 23 | 24 | """ 25 | These tests reflect what can happen when dealing with a config passing in a value that may not 26 | have been set. There are some parameters that are absolutely necessary, if they aren't set we 27 | want to make sure we are warning the user in different ways. 28 | """ 29 | 30 | 31 | @sqlquery() 32 | def query(): 33 | """ 34 | This is used to call the code that would initialize the database on the first time. 35 | If there were any failures this is where we would expect to see them. 36 | """ 37 | return QueryData("SELECT * FROM table") 38 | 39 | 40 | @pytest.fixture(autouse=True, name="mock_engine") 41 | def fixture_mock_engine(mock_create_engine): 42 | dysql.databases.DatabaseContainerSingleton().clear() 43 | dysql.databases._DEFAULT_CONNECTION_PARAMS_BY_KEY.clear() 44 | 45 | # Reset the database before the test 46 | dysql.databases.reset_current_database() 47 | 48 | yield setup_mock_engine(mock_create_engine) 49 | 50 | # Reset database after the test as well 51 | dysql.databases.reset_current_database() 52 | 53 | 54 | @pytest.fixture(autouse=True) 55 | def fixture_reset_init_hook(): 56 | yield 57 | if hasattr(dysql.databases.Database, "hook_method"): 58 | delattr(dysql.databases.Database, "hook_method") 59 | 60 | 61 | def test_nothing_set(): 62 | dysql.databases._DEFAULT_CONNECTION_PARAMS_BY_KEY.clear() 63 | with pytest.raises(DBNotPreparedError) as error: 64 | query() 65 | assert ( 66 | str(error.value) 67 | == "Unable to connect to a database, set_default_connection_parameters must first be called" 68 | ) 69 | 70 | 71 | @pytest.mark.parametrize( 72 | "host, user, password, database, failed_field", 73 | [ 74 | (None, "u", "p", "d", "host"), 75 | ("", "u", "p", "d", "host"), 76 | ("h", None, "p", "d", "user"), 77 | ("h", "", "p", "d", "user"), 78 | ("h", "u", None, "d", "password"), 79 | ("h", "u", "", "d", "password"), 80 | ("h", "u", "p", None, "database"), 81 | ("h", "u", "p", "", "database"), 82 | ], 83 | ) 84 | def test_fields_required(host, user, password, database, failed_field): 85 | with pytest.raises(DBNotPreparedError) as error: 86 | set_default_connection_parameters(host, user, password, database) 87 | assert ( 88 | str(error.value) 89 | == f'Database parameter "{failed_field}" is not set or empty and is required' 90 | ) 91 | 92 | 93 | def test_minimal_credentials(mock_engine): 94 | set_default_connection_parameters("h", "u", "p", "d") 95 | 96 | mock_engine.connect().execution_options().execute.return_value = [] 97 | query() 98 | 99 | 100 | def test_init_hook(mock_engine): 101 | init_hook = mock.MagicMock() 102 | set_database_init_hook(init_hook) 103 | set_default_connection_parameters("h", "u", "p", "d") 104 | dysql.databases.set_current_database("d") 105 | 106 | mock_engine.connect().execution_options().execute.return_value = [] 107 | query() 108 | init_hook.assert_called_once_with("d", mock_engine) 109 | 110 | 111 | def test_init_hook_multiple_databases(mock_engine): 112 | init_hook = mock.MagicMock() 113 | set_database_init_hook(init_hook) 114 | set_default_connection_parameters("h1", "u1", "p1", "d1") 115 | set_default_connection_parameters("h2", "u2", "p2", "d2") 116 | 117 | mock_engine.connect().execution_options().execute.return_value = [] 118 | query() 119 | dysql.databases.set_current_database("d1") 120 | query() 121 | assert init_hook.call_args_list == [ 122 | mock.call("test", mock_engine), 123 | mock.call("d1", mock_engine), 124 | ] 125 | dysql.databases.set_current_database("d2") 126 | query() 127 | assert init_hook.call_args_list == [ 128 | mock.call("test", mock_engine), 129 | mock.call("d1", mock_engine), 130 | mock.call("d2", mock_engine), 131 | ] 132 | 133 | 134 | def test_current_database_default(mock_engine, mock_create_engine): 135 | db_container = dysql.databases.DatabaseContainerSingleton() 136 | assert len(db_container) == 0 137 | mock_engine.connect().execution_options().execute.return_value = [] 138 | query() 139 | 140 | # Only one database is initialized 141 | assert len(db_container) == 1 142 | assert "test" in db_container 143 | assert db_container.current_database.database == "test" 144 | mock_create_engine.assert_called_once_with( 145 | "mysql+mysqlconnector://user:password@fake:3306/test?charset=utf8", 146 | echo=False, 147 | pool_pre_ping=True, 148 | pool_recycle=3600, 149 | pool_size=10, 150 | ) 151 | 152 | 153 | def test_different_charset_collation(mock_engine, mock_create_engine): 154 | db_container = dysql.databases.DatabaseContainerSingleton() 155 | set_default_connection_parameters( 156 | "host", 157 | "user", 158 | "password", 159 | "database", 160 | charset="other", 161 | collation="other_collation", 162 | ) 163 | assert len(db_container) == 0 164 | dysql.databases.set_current_database("database") 165 | mock_engine.connect().execution_options().execute.return_value = [] 166 | query() 167 | 168 | # Only one database is initialized 169 | mock_create_engine.assert_called_once_with( 170 | "mysql+mysqlconnector://user:password@host:3306/database?charset=other&collation=other_collation", 171 | echo=False, 172 | pool_pre_ping=True, 173 | pool_recycle=3600, 174 | pool_size=10, 175 | ) 176 | 177 | 178 | def test_is_set_current_database_supported(): 179 | # This test only asserts that true is always returned since this method is deprecated 180 | assert dysql.databases.is_set_current_database_supported() 181 | 182 | 183 | def test_current_database_set(mock_engine, mock_create_engine): 184 | db_container = dysql.databases.DatabaseContainerSingleton() 185 | set_default_connection_parameters("h1", "u1", "p1", "d1", database_key="db1") 186 | dysql.databases.set_current_database("db1") 187 | mock_engine.connect().execution_options().execute.return_value = [] 188 | query() 189 | 190 | assert len(db_container) == 1 191 | assert "db1" in db_container 192 | assert db_container.current_database.database == "db1" 193 | mock_create_engine.assert_called_once_with( 194 | "mysql+mysqlconnector://u1:p1@h1:3306/d1?charset=utf8", 195 | echo=False, 196 | pool_pre_ping=True, 197 | pool_recycle=3600, 198 | pool_size=10, 199 | ) 200 | 201 | 202 | def test_current_database_cached(mock_engine, mock_create_engine): 203 | db_container = dysql.databases.DatabaseContainerSingleton() 204 | mock_engine.connect().execution_options().execute.return_value = [] 205 | query() 206 | 207 | assert len(db_container) == 1 208 | assert "test" in db_container 209 | assert db_container.current_database.database == "test" 210 | 211 | set_default_connection_parameters("h1", "u1", "p1", "db1") 212 | dysql.databases.set_current_database("db1") 213 | query() 214 | assert len(db_container) == 2 215 | assert "test" in db_container 216 | assert db_container.current_database.database == "db1" 217 | 218 | # Set back to default 219 | dysql.databases.reset_current_database() 220 | query() 221 | assert len(db_container) == 2 222 | assert db_container.current_database.database == "test" 223 | 224 | assert mock_create_engine.call_count == 2 225 | assert mock_create_engine.call_args_list == [ 226 | mock.call( 227 | "mysql+mysqlconnector://user:password@fake:3306/test?charset=utf8", 228 | echo=False, 229 | pool_pre_ping=True, 230 | pool_recycle=3600, 231 | pool_size=10, 232 | ), 233 | mock.call( 234 | "mysql+mysqlconnector://u1:p1@h1:3306/db1?charset=utf8", 235 | echo=False, 236 | pool_pre_ping=True, 237 | pool_recycle=3600, 238 | pool_size=10, 239 | ), 240 | ] 241 | -------------------------------------------------------------------------------- /dysql/databases.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import contextvars 10 | import logging 11 | from collections import defaultdict 12 | from typing import Callable, Optional 13 | 14 | import sqlalchemy 15 | 16 | from .exceptions import DBNotPreparedError 17 | 18 | 19 | logger = logging.getLogger("database") 20 | 21 | _DEFAULT_CONNECTION_PARAMS_BY_KEY = defaultdict(dict) 22 | CURRENT_DATABASE_VAR = contextvars.ContextVar("dysql_current_database", default="") 23 | 24 | 25 | def set_database_init_hook( 26 | hook_method: Callable[[Optional[str], sqlalchemy.engine.Engine], None], 27 | ) -> None: 28 | """ 29 | Sets an initialization hook whenever a new database is initialized. This method will receive the database name 30 | (may be none) and the sqlalchemy engine as parameters. 31 | 32 | :param hook_method: the hook method 33 | """ 34 | Database.set_init_hook(hook_method) 35 | 36 | 37 | def is_set_current_database_supported() -> bool: 38 | """ 39 | Deprecated, left in for backwards compatibility but always returns true. 40 | :return: True 41 | """ 42 | return True 43 | 44 | 45 | def set_current_database(database_key: str) -> None: 46 | """ 47 | Sets the current database key, may be used for multitenancy. This is only supported on Python 3.7+. This uses 48 | contextvars internally to set the name for the current async context. 49 | :param database_key: the arbitrary database key to use for this async context 50 | """ 51 | CURRENT_DATABASE_VAR.set(database_key) 52 | logger.debug(f"Set current database to {database_key}") 53 | 54 | 55 | def reset_current_database() -> None: 56 | """ 57 | Helper method to reset the current database to the default. Internally, this calls `set_current_database` with 58 | an empty string. 59 | """ 60 | set_current_database("") 61 | 62 | 63 | def _get_current_database_key() -> str: 64 | """ 65 | The current database key, using contextvars (if on python 3.7+) or the default database key. 66 | :return: The current database key 67 | """ 68 | database: Optional[str] = None 69 | if CURRENT_DATABASE_VAR: 70 | database = CURRENT_DATABASE_VAR.get() 71 | if not database and _DEFAULT_CONNECTION_PARAMS_BY_KEY: 72 | # Get first database key 73 | database = next(iter(_DEFAULT_CONNECTION_PARAMS_BY_KEY)) 74 | return database 75 | 76 | 77 | def _validate_param(name: str, value: str) -> None: 78 | if not value: 79 | raise DBNotPreparedError( 80 | f'Database parameter "{name}" is not set or empty and is required' 81 | ) 82 | 83 | 84 | def set_default_connection_parameters( 85 | host: str, 86 | user: str, 87 | password: str, 88 | database: str, 89 | database_key: Optional[str] = None, 90 | port: int = 3306, 91 | pool_size: int = 10, 92 | pool_recycle: int = 3600, 93 | echo_queries: bool = False, 94 | charset: str = "utf8", 95 | collation: Optional[str] = None, 96 | ): 97 | """ 98 | Initializes the parameters to use when connecting to the database. This is a subset of the parameters 99 | used by sqlalchemy. These may be overridden by parameters provided in the QueryData, hence the "default". 100 | 101 | :param host: the db host to try to connect to 102 | :param user: user to connect to the database with 103 | :param password: password for given user 104 | :param database: database to connect to 105 | :param database_key: optional database key that may be used for multitenant DBs, defaults to the database name 106 | :param port: the port to connect to (default 3306) 107 | :param pool_size: number of connections to maintain in the connection pool (default 10) 108 | :param pool_recycle: amount of time to wait between resetting the connections 109 | in the pool (default 3600) 110 | :param echo_queries: this tells sqlalchemy to print the queries when set to True (default false) 111 | :param charset: the charset for the sql engine to initialize with. (default utf8) 112 | :param collation: the collation for the sql engine to initialize with. (default is not set) 113 | :exception DBNotPrepareError: happens when required parameters are missing 114 | """ 115 | _validate_param("host", host) 116 | _validate_param("user", user) 117 | _validate_param("password", password) 118 | _validate_param("database", database) 119 | 120 | if not database_key: 121 | database_key = database 122 | _DEFAULT_CONNECTION_PARAMS_BY_KEY[database_key].update(locals()) 123 | 124 | 125 | class Database: 126 | def __init__(self, database_key: Optional[str]) -> None: 127 | self.database = database_key 128 | # Engine is lazy-initialized 129 | self._engine: Optional[sqlalchemy.engine.Engine] = None 130 | 131 | @classmethod 132 | def set_init_hook( 133 | cls, 134 | hook_method: Callable[[Optional[str], sqlalchemy.engine.Engine], None], 135 | ) -> None: 136 | cls.hook_method = hook_method 137 | 138 | @property 139 | def engine(self) -> sqlalchemy.engine.Engine: 140 | if not self._engine: 141 | connection_params = _DEFAULT_CONNECTION_PARAMS_BY_KEY.get(self.database, {}) 142 | if not connection_params: 143 | raise DBNotPreparedError( 144 | f"No connection parameters found for database key '{self.database}'" 145 | ) 146 | user = connection_params.get("user") 147 | password = connection_params.get("password") 148 | database = connection_params.get("database") 149 | host = connection_params.get("host") 150 | port = connection_params.get("port") 151 | charset = connection_params.get("charset") 152 | collation = connection_params.get("collation") 153 | collation_str = "" 154 | if collation: 155 | collation_str = f"&collation={collation}" 156 | 157 | url = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}?charset={charset}{collation_str}" 158 | self._engine = sqlalchemy.create_engine( 159 | url, 160 | pool_recycle=connection_params.get("pool_recycle"), 161 | pool_size=connection_params.get("pool_size"), 162 | echo=connection_params.get("echo_queries"), 163 | pool_pre_ping=True, 164 | ) 165 | hook_method: Optional[ 166 | Callable[[Optional[str], sqlalchemy.engine.Engine], None] 167 | ] = getattr(self.__class__, "hook_method", None) 168 | if hook_method: 169 | hook_method(database, self._engine) 170 | 171 | return self._engine 172 | 173 | 174 | class DatabaseContainer(dict): 175 | """ 176 | Implementation of a dictionary that always provides a Database class instance, even if the key is missing. 177 | """ 178 | 179 | def __getitem__(self, database: Optional[str]) -> Database: 180 | """ 181 | Override getitem to always return an instance of a database, which includes a lazy-initialized engine. 182 | This also ensures that the database parameters have been initialized before attempting to retrieve a database. 183 | :param database: the database name (may be null for the default database) 184 | :return: a database instance 185 | :raises DBNotPreparedError: when set_default_connection_parameters has not yet been called 186 | """ 187 | if not _DEFAULT_CONNECTION_PARAMS_BY_KEY: 188 | raise DBNotPreparedError( 189 | "Unable to connect to a database, set_default_connection_parameters must first be called" 190 | ) 191 | 192 | if not super().__contains__(database): 193 | super().__setitem__(database, Database(database)) 194 | return super().__getitem__(database) 195 | 196 | @property 197 | def current_database(self) -> Database: 198 | """ 199 | The current database instance, retrieved using contextvars (if python 3.7+) or the default database. 200 | """ 201 | return self.__getitem__(_get_current_database_key()) 202 | 203 | 204 | class DatabaseContainerSingleton(DatabaseContainer): 205 | """ 206 | All instantiations of this class will result in the same instance every time due to the override of 207 | the __new__ method. 208 | """ 209 | 210 | def __new__(cls, *args, **kwargs) -> "DatabaseContainer": 211 | instance = cls.__dict__.get("__instance__") 212 | if instance is not None: 213 | return instance 214 | cls.__instance__ = instance = DatabaseContainer.__new__(cls) 215 | instance.__init__(*args, **kwargs) 216 | return instance 217 | -------------------------------------------------------------------------------- /dysql/test_managers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import abc 10 | import logging 11 | import os 12 | import subprocess 13 | from time import sleep 14 | from typing import Optional 15 | 16 | from .connections import sqlquery 17 | from .databases import ( 18 | is_set_current_database_supported, 19 | set_current_database, 20 | set_default_connection_parameters, 21 | ) 22 | from .mappers import CountMapper 23 | from .query_utils import QueryData 24 | 25 | LOGGER = logging.getLogger(__name__) 26 | 27 | 28 | class DbTestManagerBase(abc.ABC): 29 | """ 30 | Base class for all test managers. See individual implementations for usage details. 31 | """ 32 | 33 | # pylint: disable=too-many-instance-attributes 34 | 35 | def __init__( 36 | self, 37 | host: str, 38 | username: str, 39 | password: str, 40 | db_name: str, 41 | schema_db_name: Optional[str], 42 | docker_container: Optional[str] = None, 43 | keep_db: bool = False, 44 | **connection_defaults, 45 | ): # pylint: disable=too-many-arguments 46 | """ 47 | Constructor, any unknown kwargs are passed directly to set_default_connection_parameters. 48 | 49 | :param host: the host to access the test database 50 | :param username: the username to access the test database 51 | :param password: the password to access the test database 52 | :param db_name: the name of the test database 53 | :param schema_db_name: the name of the DB to duplicate schema from (using mariadb-dump) 54 | :param docker_container: the name of the docker container where the database is running (if in docker) 55 | :param keep_db: This prevents teardown from removing the created database after running tests 56 | which can be helpful in debugging 57 | """ 58 | self.host = host 59 | self.username = username 60 | self.password = password 61 | self.db_name = db_name 62 | self.schema_db_name = schema_db_name 63 | self.docker_container = docker_container 64 | self.keep_db = keep_db 65 | self.connection_defaults = connection_defaults 66 | self.in_docker = self._is_running_in_docker() 67 | 68 | @staticmethod 69 | def _is_running_in_docker(): 70 | if os.path.exists("/.dockerenv"): 71 | return True 72 | 73 | if os.path.exists("/proc/1/cgroup"): 74 | with open("/proc/1/cgroup", "rt", encoding="utf8") as fobj: 75 | contents = fobj.read() 76 | for marker in ("docker", "kubepod", "lxc"): 77 | if marker in contents: 78 | return True 79 | return False 80 | 81 | def __enter__(self): 82 | LOGGER.debug(f"Setting up database : {self.db_name}") 83 | # Set the host based on whether we are in buildrunner or not (to test locally) 84 | 85 | self._create_test_db() 86 | set_default_connection_parameters( 87 | self.host, 88 | self.username, 89 | self.password, 90 | self.db_name, 91 | **self.connection_defaults, 92 | ) 93 | 94 | # Set the database if supported by the python runtime 95 | if is_set_current_database_supported(): 96 | set_current_database(self.db_name) 97 | 98 | if self.schema_db_name: 99 | self._wait_for_tables_exist() 100 | 101 | def __exit__(self, exc_type, exc_val, exc_tb): 102 | if not self.keep_db: 103 | LOGGER.debug(f"Tearing down database : {self.db_name}") 104 | self._tear_down_test_db() 105 | 106 | @abc.abstractmethod 107 | def _create_test_db(self) -> None: 108 | """ 109 | Must be overridden in implementing classes. Creates the test database from the self.schema_db_name 110 | database (if present) or an empty database desired. 111 | """ 112 | 113 | @abc.abstractmethod 114 | def _get_tables_count(self, db_name: str) -> int: 115 | """ 116 | Must be overridden in implementing classes. Retrieves the current table count for the database. 117 | :param db_name: the database to retrieve the count for 118 | :return: the number of tables present in the database 119 | """ 120 | 121 | @abc.abstractmethod 122 | def _tear_down_test_db(self) -> None: 123 | """ 124 | Must be overridden in implementing classes. Tears down the test database. 125 | """ 126 | 127 | def _wait_for_tables_exist(self) -> None: 128 | """ 129 | Waits for the tables to exist in the new test database. 130 | """ 131 | tables_exist = False 132 | expected_count = self._get_tables_count(self.schema_db_name) 133 | while not tables_exist: 134 | sleep(0.25) 135 | LOGGER.debug("Tables are still not ready") 136 | actual_count = self._get_tables_count(self.db_name) 137 | tables_exist = expected_count == actual_count 138 | 139 | def _run(self, command: str) -> subprocess.CompletedProcess: 140 | """ 141 | Runs a command, wrapping the command in docker commands if necessary. 142 | :param command: the command to run 143 | :return: the completed process information 144 | """ 145 | if not self.in_docker: 146 | command = f"docker exec {self.docker_container} bash -c '{command}'" 147 | 148 | try: 149 | LOGGER.debug(f"Executing : '{command}'") 150 | completed_process = subprocess.run( 151 | command, shell=True, timeout=30, check=True, capture_output=True 152 | ) 153 | LOGGER.debug(f"Executed : {completed_process.stdout}") 154 | 155 | return completed_process 156 | except subprocess.CalledProcessError: 157 | LOGGER.exception(f"Error handling command : {command}") 158 | raise 159 | 160 | 161 | class MariaDbTestManager(DbTestManagerBase): 162 | """ 163 | May be used in testing to copy and create a database and schema. This class works specifically with 164 | Maria/MySQL databases. 165 | 166 | Example: 167 | 168 | @pytest.fixture(scope='module', autouse=True) 169 | def setup_db(self): 170 | # Pass in the database name and any optional params 171 | with MariaDbTestManager(f'testdb_{self.__class__.__name__.lower()}'): 172 | yield 173 | """ 174 | 175 | # pylint: disable=too-few-public-methods 176 | 177 | def __init__( 178 | self, 179 | db_name: str, 180 | schema_db_name: Optional[str] = None, 181 | echo_queries: bool = False, 182 | keep_db: bool = False, 183 | pool_size=3, 184 | charset="utf8", 185 | ): # pylint: disable=too-many-arguments 186 | """ 187 | :param db_name: the name you want for your test database 188 | :param schema_db_name: the name of the DB to duplicate schema from (using mariadb-dump) 189 | :param echo_queries: True if you want to see queries 190 | :param keep_db: This prevents teardown from removing the created DB after running tests 191 | which can be helpful in debugging 192 | :param charset: This allows you to override the default charset if you need something besides utf8 193 | """ 194 | super().__init__( 195 | os.getenv("MARIA_HOST", "localhost"), 196 | os.getenv("MARIA_USERNAME", "root"), 197 | os.getenv("MARIA_PASSWORD", "password"), 198 | db_name, 199 | schema_db_name, 200 | port=3306, 201 | echo_queries=echo_queries, 202 | pool_size=pool_size, 203 | docker_container=os.getenv("MARIA_CONTAINER_NAME", "mariadb"), 204 | keep_db=keep_db, 205 | charset=charset, 206 | ) 207 | 208 | def _create_test_db(self) -> None: 209 | self._run( 210 | f'mariadb -p{self.password} -h{self.host} -N -e "DROP DATABASE IF EXISTS {self.db_name}"' 211 | ) 212 | self._run( 213 | f'mariadb -p{self.password} -h{self.host} -s -N -e "CREATE DATABASE IF NOT EXISTS {self.db_name}"' 214 | ) 215 | if self.schema_db_name: 216 | self._run( 217 | f"mariadb-dump --no-data -p{self.password} {self.schema_db_name} -h{self.host} " 218 | f"| mariadb -p{self.password} {self.db_name} -h{self.host}" 219 | ) 220 | 221 | def _tear_down_test_db(self) -> None: 222 | self._run( 223 | f'echo "DROP DATABASE IF EXISTS {self.db_name} " | mariadb -p{self.password} -h{self.host}' 224 | ) 225 | 226 | @sqlquery(mapper=CountMapper()) 227 | def _get_tables_count(self, db_name: str) -> int: 228 | # pylint: disable=unused-argument 229 | return QueryData( 230 | """ 231 | SELECT count(1) 232 | FROM information_schema.TABLES 233 | WHERE TABLE_SCHEMA=:db_name 234 | """, 235 | query_params={"db_name": db_name}, 236 | ) 237 | -------------------------------------------------------------------------------- /dysql/test/test_pydantic_mappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import json 10 | from typing import Any, Dict, List, Set, Optional 11 | import pytest 12 | 13 | from pydantic import ValidationError 14 | 15 | from dysql import ( 16 | RecordCombiningMapper, 17 | SingleRowMapper, 18 | ) 19 | from dysql.pydantic_mappers import DbMapResultModel 20 | 21 | 22 | class ConversionDbModel(DbMapResultModel): 23 | id: int 24 | field_str: str 25 | field_int: int 26 | field_bool: bool 27 | 28 | 29 | class CombiningDbModel(DbMapResultModel): 30 | _list_fields: Set[str] = {"list1"} 31 | _set_fields: Set[str] = {"set1"} 32 | _dict_key_fields: Dict[str, str] = {"key1": "dict1", "key2": "dict2"} 33 | _dict_value_mappings: Dict[str, str] = {"dict1": "val1", "dict2": "val2"} 34 | 35 | id: int = None 36 | list1: List[str] 37 | set1: Set[str] = set() 38 | dict1: Dict[str, Any] = {} 39 | dict2: Dict[str, int] = {} 40 | 41 | 42 | class DefaultListCombiningDbModel(CombiningDbModel): 43 | list1: List[str] = [] 44 | 45 | 46 | class ListWithStringsModel(DbMapResultModel): 47 | _csv_list_fields: Set[str] = {"list1", "list2"} 48 | 49 | id: int 50 | list1: Optional[List[str]] = None 51 | list2: List[int] = [] # help test empty list gets filled 52 | 53 | 54 | class JsonModel(DbMapResultModel): 55 | _json_fields: Set[str] = {"json1", "json2"} 56 | 57 | id: int 58 | json1: dict 59 | json2: Optional[dict] = None 60 | 61 | 62 | class MultiKeyModel(DbMapResultModel): 63 | @classmethod 64 | def get_key_columns(cls): 65 | return ["a", "b"] 66 | 67 | _list_fields = {"c"} 68 | a: int 69 | b: str 70 | c: List[str] 71 | 72 | 73 | def _unwrap_results(results): 74 | return [r.raw() for r in results] 75 | 76 | 77 | def test_field_conversion(): 78 | mapper = SingleRowMapper(record_mapper=ConversionDbModel) 79 | assert mapper.map_records( 80 | [ 81 | {"id": 1, "field_str": "str1", "field_int": 1, "field_bool": 1}, 82 | ] 83 | ).raw() == {"id": 1, "field_str": "str1", "field_int": 1, "field_bool": True} 84 | 85 | 86 | def test_complex_object_record_combining(): 87 | mapper = RecordCombiningMapper(record_mapper=CombiningDbModel) 88 | assert len(mapper.map_records([])) == 0 89 | assert _unwrap_results( 90 | mapper.map_records( 91 | [ 92 | { 93 | "id": 1, 94 | "list1": "val1", 95 | "set1": "val2", 96 | "key1": "k1", 97 | "val1": "v1", 98 | "key2": "k3", 99 | "val2": 3, 100 | }, 101 | {"id": 2, "list1": "val1"}, 102 | { 103 | "id": 1, 104 | "list1": "val3", 105 | "set1": "val4", 106 | "key1": "k2", 107 | "val1": "v2", 108 | "key2": "k4", 109 | "val2": 4, 110 | }, 111 | ] 112 | ) 113 | ) == [ 114 | { 115 | "id": 1, 116 | "list1": ["val1", "val3"], 117 | "set1": {"val2", "val4"}, 118 | "dict1": {"k1": "v1", "k2": "v2"}, 119 | "dict2": {"k3": 3, "k4": 4}, 120 | }, 121 | { 122 | "id": 2, 123 | "list1": ["val1"], 124 | "set1": set(), 125 | "dict1": {}, 126 | "dict2": {}, 127 | }, 128 | ] 129 | 130 | 131 | def test_record_combining_multi_column_id(): 132 | mapper = RecordCombiningMapper(MultiKeyModel) 133 | assert len(mapper.map_records([])) == 0 134 | assert _unwrap_results( 135 | mapper.map_records( 136 | [ 137 | {"a": 1, "b": "test", "c": "one"}, 138 | {"a": 1, "b": "test", "c": "two"}, 139 | {"a": 1, "b": "test", "c": "three"}, 140 | {"a": 2, "b": "test", "c": "one"}, 141 | {"a": 2, "b": "test", "c": "two"}, 142 | {"a": 1, "b": "othertest", "c": "one"}, 143 | ] 144 | ) 145 | ) == [ 146 | {"a": 1, "b": "test", "c": ["one", "two", "three"]}, 147 | {"a": 2, "b": "test", "c": ["one", "two"]}, 148 | {"a": 1, "b": "othertest", "c": ["one"]}, 149 | ] 150 | 151 | 152 | def test_complex_object_with_null_values(): 153 | mapper = SingleRowMapper(record_mapper=DefaultListCombiningDbModel) 154 | assert mapper.map_records( 155 | [ 156 | {"id": 1}, 157 | ] 158 | ).raw() == { 159 | "id": 1, 160 | "list1": [], 161 | "set1": set(), 162 | "dict1": {}, 163 | "dict2": {}, 164 | } 165 | 166 | 167 | def test_csv_list_field(): 168 | mapper = SingleRowMapper(record_mapper=ListWithStringsModel) 169 | assert mapper.map_records( 170 | [{"id": 1, "list1": "a,b,c,d", "list2": "1,2,3,4"}] 171 | ).raw() == {"id": 1, "list1": ["a", "b", "c", "d"], "list2": [1, 2, 3, 4]} 172 | 173 | 174 | def test_csv_list_field_single_value(): 175 | mapper = SingleRowMapper(record_mapper=ListWithStringsModel) 176 | assert mapper.map_records([{"id": 1, "list1": "a", "list2": "1"}]).raw() == { 177 | "id": 1, 178 | "list1": ["a"], 179 | "list2": [1], 180 | } 181 | 182 | 183 | def test_csv_list_field_empty_string(): 184 | mapper = SingleRowMapper(record_mapper=ListWithStringsModel) 185 | assert mapper.map_records([{"id": 1, "list1": "", "list2": ""}]).raw() == { 186 | "id": 1, 187 | "list1": None, 188 | "list2": [], 189 | } 190 | 191 | 192 | def test_csv_list_field_extends(): 193 | mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) 194 | assert mapper.map_records( 195 | [ 196 | {"id": 1, "list1": "a,b", "list2": "1,2"}, 197 | {"id": 1, "list1": "c,d", "list2": "3,4"}, 198 | ] 199 | )[0].raw() == {"id": 1, "list1": ["a", "b", "c", "d"], "list2": [1, 2, 3, 4]} 200 | 201 | 202 | def test_csv_list_field_multiple_records_duplicates(): 203 | mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) 204 | assert mapper.map_records( 205 | [ 206 | {"id": 1, "list1": "a,b,c,d", "list2": "1,2,3,4"}, 207 | {"id": 1, "list1": "a,b,c,d", "list2": "1,2,3,4"}, 208 | ] 209 | )[0].raw() == { 210 | "id": 1, 211 | "list1": ["a", "b", "c", "d", "a", "b", "c", "d"], 212 | "list2": [1, 2, 3, 4, 1, 2, 3, 4], 213 | } 214 | 215 | 216 | def test_csv_list_field_none_in_first_record(): 217 | mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) 218 | assert mapper.map_records( 219 | [ 220 | {"id": 1, "list1": "a,b,c,d", "list2": None}, 221 | {"id": 1, "list1": "e,f,g", "list2": "1,2,3,4"}, 222 | ] 223 | )[0].raw() == { 224 | "id": 1, 225 | "list1": ["a", "b", "c", "d", "e", "f", "g"], 226 | "list2": [1, 2, 3, 4], 227 | } 228 | 229 | 230 | def test_csv_list_field_without_mapping_ignored(): 231 | mapper = SingleRowMapper(record_mapper=ListWithStringsModel) 232 | assert mapper.map_records( 233 | [{"id": 1, "list1": "a,b, c,d", "list2": "1,2,3,4", "list3": "x,y,z"}] 234 | ).raw() == {"id": 1, "list1": ["a", "b", "c", "d"], "list2": [1, 2, 3, 4]} 235 | 236 | 237 | def test_csv_list_field_invalid_type(): 238 | mapper = RecordCombiningMapper(record_mapper=ListWithStringsModel) 239 | with pytest.raises(ValidationError, match="1 validation error for list"): 240 | mapper.map_records( 241 | [ 242 | {"id": 1, "list1": "a,b", "list2": "1,2"}, 243 | {"id": 1, "list1": "c,d", "list2": "3,a"}, 244 | ] 245 | ) 246 | 247 | 248 | def test_json_field(): 249 | mapper = SingleRowMapper(record_mapper=JsonModel) 250 | assert mapper.map_records( 251 | [ 252 | { 253 | "id": 1, 254 | "json1": json.dumps( 255 | {"a": 1, "b": 2, "c": {"x": 10, "y": 9, "z": {"deep": "value"}}} 256 | ), 257 | } 258 | ] 259 | ).model_dump() == { 260 | "id": 1, 261 | "json1": {"a": 1, "b": 2, "c": {"x": 10, "y": 9, "z": {"deep": "value"}}}, 262 | "json2": None, 263 | } 264 | 265 | 266 | @pytest.mark.parametrize( 267 | "json1, json2", 268 | [ 269 | ('{ "json": value', None), 270 | ('{ "json": value', '{ "json": value }'), 271 | ('{ "json": value }', '{ "json": value'), 272 | (None, '{ "json": value'), 273 | ], 274 | ) 275 | def test_invalid_json(json1, json2): 276 | with pytest.raises(ValidationError, match="Invalid JSON"): 277 | mapper = SingleRowMapper(record_mapper=JsonModel) 278 | mapper.map_records([{"id": 1, "json1": json1, "json2": json2}]) 279 | 280 | 281 | def test_json_none(): 282 | mapper = SingleRowMapper(record_mapper=JsonModel) 283 | assert mapper.map_records( 284 | [{"id": 1, "json1": '{ "first": "value" }', "json2": None}] 285 | ).model_dump() == { 286 | "id": 1, 287 | "json1": { 288 | "first": "value", 289 | }, 290 | "json2": None, 291 | } 292 | -------------------------------------------------------------------------------- /dysql/mappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | # pylint: disable=too-few-public-methods 9 | 10 | import abc 11 | import logging 12 | from collections import OrderedDict, defaultdict 13 | from typing import Any, Optional, Type 14 | 15 | import sqlalchemy 16 | 17 | 18 | LOGGER = logging.getLogger(__name__) 19 | 20 | 21 | class MapperError(Exception): 22 | pass 23 | 24 | 25 | class DbMapResultBase(abc.ABC): 26 | @classmethod 27 | def get_key_columns(cls): 28 | return ["id"] 29 | 30 | @classmethod 31 | def create_instance(cls, *args, **kwargs) -> "DbMapResultBase": 32 | """ 33 | Called instead of constructor, used to support different ways of 34 | creating objects instead of constructors (if desired). 35 | :return: an instance of the class 36 | """ 37 | return cls(*args, **kwargs) 38 | 39 | @abc.abstractmethod 40 | def map_record(self, record: sqlalchemy.engine.Row) -> None: 41 | """ 42 | Implements logic to handle the mapping of individual record objects from DB records. 43 | :param record: a record row from the database 44 | """ 45 | 46 | @abc.abstractmethod 47 | def raw(self) -> dict: 48 | """ 49 | Retrieves the raw data stored in this object as a dictionary. 50 | :return: the dict representing the object data 51 | """ 52 | 53 | @abc.abstractmethod 54 | def has(self, field: str) -> bool: 55 | """ 56 | Returns true if the field exists in the object. 57 | :param field: the field to check 58 | :return: True if the field exists, false otherwise 59 | """ 60 | 61 | @abc.abstractmethod 62 | def get(self, field: str, default: Any = None) -> Any: 63 | """ 64 | Retrieves the value for the field, returning a default value if the field is not available. 65 | :param field: the field to retrieve 66 | :param default: the default value to return if the field is not available (defaults to None) 67 | :return: the value of the field or the default value 68 | """ 69 | 70 | 71 | class DbMapResult(DbMapResultBase): 72 | def __init__(self, **kwargs): 73 | self.__dict__ = kwargs 74 | # pylint: disable=invalid-name 75 | if not self.__dict__.get("id"): 76 | self.id = None 77 | 78 | def __getitem__(self, field: str) -> Any: 79 | return self.__dict__.get(field) 80 | 81 | def __setitem__(self, field: str, value: Any) -> None: 82 | self.__dict__[field] = value 83 | 84 | def __str__(self) -> str: 85 | return str(self.__dict__) 86 | 87 | def map_record(self, record: sqlalchemy.engine.Row) -> None: 88 | for column, value in record.items(): 89 | self[column] = value 90 | 91 | def raw(self) -> dict: 92 | def get_raw(value: Any) -> dict: 93 | if isinstance(value, DbMapResultBase): 94 | return value.raw() 95 | return value 96 | 97 | raw = {} 98 | for key, value in self.__dict__.items(): 99 | if isinstance(value, list): 100 | raw[key] = list(map(get_raw, value)) 101 | else: 102 | raw[key] = get_raw(value) 103 | # Remove the id field if it was never set 104 | if raw.get("id") is None: 105 | del raw["id"] 106 | return raw 107 | 108 | def has(self, field: str) -> bool: 109 | return self.__dict__.get(field) is not None 110 | 111 | def get(self, field: str, default: Any = None) -> Any: 112 | return self.__dict__.get(field, default) 113 | 114 | 115 | class BaseMapper(metaclass=abc.ABCMeta): 116 | """ 117 | Extend this class and implement the map_records method to map the results from a database query. 118 | """ 119 | 120 | @abc.abstractmethod 121 | def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: 122 | pass 123 | 124 | 125 | class RecordCombiningMapper(BaseMapper): 126 | """ 127 | Creates one or more mapped results from multiple records. This returns the same or less number 128 | of results as there are records since it potentially combines multiple records with the same 129 | identifier into a single result. 130 | 131 | This depends on ``id`` being a column present in the records. If not, it will return a mapped 132 | result for every row without combining records. Otherwise, a single result is returned for each 133 | unique ``id`` value. 134 | """ 135 | 136 | def __init__(self, record_mapper: Optional[Type[DbMapResultBase]] = DbMapResult): 137 | """ 138 | :param record_mapper: what we are mapping records to. 139 | """ 140 | self.record_mapper = record_mapper 141 | 142 | def _get_lookup(self, record): 143 | """ 144 | Return a lookup key based on a set of defined columns. if the id_columns aren't set we will look for an 'id' 145 | column to maintain older expectations. 146 | :param record: 147 | :return: the value of the 'id' colum if id_columns are not set. the hash of the values from the id_colums. 148 | 149 | Note: if the id_columns contain an invalid column, logs a warning and returns None 150 | """ 151 | key_columns = self.record_mapper.get_key_columns() 152 | if not key_columns: 153 | # preserving older expectations 154 | return None 155 | 156 | values = [] 157 | for column in key_columns: 158 | if column not in record: 159 | # returning none because we don't want to make assumptions about how to handle a complex key 160 | return None 161 | values.append(record[column]) 162 | return hash(tuple(values)) 163 | 164 | def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: 165 | if not self.record_mapper: 166 | return records 167 | 168 | current_results = OrderedDict() 169 | current_num = 0 170 | for record in records: 171 | lookup = self._get_lookup(record) 172 | if lookup: 173 | if current_results.get(lookup) is None: 174 | current_results[lookup] = self.record_mapper.create_instance() 175 | current_results[lookup].map_record(record) 176 | else: 177 | current_results[current_num] = self.record_mapper.create_instance() 178 | current_results[current_num].map_record(record) 179 | current_num += 1 180 | 181 | return list(v for k, v in current_results.items()) 182 | 183 | 184 | class SingleRowMapper(BaseMapper): 185 | """ 186 | Returns a single mapped result from one or more records. The first record is returned even if 187 | there are multiple records from the database. 188 | """ 189 | 190 | def __init__(self, record_mapper: Optional[Type[DbMapResultBase]] = DbMapResult): 191 | self.record_mapper = record_mapper 192 | 193 | def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: 194 | if not self.record_mapper: 195 | for record in records: 196 | return record 197 | return None 198 | 199 | for record in records: 200 | mapper = self.record_mapper.create_instance() 201 | mapper.map_record(record) 202 | return mapper 203 | return None 204 | 205 | 206 | class SingleColumnMapper(BaseMapper): 207 | """ 208 | Returns the first column value for each record from the database, even if multiple columns are 209 | defined. This will return a list of scalar values. 210 | """ 211 | 212 | def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: 213 | results = [] 214 | for record in records: 215 | for value in record.values(): 216 | results.append(value) 217 | break 218 | return results 219 | 220 | 221 | class SingleRowAndColumnMapper(BaseMapper): 222 | """ 223 | Returns the first column in the first record from the database, even if multiple records or 224 | columns exist. This will return a single scalar or None if there are no records. 225 | """ 226 | 227 | def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: 228 | for record in records: 229 | for value in record.values(): 230 | return value 231 | return None 232 | 233 | 234 | class CountMapper(SingleRowAndColumnMapper): 235 | """ 236 | Alias for SingleRowAndColumnMapper, may be used to easily return the result of a count query. 237 | """ 238 | 239 | 240 | class KeyValueMapper(BaseMapper): 241 | """ 242 | :param key_column can be specified to determine what field to map to as the key for are keyed values 243 | :param value_column can be specified to help decide what column to add as a value 244 | :param has_multiple_values_per_key is False by default but when True it tells the keyvalue mapper 245 | the each key may have more than 1 result. this will returns a dictionary of lists when set 246 | 247 | """ 248 | 249 | def __init__(self, key_column=0, value_column=1, has_multiple_values_per_key=False): 250 | if key_column == value_column: 251 | raise MapperError("key and value columns cannot be the same") 252 | 253 | self.key_column = key_column 254 | self.value_column = value_column 255 | self.has_multiple_values_per_key = has_multiple_values_per_key 256 | 257 | def map_records(self, records: sqlalchemy.engine.CursorResult) -> Any: 258 | results = {} 259 | if self.has_multiple_values_per_key: 260 | results = defaultdict(list) 261 | for record in records: 262 | key = record[self.key_column] 263 | value = record[self.value_column] 264 | if self.has_multiple_values_per_key: 265 | results[key].append(value) 266 | else: 267 | results[key] = value 268 | return results 269 | -------------------------------------------------------------------------------- /dysql/test/test_sql_in_list_templates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | # pylint: disable=too-many-public-methods 10 | import pytest 11 | 12 | import dysql 13 | from dysql import QueryData, sqlquery 14 | from dysql.test import ( 15 | verify_query, 16 | verify_query_args, 17 | verify_query_params, 18 | setup_mock_engine, 19 | ) 20 | 21 | 22 | @pytest.fixture(name="mock_engine", autouse=True) 23 | def mock_engine_fixture(mock_create_engine): 24 | mock_engine = setup_mock_engine(mock_create_engine) 25 | mock_engine.connect().execution_options().execute.side_effect = lambda x, y: [] 26 | return mock_engine 27 | 28 | 29 | def test_list_in_numbers(mock_engine): 30 | _query( 31 | "SELECT * FROM table WHERE {in__column_a}", 32 | template_params={"in__column_a": [1, 2, 3, 4]}, 33 | ) 34 | verify_query_params( 35 | mock_engine, 36 | "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1, :in__column_a_2, :in__column_a_3 ) ", 37 | { 38 | "in__column_a_0": 1, 39 | "in__column_a_1": 2, 40 | "in__column_a_2": 3, 41 | "in__column_a_3": 4, 42 | }, 43 | ) 44 | 45 | 46 | def test_list_in__strings(mock_engine): 47 | _query( 48 | "SELECT * FROM table WHERE {in__column_a}", 49 | template_params={"in__column_a": ["a", "b", "c", "d"]}, 50 | ) 51 | verify_query_params( 52 | mock_engine, 53 | "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1, :in__column_a_2, :in__column_a_3 ) ", 54 | { 55 | "in__column_a_0": "a", 56 | "in__column_a_1": "b", 57 | "in__column_a_2": "c", 58 | "in__column_a_3": "d", 59 | }, 60 | ) 61 | 62 | 63 | def test_list_not_in_numbers(mock_engine): 64 | _query( 65 | "SELECT * FROM table WHERE {not_in__column_b}", 66 | template_params={"not_in__column_b": [1, 2, 3, 4]}, 67 | ) 68 | verify_query_params( 69 | mock_engine, 70 | "SELECT * FROM table WHERE column_b NOT IN ( :not_in__column_b_0, :not_in__column_b_1, " 71 | ":not_in__column_b_2, :not_in__column_b_3 ) ", 72 | { 73 | "not_in__column_b_0": 1, 74 | "not_in__column_b_1": 2, 75 | "not_in__column_b_2": 3, 76 | "not_in__column_b_3": 4, 77 | }, 78 | ) 79 | 80 | 81 | def test_list_not_in_strings(mock_engine): 82 | _query( 83 | "SELECT * FROM table WHERE {not_in__column_b}", 84 | template_params={"not_in__column_b": ["a", "b", "c", "d"]}, 85 | ) 86 | verify_query_params( 87 | mock_engine, 88 | "SELECT * FROM table WHERE column_b NOT IN ( :not_in__column_b_0, :not_in__column_b_1, " 89 | ":not_in__column_b_2, :not_in__column_b_3 ) ", 90 | { 91 | "not_in__column_b_0": "a", 92 | "not_in__column_b_1": "b", 93 | "not_in__column_b_2": "c", 94 | "not_in__column_b_3": "d", 95 | }, 96 | ) 97 | 98 | 99 | def test_list_in_handles_empty(mock_engine): 100 | _query( 101 | "SELECT * FROM table WHERE {in__column_a}", template_params={"in__column_a": []} 102 | ) 103 | verify_query(mock_engine, "SELECT * FROM table WHERE 1 <> 1 ") 104 | 105 | 106 | def test_list_in_handles_no_param(): 107 | with pytest.raises( 108 | dysql.query_utils.ListTemplateException, match="['in__column_a']" 109 | ): 110 | _query("SELECT * FROM table WHERE {in__column_a}") 111 | 112 | 113 | def test_list_in_multiple_lists(mock_engine): 114 | _query( 115 | "SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", 116 | template_params={"in__column_a": ["first", "second"], "in__column_b": [1, 2]}, 117 | ) 118 | verify_query( 119 | mock_engine, 120 | "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) " 121 | "OR column_b IN ( :in__column_b_0, :in__column_b_1 ) ", 122 | ) 123 | 124 | 125 | def test_list_in_multiple_lists_one_empty(mock_engine): 126 | _query( 127 | "SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", 128 | template_params={"in__column_a": ["first", "second"], "in__column_b": []}, 129 | ) 130 | verify_query( 131 | mock_engine, 132 | "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) OR 1 <> 1 ", 133 | ) 134 | 135 | 136 | def test_list_in_multiple_lists_one_missing(): 137 | with pytest.raises( 138 | dysql.query_utils.ListTemplateException, match="['in__column_a']" 139 | ): 140 | _query( 141 | "SELECT * FROM table WHERE {in__column_a} OR {in__column_b} ", 142 | template_params={"in__column_b": [1, 2]}, 143 | ) 144 | 145 | 146 | def test_list_in_multiple_lists_all_missing(): 147 | with pytest.raises( 148 | dysql.query_utils.ListTemplateException, match="['in__column_a','in__column_b']" 149 | ): 150 | _query("SELECT * FROM table WHERE {in__column_a} OR {in__column_b} ") 151 | 152 | 153 | def test_list_not_in_handles_empty(mock_engine): 154 | _query( 155 | "SELECT * FROM table WHERE {not_in__column_b}", 156 | template_params={"not_in__column_b": []}, 157 | ) 158 | verify_query(mock_engine, "SELECT * FROM table WHERE 1 = 1 ") 159 | 160 | 161 | def test_list_not_in_handles_no_param(): 162 | with pytest.raises( 163 | dysql.query_utils.ListTemplateException, match="['not_in__column_b']" 164 | ): 165 | _query("SELECT * FROM table WHERE {not_in__column_b} ") 166 | 167 | 168 | def test_list_gives_template_space_before(mock_engine): 169 | _query( 170 | "SELECT * FROM table WHERE{in__space}", template_params={"in__space": [9, 8]} 171 | ) 172 | verify_query( 173 | mock_engine, 174 | "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) ", 175 | ) 176 | 177 | 178 | def test_list_gives_template_space_after(mock_engine): 179 | _query( 180 | "SELECT * FROM table WHERE {in__space}AND other_condition = 1", 181 | template_params={"in__space": [9, 8]}, 182 | ) 183 | verify_query( 184 | mock_engine, 185 | "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1", 186 | ) 187 | 188 | 189 | def test_list_gives_template_space_before_and_after(mock_engine): 190 | _query( 191 | "SELECT * FROM table WHERE{in__space}AND other_condition = 1", 192 | template_params={"in__space": [9, 8]}, 193 | ) 194 | verify_query( 195 | mock_engine, 196 | "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1", 197 | ) 198 | 199 | 200 | def test_in_contains_whitespace(mock_engine): 201 | _query("{in__column_one}", template_params={"in__column_one": [1, 2]}) 202 | verify_query( 203 | mock_engine, " column_one IN ( :in__column_one_0, :in__column_one_1 ) " 204 | ) 205 | 206 | 207 | def test_template_handles_table_qualifier(mock_engine): 208 | """ 209 | when the table is specified with a dot separator, we need to split the table out of the 210 | keyword and the keyword passed in should be passed in without the keyword 211 | 212 | {in__table.column} -> {table}.{in__column} -> table.column IN (:in__column_0, :in__column_1) 213 | :return: 214 | """ 215 | _query( 216 | "SELECT * FROM table WHERE {in__table.column}", 217 | template_params={"in__table.column": [1, 2]}, 218 | ) 219 | verify_query( 220 | mock_engine, 221 | "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) ", 222 | ) 223 | verify_query_args(mock_engine, {"in__table_column_0": 1, "in__table_column_1": 2}) 224 | 225 | 226 | def test_template_handles_multiple_table_qualifier(mock_engine): 227 | _query( 228 | "SELECT * FROM table WHERE {in__table.column} AND {not_in__other_column}", 229 | template_params={ 230 | "in__table.column": [1, 2], 231 | "not_in__other_column": ["a", "b"], 232 | }, 233 | ) 234 | verify_query( 235 | mock_engine, 236 | "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) " 237 | "AND other_column NOT IN ( :not_in__other_column_0, :not_in__other_column_1 ) ", 238 | ) 239 | verify_query_args( 240 | mock_engine, 241 | { 242 | "in__table_column_0": 1, 243 | "in__table_column_1": 2, 244 | "not_in__other_column_0": "a", 245 | "not_in__other_column_1": "b", 246 | }, 247 | ) 248 | 249 | 250 | def test_empty_in_contains_whitespace(mock_engine): 251 | _query("{in__column_one}", template_params={"in__column_one": []}) 252 | verify_query(mock_engine, " 1 <> 1 ") 253 | 254 | 255 | def test_multiple_templates_same_column_diff_table(mock_engine): 256 | template_params = { 257 | "in__table.status": ["on", "off", "waiting"], 258 | "in__other_table.status": ["success", "partial_success"], 259 | } 260 | expected_params_from_template = { 261 | "in__table_status_0": "on", 262 | "in__table_status_1": "off", 263 | "in__table_status_2": "waiting", 264 | "in__other_table_status_0": "success", 265 | "in__other_table_status_1": "partial_success", 266 | } 267 | 268 | # writing each of these queries out to help see what we expect compared to 269 | # the query we actually sent 270 | _query( 271 | "SELECT * FROM table WHERE {in__table.status} AND {in__other_table.status}", 272 | template_params=template_params, 273 | ) 274 | expected_query = ( 275 | "SELECT * FROM table WHERE table.status IN ( :in__table_status_0, :in__table_status_1, " 276 | ":in__table_status_2 ) AND other_table.status IN ( :in__other_table_status_0, " 277 | ":in__other_table_status_1 ) " 278 | ) 279 | 280 | connection = mock_engine.connect.return_value.execution_options.return_value 281 | execute_call = connection.execute 282 | execute_call.assert_called_once() 283 | assert execute_call.call_args[0][0].text == expected_query 284 | assert execute_call.call_args[0][1] == expected_params_from_template 285 | 286 | 287 | @sqlquery() 288 | def _query(query, query_params=None, template_params=None): 289 | return QueryData(query, query_params=query_params, template_params=template_params) 290 | -------------------------------------------------------------------------------- /dysql/connections.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import logging 10 | import functools 11 | import inspect 12 | from typing import Type, Union 13 | 14 | import sqlalchemy 15 | 16 | from .databases import DatabaseContainerSingleton 17 | from .mappers import ( 18 | BaseMapper, 19 | DbMapResult, 20 | RecordCombiningMapper, 21 | ) 22 | from .query_utils import get_query_data 23 | 24 | logger = logging.getLogger("database") 25 | 26 | # Always initialize a database container, it is never set again 27 | _DATABASE_CONTAINER = DatabaseContainerSingleton() 28 | 29 | 30 | class _ConnectionManager: 31 | _connection: sqlalchemy.engine.base.Connection 32 | 33 | # pylint: disable=unused-argument 34 | def __init__(self, func, isolation_level, transaction, *args, **kwargs): 35 | self._transaction = None 36 | 37 | self._connection = ( 38 | _DATABASE_CONTAINER.current_database.engine.connect().execution_options( 39 | isolation_level=isolation_level 40 | ) 41 | ) 42 | if transaction: 43 | self._transaction = self._connection.begin() 44 | 45 | def __enter__(self): 46 | self._connection.__enter__() 47 | return self 48 | 49 | def __exit__(self, exc_type, exc_val, exc_tb): 50 | # Handle transaction committing/rollback before closing the connection 51 | if exc_type: 52 | logger.error(f"Encountered an error : {exc_val}") 53 | 54 | # Roll back transaction if there is one 55 | if self._transaction: 56 | self._transaction.rollback() 57 | elif self._transaction: 58 | self._transaction.commit() 59 | 60 | # Close the connection 61 | self._connection.__exit__(exc_type, exc_val, exc_tb) 62 | 63 | def execute_query(self, query, params=None) -> sqlalchemy.engine.CursorResult: 64 | """ 65 | Executes the query through the connection with optional parameters. 66 | :param query: The query to run 67 | :param params: The parameters to use 68 | :return: The result proxy object from sqlalchemy 69 | """ 70 | if not params: 71 | params = {} 72 | 73 | if isinstance(params, DbMapResult): 74 | params = params.raw() 75 | 76 | return self._connection.execute(sqlalchemy.text(query), params) 77 | 78 | 79 | def sqlquery( 80 | mapper: Union[BaseMapper, Type[BaseMapper]] = None, 81 | isolation_level: str = "READ_COMMITTED", 82 | ): 83 | """ 84 | query allows for defining a parameterize select query that is then executed 85 | :param mapper: a class extending from or an instance of BaseMapper, defaults to 86 | RecordCombiningMapper 87 | :param isolation_level: the isolation level to use for the transaction, defaults to 88 | 'READ_COMMITTED' 89 | :return: Depends on the mapper used, see mapper documentation for more details. 90 | 91 | Examples:: 92 | 93 | # Uses the default RecordCombiningMapper, returns a list of results 94 | @sqlquery() 95 | def get_items(name, date): 96 | return QueryData("SELECT item1, item2, ... FROM item join item_data ... 97 | WHERE name=:name and date before :date"), {'name': name, 'date': date}) 98 | 99 | # Override the mapper to return a list of scalars 100 | @sqlquery(mapper=SingleColumnMapper) 101 | def get_items(): 102 | return QueryData("SELECT item1 FROM item join item_data ...") 103 | 104 | # Override the mapper to return a single (first) result 105 | @sqlquery(mapper=SingleRowMapper) 106 | def get_items(): 107 | return QueryData("SELECT item1, item2, ... FROM item join item_data ...") 108 | 109 | # Override the mapper to return a single result from the the first result 110 | @sqlquery(mapper=CountMapper) 111 | def get_count(): 112 | return QueryData("SELECT count(*) FROM item GROUP BY item1") 113 | """ 114 | 115 | def decorator(func): 116 | def handle_query(*args, **kwargs): 117 | functools.wraps(func, handle_query) 118 | actual_mapper = mapper 119 | if not actual_mapper: 120 | # Default to record combining mapper 121 | actual_mapper = RecordCombiningMapper 122 | 123 | # Handle mapper defined as a class 124 | if inspect.isclass(actual_mapper): 125 | actual_mapper = actual_mapper() 126 | 127 | with _ConnectionManager( 128 | func, isolation_level, False, *args, **kwargs 129 | ) as conn_manager: 130 | data = func(*args, **kwargs) 131 | query, params = get_query_data(data) 132 | records = conn_manager.execute_query(query, params) 133 | return actual_mapper.map_records(records) 134 | 135 | return handle_query 136 | 137 | return decorator 138 | 139 | 140 | def sqlexists(isolation_level="READ_COMMITTED"): 141 | """ 142 | exists query allows for defining a parameterize select query that is 143 | wrapped in an exists clause and then executed 144 | :param isolation_level: the isolation level to use for the transaction, defaults to 145 | 'READ_COMMITTED' 146 | :return: returns a True or False depending on the value returned from 147 | the exists query. If Exists returns 1, True. otherwise, False 148 | Examples:: 149 | 150 | @sqlexists 151 | def get_items_exists(name, date): 152 | return "SELECT 1 FROM item join item_data ... 153 | WHERE name=:name and date before :date", { name=name, date=date} 154 | 155 | @sqlexists 156 | def get_items_exists(): 157 | return "SELECT 1 FROM item WHERE date before now()-100 158 | 159 | """ 160 | 161 | def decorator(func): 162 | def handle_query(*args, **kwargs): 163 | functools.wraps(func, handle_query) 164 | with _ConnectionManager( 165 | func, isolation_level, False, *args, **kwargs 166 | ) as conn_manager: 167 | data = func(*args, **kwargs) 168 | query, params = get_query_data(data) 169 | 170 | query = query.lstrip() 171 | if query.startswith("SELECT EXISTS"): 172 | query = query.replace("SELECT EXISTS", "SELECT 1 WHERE EXISTS") 173 | if not query.startswith("SELECT 1 WHERE EXISTS"): 174 | query = f"SELECT 1 WHERE EXISTS ( {query} )" 175 | result = conn_manager.execute_query(query, params).scalar() 176 | return result == 1 177 | 178 | return handle_query 179 | 180 | return decorator 181 | 182 | 183 | def sqlupdate( 184 | isolation_level="READ_COMMITTED", 185 | disable_foreign_key_checks=False, 186 | on_success=None, 187 | ): 188 | """ 189 | :param isolation_level should specify whether we can read data from transactions that are not 190 | yet committed defaults to READ_COMMITTED 191 | :param disable_foreign_key_checks Should be used with caution. Only use this if you have 192 | confidence the data you are inserting is going to have everything it needs to satisfy 193 | foreign_key_checks once all the data is applied. Foreign key checks are there to provide 194 | our database with integrety 195 | 196 | That being said, there are times when you are synchronizing data from another source that 197 | you have data that may be difficult to sort ahead of time in order to avoid failing foreign 198 | key constraints this flag helps to deal with cases where the logic may be unnecessarily 199 | complex. 200 | 201 | Note: this will work as expected when a normal transaction completes successfully and if a 202 | transaction rolls back this will be left in a clean state as expected before executing 203 | anything 204 | Examples:: 205 | 206 | @sqlupdate 207 | def insert_example(key_values) 208 | return QueryData("INSERT INTO table(id, value) VALUES (:id, :value)", key_values) 209 | 210 | @sqlupdate 211 | def delete_example(ids) 212 | return QueryData("DELETE FROM table WHERE id=:id", { "id": id }) 213 | 214 | @sqlupdate() 215 | def insert_with_relations(get_last_insert_id = None): 216 | yield QueryData("INSERT INTO table(value) VALUES (:value)", key_values) 217 | id = get_last_insert_id() 218 | yield "INSERT INTO relation_table(id, value) VALUES (:id, :value)", { 219 | "id": id, 220 | "value": "value" 221 | }) 222 | """ 223 | 224 | def update_wrapper(func): 225 | """ 226 | :param func: should return a tuple representing the query followed by the dictionary of key 227 | values for query parameters 228 | """ 229 | 230 | def handle_query(*args, **kwargs): 231 | functools.wraps(func) 232 | with _ConnectionManager( 233 | func, isolation_level, True, *args, **kwargs 234 | ) as conn_manager: 235 | if disable_foreign_key_checks: 236 | conn_manager.execute_query("SET FOREIGN_KEY_CHECKS=0") 237 | last_insert_method = "get_last_insert_id" 238 | if last_insert_method in inspect.signature(func).parameters: 239 | kwargs[last_insert_method] = lambda: conn_manager.execute_query( 240 | "SELECT LAST_INSERT_ID()" 241 | ).scalar() 242 | if inspect.isgeneratorfunction(func): 243 | logger.debug("handling each query before committing transaction") 244 | 245 | for data in func(*args, **kwargs): 246 | query, params = get_query_data(data) 247 | if isinstance(params, list): 248 | for param in params: 249 | conn_manager.execute_query(query, param) 250 | else: 251 | conn_manager.execute_query(query, params) 252 | else: 253 | data = func(*args, **kwargs) 254 | query, params = get_query_data(data) 255 | if isinstance(params, list): 256 | raise Exception("Params must not be a list") 257 | conn_manager.execute_query(query, params) 258 | 259 | if disable_foreign_key_checks: 260 | conn_manager.execute_query("SET FOREIGN_KEY_CHECKS=1") 261 | 262 | if last_insert_method in kwargs: 263 | del kwargs[last_insert_method] 264 | if on_success: 265 | on_success(*args, **kwargs) 266 | 267 | return handle_query 268 | 269 | return update_wrapper 270 | -------------------------------------------------------------------------------- /dysql/query_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import re 10 | 11 | from typing import Iterable, Optional, Tuple, Union 12 | 13 | 14 | # group 0 - entire string 15 | # group 1 - empty space before template. this helps us ensure we add space to a template but only 1 16 | # group 2 - template keyword 17 | # group 3 - table_name 18 | # group 4 - column_name 19 | # group 5 - empty space after template. this helps us ensure we add space to a template after but 20 | # only 1 21 | LIST_TEMPLATE_REGEX = re.compile( 22 | r"(( +)?{(in|not_in|values)__([A-Za-z_]+\.)?([A-Za-z_]+)}( +)?)" 23 | ) 24 | 25 | 26 | class TemplateGenerators: 27 | """ 28 | Functions in this class help to return a tuple which contains the query template along with 29 | a dictionary of parameterized values for the query. 30 | """ 31 | 32 | @classmethod 33 | def get_template(cls, name: str): 34 | if name == "in": 35 | return cls.in_column 36 | if name == "not_in": 37 | return cls.not_in_column 38 | if name == "values": 39 | return cls.values 40 | return None 41 | 42 | @staticmethod 43 | def in_column( 44 | name: str, 45 | values: Union[str, Iterable[str]], 46 | legacy_key: str = None, 47 | is_multi_column: bool = False, 48 | ) -> Tuple[str, Optional[dict]]: 49 | """ 50 | Returns query and params for using "IN" SQL queries. 51 | :param name: the field name 52 | :param values: the field values 53 | :param legacy_key: 54 | :param is_multi_column: needed when more than one field is being compared as part of the in statement 55 | :return: a tuple of the query string and the params dictionary 56 | """ 57 | if not values: 58 | return "1 <> 1", None 59 | key_name = TemplateGenerators._get_key(name, legacy_key, is_multi_column) 60 | keys, values = TemplateGenerators._parameterize_list(key_name, values) 61 | if is_multi_column: 62 | keys = f"({keys})" 63 | return f"{name} IN {keys}", values 64 | 65 | @staticmethod 66 | def in_multi_column( 67 | name: str, values: Union[str, Iterable[str]], legacy_key: str = None 68 | ): 69 | """ 70 | A wrapper for in_column with is_multi_column set to true 71 | """ 72 | return TemplateGenerators.in_column(name, values, legacy_key, True) 73 | 74 | @staticmethod 75 | def not_in_column( 76 | name: str, 77 | values: Union[str, Iterable[str]], 78 | legacy_key: str = None, 79 | is_multi_column: bool = False, 80 | ) -> Tuple[str, Optional[dict]]: 81 | """ 82 | Returns query and params for using "NOT IN" SQL queries. 83 | :param name: the field name 84 | :param values: the field values 85 | :param legacy_key: 86 | :param is_multi_column: needed when more than one field is being compared as part of the not in statement 87 | :return: a tuple of the query string and the params dictionary 88 | """ 89 | if not values: 90 | return "1 = 1", None 91 | key_name = TemplateGenerators._get_key(name, legacy_key, is_multi_column) 92 | keys, values = TemplateGenerators._parameterize_list(key_name, values) 93 | if is_multi_column: 94 | keys = f"({keys})" 95 | return f"{name} NOT IN {keys}", values 96 | 97 | @staticmethod 98 | def not_in_multi_column( 99 | name: str, values: Union[str, Iterable[str]], legacy_key: str = None 100 | ): 101 | """ 102 | A wrapper for not_in_column with is_multi_column set to true 103 | """ 104 | return TemplateGenerators.not_in_column(name, values, legacy_key, True) 105 | 106 | @staticmethod 107 | def values( 108 | name: str, 109 | values: Union[str, Iterable[str]], 110 | legacy_key: str = None, 111 | ) -> Tuple[str, Optional[dict]]: 112 | """ 113 | Returns query and params for using "VALUES" SQL queries. 114 | :param name: the field name 115 | :param values: the values 116 | :param legacy_key: 117 | :return: a tuple of the query string and the params dictionary 118 | """ 119 | if not values: 120 | raise ListTemplateException(f"Must have values for {name} template") 121 | key_name = TemplateGenerators._get_key(name, legacy_key, False) 122 | keys, values = TemplateGenerators._parameterize_list(key_name, values) 123 | return f"VALUES {keys}", values 124 | 125 | @staticmethod 126 | def _get_key(key: str, legacy_key: str, is_multi_column: bool) -> str: 127 | key_name = key 128 | if legacy_key: 129 | key_name = legacy_key 130 | if is_multi_column: 131 | key_name = re.sub("[, ()]", "", key_name) 132 | return key_name 133 | 134 | @staticmethod 135 | def _parameterize_inner_list( 136 | key: str, values: Union[str, Iterable[str]] 137 | ) -> Tuple[str, Optional[dict]]: 138 | param_values = {} 139 | parameterized_keys = [] 140 | if not isinstance(values, (list, tuple, set)): 141 | param_values[key] = values 142 | parameterized_keys.append(key) 143 | else: 144 | for index, value in enumerate(values): 145 | parameterized_key = f"{key.replace('.', '_')}_{str(index)}" 146 | param_values[parameterized_key] = value 147 | parameterized_keys.append(parameterized_key) 148 | 149 | return f"( :{', :'.join(parameterized_keys)} )", param_values 150 | 151 | @staticmethod 152 | def _parameterize_list( 153 | key: str, values: Union[str, Iterable[str]] 154 | ) -> Tuple[str, Optional[dict]]: 155 | """ 156 | Build a string with parameterized values and a dictionary 157 | with key value pairs matching the string parameters. 158 | :return a tuple of the parameter string and the dictionary of parameter values 159 | """ 160 | param_values = {} 161 | param_inner_keys = [] 162 | 163 | if isinstance(values, str): 164 | values = tuple((values,)) 165 | 166 | for index, value in enumerate(values): 167 | if isinstance(value, tuple) or key.startswith("values"): 168 | ( 169 | param_string, 170 | inner_param_values, 171 | ) = TemplateGenerators._parameterize_inner_list( 172 | f"{key}_{str(index)}", value 173 | ) 174 | param_values.update(inner_param_values) 175 | param_inner_keys.append(param_string) 176 | else: 177 | return TemplateGenerators._parameterize_inner_list(key, values) 178 | 179 | return ", ".join(param_inner_keys), param_values 180 | 181 | 182 | class ListTemplateException(Exception): 183 | """ 184 | List Template Exception 185 | """ 186 | 187 | 188 | class QueryDataError(Exception): 189 | """ 190 | This exception is thrown when we expect a QueryData object but end up with something different 191 | """ 192 | 193 | 194 | class QueryData: 195 | # pylint: disable=too-few-public-methods 196 | """ 197 | Query data is a wrapper class that allows us to pass back information for a query with more 198 | information and specification. this helps to have template queries with template parameters but 199 | to still have the ability to have query parameters that get parameterized on the query. 200 | """ 201 | 202 | def __init__( 203 | self, 204 | query: str, 205 | query_params: dict = None, 206 | template_params: dict = None, 207 | ): 208 | """ 209 | Constructor. 210 | :param query: the SQL query 211 | :param query_params: object or list of objects containing values to apply to parameters 212 | 213 | examples of a queryparams passed to a QueryData 214 | 215 | Single Object 216 | QueryData( 217 | "SELECT * FROM table WHERE value_a=:value_a and value_b=:value_b", 218 | query_params={'value_a': 1, 'value_b': 'b' } 219 | ) 220 | 221 | List of Objects (Note: this will create a query for each object, it may not be possible or 222 | even desirable but, using template_params will create a single query with the params) 223 | QueryData( 224 | "SELECT * FROM table WHERE value_a=:value_a` and value_b=:value_b", 225 | query_params=[{'value_a': 1, 'value_b': 'b' }, {'value_a': 1, 'value_b': 'b' }] 226 | ) 227 | :param template_params: these are templates that can be added to queries to generate a single 228 | parameterized query. 229 | 230 | examples of current list templates and transformations 231 | 232 | "IN (... )" 233 | QueryData( 234 | "SELECT * FROM table where name {in__name}", 235 | template_params={'in__name' : ['bob','tom','chic']} 236 | ) 237 | {in__actor.name} -> actor.name IN ('bob', 'tom', ' ), 238 | 239 | "NOT IN (... )" 240 | QueryData( 241 | "SELECT * FROM table wher name {in__name}", 242 | template_params={'in__name' : ['bob','tom','chic']} 243 | ) 244 | {not_in__actor.name} -> actor.name NOT IN ('name1','name2',...), 245 | 246 | "VALUES (), (), ()..." 247 | {values__actors} -> VALUES (a1.v1,a1.v2,a1.v3), (a1.v1,a1.v2,a1.v3),... 248 | """ 249 | self.query = query 250 | self.query_params = query_params 251 | self.template_params = template_params 252 | 253 | 254 | def __validate_keys_clean_query(query, template_params): 255 | validated_keys = [] 256 | for groups in re.findall(LIST_TEMPLATE_REGEX, query): 257 | # check first group for the full key 258 | key = f"{groups[2]}__{groups[3] if groups[3] else ''}{groups[4]}" 259 | validated_keys.append(key) 260 | missing_keys = [] 261 | 262 | # validate 263 | if template_params is None or template_params.get(key) is None: 264 | missing_keys.append(key) 265 | elif key == "values" and len(template_params.get(key)) == 0: 266 | missing_keys.append(key) 267 | 268 | if len(missing_keys) > 0: 269 | raise ListTemplateException(f"Missing template keys {missing_keys}") 270 | 271 | # Clean whitespace as templates will add their own padding later on 272 | query = query.replace(groups[0], groups[0].strip()) 273 | return query, validated_keys 274 | 275 | 276 | def __validate_query_and_params(data: QueryData) -> None: 277 | if not isinstance(data, QueryData): 278 | raise QueryDataError( 279 | "SQL annotated methods must return an instance of QueryData for query information" 280 | ) 281 | 282 | 283 | def get_query_data(data: QueryData) -> Tuple[str, dict]: 284 | """ 285 | Retrieves the query and parameters from a QueryData object. 286 | :param data: the query data object 287 | :return: a tuple of the query string and the params 288 | """ 289 | __validate_query_and_params(data) 290 | 291 | params = {} 292 | query, validated_keys = __validate_keys_clean_query( 293 | data.query, data.template_params 294 | ) 295 | 296 | if data.query_params: 297 | params.update(data.query_params) 298 | 299 | for key in validated_keys: 300 | list_template_key, column_name = tuple(key.split("__")) 301 | template_to_use = TemplateGenerators.get_template(list_template_key) 302 | template_query, param_dict = template_to_use( 303 | column_name, data.template_params[key], legacy_key=key 304 | ) 305 | if param_dict: 306 | params.update(param_dict) 307 | query_key = "{" + key + "}" 308 | query = query.replace(query_key, f" {template_query} ") 309 | 310 | return query, params 311 | -------------------------------------------------------------------------------- /dysql/test/test_sql_decorator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Adobe 3 | All Rights Reserved. 4 | 5 | NOTICE: Adobe permits you to use, modify, and distribute this file in accordance 6 | with the terms of the Adobe license agreement accompanying it. 7 | """ 8 | 9 | import pytest 10 | 11 | from dysql import ( 12 | sqlupdate, 13 | sqlquery, 14 | DbMapResult, 15 | CountMapper, 16 | SingleRowMapper, 17 | QueryData, 18 | QueryDataError, 19 | ) 20 | from dysql.test import setup_mock_engine 21 | 22 | 23 | class TestSqlSelectDecorator: 24 | """ 25 | Test Sql Select Decorator 26 | """ 27 | 28 | @staticmethod 29 | @pytest.fixture 30 | def mock_results(): 31 | return [ 32 | { 33 | "id": 1, 34 | "name": "jack", 35 | "email": "jack@adobe.com", 36 | "hobbies": ["golf", "bikes", "coding"], 37 | }, 38 | { 39 | "id": 2, 40 | "name": "flora", 41 | "email": "flora@adobe.com", 42 | "hobbies": ["golf", "coding"], 43 | }, 44 | { 45 | "id": 3, 46 | "name": "Terrence", 47 | "email": "terrence@adobe.com", 48 | "hobbies": ["coding", "watching tv", "hot dog eating contest"], 49 | }, 50 | ] 51 | 52 | @staticmethod 53 | @pytest.fixture 54 | def mock_engine(mock_results, mock_create_engine): 55 | mock_engine = setup_mock_engine(mock_create_engine) 56 | mock_engine.connect.return_value.execution_options.return_value.execute.return_value = mock_results 57 | return mock_engine 58 | 59 | @staticmethod 60 | def test_non_query_data_fails(mock_create_engine): 61 | # This argument forces the fixture setup 62 | _ = mock_create_engine 63 | 64 | @sqlquery() 65 | def select_with_string(): 66 | return "SELECT" 67 | 68 | with pytest.raises(QueryDataError): 69 | select_with_string() 70 | 71 | def test_map_not_specified(self, mock_results, mock_engine): 72 | mock_engine.connect.return_value.execution_options.return_value.execute.return_value = mock_results 73 | assert isinstance(list(self._select_all())[0], DbMapResult) 74 | 75 | def test_single(self, mock_results, mock_engine): 76 | mock_engine.connect.return_value.execution_options.return_value.execute.return_value = mock_results 77 | result = self._select_single() 78 | assert isinstance(result, DbMapResult) 79 | 80 | def test_count(self, mock_engine): 81 | mock_engine.connect.return_value.execution_options.return_value.execute.return_value = [ 82 | {"count": 2}, 83 | {"count": 3}, 84 | ] 85 | assert self._select_count() == 2 86 | 87 | def test_execute_no_params(self, mock_engine): 88 | self._select_all() 89 | mock_engine.connect.return_value.execution_options.return_value.execute.assert_called() 90 | call_args = mock_engine.connect.return_value.execution_options.return_value.execute.call_args 91 | assert call_args[0][0].text == "SELECT * FROM my_table" 92 | assert call_args[0][1] == {} 93 | 94 | def test_execute_params(self, mock_engine): 95 | self._select_filtered(3) 96 | mock_engine.connect.return_value.execution_options.return_value.execute.assert_called() 97 | 98 | call_args = mock_engine.connect.return_value.execution_options.return_value.execute.call_args 99 | assert call_args[0][0].text == "SELECT * FROM my_table WHERE id=:id" 100 | assert call_args[0][1] == {"id": 3} 101 | 102 | def test_list_results_map(self, mock_results, mock_engine): 103 | mock_engine.connect.return_value.execution_options.return_value.execute.return_value = [ 104 | mock_results[2] 105 | ] 106 | results = self._select_filtered(3) 107 | 108 | assert len(results) == 1 109 | assert len(list(results)[0].hobbies) == 3 110 | 111 | def test_isolation_default(self, mock_engine): 112 | mock_connect = mock_engine.connect.return_value.execution_options 113 | self._select_all() 114 | mock_connect.assert_called_with(isolation_level="READ_COMMITTED") 115 | 116 | def test_isolation_default_read_uncommited(self, mock_engine): 117 | mock_connect = mock_engine.connect.return_value.execution_options 118 | self._select_uncommitted() 119 | mock_connect.assert_called_with(isolation_level="READ_UNCOMMITTED") 120 | mock_connect.return_value.execute.assert_called() 121 | 122 | @staticmethod 123 | @sqlquery() 124 | def _select_all(): 125 | """ 126 | method is static just to help show you can use static methods if 127 | you don't want to instantiate an instance of the class 128 | """ 129 | return QueryData("SELECT * FROM my_table") 130 | 131 | @staticmethod 132 | @sqlquery(mapper=SingleRowMapper()) 133 | def _select_single(): 134 | return QueryData("SELECT * FROM my_table") 135 | 136 | @staticmethod 137 | @sqlquery(mapper=CountMapper) 138 | def _select_count(): 139 | return QueryData("SELECT COUNT(*) FROM my_table") 140 | 141 | @staticmethod 142 | @sqlquery() 143 | def _select_filtered(_id): 144 | return QueryData( 145 | "SELECT * FROM my_table WHERE id=:id", query_params={"id": _id} 146 | ) 147 | 148 | @staticmethod 149 | @sqlquery(isolation_level="READ_UNCOMMITTED") 150 | def _select_uncommitted(): 151 | return QueryData("SELECT * FROM uncommitted") 152 | 153 | 154 | class TestSqlUpdateDecorator: 155 | """ 156 | Test Sql Update Decorator 157 | """ 158 | 159 | @staticmethod 160 | @pytest.fixture 161 | def mock_engine(mock_create_engine): 162 | return setup_mock_engine(mock_create_engine) 163 | 164 | @staticmethod 165 | @pytest.fixture 166 | def mock_connect(mock_engine): 167 | return mock_engine.connect.return_value.execution_options 168 | 169 | def test_isolation_default(self, mock_connect): 170 | self._update_something({"id": 1, "value": "test"}) 171 | mock_connect.assert_called_with(isolation_level="READ_COMMITTED") 172 | 173 | def test_isolation_default_read_uncommited(self, mock_connect): 174 | self._update_something_uncommited_isolation({"id": 1, "value": "test"}) 175 | mock_connect.assert_called_with(isolation_level="READ_UNCOMMITTED") 176 | mock_connect.return_value.begin.assert_called() 177 | mock_connect.return_value.begin.return_value.commit.assert_called() 178 | mock_connect.return_value.__exit__.assert_called() 179 | 180 | def test_transaction(self, mock_connect): 181 | self._update_something({"id": 1, "value": "test"}) 182 | mock_connect().begin.assert_called() 183 | mock_connect().begin.return_value.commit.assert_called() 184 | mock_connect().__exit__.assert_called() 185 | 186 | def test_transaction_fails(self, mock_connect): 187 | mock_connect().execute.side_effect = Exception("error") 188 | with pytest.raises(Exception): 189 | self._update_something({"id": 1, "value": "test"}) 190 | mock_connect().begin.return_value.commit.assert_not_called() 191 | mock_connect().begin.return_value.rollback.assert_called() 192 | mock_connect().__exit__.assert_called() 193 | 194 | def test_execute_passes_values(self, mock_engine): 195 | values = {"id": 1, "value": "test"} 196 | self._update_something(values) 197 | 198 | execute_call = ( 199 | mock_engine.connect.return_value.execution_options.return_value.execute 200 | ) 201 | execute_call.assert_called() 202 | 203 | execute_call_args = execute_call.call_args[0] 204 | assert execute_call_args[0].text == "INSERT something" 205 | assert values == execute_call_args[1] 206 | 207 | def test_execute_query_values_none(self, mock_engine): 208 | self._update_something(None) 209 | 210 | execute_call = ( 211 | mock_engine.connect.return_value.execution_options.return_value.execute 212 | ) 213 | execute_call.assert_called() 214 | 215 | self._expect_args_list(execute_call.call_args_list[0], "INSERT something") 216 | 217 | def test_execute_query_values_not_given(self, mock_engine): 218 | # pylint: disable=unused-argument 219 | self._update_something_no_params() 220 | 221 | def test_execute_multi_yield(self, mock_connect): 222 | expected_values = [ 223 | {"id": 1, "value": "test 1"}, 224 | {"id": 2, "value": "test 2"}, 225 | {"id": 3, "value": "test 3"}, 226 | ] 227 | self._update_something_multi_yield(expected_values) 228 | 229 | execute_call = mock_connect().execute 230 | assert execute_call.call_count == 3 231 | 232 | for call_args in execute_call.call_args_list: 233 | self._expect_args_list(call_args, "INSERT something") 234 | 235 | def test_execute_fails_list_if_multi_false(self): 236 | expected_values = [ 237 | {"id": 1, "value": "test 1"}, 238 | {"id": 2, "value": "test 2"}, 239 | {"id": 3, "value": "test 3"}, 240 | ] 241 | with pytest.raises(Exception): 242 | self._update_list_when_multi_false(expected_values) 243 | 244 | def test_execute_multi_yield_and_lists(self, mock_engine): 245 | expected_values = [ 246 | {"id": 1, "value": "test 1"}, 247 | {"id": 2, "value": "test 2"}, 248 | {"id": 3, "value": "test 3"}, 249 | ] 250 | other_expected_values = [ 251 | {"id": 5, "value": "test 5"}, 252 | {"id": 6, "value": "test 6"}, 253 | {"id": 7, "value": "test 7"}, 254 | ] 255 | self._update_yield_with_lists(expected_values, other_expected_values) 256 | 257 | execute_call = ( 258 | mock_engine.connect.return_value.execution_options.return_value.execute 259 | ) 260 | assert execute_call.call_count == 2 261 | 262 | self._expect_args_list( 263 | execute_call.call_args_list[0], 264 | "INSERT some VALUES ( :values__in_0_0, :values__in_0_1 ), ( :values__in_1_0, :values__in_1_1 ), " 265 | "( :values__in_2_0, :values__in_2_1 ) ", 266 | ) 267 | self._expect_args_list( 268 | execute_call.call_args_list[1], 269 | "INSERT some more VALUES ( :values__other_0_0, :values__other_0_1 ), " 270 | "( :values__other_1_0, :values__other_1_1 ), ( :values__other_2_0, :values__other_2_1 ) ", 271 | ) 272 | 273 | def test_execute_multi_yield_and_lists_some_no_params(self, mock_engine): 274 | expected_values = [ 275 | {"id": 1, "value": "test 1"}, 276 | {"id": 2, "value": "test 2"}, 277 | {"id": 3, "value": "test 3"}, 278 | ] 279 | self._update_yield_with_lists_some_no_params(expected_values) 280 | 281 | execute_call = ( 282 | mock_engine.connect.return_value.execution_options.return_value.execute 283 | ) 284 | assert execute_call.call_count == 4 285 | 286 | self._expect_args_list( 287 | execute_call.call_args_list[0], 288 | "INSERT some VALUES ( :values__in_0_0, :values__in_0_1 ), ( :values__in_1_0, :values__in_1_1 ), " 289 | "( :values__in_2_0, :values__in_2_1 ) ", 290 | ) 291 | self._expect_args_list(execute_call.call_args_list[1], "UPDATE some more") 292 | self._expect_args_list(execute_call.call_args_list[2], "UPDATE some more") 293 | self._expect_args_list(execute_call.call_args_list[3], "DELETE some more") 294 | 295 | def test_set_foreign_key_checks_default(self, mock_engine): 296 | self._update_something_no_params() 297 | 298 | execute_call = ( 299 | mock_engine.connect.return_value.execution_options.return_value.execute 300 | ) 301 | assert execute_call.call_count == 1 302 | 303 | def test_set_foreign_key_checks_true(self, mock_engine): 304 | self._update_without_foreign_key_checks() 305 | 306 | execute_call = ( 307 | mock_engine.connect.return_value.execution_options.return_value.execute 308 | ) 309 | assert execute_call.call_count == 4 310 | assert execute_call.call_args_list[0][0][0].text == "SET FOREIGN_KEY_CHECKS=0" 311 | assert execute_call.call_args_list[-1][0][0].text == "SET FOREIGN_KEY_CHECKS=1" 312 | 313 | @staticmethod 314 | def _expect_args_list(call_args, expected_query): 315 | assert expected_query == call_args[0][0].text 316 | 317 | @staticmethod 318 | @sqlupdate(disable_foreign_key_checks=True) 319 | def _update_without_foreign_key_checks(): 320 | yield QueryData("INSERT B_RELIES_ON_A") 321 | yield QueryData("INSERT A") 322 | 323 | @staticmethod 324 | @sqlupdate() 325 | def _update_something(values): 326 | return QueryData("INSERT something", query_params=values) 327 | 328 | @staticmethod 329 | @sqlupdate(isolation_level="READ_UNCOMMITTED") 330 | def _update_something_uncommited_isolation(values): 331 | return QueryData(f"INSERT WITH uncommitted {values}") 332 | 333 | @staticmethod 334 | @sqlupdate() 335 | def _update_something_no_params(): 336 | return QueryData("INSERT something") 337 | 338 | @staticmethod 339 | @sqlupdate() 340 | def _update_something_multi_yield(multiple_values): 341 | for values in multiple_values: 342 | yield QueryData("INSERT something", query_params=values) 343 | 344 | @staticmethod 345 | @sqlupdate() 346 | def _update_something_return_list(multiple_values): 347 | return QueryData("INSERT something", query_params=multiple_values) 348 | 349 | @staticmethod 350 | @sqlupdate() 351 | def _update_list_when_multi_false(multiple_values): 352 | return "INSERT something", multiple_values 353 | 354 | @staticmethod 355 | @sqlupdate() 356 | def _update_yield_with_lists(multiple_values, other_values): 357 | """ 358 | method is static just to help show you can use static methods if 359 | you don't want to instantiate an instance of the class 360 | 361 | NOTE: 362 | """ 363 | yield QueryData( 364 | "INSERT some {values__in}", 365 | template_params=_get_template_params("values__in", multiple_values), 366 | ) 367 | yield QueryData( 368 | "INSERT some more {values__other}", 369 | template_params=_get_template_params("values__other", other_values), 370 | ) 371 | 372 | @staticmethod 373 | @sqlupdate() 374 | def _update_yield_with_lists_some_no_params(multiple_values): 375 | """ 376 | method is static just to help show you can use static methods if 377 | you don't want to instantiate an instance of the class 378 | """ 379 | yield QueryData( 380 | "INSERT some {values__in}", 381 | template_params=_get_template_params("values__in", multiple_values), 382 | ) 383 | yield QueryData("UPDATE some more") 384 | yield QueryData("UPDATE some more") 385 | yield QueryData("DELETE some more") 386 | 387 | 388 | def _get_template_params(key, values): 389 | """ 390 | template parameters is handled here as an example of what you might end up 391 | doing. usually data coming in is going to be a list 392 | 393 | :param key: the template key we want to use to build our template with 394 | :param values: the list of objects or mariadbmaps, 395 | :return: a keyed list of tuples or mariadbmaps 396 | """ 397 | if isinstance(values[0], DbMapResult): 398 | return {key: values} 399 | return {key: [(v["id"], v["value"]) for v in values]} 400 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ###################### 2 | Dynamic SQL (dy-sql) 3 | ###################### 4 | 5 | This project consists of a set of python decorators that eases integration with SQL databases. 6 | These decorators may trigger queries, inserts, updates, and deletes. 7 | 8 | The decorators are a way to help us map our data in python to SQL queries and vice versa. 9 | When we select, insert, update, or delete the queries, we pass the data we want 10 | to insert along with a well defined query. 11 | 12 | This is designed to be done with minimal setup and coding. You need to specify 13 | the database connection parameters and annotate any SQL queries/updates you have with the 14 | decorator that fits your needs. 15 | 16 | Installation 17 | ============ 18 | 19 | .. code-block:: 20 | 21 | pip install dy-sql 22 | 23 | Development Setup 24 | ================= 25 | 26 | To set up the development environment: 27 | 28 | .. code-block:: 29 | 30 | uv sync 31 | 32 | Component Breakdown 33 | =================== 34 | * **set_default_connection_parameters** - this function needs to be used to set the database parameters on 35 | initialization so that when a decorator function is called, it can setup a connection pool to a correct database 36 | * **is_set_current_database_supported** - this function may be used to determine if the ``*_current_database`` methods 37 | may be used or not 38 | * **set_current_database** - this function may be used to set the database name for the current async context 39 | (not thread), this is especially useful for multitenant applications 40 | * **reset_current_database** - helper method to reset the current database after ``set_current_database`` has 41 | been used in an async context 42 | * **set_database_init_hook** - sets a method to call whenever a new database is initialized 43 | * **QueryData** - a class that may be returned or yielded from ``sql*`` decorated methods which 44 | contains query information 45 | * **DbMapResult** - base class that can be used when selecting data that helps to map the results of a 46 | query to an object in python 47 | * **DbMapResultModel** - pydantic version of ``DbMapResult`` that allows easy mapping to pydantic models 48 | * **@sqlquery** - decorator for select queries that can return a SQL result in a ``DbMapResult`` 49 | * **@sqlupdate** - decorator for any queries that can change data in the database, this can take a set of 50 | values and yield multiple operations back for insertions or updates inside of a transaction 51 | * **@sqlexists** - decorator for a simplified select query that will return true if a record exists and false otherwise 52 | * **XDbTestManager** - test manager classes that may be used for testing purposes 53 | 54 | Database Preparation 55 | ==================== 56 | In order to initialize a connection pool for the ``sql*`` decorators, the database needs to first be set up 57 | using the ``set_default_connection_parameters`` method. 58 | 59 | .. code-block:: python 60 | 61 | from dysql import set_database_parameters 62 | 63 | def set_database_from_config(): 64 | maria_db_config = {...} 65 | set_database_parameters( 66 | maria_db_host, 67 | maria_db_user, 68 | maria_db_password, 69 | maria_db_database, 70 | port=maria_db_port, 71 | charset=maria_db_charset 72 | ) 73 | 74 | Note: the keyword arguments are not required and have standard default values, 75 | the port for example defaults to 3306 76 | 77 | Database Init Hook 78 | ================== 79 | At times, it is necessary to perform post-initialization tasks on the database engine after it has been created. 80 | The ``set_database_init_hook`` method may be used in this case. As an example, to instrument the engine using 81 | ``opentelemetry-instrumentation-sqlalchemy``, the following code may be used: 82 | 83 | .. code-block:: python 84 | 85 | from typing import Optional 86 | # Used for type-hints only 87 | import sqlalchemy 88 | from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor 89 | from dysql import set_database_init_hook 90 | 91 | def _instrument_engine(database_name: Optional[str], engine: sqlalchemy.engine.Engine) -> None: 92 | # The database name is unused in this case 93 | _ = database_name 94 | SQLAlchemyInstrumentor().instrument(engine=engine) 95 | 96 | set_database_init_hook(_instrument_engine) 97 | 98 | 99 | Multitenancy 100 | ============ 101 | In some applications, it may be useful to set a database other than the default database in order to support 102 | database-per-tenant configurations. This may be done using various provided methods. 103 | 104 | .. code-block:: python 105 | 106 | from dysql import ( 107 | set_default_connection_parameters, 108 | reset_current_database, 109 | set_current_database, 110 | use_database_tenant, 111 | tenant_database_manager, 112 | sqlquery, 113 | QueryData, 114 | ) 115 | 116 | def init(): 117 | # Initialize all databases up-front using an arbitrary database key to refer to them later 118 | set_default_connection_parameters( 119 | ... 120 | database_key='db1', 121 | ) 122 | set_default_connection_parameters( 123 | ... 124 | database_key='db2', 125 | ) 126 | 127 | def tenant_query_with_manual_set_reset(): 128 | set_current_database('db2') 129 | try: 130 | query_database() 131 | finally: 132 | reset_current_database() 133 | 134 | def tenant_query_with_context_manager(): 135 | with tenant_database_manager("db2"): 136 | return query_database() 137 | 138 | @use_database_tenant("db2") 139 | @sqlquery() 140 | def tenant_query_with_decorator(): 141 | return QueryData("SELECT * FROM users") 142 | 143 | Decorators 144 | ========== 145 | Decorators are an easy way for us to tell a function to be a 'query' and return 146 | a result without having to have a big chunk of boiler plate code. Once the 147 | database has been prepared, calling a ``sql*`` decorated function will initialize 148 | the database, parse the value returned in your function, make a corresponding 149 | parameterized query and return the results. 150 | 151 | The basic structure is to decorate a method that returns information about the query. 152 | There are multiple options for returning a query, below is a summary of some of the possibilities: 153 | 154 | * return a ``QueryData`` object that possibly contains ``query_params`` and/or ``template_params`` 155 | * (not available for all ``sql*`` decorators) yield one or more ``QueryData`` objects, 156 | each containing ``query_params`` and/or ``template_params`` 157 | 158 | DbMapResult 159 | ~~~~~~~~~~~ 160 | This class is used in the default mapper (see below) for any ``sqlquery`` decorated method. This class may also be 161 | overridden as shown below. The default class wraps and returns the results of a query for easy access to the data 162 | from the query. For example, if you use the query ``SELECT id, name FROM table``, it would return a list of 163 | ``DbMapResult`` objects where each contains the ``id`` and ``name`` fields. You could then easily loop through 164 | and access the properties as shown in the following example: 165 | 166 | .. code-block:: python 167 | 168 | @sqlquery() 169 | def get_items_from_sql_query(): 170 | return QueryData("SELECT id, name FROM table") 171 | 172 | def get_and_process_items(): 173 | for item in get_items_from_sql_query(): 174 | # we are able to access properties on the object 175 | print('{name} goes with {id}'.format(item.name, item.id)) 176 | 177 | We can inherit from ``DbMapResult`` and override the way our data maps into the 178 | object. This is primarily helpful in cases where we end up with multiple rows 179 | such as a query for a 1-to-many relationship. 180 | 181 | .. code-block:: python 182 | 183 | class ExampleMap(DbMapResult): 184 | def map_result(self, result): 185 | # we know we are mapping multiple rows to a single result 186 | if self.id is None: 187 | # in our case we know the id is the same so we only set it the first time 188 | self.id = result['id'] 189 | # initialize our array 190 | self.item_names = [] 191 | 192 | # we know that every result for a given id has a unique item_name 193 | self.item_names.append(result['item_name']) 194 | 195 | @sqlquery(mapping=ExampleMap) 196 | def get_table_items() 197 | return QueryData(""" 198 | SELECT id, name, item_name FROM table 199 | JOIN table_item ON table.id = table_item.table_id 200 | JOIN item ON item.id = table_item.item_id 201 | """) 202 | 203 | def print_item_names() 204 | for table_item in get_table_items(): 205 | for item_name in table_item.item_names: 206 | print(f'table name {table_item.name} has item {item_name}') 207 | 208 | DbMapResultModel (pydantic) 209 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 210 | 211 | If pydantic models are desired to be used, there is a record mapper available. Note that pydantic must be installed, 212 | which is available as an extra package: 213 | 214 | .. code-block:: 215 | 216 | pip install dy-sql[pydantic] 217 | 218 | This model attempts to make mapping records easier, but there are shortcomings of it in more complex cases. 219 | Most fields will "just work" as defined by the type annotations. 220 | 221 | .. code-block:: python 222 | 223 | from dysql.pydantic_mappers import DbMapResultModel 224 | 225 | class PydanticDbModel(DbMapResultModel): 226 | id: int 227 | field_str: str 228 | field_int: int 229 | field_bool: bool 230 | 231 | Mapping a record onto this class will automatically convert types as defined by the type annotations. No ``map_record`` 232 | method needs to be defined since the pydantic model has everything necessary to map database fields. 233 | 234 | Lists, sets, dicts, csv strings, and json strings (when using the RecordCombiningMapper) require additional configuration on the model class. 235 | 236 | .. code-block:: python 237 | 238 | from dysql.pydantic_mappers import DbMapResultModel 239 | 240 | class ComplexDbModel(DbMapResultModel): 241 | # if any data has been aggregated or saved into a string as a comma delimited list, this will convert to a list 242 | # NOTE this only does simple splitting and is not fully rfc4180 compatible 243 | _csv_list_fields: Set[str] = {'list_from_string'} 244 | # List fields (type does not matter) 245 | _list_fields: Set[str] = {'list1'} 246 | # Set fields (type does not matter) 247 | _set_fields: Set[str] = {'set1'} 248 | # Dictionary key fields as DB field name => model field name 249 | _dict_key_fields: Dict[str, str] = {'key1': 'dict1', 'key2': 'dict2'} 250 | # Dictionary value fields as model field name => DB field name (this is reversed from _dict_key_fields!) 251 | _dict_value_mappings: Dict[str, str] = {'dict1': 'val1', 'dict2': 'val2'} 252 | # JSON string fields. Type can be any dictionary type but for larger json objects its safe to stay with `dict` 253 | _json_fields: Set[str] = {'json1', 'json2'} 254 | 255 | id: int = None 256 | list_from_string: List[str] 257 | list1: List[str] 258 | set1: Set[str] = set() 259 | dict1: Dict[str, Any] = {} 260 | dict2: Dict[str, int] = {} 261 | json1: dict 262 | json2: dict 263 | 264 | .. note:: 265 | 266 | csv strings can be useful in queries where you want to group by an id and then ``group_concat`` some field 267 | 268 | json strings are a handy way to extract json blobs into a python dictionary for ease of use without manually processing 269 | each field everytime you need something. 270 | 271 | In this case, the ``_`` prefixed properties tell the model which fields should be treated differently when combining 272 | multiple rows into a single object. For an example of how this works with database rows, see the 273 | ``test_pydantic_mappers.py`` file in the source repository. 274 | 275 | Note that validation **does** occur the very first time ``map_record`` is called, but not on subsequent runs. Therefore 276 | if you desire better validation for list, set, or dict fields, this must most likely be done outside of dysql/pydantic. 277 | Additionally, lists, sets, and dicts will ignore null values from the database. Therefore you must provide default 278 | values for these fields when used or else validation will fail. 279 | 280 | Added annotations when using DbMapResultModel 281 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 282 | When using the ``DbMapResultModel`` mapper, there are some additional annotations that may be used to help with 283 | mapping. These annotations are not required, but may be helpful in some cases. 284 | 285 | * FromCSVToList - this annotation will convert a comma separated string into a list. This is useful when you have 286 | a column containing a csv or a query that uses ``group_concat`` to combine multiple rows into a single row. This 287 | annotation may be used on any field that is a list. For example: 288 | 289 | .. code-block:: python 290 | 291 | from dysql.pydantic_mappers import DbMapResultModel, FromCSVToList 292 | 293 | class CsvModel(DbMapResultModel): 294 | id: int 295 | name: str 296 | # This annotation will convert the string into a list of ints 297 | list_from_string_int: FromCSVToList[List[int]] 298 | # This annotation will convert the string into a list of strings 299 | list_from_string: FromCSVToList[List[str]] 300 | # This annotation will convert the string into a list of ints or None if the string is null or empty 301 | list_from_string_int_nullable: FromCSVToList[List[int] | None] 302 | # This annotation will convert the string into a list of strings or None if the string is null or empty 303 | list_from_string_nullable: FromCSVToList[List[str] | None] 304 | 305 | # if using python <= 3.9, you can use typing.Union instead of the pipe operator 306 | # list_from_string_nullable: FromCSVToList[Union[List[str],None]] 307 | 308 | 309 | @sqlquery 310 | ~~~~~~~~~ 311 | This is for making SQL ``select`` calls. An optional mapper may be specified to 312 | change the behavior of what is returned from a decorated method. The default 313 | mapper can combine multiple records into a single result if there is an 314 | ``id`` field present in each record. Mappers available: 315 | 316 | * ``RecordCombiningMapper`` (default) - Returns a list of results where multiple records that can be combined with the 317 | same unique identifer. An optional ``record_mapper`` value may be passed to the constructor to change 318 | how records are mapped to result. By default the ``record_mapper`` used is ``DbMapResult``. The base identifier 319 | is the column ``id`` but an array of columns can be used to create a unique key lookup for combining records. 320 | 321 | .. note:: 322 | The ``_key_columns`` field of the ``DbMapResultModel`` is an array containing only the ``id`` but can 323 | be overriden in derived classes. For example, setting ``_key_columns = [ 'a', 'b' ]`` in your derived class 324 | would make it so you class would use the values of columns `a` and `b` in order to uniquely identify 325 | records when being combined. 326 | 327 | * ``SingleRowMapper`` - returns an object for the first record from the database (even if multiple records are 328 | returned). An optional ``record_mapper`` value may be passed to the constructor to change how this first record is 329 | mapped to the result. 330 | * ``SingleColumnMapper`` - Returns a list of scalars with the first column from every record, even if multiple columns 331 | are returned from the database. 332 | * ``SingleRowAndColumnMapper`` - Returns a single scalar value even if multiple records and columns are returned 333 | from the database. 334 | * ``CountMapper`` - alias for ``SingleRowAndColumnMapper`` to make it clear that it may be used for ``count`` queries. 335 | * ``KeyValueMapper`` - returns a dictionary mapping 1 column to the keys and 1 column to the values. 336 | By default the key is mapped to the first column and value is mapped to the second column. You can override the key_column 337 | and value_columns by specifying the name of the columns you want for each. You can also pass in a has_multiple_values 338 | which defaults to False. Doing so will allow you to get a dictionary of lists based on the keys and values you specify. 339 | * Custom mappers may be made by extending the ``BaseMapper`` class and implementing the ``map_records`` method. 340 | 341 | basic query with conditions hardcoded into query and default mapper 342 | 343 | .. code-block:: python 344 | 345 | def get_items(): 346 | items = select_items_for_joe() 347 | # ... work on items 348 | 349 | @sqlquery() 350 | def select_items_for_joe() 351 | return QueryData("SELECT * FROM table WHERE name='joe'") 352 | 353 | basic query with params passed as a dict 354 | 355 | .. code-block:: python 356 | 357 | def get_items(): 358 | items = select_items_for_name('joe') 359 | # ... work on items, which contains all records matching the name 360 | 361 | @sqlquery() 362 | def select_items_for_name(name) 363 | return QueryData("SELECT * FROM table WHERE name=:name", query_params={'name': name}) 364 | 365 | query that only returns a single result from the first row 366 | 367 | .. code-block:: python 368 | 369 | def get_joe_id(): 370 | result = get_item_for_name('joe') 371 | return result.get('id') 372 | 373 | # Either an instance or class may be used as the mapper parameter 374 | @sqlquery(mapper=SingleRowMapper()) 375 | def get_item_for_name(name) 376 | return QueryData("SELECT id, name FROM table WHERE name=:name", query_params={'name': name}) 377 | 378 | alternative to the above query that returns the id directly 379 | 380 | .. code-block:: python 381 | 382 | def get_joe_id(): 383 | return get_id_for_name('joe') 384 | 385 | @sqlquery(mapper=SingleRowAndColumnMapper) 386 | def get_id_for_name(name) 387 | return QueryData("SELECT id FROM table WHERE name=:name", query_params={'name': name}) 388 | 389 | query that returns a list of scalar values containing the list of distinct names available 390 | 391 | .. code-block:: python 392 | 393 | def get_unique_names(): 394 | return get_names_from_items() 395 | 396 | @sqlquery(mapper=SingleColumnMapper) 397 | def get_names_from_items() 398 | return QueryData("SELECT DISTINCT(name) FROM table") 399 | 400 | basic count query that only returns the scalar value returned for the count 401 | 402 | .. code-block:: python 403 | 404 | def get_count_for_joe(): 405 | return get_count_for_name('joe') 406 | 407 | @sqlquery(mapper=CountMapper) 408 | def get_count_for_name(name): 409 | return QueryData("SELECT COUNT(*) FROM table WHERE name=:name", query_params={'name': name}) 410 | 411 | 412 | basic query returning dictionary 413 | 414 | .. code-block:: python 415 | 416 | @sqlquery(mapper=KeyValueMapper()) 417 | def get_status_by_name(): 418 | return QueryData("SELECT name, status FROM table") 419 | 420 | query returning a dictionary where we are specifying the keys. Note that the columns are returning in a different order 421 | 422 | .. code-block:: python 423 | 424 | @sqlquery(mapper=KeyValueMapper(key_column='name', value_column='status')) 425 | def get_status_by_name(): 426 | return QueryData("SELECT status, name FROM table") 427 | 428 | query returning a dictionary where there are multiple results under each key. Note that here we are essentially grouping under status 429 | 430 | .. code-block:: python 431 | 432 | @sqlquery(mapper=KeyValueMapper(key_column='status', value_column='name', has_multiple_values=True)) 433 | def get_status_by_name(): 434 | return QueryData("SELECT status, name FROM table") 435 | 436 | ========== 437 | @sqlupdate 438 | ========== 439 | Handles any SQL that is not a select. This is primarily, but not limited to, ``insert``, ``update``, and ``delete``. 440 | 441 | 442 | .. code-block:: python 443 | 444 | @sqlupdate() 445 | def insert_items(item_dict): 446 | return QueryData("INSERT INTO", template_params={'in__item_id':item_id_list}) 447 | 448 | 449 | --------------------------------- 450 | multiple queries in a transaction 451 | --------------------------------- 452 | You can yield multiple QueryData objects. This is done in a transaction and it can be helpful for data integrity or just 453 | a nice clean way to run a set of updates. 454 | 455 | .. code-block:: python 456 | 457 | @sqlupdate() 458 | def insert_items(item_dict): 459 | insert_values_1, insert_params_1 = TemplateGenerator.values('table1values', _get_values_for_1_from_items(item_dict)) 460 | insert_values_2, insert_params_2 = TemplateGenerator.values('table2values', _get_values_for_2_from_items(item_dict)) 461 | yield QueryData(f'INSERT INTO table_1 {insert_values_1}', query_params=insert_values_params_1) 462 | yield QueryData(f'INSERT INTO table_2 {insert_values_2}', query_params=insert_values_params_2) 463 | 464 | -------------------------- 465 | getting the last insert id 466 | -------------------------- 467 | You can assign a callback to be ran after a query or set of queries completes successfully. This is useful when you need 468 | to get the last insert id for a table that has an auto incrementing id field. This allows you to set it as a parameter on 469 | a follow up relational table within the same transaction scope. 470 | 471 | .. code-block:: python 472 | 473 | @sqlupdate() 474 | def insert_items_with_callback(item_dict): 475 | insert_values_1, insert_params_1 = TemplateGenerator.values('table1values', _get_values_for_1_from_items(item_dict)) 476 | insert_values_2, insert_params_2 = TemplateGenerator.values('table2values', _get_values_for_2_from_items(item_dict)) 477 | yield QueryData(f'INSERT INTO table_1 {insert_values_1}', query_params=insert_values_params_1) 478 | yield QueryData(f'INSERT INTO table_2 {insert_values_2}', query_params=insert_values_params_2) 479 | 480 | def _handle_insert_success(item_dict): 481 | # callback logic here happens after the transaction is complete 482 | 483 | `get_last_insert_id` is a placeholder kwarg that will be automatically overwritten by the sqlupdate decorator at run time. 484 | Therefore, the assigned value in the function definition does not matter. 485 | 486 | 487 | Using `get_last_insert_id` gives you the most recently set id. You can leverage this for later queries yielded, or you could 488 | use it and set ids in a reference object passed in for access to the ides outside of the sqlupdate function. 489 | 490 | 491 | .. code-block:: python 492 | 493 | @sqlupdate() 494 | def insert_item_with_get_last_insert(get_last_insert_id=None, item_dict): 495 | insert_values, insert_params = TemplateGenerator.values('table1values', _get_values_from_items(item_dict)) 496 | yield QueryData(f'INSERT INTO table_1 {insert_values}', query_params=insert_values_params) 497 | last_id = get_last_insert_id() 498 | yield QueryData(f'INSERT INTO related_table_1 (table_1_id, value) VALUES (:table_1_id, :value)', 499 | query_params={'table_1_id': last_id, 'value': 'some_value'}) 500 | 501 | .. note:: 502 | `get_last_insert_id` will get you the last inserted id from the most recently table inserted with an autoincrement. 503 | Be sure to call `get_last_insert_id` right after you yield the query that inserts the record you need the id for. 504 | 505 | 506 | .. code-block:: python 507 | 508 | class Item(BaseModel): 509 | id: int | None = None 510 | name: str 511 | 512 | @sqlupdate() 513 | def insert_items_and_update_ids(items: List[Item], get_last_insert_id = None) 514 | for item in items: 515 | yield QueryData("INSERT INTO table (name) VALUES (:name)", query_params={'name': item.name}) 516 | last_id = get_last_insert_id() 517 | item.id = last_id 518 | 519 | @sqlexists 520 | ~~~~~~~~~~ 521 | This wraps a SQL query to determine if a row exists or not. If at least one row is returned from the query, it will 522 | return True, otherwise False. The query you give here can return anything you want but as good practice, 523 | try to always select as little as possible. For example, below we are just returning 1 because the value itself 524 | isn't used, we just need to know there are records available. 525 | 526 | .. code-block:: python 527 | 528 | @sqlexists() 529 | def item_exists(item_id) 530 | return QueryData("SELECT 1 FROM table WHERE id=:id", query_params={'id': item_id}) 531 | 532 | Ultimately, the above query becomes ``SELECT EXISTS (SELECT 1 FROM table WHERE id=:id)``. 533 | You'll notice the inner select value isn't actually used in the return. 534 | 535 | Decorator templates 536 | =================== 537 | 538 | Templates and generators for these templates are also provided to simplify SQL query strings. 539 | 540 | 541 | **in** template - this template will allow you to pass a list as a single parameter and have the `IN` 542 | condition build out for you. This allows you to more dynamically include values in your queries. 543 | 544 | .. code-block:: python 545 | 546 | @sqlquery() 547 | def select_items(item_id_list): 548 | return QueryData("SELECT * FROM table WHERE {in__item_id}", 549 | template_params={'in__item_id': item_id_list}) 550 | 551 | 552 | you can also use the TemlpateGenerate.in_column method to get back a tuple of query and params 553 | 554 | .. code-block:: python 555 | 556 | @sqlquery() 557 | def select_items(item_id_list): 558 | in_query, in_params = TemplateGenerators.in_column('key', item_id_list) 559 | # NOTE: the query string is using an f-string and passing into query_params instead of template_params 560 | return QueryData(f"SELECT * FROM table WHERE {in_query}", query_params=in_params) 561 | 562 | 563 | **in and not in multi column** - this template works the same as the in and not in template but it will allow you to 564 | pass a list of tuples to an in clause allowing you to match against multiple columns. 565 | `NOTE: this is only available through the TemplateGenerators using query_params and not through the the template_params method` 566 | 567 | .. code-block:: python 568 | 569 | @sqlquery() 570 | def select_multi(tuple_list): 571 | in_query, in_params = TemplateGenerators.in_multi_column('(key1, key2)', tuple_list) 572 | return QueryData(f"SELECT * FROM table WHERE {in_query}", query_params=in_params) 573 | 574 | 575 | .. code-block:: python 576 | 577 | @sqlquery() 578 | def select_multi(tuple_list): 579 | in_query, in_params = TemplateGenerators.not_in_multi_column('(key1, key2)', tuple_list) 580 | return QueryData(f"SELECT * FROM table WHERE {in_query}", query_params=in_params) 581 | 582 | 583 | **not_in** template - this template will allow you to pass a list as a single parameter and have the `NOT IN` 584 | condition build out for you. This allows you more dynamically exclude values in your queries. 585 | 586 | .. code-block:: python 587 | 588 | @sqlquery() 589 | def select_items(item_id_list) 590 | return QueryData("SELECT * FROM table WHERE {not_in__item_id}", 591 | template_params={'not_in__item_id': item_id_list}) 592 | 593 | 594 | 595 | 596 | you can also use the TemplateGenerators.not_in_column method to get back a tuple of query and params 597 | 598 | .. code-block:: python 599 | 600 | @sqlquery() 601 | def select_items(item_id_list): 602 | not_in_query, not_in_params = TemplateGenerators.not_in_column('key', item_id_list) 603 | # NOTE: the query string is using an f-string and passing into query_params instead of template_params 604 | return QueryData(f"SELECT * FROM table WHERE {not_in_query}", query_params=not_in_params) 605 | 606 | 607 | **values** template - when inserting and you have multiple records to insert, this allows you to pass 608 | multiple records for insert in a single INSERT statement. 609 | 610 | .. code-block:: python 611 | 612 | @sqlquery() 613 | def insert_items(items): 614 | return QueryData("INSERT_INTO table(column_a, column_b) {values__items}", 615 | template_params={'values__items': item_id_list}) 616 | 617 | You can write queries that combine ``template_params`` and ``query_params`` as well.. 618 | 619 | .. code-block:: python 620 | 621 | @sqlquery() 622 | def select_items(item_id_list, name): 623 | return QueryData("SELECT * FROM table WHERE {in__item_id} and name=:name", 624 | template_params={'in__item_id': item_id_list}, 625 | query_params={'name': name}) 626 | 627 | Testing with Managers 628 | ===================== 629 | 630 | During testing, it may be useful to hook up a real database to the tests. However, this can be difficult to maintain 631 | schema and isolate databases during testing. Database test managers exist for this reason. Usage is very simple with 632 | pytest. 633 | 634 | .. code-block:: python 635 | 636 | @pytest.fixture(scope='module', autouse=True) 637 | def setup_db(self): 638 | # Pass in the database name and any optional params 639 | with MariaDbTestManager(f'testdb_{self.__class__.__name__.lower()}'): 640 | yield 641 | 642 | The Maria database test manager is shown used above, but future implementations may be added for other SQL backends. 643 | --------------------------------------------------------------------------------