├── .gitignore ├── LICENSE ├── Makefile ├── README.rst ├── requirements.txt ├── sampleflaskapp.py ├── setup.py ├── sqlbag ├── __init__.py ├── createdrop.py ├── flask │ ├── __init__.py │ └── sessions.py ├── misc.py ├── pg │ ├── __init__.py │ ├── datetimes.py │ └── postgresql.py ├── sqla.py ├── sqla_orm.py ├── util_mysql.py └── util_pg.py ├── tests ├── common.py ├── test_createdrop.py ├── test_flask.py ├── test_misc.py ├── test_pg.py ├── test_sqla.py └── test_sqla_orm.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | 3 | # C extensions 4 | *.so 5 | 6 | # Packages 7 | *.egg 8 | *.egg-info 9 | dist 10 | build 11 | eggs 12 | parts 13 | bin 14 | var 15 | sdist 16 | develop-eggs 17 | .installed.cfg 18 | lib 19 | lib64 20 | __pycache__ 21 | wheelhouse 22 | 23 | # Installer logs 24 | pip-log.txt 25 | 26 | # Unit test / coverage reports 27 | .coverage 28 | .tox 29 | nosetests.xml 30 | 31 | # Translations 32 | *.mo 33 | 34 | # Mr Developer 35 | .mr.developer.cfg 36 | .project 37 | .pydevproject 38 | 39 | .cache 40 | 41 | docs/_build 42 | 43 | scrap 44 | .pytest_cache 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | 3 | # test commands and arguments 4 | tcommand = PYTHONPATH=. py.test -x 5 | tmessy = -svv 6 | targs = --cov-report term-missing --cov sqlbag 7 | 8 | init: apt pip 9 | 10 | apt: 11 | sudo apt-get install python-dev python3-dev libyaml-dev libmariadbclient-dev 12 | 13 | pip: 14 | pip install --upgrade pip 15 | pip install --upgrade -r requirements.txt 16 | 17 | tox: 18 | tox tests 19 | 20 | test: 21 | $(tcommand) $(targs) tests 22 | 23 | stest: 24 | $(tcommand) $(tmessy) $(targs) tests 25 | 26 | fmt: 27 | isort -rc . 28 | black . 29 | 30 | clean: 31 | find . -name \*.pyc -delete 32 | rm -rf .cache 33 | rm -rf build dist 34 | 35 | lint: 36 | flake8 sqlbag 37 | flake8 tests 38 | 39 | tidy: clean lint 40 | 41 | all: clean lint tox 42 | 43 | publish: 44 | python setup.py sdist bdist_wheel --universal 45 | twine upload dist/* 46 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | sqlbag: various sql boilerplate 2 | =============================== 3 | 4 | This is just a collection of handy code for doing database things. 5 | 6 | What is in the box 7 | ------------------ 8 | 9 | Connections, flask setup, SQLAlchemy ORM helpers, temporary database setup and teardown (handy for integration tests). 10 | 11 | Installation 12 | ------------ 13 | 14 | Simply install with `pip `_: 15 | 16 | .. code-block:: shell 17 | 18 | $ pip install sqlbag 19 | 20 | If you want you can install the database drivers you need at the same time, by specifying one of the optional bundles. 21 | 22 | If you're using postgres, this installs ``sqlbag`` and ``psycopg2``: 23 | 24 | .. code-block:: shell 25 | 26 | $ pip install sqlbag[pg] 27 | 28 | If you're installing MySQL/MariaDB then this installs ``pymysql`` as well: 29 | 30 | .. code-block:: shell 31 | 32 | $ pip install sqlbag[maria] 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | 3 | pytest 4 | pytest-cov 5 | pytest-xdist 6 | pytest-sugar 7 | 8 | mock 9 | 10 | pylint 11 | flake8 12 | 13 | psycopg2 14 | pymysql 15 | 16 | flask 17 | 18 | tox 19 | yapf 20 | 21 | pep8 22 | virtualenv 23 | wheel 24 | 25 | pendulum 26 | pytz 27 | 28 | twine 29 | black 30 | isort 31 | -------------------------------------------------------------------------------- /sampleflaskapp.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | 3 | from sqlbag.flask import FS, proxies, session_setup 4 | 5 | s = proxies.s 6 | 7 | 8 | def get_app(): 9 | a = Flask(__name__) 10 | a.s = FS("postgresql:///example", echo=True) 11 | session_setup(a) 12 | return a 13 | 14 | 15 | app = get_app() 16 | 17 | 18 | @app.route("/") 19 | def hello(): 20 | # returns 'Hello World!' as a response 21 | return s.execute("select 'Hello world!'").scalar() 22 | 23 | 24 | if __name__ == "__main__": 25 | app.run() 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import io 4 | 5 | from setuptools import find_packages, setup 6 | 7 | with io.open("README.rst") as f: 8 | readme = f.read() 9 | 10 | setup( 11 | name="sqlbag", 12 | version="0.1.1579049654", 13 | url="https://github.com/djrobstep/sqlbag", 14 | description="various snippets of SQL-related boilerplate", 15 | long_description=readme, 16 | author="Robert Lechte", 17 | author_email="robertlechte@gmail.com", 18 | install_requires=[ 19 | "pathlib; python_version<'3'", 20 | "six", 21 | "sqlalchemy"], 22 | zip_safe=False, 23 | packages=find_packages(), 24 | classifiers=["Development Status :: 3 - Alpha"], 25 | extras_require={ 26 | "pg": ["psycopg2"], 27 | "pendulum": ["pendulum", "relativedelta"], 28 | "maria": ["pymysql"], 29 | }, 30 | ) 31 | -------------------------------------------------------------------------------- /sqlbag/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """sqlbag is a bunch of handy SQL things. 3 | 4 | This is a whole bunch of useful boilerplate and helper methods, for working 5 | with SQL databases, particularly PostgreSQL. 6 | 7 | """ 8 | 9 | from __future__ import absolute_import, division, print_function, unicode_literals 10 | 11 | from .misc import ( 12 | quoted_identifier, 13 | load_sql_from_folder, 14 | load_sql_from_file, 15 | sql_from_file, 16 | sql_from_folder, 17 | sql_from_folder_iter, 18 | ) # noqa 19 | 20 | from .sqla import ( 21 | S, 22 | raw_execute, 23 | admin_db_connection, 24 | _killquery, 25 | kill_other_connections, 26 | session, 27 | DB_ERROR_TUPLE, 28 | raw_connection, 29 | get_raw_autocommit_connection, 30 | copy_url, 31 | alter_url, 32 | connection_from_s_or_c, 33 | C, 34 | ) # noqa 35 | 36 | from .sqla_orm import ( 37 | row2dict, 38 | Base, 39 | metadata_from_session, 40 | sqlachanges, 41 | get_properties, 42 | ) # noqa 43 | 44 | from .createdrop import ( 45 | database_exists, 46 | create_database, 47 | drop_database, 48 | temporary_database, 49 | can_select, 50 | ) # noqa 51 | 52 | try: 53 | from . import pg 54 | except ImportError: 55 | pg = None 56 | -------------------------------------------------------------------------------- /sqlbag/createdrop.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import copy 4 | import getpass 5 | import os 6 | import random 7 | import string 8 | import tempfile 9 | from contextlib import contextmanager 10 | 11 | from sqlalchemy import create_engine 12 | from sqlalchemy.exc import InternalError, OperationalError, ProgrammingError 13 | 14 | from sqlbag import quoted_identifier 15 | 16 | from .sqla import ( 17 | admin_db_connection, 18 | connection_from_s_or_c, 19 | make_url, 20 | kill_other_connections, 21 | ) 22 | 23 | 24 | def database_exists(db_url, test_can_select=False): 25 | url = make_url(db_url) 26 | name = url.database 27 | db_type = url.get_dialect().name 28 | 29 | if not test_can_select: 30 | if db_type == "sqlite": 31 | return name is None or name == ":memory:" or os.path.exists(name) 32 | elif db_type in ["postgresql", "mysql"]: 33 | with admin_db_connection(url) as s: 34 | return _database_exists(s, name) 35 | return can_select(url) 36 | 37 | 38 | def can_select(url): 39 | text = "select 1" 40 | 41 | e = create_engine(url) 42 | 43 | try: 44 | e.execute(text) 45 | return True 46 | except (ProgrammingError, OperationalError, InternalError): 47 | return False 48 | 49 | 50 | def _database_exists(session_or_connection, name): 51 | c = connection_from_s_or_c(session_or_connection) 52 | e = copy.copy(c.engine) 53 | url = make_url(e.url) 54 | dbtype = url.get_dialect().name 55 | 56 | if dbtype == "postgresql": 57 | EXISTENCE = """ 58 | SELECT 1 59 | FROM pg_catalog.pg_database 60 | WHERE datname = %s 61 | """ 62 | 63 | result = c.execute(EXISTENCE, (name,)).scalar() 64 | 65 | return bool(result) 66 | elif dbtype == "mysql": 67 | EXISTENCE = """ 68 | SELECT SCHEMA_NAME 69 | FROM INFORMATION_SCHEMA.SCHEMATA 70 | WHERE SCHEMA_NAME = %s 71 | """ 72 | 73 | result = c.execute(EXISTENCE, (name,)).scalar() 74 | 75 | return bool(result) 76 | 77 | 78 | def create_database(db_url, template=None, wipe_if_existing=False): 79 | target_url = make_url(db_url) 80 | dbtype = target_url.get_dialect().name 81 | 82 | if wipe_if_existing: 83 | drop_database(db_url) 84 | 85 | if database_exists(target_url): 86 | return False 87 | else: 88 | 89 | if dbtype == "sqlite": 90 | can_select(target_url) 91 | return True 92 | 93 | with admin_db_connection(target_url) as c: 94 | if template: 95 | t = "template {}".format(quoted_identifier(template)) 96 | else: 97 | t = "" 98 | 99 | c.execute( 100 | """ 101 | create database {} {}; 102 | """.format( 103 | quoted_identifier(target_url.database), t 104 | ) 105 | ) 106 | return True 107 | 108 | 109 | def drop_database(db_url): 110 | url = make_url(db_url) 111 | 112 | dbtype = url.get_dialect().name 113 | name = url.database 114 | 115 | if database_exists(url): 116 | if dbtype == "sqlite": 117 | if name and name != ":memory:": 118 | os.remove(name) 119 | return True 120 | else: 121 | return False 122 | else: 123 | with admin_db_connection(url) as c: 124 | if dbtype == "postgresql": 125 | 126 | REVOKE = "revoke connect on database {} from public" 127 | revoke = REVOKE.format(quoted_identifier(name)) 128 | c.execute(revoke) 129 | 130 | kill_other_connections(c, name, hardkill=True) 131 | 132 | c.execute( 133 | """ 134 | drop database if exists {}; 135 | """.format( 136 | quoted_identifier(name) 137 | ) 138 | ) 139 | return True 140 | else: 141 | return False 142 | 143 | 144 | def _current_username(): 145 | return getpass.getuser() 146 | 147 | 148 | def temporary_name(prefix="sqlbag_tmp_"): 149 | random_letters = [random.choice(string.ascii_lowercase) for _ in range(10)] 150 | rnd = "".join(random_letters) 151 | tempname = prefix + rnd 152 | return tempname 153 | 154 | 155 | @contextmanager 156 | def temporary_database(dialect="postgresql", do_not_delete=False, host=None): 157 | """ 158 | Args: 159 | dialect(str): Type of database to create (either 'postgresql', 'mysql', or 'sqlite'). 160 | do_not_delete: Do not delete the database as this method usually would. 161 | 162 | Creates a temporary database for the duration of the context manager scope. Cleans it up when finished unless do_not_delete is specified. 163 | 164 | PostgreSQL, MySQL/MariaDB, and SQLite are supported. This method's mysql creation code uses the pymysql driver, so make sure you have that installed. 165 | """ 166 | 167 | host = host or "" 168 | 169 | if dialect == "sqlite": 170 | tmp = tempfile.NamedTemporaryFile(delete=False) 171 | 172 | try: 173 | url = "sqlite:///" + tmp.name 174 | yield url 175 | 176 | finally: 177 | if not do_not_delete: 178 | os.remove(tmp.name) 179 | 180 | else: 181 | tempname = temporary_name() 182 | 183 | current_username = _current_username() 184 | 185 | url = "{}://{}@{}/{}".format(dialect, current_username, host, tempname) 186 | 187 | if url.startswith("mysql:"): 188 | url = url.replace("mysql:", "mysql+pymysql:", 1) 189 | 190 | try: 191 | create_database(url) 192 | yield url 193 | finally: 194 | if not do_not_delete: 195 | drop_database(url) 196 | -------------------------------------------------------------------------------- /sqlbag/flask/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Flask-specific code. 3 | 4 | Helps you setup per-request database connections for flask apps. 5 | 6 | """ 7 | 8 | from __future__ import absolute_import, division, print_function, unicode_literals 9 | 10 | from .sessions import FS, session_setup, proxies 11 | 12 | __all__ = ("FS", "session_setup", "proxies") 13 | -------------------------------------------------------------------------------- /sqlbag/flask/sessions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | from flask import _app_ctx_stack, current_app 6 | from sqlalchemy import create_engine 7 | from sqlalchemy.orm import scoped_session, sessionmaker 8 | from werkzeug.local import LocalProxy 9 | 10 | FLASK_SCOPED_SESSION_MAKERS = [] 11 | COMMIT_AFTER_REQUEST = [] 12 | 13 | 14 | def session_setup(app): 15 | """Args: 16 | app (Flask Application): The flask application to set up. 17 | 18 | Wires up any sessions created with `FS` to commit automatically once the request response is complete. 19 | """ 20 | f = flask_smart_after_request 21 | fteardown = flask_smart_teardown_appcontext 22 | 23 | funcs = app.after_request_funcs.get(None, []) 24 | 25 | if f not in funcs: 26 | funcs.append(f) 27 | app.after_request_funcs[None] = funcs 28 | 29 | if fteardown not in app.teardown_appcontext_funcs: 30 | app.teardown_appcontext_funcs.append(fteardown) 31 | 32 | 33 | def FS(*args, **kwargs): 34 | """ 35 | Args: 36 | args: Same arguments as SQLAlchemy's create_engine. 37 | kwargs: Same arguments as SQLAlchemy's create_engine. 38 | 39 | Returns: 40 | scoped_session: An SQLAlchemy scoped_session object (for more details, 41 | see SQLAlchemy docs). 42 | 43 | create this object in your initialization code. 44 | 45 | >>> s = FS('postgresql:///webdb') 46 | 47 | and make sure you've called `session_setup` somewhere in your init code also. After that, simply use it in your route methods like so: 48 | 49 | >>> results = s.execute('select a from b') 50 | 51 | all usages of this `s` object within the same request will use this same session. 52 | """ 53 | 54 | commit_after_request = kwargs.get("commit_after_request", True) 55 | 56 | s = scoped_session( 57 | sessionmaker(bind=create_engine(*args, **kwargs)), 58 | scopefunc=_app_ctx_stack.__ident_func__, 59 | ) 60 | 61 | FLASK_SCOPED_SESSION_MAKERS.append(s) 62 | COMMIT_AFTER_REQUEST.append(bool(commit_after_request)) 63 | return s 64 | 65 | 66 | def flask_smart_after_request(resp): 67 | is_error = 400 <= resp.status_code < 600 68 | 69 | for do_commit, scoped in zip(COMMIT_AFTER_REQUEST, FLASK_SCOPED_SESSION_MAKERS): 70 | if do_commit: 71 | if not is_error: 72 | scoped.commit() 73 | return resp 74 | 75 | 76 | def flask_smart_teardown_appcontext(exception=None): 77 | for scoped in FLASK_SCOPED_SESSION_MAKERS: 78 | scoped.remove() 79 | 80 | 81 | class Proxies(object): 82 | def __getattr__(self, name): 83 | def get_proxy(): 84 | return getattr(current_app, name) 85 | 86 | return LocalProxy(get_proxy) 87 | 88 | 89 | proxies = Proxies() 90 | -------------------------------------------------------------------------------- /sqlbag/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import io 4 | import sys 5 | from pathlib import Path 6 | 7 | from .sqla import raw_execute 8 | 9 | 10 | def quoted_identifier(identifier): 11 | """One-liner to add double-quote marks around an SQL identifier 12 | (table name, view name, etc), and to escape double-quote marks. 13 | 14 | Args: 15 | identifier(str): the unquoted identifier 16 | """ 17 | 18 | return '"{}"'.format(identifier.replace('"', '""')) 19 | 20 | 21 | def sql_from_file(fpath): 22 | """ 23 | Args: 24 | fpath (str): The path to the file. 25 | 26 | Returns: 27 | sql (str): The file contents as a string, with any whitespace stripped 28 | from the start and end. 29 | 30 | Merely opens a file and return the contents stripped of whitespace. 31 | """ 32 | with io.open(str(fpath)) as f: 33 | return f.read().strip() 34 | 35 | 36 | def sql_from_folder_iter(fpath): 37 | """ 38 | Args: 39 | fpath (str): The path to the file. 40 | Returns: 41 | sql (str): The file contents as a string, with any whitespace stripped 42 | from the start and end. 43 | 44 | Iterate through all the .sql files in a folder. 45 | """ 46 | folder = Path(fpath) 47 | 48 | sql_files = sorted(folder.glob("**/*.sql")) 49 | 50 | for fpath in sql_files: 51 | sql = sql_from_file(fpath) 52 | if sql: 53 | yield fpath, sql 54 | 55 | 56 | def sql_from_folder(fpath): 57 | return list(sql for _, sql in sql_from_folder_iter(fpath)) 58 | 59 | 60 | def load_sql_from_folder(s, fpath, verbose=False, out=None): 61 | """ 62 | Args: 63 | s (Session): Applies the SQL to this session. 64 | fpath (str): The path to the file. 65 | verbose (bool): Prints some information as it loads files. 66 | out (stream): Change where verbose mode prints to. defaults to sys.stdout 67 | 68 | Returns: 69 | sql (str): The file contents as a string, with any whitespace stripped from the start and end. 70 | 71 | Iterate through all the .sql files in a folder. 72 | """ 73 | 74 | if verbose: 75 | if not out: 76 | out = sys.stdout # pragma: no cover 77 | out.write("Running all .sql files in: {}".format(fpath)) 78 | 79 | for fpath, text in sql_from_folder_iter(fpath): 80 | if verbose: 81 | out.write(" Running SQL in: {}".format(fpath)) 82 | raw_execute(s, text) 83 | 84 | 85 | def load_sql_from_file(s_or_c, fpath): 86 | """ 87 | Args: 88 | s_or_c: :class:`Session` or :class:`Connection` to use. 89 | fpath (str): The path to the file. 90 | Returns: 91 | sql (str): The sql that was executed 92 | 93 | Iterate through all the .sql files in a folder. 94 | """ 95 | 96 | text = sql_from_file(fpath) 97 | 98 | if text: 99 | raw_execute(s_or_c, text) 100 | 101 | return text 102 | -------------------------------------------------------------------------------- /sqlbag/pg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | 4 | from .postgresql import ( 5 | pg_notices, 6 | pg_print_notices, 7 | pg_errorname_lookup, 8 | errorcode_from_error, 9 | ) # noqa 10 | 11 | from .datetimes import use_pendulum_for_time_types, format_relativedelta # noqa 12 | -------------------------------------------------------------------------------- /sqlbag/pg/datetimes.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta, tzinfo 2 | 3 | import pendulum 4 | from dateutil.relativedelta import relativedelta 5 | from psycopg2.extensions import AsIs, new_type, register_adapter, register_type 6 | 7 | ZERO = timedelta(0) 8 | HOUR = timedelta(hours=1) 9 | 10 | PENDULUM_DATETIME_TYPE = type(pendulum.now("UTC")) 11 | 12 | 13 | # A UTC class. 14 | class UTC(tzinfo): 15 | """UTC""" 16 | 17 | def utcoffset(self, dt): 18 | return ZERO 19 | 20 | def tzname(self, dt): 21 | return "UTC" 22 | 23 | def dst(self, dt): 24 | return ZERO 25 | 26 | 27 | utc = UTC() 28 | 29 | 30 | def vanilla(pendulum_dt): 31 | x = pendulum_dt.in_timezone("UTC") 32 | 33 | return datetime( 34 | x.year, x.month, x.day, x.hour, x.minute, x.second, x.microsecond, tzinfo=utc 35 | ) 36 | 37 | 38 | def naive(pendulum_dt): 39 | x = pendulum_dt 40 | return x.naive() 41 | 42 | 43 | def utcnow(): 44 | return pendulum.now("UTC") 45 | 46 | 47 | def localnow(): 48 | return pendulum.now() 49 | 50 | 51 | def parse_time_of_day(x): 52 | return pendulum.parse(x).time() 53 | 54 | 55 | def combine_date_and_time(date, time, timezone="UTC"): 56 | naive = datetime.combine(date, time) 57 | return pendulum.instance(naive, tz=timezone) 58 | 59 | 60 | OID_TIMESTAMP = 1114 61 | OID_TIMESTAMPTZ = 1184 62 | OID_DATE = 1082 63 | OID_TIME = 1083 64 | OID_INTERVAL = 1186 65 | 66 | 67 | def tokens_iter(s): 68 | tokens = s.split() 69 | 70 | while tokens: 71 | if ":" in tokens[0]: 72 | x, tokens = tokens[0], tokens[1:] 73 | t = pendulum.parse(x, strict=False).time() 74 | 75 | yield { 76 | "hours": x.startswith("-") and -t.hour or t.hour, 77 | "minutes": t.minute, 78 | "seconds": t.second, 79 | "microseconds": t.microsecond, 80 | } 81 | else: 82 | x, tokens = tokens[:2], tokens[2:] 83 | x[1] = x[1].replace("mons", "months") 84 | yield {x[1]: int(x[0])} 85 | 86 | 87 | def parse_interval_values(s): 88 | values = {} 89 | [values.update(_) for _ in tokens_iter(s)] 90 | 91 | for k in list(values): 92 | if not k.endswith("s"): 93 | values[k + "s"] = values.pop(k) 94 | return values 95 | 96 | 97 | def format_relativedelta(rd): 98 | RELATIVEDELTA_FIELDS = [ 99 | "years", 100 | "months", 101 | "days", 102 | "hours", 103 | "minutes", 104 | "seconds", 105 | "microseconds", 106 | ] 107 | 108 | fields = [(k, getattr(rd, k)) for k in RELATIVEDELTA_FIELDS if getattr(rd, k)] 109 | 110 | s = " ".join("{} {}".format(v, k) for k, v in fields) 111 | 112 | return s 113 | 114 | 115 | class sqlbagrelativedelta(relativedelta): 116 | def __str__(self): 117 | return format_relativedelta(self) 118 | 119 | 120 | def cast_timestamp(value, cur): 121 | if value is None: 122 | return None 123 | return pendulum.parse(value).naive() 124 | 125 | 126 | def cast_timestamptz(value, cur): 127 | if value is None: 128 | return None 129 | return pendulum.parse(value).in_timezone("UTC") 130 | 131 | 132 | def cast_time(value, cur): 133 | if value is None: 134 | return None 135 | return pendulum.parse(value).time() 136 | 137 | 138 | def cast_date(value, cur): 139 | if value is None: 140 | return None 141 | return pendulum.parse(value).date() 142 | 143 | 144 | def cast_interval(value, cur): 145 | if value is None: 146 | return None 147 | values = parse_interval_values(value) 148 | return sqlbagrelativedelta(**values) 149 | 150 | 151 | def adapt_datetime(dt): 152 | if not isinstance(dt, PENDULUM_DATETIME_TYPE): 153 | dt = pendulum.instance(dt) 154 | in_utc = dt.in_timezone("UTC") 155 | return AsIs("'{}'".format(in_utc)) 156 | 157 | 158 | def adapt_relativedelta(rd): 159 | return AsIs("'{}'".format(format_relativedelta(rd))) 160 | 161 | 162 | def register_cast(oid, typename, method): 163 | new_t = new_type((oid,), typename, method) 164 | register_type(new_t) 165 | 166 | 167 | def use_pendulum_for_time_types(): 168 | register_cast(OID_TIMESTAMP, "TIMESTAMP", cast_timestamp) 169 | register_cast(OID_TIMESTAMPTZ, "TIMESTAMPTZ", cast_timestamptz) 170 | register_cast(OID_DATE, "DATE", cast_date) 171 | register_cast(OID_TIME, "TIME", cast_time) 172 | register_cast(OID_INTERVAL, "INTERVAL", cast_interval) 173 | 174 | register_adapter(datetime, adapt_datetime) 175 | register_adapter(relativedelta, adapt_relativedelta) 176 | 177 | 178 | utc = UTC() 179 | -------------------------------------------------------------------------------- /sqlbag/pg/postgresql.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import sys 4 | 5 | import six 6 | from psycopg2 import errorcodes as pgerrorcodes 7 | 8 | from sqlbag import raw_connection 9 | 10 | if not six.PY2: 11 | unicode = str 12 | 13 | 14 | def errorcode_from_error(e): 15 | """ 16 | Get the error code from a particular error/exception caused by PostgreSQL. 17 | """ 18 | return e.orig.pgcode 19 | 20 | 21 | def pg_errorname_lookup(pgcode): 22 | """ 23 | Args: 24 | pgcode(int): A PostgreSQL error code. 25 | 26 | Returns: 27 | The error name from a PostgreSQL error code as per: https://www.postgresql.org/docs/9.5/static/errcodes-appendix.html 28 | """ 29 | 30 | return pgerrorcodes.lookup(str(pgcode)) 31 | 32 | 33 | def pg_notices(s, wipe=False): 34 | """ 35 | Args: 36 | s(:class:`sqlalchemy.orm.Session`): The session in question. 37 | wipe(bool): If true, clears the notices after reading them. 38 | 39 | Returns: 40 | notices(list): The list of notices. 41 | 42 | Grab the list of notices that PostgreSQL has generated for the connection. 43 | 44 | Optionally wipes/clears the list so you won't see the same ones if you check again later. 45 | """ 46 | 47 | c = raw_connection(s) 48 | notices = list(c.notices) 49 | if wipe: 50 | del c.notices[:] 51 | return notices 52 | 53 | 54 | def pg_print_notices(s, out=None, wipe=True): 55 | """ 56 | Args: 57 | s(sqlalchemy.orm.Session): The session. 58 | out(stream): Output stream to print notices to. If None, use sys.stdout 59 | wipe(bool): If True, wipes the current notices after reading them. 60 | 61 | Print the notices generated for a session. 62 | """ 63 | if not out: 64 | out = sys.stdout # pragma: no cover 65 | 66 | for n in pg_notices(s, wipe=wipe): 67 | for line in n.splitlines(): 68 | out.write(unicode(line)) 69 | -------------------------------------------------------------------------------- /sqlbag/sqla.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous helpful stuff for working with the SQLAlchemy core.""" 2 | 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | import copy 6 | import getpass 7 | from contextlib import contextmanager 8 | from packaging import version 9 | 10 | import sqlalchemy 11 | import sqlalchemy.engine.url 12 | import sqlalchemy.exc 13 | import sqlalchemy.orm 14 | import sqlalchemy.orm.session 15 | from six import string_types 16 | from sqlalchemy import create_engine 17 | from sqlalchemy.engine import Engine 18 | from sqlalchemy.engine.url import make_url 19 | from sqlalchemy.orm import scoped_session, sessionmaker 20 | from sqlalchemy.pool import NullPool 21 | from sqlalchemy.sql import text 22 | 23 | from .util_mysql import MYSQL_KILLQUERY_FORMAT as MYSQL_KILL 24 | from .util_pg import PSQL_KILLQUERY_FORMAT_INCLUDING_DROPPED as PG_KILL 25 | 26 | DB_ERROR_TUPLE = ( 27 | sqlalchemy.exc.OperationalError, 28 | sqlalchemy.exc.InternalError, 29 | sqlalchemy.exc.ProgrammingError, 30 | ) 31 | 32 | SCOPED_SESSION_MAKERS = {} 33 | 34 | SQLA14 = version.parse(sqlalchemy.__version__) >= version.parse('1.4.0b1') 35 | 36 | 37 | try: 38 | import secrets 39 | scopefunc = secrets.token_hex 40 | except ImportError: 41 | import random 42 | scopefunc = random.random 43 | 44 | 45 | def copy_url(db_url): 46 | """ 47 | Args: 48 | db_url: Already existing SQLAlchemy :class:`URL`, or URL string. 49 | Returns: 50 | A brand new SQLAlchemy :class:`URL`. 51 | 52 | Make a copy of a SQLAlchemy :class:`URL`. 53 | """ 54 | return copy.copy(make_url(db_url)) 55 | 56 | 57 | def alter_url(db_url, **kwargs): 58 | """ 59 | Args: 60 | db_url: Already existing SQLAlchemy :class:`URL`, or URL string. 61 | **kwargs: Attributes to modify 62 | Returns: 63 | A brand new SQLAlchemy :class:`URL`. 64 | 65 | Return a copy of a SQLALchemy :class:`URL` with some modifications. 66 | """ 67 | db_url = make_url(db_url) 68 | if SQLA14: 69 | return db_url.set(**kwargs) 70 | else: 71 | db_url = copy.copy(db_url) 72 | for k, v in kwargs.items(): 73 | setattr(db_url, k, v) 74 | return db_url 75 | 76 | def connection_from_s_or_c(s_or_c): 77 | """Args: 78 | s_or_c (str): Either an SQLAlchemy ORM :class:`Session`, or a core 79 | :class:`Connection`. 80 | 81 | Returns: 82 | Connection: An SQLAlchemy Core connection. If you passed in a 83 | :class:`Session`, it's the Connection associated with that session. 84 | 85 | Get you a method that can do both. This is handy for writing methods 86 | that can accept both :class:`Session`s and core :class:`Connection`s. 87 | 88 | """ 89 | try: 90 | s_or_c.engine 91 | return s_or_c 92 | except AttributeError: 93 | return s_or_c.connection() 94 | 95 | 96 | def get_raw_autocommit_connection(url_or_engine_or_connection): 97 | """ 98 | Args: 99 | url_or_engine_or_connection (str): A URL string or SQLAlchemy engine object, or already existing DBAPI connection. 100 | 101 | Returns: 102 | A connection in autocommit mode. 103 | 104 | Sometimes you want just want to autocommit. 105 | 106 | """ 107 | x = url_or_engine_or_connection 108 | 109 | if isinstance(x, string_types): 110 | import psycopg2 111 | 112 | c = psycopg2.connect(x) 113 | x = c 114 | elif isinstance(x, Engine): 115 | sqla_connection = x.connect() 116 | sqla_connection.execution_options(isolation_level="AUTOCOMMIT") 117 | sqla_connection.detach() 118 | x = sqla_connection.connection.connection 119 | elif hasattr(x, "protocol_version"): 120 | # this is already a DBAPI connection object 121 | pass 122 | else: 123 | raise ValueError("must pass a url or engine or DBAPI connection object") 124 | x.autocommit = True 125 | return x 126 | 127 | 128 | def session(*args, **kwargs): 129 | """ 130 | Returns: 131 | Session: A new SQLAlchemy :class:`Session`. 132 | 133 | Boilerplate method to create a database session. 134 | 135 | Pass in the same parameters as you'd pass to create_engine. Internally, 136 | this uses SQLAlchemy's `scoped_session` session constructors, which means 137 | that calling it again with the same parameters will reuse the 138 | `scoped_session`. 139 | 140 | :class:`S ` creates a session in the same way but in the form of a 141 | context manager. 142 | """ 143 | Session = get_scoped_session_maker(*args, **kwargs) 144 | return Session() 145 | 146 | 147 | @contextmanager 148 | def S(*args, **kwargs): 149 | """Boilerplate context manager for creating and using sessions. 150 | 151 | This makes using a database session as simple as: 152 | 153 | .. code-block:: python 154 | 155 | with S('postgresql:///databasename') as s: 156 | s.execute('select 1;') 157 | 158 | Does `commit()` on close, `rollback()` on exception. 159 | 160 | Also uses `scoped_session` under the hood. 161 | 162 | """ 163 | Session = get_scoped_session_maker(*args, **kwargs) 164 | 165 | try: 166 | session = Session() 167 | yield session 168 | session.commit() 169 | except Exception: 170 | session.rollback() 171 | raise 172 | finally: 173 | session.close() 174 | 175 | 176 | def get_scoped_session_maker(*args, **kwargs): 177 | """ 178 | Creates a scoped session maker, and saves it for reuse next time. 179 | 180 | """ 181 | 182 | tup = (args, frozenset(kwargs.items())) 183 | if tup not in SCOPED_SESSION_MAKERS: 184 | SCOPED_SESSION_MAKERS[tup] = scoped_session( 185 | sessionmaker(bind=create_engine(*args, **kwargs)), scopefunc=scopefunc 186 | ) 187 | return SCOPED_SESSION_MAKERS[tup] 188 | 189 | 190 | def raw_connection(s_or_c_or_rawc): 191 | """ 192 | Args: 193 | s_or_c_or_rawc (str): SQLAlchemy :class:`Session` or :class:`Connection` or already existing DBAPI connection. 194 | Returns: 195 | connection (str): Raw DBAPI connection 196 | 197 | Get the raw DBAPI connection from 198 | """ 199 | x = s_or_c_or_rawc 200 | 201 | try: 202 | return connection_from_s_or_c(x).connection 203 | except TypeError: 204 | return x 205 | 206 | 207 | def raw_execute(s, statements): 208 | raw_connection(s).cursor().execute(statements) 209 | 210 | 211 | @contextmanager 212 | def C(*args, **kwargs): 213 | """ 214 | Hello it's me. 215 | """ 216 | e = create_engine(*args, **kwargs) 217 | c = e.connect() 218 | trans = c.begin() 219 | 220 | try: 221 | yield c 222 | trans.commit() 223 | except Exception: 224 | trans.rollback() 225 | raise 226 | finally: 227 | c.close() 228 | 229 | 230 | @contextmanager 231 | def admin_db_connection(db_url): 232 | url = make_url(db_url) 233 | dbtype = url.get_dialect().name 234 | 235 | if dbtype == "postgresql": 236 | url = alter_url(url, database='') 237 | 238 | if not url.username: 239 | url = alter_url(url, username=getpass.getuser()) 240 | 241 | elif not dbtype == "sqlite": 242 | url = alter_url(url, database='') 243 | 244 | if dbtype == "postgresql": 245 | with C(url, poolclass=NullPool, isolation_level="AUTOCOMMIT") as c: 246 | yield c 247 | 248 | elif dbtype == "mysql": 249 | with C(url, poolclass=NullPool) as c: 250 | c.execute( 251 | """ 252 | SET sql_mode = 'ANSI'; 253 | """ 254 | ) 255 | yield c 256 | 257 | elif dbtype == "sqlite": 258 | with C(url, poolclass=NullPool) as c: 259 | yield c 260 | 261 | 262 | def _killquery(dbtype, dbname, hardkill): 263 | where = [] 264 | 265 | if dbtype == "postgresql": 266 | sql = PG_KILL 267 | 268 | if not hardkill: 269 | where.append("psa.state = 'idle'") 270 | if dbname: 271 | where.append("datname = :databasename") 272 | elif dbtype == "mysql": 273 | sql = MYSQL_KILL 274 | 275 | if not hardkill: 276 | where.append("COMMAND = 'Sleep'") 277 | if dbname: 278 | where.append("DB = :databasename") 279 | else: 280 | raise NotImplementedError 281 | 282 | where = " and ".join(where) 283 | 284 | if where: 285 | sql += " and {}".format(where) 286 | return sql 287 | 288 | 289 | def kill_other_connections(s_or_c, dbname=None, hardkill=False): 290 | """ 291 | Args: 292 | s_or_c: SQLAlchemy Session or Connection. Needs to have the appropriate permssions to kill connections. For best results use :class:`admin_db_connection`. 293 | dbname: Name of database. If `None`, kills connections to all databases on the server. 294 | 295 | Returns: 296 | None 297 | 298 | Kill other connections to this database (or entire database server). 299 | """ 300 | c = connection_from_s_or_c(s_or_c) 301 | 302 | dbtype = c.engine.dialect.name 303 | 304 | killquery = _killquery(dbtype, dbname=dbname, hardkill=hardkill) 305 | 306 | if dbname: 307 | results = c.execute(text(killquery), databasename=dbname) 308 | else: # pragma: no cover 309 | results = c.execute(text(killquery)) 310 | 311 | if dbtype == "mysql": 312 | for x in results: 313 | kill = text("kill connection :pid") 314 | 315 | try: 316 | c.execute(kill, pid=x.process_id) 317 | except DB_ERROR_TUPLE as e: # pragma: no cover 318 | code, message = e.orig.args 319 | if "Unknown thread id" in message: 320 | pass 321 | else: 322 | raise 323 | -------------------------------------------------------------------------------- /sqlbag/sqla_orm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import re 4 | from collections import OrderedDict 5 | 6 | import six 7 | import sqlalchemy.engine.url 8 | import sqlalchemy.exc 9 | import sqlalchemy.orm 10 | import sqlalchemy.orm.session 11 | from sqlalchemy import inspect 12 | from sqlalchemy.ext.declarative import declared_attr 13 | from sqlalchemy.schema import MetaData 14 | 15 | 16 | def metadata_from_session(s): 17 | """ 18 | Args: 19 | s: an SQLAlchemy :class:`Session` 20 | Returns: 21 | The metadata. 22 | 23 | Get the metadata associated with the schema. 24 | """ 25 | meta = MetaData() 26 | meta.reflect(bind=s.bind) 27 | return meta 28 | 29 | 30 | @six.python_2_unicode_compatible 31 | class Base(object): 32 | """ 33 | A modified ORM Base implementation that gives you a nicer __repr__ (useful when printing/logging/debugging), along with some additional properties. 34 | """ 35 | 36 | @declared_attr 37 | def __tablename__(cls): 38 | """Convert CamelCase class name to underscores_between_words 39 | table name.""" 40 | name = cls.__name__ 41 | return name[0].lower() + re.sub( 42 | r"([A-Z])", lambda m: "_" + m.group(0).lower(), name[1:] 43 | ) 44 | 45 | def __repr__(self): 46 | items = row2dict(self).items() 47 | return "{0}({1})".format( 48 | self.__class__.__name__, ", ".join(["{0}={1!r}".format(*_) for _ in items]) 49 | ) 50 | 51 | @property 52 | def _sqlachanges(self): 53 | """ 54 | Return the changes you've made to this object so far this session. 55 | """ 56 | return sqlachanges(self) 57 | 58 | @property 59 | def _ordereddict(self): 60 | """ 61 | Return this object's properties as an OrderedDict. 62 | """ 63 | return row2dict(self) 64 | 65 | def __str__(self): 66 | return repr(self) 67 | 68 | 69 | def sqlachanges(sa_object): 70 | """ 71 | Returns the changes made to this object so far this session, in {'propertyname': [listofvalues] } format. 72 | """ 73 | attrs = inspect(sa_object).attrs 74 | return { 75 | a.key: list(reversed(a.history.sum())) 76 | for a in attrs 77 | if len(a.history.sum()) > 1 78 | } 79 | 80 | 81 | def row2dict(sa_object): 82 | """ 83 | Converts a mapped object into an OrderedDict. 84 | """ 85 | return OrderedDict( 86 | (pname, getattr(sa_object, pname)) for pname in get_properties(sa_object) 87 | ) 88 | 89 | 90 | def get_properties(instance): 91 | """ 92 | Gets the mapped properties of this mapped object. 93 | """ 94 | 95 | def _props(): 96 | mapper = sqlalchemy.orm.object_mapper(instance) 97 | for prop in mapper.iterate_properties: 98 | yield prop.key 99 | 100 | return list(_props()) 101 | -------------------------------------------------------------------------------- /sqlbag/util_mysql.py: -------------------------------------------------------------------------------- 1 | MYSQL_KILLQUERY_FORMAT = """ 2 | select 3 | *, 4 | ID as process_id, 5 | connection_id() as cid 6 | from 7 | information_schema.processlist 8 | where 9 | ID != connection_id() 10 | """ 11 | -------------------------------------------------------------------------------- /sqlbag/util_pg.py: -------------------------------------------------------------------------------- 1 | PSQL_KILLQUERY_FORMAT = """ 2 | select 3 | pg_terminate_backend(psa.pid) 4 | from 5 | pg_stat_activity psa 6 | where 7 | psa.pid != pg_backend_pid() 8 | """ 9 | 10 | PSQL_KILLQUERY_FORMAT_HARD = """ 11 | select 12 | pg_terminate_backend(psa.pid) 13 | from 14 | pg_stat_activity psa 15 | where datname = :databasename 16 | """ 17 | 18 | # this is needed because pg_stat_activity doesn't 19 | # show activity of dropped users properly 20 | STAT_ACTIVITY_INCLUDING_DROPPED_USERS = """ 21 | with psa as ( 22 | SELECT 23 | *, 24 | (select datname from pg_database d where d.oid = s.datid) 25 | as datname 26 | FROM pg_stat_get_activity(NULL::integer) s 27 | 28 | ) 29 | select * from psa; 30 | """ 31 | 32 | PSQL_KILLQUERY_FORMAT_INCLUDING_DROPPED = """ 33 | with psa as ( 34 | SELECT 35 | *, 36 | (select datname from pg_database d where d.oid = s.datid) 37 | as datname 38 | FROM pg_stat_get_activity(NULL::integer) s 39 | ) 40 | select 41 | pg_terminate_backend(psa.pid) 42 | from 43 | psa 44 | where psa.pid != pg_backend_pid() 45 | """ 46 | -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from sqlbag import temporary_database 4 | 5 | 6 | @pytest.yield_fixture(scope="module") 7 | def db(): 8 | with temporary_database("postgresql") as dburi: 9 | yield dburi 10 | 11 | 12 | @pytest.yield_fixture(scope="module") 13 | def mysqldb(): 14 | with temporary_database("mysql") as dburi: 15 | yield dburi 16 | -------------------------------------------------------------------------------- /tests/test_createdrop.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | from sqlbag import ( 4 | S, 5 | create_database, 6 | database_exists, 7 | drop_database, 8 | temporary_database, 9 | ) 10 | 11 | 12 | def exists(db_url): 13 | e = database_exists(db_url) 14 | e2 = database_exists(db_url, test_can_select=True) 15 | assert e == e2 16 | return e 17 | 18 | 19 | def test_createdrop(tmpdir): 20 | sqlite_path = str(tmpdir / "testonly.db") 21 | 22 | urls = ["postgresql:///sqlbag_testonly", "mysql+pymysql:///sqlbag_testonly"] 23 | 24 | for db_url in urls: 25 | drop_database(db_url) 26 | assert not drop_database(db_url) 27 | assert not exists(db_url) 28 | assert create_database(db_url) 29 | assert exists(db_url) 30 | 31 | if db_url.startswith("postgres"): 32 | assert create_database(db_url, template="template1", wipe_if_existing=True) 33 | else: 34 | assert create_database(db_url, wipe_if_existing=True) 35 | assert exists(db_url) 36 | assert drop_database(db_url) 37 | assert not exists(db_url) 38 | 39 | db_url = "sqlite://" # in-memory special case 40 | 41 | assert exists(db_url) 42 | assert not create_database(db_url) 43 | assert exists(db_url) 44 | assert not drop_database(db_url) 45 | assert exists(db_url) 46 | 47 | db_url = "sqlite:///" + sqlite_path 48 | 49 | assert not database_exists(db_url) 50 | # selecting works because sqlite auto-creates 51 | assert database_exists(db_url, test_can_select=True) 52 | drop_database(db_url) 53 | create_database(db_url) 54 | assert exists(db_url) 55 | 56 | drop_database(db_url) 57 | assert not database_exists(db_url) 58 | assert database_exists(db_url, test_can_select=True) 59 | 60 | with temporary_database("sqlite") as dburi: 61 | with S(dburi) as s: 62 | s.execute("select 1") 63 | -------------------------------------------------------------------------------- /tests/test_flask.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | 3 | from common import db # flake8: noqa 4 | from sqlbag.flask import FS, session_setup 5 | 6 | 7 | def test_flask_integration(db): 8 | app = Flask(__name__) 9 | 10 | s = FS(db) 11 | s2 = FS(db) 12 | 13 | @app.route("/") 14 | def hello(): 15 | s.execute("select 1") 16 | s2.execute("select 2") 17 | return "ok" 18 | 19 | session_setup(app) 20 | 21 | client = app.test_client() 22 | result = client.get("/") 23 | 24 | # TODO: should test this a lot more thoroughly 25 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from sqlbag import quoted_identifier 4 | 5 | 6 | def test_misc(): 7 | assert quoted_identifier("hi") == '"hi"' 8 | assert quoted_identifier('he"llo') == '"he""llo"' 9 | -------------------------------------------------------------------------------- /tests/test_pg.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import io 4 | from datetime import datetime, timedelta, tzinfo 5 | 6 | import pendulum 7 | from dateutil.relativedelta import relativedelta 8 | from pytest import raises 9 | from sqlalchemy.exc import ProgrammingError 10 | from sqlalchemy.pool import NullPool 11 | 12 | from common import db # flake8: noqa 13 | from sqlbag import DB_ERROR_TUPLE, S, copy_url, raw_connection 14 | from sqlbag.pg import ( 15 | errorcode_from_error, 16 | pg_errorname_lookup, 17 | pg_notices, 18 | pg_print_notices, 19 | use_pendulum_for_time_types, 20 | ) 21 | from sqlbag.pg.datetimes import ( 22 | UTC, 23 | ZERO, 24 | combine_date_and_time, 25 | localnow, 26 | naive, 27 | parse_interval_values, 28 | parse_time_of_day, 29 | sqlbagrelativedelta, 30 | utcnow, 31 | vanilla, 32 | ) 33 | 34 | USERNAME = "testonly_sqlbag_user" 35 | PW = "duck" 36 | 37 | 38 | ZERO = timedelta(0) 39 | HOUR = timedelta(hours=1) 40 | 41 | 42 | def test_errors_and_messages(db): 43 | assert pg_errorname_lookup(22005) == "ERROR_IN_ASSIGNMENT" 44 | 45 | with S(db) as s: 46 | s.execute("drop table if exists x") 47 | assert pg_notices(s) == ['NOTICE: table "x" does not exist, skipping\n'] 48 | assert pg_notices(s, wipe=True) == [ 49 | 'NOTICE: table "x" does not exist, skipping\n' 50 | ] 51 | assert pg_notices(s) == [] 52 | 53 | out = io.StringIO() 54 | s.execute("drop table if exists x") 55 | pg_print_notices(s, out=out) 56 | 57 | assert out.getvalue() == 'NOTICE: table "x" does not exist, skipping' 58 | 59 | out = io.StringIO() 60 | pg_print_notices(s, out=out) 61 | 62 | assert out.getvalue() == "" 63 | 64 | s.execute("create table x(id text)") 65 | 66 | try: 67 | with S(db) as s: 68 | s.execute("create table x(id text)") 69 | except DB_ERROR_TUPLE as e: 70 | assert errorcode_from_error(e) == "42P07" 71 | 72 | 73 | def test_parse_interval(): 74 | TEST_CASES = [ 75 | "1 years 2 mons", 76 | "3 days 04:05:06", 77 | "-1 year -2 mons +3 days -04:05:06.2", 78 | "1 day" 79 | ] 80 | 81 | ANSWERS = [ 82 | dict(years=1, months=2), 83 | dict(days=3, hours=4, minutes=5, seconds=6, microseconds=0), 84 | dict( 85 | years=-1, 86 | months=-2, 87 | days=3, 88 | hours=-4, 89 | minutes=5, 90 | seconds=6, 91 | microseconds=200000, 92 | ), 93 | dict(days=1) 94 | ] 95 | 96 | for case, answer in zip(TEST_CASES, ANSWERS): 97 | assert parse_interval_values(case) == answer 98 | 99 | 100 | def test_datetime_primitives(): 101 | dt = datetime.now() 102 | 103 | utc = UTC() 104 | assert utc.utcoffset(dt) == ZERO 105 | assert utc.utcoffset(None) == ZERO 106 | 107 | assert utc.tzname(dt) == "UTC" 108 | 109 | assert utc.dst(dt) == ZERO 110 | assert utc.dst(None) == ZERO 111 | 112 | p = pendulum.instance(dt) 113 | n = naive(p) 114 | assert n == dt 115 | assert type(n) == type(p) # use pendulum naive type 116 | 117 | p2 = utcnow() 118 | 119 | assert p2.tz == p2.in_timezone("UTC").tz 120 | 121 | p3 = localnow() 122 | 123 | v = vanilla(p3) 124 | assert pendulum.instance(v) == p3 125 | 126 | tod = parse_time_of_day("2015-01-01 12:34:56") 127 | assert str(tod) == "12:34:56" 128 | 129 | d = pendulum.Date(2017, 1, 1) 130 | dt = combine_date_and_time(d, tod) 131 | assert str(dt) == "2017-01-01T12:34:56+00:00" 132 | 133 | sbrd = sqlbagrelativedelta(days=5, weeks=6, months=7) 134 | assert str(sbrd) == "7 months 47 days" 135 | 136 | 137 | def test_pendulum_for_time_types(db): 138 | t = pendulum.parse("2017-12-31 23:34:45", tz="Australia/Melbourne") 139 | i = relativedelta(days=1, seconds=200, microseconds=99) 140 | 141 | with S(db) as s: 142 | c = raw_connection(s) 143 | cu = c.cursor() 144 | 145 | cu.execute( 146 | """ 147 | select 148 | null::timestamp, 149 | null::timestamptz, 150 | null::date, 151 | null::time, 152 | null::interval 153 | """ 154 | ) 155 | 156 | descriptions = cu.description 157 | oids = [x[1] for x in descriptions] 158 | 159 | use_pendulum_for_time_types() 160 | 161 | s.execute( 162 | """ 163 | create temporary table dt( 164 | ts timestamp, 165 | tstz timestamptz, 166 | d date, 167 | t time, 168 | i interval) 169 | """ 170 | ) 171 | 172 | s.execute( 173 | """ 174 | insert into dt(ts, tstz, d, t, i) 175 | values 176 | (:ts, 177 | :tstz, 178 | :d, 179 | :t, 180 | :i) 181 | """, 182 | { 183 | "ts": vanilla(t), 184 | "tstz": t.in_timezone("Australia/Sydney"), 185 | "d": t.date(), 186 | "t": t.time(), 187 | "i": i, 188 | }, 189 | ) 190 | 191 | out = list(s.execute("""select * from dt"""))[0] 192 | 193 | assert out.ts == naive(t.in_tz("UTC")) 194 | assert out.tstz == t.in_timezone("UTC") 195 | assert out.d == t.date() 196 | assert out.t == t.time() 197 | assert out.i == i 198 | 199 | result = s.execute( 200 | """ 201 | select 202 | null::timestamp, 203 | null::timestamptz, 204 | null::date, 205 | null::time, 206 | null::interval 207 | """ 208 | ) 209 | 210 | out = list(result)[0] 211 | assert list(out) == [None, None, None, None, None] 212 | -------------------------------------------------------------------------------- /tests/test_sqla.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import io 4 | import os 5 | 6 | import psycopg2 7 | from pytest import raises 8 | from sqlalchemy import create_engine 9 | from sqlalchemy.exc import ProgrammingError 10 | 11 | from common import db # flake8: noqa 12 | from sqlbag import ( 13 | C, 14 | S, 15 | _killquery, 16 | admin_db_connection, 17 | copy_url, 18 | get_raw_autocommit_connection, 19 | kill_other_connections, 20 | load_sql_from_file, 21 | load_sql_from_folder, 22 | raw_connection, 23 | session, 24 | sql_from_folder, 25 | temporary_database, 26 | ) 27 | 28 | MYSQL_KILLQUERY_EXPECTED_ALL = """ 29 | select 30 | *, 31 | ID as process_id, 32 | connection_id() as cid 33 | from 34 | information_schema.processlist 35 | where 36 | ID != connection_id() 37 | """ 38 | 39 | MYSQL_KILLQUERY_EXPECTED = """ 40 | select 41 | *, 42 | ID as process_id, 43 | connection_id() as cid 44 | from 45 | information_schema.processlist 46 | where 47 | ID != connection_id() 48 | and COMMAND = 'Sleep' and DB = :databasename""" 49 | 50 | 51 | def test_basic(db, tmpdir): 52 | url = copy_url(db) 53 | 54 | with S(db) as s: 55 | kill_other_connections(s, url.database) 56 | 57 | s = session(db) 58 | s.commit() 59 | s.close() 60 | 61 | with C(url) as c: 62 | c.execute("select 1") 63 | core_to_raw = raw_connection(c) 64 | 65 | with raises(ProgrammingError): 66 | with C(url) as c: 67 | c.execute("select bad") 68 | 69 | with admin_db_connection("sqlite://") as c: 70 | pass 71 | 72 | with temporary_database("mysql") as mysql_url: 73 | url = copy_url(mysql_url) 74 | 75 | with S(mysql_url) as s: 76 | s.execute("select 1") 77 | 78 | with admin_db_connection(mysql_url) as c: 79 | kq = _killquery(url.get_dialect().name, None, True) 80 | assert kq == MYSQL_KILLQUERY_EXPECTED_ALL 81 | kq = _killquery(url.get_dialect().name, url.database, False) 82 | assert kq == MYSQL_KILLQUERY_EXPECTED 83 | kill_other_connections(c, url.database, False) 84 | 85 | tempd = str(tmpdir / "sqlfiles") 86 | os.makedirs(tempd) 87 | 88 | tempf1 = str(tmpdir / "sqlfiles/f.sql") 89 | 90 | tempf2 = str(tmpdir / "f2.sql") 91 | tempf3 = str(tmpdir / "f3.sql") 92 | 93 | io.open(tempf1, "w").write("create table x(a text);") 94 | io.open(tempf2, "w").write("select * from x;") 95 | io.open(tempf3, "w").write("") 96 | 97 | out = io.StringIO() 98 | 99 | with S(db) as s: 100 | load_sql_from_folder(s, str(tempd), out=out, verbose=True) 101 | load_sql_from_file(s, str(tempf2)) 102 | load_sql_from_file(s, str(tempf3)) 103 | 104 | assert sql_from_folder(str(tempd)) == ["create table x(a text);"] 105 | 106 | session_to_raw = raw_connection(s) 107 | raw_to_raw = raw_connection(session_to_raw) 108 | 109 | assert type(raw_to_raw) == type(core_to_raw) == type(session_to_raw) 110 | 111 | a = db 112 | b = create_engine(db) 113 | c = psycopg2.connect(db) 114 | 115 | for x in [a, b, c]: 116 | cc = get_raw_autocommit_connection(x) 117 | try: 118 | assert type(cc) == type(c) 119 | assert cc.autocommit == True 120 | finally: 121 | cc.close() 122 | 123 | with raises(ValueError): 124 | get_raw_autocommit_connection(1) 125 | 126 | with raises(NotImplementedError): 127 | _killquery("oracle", "db", False) 128 | 129 | 130 | import secrets 131 | 132 | 133 | def test_transaction_separation(db): 134 | with S(db) as s1, S(db) as s2: 135 | id1 = s1.execute('select txid_current()').fetchall()[0][0] 136 | id2 = s2.execute('select txid_current()').fetchall()[0][0] 137 | assert id1 != id2 138 | -------------------------------------------------------------------------------- /tests/test_sqla_orm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | from collections import OrderedDict 4 | 5 | import six 6 | from sqlalchemy import Column, Integer, String 7 | from sqlalchemy.ext.declarative import declarative_base 8 | 9 | from sqlbag import Base as SqlxBase 10 | from sqlbag import S, metadata_from_session, temporary_database 11 | 12 | Base = declarative_base(cls=SqlxBase) 13 | 14 | 15 | class Something(Base): 16 | id = Column(Integer, primary_key=True) 17 | name = Column(String) 18 | 19 | 20 | def test_orm_stuff(): 21 | with temporary_database() as url: 22 | with S(url) as s: 23 | Base.metadata.create_all(s.bind.engine) 24 | 25 | with S(url) as s: 26 | x = Something(name="kanye") 27 | s.add(x) 28 | s.commit() 29 | things = s.query(Something).all() 30 | x1 = things[0] 31 | 32 | prefix = "u" if six.PY2 else "" 33 | repr_str = "Something(id=1, name={}'kanye')".format(prefix) 34 | assert repr(x1) == str(x1) == repr_str 35 | 36 | assert metadata_from_session(s).schema == Base.metadata.schema 37 | assert x1._sqlachanges == {} 38 | assert x1._ordereddict == OrderedDict([("id", 1), ("name", "kanye")]) 39 | x1.name = "kanye west" 40 | assert x1._sqlachanges == {"name": ["kanye", "kanye west"]} 41 | s.commit() 42 | assert x1._sqlachanges == {} 43 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | envlist = py27,py36 8 | toxworkdir = {homedir}/.toxfiles{toxinidir} 9 | 10 | [testenv] 11 | commands = py.test \ 12 | [] # substitute with tox positional arguments 13 | 14 | deps = 15 | -rrequirements.txt 16 | 17 | 18 | [flake8] 19 | ignore = E501,D100,D101,D102,D103,D104,D105 20 | --------------------------------------------------------------------------------