├── .gitignore ├── .gitmodules ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── psqlparse ├── __init__.py ├── exceptions.py ├── nodes │ ├── __init__.py │ ├── nodes.py │ ├── parsenodes.py │ ├── primnodes.py │ ├── utils.py │ └── value.py ├── parser.pyx └── pg_query.pxd ├── requirements.txt ├── setup.py └── test ├── __init__.py └── test_parse.py /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | .coveragerc 3 | libpg_query-9.5-latest.zip 4 | libpg_query-9.5-latest/ 5 | *.so 6 | psqlparse/parser.c 7 | build/ 8 | dist/ 9 | .idea/ 10 | *.egg-info 11 | __pycache__/ 12 | *.pyc 13 | .virtualenv/ 14 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "libpg_query"] 2 | path = libpg_query 3 | url = https://github.com/lfittl/libpg_query.git 4 | branch = 9.5-latest 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.5" 5 | install: 6 | - pip install -r requirements.txt 7 | - pip install 'pep8==1.7.0' 8 | - USE_CYTHON=1 python setup.py build_ext --inplace -f install 9 | # actual test run can be added later 10 | script: 11 | - python -m unittest discover --verbose 12 | - python -m pep8 . 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Aldo Culquicondor 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md LICENSE 2 | include psqlparse/pg_query.pxd psqlparse/parser.pyx 3 | include libpg_query/LICENSE libpg_query/Makefile libpg_query/pg_query.h 4 | recursive-include libpg_query/src *.c *.h 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | psqlparse 2 | ========= 3 | [![Build Status](https://travis-ci.org/alculquicondor/psqlparse.svg?branch=master)](https://travis-ci.org/alculquicondor/psqlparse) 4 | 5 | **This project is not maintained anymore**. If you would like to maintain it, send me a DM in twitter @alculquicondor. 6 | 7 | This Python module uses the [libpg\_query](https://github.com/lfittl/libpg_query) to parse SQL 8 | queries and return the internal PostgreSQL parsetree. 9 | 10 | Installation 11 | ------------ 12 | 13 | ```shell 14 | pip install psqlparse 15 | ``` 16 | 17 | Usage 18 | ----- 19 | 20 | ```python 21 | import psqlparse 22 | statements = psqlparse.parse('SELECT * from mytable') 23 | used_tables = statements[0].tables() # ['my_table'] 24 | ``` 25 | 26 | `tables` is only available from version 1.0rc1 27 | 28 | Development 29 | ----------- 30 | 31 | 0. Update dependencies 32 | 33 | ```shell 34 | git submodule update --init 35 | ``` 36 | 37 | 1. Install requirements: 38 | 39 | ```shell 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | 2. Build Cython extension 44 | 45 | ```shell 46 | USE_CYTHON=1 python setup.py build_ext --inplace 47 | ``` 48 | 49 | 3. Perform changes 50 | 51 | 4. Run tests 52 | 53 | ```shell 54 | pytest 55 | ``` 56 | 57 | Maintainers 58 | ------------ 59 | 60 | - [Aldo Culquicondor](https://github.com/alculquicondor/) 61 | - [Kevin Zúñiga](https://github.com/kevinzg/) 62 | -------------------------------------------------------------------------------- /psqlparse/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import parse, parse_dict 2 | -------------------------------------------------------------------------------- /psqlparse/exceptions.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | 4 | @six.python_2_unicode_compatible 5 | class PSqlParseError(Exception): 6 | 7 | def __init__(self, message, lineno, cursorpos): 8 | self.message = message 9 | self.lineno = lineno 10 | self.cursorpos = cursorpos 11 | 12 | def __str__(self): 13 | return self.message 14 | -------------------------------------------------------------------------------- /psqlparse/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | from .parsenodes import (SelectStmt, InsertStmt, UpdateStmt, DeleteStmt, 2 | WithClause, CommonTableExpr, RangeSubselect, 3 | ResTarget, ColumnRef, FuncCall, AStar, AExpr, AConst, 4 | TypeCast, TypeName, SortBy, WindowDef, LockingClause, 5 | RangeFunction, AArrayExpr, AIndices, MultiAssignRef) 6 | from .primnodes import (RangeVar, JoinExpr, Alias, IntoClause, BoolExpr, 7 | SubLink, SetToDefault, CaseExpr, CaseWhen, NullTest, 8 | BooleanTest, RowExpr) 9 | from .value import Integer, String, Float 10 | -------------------------------------------------------------------------------- /psqlparse/nodes/nodes.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | 4 | class Node(object): 5 | 6 | def tables(self): 7 | """ 8 | Generic method that does a depth-first search on the node attributes. 9 | 10 | Child classes should override this method for better performance. 11 | """ 12 | _tables = set() 13 | 14 | for attr in six.itervalues(self.__dict__): 15 | if isinstance(attr, list): 16 | for item in attr: 17 | if isinstance(item, Node): 18 | _tables |= item.tables() 19 | elif isinstance(attr, Node): 20 | _tables |= attr.tables() 21 | 22 | return _tables 23 | -------------------------------------------------------------------------------- /psqlparse/nodes/parsenodes.py: -------------------------------------------------------------------------------- 1 | import six 2 | 3 | from .utils import build_from_item 4 | from .nodes import Node 5 | 6 | 7 | class Statement(Node): 8 | 9 | statement = '' 10 | 11 | def __str__(self): 12 | return self.statement 13 | 14 | 15 | class SelectStmt(Statement): 16 | 17 | statement = 'SELECT' 18 | 19 | def __init__(self, obj): 20 | self.distinct_clause = build_from_item(obj, 'distinctClause') 21 | self.into_clause = build_from_item(obj, 'intoClause') 22 | self.target_list = build_from_item(obj, 'targetList') 23 | self.from_clause = build_from_item(obj, 'fromClause') 24 | self.where_clause = build_from_item(obj, 'whereClause') 25 | self.group_clause = build_from_item(obj, 'groupClause') 26 | self.having_clause = build_from_item(obj, 'havingClause') 27 | self.window_clause = build_from_item(obj, 'windowClause') 28 | 29 | self.values_lists = build_from_item(obj, 'valuesLists') 30 | 31 | self.sort_clause = build_from_item(obj, 'sortClause') 32 | self.limit_offset = build_from_item(obj, 'limitOffset') 33 | self.limit_count = build_from_item(obj, 'limitCount') 34 | self.locking_clause = build_from_item(obj, 'lockingClause') 35 | self.with_clause = build_from_item(obj, 'withClause') 36 | 37 | self.op = obj.get('op') 38 | self.all = obj.get('all') 39 | self.larg = build_from_item(obj, 'larg') 40 | self.rarg = build_from_item(obj, 'rarg') 41 | 42 | def tables(self): 43 | _tables = set() 44 | if self.target_list: 45 | for item in self.target_list: 46 | _tables |= item.tables() 47 | if self.from_clause: 48 | for item in self.from_clause: 49 | _tables |= item.tables() 50 | if self.where_clause: 51 | _tables |= self.where_clause.tables() 52 | if self.with_clause: 53 | _tables |= self.with_clause.tables() 54 | 55 | if self.larg: 56 | _tables |= self.larg.tables() 57 | if self.rarg: 58 | _tables |= self.rarg.tables() 59 | 60 | return _tables 61 | 62 | 63 | class InsertStmt(Statement): 64 | 65 | statement = 'INSERT INTO' 66 | 67 | def __init__(self, obj): 68 | self.relation = build_from_item(obj, 'relation') 69 | self.cols = build_from_item(obj, 'cols') 70 | self.select_stmt = build_from_item(obj, 'selectStmt') 71 | self.on_conflict_clause = build_from_item(obj, 'onConflictClause') 72 | self.returning_list = build_from_item(obj, 'returningList') 73 | self.with_clause = build_from_item(obj, 'withClause') 74 | 75 | def tables(self): 76 | _tables = self.relation.tables() | self.select_stmt.tables() 77 | 78 | if self.with_clause: 79 | _tables |= self.with_clause.tables() 80 | 81 | return _tables 82 | 83 | 84 | class UpdateStmt(Statement): 85 | 86 | statement = 'UPDATE' 87 | 88 | def __init__(self, obj): 89 | self.relation = build_from_item(obj, 'relation') 90 | self.target_list = build_from_item(obj, 'targetList') 91 | self.where_clause = build_from_item(obj, 'whereClause') 92 | self.from_clause = build_from_item(obj, 'fromClause') 93 | self.returning_list = build_from_item(obj, 'returningList') 94 | self.with_clause = build_from_item(obj, 'withClause') 95 | 96 | def tables(self): 97 | _tables = self.relation.tables() 98 | 99 | if self.where_clause: 100 | _tables |= self.where_clause.tables() 101 | if self.from_clause: 102 | for item in self.from_clause: 103 | _tables |= item.tables() 104 | if self.with_clause: 105 | _tables |= self.with_clause.tables() 106 | 107 | return _tables 108 | 109 | 110 | class DeleteStmt(Statement): 111 | 112 | statement = 'DELETE FROM' 113 | 114 | def __init__(self, obj): 115 | self.relation = build_from_item(obj, 'relation') 116 | self.using_clause = build_from_item(obj, 'usingClause') 117 | self.where_clause = build_from_item(obj, 'whereClause') 118 | self.returning_list = build_from_item(obj, 'returningList') 119 | self.with_clause = build_from_item(obj, 'withClause') 120 | 121 | def tables(self): 122 | _tables = self.relation.tables() 123 | 124 | if self.using_clause: 125 | for item in self.using_clause: 126 | _tables |= item.tables() 127 | if self.where_clause: 128 | _tables |= self.where_clause.tables() 129 | if self.with_clause: 130 | _tables |= self.with_clause.tables() 131 | 132 | return _tables 133 | 134 | 135 | class WithClause(Node): 136 | 137 | def __init__(self, obj): 138 | self.ctes = build_from_item(obj, 'ctes') 139 | self.recursive = obj.get('recursive') 140 | self.location = obj.get('location') 141 | 142 | def __repr__(self): 143 | return '' % len(self.ctes) 144 | 145 | def __str__(self): 146 | s = 'WITH ' 147 | if self.recursive: 148 | s += 'RECURSIVE ' 149 | s += ', '.join( 150 | ['%s AS (%s)' % (name, query) 151 | for name, query in six.iteritems(self.ctes)]) 152 | return s 153 | 154 | def tables(self): 155 | _tables = set() 156 | for item in self.ctes: 157 | _tables |= item.tables() 158 | return _tables 159 | 160 | 161 | class CommonTableExpr(Node): 162 | 163 | def __init__(self, obj): 164 | self.ctename = obj.get('ctename') 165 | self.aliascolnames = build_from_item(obj, 'aliascolnames') 166 | self.ctequery = build_from_item(obj, 'ctequery') 167 | self.location = obj.get('location') 168 | self.cterecursive = obj.get('cterecursive') 169 | self.cterefcount = obj.get('cterefcount') 170 | self.ctecolnames = build_from_item(obj, 'ctecolnames') 171 | self.ctecoltypes = build_from_item(obj, 'ctecoltypes') 172 | self.ctecoltypmods = build_from_item(obj, 'ctecoltypmods') 173 | self.ctecolcollations = build_from_item(obj, 'ctecolcollations') 174 | 175 | def tables(self): 176 | return self.ctequery.tables() 177 | 178 | 179 | class RangeSubselect(Node): 180 | 181 | def __init__(self, obj): 182 | self.lateral = obj.get('lateral') 183 | self.subquery = build_from_item(obj, 'subquery') 184 | self.alias = build_from_item(obj, 'alias') 185 | 186 | def tables(self): 187 | return self.subquery.tables() 188 | 189 | 190 | class ResTarget(Node): 191 | """ 192 | Result target. 193 | 194 | In a SELECT target list, 'name' is the column label from an 195 | 'AS ColumnLabel' clause, or NULL if there was none, and 'val' is the 196 | value expression itself. The 'indirection' field is not used. 197 | 198 | INSERT uses ResTarget in its target-column-names list. Here, 'name' is 199 | the name of the destination column, 'indirection' stores any subscripts 200 | attached to the destination, and 'val' is not used. 201 | 202 | In an UPDATE target list, 'name' is the name of the destination column, 203 | 'indirection' stores any subscripts attached to the destination, and 204 | 'val' is the expression to assign. 205 | """ 206 | 207 | def __init__(self, obj): 208 | self.name = obj.get('name') 209 | self.indirection = build_from_item(obj, 'indirection') 210 | self.val = build_from_item(obj, 'val') 211 | self.location = obj.get('location') 212 | 213 | def tables(self): 214 | _tables = set() 215 | if isinstance(self.val, list): 216 | for item in self.val: 217 | _tables |= item.tables() 218 | elif isinstance(self.val, Node): 219 | _tables |= self.val.tables() 220 | 221 | return _tables 222 | 223 | 224 | class ColumnRef(Node): 225 | 226 | def __init__(self, obj): 227 | self.fields = build_from_item(obj, 'fields') 228 | self.location = obj.get('location') 229 | 230 | def tables(self): 231 | return set() 232 | 233 | 234 | class FuncCall(Node): 235 | 236 | def __init__(self, obj): 237 | self.funcname = build_from_item(obj, 'funcname') 238 | self.args = build_from_item(obj, 'args') 239 | self.agg_order = build_from_item(obj, 'agg_order') 240 | self.agg_filter = build_from_item(obj, 'agg_filter') 241 | self.agg_within_group = obj.get('agg_within_group') 242 | self.agg_star = obj.get('agg_star') 243 | self.agg_distinct = obj.get('agg_distinct') 244 | self.func_variadic = obj.get('func_variadic') 245 | self.over = build_from_item(obj, 'over') 246 | self.location = obj.get('location') 247 | 248 | def tables(self): 249 | _tables = set() 250 | if self.args: 251 | for item in self.args: 252 | _tables |= item.tables() 253 | return _tables 254 | 255 | 256 | class AStar(Node): 257 | 258 | def __init__(self, obj): 259 | pass 260 | 261 | def tables(self): 262 | return set() 263 | 264 | 265 | class AExpr(Node): 266 | 267 | def __init__(self, obj): 268 | self.kind = obj.get('kind') 269 | self.name = build_from_item(obj, 'name') 270 | self.lexpr = build_from_item(obj, 'lexpr') 271 | self.rexpr = build_from_item(obj, 'rexpr') 272 | self.location = obj.get('location') 273 | 274 | def tables(self): 275 | _tables = set() 276 | 277 | if isinstance(self.lexpr, list): 278 | for item in self.lexpr: 279 | _tables |= item.tables() 280 | elif isinstance(self.lexpr, Node): 281 | _tables |= self.lexpr.tables() 282 | 283 | if isinstance(self.rexpr, list): 284 | for item in self.rexpr: 285 | _tables |= item.tables() 286 | elif isinstance(self.rexpr, Node): 287 | _tables |= self.rexpr.tables() 288 | 289 | return _tables 290 | 291 | 292 | class AConst(Node): 293 | 294 | def __init__(self, obj): 295 | self.val = build_from_item(obj, 'val') 296 | self.location = obj.get('location') 297 | 298 | def tables(self): 299 | return set() 300 | 301 | 302 | class TypeCast(Node): 303 | 304 | def __init__(self, obj): 305 | self.arg = build_from_item(obj, 'arg') 306 | self.type_name = build_from_item(obj, 'typeName') 307 | self.location = obj.get('location') 308 | 309 | 310 | class TypeName(Node): 311 | 312 | def __init__(self, obj): 313 | self.names = build_from_item(obj, 'names') 314 | self.type_oid = obj.get('typeOid') 315 | self.setof = obj.get('setof') 316 | self.pct_type = obj.get('pct_type') 317 | self.typmods = build_from_item(obj, 'typmods') 318 | self.typemod = obj.get('typemod') 319 | self.array_bounds = build_from_item(obj, 'arrayBounds') 320 | self.location = obj.get('location') 321 | 322 | 323 | class SortBy(Node): 324 | 325 | def __init__(self, obj): 326 | self.node = build_from_item(obj, 'node') 327 | self.sortby_dir = obj.get('sortby_dir') 328 | self.sortby_nulls = obj.get('sortby_nulls') 329 | self.use_op = build_from_item(obj, 'useOp') 330 | self.location = obj.get('location') 331 | 332 | 333 | class WindowDef(Node): 334 | 335 | def __init__(self, obj): 336 | self.name = obj.get('name') 337 | self.refname = obj.get('refname') 338 | self.partition_clause = build_from_item(obj, 'partitionClause') 339 | self.order_clause = build_from_item(obj, 'orderClause') 340 | self.frame_options = obj.get('frameOptions') 341 | self.start_offset = build_from_item(obj, 'startOffset') 342 | self.end_offset = build_from_item(obj, 'endOffset') 343 | self.location = obj.get('location') 344 | 345 | 346 | class LockingClause(Node): 347 | 348 | def __init__(self, obj): 349 | self.locked_rels = build_from_item(obj, 'lockedRels') 350 | self.strength = build_from_item(obj, 'strength') 351 | self.wait_policy = obj.get('waitPolicy') 352 | 353 | 354 | class RangeFunction(Node): 355 | 356 | def __init__(self, obj): 357 | self.lateral = obj.get('lateral') 358 | self.ordinality = obj.get('ordinality') 359 | self.is_rowsfrom = obj.get('is_rowsfrom') 360 | self.functions = build_from_item(obj, 'functions') 361 | self.alias = build_from_item(obj, 'alias') 362 | self.coldeflist = build_from_item(obj, 'coldeflist') 363 | 364 | 365 | class AArrayExpr(Node): 366 | 367 | def __init__(self, obj): 368 | self.elements = build_from_item(obj, 'elements') 369 | self.location = obj.get('location') 370 | 371 | 372 | class AIndices(Node): 373 | def __init__(self, obj): 374 | self.lidx = build_from_item(obj, 'lidx') 375 | self.uidx = build_from_item(obj, 'uidx') 376 | 377 | 378 | class MultiAssignRef(Node): 379 | 380 | def __init__(self, obj): 381 | self.source = build_from_item(obj, 'source') 382 | self.colno = obj.get('colno') 383 | self.ncolumns = obj.get('ncolumns') 384 | -------------------------------------------------------------------------------- /psqlparse/nodes/primnodes.py: -------------------------------------------------------------------------------- 1 | from .utils import build_from_item 2 | from .nodes import Node 3 | 4 | 5 | class RangeVar(Node): 6 | 7 | def __init__(self, obj): 8 | """ 9 | Range variable, used in FROM clauses 10 | 11 | Also used to represent table names in utility statements; there, 12 | the alias field is not used, and inhOpt shows whether to apply the 13 | operation recursively to child tables. 14 | """ 15 | 16 | self.catalogname = obj.get('catalogname') 17 | self.schemaname = obj.get('schemaname') 18 | self.relname = obj.get('relname') 19 | self.inh_opt = obj.get('inhOpt') 20 | self.relpersistence = obj.get('relpersistence') 21 | self.alias = build_from_item(obj, 'alias') 22 | self.location = obj['location'] 23 | 24 | def __repr__(self): 25 | return '' % self.relname 26 | 27 | def __str__(self): 28 | return '%s' % self.relname 29 | 30 | def tables(self): 31 | components = [ 32 | getattr(self, name) for name in ('schemaname', 'relname') 33 | if getattr(self, name, None) is not None 34 | ] 35 | return {'.'.join(components)} 36 | 37 | 38 | class JoinExpr(Node): 39 | """ 40 | For SQL JOIN expressions 41 | """ 42 | 43 | def __init__(self, obj): 44 | self.jointype = obj.get('jointype') 45 | self.is_natural = obj.get('isNatural') 46 | self.larg = build_from_item(obj, 'larg') 47 | self.rarg = build_from_item(obj, 'rarg') 48 | self.using_clause = build_from_item(obj, 'usingClause') 49 | self.quals = build_from_item(obj, 'quals') 50 | self.alias = build_from_item(obj, 'alias') 51 | 52 | def __repr__(self): 53 | return '' % self.jointype 54 | 55 | def __str__(self): 56 | return '%s JOIN %s ON ()' % (self.larg, self.rarg) 57 | 58 | def tables(self): 59 | return self.larg.tables() | self.rarg.tables() 60 | 61 | 62 | class Alias(Node): 63 | 64 | def __init__(self, obj): 65 | self.aliasname = obj.get('aliasname') 66 | self.colnames = build_from_item(obj, 'colnames') 67 | 68 | def tables(self): 69 | return set() 70 | 71 | 72 | class IntoClause(Node): 73 | 74 | def __init__(self, obj): 75 | self._obj = obj 76 | 77 | 78 | class Expr(Node): 79 | """ 80 | Expr - generic superclass for executable-expression nodes 81 | """ 82 | 83 | 84 | class BoolExpr(Expr): 85 | 86 | def __init__(self, obj): 87 | self.boolop = obj.get('boolop') 88 | self.args = build_from_item(obj, 'args') 89 | self.location = obj.get('location') 90 | 91 | def tables(self): 92 | _tables = set() 93 | for item in self.args: 94 | _tables |= item.tables() 95 | return _tables 96 | 97 | 98 | class SubLink(Expr): 99 | 100 | def __init__(self, obj): 101 | self.sub_link_type = obj.get('subLinkType') 102 | self.sub_link_id = obj.get('subLinkId') 103 | self.testexpr = build_from_item(obj, 'testexpr') 104 | self.oper_name = build_from_item(obj, 'operName') 105 | self.subselect = build_from_item(obj, 'subselect') 106 | self.location = obj.get('location') 107 | 108 | def tables(self): 109 | return self.subselect.tables() 110 | 111 | 112 | class SetToDefault(Node): 113 | 114 | def __init__(self, obj): 115 | self.type_id = obj.get('typeId') 116 | self.type_mod = obj.get('typeMod') 117 | self.collation = obj.get('collation') 118 | self.location = obj.get('location') 119 | 120 | 121 | class CaseExpr(Node): 122 | 123 | def __init__(self, obj): 124 | self.casetype = obj.get('casetype') 125 | self.casecollid = obj.get('casecollid') 126 | self.arg = build_from_item(obj, 'arg') 127 | self.args = build_from_item(obj, 'args') 128 | self.defresult = build_from_item(obj, 'defresult') 129 | self.location = obj.get('location') 130 | 131 | 132 | class CaseWhen(Node): 133 | 134 | def __init__(self, obj): 135 | self.expr = build_from_item(obj, 'expr') 136 | self.result = build_from_item(obj, 'result') 137 | self.location = obj.get('location') 138 | 139 | 140 | class NullTest(Node): 141 | 142 | def __init__(self, obj): 143 | self.arg = build_from_item(obj, 'arg') 144 | self.nulltesttype = obj.get('nulltesttype') 145 | self.argisrow = obj.get('argisrow') 146 | self.location = obj.get('location') 147 | 148 | 149 | class BooleanTest(Node): 150 | 151 | def __init__(self, obj): 152 | self.arg = build_from_item(obj, 'arg') 153 | self.booltesttype = obj.get('booltesttype') 154 | self.location = obj.get('location') 155 | 156 | 157 | class RowExpr(Node): 158 | 159 | def __init__(self, obj): 160 | self.args = build_from_item(obj, 'args') 161 | self.colnames = build_from_item(obj, 'colnames') 162 | self.location = obj['location'] 163 | self.row_format = obj.get('row_format') 164 | self.type_id = obj.get('typeId') 165 | -------------------------------------------------------------------------------- /psqlparse/nodes/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | from six import next, iterkeys, itervalues 4 | 5 | 6 | module = importlib.import_module('psqlparse.nodes') 7 | 8 | 9 | def get_node_class(class_name): 10 | class_name = class_name.replace('_', '') 11 | return getattr(module, class_name, None) 12 | 13 | 14 | def build_from_obj(obj): 15 | if isinstance(obj, list): 16 | return [build_from_obj(item) for item in obj] 17 | if not isinstance(obj, dict): 18 | return obj 19 | _class = get_node_class(next(iterkeys(obj))) 20 | return _class(next(itervalues(obj))) if _class else obj 21 | 22 | 23 | def build_from_item(obj, key): 24 | return build_from_obj(obj[key]) if key in obj else None 25 | -------------------------------------------------------------------------------- /psqlparse/nodes/value.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Value(object): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | def __str__(self): 8 | return str(self.val) 9 | 10 | @abc.abstractproperty 11 | def val(self): 12 | pass 13 | 14 | 15 | class Integer(Value): 16 | 17 | def __init__(self, obj): 18 | self.ival = obj.get('ival') 19 | 20 | def __int__(self): 21 | return self.ival 22 | 23 | @property 24 | def val(self): 25 | return self.ival 26 | 27 | 28 | class String(Value): 29 | 30 | def __init__(self, obj): 31 | self.str = obj.get('str') 32 | 33 | @property 34 | def val(self): 35 | return self.str 36 | 37 | 38 | class Float(Value): 39 | 40 | def __init__(self, obj): 41 | self.str = obj.get('str') 42 | self.fval = float(self.str) 43 | 44 | def __float__(self): 45 | return self.fval 46 | 47 | @property 48 | def val(self): 49 | return self.fval 50 | -------------------------------------------------------------------------------- /psqlparse/parser.pyx: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import six 4 | 5 | from .nodes.utils import build_from_obj 6 | from .exceptions import PSqlParseError 7 | from .pg_query cimport (pg_query_parse, pg_query_free_parse_result, 8 | PgQueryParseResult) 9 | 10 | 11 | def parse_dict(query): 12 | cdef bytes encoded_query 13 | cdef PgQueryParseResult result 14 | 15 | if isinstance(query, six.text_type): 16 | encoded_query = query.encode('utf8') 17 | elif isinstance(query, six.binary_type): 18 | encoded_query = query 19 | else: 20 | encoded_query = six.text_type(query).encode('utf8') 21 | 22 | result = pg_query_parse(encoded_query) 23 | if result.error: 24 | error = PSqlParseError(result.error.message.decode('utf8'), 25 | result.error.lineno, result.error.cursorpos) 26 | pg_query_free_parse_result(result) 27 | raise error 28 | 29 | statement_dicts = json.loads(result.parse_tree.decode('utf8'), 30 | strict=False) 31 | pg_query_free_parse_result(result) 32 | return statement_dicts 33 | 34 | 35 | def parse(query): 36 | return [build_from_obj(obj) for obj in parse_dict(query)] 37 | -------------------------------------------------------------------------------- /psqlparse/pg_query.pxd: -------------------------------------------------------------------------------- 1 | cdef extern from "pg_query.h": 2 | 3 | ctypedef struct PgQueryError: 4 | char *message 5 | int lineno 6 | int cursorpos 7 | 8 | ctypedef struct PgQueryParseResult: 9 | char *parse_tree 10 | PgQueryError *error 11 | 12 | PgQueryParseResult pg_query_parse(const char* input) 13 | 14 | void pg_query_free_parse_result(PgQueryParseResult result); 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.28.5 2 | six==1.10.0 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from setuptools.command.build_ext import build_ext 3 | import os.path 4 | import subprocess 5 | import sys 6 | 7 | 8 | libpg_query = os.path.join('.', 'libpg_query') 9 | 10 | 11 | class PSqlParseBuildExt(build_ext): 12 | 13 | def run(self): 14 | return_code = subprocess.call(['make', '-C', libpg_query, 'build']) 15 | if return_code: 16 | sys.stderr.write(''' 17 | An error occurred during extension building. 18 | Make sure you have bison and flex installed on your system. 19 | ''') 20 | sys.exit(return_code) 21 | build_ext.run(self) 22 | 23 | 24 | USE_CYTHON = bool(os.environ.get('USE_CYTHON')) 25 | 26 | ext = '.pyx' if USE_CYTHON else '.c' 27 | 28 | libraries = ['pg_query'] 29 | 30 | extensions = [ 31 | Extension('psqlparse.parser', 32 | ['psqlparse/parser' + ext], 33 | libraries=libraries, 34 | include_dirs=[libpg_query], 35 | library_dirs=[libpg_query]) 36 | ] 37 | 38 | if USE_CYTHON: 39 | from Cython.Build import cythonize 40 | extensions = cythonize(extensions) 41 | 42 | setup(name='psqlparse', 43 | version='1.0-rc7', 44 | url='https://github.com/alculquicondor/psqlparse', 45 | author='Aldo Culquicondor', 46 | author_email='aldo@amigocloud.com', 47 | description='Parse SQL queries using the PostgreSQL query parser', 48 | install_requires=['six'], 49 | license='BSD', 50 | cmdclass={'build_ext': PSqlParseBuildExt}, 51 | packages=['psqlparse', 'psqlparse.nodes'], 52 | ext_modules=extensions) 53 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alculquicondor/psqlparse/72decb854590f70cbc54c549cd033df4a256b68b/test/__init__.py -------------------------------------------------------------------------------- /test/test_parse.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from psqlparse import parse 4 | from psqlparse.exceptions import PSqlParseError 5 | from psqlparse import nodes 6 | 7 | 8 | class SelectQueriesTest(unittest.TestCase): 9 | 10 | def test_select_all_no_where(self): 11 | query = "SELECT * FROM my_table" 12 | stmt = parse(query).pop() 13 | self.assertIsInstance(stmt, nodes.SelectStmt) 14 | 15 | self.assertIsNone(stmt.where_clause) 16 | 17 | self.assertEqual(len(stmt.target_list), 1) 18 | target = stmt.target_list[0] 19 | self.assertIsInstance(target, nodes.ResTarget) 20 | self.assertIsInstance(target.val.fields[0], nodes.AStar) 21 | 22 | self.assertEqual(len(stmt.from_clause), 1) 23 | from_clause = stmt.from_clause[0] 24 | self.assertIsInstance(from_clause, nodes.RangeVar) 25 | self.assertEqual(from_clause.relname, 'my_table') 26 | 27 | def test_select_one_column_where(self): 28 | query = ("SELECT col1 FROM my_table " 29 | "WHERE my_attribute LIKE 'condition'" 30 | " AND other = 5.6 AND extra > 5") 31 | stmt = parse(query).pop() 32 | self.assertIsInstance(stmt, nodes.SelectStmt) 33 | 34 | self.assertEqual(len(stmt.target_list), 1) 35 | target = stmt.target_list[0] 36 | self.assertIsInstance(target, nodes.ResTarget) 37 | self.assertEqual(target.val.fields[0].val, 'col1') 38 | 39 | self.assertIsInstance(stmt.where_clause, nodes.BoolExpr) 40 | self.assertEqual(len(stmt.where_clause.args), 3) 41 | 42 | one = stmt.where_clause.args[0] 43 | self.assertIsInstance(one, nodes.AExpr) 44 | self.assertEqual(str(one.lexpr.fields[0].val), 'my_attribute') 45 | self.assertEqual(str(one.name[0]), '~~') 46 | self.assertEqual(str(one.rexpr.val), 'condition') 47 | 48 | two = stmt.where_clause.args[1] 49 | self.assertIsInstance(two, nodes.AExpr) 50 | self.assertEqual(str(two.lexpr.fields[0].val), 'other') 51 | self.assertEqual(str(two.name[0]), '=') 52 | self.assertEqual(float(two.rexpr.val), 5.6) 53 | 54 | three = stmt.where_clause.args[2] 55 | self.assertIsInstance(three, nodes.AExpr) 56 | self.assertEqual(str(three.lexpr.fields[0].val), 'extra') 57 | self.assertEqual(str(three.name[0]), '>') 58 | self.assertEqual(int(three.rexpr.val), 5) 59 | 60 | def test_select_join(self): 61 | query = "SELECT * FROM table_one JOIN table_two USING (common)" 62 | stmt = parse(query).pop() 63 | self.assertIsInstance(stmt, nodes.SelectStmt) 64 | 65 | self.assertEqual(len(stmt.from_clause), 1) 66 | join_expr = stmt.from_clause[0] 67 | self.assertIsInstance(join_expr, nodes.JoinExpr) 68 | 69 | self.assertIsInstance(join_expr.larg, nodes.RangeVar) 70 | self.assertEqual(join_expr.larg.relname, 'table_one') 71 | 72 | self.assertIsInstance(join_expr.rarg, nodes.RangeVar) 73 | self.assertEqual(join_expr.rarg.relname, 'table_two') 74 | 75 | def test_select_with(self): 76 | query = ("WITH fake_table AS (SELECT SUM(countable) AS total " 77 | "FROM inner_table GROUP BY groupable) " 78 | "SELECT * FROM fake_table") 79 | stmt = parse(query).pop() 80 | self.assertIsInstance(stmt, nodes.SelectStmt) 81 | 82 | self.assertIsInstance(stmt.with_clause, nodes.WithClause) 83 | self.assertEqual(len(stmt.with_clause.ctes), 1) 84 | self.assertIsNone(stmt.with_clause.recursive) 85 | 86 | with_query = stmt.with_clause.ctes[0] 87 | self.assertEqual(with_query.ctename, 'fake_table') 88 | self.assertIsInstance(with_query.ctequery, nodes.SelectStmt) 89 | 90 | def test_select_subquery(self): 91 | query = "SELECT * FROM (SELECT something FROM dataset) AS other" 92 | stmt = parse(query).pop() 93 | 94 | self.assertIsInstance(stmt, nodes.SelectStmt) 95 | self.assertEqual(len(stmt.from_clause), 1) 96 | sub_query = stmt.from_clause[0] 97 | self.assertIsInstance(sub_query, nodes.RangeSubselect) 98 | self.assertIsInstance(sub_query.alias, nodes.Alias) 99 | self.assertEqual(sub_query.alias.aliasname, 'other') 100 | self.assertIsInstance(sub_query.subquery, nodes.SelectStmt) 101 | 102 | self.assertEqual(len(stmt.target_list), 1) 103 | 104 | def test_select_from_values(self): 105 | query = ("SELECT * FROM " 106 | "(VALUES (1, 'one'), (2, 'two')) AS t (num, letter)") 107 | stmt = parse(query).pop() 108 | self.assertIsInstance(stmt, nodes.SelectStmt) 109 | 110 | self.assertEqual(len(stmt.from_clause), 1) 111 | self.assertIsInstance(stmt.from_clause[0], nodes.RangeSubselect) 112 | 113 | alias = stmt.from_clause[0].alias 114 | self.assertIsInstance(alias, nodes.Alias) 115 | self.assertEqual(alias.aliasname, 't') 116 | self.assertEqual(['num', 'letter'], [str(v) for v in alias.colnames]) 117 | 118 | subquery = stmt.from_clause[0].subquery 119 | self.assertIsInstance(subquery, nodes.SelectStmt) 120 | self.assertEqual(len(subquery.values_lists), 2) 121 | self.assertEqual([1, 'one'], [v.val.val 122 | for v in subquery.values_lists[0]]) 123 | 124 | def test_select_case(self): 125 | query = ("SELECT a, CASE WHEN a=1 THEN 'one' WHEN a=2 THEN 'two'" 126 | " ELSE 'other' END FROM test") 127 | stmt = parse(query).pop() 128 | self.assertIsInstance(stmt, nodes.SelectStmt) 129 | self.assertEqual(len(stmt.target_list), 2) 130 | target = stmt.target_list[1] 131 | self.assertIsInstance(target.val, nodes.CaseExpr) 132 | self.assertIsNone(target.val.arg) 133 | self.assertEqual(len(target.val.args), 2) 134 | self.assertIsInstance(target.val.args[0], nodes.CaseWhen) 135 | self.assertIsInstance(target.val.args[0].expr, nodes.AExpr) 136 | self.assertIsInstance(target.val.args[0].result, nodes.AConst) 137 | self.assertIsInstance(target.val.defresult, nodes.AConst) 138 | 139 | query = ("SELECT CASE a.value WHEN 0 THEN '1' ELSE '2' END FROM " 140 | "sometable a") 141 | stmt = parse(query).pop() 142 | self.assertIsInstance(stmt, nodes.SelectStmt) 143 | self.assertEqual(len(stmt.target_list), 1) 144 | target = stmt.target_list[0] 145 | self.assertIsInstance(target.val, nodes.CaseExpr) 146 | self.assertIsInstance(target.val.arg, nodes.ColumnRef) 147 | 148 | def test_select_union(self): 149 | query = "SELECT * FROM table_one UNION select * FROM table_two" 150 | stmt = parse(query).pop() 151 | self.assertIsInstance(stmt, nodes.SelectStmt) 152 | 153 | self.assertIsInstance(stmt.larg, nodes.SelectStmt) 154 | self.assertIsInstance(stmt.rarg, nodes.SelectStmt) 155 | 156 | def test_function_call(self): 157 | query = "SELECT * FROM my_table WHERE ST_Intersects(geo1, geo2)" 158 | stmt = parse(query).pop() 159 | self.assertIsInstance(stmt, nodes.SelectStmt) 160 | 161 | func_call = stmt.where_clause 162 | self.assertIsInstance(func_call, nodes.FuncCall) 163 | self.assertEqual(str(func_call.funcname[0]), 'st_intersects') 164 | self.assertEqual(str(func_call.args[0].fields[0]), 'geo1') 165 | self.assertEqual(str(func_call.args[1].fields[0]), 'geo2') 166 | 167 | def test_select_type_cast(self): 168 | query = "SELECT 'accbf276-705b-11e7-b8e4-0242ac120002'::UUID" 169 | stmt = parse(query).pop() 170 | self.assertIsInstance(stmt, nodes.SelectStmt) 171 | self.assertEqual(len(stmt.target_list), 1) 172 | target = stmt.target_list[0] 173 | self.assertIsInstance(target, nodes.ResTarget) 174 | self.assertIsInstance(target.val, nodes.TypeCast) 175 | self.assertIsInstance(target.val.arg, nodes.AConst) 176 | self.assertEqual(target.val.arg.val.val, 177 | 'accbf276-705b-11e7-b8e4-0242ac120002') 178 | self.assertIsInstance(target.val.type_name, nodes.TypeName) 179 | self.assertEqual(target.val.type_name.names[0].val, "uuid") 180 | 181 | def test_select_order_by(self): 182 | query = "SELECT * FROM my_table ORDER BY field DESC NULLS FIRST" 183 | stmt = parse(query).pop() 184 | self.assertIsInstance(stmt, nodes.SelectStmt) 185 | self.assertEqual(len(stmt.sort_clause), 1) 186 | self.assertIsInstance(stmt.sort_clause[0], nodes.SortBy) 187 | self.assertIsInstance(stmt.sort_clause[0].node, nodes.ColumnRef) 188 | self.assertEqual(stmt.sort_clause[0].sortby_dir, 2) 189 | self.assertEqual(stmt.sort_clause[0].sortby_nulls, 1) 190 | 191 | query = "SELECT * FROM my_table ORDER BY field USING @>" 192 | stmt = parse(query).pop() 193 | self.assertIsInstance(stmt.sort_clause[0], nodes.SortBy) 194 | self.assertIsInstance(stmt.sort_clause[0].node, nodes.ColumnRef) 195 | self.assertEqual(len(stmt.sort_clause[0].use_op), 1) 196 | self.assertIsInstance(stmt.sort_clause[0].use_op[0], nodes.String) 197 | self.assertEqual(stmt.sort_clause[0].use_op[0].val, '@>') 198 | 199 | def test_select_window(self): 200 | query = "SELECT salary, sum(salary) OVER () FROM empsalary" 201 | stmt = parse(query).pop() 202 | self.assertIsInstance(stmt, nodes.SelectStmt) 203 | self.assertEqual(len(stmt.target_list), 2) 204 | target = stmt.target_list[1] 205 | self.assertIsInstance(target.val, nodes.FuncCall) 206 | self.assertIsInstance(target.val.over, nodes.WindowDef) 207 | self.assertIsNone(target.val.over.order_clause) 208 | self.assertIsNone(target.val.over.partition_clause) 209 | 210 | query = ("SELECT salary, sum(salary) " 211 | "OVER (ORDER BY salary) " 212 | "FROM empsalary") 213 | stmt = parse(query).pop() 214 | self.assertIsInstance(stmt, nodes.SelectStmt) 215 | self.assertEqual(len(stmt.target_list), 2) 216 | target = stmt.target_list[1] 217 | self.assertIsInstance(target.val, nodes.FuncCall) 218 | self.assertIsInstance(target.val.over, nodes.WindowDef) 219 | self.assertEqual(len(target.val.over.order_clause), 1) 220 | self.assertIsInstance(target.val.over.order_clause[0], nodes.SortBy) 221 | self.assertIsNone(target.val.over.partition_clause) 222 | 223 | query = ("SELECT salary, avg(salary) " 224 | "OVER (PARTITION BY depname) " 225 | "FROM empsalary") 226 | stmt = parse(query).pop() 227 | self.assertIsInstance(stmt, nodes.SelectStmt) 228 | self.assertEqual(len(stmt.target_list), 2) 229 | target = stmt.target_list[1] 230 | self.assertIsInstance(target.val, nodes.FuncCall) 231 | self.assertIsInstance(target.val.over, nodes.WindowDef) 232 | self.assertIsNone(target.val.over.order_clause) 233 | self.assertEqual(len(target.val.over.partition_clause), 1) 234 | self.assertIsInstance(target.val.over.partition_clause[0], 235 | nodes.ColumnRef) 236 | 237 | def test_select_locks(self): 238 | query = "SELECT m.* FROM mytable m FOR UPDATE" 239 | stmt = parse(query).pop() 240 | self.assertIsInstance(stmt, nodes.SelectStmt) 241 | self.assertEqual(len(stmt.locking_clause), 1) 242 | self.assertIsInstance(stmt.locking_clause[0], nodes.LockingClause) 243 | self.assertEqual(stmt.locking_clause[0].strength, 4) 244 | 245 | query = "SELECT m.* FROM mytable m FOR SHARE of m nowait" 246 | stmt = parse(query).pop() 247 | self.assertIsInstance(stmt, nodes.SelectStmt) 248 | self.assertEqual(len(stmt.locking_clause), 1) 249 | self.assertIsInstance(stmt.locking_clause[0], nodes.LockingClause) 250 | self.assertEqual(stmt.locking_clause[0].strength, 2) 251 | self.assertEqual(len(stmt.locking_clause[0].locked_rels), 1) 252 | self.assertIsInstance(stmt.locking_clause[0].locked_rels[0], 253 | nodes.RangeVar) 254 | self.assertEqual(stmt.locking_clause[0].locked_rels[0].relname, 'm') 255 | self.assertEqual(stmt.locking_clause[0].wait_policy, 2) 256 | 257 | def test_select_is_null(self): 258 | query = "SELECT m.* FROM mytable m WHERE m.foo IS NULL" 259 | stmt = parse(query).pop() 260 | self.assertIsInstance(stmt, nodes.SelectStmt) 261 | self.assertIsInstance(stmt.where_clause, nodes.NullTest) 262 | self.assertEqual(stmt.where_clause.nulltesttype, 0) 263 | 264 | query = "SELECT m.* FROM mytable m WHERE m.foo IS NOT NULL" 265 | stmt = parse(query).pop() 266 | self.assertIsInstance(stmt, nodes.SelectStmt) 267 | self.assertIsInstance(stmt.where_clause, nodes.NullTest) 268 | self.assertEqual(stmt.where_clause.nulltesttype, 1) 269 | 270 | def test_select_is_true(self): 271 | query = "SELECT m.* FROM mytable m WHERE m.foo IS TRUE" 272 | stmt = parse(query).pop() 273 | self.assertIsInstance(stmt, nodes.SelectStmt) 274 | self.assertIsInstance(stmt.where_clause, nodes.BooleanTest) 275 | self.assertEqual(stmt.where_clause.booltesttype, 0) 276 | 277 | def test_select_range_function(self): 278 | query = ("SELECT m.name AS mname, pname " 279 | "FROM manufacturers m, LATERAL get_product_names(m.id) pname") 280 | stmt = parse(query).pop() 281 | self.assertIsInstance(stmt, nodes.SelectStmt) 282 | self.assertEqual(len(stmt.from_clause), 2) 283 | second = stmt.from_clause[1] 284 | self.assertIsInstance(second, nodes.RangeFunction) 285 | self.assertTrue(second.lateral) 286 | self.assertEqual(len(second.functions), 1) 287 | self.assertEqual(len(second.functions[0]), 2) 288 | self.assertIsInstance(second.functions[0][0], nodes.FuncCall) 289 | 290 | def test_select_array(self): 291 | query = ("SELECT * FROM unnest(ARRAY['a','b','c','d','e','f']) " 292 | "WITH ORDINALITY") 293 | stmt = parse(query).pop() 294 | self.assertIsInstance(stmt, nodes.SelectStmt) 295 | self.assertEqual(len(stmt.from_clause), 1) 296 | outer = stmt.from_clause[0] 297 | self.assertIsInstance(outer, nodes.RangeFunction) 298 | self.assertTrue(outer.ordinality) 299 | self.assertEqual(len(outer.functions), 1) 300 | inner = outer.functions[0][0] 301 | self.assertIsInstance(inner, nodes.FuncCall) 302 | self.assertEqual(len(inner.args), 1) 303 | self.assertIsInstance(inner.args[0], nodes.AArrayExpr) 304 | self.assertEqual(len(inner.args[0].elements), 6) 305 | 306 | def test_select_where_in_many(self): 307 | query = ( 308 | "SELECT * FROM my_table WHERE (a, b) in (('a', 'b'), ('c', 'd'))") 309 | stmt = parse(query).pop() 310 | self.assertEqual(2, len(stmt.where_clause.rexpr)) 311 | for node in stmt.where_clause.rexpr: 312 | self.assertIsInstance(node, nodes.RowExpr) 313 | 314 | 315 | class InsertQueriesTest(unittest.TestCase): 316 | 317 | def test_insert_no_where(self): 318 | query = "INSERT INTO my_table(id, name) VALUES(1, 'some')" 319 | stmt = parse(query).pop() 320 | 321 | self.assertIsInstance(stmt, nodes.InsertStmt) 322 | self.assertIsNone(stmt.returning_list) 323 | 324 | self.assertEqual(stmt.relation.relname, 'my_table') 325 | 326 | self.assertEqual(len(stmt.cols), 2) 327 | self.assertEqual(stmt.cols[0].name, 'id') 328 | self.assertEqual(stmt.cols[1].name, 'name') 329 | 330 | self.assertIsInstance(stmt.select_stmt, nodes.SelectStmt) 331 | self.assertEqual(len(stmt.select_stmt.values_lists), 1) 332 | 333 | def test_insert_select(self): 334 | query = "INSERT INTO my_table(id, name) SELECT 1, 'some'" 335 | stmt = parse(query).pop() 336 | 337 | self.assertIsInstance(stmt, nodes.InsertStmt) 338 | 339 | self.assertIsInstance(stmt.select_stmt, nodes.SelectStmt) 340 | targets = stmt.select_stmt.target_list 341 | self.assertEqual(len(targets), 2) 342 | self.assertIsInstance(targets[0], nodes.ResTarget) 343 | self.assertEqual(int(targets[0].val.val), 1) 344 | self.assertIsInstance(targets[1], nodes.ResTarget) 345 | self.assertEqual(str(targets[1].val.val), 'some') 346 | 347 | def test_insert_returning(self): 348 | query = "INSERT INTO my_table(id) VALUES (5) RETURNING id, \"date\"" 349 | stmt = parse(query).pop() 350 | 351 | self.assertIsInstance(stmt, nodes.InsertStmt) 352 | self.assertEqual(len(stmt.returning_list), 2) 353 | self.assertIsInstance(stmt.returning_list[0], nodes.ResTarget) 354 | self.assertEqual(str(stmt.returning_list[0].val.fields[0]), 'id') 355 | self.assertIsInstance(stmt.returning_list[1], nodes.ResTarget) 356 | self.assertEqual(str(stmt.returning_list[1].val.fields[0]), 'date') 357 | 358 | 359 | class UpdateQueriesTest(unittest.TestCase): 360 | 361 | def test_update_to_default(self): 362 | query = "UPDATE my_table SET the_value = DEFAULT" 363 | stmt = parse(query).pop() 364 | 365 | self.assertIsInstance(stmt, nodes.UpdateStmt) 366 | self.assertEqual(len(stmt.target_list), 1) 367 | self.assertIsInstance(stmt.target_list[0], nodes.ResTarget) 368 | self.assertIsInstance(stmt.target_list[0].val, nodes.SetToDefault) 369 | 370 | def test_update_array(self): 371 | query = ("UPDATE tictactoe " 372 | "SET board[1:3][1:3] = " 373 | "'{{" "," "," "},{" "," "," "},{" "," "," "}}' " 374 | "WHERE game = 1") 375 | stmt = parse(query).pop() 376 | 377 | self.assertIsInstance(stmt, nodes.UpdateStmt) 378 | self.assertEqual(len(stmt.target_list), 1) 379 | self.assertIsInstance(stmt.target_list[0], nodes.ResTarget) 380 | indirection = stmt.target_list[0].indirection 381 | self.assertEqual(len(indirection), 2) 382 | self.assertIsInstance(indirection[0], nodes.AIndices) 383 | self.assertIsInstance(indirection[1], nodes.AIndices) 384 | self.assertIsInstance(indirection[0].lidx, nodes.AConst) 385 | self.assertIsInstance(indirection[0].uidx, nodes.AConst) 386 | 387 | def test_update_multi_assign(self): 388 | query = ("UPDATE accounts " 389 | "SET (contact_first_name, contact_last_name) " 390 | "= (SELECT first_name, last_name FROM salesmen " 391 | "WHERE salesmen.id = accounts.sales_id)") 392 | stmt = parse(query).pop() 393 | 394 | self.assertIsInstance(stmt, nodes.UpdateStmt) 395 | self.assertEqual(len(stmt.target_list), 2) 396 | self.assertIsInstance(stmt.target_list[0], nodes.ResTarget) 397 | first = stmt.target_list[0] 398 | self.assertIsInstance(first, nodes.ResTarget) 399 | self.assertEqual(first.name, 'contact_first_name') 400 | self.assertIsInstance(first.val, nodes.MultiAssignRef) 401 | self.assertEqual(first.val.ncolumns, 2) 402 | self.assertEqual(first.val.colno, 1) 403 | self.assertIsInstance(first.val.source, nodes.SubLink) 404 | self.assertIsInstance(first.val.source.subselect, nodes.SelectStmt) 405 | 406 | 407 | class MultipleQueriesTest(unittest.TestCase): 408 | 409 | def test_has_insert_and_select_statement(self): 410 | query = ("INSERT INTO my_table(id) VALUES(1); " 411 | "SELECT * FROM my_table") 412 | stmts = parse(query) 413 | stmt_types = [type(stmt) for stmt in stmts] 414 | self.assertListEqual([nodes.InsertStmt, nodes.SelectStmt], stmt_types) 415 | 416 | def test_has_update_and_delete_statement(self): 417 | query = ("UPDATE my_table SET id = 5; " 418 | "DELETE FROM my_table") 419 | stmts = parse(query) 420 | stmt_types = [type(stmt) for stmt in stmts] 421 | self.assertListEqual([nodes.UpdateStmt, nodes.DeleteStmt], stmt_types) 422 | 423 | 424 | class WrongQueriesTest(unittest.TestCase): 425 | 426 | def test_syntax_error_select_statement(self): 427 | query = "SELECT * FRO my_table" 428 | try: 429 | parse(query) 430 | self.fail('Syntax error not generating an PSqlParseError') 431 | except PSqlParseError as e: 432 | self.assertEqual(e.cursorpos, 10) 433 | self.assertEqual(e.message, 'syntax error at or near "FRO"') 434 | 435 | def test_incomplete_insert_statement(self): 436 | query = "INSERT INTO my_table" 437 | try: 438 | parse(query) 439 | self.fail('Syntax error not generating an PSqlParseError') 440 | except PSqlParseError as e: 441 | self.assertEqual(e.cursorpos, 21) 442 | self.assertEqual(e.message, 'syntax error at end of input') 443 | 444 | def test_case_no_value(self): 445 | query = ("SELECT a, CASE WHEN a=1 THEN 'one' WHEN a=2 THEN " 446 | " ELSE 'other' END FROM test") 447 | try: 448 | parse(query) 449 | self.fail('Syntax error not generating an PSqlParseError') 450 | except PSqlParseError as e: 451 | self.assertEqual(e.cursorpos, 51) 452 | self.assertEqual(e.message, 'syntax error at or near "ELSE"') 453 | 454 | 455 | class TablesTest(unittest.TestCase): 456 | 457 | def test_simple_select(self): 458 | query = "SELECT * FROM table_one, table_two" 459 | stmt = parse(query).pop() 460 | self.assertEqual(stmt.tables(), {'table_one', 'table_two'}) 461 | 462 | def test_simple_select_using_schema_names(self): 463 | query = "SELECT * FROM table_one, public.table_one" 464 | stmt = parse(query).pop() 465 | self.assertEqual(stmt.tables(), {'table_one', 'public.table_one'}) 466 | 467 | def test_select_with(self): 468 | query = ("WITH fake_table AS (SELECT * FROM inner_table) " 469 | "SELECT * FROM fake_table") 470 | stmt = parse(query).pop() 471 | self.assertEqual(stmt.tables(), {'inner_table', 'fake_table'}) 472 | 473 | def test_update_subquery(self): 474 | query = ("UPDATE dataset SET a = 5 WHERE " 475 | "id IN (SELECT * from table_one) OR" 476 | " age IN (select * from table_two)") 477 | stmt = parse(query).pop() 478 | self.assertEqual(stmt.tables(), 479 | {'table_one', 'table_two', 'dataset'}) 480 | 481 | def test_update_from(self): 482 | query = "UPDATE dataset SET a = 5 FROM extra WHERE b = c" 483 | stmt = parse(query).pop() 484 | self.assertEqual(stmt.tables(), {'dataset', 'extra'}) 485 | 486 | def test_join(self): 487 | query = ("SELECT * FROM table_one JOIN table_two USING (common_1)" 488 | " JOIN table_three USING (common_2)") 489 | stmt = parse(query).pop() 490 | self.assertEqual(stmt.tables(), 491 | {'table_one', 'table_two', 'table_three'}) 492 | 493 | def test_insert_select(self): 494 | query = "INSERT INTO table_one(id, name) SELECT * from table_two" 495 | stmt = parse(query).pop() 496 | self.assertEqual(stmt.tables(), {'table_one', 'table_two'}) 497 | 498 | def test_insert_with(self): 499 | query = ("WITH fake as (SELECT * FROM inner_table) " 500 | "INSERT INTO dataset SELECT * FROM fake") 501 | stmt = parse(query).pop() 502 | self.assertEqual(stmt.tables(), {'inner_table', 'fake', 'dataset'}) 503 | 504 | def test_delete(self): 505 | query = ("DELETE FROM dataset USING table_one " 506 | "WHERE x = y OR x IN (SELECT * from table_two)") 507 | stmt = parse(query).pop() 508 | self.assertEqual(stmt.tables(), {'dataset', 'table_one', 509 | 'table_two'}) 510 | 511 | def test_select_union(self): 512 | query = "select * FROM table_one UNION select * FROM table_two" 513 | stmt = parse(query).pop() 514 | self.assertIsInstance(stmt, nodes.SelectStmt) 515 | 516 | self.assertEqual(stmt.tables(), {'table_one', 'table_two'}) 517 | 518 | def test_where_in_expr(self): 519 | query = "SELECT * FROM my_table WHERE (a, b) in ('a', 'b')" 520 | stmt = parse(query).pop() 521 | self.assertIsInstance(stmt, nodes.SelectStmt) 522 | self.assertEqual(stmt.tables(), {'my_table'}) 523 | 524 | def test_where_in_expr_many(self): 525 | query = ( 526 | "SELECT * FROM my_table WHERE (a, b) in (('a', 'b'), ('c', 'd'))") 527 | stmt = parse(query).pop() 528 | self.assertIsInstance(stmt, nodes.SelectStmt) 529 | self.assertEqual(stmt.tables(), {'my_table'}) 530 | 531 | def test_select_target_list(self): 532 | query = "SELECT (SELECT * FROM table_one)" 533 | stmt = parse(query).pop() 534 | self.assertIsInstance(stmt, nodes.SelectStmt) 535 | self.assertEqual(stmt.tables(), {'table_one'}) 536 | 537 | def test_func_call(self): 538 | query = "SELECT my_func((select * from table_one))" 539 | stmt = parse(query).pop() 540 | self.assertIsInstance(stmt, nodes.SelectStmt) 541 | self.assertEqual(stmt.tables(), {'table_one'}) 542 | --------------------------------------------------------------------------------