├── .github └── workflows │ ├── ci.yml │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.rst ├── pyproject.toml ├── setup.cfg ├── src └── sql_tstring │ ├── __init__.py │ ├── parser.py │ ├── py.typed │ └── t.py ├── tests ├── test_dialect.py ├── test_identifiers.py ├── test_parameters.py └── test_parsing.py └── tox.ini /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | tox: 11 | name: ${{ matrix.name }} 12 | runs-on: ubuntu-latest 13 | 14 | container: python:${{ matrix.python }} 15 | 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | include: 20 | - {name: '3.13', python: '3.13', tox: py313} 21 | - {name: '3.12', python: '3.12', tox: py312} 22 | - {name: 'format', python: '3.13', tox: format} 23 | - {name: 'mypy', python: '3.13', tox: mypy} 24 | - {name: 'pep8', python: '3.13', tox: pep8} 25 | - {name: 'package', python: '3.13', tox: package} 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - name: update pip 31 | run: | 32 | pip install -U wheel 33 | pip install -U setuptools 34 | python -m pip install -U pip 35 | 36 | - run: pip install tox 37 | 38 | - run: tox -e ${{ matrix.tox }} 39 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v4 11 | 12 | - uses: actions/setup-python@v4 13 | with: 14 | python-version: 3.13 15 | 16 | - run: | 17 | pip install pdm 18 | pdm build 19 | 20 | - uses: actions/upload-artifact@v4 21 | with: 22 | path: ./dist 23 | 24 | pypi-publish: 25 | needs: ['build'] 26 | environment: 'publish' 27 | 28 | name: upload release to PyPI 29 | runs-on: ubuntu-latest 30 | permissions: 31 | # IMPORTANT: this permission is mandatory for trusted publishing 32 | id-token: write 33 | steps: 34 | - uses: actions/download-artifact@v4 35 | 36 | - name: Publish package distributions to PyPI 37 | uses: pypa/gh-action-pypi-publish@release/v1 38 | with: 39 | packages_dir: artifact/ 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm-project.org/#use-with-ide 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Emacs backups 165 | *~ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright P G Jones 2025. 2 | 3 | Permission is hereby granted, free of charge, to any person 4 | obtaining a copy of this software and associated documentation 5 | files (the "Software"), to deal in the Software without 6 | restriction, including without limitation the rights to use, 7 | copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following 10 | conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | SQL-tString 2 | =========== 3 | 4 | |Build Status| |pypi| |python| |license| 5 | 6 | SQL-tString allows for t-string based construction of sql queries 7 | without allowing for SQL injection. The basic usage is as follows, 8 | 9 | .. code-block:: python 10 | 11 | from sql_tstring import sql 12 | 13 | a = 1 14 | 15 | query, values = sql( 16 | t"""SELECT a, b, c 17 | FROM tbl 18 | WHERE a = {a}""", 19 | ) 20 | 21 | The ``query`` is a ``str`` and ``values`` a ``list[Any]``, both are 22 | then typically passed to a DB connection. Note the parameters can only 23 | be identifiers that identify variables (in the above example in the 24 | locals()) e.g. ``{a - 1}`` is not valid. 25 | 26 | SQL-tString will convert parameters to SQL placeholders where 27 | appropriate. In other locations SQL-tString will allow pre defined 28 | column or table names to be used, 29 | 30 | .. code-block:: python 31 | 32 | from sql_tstring import sql, sql_context 33 | 34 | col = "a" 35 | table = "tbl" 36 | 37 | with sql_context(columns={"a"}, tables={"tbl"}): 38 | query, values = sql( 39 | t"SELECT {col} FROM {table}", 40 | ) 41 | 42 | If the value of ``col`` or ``table`` does not match the valid values 43 | given to the ``sql_context`` function an error will be raised. 44 | 45 | Rewriting values 46 | ---------------- 47 | 48 | SQL-tString will also remove parameters if they are set to the special 49 | value of ``Absent`` (or ``RewritingValue.Absent``). This is most 50 | useful for optional updates, or conditionals, 51 | 52 | .. code-block:: python 53 | 54 | from sql_tstring import Absent, sql 55 | 56 | a = Absent 57 | b = Absent 58 | 59 | query, values = sql( 60 | t"""UPDATE tbl 61 | SET a = {a}, 62 | b = 1 63 | WHERE b = {b}""", 64 | ) 65 | 66 | As both ``a`` and ``b`` are ``Absent`` the above ``query`` will be 67 | ``UPDATE tbl SET b =1``. 68 | 69 | In addition for conditionals the values ``IsNull`` (or 70 | ``RewritingValue.IS_NULL``) and ``IsNotNull`` (or 71 | ``RewritingValue.IS_NOT_NULL``) can be used to rewrite the conditional 72 | as expected. This is useful as ``x = NULL`` is always false in SQL. 73 | 74 | Paramstyle (dialect) 75 | -------------------- 76 | 77 | By default SQL-tString uses the ``qmark`` `paramstyle 78 | `_ (dialect) but also 79 | supports the ``$`` paramstyle or asyncpg dialect. This is best changed 80 | globally via, 81 | 82 | .. code-block:: python 83 | 84 | from sql_tstring import Context, set_context 85 | 86 | set_context(Context(dialect="asyncpg")) 87 | 88 | Pre Python 3.14 usage 89 | --------------------- 90 | 91 | t-strings were introduced in Python 3.14 via, `PEP 750 92 | `_, however this library can be 93 | used with Python 3.12 and 3.13 as follows, 94 | 95 | .. code-block:: python 96 | 97 | from sql_tstring import sql 98 | 99 | a = 1 100 | 101 | query, values = sql( 102 | """SELECT a, b, c 103 | FROM tbl 104 | WHERE a = {a}""", 105 | locals(), 106 | ) 107 | 108 | Please note though that only simple variable identifiers can be placed 109 | within the braces. 110 | 111 | .. |Build Status| image:: https://github.com/pgjones/sql-tstring/actions/workflows/ci.yml/badge.svg 112 | :target: https://github.com/pgjones/sql-tstring/commits/main 113 | 114 | .. |pypi| image:: https://img.shields.io/pypi/v/sql-tstring.svg 115 | :target: https://pypi.python.org/pypi/Sql-Tstring/ 116 | 117 | .. |python| image:: https://img.shields.io/pypi/pyversions/sql-tstring.svg 118 | :target: https://pypi.python.org/pypi/Sql-Tstring/ 119 | 120 | .. |license| image:: https://img.shields.io/badge/license-MIT-blue.svg 121 | :target: https://github.com/pgjones/sql-tstring/blob/main/LICENSE 122 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sql-tstring" 3 | version = "0.3.0" 4 | description = "SQL builder via string templates" 5 | authors = [ 6 | {name = "pgjones", email = "philip.graham.jones@googlemail.com"}, 7 | ] 8 | classifiers = [ 9 | "Development Status :: 3 - Alpha", 10 | "Intended Audience :: Developers", 11 | "Operating System :: OS Independent", 12 | "Programming Language :: Python", 13 | "Programming Language :: Python :: 3", 14 | "Programming Language :: Python :: 3.12", 15 | "Programming Language :: Python :: 3.13", 16 | "Topic :: Software Development :: Libraries :: Python Modules", 17 | ] 18 | include = ["src/sql_tstring/py.typed"] 19 | license = "MIT" 20 | license-files = ["LICENSE"] 21 | readme = "README.rst" 22 | repository = "https://github.com/pgjones/sql-tstring/" 23 | dependencies = [] 24 | requires-python = ">=3.12" 25 | 26 | [project.optional-dependencies] 27 | docs = ["pydata_sphinx_theme"] 28 | 29 | [tool.black] 30 | line-length = 100 31 | 32 | [tool.isort] 33 | combine_as_imports = true 34 | force_grid_wrap = 0 35 | include_trailing_comma = true 36 | known_first_party = "sql_tstring, tests" 37 | line_length = 100 38 | multi_line_output = 3 39 | no_lines_before = "LOCALFOLDER" 40 | order_by_type = false 41 | reverse_relative = true 42 | 43 | [tool.mypy] 44 | allow_redefinition = true 45 | disallow_any_generics = false 46 | disallow_subclassing_any = true 47 | disallow_untyped_calls = false 48 | disallow_untyped_defs = true 49 | implicit_reexport = true 50 | no_implicit_optional = true 51 | show_error_codes = true 52 | strict = true 53 | strict_equality = true 54 | strict_optional = false 55 | warn_redundant_casts = true 56 | warn_return_any = false 57 | warn_unused_configs = true 58 | warn_unused_ignores = true 59 | 60 | [tool.pytest.ini_options] 61 | addopts = "--no-cov-on-fail --showlocals --strict-markers" 62 | testpaths = ["tests"] 63 | 64 | [build-system] 65 | requires = ["pdm-backend"] 66 | build-backend = "pdm.backend" 67 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E252, E704, W503, W504 3 | max_line_length = 100 4 | -------------------------------------------------------------------------------- /src/sql_tstring/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import typing 4 | from contextvars import ContextVar 5 | from copy import deepcopy 6 | from dataclasses import dataclass, field, replace 7 | from enum import auto, Enum, unique 8 | from types import TracebackType 9 | 10 | from sql_tstring.parser import ( 11 | Clause, 12 | Element, 13 | Expression, 14 | ExpressionGroup, 15 | Function, 16 | Group, 17 | Literal, 18 | Operator, 19 | parse, 20 | Part, 21 | Placeholder, 22 | PlaceholderType, 23 | Statement, 24 | ) 25 | from sql_tstring.t import t, Template as TTemplate 26 | 27 | try: 28 | from string.templatelib import Template # type: ignore[import-untyped] 29 | except ImportError: 30 | 31 | class Template: # type: ignore[no-redef] 32 | pass 33 | 34 | 35 | @unique 36 | class RewritingValue(Enum): 37 | ABSENT = auto() 38 | IS_NULL = auto() 39 | IS_NOT_NULL = auto() 40 | 41 | 42 | type AbsentType = typing.Literal[RewritingValue.ABSENT] 43 | Absent: AbsentType = RewritingValue.ABSENT 44 | IsNull = RewritingValue.IS_NULL 45 | IsNotNull = RewritingValue.IS_NOT_NULL 46 | 47 | 48 | @dataclass 49 | class Context: 50 | columns: set[str] = field(default_factory=set) 51 | dialect: typing.Literal["asyncpg", "sql"] = "sql" 52 | tables: set[str] = field(default_factory=set) 53 | 54 | 55 | _context_var: ContextVar[Context] = ContextVar("sql_tstring_context") 56 | 57 | 58 | def get_context() -> Context: 59 | try: 60 | return _context_var.get() 61 | except LookupError: 62 | context = Context() 63 | _context_var.set(context) 64 | return context 65 | 66 | 67 | def set_context(context: Context) -> None: 68 | _context_var.set(context) 69 | 70 | 71 | def sql_context( 72 | columns: set[typing.LiteralString] | None = None, 73 | dialect: typing.Literal["asyncpg", "sql"] | None = None, 74 | tables: set[typing.LiteralString] | None = None, 75 | ) -> _ContextManager: 76 | ctx = get_context() 77 | ctx_manager = _ContextManager(ctx) 78 | if columns is not None: 79 | ctx_manager._context.columns = columns 80 | if dialect is not None: 81 | ctx_manager._context.dialect = dialect 82 | if tables is not None: 83 | ctx_manager._context.tables = tables 84 | return ctx_manager 85 | 86 | 87 | def sql( 88 | query_or_template: str | Template | TTemplate, values: dict[str, typing.Any] | None = None 89 | ) -> tuple[str, list]: 90 | template: Template 91 | if isinstance(query_or_template, (Template, TTemplate)) and values is None: 92 | template = query_or_template 93 | elif isinstance(query_or_template, str) and values is not None: 94 | template = t(query_or_template, values) 95 | else: 96 | raise ValueError("Must call with a template, or a query string and values") 97 | 98 | parsed_queries = parse(template) 99 | result_str = "" 100 | result_values: list[typing.Any] = [] 101 | ctx = get_context() 102 | for raw_parsed_query in parsed_queries: 103 | parsed_query = deepcopy(raw_parsed_query) 104 | new_values = _replace_placeholders(parsed_query, 0) 105 | result_str += _print_node(parsed_query, [None] * len(result_values), ctx.dialect) 106 | result_values.extend(new_values) 107 | 108 | return result_str, result_values 109 | 110 | 111 | class _ContextManager: 112 | def __init__(self, context: Context) -> None: 113 | self._context = replace(context) 114 | 115 | def __enter__(self) -> Context: 116 | self._original_context = get_context() 117 | set_context(self._context) 118 | return self._context 119 | 120 | def __exit__( 121 | self, 122 | _type: type[BaseException] | None, 123 | _value: BaseException | None, 124 | _traceback: TracebackType | None, 125 | ) -> None: 126 | set_context(self._original_context) 127 | 128 | 129 | def _check_valid( 130 | value: object, 131 | *, 132 | case_sensitive: set[str] | None = None, 133 | case_insensitive: set[str] | None = None, 134 | ) -> None: 135 | if case_sensitive is None: 136 | case_sensitive = set() 137 | if case_insensitive is None: 138 | case_insensitive = set() 139 | if not isinstance(value, str) or ( 140 | value not in case_sensitive and value.lower() not in case_insensitive 141 | ): 142 | raise ValueError( 143 | f"{value} is not valid, must be one of {case_sensitive} or {case_insensitive}" 144 | ) 145 | 146 | 147 | def _print_node( 148 | node: Element, 149 | placeholders: list | None = None, 150 | dialect: str = "sql", 151 | ) -> str: 152 | if placeholders is None: 153 | placeholders = [] 154 | 155 | match node: 156 | case Statement(): 157 | result = " ".join(_print_node(clause, placeholders, dialect) for clause in node.clauses) 158 | case Clause() | ExpressionGroup(): 159 | result = "" 160 | 161 | for expression in node.expressions: 162 | addition = _print_node(expression, placeholders, dialect) 163 | separator = "" 164 | if result != "": 165 | separator = expression.separator 166 | if addition != "": 167 | result += f" {separator} {addition}" 168 | 169 | result = result.strip() 170 | 171 | if isinstance(node, ExpressionGroup): 172 | if result != "": 173 | result = f"({result})" 174 | else: 175 | if node.removed: 176 | result = "" 177 | elif result == "" and not node.properties.allow_empty: 178 | result = "" 179 | else: 180 | result = f"{node.text} {result}" 181 | case Expression(): 182 | if not node.removed: 183 | result = " ".join(_print_node(part, placeholders, dialect) for part in node.parts) 184 | else: 185 | result = "" 186 | case Function(): 187 | arguments = " ".join(_print_node(part, placeholders, dialect) for part in node.parts) 188 | result = f"{node.name}({arguments})" 189 | case Group(): 190 | result = ( 191 | f"({" ".join(_print_node(part, placeholders, dialect) for part in node.parts)})" 192 | ) 193 | case Operator(): 194 | result = node.text 195 | case Part(): 196 | result = node.text 197 | case Placeholder(): 198 | placeholders.append(None) 199 | result = f"${len(placeholders)}" if dialect == "asyncpg" else "?" 200 | case Literal(): 201 | value = "".join(_print_node(part, placeholders, dialect) for part in node.parts) 202 | result = f"'{value}'" 203 | 204 | return result.strip() 205 | 206 | 207 | def _replace_placeholders( 208 | node: Element, 209 | index: int, 210 | ) -> list[typing.Any]: 211 | result = [] 212 | match node: 213 | case Statement(): 214 | for clause_ in node.clauses: 215 | result.extend(_replace_placeholders(clause_, 0)) 216 | case Clause() | ExpressionGroup(): 217 | for index, expression_ in enumerate(node.expressions): 218 | result.extend(_replace_placeholders(expression_, index)) 219 | case Expression() | Function() | Group() | Literal(): 220 | for index, part in enumerate(node.parts): 221 | result.extend(_replace_placeholders(part, index)) 222 | case Placeholder(): 223 | result.extend(_replace_placeholder(node, index)) 224 | 225 | return result 226 | 227 | 228 | def _replace_placeholder( 229 | node: Placeholder, 230 | index: int, 231 | ) -> list[typing.Any]: 232 | result = [] 233 | ctx = get_context() 234 | 235 | clause_or_function = node.parent 236 | while not isinstance(clause_or_function, (Clause, Function)): 237 | clause_or_function = clause_or_function.parent # type: ignore 238 | 239 | clause: Clause | None = None 240 | placeholder_type = PlaceholderType.VARIABLE 241 | if isinstance(clause_or_function, Clause): 242 | clause = clause_or_function 243 | placeholder_type = clause_or_function.properties.placeholder_type 244 | 245 | value = node.value 246 | new_node: Part | Placeholder 247 | if value is RewritingValue.ABSENT: 248 | if placeholder_type == PlaceholderType.VARIABLE_DEFAULT: 249 | new_node = Part(text="DEFAULT", parent=node.parent) 250 | node.parent.parts[index] = new_node 251 | elif placeholder_type == PlaceholderType.LOCK: 252 | if clause is not None: 253 | clause.removed = True 254 | else: 255 | expression = node.parent 256 | while not isinstance(expression, Expression): 257 | expression = expression.parent 258 | 259 | expression.removed = True 260 | elif isinstance(node.parent, Literal): 261 | if not isinstance(value, str): 262 | raise RuntimeError("Invalid placeholder usage") 263 | else: 264 | for position, part in enumerate(node.parent.parent.parts): 265 | if part is node.parent: 266 | value = "" 267 | for bit in part.parts: # type: ignore[union-attr] 268 | if isinstance(bit, Placeholder): 269 | value += str(bit.value) 270 | else: 271 | value += bit.text # type: ignore[union-attr] 272 | node.parent.parent.parts[position] = Placeholder( 273 | parent=part.parent, value=value # type: ignore[arg-type] 274 | ) 275 | result.append(value) 276 | else: 277 | if clause is not None and clause.text.lower() == "order by": 278 | _check_valid( 279 | value, 280 | case_sensitive=ctx.columns, 281 | case_insensitive={"asc", "ascending", "desc", "descending"}, 282 | ) 283 | new_node = Part(text=typing.cast(str, value), parent=node.parent) 284 | elif placeholder_type == PlaceholderType.COLUMN: 285 | _check_valid(value, case_sensitive=ctx.columns) 286 | new_node = Part(text=typing.cast(str, value), parent=node.parent) 287 | elif placeholder_type == PlaceholderType.TABLE: 288 | _check_valid(value, case_sensitive=ctx.tables) 289 | new_node = Part(text=typing.cast(str, value), parent=node.parent) 290 | elif placeholder_type == PlaceholderType.LOCK: 291 | _check_valid(value, case_insensitive={"", "nowait", "skip locked"}) 292 | new_node = Part(text=typing.cast(str, value), parent=node.parent) 293 | else: 294 | if ( 295 | value is RewritingValue.IS_NULL or value is RewritingValue.IS_NOT_NULL 296 | ) and placeholder_type == PlaceholderType.VARIABLE_CONDITION: 297 | for part in node.parent.parts: 298 | if isinstance(part, Operator): 299 | if value is RewritingValue.IS_NULL: 300 | part.text = "IS" 301 | else: 302 | part.text = "IS NOT" 303 | new_node = Part(text="NULL", parent=node.parent) 304 | else: 305 | new_node = node 306 | result.append(value) 307 | 308 | if isinstance(node.parent, (Expression, ExpressionGroup, Function, Group)): 309 | node.parent.parts[index] = new_node 310 | else: 311 | raise RuntimeError("Invalid placeholder") 312 | 313 | return result 314 | -------------------------------------------------------------------------------- /src/sql_tstring/parser.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import re 4 | from dataclasses import dataclass, field 5 | from enum import auto, Enum, unique 6 | from typing import cast 7 | 8 | from sql_tstring.t import Template as TTemplate 9 | 10 | try: 11 | from string.templatelib import Interpolation, Template # type: ignore[import-untyped] 12 | except ImportError: 13 | from sql_tstring.t import Interpolation, Template 14 | 15 | SPLIT_RE = re.compile(r"([^\s'(]+\(|\(|'+|[ ',;)\n\t])") 16 | 17 | 18 | @unique 19 | class PlaceholderType(Enum): 20 | COLUMN = auto() 21 | DISALLOWED = auto() 22 | LOCK = auto() 23 | TABLE = auto() 24 | VARIABLE = auto() 25 | VARIABLE_CONDITION = auto() 26 | VARIABLE_DEFAULT = auto() 27 | 28 | 29 | @dataclass 30 | class ClauseProperties: 31 | allow_empty: bool 32 | placeholder_type: PlaceholderType 33 | separators: set[str] 34 | 35 | 36 | type ClauseDictionary = dict[str, "ClauseDictionary" | ClauseProperties] 37 | 38 | _JOIN_CLAUSE = ClauseProperties( 39 | allow_empty=False, placeholder_type=PlaceholderType.TABLE, separators=set() 40 | ) 41 | 42 | CLAUSES: ClauseDictionary = { 43 | "delete": { 44 | "from": { 45 | "": ClauseProperties( 46 | allow_empty=False, placeholder_type=PlaceholderType.TABLE, separators=set() 47 | ), 48 | }, 49 | }, 50 | "default": { 51 | "values": { 52 | "": ClauseProperties( 53 | allow_empty=True, 54 | placeholder_type=PlaceholderType.DISALLOWED, 55 | separators=set(), 56 | ), 57 | }, 58 | }, 59 | "for": { 60 | "update": { 61 | "": ClauseProperties( 62 | allow_empty=True, placeholder_type=PlaceholderType.LOCK, separators=set() 63 | ) 64 | }, 65 | }, 66 | "full": { 67 | "join": { 68 | "": _JOIN_CLAUSE, 69 | }, 70 | "outer": { 71 | "join": { 72 | "": _JOIN_CLAUSE, 73 | }, 74 | }, 75 | }, 76 | "group": { 77 | "by": { 78 | "": ClauseProperties( 79 | allow_empty=False, placeholder_type=PlaceholderType.COLUMN, separators={","} 80 | ) 81 | }, 82 | }, 83 | "inner": { 84 | "join": { 85 | "": _JOIN_CLAUSE, 86 | }, 87 | }, 88 | "insert": { 89 | "into": { 90 | "": ClauseProperties( 91 | allow_empty=True, 92 | placeholder_type=PlaceholderType.DISALLOWED, 93 | separators=set(), 94 | ) 95 | }, 96 | }, 97 | "left": { 98 | "join": { 99 | "": _JOIN_CLAUSE, 100 | }, 101 | "outer": { 102 | "join": { 103 | "": _JOIN_CLAUSE, 104 | }, 105 | }, 106 | }, 107 | "on": { 108 | "conflict": { 109 | "": ClauseProperties( 110 | allow_empty=True, 111 | placeholder_type=PlaceholderType.DISALLOWED, 112 | separators=set(), 113 | ) 114 | }, 115 | "": ClauseProperties( 116 | allow_empty=False, placeholder_type=PlaceholderType.VARIABLE, separators={","} 117 | ), 118 | }, 119 | "order": { 120 | "by": { 121 | "": ClauseProperties( 122 | allow_empty=False, placeholder_type=PlaceholderType.COLUMN, separators={","} 123 | ) 124 | }, 125 | }, 126 | "right": { 127 | "join": { 128 | "": _JOIN_CLAUSE, 129 | }, 130 | "outer": { 131 | "join": { 132 | "": _JOIN_CLAUSE, 133 | }, 134 | }, 135 | }, 136 | "do": { 137 | "update": { 138 | "set": { 139 | "": ClauseProperties( 140 | allow_empty=False, 141 | placeholder_type=PlaceholderType.VARIABLE, 142 | separators={","}, 143 | ), 144 | }, 145 | }, 146 | "": ClauseProperties( 147 | allow_empty=False, placeholder_type=PlaceholderType.DISALLOWED, separators=set() 148 | ), 149 | }, 150 | "from": { 151 | "": ClauseProperties( 152 | allow_empty=False, placeholder_type=PlaceholderType.TABLE, separators=set() 153 | ) 154 | }, 155 | "having": { 156 | "": ClauseProperties( 157 | allow_empty=False, 158 | placeholder_type=PlaceholderType.VARIABLE_CONDITION, 159 | separators={"and", "or"}, 160 | ) 161 | }, 162 | "join": { 163 | "": _JOIN_CLAUSE, 164 | }, 165 | "limit": { 166 | "": ClauseProperties( 167 | allow_empty=False, placeholder_type=PlaceholderType.VARIABLE, separators=set() 168 | ) 169 | }, 170 | "offset": { 171 | "": ClauseProperties( 172 | allow_empty=False, placeholder_type=PlaceholderType.VARIABLE, separators=set() 173 | ) 174 | }, 175 | "returning": { 176 | "": ClauseProperties( 177 | allow_empty=False, placeholder_type=PlaceholderType.DISALLOWED, separators={","} 178 | ) 179 | }, 180 | "select": { 181 | "": ClauseProperties( 182 | allow_empty=False, placeholder_type=PlaceholderType.COLUMN, separators={","} 183 | ) 184 | }, 185 | "set": { 186 | "": ClauseProperties( 187 | allow_empty=False, placeholder_type=PlaceholderType.VARIABLE, separators={","} 188 | ) 189 | }, 190 | "update": { 191 | "": ClauseProperties( 192 | allow_empty=False, placeholder_type=PlaceholderType.DISALLOWED, separators=set() 193 | ) 194 | }, 195 | "values": { 196 | "": ClauseProperties( 197 | allow_empty=False, 198 | placeholder_type=PlaceholderType.VARIABLE_DEFAULT, 199 | separators={","}, 200 | ) 201 | }, 202 | "with": { 203 | "": ClauseProperties( 204 | allow_empty=False, placeholder_type=PlaceholderType.DISALLOWED, separators=set() 205 | ) 206 | }, 207 | "where": { 208 | "": ClauseProperties( 209 | allow_empty=False, 210 | placeholder_type=PlaceholderType.VARIABLE_CONDITION, 211 | separators={"and", "or"}, 212 | ) 213 | }, 214 | } 215 | 216 | OPERATORS: dict[str, dict] = { 217 | "=": {}, 218 | "<>": {}, 219 | "!=": {}, 220 | ">": {}, 221 | "<": {}, 222 | ">=": {}, 223 | "<=": {}, 224 | "between": {}, 225 | "ilike": {}, 226 | "in": {}, 227 | "is": { 228 | "not": { 229 | "null": {}, 230 | "true": {}, 231 | }, 232 | "null": {}, 233 | "true": {}, 234 | }, 235 | "like": {}, 236 | "not": { 237 | "between": {}, 238 | "ilike": {}, 239 | "in": {}, 240 | "like": {}, 241 | }, 242 | } 243 | 244 | 245 | @dataclass 246 | class Statement: 247 | clauses: list[Clause] = field(default_factory=list) 248 | parent: ExpressionGroup | Function | Group | None = None 249 | 250 | 251 | @dataclass 252 | class Clause: 253 | parent: Statement 254 | properties: ClauseProperties 255 | text: str 256 | expressions: list[Expression] = field(init=False) 257 | removed: bool = False 258 | 259 | def __post_init__(self) -> None: 260 | self.expressions = [Expression(self)] 261 | 262 | 263 | @dataclass 264 | class Expression: 265 | parent: Clause | ExpressionGroup 266 | parts: list[ 267 | ExpressionGroup | Function | Group | Operator | Part | Placeholder | Statement | Literal 268 | ] = field(default_factory=list) 269 | removed: bool = False 270 | separator: str = "" 271 | 272 | 273 | @dataclass 274 | class Part: 275 | parent: Expression | Function | Group | Literal 276 | text: str 277 | 278 | 279 | @dataclass 280 | class Placeholder: 281 | parent: Expression | Function | Group | Literal 282 | value: object 283 | 284 | 285 | @dataclass 286 | class Group: 287 | parent: Expression | Function | Group 288 | parts: list[Function | Group | Literal | Operator | Part | Placeholder | Statement] = field( 289 | default_factory=list 290 | ) 291 | 292 | 293 | @dataclass 294 | class ExpressionGroup: 295 | parent: Expression 296 | expressions: list[Expression] = field(init=False) 297 | 298 | def __post_init__(self) -> None: 299 | self.expressions = [Expression(self)] 300 | 301 | 302 | @dataclass 303 | class Function: 304 | name: str 305 | parent: Expression | Function | Group 306 | parts: list[Function | Group | Literal | Operator | Part | Placeholder | Statement] = field( 307 | default_factory=list 308 | ) 309 | 310 | 311 | @dataclass 312 | class Literal: 313 | parent: Expression | Function | Group 314 | parts: list[Operator | Part | Placeholder] = field(default_factory=list) 315 | 316 | 317 | @dataclass 318 | class Operator: 319 | parent: Expression | Function | Group | Literal 320 | text: str 321 | 322 | 323 | type ParentNode = Clause | Expression | ExpressionGroup | Function | Group 324 | type Node = ParentNode | Literal | Statement 325 | type Element = Node | Operator | Part | Placeholder 326 | 327 | 328 | def parse(template: Template) -> list[Statement]: 329 | statements = [Statement()] 330 | current_node: Node = statements[0] 331 | _parse_template(template, current_node, statements) 332 | return statements 333 | 334 | 335 | def _parse_template(template: Template, current_node: Node, statements: list[Statement]) -> None: 336 | for item in template: 337 | match item: 338 | case Interpolation(value, _, _, _): # type: ignore[misc] 339 | if isinstance(value, (Template, TTemplate)): 340 | _parse_template(value, current_node, statements) 341 | else: 342 | _parse_placeholder(current_node, value) 343 | case str() as raw: 344 | current_node = _parse_string(raw, current_node, statements) 345 | 346 | 347 | def _parse_placeholder( 348 | current_node: Node, 349 | value: object, 350 | ) -> None: 351 | if isinstance(current_node, (Expression, Function, Group, Literal)): 352 | parent = current_node 353 | elif isinstance(current_node, Statement): 354 | raise ValueError("Invalid syntax") 355 | else: # Clause | ExpressionGroup 356 | parent = current_node.expressions[-1] 357 | placeholder = Placeholder(parent=parent, value=value) 358 | parent.parts.append(placeholder) 359 | 360 | 361 | def _parse_string( 362 | raw: str, 363 | current_node: Node, 364 | statements: list[Statement], 365 | ) -> Node: 366 | tokens = [part.strip() for part in SPLIT_RE.split(raw) if part.strip() != ""] 367 | index = 0 368 | while index < len(tokens): 369 | raw_current_token = tokens[index] 370 | current_token = raw_current_token.lower() 371 | 372 | consumed = 1 373 | if isinstance(current_node, Literal): 374 | if current_token == "'": 375 | current_node = _find_node( # type: ignore[assignment] 376 | current_node.parent, (Clause, ExpressionGroup, Function, Group) 377 | ) 378 | else: 379 | current_node.parts.append(Part(parent=current_node, text=raw_current_token)) 380 | elif isinstance(current_node, (Function, Group)): 381 | if current_token == ")": 382 | group_or_function = _find_node(current_node, (Function, Group)) 383 | current_node = _find_node( # type: ignore[assignment] 384 | group_or_function.parent, (Clause, ExpressionGroup, Function, Group) 385 | ) 386 | else: 387 | current_node, consumed = _parse_token( 388 | current_node, raw_current_token, current_token, tokens[index:], statements 389 | ) 390 | elif isinstance(current_node, ExpressionGroup): 391 | if current_token == ")": 392 | group = _find_node(current_node, ExpressionGroup) 393 | current_node = _find_node( # type: ignore[assignment] 394 | group.parent, (Clause, ExpressionGroup) 395 | ) 396 | else: 397 | clause = _find_node(current_node, Clause) 398 | if current_token in clause.properties.separators: 399 | current_node.expressions.append( 400 | Expression(parent=current_node, separator=raw_current_token) 401 | ) 402 | else: 403 | current_node, consumed = _parse_token( 404 | current_node, raw_current_token, current_token, tokens[index:], statements 405 | ) 406 | elif isinstance(current_node, Clause): 407 | if current_token in current_node.properties.separators: 408 | current_node.expressions.append( 409 | Expression(parent=current_node, separator=raw_current_token) 410 | ) 411 | else: 412 | current_node, consumed = _parse_token( 413 | current_node, raw_current_token, current_token, tokens[index:], statements 414 | ) 415 | elif isinstance(current_node, Expression): 416 | clause = _find_node(current_node, Clause) 417 | if current_token in clause.properties.separators: 418 | parent_group = cast( 419 | Clause | ExpressionGroup, _find_node(current_node, (Clause, ExpressionGroup)) 420 | ) 421 | current_node = Expression(parent=parent_group, separator=raw_current_token) 422 | parent_group.expressions.append(current_node) 423 | else: 424 | current_node, consumed = _parse_token( 425 | current_node, raw_current_token, current_token, tokens[index:], statements 426 | ) 427 | else: # Statement 428 | current_node, consumed = _parse_token( 429 | current_node, raw_current_token, current_token, tokens[index:], statements 430 | ) 431 | 432 | index += consumed 433 | 434 | return current_node 435 | 436 | 437 | def _parse_token( 438 | current_node: ParentNode | Statement, 439 | raw_current_token: str, 440 | current_token: str, 441 | tokens: list[str], 442 | statements: list[Statement], 443 | ) -> tuple[Node, int]: 444 | if current_token in CLAUSES: 445 | return _parse_clause(current_node, tokens) 446 | elif current_token == ";": 447 | statements.append(Statement()) 448 | return statements[-1], 1 449 | elif not isinstance(current_node, Statement): 450 | if current_token in OPERATORS: 451 | return _parse_operator(current_node, tokens) 452 | elif current_token == "'": 453 | return _parse_literal(current_node) 454 | elif current_token == "(": 455 | return _parse_group(current_node) 456 | elif current_token.endswith("("): 457 | return _parse_function(current_node, raw_current_token[:-1]) 458 | elif current_token == ")": 459 | current_node = _find_node( # type: ignore[assignment] 460 | current_node, (ExpressionGroup, Function, Group) 461 | ) 462 | return current_node.parent, 1 463 | else: 464 | return _parse_part(current_node, raw_current_token) 465 | else: 466 | raise ValueError("Invalid syntax") 467 | 468 | 469 | def _parse_clause( 470 | current_node: ParentNode | Statement, 471 | tokens: list[str], 472 | ) -> tuple[Clause, int]: 473 | index = 0 474 | clause_entry = CLAUSES 475 | text = "" 476 | while index < len(tokens) and tokens[index].lower() in clause_entry: 477 | clause_entry = cast(ClauseDictionary, clause_entry[tokens[index].lower()]) 478 | text = f"{text} {tokens[index]}".strip() 479 | index += 1 480 | 481 | if isinstance(current_node, (Function, Group)): 482 | statement = Statement(parent=current_node) 483 | current_node.parts.append(statement) 484 | current_node = statement 485 | elif isinstance(current_node, ExpressionGroup): 486 | statement = Statement(parent=current_node) 487 | current_node.expressions[-1].parts.append(statement) 488 | current_node = statement 489 | else: # Clause | Expression | Statement 490 | current_node = _find_node(current_node, Statement) 491 | 492 | clause_properties = cast(ClauseProperties, clause_entry[""]) 493 | clause = Clause( 494 | parent=current_node, 495 | properties=clause_properties, 496 | text=text, 497 | ) 498 | current_node.clauses.append(clause) 499 | return clause, index 500 | 501 | 502 | def _parse_operator[T: ParentNode]( 503 | current_node: T, 504 | tokens: list[str], 505 | ) -> tuple[T, int]: 506 | index = 0 507 | operator_entry = OPERATORS 508 | text = "" 509 | while index < len(tokens) and tokens[index].lower() in operator_entry: 510 | operator_entry = operator_entry[tokens[index].lower()] 511 | text = f"{text} {tokens[index]}".strip() 512 | index += 1 513 | if isinstance(current_node, (Expression, Function, Group)): 514 | parent = current_node 515 | else: # Clause | ExpressionGroup 516 | parent = current_node.expressions[-1] 517 | parent.parts.append(Operator(parent=parent, text=text)) 518 | return current_node, index 519 | 520 | 521 | def _parse_group( 522 | current_node: ParentNode, 523 | ) -> tuple[ExpressionGroup | Group, int]: 524 | group: ExpressionGroup | Group 525 | if isinstance(current_node, (Expression, Function, Group)): 526 | group = Group(parent=current_node) 527 | current_node.parts.append(group) 528 | return group, 1 529 | else: # Clause | ExpressionGroup 530 | parent = current_node.expressions[-1] 531 | if len(parent.parts) == 0: 532 | group = ExpressionGroup(parent=parent) 533 | else: 534 | group = Group(parent=parent) 535 | parent.parts.append(group) 536 | return group, 1 537 | 538 | 539 | def _parse_function( 540 | current_node: ParentNode, 541 | name: str, 542 | ) -> tuple[Function, int]: 543 | if isinstance(current_node, (Expression, Function, Group)): 544 | parent = current_node 545 | else: # Clause | ExpressionGroup 546 | parent = current_node.expressions[-1] 547 | func = Function(name=name, parent=parent) 548 | parent.parts.append(func) 549 | return func, 1 550 | 551 | 552 | def _parse_literal(current_node: ParentNode) -> tuple[Literal, int]: 553 | if isinstance(current_node, (Expression, Function, Group)): 554 | value = Literal(parent=current_node) 555 | current_node.parts.append(value) 556 | return value, 1 557 | else: # Clause | ExpressionGroup 558 | parent = current_node.expressions[-1] 559 | value = Literal(parent=parent) 560 | parent.parts.append(value) 561 | return value, 1 562 | 563 | 564 | def _parse_part[T: ParentNode]( 565 | current_node: T, 566 | text: str, 567 | ) -> tuple[T, int]: 568 | if isinstance(current_node, (Expression, Function, Group)): 569 | parent = current_node 570 | else: # Clause | ExpressionGroup 571 | parent = current_node.expressions[-1] 572 | parent.parts.append(Part(parent=parent, text=text)) 573 | return current_node, 1 574 | 575 | 576 | def _find_node[T: Element](current_node: Element, target: type[T] | tuple[type[T], ...]) -> T: 577 | while not isinstance(current_node, target): 578 | if current_node is None: 579 | raise ValueError("Parsing Error") 580 | current_node = current_node.parent 581 | return cast(T, current_node) 582 | -------------------------------------------------------------------------------- /src/sql_tstring/py.typed: -------------------------------------------------------------------------------- 1 | Marker 2 | -------------------------------------------------------------------------------- /src/sql_tstring/t.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Any, Iterator 3 | 4 | PLACEHOLDER_RE = re.compile(r"(?<=(? None: 11 | self.value = value 12 | self.expr = "" 13 | self.conv = None 14 | self.format_spec = "" 15 | 16 | 17 | class Template: 18 | def __init__(self, values: list[str | Interpolation]) -> None: 19 | self._values = values 20 | 21 | def __iter__(self) -> Iterator[str | Interpolation]: 22 | return iter(self._values) 23 | 24 | 25 | def t(raw: str, values: dict[str, Any]) -> Template: 26 | parts: list[str | Interpolation] = [] 27 | position = 0 28 | for match_ in PLACEHOLDER_RE.finditer(raw): 29 | end = match_.start() - 1 30 | if position != end: 31 | parts.append(raw[position:end].replace("{{", "{").replace("}}", "}")) 32 | position = match_.end() + 1 33 | parts.append(Interpolation(value=values[match_.group(0)])) 34 | 35 | if position != len(raw): 36 | parts.append(raw[position:].replace("{{", "{").replace("}}", "}")) 37 | 38 | return Template(parts) 39 | -------------------------------------------------------------------------------- /tests/test_dialect.py: -------------------------------------------------------------------------------- 1 | from sql_tstring import RewritingValue, sql, sql_context 2 | 3 | 4 | def test_asyncpg() -> None: 5 | a = 1 6 | b = RewritingValue.ABSENT 7 | c = 2 8 | with sql_context(dialect="asyncpg"): 9 | assert ("SELECT x FROM y WHERE a = $1 AND c = $2", [1, 2]) == sql( 10 | "SELECT x FROM y WHERE a = {a} AND b = {b} AND c = {c}", locals() 11 | ) 12 | -------------------------------------------------------------------------------- /tests/test_identifiers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sql_tstring import RewritingValue, sql, sql_context 4 | 5 | 6 | def test_order_by() -> None: 7 | a = RewritingValue.ABSENT 8 | b = "x" 9 | with sql_context(columns={"x"}): 10 | assert ("SELECT x FROM y ORDER BY x", []) == sql( 11 | "SELECT x FROM y ORDER BY {a}, {b}", locals() 12 | ) 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "a", 17 | ["ASC", "ASCENDING", "asc", "ascending"], 18 | ) 19 | def test_order_by_direction(a: str) -> None: 20 | b = "x" 21 | with sql_context(columns={"x"}): 22 | assert (f"SELECT x FROM y ORDER BY x {a}", []) == sql( 23 | "SELECT x FROM y ORDER BY {b} {a}", locals() 24 | ) 25 | 26 | 27 | def test_order_by_invalid_column() -> None: 28 | a = RewritingValue.ABSENT 29 | b = "x" 30 | with pytest.raises(ValueError): 31 | sql("SELECT x FROM y ORDER BY {a}, {b}", locals()) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "lock_type, expected", 36 | ( 37 | ("", "SELECT x FROM y FOR UPDATE"), 38 | ("NOWAIT", "SELECT x FROM y FOR UPDATE NOWAIT"), 39 | ("SKIP LOCKED", "SELECT x FROM y FOR UPDATE SKIP LOCKED"), 40 | ), 41 | ) 42 | def test_lock(lock_type: str, expected: str) -> None: 43 | assert (expected, []) == sql("SELECT x FROM y FOR UPDATE {lock_type}", locals()) 44 | 45 | 46 | def test_absent_lock() -> None: 47 | a = RewritingValue.ABSENT 48 | assert ("SELECT x FROM y", []) == sql("SELECT x FROM y FOR UPDATE {a}", locals()) 49 | -------------------------------------------------------------------------------- /tests/test_parameters.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | 5 | from sql_tstring import RewritingValue, sql, sql_context, t 6 | 7 | TZ = "uk" 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "query, expected_query, expected_values", 12 | [ 13 | ( 14 | "SELECT x FROM y WHERE x = {val}", 15 | "SELECT x FROM y WHERE x = ?", 16 | [2], 17 | ), 18 | ( 19 | "SELECT x FROM y WHERE DATE(x AT TIME ZONE {TZ}) >= {val}", 20 | "SELECT x FROM y WHERE DATE(x AT TIME ZONE ?) >= ?", 21 | ["uk", 2], 22 | ), 23 | ( 24 | "SELECT x FROM y WHERE x = ANY({val})", 25 | "SELECT x FROM y WHERE x = ANY(?)", 26 | [2], 27 | ), 28 | ( 29 | "SELECT x FROM y JOIN z ON u = {val}", 30 | "SELECT x FROM y JOIN z ON u = ?", 31 | [2], 32 | ), 33 | ( 34 | "UPDATE x SET x = {val}", 35 | "UPDATE x SET x = ?", 36 | [2], 37 | ), 38 | ( 39 | "SELECT {col} FROM {tbl}", 40 | "SELECT col FROM tbl", 41 | [], 42 | ), 43 | ( 44 | "SELECT x FROM y LIMIT {val} OFFSET {val}", 45 | "SELECT x FROM y LIMIT ? OFFSET ?", 46 | [2, 2], 47 | ), 48 | ( 49 | "SELECT x FROM y ORDER BY ARRAY_POSITION({val}, x)", 50 | "SELECT x FROM y ORDER BY ARRAY_POSITION(? , x)", 51 | [2], 52 | ), 53 | ( 54 | "SELECT x FROM y WHERE x LIKE '%{col}'", 55 | "SELECT x FROM y WHERE x LIKE ?", 56 | ["%col"], 57 | ), 58 | ( 59 | "INSERT INTO y (x) VALUES (2) ON CONFLICT DO UPDATE SET x = {val}", 60 | "INSERT INTO y (x) VALUES (2) ON CONFLICT DO UPDATE SET x = ?", 61 | [2], 62 | ), 63 | ], 64 | ) 65 | def test_placeholders(query: str, expected_query: str, expected_values: list[Any]) -> None: 66 | col = "col" 67 | tbl = "tbl" 68 | val = 2 69 | with sql_context(columns={"col"}, tables={"tbl"}): 70 | assert (expected_query, expected_values) == sql(query, locals() | globals()) 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "query, expected_query, expected_values", 75 | [ 76 | ( 77 | "SELECT x FROM y WHERE x = {val}", 78 | "SELECT x FROM y", 79 | [], 80 | ), 81 | ( 82 | "SELECT x FROM y WHERE x = 2 AND z = ANY({val})", 83 | "SELECT x FROM y WHERE x = 2", 84 | [], 85 | ), 86 | ( 87 | "SELECT x FROM y WHERE x = 2 AND (u = {val} OR v = 1)", 88 | "SELECT x FROM y WHERE x = 2 AND (v = 1)", 89 | [], 90 | ), 91 | ( 92 | "SELECT x FROM y WHERE x = 2 AND (v = 1 OR u = {val})", 93 | "SELECT x FROM y WHERE x = 2 AND (v = 1)", 94 | [], 95 | ), 96 | ( 97 | "SELECT x FROM y WHERE x = {val} AND (v = 1 OR u = 2)", 98 | "SELECT x FROM y WHERE (v = 1 OR u = 2)", 99 | [], 100 | ), 101 | ( 102 | "SELECT x FROM y WHERE x = 2 AND (v = {val} OR u = {val})", 103 | "SELECT x FROM y WHERE x = 2", 104 | [], 105 | ), 106 | ( 107 | "SELECT x FROM y JOIN z ON u = {val}", 108 | "SELECT x FROM y JOIN z", 109 | [], 110 | ), 111 | ( 112 | "SELECT x FROM y LIMIT {val} OFFSET {val}", 113 | "SELECT x FROM y", 114 | [], 115 | ), 116 | ( 117 | "SELECT x FROM y ORDER BY ARRAY_POSITION({val}, x)", 118 | "SELECT x FROM y", 119 | [], 120 | ), 121 | ( 122 | "SELECT x FROM y WHERE x LIKE '%{val}'", 123 | "SELECT x FROM y", 124 | [], 125 | ), 126 | ( 127 | "UPDATE y SET x = {val}, u = 2", 128 | "UPDATE y SET u = 2", 129 | [], 130 | ), 131 | ( 132 | "INSERT INTO y (x) VALUES ({val})", 133 | "INSERT INTO y (x) VALUES (DEFAULT)", 134 | [], 135 | ), 136 | ( 137 | "SELECT x FROM y WHERE x = ANY('{{1}}') AND y = {val}", 138 | "SELECT x FROM y WHERE x = ANY('{1}')", 139 | [], 140 | ), 141 | ( 142 | "SELECT x FROM y WHERE x = 1 AND y = {val}", 143 | "SELECT x FROM y WHERE x = 1", 144 | [], 145 | ), 146 | ], 147 | ) 148 | def test_absent(query: str, expected_query: str, expected_values: list[Any]) -> None: 149 | val = RewritingValue.ABSENT 150 | assert (expected_query, expected_values) == sql(query, locals()) 151 | 152 | 153 | @pytest.mark.parametrize( 154 | "query, expected_query, expected_values", 155 | [ 156 | ( 157 | "SELECT x FROM y WHERE x = {val}", 158 | "SELECT x FROM y WHERE x IS NULL", 159 | [], 160 | ), 161 | ( 162 | "SELECT x FROM y WHERE x != {val}", 163 | "SELECT x FROM y WHERE x IS NULL", 164 | [], 165 | ), 166 | ( 167 | "SELECT x FROM y WHERE a = ANY(a) AND ((x = '1' AND b = {val}) OR c = 1)", 168 | "SELECT x FROM y WHERE a = ANY(a) AND ((x = '1' AND b IS NULL) OR c = 1)", 169 | [], 170 | ), 171 | ], 172 | ) 173 | def test_is_null(query: str, expected_query: str, expected_values: list[Any]) -> None: 174 | val = RewritingValue.IS_NULL 175 | assert (expected_query, expected_values) == sql(query, locals()) 176 | 177 | 178 | @pytest.mark.parametrize( 179 | "query, expected_query, expected_values", 180 | [ 181 | ( 182 | "SELECT x FROM y WHERE x = {val}", 183 | "SELECT x FROM y WHERE x IS NOT NULL", 184 | [], 185 | ), 186 | ( 187 | "SELECT x FROM y WHERE x != {val}", 188 | "SELECT x FROM y WHERE x IS NOT NULL", 189 | [], 190 | ), 191 | ], 192 | ) 193 | def test_is_not_null(query: str, expected_query: str, expected_values: list[Any]) -> None: 194 | val = RewritingValue.IS_NOT_NULL 195 | assert (expected_query, expected_values) == sql(query, locals()) 196 | 197 | 198 | def test_nested() -> None: 199 | a = "a" 200 | inner = t("x = {a}", locals()) 201 | query, values = sql("SELECT x FROM y WHERE {inner}", locals()) 202 | assert query == "SELECT x FROM y WHERE x = ?" 203 | assert values == ["a"] 204 | -------------------------------------------------------------------------------- /tests/test_parsing.py: -------------------------------------------------------------------------------- 1 | from sql_tstring import sql, t 2 | 3 | 4 | def test_literals() -> None: 5 | query, _ = sql("SELECT x FROM y WHERE x = 'NONE' AND y = 'FUNC_LIKE('", locals()) 6 | assert query == "SELECT x FROM y WHERE x = 'NONE' AND y = 'FUNC_LIKE('" 7 | 8 | 9 | def test_function_literals() -> None: 10 | query, _ = sql("SELECT COALESCE(x, 'a') FROM y WHERE x = 'NONE'", locals()) 11 | assert query == "SELECT COALESCE(x , 'a') FROM y WHERE x = 'NONE'" 12 | 13 | 14 | def test_literal_quote() -> None: 15 | query, _ = sql("SELECT x FROM y WHERE x = 'AB''C'", locals()) 16 | assert query == "SELECT x FROM y WHERE x = 'AB''C'" 17 | 18 | 19 | def test_quoted() -> None: 20 | query, _ = sql('SELECT "x" FROM "y"', locals()) 21 | assert query == 'SELECT "x" FROM "y"' 22 | 23 | 24 | def test_whitespace() -> None: 25 | query, _ = sql( 26 | """SELECT 27 | "x" FROM "y" """, 28 | locals(), 29 | ) 30 | assert query == 'SELECT "x" FROM "y"' 31 | 32 | 33 | def test_delete_from() -> None: 34 | query, _ = sql("DELETE FROM y WHERE x = 'NONE'", locals()) 35 | assert query == "DELETE FROM y WHERE x = 'NONE'" 36 | 37 | 38 | def test_function() -> None: 39 | query, _ = sql("SELECT COALESCE(x, now())", locals()) 40 | assert query == "SELECT COALESCE(x , now())" 41 | 42 | 43 | def test_lowercase() -> None: 44 | query, _ = sql("select x from y where x = 2", locals()) 45 | assert query == "select x from y where x = 2" 46 | 47 | 48 | def test_cte() -> None: 49 | query, _ = sql( 50 | """WITH cte AS (SELECT DISTINCT x FROM y) 51 | SELECT DISTINCT x 52 | FROM z 53 | WHERE x NOT IN (SELECT a FROM b)""", 54 | locals(), 55 | ) 56 | assert ( 57 | query 58 | == """WITH cte AS (SELECT DISTINCT x FROM y) SELECT DISTINCT x FROM z WHERE x NOT IN (SELECT a FROM b)""" # noqa: E501 59 | ) 60 | 61 | 62 | def test_with_conflict() -> None: 63 | a = "A" 64 | b = "B" 65 | query, _ = sql( 66 | """INSERT INTO x (a, b) 67 | VALUES ({a}, {b}) 68 | ON CONFLICT (a) DO UPDATE SET b = {b} 69 | RETURNING a, b""", 70 | locals(), 71 | ) 72 | assert ( 73 | query 74 | == "INSERT INTO x (a , b) VALUES (? , ?) ON CONFLICT (a) DO UPDATE SET b = ? RETURNING a , b" # noqa: E501 75 | ) 76 | 77 | 78 | def test_default_insert() -> None: 79 | query, _ = sql("INSERT INTO tbl DEFAULT VALUES RETURNING id", locals()) 80 | assert query == "INSERT INTO tbl DEFAULT VALUES RETURNING id" 81 | 82 | 83 | def test_grouping() -> None: 84 | query, _ = sql("SELECT x FROM y WHERE (DATE(x) = 1 OR x = 2) AND y = 3", locals()) 85 | assert query == "SELECT x FROM y WHERE (DATE(x) = 1 OR x = 2) AND y = 3" 86 | 87 | 88 | def test_subquery() -> None: 89 | query, _ = sql( 90 | "SELECT x FROM y JOIN (SELECT z FROM a) ON z = x", 91 | locals(), 92 | ) 93 | assert query == "SELECT x FROM y JOIN (SELECT z FROM a) ON z = x" 94 | 95 | 96 | def test_functional_subquery() -> None: 97 | query, _ = sql( 98 | "SELECT x, ARRAY(SELECT a FROM b) FROM y", 99 | locals(), 100 | ) 101 | assert query == "SELECT x , ARRAY(SELECT a FROM b) FROM y" 102 | 103 | 104 | def test_nested() -> None: 105 | inner = t("x = 'a'", locals()) 106 | query, _ = sql("SELECT x FROM y WHERE {inner}", locals()) 107 | assert query == "SELECT x FROM y WHERE x = 'a'" 108 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = format,mypy,py312,py313,pep8,package 3 | isolated_build = True 4 | 5 | [testenv] 6 | deps = 7 | pytest 8 | pytest-cov 9 | pytest-sugar 10 | commands = pytest --cov=sql_tstring {posargs} 11 | passenv = DATABASE_URL 12 | 13 | [testenv:format] 14 | basepython = python3.13 15 | deps = 16 | black 17 | isort 18 | commands = 19 | black --check --diff src/sql_tstring/ tests/ 20 | isort --check --diff src/sql_tstring/ tests/ 21 | 22 | [testenv:pep8] 23 | basepython = python3.13 24 | deps = 25 | flake8 26 | flake8-bugbear 27 | flake8-print 28 | pep8-naming 29 | commands = flake8 src/sql_tstring/ tests/ 30 | 31 | [testenv:mypy] 32 | basepython = python3.13 33 | deps = 34 | mypy 35 | pytest 36 | commands = 37 | mypy src/sql_tstring/ tests/ 38 | 39 | [testenv:package] 40 | basepython = python3.13 41 | deps = 42 | pdm 43 | twine 44 | commands = 45 | pdm build 46 | twine check dist/* 47 | --------------------------------------------------------------------------------