├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── buildpg ├── __init__.py ├── asyncpg.py ├── clauses.py ├── components.py ├── funcs.py ├── logic.py ├── main.py └── version.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── check_tag.py ├── conftest.py ├── requirements-linting.txt ├── requirements.txt ├── test_clauses.py ├── test_logic.py ├── test_main.py └── test_query.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | tags: 8 | - '**' 9 | pull_request: {} 10 | 11 | jobs: 12 | test: 13 | name: test py${{ matrix.python-version }} 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | python-version: ['3.7', '3.8', '3.9', '3.10'] 18 | 19 | runs-on: ubuntu-latest 20 | 21 | services: 22 | postgres: 23 | image: postgres:12 24 | env: 25 | POSTGRES_USER: postgres 26 | POSTGRES_PASSWORD: postgres 27 | ports: 28 | - 5432:5432 29 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 30 | 31 | env: 32 | PYTHON: ${{ matrix.python-version }} 33 | PGPASSWORD: postgres 34 | 35 | steps: 36 | - uses: actions/checkout@v2 37 | 38 | - name: set up python 39 | uses: actions/setup-python@v1 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | 43 | - run: pip install -U wheel 44 | - run: pip install -r tests/requirements.txt 45 | - run: pip install . 46 | - run: pip freeze 47 | 48 | - name: test 49 | run: make test 50 | 51 | - run: coverage xml 52 | 53 | - uses: codecov/codecov-action@v1.5.2 54 | with: 55 | file: ./coverage.xml 56 | env_vars: PYTHON 57 | 58 | lint: 59 | runs-on: ubuntu-latest 60 | 61 | steps: 62 | - uses: actions/checkout@v2 63 | 64 | - uses: actions/setup-python@v1 65 | with: 66 | python-version: '3.8' 67 | 68 | - run: pip install -U wheel 69 | - run: pip install -r tests/requirements-linting.txt 70 | - run: pip install . 71 | 72 | - run: make lint 73 | 74 | deploy: 75 | needs: 76 | - test 77 | - lint 78 | if: "success() && startsWith(github.ref, 'refs/tags/')" 79 | runs-on: ubuntu-latest 80 | 81 | steps: 82 | - uses: actions/checkout@v2 83 | 84 | - name: set up python 85 | uses: actions/setup-python@v1 86 | with: 87 | python-version: '3.8' 88 | 89 | - name: install 90 | run: pip install -U pip setuptools twine wheel 91 | 92 | - name: set version 93 | run: VERSION_PATH='buildpg/version.py' python <(curl -Ls https://git.io/JT3rm) 94 | 95 | - name: build 96 | run: python setup.py sdist bdist_wheel 97 | 98 | - run: twine check dist/* 99 | 100 | - name: upload to pypi 101 | run: twine upload dist/* 102 | env: 103 | TWINE_USERNAME: __token__ 104 | TWINE_PASSWORD: ${{ secrets.pypi_token }} 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | env/ 3 | *.py[cod] 4 | *.egg-info/ 5 | build/ 6 | dist/ 7 | .cache/ 8 | .mypy_cache/ 9 | test.py 10 | .coverage 11 | htmlcov/ 12 | benchmarks/*.json 13 | docs/_build/ 14 | docs/.TMP_HISTORY.rst 15 | .pytest_cache/ 16 | .vscode/ 17 | venv/ 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2018 Samuel Colvin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := all 2 | isort = isort buildpg tests 3 | black = black -S -l 120 --target-version py38 buildpg tests 4 | 5 | .PHONY: install 6 | install: 7 | pip install -U setuptools pip 8 | pip install -U -r tests/requirements.txt 9 | pip install -r tests/requirements-linting.txt 10 | pip install -e . 11 | 12 | .PHONY: format 13 | format: 14 | $(isort) 15 | $(black) 16 | 17 | .PHONY: lint 18 | lint: 19 | flake8 buildpg/ tests/ 20 | $(isort) --check-only 21 | $(black) --check 22 | 23 | .PHONY: test 24 | test: 25 | pytest --cov=buildpg 26 | 27 | .PHONY: testcov 28 | testcov: 29 | pytest --cov=buildpg 30 | @echo "building coverage html" 31 | @coverage html 32 | 33 | .PHONY: all 34 | all: lint testcov 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # buildpg 2 | 3 | [![CI](https://github.com/samuelcolvin/buildpg/workflows/ci/badge.svg?event=push)](https://github.com/samuelcolvin/buildpg/actions?query=event%3Apush+branch%3Amaster+workflow%3Aci) 4 | [![Coverage](https://codecov.io/gh/samuelcolvin/buildpg/branch/master/graph/badge.svg)](https://codecov.io/gh/samuelcolvin/buildpg) 5 | [![pypi](https://img.shields.io/pypi/v/buildpg.svg)](https://pypi.python.org/pypi/buildpg) 6 | [![versions](https://img.shields.io/pypi/pyversions/buildpg.svg)](https://github.com/samuelcolvin/buildpg) 7 | [![license](https://img.shields.io/github/license/samuelcolvin/buildpg.svg)](https://github.com/samuelcolvin/buildpg/blob/master/LICENSE) 8 | 9 | Query building for the postgresql prepared statements and asyncpg. 10 | 11 | Lots of more powerful features, including full clause construction, multiple values, logic functions, 12 | query pretty-printing and different variable substitution - below is just a very quick summary. 13 | Please check the code and tests for examples. 14 | 15 | ## Building Queries 16 | 17 | Simple variable substitution: 18 | 19 | ```py 20 | from buildpg import render 21 | 22 | render('select * from mytable where x=:foo and y=:bar', foo=123, bar='whatever') 23 | >> 'select * from mytable where x=$1 and y=$2', [123, 'whatever'] 24 | ``` 25 | 26 | 27 | Use of `V` to substitute constants: 28 | 29 | ```py 30 | from buildpg import V, render 31 | 32 | render('select * from mytable where :col=:foo', col=V('x'), foo=456) 33 | >> 'select * from mytable where x=$1', [456] 34 | ``` 35 | 36 | Complex logic: 37 | 38 | ```py 39 | from buildpg import V, funcs, render 40 | 41 | where_logic = V('foo.bar') == 123 42 | if spam_value: 43 | where_logic &= V('foo.spam') <= spam_value 44 | 45 | if exclude_cake: 46 | where_logic &= funcs.not_(V('foo.cake').in_([1, 2, 3])) 47 | 48 | render('select * from foo :where', where=where_logic) 49 | >> 'select * from foo foo.bar = $1 AND foo.spam <= $2 AND not(foo.cake in $3)', [123, 123, ['x', 'y']] 50 | ``` 51 | 52 | Values usage: 53 | 54 | ```py 55 | from buildpg import Values, render 56 | 57 | render('insert into the_table (:values__names) values :values', values=Values(a=123, b=456, c='hello')) 58 | >> 'insert into the_table (a, b, c) values ($1, $2, $3)', [123, 456, 'hello'] 59 | ``` 60 | 61 | ## With asyncpg 62 | 63 | As a wrapper around *asyncpg*: 64 | 65 | ```py 66 | import asyncio 67 | from buildpg import asyncpg 68 | 69 | async def main(): 70 | async with asyncpg.create_pool_b('postgres://postgres@localhost:5432/db') as pool: 71 | await pool.fetchval_b('select spam from mytable where x=:foo and y=:bar', foo=123, bar='whatever') 72 | >> 42 73 | 74 | asyncio.run(main()) 75 | ``` 76 | 77 | 78 | Both the pool and connections have `*_b` variants of all common query methods: 79 | 80 | - `execute_b` 81 | - `executemany_b` 82 | - `fetch_b` 83 | - `fetchval_b` 84 | - `fetchrow_b` 85 | - `cursor_b` 86 | 87 | 88 | ## Operators 89 | 90 | | Python operator/function | SQL operator | 91 | | ------------------------ | ------------ | 92 | | `&` | `AND` | 93 | | `|` | `OR` | 94 | | `=` | `=` | 95 | | `!=` | `!=` | 96 | | `<` | `<` | 97 | | `<=` | `<=` | 98 | | `>` | `>` | 99 | | `>=` | `>=` | 100 | | `+` | `+` | 101 | | `-` | `-` | 102 | | `*` | `*` | 103 | | `/` | `/` | 104 | | `%` | `%` | 105 | | `**` | `^` | 106 | | `-` | `-` | 107 | | `~` | `not(...)` | 108 | | `sqrt` | `|/` | 109 | | `abs` | `@` | 110 | | `contains` | `@>` | 111 | | `contained_by` | `<@` | 112 | | `overlap` | `&&` | 113 | | `like` | `LIKE` | 114 | | `ilike` | `ILIKE` | 115 | | `cat` | `||` | 116 | | `in_` | `in` | 117 | | `from_` | `from` | 118 | | `at_time_zone` | `AT TIME ZONE` | 119 | | `matches` | `@@` | 120 | | `is_` | `is` | 121 | | `is_not` | `is not` | 122 | | `for_` | `for` | 123 | | `factorial` | `!` | 124 | | `cast` | `::` | 125 | | `asc` | `ASC` | 126 | | `desc` | `DESC` | 127 | | `comma` | `,` | 128 | | `on` | `ON` | 129 | | `as_` | `AS` | 130 | | `nulls_first` | `NULLS FIRST` | 131 | | `nulls_last` | `NULLS LAST` | 132 | 133 | Usage: 134 | 135 | ```py 136 | from buildpg import V, S, render 137 | 138 | def show(component): 139 | sql, params = render(':c', c=component) 140 | print(f'sql="{sql}" params={params}') 141 | 142 | show(V('foobar').contains([1, 2, 3])) 143 | #> sql="foobar @> $1" params=[[1, 2, 3]] 144 | show(V('foobar') == 4) 145 | #> sql="foobar = $1" params=[4] 146 | show(~V('foobar')) 147 | #> sql="not(foobar)" params=[] 148 | show(S(625).sqrt()) 149 | #> sql="|/ $1" params=[625] 150 | show(V('foo').is_not('true')) 151 | #> sql="foo is not true" params=[] 152 | ``` 153 | 154 | ## Functions 155 | 156 | | Python function | SQL function | 157 | | ------------------------------------------- | ------------- | 158 | | `AND(*args)` | ` and ...` | 159 | | `OR(*args)` | ` or ...` | 160 | | `NOT(arg)` | `not()` | 161 | | `comma_sep(*args)` | `, , ...` | 162 | | `count(expr)` | `count(expr)` | 163 | | `any(arg)` | `any()` | 164 | | `now()` | `now()` | 165 | | `cast(v, cast_type)` | `::` | 166 | | `upper(string)` | `upper()` | 167 | | `lower(string)` | `lower()` | 168 | | `length(string)` | `length()` | 169 | | `left(string, n)` | `left(, )` | 170 | | `right(string, n)` | `right(, )` | 171 | | `extract(expr)` | `extract()` | 172 | | `sqrt(n)` | `|/` | 173 | | `abs(n)` | `@` | 174 | | `factorial(n)` | `!` | 175 | | `position(substring, string)` | `position( in from )` | 178 | | `to_tsquery(arg1, text-None)` | `to_tsquery()` | 179 | 180 | Usage: 181 | 182 | ```py 183 | from buildpg import V, render, funcs 184 | 185 | def show(component): 186 | sql, params = render(':c', c=component) 187 | print(f'sql="{sql}" params={params}') 188 | 189 | show(funcs.AND(V('x') == 4, V('y') > 6)) 190 | #> sql="x = $1 AND y > $2" params=[4, 6] 191 | show(funcs.position('foo', 'this has foo in it')) 192 | #> sql="position($1 in $2)" params=['foo', 'this has foo in it'] 193 | ``` 194 | -------------------------------------------------------------------------------- /buildpg/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from . import clauses, funcs 3 | from .components import * 4 | from .logic import * 5 | from .main import * 6 | from .version import * 7 | -------------------------------------------------------------------------------- /buildpg/asyncpg.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from textwrap import indent 3 | 4 | from asyncpg import * # noqa 5 | from asyncpg.pool import Pool 6 | from asyncpg.protocol import Record 7 | 8 | from .main import render 9 | 10 | try: 11 | import sqlparse 12 | from pygments import highlight 13 | from pygments.formatters import Terminal256Formatter 14 | from pygments.lexers.sql import PlPgsqlLexer 15 | except ImportError: # pragma: no cover 16 | sqlparse = None 17 | 18 | 19 | class _BuildPgMixin: 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | 23 | @staticmethod 24 | def _format_sql(sql, formatted): 25 | if formatted and sqlparse is not None: 26 | sql = indent(sqlparse.format(str(sql), reindent=True), ' ' * 4) 27 | return highlight(sql, PlPgsqlLexer(), Terminal256Formatter(style='monokai')).strip('\n') 28 | else: 29 | return sql.strip('\r\n ') 30 | 31 | def _print_query(self, print_, sql, args): 32 | if print_: 33 | if not callable(print_): 34 | print_ = print 35 | formatted = sys.stdout.isatty() 36 | else: 37 | formatted = getattr(print_, 'formatted', False) 38 | print_(f'params: {args} query:\n{self._format_sql(sql, formatted)}') 39 | 40 | def print_b(self, query_template, *, print_=True, **kwargs): 41 | query, args = render(query_template, **kwargs) 42 | self._print_query(print_, query, args) 43 | return query, args 44 | 45 | async def execute_b(self, query_template, *, _timeout: float = None, print_=False, **kwargs): 46 | query, args = render(query_template, **kwargs) 47 | self._print_query(print_, query, args) 48 | return await self.execute(query, *args, timeout=_timeout) 49 | 50 | async def executemany_b(self, query_template, args, *, timeout: float = None, print_=False): 51 | query, _ = render(query_template, values=args[0]) 52 | args_ = [render.get_params(a) for a in args] 53 | self._print_query(print_, query, args) 54 | return await self.executemany(query, args_, timeout=timeout) 55 | 56 | def cursor_b(self, query_template, *, _timeout: float = None, _prefetch=None, print_=False, **kwargs): 57 | query, args = render(query_template, **kwargs) 58 | self._print_query(print_, query, args) 59 | return self.cursor(query, *args, timeout=_timeout, prefetch=_prefetch) 60 | 61 | async def fetch_b(self, query_template, *, _timeout: float = None, print_=False, **kwargs): 62 | query, args = render(query_template, **kwargs) 63 | self._print_query(print_, query, args) 64 | return await self.fetch(query, *args, timeout=_timeout) 65 | 66 | async def fetchval_b(self, query_template, *, _timeout: float = None, _column=0, print_=False, **kwargs): 67 | query, args = render(query_template, **kwargs) 68 | self._print_query(print_, query, args) 69 | return await self.fetchval(query, *args, timeout=_timeout, column=_column) 70 | 71 | async def fetchrow_b(self, query_template, *, _timeout: float = None, print_=False, **kwargs): 72 | query, args = render(query_template, **kwargs) 73 | self._print_query(print_, query, args) 74 | return await self.fetchrow(query, *args, timeout=_timeout) 75 | 76 | 77 | class BuildPgConnection(_BuildPgMixin, Connection): # noqa 78 | pass 79 | 80 | 81 | async def connect_b(*args, **kwargs): 82 | kwargs.setdefault('connection_class', BuildPgConnection) 83 | return await connect(*args, **kwargs) # noqa 84 | 85 | 86 | class BuildPgPool(_BuildPgMixin, Pool): 87 | pass 88 | 89 | 90 | def create_pool_b( 91 | dsn=None, 92 | *, 93 | min_size=10, 94 | max_size=10, 95 | max_queries=50000, 96 | max_inactive_connection_lifetime=300.0, 97 | setup=None, 98 | init=None, 99 | loop=None, 100 | connection_class=BuildPgConnection, 101 | record_class=Record, 102 | **connect_kwargs, 103 | ): 104 | """ 105 | Create a connection pool. 106 | 107 | Can be used either with an ``async with`` block: 108 | 109 | Identical to ``asyncpg.create_pool`` except that both the pool and connection have the *_b varients of 110 | ``execute``, ``fetch``, ``fetchval``, ``fetchrow`` etc 111 | 112 | Arguments are exactly the same as ``asyncpg.create_pool``. 113 | """ 114 | return BuildPgPool( 115 | dsn, 116 | connection_class=connection_class, 117 | min_size=min_size, 118 | max_size=max_size, 119 | max_queries=max_queries, 120 | loop=loop, 121 | setup=setup, 122 | init=init, 123 | max_inactive_connection_lifetime=max_inactive_connection_lifetime, 124 | record_class=record_class, 125 | **connect_kwargs, 126 | ) 127 | -------------------------------------------------------------------------------- /buildpg/clauses.py: -------------------------------------------------------------------------------- 1 | from . import funcs, logic 2 | from .components import Component, RawDangerous, yield_sep 3 | 4 | 5 | class Clauses(Component): 6 | __slots__ = 'clauses' 7 | 8 | def __init__(self, *clauses): 9 | self.clauses = list(clauses) 10 | 11 | def render(self): 12 | yield from yield_sep(self.clauses, sep=RawDangerous('\n')) 13 | 14 | def __add__(self, other): 15 | self.clauses.append(other) 16 | return self 17 | 18 | 19 | class Clause(Component): 20 | __slots__ = ('logic',) 21 | base = NotImplemented 22 | 23 | def __init__(self, logic): 24 | self.logic = logic 25 | 26 | def render(self): 27 | yield RawDangerous(self.base + ' ') 28 | yield self.logic 29 | 30 | def __add__(self, other): 31 | return Clauses(self, other) 32 | 33 | 34 | def component_or_var(v): 35 | return v if isinstance(v, Component) else logic.Var(v) 36 | 37 | 38 | class Select(Clause): 39 | base = 'SELECT' 40 | 41 | def __init__(self, select): 42 | if not isinstance(select, Component): 43 | select = logic.select_fields(*select) 44 | super().__init__(select) 45 | 46 | 47 | class CommaClause(Clause): 48 | def __init__(self, *fields): 49 | super().__init__(funcs.comma_sep(*[component_or_var(f) for f in fields])) 50 | 51 | 52 | class From(CommaClause): 53 | base = 'FROM' 54 | 55 | 56 | class OrderBy(CommaClause): 57 | base = 'ORDER BY' 58 | 59 | 60 | class Limit(Clause): 61 | base = 'LIMIT' 62 | 63 | def __init__(self, limit_value): 64 | super().__init__(limit_value) 65 | 66 | 67 | class Join(Clause): 68 | base = 'JOIN' 69 | 70 | def __init__(self, table, on_clause=None): 71 | v = logic.as_var(table) 72 | if on_clause: 73 | v = v.on(on_clause) 74 | super().__init__(v) 75 | 76 | 77 | class LeftJoin(Join): 78 | base = 'LEFT JOIN' 79 | 80 | 81 | class RightJoin(Join): 82 | base = 'RIGHT JOIN' 83 | 84 | 85 | class FullJoin(Join): 86 | base = 'FULL JOIN' 87 | 88 | 89 | class CrossJoin(Join): 90 | base = 'CROSS JOIN' 91 | 92 | 93 | class Where(Clause): 94 | base = 'WHERE' 95 | 96 | 97 | class Offset(Clause): 98 | base = 'OFFSET' 99 | 100 | def __init__(self, offset_value): 101 | super().__init__(offset_value) 102 | -------------------------------------------------------------------------------- /buildpg/components.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | __all__ = ( 4 | 'check_word', 5 | 'BuildError', 6 | 'ComponentError', 7 | 'UnsafeError', 8 | 'yield_sep', 9 | 'RawDangerous', 10 | 'VarLiteral', 11 | 'Component', 12 | 'Values', 13 | 'MultipleValues', 14 | 'SetValues', 15 | 'JoinComponent', 16 | ) 17 | 18 | NOT_WORD = re.compile(r'[^\w.*]', flags=re.A) 19 | 20 | 21 | def check_word(s): 22 | if not isinstance(s, str): 23 | raise TypeError('value is not a string') 24 | if NOT_WORD.search(s): 25 | raise UnsafeError(f'str contain unsafe (non word) characters: "{s}"') 26 | 27 | 28 | def check_word_many(args): 29 | if any(not isinstance(a, str) or NOT_WORD.search(a) for a in args): 30 | unsafe = [a for a in args if not isinstance(a, str) or NOT_WORD.search(a)] 31 | raise UnsafeError(f'raw arguments contain unsafe (non word) characters: {unsafe}') 32 | 33 | 34 | class BuildError(RuntimeError): 35 | pass 36 | 37 | 38 | class ComponentError(BuildError): 39 | pass 40 | 41 | 42 | class UnsafeError(ComponentError): 43 | pass 44 | 45 | 46 | class RawDangerous(str): 47 | pass 48 | 49 | 50 | def yield_sep(iterable, sep=RawDangerous(', ')): 51 | iter_ = iter(iterable) 52 | yield next(iter_) 53 | for v in iter_: 54 | yield sep 55 | yield v 56 | 57 | 58 | class VarLiteral(RawDangerous): 59 | def __init__(self, s: str): 60 | check_word(s) 61 | str.__init__(s) 62 | 63 | 64 | class Component: 65 | def render(self): 66 | raise NotImplementedError() 67 | 68 | def __str__(self): 69 | return ''.join(self._get_chunks(self.render())) 70 | 71 | @classmethod 72 | def _get_chunks(cls, gen): 73 | for chunk in gen: 74 | if isinstance(chunk, Component): 75 | yield from cls._get_chunks(chunk.render()) 76 | elif isinstance(chunk, str): 77 | yield chunk 78 | else: 79 | yield str(chunk) 80 | 81 | def __repr__(self): 82 | return f'' 83 | 84 | 85 | class Values(Component): 86 | __slots__ = 'values', 'names' 87 | 88 | def __init__(self, *args, **kwargs): 89 | if (args and kwargs) or (not args and not kwargs): 90 | raise ValueError('either args or kwargs are required but not both') 91 | 92 | if args: 93 | self.names = None 94 | self.values = args 95 | else: 96 | self.names, self.values = zip(*kwargs.items()) 97 | check_word_many(self.names) 98 | 99 | def render(self): 100 | yield RawDangerous('(') 101 | yield from yield_sep(self.values) 102 | yield RawDangerous(')') 103 | 104 | def render_names(self): 105 | if not self.names: 106 | raise ComponentError('"names" are not available for nameless values') 107 | yield RawDangerous(', '.join(self.names)) 108 | 109 | 110 | class MultipleValues(Component): 111 | __slots__ = 'values', 'names' 112 | 113 | def __init__(self, *args): 114 | first = args[0] 115 | self.names = first.names 116 | self.render_names = first.render_names 117 | self.rows = [first] 118 | expected_len = len(first.values) 119 | for r in args[1:]: 120 | if not isinstance(r, Values): 121 | raise ValueError('either all or no arguments should be Values()') 122 | if r.names != self.names: 123 | raise ValueError(f'names of different rows do not match {r.names} != {self.names}') 124 | if not self.names and len(r.values) != expected_len: 125 | raise ValueError(f"row lengths don't match {r.values} != {expected_len}") 126 | self.rows.append(r) 127 | 128 | def render(self): 129 | yield from yield_sep(self.rows) 130 | 131 | 132 | class SetValues(Component): 133 | __slots__ = ('kwargs',) 134 | 135 | def __init__(self, **kwargs): 136 | self.kwargs = kwargs 137 | 138 | @staticmethod 139 | def _yield_pairs(k, v): 140 | yield VarLiteral(k) 141 | yield RawDangerous(' = ') 142 | yield v 143 | 144 | def render(self): 145 | iter_ = iter(self.kwargs.items()) 146 | yield from self._yield_pairs(*next(iter_)) 147 | for k, v in iter_: 148 | yield RawDangerous(', ') 149 | yield from self._yield_pairs(k, v) 150 | 151 | 152 | class JoinComponent(Component): 153 | __slots__ = 'sep', 'items' 154 | 155 | def __init__(self, items, sep=RawDangerous(', ')): 156 | self.items = items 157 | self.sep = sep 158 | 159 | def render(self): 160 | yield from yield_sep(self.items, self.sep) 161 | -------------------------------------------------------------------------------- /buildpg/funcs.py: -------------------------------------------------------------------------------- 1 | from . import logic 2 | from .components import JoinComponent 3 | 4 | 5 | def AND(arg, *args): 6 | v = logic.as_sql_block(arg) 7 | for a in args: 8 | v &= a 9 | return v 10 | 11 | 12 | def OR(arg, *args): 13 | v = logic.as_sql_block(arg) 14 | for a in args: 15 | v |= a 16 | return v 17 | 18 | 19 | def comma_sep(arg, *args): 20 | return JoinComponent((arg,) + args) 21 | 22 | 23 | def count(expr): 24 | return logic.Func('COUNT', logic.as_var(expr)) 25 | 26 | 27 | def NOT(arg): 28 | return logic.Func('not', arg) 29 | 30 | 31 | def any(arg): 32 | return logic.Func('any', arg) 33 | 34 | 35 | def now(): 36 | return logic.SqlBlock(logic.RawDangerous('now()')) 37 | 38 | 39 | def cast(v, cast_type): 40 | return logic.as_sql_block(v).cast(cast_type) 41 | 42 | 43 | def upper(string): 44 | return logic.Func('upper', string) 45 | 46 | 47 | def lower(string): 48 | return logic.Func('lower', string) 49 | 50 | 51 | def length(string): 52 | return logic.Func('length', string) 53 | 54 | 55 | def left(string, n): 56 | return logic.Func('left', string, n) 57 | 58 | 59 | def right(string, n): 60 | return logic.Func('right', string, n) 61 | 62 | 63 | def extract(expr): 64 | return logic.Func('extract', expr) 65 | 66 | 67 | def sqrt(n): 68 | return logic.as_sql_block(n).sqrt() 69 | 70 | 71 | def abs(n): 72 | return logic.as_sql_block(n).abs() 73 | 74 | 75 | def factorial(n): 76 | return logic.as_sql_block(n).factorial() 77 | 78 | 79 | def position(substring, string): 80 | return logic.Func('position', logic.SqlBlock(substring).in_(string)) 81 | 82 | 83 | def substring(string, pattern, escape=None): 84 | a = logic.SqlBlock(string).from_(pattern) 85 | if escape: 86 | a = a.for_(escape) 87 | return logic.Func('substring', a) 88 | 89 | 90 | def to_tsvector(arg1, document=None): 91 | return logic.Func('to_tsvector', arg1) if document is None else logic.Func('to_tsvector', arg1, document) 92 | 93 | 94 | def to_tsquery(arg1, text=None): 95 | return logic.Func('to_tsquery', arg1) if text is None else logic.Func('to_tsquery', arg1, text) 96 | -------------------------------------------------------------------------------- /buildpg/logic.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, unique 2 | from typing import Union 3 | 4 | from .components import Component, JoinComponent, RawDangerous, VarLiteral, check_word, yield_sep 5 | 6 | __all__ = ('LogicError', 'SqlBlock', 'Func', 'Not', 'Var', 'S', 'V', 'select_fields', 'Empty') 7 | 8 | 9 | class LogicError(RuntimeError): 10 | pass 11 | 12 | 13 | @unique 14 | class Operator(str, Enum): 15 | and_ = ' AND ' 16 | or_ = ' OR ' 17 | eq = ' = ' 18 | ne = ' != ' 19 | lt = ' < ' 20 | le = ' <= ' 21 | gt = ' > ' 22 | ge = ' >= ' 23 | add = ' + ' 24 | sub = ' - ' 25 | mul = ' * ' 26 | div = ' / ' 27 | mod = ' % ' 28 | pow = ' ^ ' 29 | contains = ' @> ' 30 | contained_by = ' <@ ' 31 | overlap = ' && ' 32 | like = ' LIKE ' 33 | ilike = ' ILIKE ' 34 | cat = ' || ' 35 | in_ = ' in ' 36 | from_ = ' from ' 37 | at_time_zone = ' AT TIME ZONE ' 38 | matches = ' @@ ' 39 | is_ = ' is ' 40 | is_not = ' is not ' 41 | for_ = ' for ' 42 | factorial = '!' 43 | cast = '::' 44 | asc = ' ASC' 45 | desc = ' DESC' 46 | comma = ', ' 47 | on = ' ON ' 48 | as_ = ' AS ' 49 | func = '_function_' 50 | nulls_first = ' NULLS FIRST' 51 | nulls_last = ' NULLS LAST' 52 | 53 | 54 | @unique 55 | class LeftOperator(str, Enum): 56 | neg = '-' 57 | sqrt = '|/ ' 58 | abs = '@ ' 59 | 60 | 61 | PRECEDENCE = { 62 | Operator.func: 140, 63 | Operator.cast: 130, 64 | LeftOperator.neg: 120, 65 | LeftOperator.sqrt: 120, 66 | LeftOperator.abs: 120, 67 | Operator.pow: 100, 68 | Operator.mod: 90, 69 | Operator.mul: 80, 70 | Operator.div: 70, 71 | Operator.add: 60, 72 | Operator.sub: 50, 73 | Operator.contains: 35, 74 | Operator.contained_by: 35, 75 | Operator.overlap: 35, 76 | Operator.like: 35, 77 | Operator.ilike: 35, 78 | Operator.cat: 35, 79 | Operator.in_: 35, 80 | Operator.from_: 35, 81 | Operator.at_time_zone: 35, 82 | Operator.matches: 35, 83 | Operator.is_: 35, 84 | Operator.is_not: 35, 85 | Operator.eq: 40, 86 | Operator.ne: 40, 87 | Operator.lt: 40, 88 | Operator.le: 40, 89 | Operator.gt: 40, 90 | Operator.ge: 40, 91 | Operator.asc: 30, 92 | Operator.desc: 30, 93 | Operator.nulls_first: 30, 94 | Operator.nulls_last: 30, 95 | Operator.on: 30, 96 | Operator.as_: 30, 97 | Operator.and_: 20, 98 | Operator.or_: 10, 99 | Operator.comma: 0, 100 | } 101 | 102 | 103 | class SqlBlock(Component): 104 | __slots__ = 'v1', 'op', 'v2' 105 | 106 | def __init__(self, v1, *, op=None, v2=None): 107 | self.v1 = v1 108 | self.op = op 109 | self.v2 = v2 110 | 111 | def operate(self, op: Union[RawDangerous, Operator], v2=None): 112 | if self.op: 113 | # op already completed 114 | return SqlBlock(self, op=op, v2=v2) 115 | else: 116 | self.op = op 117 | self.v2 = v2 118 | return self 119 | 120 | def __and__(self, other): 121 | return self.operate(Operator.and_, other) 122 | 123 | def __or__(self, other): 124 | return self.operate(Operator.or_, other) 125 | 126 | def __eq__(self, other): 127 | return self.operate(Operator.eq, other) 128 | 129 | def __ne__(self, other): 130 | return self.operate(Operator.ne, other) 131 | 132 | def __lt__(self, other): 133 | return self.operate(Operator.lt, other) 134 | 135 | def __le__(self, other): 136 | return self.operate(Operator.le, other) 137 | 138 | def __gt__(self, other): 139 | return self.operate(Operator.gt, other) 140 | 141 | def __ge__(self, other): 142 | return self.operate(Operator.ge, other) 143 | 144 | def __add__(self, other): 145 | return self.operate(Operator.add, other) 146 | 147 | def __sub__(self, other): 148 | return self.operate(Operator.sub, other) 149 | 150 | def __mul__(self, other): 151 | return self.operate(Operator.mul, other) 152 | 153 | def __truediv__(self, other): 154 | return self.operate(Operator.div, other) 155 | 156 | def __mod__(self, other): 157 | return self.operate(Operator.mod, other) 158 | 159 | def __pow__(self, other): 160 | return self.operate(Operator.pow, other) 161 | 162 | def __invert__(self): 163 | return Not(self) 164 | 165 | def __neg__(self): 166 | return LeftOp(LeftOperator.neg, self) 167 | 168 | def sqrt(self): 169 | return LeftOp(LeftOperator.sqrt, self) 170 | 171 | def abs(self): 172 | return LeftOp(LeftOperator.abs, self) 173 | 174 | def factorial(self): 175 | return self.operate(Operator.factorial) 176 | 177 | def contains(self, other): 178 | return self.operate(Operator.contains, other) 179 | 180 | def contained_by(self, other): 181 | return self.operate(Operator.contained_by, other) 182 | 183 | def overlap(self, other): 184 | return self.operate(Operator.overlap, other) 185 | 186 | def like(self, other): 187 | return self.operate(Operator.like, other) 188 | 189 | def ilike(self, other): 190 | return self.operate(Operator.ilike, other) 191 | 192 | def cat(self, other): 193 | return self.operate(Operator.cat, other) 194 | 195 | def in_(self, other): 196 | return self.operate(Operator.in_, other) 197 | 198 | def from_(self, other): 199 | return self.operate(Operator.from_, other) 200 | 201 | def at_time_zone(self, other): 202 | return self.operate(Operator.at_time_zone, other) 203 | 204 | def matches(self, other): 205 | return self.operate(Operator.matches, other) 206 | 207 | def is_(self, other): 208 | return self.operate(Operator.is_, other) 209 | 210 | def is_not(self, other): 211 | return self.operate(Operator.is_not, other) 212 | 213 | def for_(self, other): 214 | return self.operate(Operator.for_, other) 215 | 216 | def cast(self, cast_type): 217 | return self.operate(Operator.cast, as_var(cast_type)) 218 | 219 | def asc(self): 220 | return self.operate(Operator.asc) 221 | 222 | def desc(self): 223 | return self.operate(Operator.desc) 224 | 225 | def nulls_first(self): 226 | return self.operate(Operator.nulls_first) 227 | 228 | def nulls_last(self): 229 | return self.operate(Operator.nulls_last) 230 | 231 | def comma(self, other): 232 | return self.operate(Operator.comma, other) 233 | 234 | def on(self, other): 235 | return self.operate(Operator.on, other) 236 | 237 | def as_(self, other): 238 | return self.operate(Operator.as_, as_var(other)) 239 | 240 | def _should_parenthesise(self, v): 241 | if self.op: 242 | sub_op = getattr(v, 'op', None) 243 | if sub_op and PRECEDENCE.get(self.op, 0) > PRECEDENCE.get(sub_op, -1): 244 | return True 245 | 246 | def _bracket(self, v): 247 | if self._should_parenthesise(v): 248 | yield RawDangerous('(') 249 | yield v 250 | yield RawDangerous(')') 251 | else: 252 | yield v 253 | 254 | def render(self): 255 | yield from self._bracket(self.v1) 256 | if self.op: 257 | if isinstance(self.op, RawDangerous): 258 | yield self.op 259 | else: 260 | yield RawDangerous(self.op.value) 261 | if self.v2 is not None: 262 | yield from self._bracket(self.v2) 263 | 264 | 265 | def as_sql_block(n): 266 | if isinstance(n, SqlBlock): 267 | return n 268 | else: 269 | return SqlBlock(n) 270 | 271 | 272 | def as_var(n): 273 | if isinstance(n, SqlBlock): 274 | return n 275 | else: 276 | return Var(n) 277 | 278 | 279 | class Func(SqlBlock): 280 | __slots__ = 'v1', 'op', 'func', 'v2' 281 | allow_unsafe = False 282 | 283 | def __init__(self, func, *args): 284 | self.func = func 285 | if not self.allow_unsafe: 286 | check_word(func) 287 | super().__init__(args, op=Operator.func) 288 | 289 | def operate(self, op: Operator, v2=None): 290 | return SqlBlock(self, op=op, v2=v2) 291 | 292 | def render(self): 293 | yield RawDangerous(self.func + '(') 294 | yield from yield_sep(self.v1) 295 | yield RawDangerous(')') 296 | 297 | 298 | class LeftOp(Func): 299 | allow_unsafe = True 300 | 301 | def render(self): 302 | yield RawDangerous(self.func.value) 303 | yield from self._bracket(self.v1[0]) 304 | 305 | 306 | class Not(Func): 307 | def __init__(self, v): 308 | super().__init__('not', v) 309 | 310 | def __invert__(self): 311 | return self.v1[0] 312 | 313 | 314 | class Var(SqlBlock): 315 | def __init__(self, v1, *, op: Operator = None, v2=None): 316 | super().__init__(VarLiteral(v1), op=op, v2=v2) 317 | 318 | 319 | def select_fields(arg, *args): 320 | return JoinComponent([as_var(v) for v in (arg,) + args]) 321 | 322 | 323 | class Empty(SqlBlock): 324 | def __init__(self, *, op: Operator = None, v2=None): 325 | super().__init__(VarLiteral(''), op=op, v2=v2) 326 | 327 | 328 | S = SqlBlock 329 | V = Var 330 | -------------------------------------------------------------------------------- /buildpg/main.py: -------------------------------------------------------------------------------- 1 | import re 2 | from functools import partial 3 | 4 | from .components import BuildError, Component, ComponentError, RawDangerous 5 | 6 | __all__ = ('Renderer', 'render') 7 | 8 | 9 | class Renderer: 10 | __slots__ = 'regex', 'sep' 11 | 12 | def __init__(self, regex=r'(?=3.6', 48 | zip_safe=True, 49 | ) 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samuelcolvin/buildpg/e2a16abea5c7607b53c501dbae74a5765ba66e15/tests/__init__.py -------------------------------------------------------------------------------- /tests/check_tag.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | 5 | from buildpg.version import VERSION 6 | 7 | git_tag = os.getenv('TRAVIS_TAG') 8 | if git_tag: 9 | if git_tag.lower().lstrip('v') != str(VERSION).lower(): 10 | print('✖ "TRAVIS_TAG" environment variable does not match package version: "%s" vs. "%s"' % (git_tag, VERSION)) 11 | sys.exit(1) 12 | else: 13 | print('✓ "TRAVIS_TAG" environment variable matches package version: "%s" vs. "%s"' % (git_tag, VERSION)) 14 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import pytest 4 | 5 | from buildpg import asyncpg 6 | 7 | DB_NAME = 'buildpg_test' 8 | 9 | # some simple data to use in queries 10 | POPULATE_DB = """ 11 | DROP SCHEMA public CASCADE; 12 | CREATE SCHEMA public; 13 | 14 | CREATE TABLE companies ( 15 | id SERIAL PRIMARY KEY, 16 | name VARCHAR(255) NOT NULL UNIQUE, 17 | description TEXT 18 | ); 19 | 20 | CREATE TABLE users ( 21 | id SERIAL PRIMARY KEY, 22 | company INT NOT NULL REFERENCES companies ON DELETE CASCADE, 23 | first_name VARCHAR(255), 24 | last_name VARCHAR(255), 25 | value INT, 26 | created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 27 | ); 28 | 29 | INSERT INTO companies (name, description) VALUES 30 | ('foobar', 'this is some test data'), 31 | ('egg plant', 'this is some more test data'); 32 | 33 | WITH values_ (company_name_, first_name_, last_name_, value_, created_) AS (VALUES 34 | ('foobar', 'Franks', 'spencer', 44, date '2020-01-28'), 35 | ('foobar', 'Fred', 'blogs', -10, date '2021-01-28'), 36 | ('egg plant', 'Joe', null, 1000, date '2022-01-28') 37 | ) 38 | INSERT INTO users (company, first_name, last_name, value, created) 39 | SELECT c.id, first_name_, last_name_, value_, created_ FROM values_ 40 | JOIN companies AS c ON company_name_=c.name; 41 | """ 42 | 43 | 44 | async def _reset_db(): 45 | conn = await asyncpg.connect('postgresql://postgres@localhost') 46 | try: 47 | if not await conn.fetchval('SELECT 1 from pg_database WHERE datname=$1;', DB_NAME): 48 | await conn.execute(f'CREATE DATABASE {DB_NAME};') 49 | finally: 50 | await conn.close() 51 | conn = await asyncpg.connect(f'postgresql://postgres@localhost/{DB_NAME}') 52 | try: 53 | await conn.execute(POPULATE_DB) 54 | finally: 55 | await conn.close() 56 | 57 | 58 | @pytest.fixture(scope='session') 59 | def db(): 60 | asyncio.run(_reset_db()) 61 | 62 | 63 | @pytest.fixture 64 | async def conn(db): 65 | conn = await asyncpg.connect_b(f'postgresql://postgres@localhost/{DB_NAME}') 66 | tr = conn.transaction() 67 | await tr.start() 68 | 69 | yield conn 70 | 71 | await tr.rollback() 72 | await conn.close() 73 | -------------------------------------------------------------------------------- /tests/requirements-linting.txt: -------------------------------------------------------------------------------- 1 | black==20.8b1 2 | flake8==3.9.2 3 | flake8-quotes==3.2.0 4 | isort==5.9.2 5 | pycodestyle==2.7.0 6 | pyflakes==2.3.1 7 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | asyncpg==0.23.0 2 | coverage==5.5 3 | pygments==2.9.0 4 | pytest==6.2.4 5 | pytest-asyncio==0.15.1 6 | pytest-cov==2.12.1 7 | pytest-mock==3.6.1 8 | pytest-sugar==0.9.4 9 | sqlparse==0.4.2 10 | -------------------------------------------------------------------------------- /tests/test_clauses.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from buildpg import V, clauses, render 4 | 5 | 6 | @pytest.mark.parametrize( 7 | 'block,expected_query,expected_params', 8 | [ 9 | (lambda: clauses.Select(['foo', 'bar']), 'SELECT foo, bar', []), 10 | (lambda: clauses.Select(['foo', 'bar']), 'SELECT foo, bar', []), 11 | (lambda: clauses.Select(V('foo').comma(V('bar'))), 'SELECT foo, bar', []), 12 | (lambda: clauses.Select([V('foo').as_('x'), V('bar').as_('y')]), 'SELECT foo AS x, bar AS y', []), 13 | (lambda: clauses.From('foobar'), 'FROM foobar', []), 14 | (lambda: clauses.From('foo', 'bar'), 'FROM foo, bar', []), 15 | (lambda: clauses.From('foo', V('bar')), 'FROM foo, bar', []), 16 | (lambda: clauses.Join('foobar', V('x.id') == V('y.id')), 'JOIN foobar ON x.id = y.id', []), 17 | (lambda: clauses.CrossJoin('xxx'), 'CROSS JOIN xxx', []), 18 | (lambda: clauses.From('a') + clauses.Join('b') + clauses.Join('c'), 'FROM a\nJOIN b\nJOIN c', []), 19 | (lambda: clauses.OrderBy('apple', V('pear').desc()), 'ORDER BY apple, pear DESC', []), 20 | (lambda: clauses.Limit(20), 'LIMIT $1', [20]), 21 | (lambda: clauses.Offset(20), 'OFFSET $1', [20]), 22 | (lambda: clauses.Join('foobar', V('x.value') == 0), 'JOIN foobar ON x.value = $1', [0]), 23 | ], 24 | ) 25 | def test_simple_blocks(block, expected_query, expected_params): 26 | query, params = render(':v', v=block()) 27 | assert expected_query == query 28 | assert expected_params == params 29 | 30 | 31 | def test_where(): 32 | query, params = render(':v', v=clauses.Where((V('x') == 4) & (V('y').like('xxx')))) 33 | assert 'WHERE x = $1 AND y LIKE $2' == query 34 | assert [4, 'xxx'] == params 35 | 36 | 37 | @pytest.mark.parametrize('value,expected', [(0, 0), (False, False), ('', '')]) 38 | def test_falsey_values(value, expected): 39 | query, params = render('WHERE :a', a=V('a') == value) 40 | assert query == 'WHERE a = $1' 41 | assert params == [expected] 42 | assert type(params[0]) == type(expected) 43 | 44 | 45 | def test_args(): 46 | query, params = render('WHERE :a', a=V('a') == True) # noqa: E712 47 | assert query == 'WHERE a = $1' 48 | assert params == [True] 49 | assert params[0] is True 50 | -------------------------------------------------------------------------------- /tests/test_logic.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from buildpg import Empty, Func, RawDangerous, S, SqlBlock, V, Var, funcs, render, select_fields 4 | 5 | args = 'template', 'var', 'expected_query', 'expected_params' 6 | TESTS = [ 7 | {'template': 'eq: :var', 'var': lambda: S(1) > 2, 'expected_query': 'eq: $1 > $2', 'expected_params': [1, 2]}, 8 | {'template': 'and: :var', 'var': lambda: S(1) & 2, 'expected_query': 'and: $1 AND $2', 'expected_params': [1, 2]}, 9 | {'template': 'or: :var', 'var': lambda: S(1) | 2, 'expected_query': 'or: $1 OR $2', 'expected_params': [1, 2]}, 10 | { 11 | 'template': 'inv: :var', 12 | 'var': lambda: ~(S(1) % 2), 13 | 'expected_query': 'inv: not($1 % $2)', 14 | 'expected_params': [1, 2], 15 | }, 16 | { 17 | 'template': 'double inv: :var', 18 | 'var': lambda: ~~(S(1) % 2), 19 | 'expected_query': 'double inv: $1 % $2', 20 | 'expected_params': [1, 2], 21 | }, 22 | {'template': 'eq: :var', 'var': lambda: Var('foo') > 2, 'expected_query': 'eq: foo > $1', 'expected_params': [2]}, 23 | { 24 | 'template': 'chain: :var', 25 | 'var': lambda: V('x') + 4 + 2 + 1, 26 | 'expected_query': 'chain: x + $1 + $2 + $3', 27 | 'expected_params': [4, 2, 1], 28 | }, 29 | { 30 | 'template': 'complex: :var', 31 | 'var': lambda: (Var('foo') > 2) & ((SqlBlock('x') / 7 > 3) | (Var('another') ** 4 == 42)), 32 | 'expected_query': 'complex: foo > $1 AND ($2 / $3 > $4 OR another ^ $5 = $6)', 33 | 'expected_params': [2, 'x', 7, 3, 4, 42], 34 | }, 35 | { 36 | 'template': 'complex AND OR: :var', 37 | # as above but using the AND and OR functions for clear usage 38 | 'var': lambda: funcs.AND(V('foo') > 2, funcs.OR(S('x') / 7 > 3, V('another') ** 4 == 42)), 39 | 'expected_query': 'complex AND OR: foo > $1 AND ($2 / $3 > $4 OR another ^ $5 = $6)', 40 | 'expected_params': [2, 'x', 7, 3, 4, 42], 41 | }, 42 | { 43 | 'template': 'inv chain: :var', 44 | 'var': lambda: ~(V('x') == 2) | V('y'), 45 | 'expected_query': 'inv chain: not(x = $1) OR y', 46 | 'expected_params': [2], 47 | }, 48 | { 49 | 'template': 'func: :var', 50 | 'var': lambda: funcs.upper(12), 51 | 'expected_query': 'func: upper($1)', 52 | 'expected_params': [12], 53 | }, 54 | { 55 | 'template': 'func args: :var', 56 | 'var': lambda: funcs.left(V('x').cat('xx'), V('y') + 4 + 5), 57 | 'expected_query': 'func args: left(x || $1, y + $2 + $3)', 58 | 'expected_params': ['xx', 4, 5], 59 | }, 60 | { 61 | 'template': 'abs neg: :var', 62 | 'var': lambda: funcs.abs(-S(4)), 63 | 'expected_query': 'abs neg: @ -$1', 64 | 'expected_params': [4], 65 | }, 66 | { 67 | 'template': 'abs brackets: :var', 68 | 'var': lambda: funcs.abs(S(4) + S(5)), 69 | 'expected_query': 'abs brackets: @ ($1 + $2)', 70 | 'expected_params': [4, 5], 71 | }, 72 | ] 73 | 74 | 75 | @pytest.mark.parametrize(','.join(args), [[t[a] for a in args] for t in TESTS], ids=[t['template'] for t in TESTS]) 76 | def test_render(template, var, expected_query, expected_params): 77 | query, params = render(template, var=var()) 78 | assert expected_query == query 79 | assert expected_params == params 80 | 81 | 82 | @pytest.mark.parametrize( 83 | 'block,expected_query', 84 | [ 85 | (lambda: SqlBlock(1) == 2, '$1 = $2'), 86 | (lambda: SqlBlock(1) != 2, '$1 != $2'), 87 | (lambda: SqlBlock(1) < 2, '$1 < $2'), 88 | (lambda: SqlBlock(1) <= 2, '$1 <= $2'), 89 | (lambda: SqlBlock(1) > 2, '$1 > $2'), 90 | (lambda: SqlBlock(1) >= 2, '$1 >= $2'), 91 | (lambda: SqlBlock(1) + 2, '$1 + $2'), 92 | (lambda: SqlBlock(1) - 2, '$1 - $2'), 93 | (lambda: SqlBlock(1) * 2, '$1 * $2'), 94 | (lambda: SqlBlock(1) / 2, '$1 / $2'), 95 | (lambda: SqlBlock(1) % 2, '$1 % $2'), 96 | (lambda: SqlBlock(1) ** 2, '$1 ^ $2'), 97 | (lambda: V('x').contains(V('y')), 'x @> y'), 98 | (lambda: V('x').overlap(V('y')), 'x && y'), 99 | (lambda: V('x').contained_by(V('y')), 'x <@ y'), 100 | (lambda: V('x').like('y'), 'x LIKE $1'), 101 | (lambda: V('x').ilike('y'), 'x ILIKE $1'), 102 | (lambda: V('x').cat('y'), 'x || $1'), 103 | (lambda: V('x').comma(V('y')), 'x, y'), 104 | (lambda: V('a').asc(), 'a ASC'), 105 | (lambda: V('a').desc(), 'a DESC'), 106 | (lambda: V('a').nulls_first(), 'a NULLS FIRST'), 107 | (lambda: V('a').nulls_last(), 'a NULLS LAST'), 108 | (lambda: V('a').desc().nulls_first(), 'a DESC NULLS FIRST'), 109 | (lambda: ~V('a'), 'not(a)'), 110 | (lambda: -V('a'), '-a'), 111 | (lambda: V('x').operate(RawDangerous(' foobar '), V('y')), 'x foobar y'), 112 | (lambda: funcs.sqrt(4), '|/ $1'), 113 | (lambda: funcs.abs(4), '@ $1'), 114 | (lambda: funcs.factorial(4), '$1!'), 115 | (lambda: funcs.count('*'), 'COUNT(*)'), 116 | (lambda: funcs.count(V('*')), 'COUNT(*)'), 117 | (lambda: funcs.count('*').as_('foobar'), 'COUNT(*) AS foobar'), 118 | (lambda: funcs.upper('a'), 'upper($1)'), 119 | (lambda: funcs.lower('a'), 'lower($1)'), 120 | (lambda: funcs.lower('a') > 4, 'lower($1) > $2'), 121 | (lambda: funcs.length('a'), 'length($1)'), 122 | (lambda: funcs.position('a', 'b'), 'position($1 in $2)'), 123 | (lambda: funcs.substring('a', 'b'), 'substring($1 from $2)'), 124 | (lambda: funcs.substring('x', 2, 3), 'substring($1 from $2 for $3)'), 125 | (lambda: funcs.extract(V('epoch').from_(V('foo.bar'))).cast('int'), 'extract(epoch from foo.bar)::int'), 126 | (lambda: funcs.AND('a', 'b', 'c'), '$1 AND $2 AND $3'), 127 | (lambda: funcs.AND('a', 'b', V('c') | V('d')), '$1 AND $2 AND (c OR d)'), 128 | (lambda: funcs.OR('a', 'b', V('c') & V('d')), '$1 OR $2 OR c AND d'), 129 | (lambda: funcs.comma_sep('a', 'b', V('c') | V('d')), '$1, $2, c OR d'), 130 | (lambda: funcs.comma_sep(V('first_name'), 123), 'first_name, $1'), 131 | (lambda: Func('foobar', V('x'), V('y')), 'foobar(x, y)'), 132 | (lambda: Func('foobar', funcs.comma_sep('x', 'y')), 'foobar($1, $2)'), 133 | (lambda: Empty() & (V('foo') == 4), ' AND foo = $1'), 134 | (lambda: Empty() & (V('bar') == 4), ' AND bar = $1'), 135 | (lambda: V('epoch').at_time_zone('MST'), 'epoch AT TIME ZONE $1'), 136 | (lambda: S('2032-02-16 19:38:40-08').at_time_zone('MST'), '$1 AT TIME ZONE $2'), 137 | (lambda: V('foo').matches(V('bar')), 'foo @@ bar'), 138 | (lambda: V('foo').is_(V('bar')), 'foo is bar'), 139 | (lambda: V('foo').is_not(V('bar')), 'foo is not bar'), 140 | ( 141 | lambda: funcs.to_tsvector('fat cats ate rats').matches(funcs.to_tsquery('cat & rat')), 142 | 'to_tsvector($1) @@ to_tsquery($2)', 143 | ), 144 | (lambda: funcs.now(), 'now()'), 145 | (lambda: select_fields('foo', 'bar'), 'foo, bar'), 146 | (lambda: select_fields('foo', S(RawDangerous("'raw text'"))), "foo, 'raw text'"), 147 | (lambda: funcs.NOT(V('a').in_([1, 2])), 'not(a in $1)'), 148 | (lambda: ~V('a').in_([1, 2]), 'not(a in $1)'), 149 | (lambda: funcs.NOT(V('a') == funcs.any([1, 2])), 'not(a = any($1))'), 150 | ], 151 | ) 152 | def test_simple_blocks(block, expected_query): 153 | query, _ = render(':v', v=block()) 154 | assert expected_query == query 155 | 156 | 157 | @pytest.mark.parametrize( 158 | 'block,s', 159 | [ 160 | (lambda: SqlBlock(1) == 2, ''), 161 | (lambda: SqlBlock('x') != 1, ''), 162 | (lambda: SqlBlock('y') < 2, ''), 163 | ], 164 | ) 165 | def test_block_repr(block, s): 166 | assert s == repr(block()) 167 | 168 | 169 | def test_recursion(): 170 | # without JoinComponent this would cause a recursion error, see #35 171 | values = [f'foo_{i}' for i in range(5000)] 172 | query, query_args = render(':t', t=funcs.comma_sep(*values)) 173 | 174 | assert query.startswith('$1, $2') 175 | assert query.endswith('$4998, $4999, $5000') 176 | assert len(query_args) == 5000 177 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from buildpg import BuildError, MultipleValues, Renderer, SetValues, UnsafeError, Values, VarLiteral, render 4 | 5 | args = 'template', 'ctx', 'expected_query', 'expected_params' 6 | TESTS = [ 7 | {'template': 'simple: :v', 'ctx': lambda: dict(v=1), 'expected_query': 'simple: $1', 'expected_params': [1]}, 8 | { 9 | 'template': 'multiple: :a :c :b', 10 | 'ctx': lambda: dict(a=1, b=2, c=3), 11 | 'expected_query': 'multiple: $1 $2 $3', 12 | 'expected_params': [1, 3, 2], 13 | }, 14 | { 15 | 'template': 'values: :a', 16 | 'ctx': lambda: dict(a=Values(1, 2, 3)), 17 | 'expected_query': 'values: ($1, $2, $3)', 18 | 'expected_params': [1, 2, 3], 19 | }, 20 | { 21 | 'template': 'named values: :a :a__names', 22 | 'ctx': lambda: dict(a=Values(foo=1, bar=2)), 23 | 'expected_query': 'named values: ($1, $2) foo, bar', 24 | 'expected_params': [1, 2], 25 | }, 26 | { 27 | 'template': 'multiple values: :a', 28 | 'ctx': lambda: dict(a=MultipleValues(Values(3, 2, 1), Values('i', 'j', 'k'))), 29 | 'expected_query': 'multiple values: ($1, $2, $3), ($4, $5, $6)', 30 | 'expected_params': [3, 2, 1, 'i', 'j', 'k'], 31 | }, 32 | { 33 | 'template': 'set values: :a', 34 | 'ctx': lambda: dict(a=SetValues(foo=123, bar='b', c='this is a value')), 35 | 'expected_query': 'set values: foo = $1, bar = $2, c = $3', 36 | 'expected_params': [123, 'b', 'this is a value'], 37 | }, 38 | { 39 | 'template': 'numeric: :1000 :v1000', 40 | 'ctx': lambda: {'1000': 1, 'v1000': 2}, 41 | 'expected_query': 'numeric: :1000 $1', 42 | 'expected_params': [2], 43 | }, 44 | { 45 | 'template': 'select * from a where x=:a and y=:b', 46 | 'ctx': lambda: {'a': 123, 'b': 456}, 47 | 'expected_query': 'select * from a where x=$1 and y=$2', 48 | 'expected_params': [123, 456], 49 | }, 50 | { 51 | 'template': 'select * from a where x=:a and y=:b', 52 | 'ctx': lambda: {'a': 123, 'b': 123}, 53 | 'expected_query': 'select * from a where x=$1 and y=$2', 54 | 'expected_params': [123, 123], 55 | }, 56 | { 57 | 'template': 'select * from a where x=:a and y=:a', 58 | 'ctx': lambda: {'a': 123, 'b': 123}, 59 | 'expected_query': 'select * from a where x=$1 and y=$1', 60 | 'expected_params': [123], 61 | }, 62 | { 63 | 'template': 'values: :a, again: :a', 64 | 'ctx': lambda: dict(a=Values(1, 2, 3)), 65 | 'expected_query': 'values: ($1, $2, $3), again: ($1, $2, $3)', 66 | 'expected_params': [1, 2, 3], 67 | }, 68 | { 69 | 'template': 'values: :a, different: :b', 70 | 'ctx': lambda: dict(a=Values(1, 2, 3), b=Values(1, 2, 3)), 71 | 'expected_query': 'values: ($1, $2, $3), different: ($4, $5, $6)', 72 | 'expected_params': [1, 2, 3, 1, 2, 3], 73 | }, 74 | ] 75 | 76 | 77 | @pytest.mark.parametrize(','.join(args), [[t[a] for a in args] for t in TESTS], ids=[t['template'] for t in TESTS]) 78 | def test_render(template, ctx, expected_query, expected_params): 79 | ctx = ctx() 80 | query, params = render(template, **ctx) 81 | assert expected_query == query 82 | assert expected_params == params 83 | 84 | 85 | @pytest.mark.parametrize( 86 | 'component,s', [(lambda: MultipleValues(Values(3, 2, 1), Values(1, 2, 3)), '')] 87 | ) 88 | def test_component_repr(component, s): 89 | assert s == repr(component()) 90 | 91 | 92 | def test_different_regex(): 93 | customer_render = Renderer(r'{{ ?([\w.]+) ?}}', '.') 94 | q, params = customer_render('testing: {{ a}} {{c }} {{b}}', a=1, b=2, c=3, x=Values(foo=1, bar=2)) 95 | assert q == 'testing: $1 $2 $3' 96 | assert params == [1, 3, 2] 97 | 98 | q, params = customer_render('testing: {{ x }} {{ x.names }}', x=Values(foo=1, bar=2)) 99 | assert q == 'testing: ($1, $2) foo, bar' 100 | assert params == [1, 2] 101 | 102 | 103 | @pytest.mark.parametrize( 104 | 'query,ctx,msg', 105 | [ 106 | (':a :b', dict(a=1), 'variable "b" not found in context'), 107 | (':a__names', dict(a=Values(1, 2)), '"a": "names" are not available for nameless values'), 108 | ( 109 | ':a__missing', 110 | dict(a=1), 111 | '"a": error building content, AttributeError: \'int\' object has no attribute \'render_missing\'', 112 | ), 113 | ], 114 | ) 115 | def test_errors(query, ctx, msg): 116 | with pytest.raises(BuildError) as exc_info: 117 | render(query, **ctx) 118 | assert msg in str(exc_info.value) 119 | 120 | 121 | @pytest.mark.parametrize( 122 | 'func,exc', 123 | [ 124 | (lambda: VarLiteral('"y"'), UnsafeError), 125 | (lambda: Values(**{';foobar': 'xx'}), UnsafeError), 126 | (lambda: VarLiteral(1.1), TypeError), 127 | (lambda: Values(1, 2, c=3), ValueError), 128 | (lambda: MultipleValues(Values(1), 42), ValueError), 129 | (lambda: MultipleValues(Values(a=1, b=2), Values(b=1, a=2)), ValueError), 130 | (lambda: MultipleValues(Values(1), Values(1, 2)), ValueError), 131 | ], 132 | ) 133 | def test_other_errors(func, exc): 134 | with pytest.raises(exc): 135 | func() 136 | -------------------------------------------------------------------------------- /tests/test_query.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import pytest 4 | 5 | from buildpg import MultipleValues, S, V, Values, asyncpg, funcs, render, select_fields 6 | 7 | from .conftest import DB_NAME 8 | 9 | pytestmark = pytest.mark.asyncio 10 | 11 | 12 | async def test_manual_select(conn): 13 | query, params = render('SELECT :v FROM users ORDER BY first_name', v=select_fields('first_name', 'last_name')) 14 | v = await conn.fetch(query, *params) 15 | assert [ 16 | {'first_name': 'Franks', 'last_name': 'spencer'}, 17 | {'first_name': 'Fred', 'last_name': 'blogs'}, 18 | {'first_name': 'Joe', 'last_name': None}, 19 | ] == [dict(r) for r in v] 20 | 21 | 22 | async def test_manual_logic(conn): 23 | query, params = render( 24 | 'SELECT :select FROM users WHERE :where ORDER BY :order_by', 25 | select=select_fields('first_name'), 26 | where=V('created') > datetime(2021, 1, 1), 27 | order_by=V('last_name'), 28 | ) 29 | v = await conn.fetch(query, *params) 30 | assert ['Fred', 'Joe'] == [r[0] for r in v] 31 | 32 | 33 | async def test_logic(conn, capsys): 34 | a, b, c, d = await conn.fetchrow_b( 35 | 'SELECT :a, :b, :c, :d::int', a=funcs.cast(5, 'int') * 5, b=funcs.sqrt(676), c=S('a').cat('b'), d=987_654 36 | ) 37 | assert a == 25 38 | assert b == 26 39 | assert c == 'ab' 40 | assert d == 987_654 41 | assert 'SELECT' not in capsys.readouterr().out 42 | 43 | 44 | async def test_multiple_values_execute(conn): 45 | co_id = await conn.fetchval('SELECT id FROM companies') 46 | v = MultipleValues( 47 | Values(company=co_id, first_name='anne', last_name=None, value=3, created=datetime(2032, 1, 1)), 48 | Values(company=co_id, first_name='ben', last_name='better', value=5, created=datetime(2032, 1, 2)), 49 | Values(company=co_id, first_name='charlie', last_name='cat', value=5, created=V('DEFAULT')), 50 | ) 51 | await conn.execute_b('INSERT INTO users (:values__names) VALUES :values', values=v) 52 | assert 6 == await conn.fetchval('SELECT COUNT(*) FROM users') 53 | 54 | 55 | async def test_values_executemany(conn): 56 | co_id = await conn.fetchval('SELECT id FROM companies') 57 | v = [ 58 | Values(company=co_id, first_name='anne', last_name=None, value=3, created=datetime(2032, 1, 1)), 59 | Values(company=co_id, first_name='ben', last_name='better', value=5, created=datetime(2032, 1, 2)), 60 | Values(company=co_id, first_name='charlie', last_name='cat', value=5, created=datetime(2032, 1, 3)), 61 | ] 62 | await conn.executemany_b('INSERT INTO users (:values__names) VALUES :values', v) 63 | assert 6 == await conn.fetchval('SELECT COUNT(*) FROM users') 64 | 65 | 66 | async def test_values_executemany_default(conn): 67 | co_id = await conn.fetchval('SELECT id FROM companies') 68 | v = [ 69 | Values(company=co_id, first_name='anne', last_name=None, value=V('DEFAULT')), 70 | Values(company=co_id, first_name='ben', last_name='better', value=V('DEFAULT')), 71 | Values(company=co_id, first_name='charlie', last_name='cat', value=V('DEFAULT')), 72 | ] 73 | await conn.executemany_b('INSERT INTO users (:values__names) VALUES :values', v) 74 | assert 6 == await conn.fetchval('SELECT COUNT(*) FROM users') 75 | 76 | 77 | async def test_position(conn): 78 | result = await conn.fetch_b('SELECT :a', a=funcs.position('xx', 'testing xx more')) 79 | assert [dict(r) for r in result] == [{'position': 9}] 80 | 81 | 82 | async def test_substring(conn): 83 | a = await conn.fetchval_b('SELECT :a', a=funcs.substring('Samuel', '...$')) 84 | assert a == 'uel' 85 | 86 | 87 | async def test_substring_for(conn): 88 | a = await conn.fetchval_b('SELECT :a', a=funcs.substring('Samuel', funcs.cast(2, 'int'), funcs.cast(3, 'int'))) 89 | assert a == 'amu' 90 | 91 | 92 | async def test_cursor(conn): 93 | results = [] 94 | q = 'SELECT :s FROM users ORDER BY :o' 95 | async for r in conn.cursor_b(q, s=select_fields('first_name'), o=V('first_name').asc()): 96 | results.append(tuple(r)) 97 | if len(results) == 2: 98 | break 99 | assert results == [('Franks',), ('Fred',)] 100 | 101 | 102 | async def test_pool(): 103 | async with asyncpg.create_pool_b(f'postgresql://postgres@localhost/{DB_NAME}') as pool: 104 | async with pool.acquire() as conn: 105 | v = await conn.fetchval_b('SELECT :v FROM users ORDER BY id LIMIT 1', v=funcs.right(V('first_name'), 3)) 106 | 107 | assert v == 'red' 108 | 109 | 110 | async def test_log(conn, capsys): 111 | a = await conn.fetchval_b('SELECT :a', print_=True, __timeout=12, a=funcs.cast(5, 'int') * 5) 112 | assert a == 25 113 | 114 | assert 'SELECT' in capsys.readouterr().out 115 | 116 | 117 | async def test_print(conn, capsys): 118 | query, params = conn.print_b('SELECT :a', a=funcs.cast(5, 'int') * 5) 119 | assert query == 'SELECT $1::int * $2' 120 | assert params == [5, 5] 121 | 122 | assert 'SELECT' in capsys.readouterr().out 123 | 124 | 125 | async def test_log_callable(conn): 126 | logged_message = None 127 | 128 | def foobar(arg): 129 | nonlocal logged_message 130 | logged_message = arg 131 | 132 | a = await conn.fetchval_b('SELECT :a', print_=foobar, __timeout=12, a=funcs.cast(5, 'int') * 5) 133 | assert a == 25 134 | 135 | assert 'SELECT' in logged_message 136 | 137 | 138 | async def test_coloured_callable(conn): 139 | logged_message = None 140 | 141 | def foobar(arg): 142 | nonlocal logged_message 143 | logged_message = arg 144 | 145 | foobar.formatted = True 146 | a = await conn.fetchval_b('SELECT :a', print_=foobar, __timeout=12, a=funcs.cast(5, 'int') * 5) 147 | assert a == 25 148 | 149 | assert 'SELECT' in logged_message 150 | assert '\x1b[' in logged_message 151 | 152 | 153 | async def test_pool_fetch(): 154 | async with asyncpg.create_pool_b(f'postgresql://postgres@localhost/{DB_NAME}') as pool: 155 | v = await pool.fetchval_b('SELECT :v FROM users ORDER BY id LIMIT 1', v=funcs.right(V('first_name'), 3)) 156 | 157 | assert v == 'red' 158 | --------------------------------------------------------------------------------