├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── create_db.py ├── manifest.in ├── postpy ├── __init__.py ├── _version.py ├── admin.py ├── base.py ├── connections.py ├── data_types.py ├── ddl.py ├── dml.py ├── dml_copy.py ├── extensions.py ├── fixtures.py ├── formatting.py ├── pg_encodings.py ├── sql.py └── uuids.py ├── pytest.ini ├── requirements-dev.txt ├── requirements-test.txt ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── test_admin.py ├── test_base.py ├── test_data_types.py ├── test_ddl.py ├── test_dml.py ├── test_dml_copy.py ├── test_extensions.py ├── test_formatting.py ├── test_pg_encodings.py ├── test_sql.py └── test_uuids.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | 4 | [report] 5 | show_missing = true 6 | exclude_lines = 7 | def __repr__ 8 | def __str__ 9 | raise AssertionError 10 | raise NotImplementedError 11 | if __name__ == .__main__.: 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg* 3 | .coverage 4 | .tox 5 | venv 6 | _build 7 | build/ 8 | .DS_Store 9 | .cache/ 10 | .env 11 | .venv 12 | .envrc 13 | /dist 14 | /tmp 15 | 16 | # pycharm 17 | .idea/ 18 | 19 | # ipython 20 | .ipynb_checkpoints/ 21 | 22 | # Sublime 23 | /*.sublime-* 24 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | 6 | # current requirement to run postgres 9.5 > on travis 7 | sudo: required 8 | dist: trusty 9 | 10 | addons: 11 | postgresql: "9.6" 12 | env: 13 | global: 14 | - DB=postgresql 15 | - PGHOST=localhost 16 | - PGDATABASE=postpy_testing 17 | - PGUSER=postgres 18 | cache: 19 | directories: 20 | - $HOME/.cache/pip/wheels 21 | - $HOME/travis/virtualenv/python3.5.2 22 | 23 | install: 24 | - pip install --upgrade pip 25 | - pip install wheel 26 | - pip install -r requirements.txt 27 | - pip install -r requirements-test.txt 28 | - pip install codecov 29 | - python setup.py install 30 | 31 | script: 32 | - python create_db.py 33 | - py.test tests -s --cov 34 | 35 | after_success: 36 | - codecov 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Censible 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/portfoliome/postpy.svg?branch=master)](https://travis-ci.com/portfoliome/postpy) 2 | [![codecov.io](http://codecov.io/github/portfoliome/postpy/coverage.svg?branch=master)](http://codecov.io/github/portfoliome/postpy?branch=master) 3 | [![Code Health](https://landscape.io/github/portfoliome/postpy/master/landscape.svg?style=flat)](https://landscape.io/github/portfoliome/postpy/master) 4 | 5 | # postpy 6 | Postgresql utilities for ETL and data analysis. 7 | 8 | # Purpose 9 | postpy focuses on processes that typically arise from ETL processes and data analysis. Generally, these situtations arise when third-party data providers provide a default schema and handle data migration. The benefits over sqlalchemy are dml statements accepting iterable sequences, and upsert statements prior to sqlalchemy 1.1. While the library protects against SQL injection, ddl compiler functions do not check against things like reserved keywords. 10 | 11 | # Example Usage 12 | 13 | Let's say a third-party provider has given you a JSON schema file, all referring to different zipped data files. 14 | 15 | Mocking out a single file load might look something like: 16 | 17 | ```python 18 | import csv 19 | 20 | from foil.fileio import DelimitedReader 21 | from foil.parsers import parse_str, parse_int, passthrough 22 | 23 | from postpy import dml 24 | 25 | ENCODING = 'utf-8' 26 | 27 | class DataDialect(csv.Dialect): 28 | delimiter = '|' 29 | quotechar = '"' 30 | lineterminator = '\r\n' 31 | doublequote = False 32 | quoting = csv.QUOTE_NONE 33 | 34 | dialect = DataDialect() 35 | 36 | # Gathering table/file attributes 37 | 38 | tablename = 'my_table' 39 | fields = DelimitedReader.from_zipfile(zip_path, filename, encoding=ENCODING, 40 | dialect=dialect, fields=[], converters=[]).header 41 | field_parsers = [parse_str, parse_int, passthrough, parse_it] # would get through reflection or JSON file 42 | 43 | # loading one file and insert 44 | reader = DelimitedReader.from_zipfile(zip_path, filename, encoding=ENCODING, 45 | dialect=dialect, fields=fields, converters=field_parsers) 46 | 47 | # Insert records by loading only 10,000 records/file lines into memory each iteration 48 | dml.insert_many(conn, tablename, fields, records=reader, chunksize=10000) 49 | ``` 50 | 51 | Since each process is very light-weight, each loader can reside on a micro-instance. Queues like RabbitMQ or SNS/SQS can be setup to handle message notifications between each process. 52 | 53 | Instead of worrying about async/threads, each miro-instance can handle a single table load and pass off a message upon completion. 54 | 55 | # Potential Near-term Plans 56 | The ddl compilers maybe converted to sqlalchemy compilers to allow for greater flexibility in constraint definitions without adding code maintainability. Python 3.6's f-strings may be incorporated into the ddl compilers, breaking 3.5 compatibility. 57 | -------------------------------------------------------------------------------- /create_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import psycopg2 5 | from psycopg2 import connect 6 | 7 | 8 | def main(): 9 | db_name = os.environ['PGDATABASE'] 10 | connection_parameters = { 11 | 'host': os.environ['PGHOST'], 12 | 'database': 'postgres', 13 | 'user': os.environ['PGUSER'], 14 | 'password': os.environ['PGPASSWORD'] 15 | } 16 | drop_statement = 'DROP DATABASE IF EXISTS {};'.format(db_name) 17 | ddl_statement = 'CREATE DATABASE {};'.format(db_name) 18 | conn = connect(**connection_parameters) 19 | conn.autocommit = True 20 | 21 | try: 22 | with conn.cursor() as cursor: 23 | cursor.execute(drop_statement) 24 | cursor.execute(ddl_statement) 25 | conn.close() 26 | sys.stdout.write('Created database environment successfully.\n') 27 | except psycopg2.Error: 28 | raise SystemExit( 29 | 'Failed to setup Postgres environment.\n{0}'.format(sys.exc_info()) 30 | ) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /manifest.in: -------------------------------------------------------------------------------- 1 | include requirements*.txt 2 | include setup.py 3 | recursive-exclude * *.pyc *.pyo 4 | global-exclude .git* 5 | exclude .travis.yml 6 | include tox.ini 7 | include README.md -------------------------------------------------------------------------------- /postpy/__init__.py: -------------------------------------------------------------------------------- 1 | from postpy._version import version_info, __version__ 2 | 3 | from postpy.connections import connect 4 | from postpy.pg_encodings import get_postgres_encoding 5 | -------------------------------------------------------------------------------- /postpy/_version.py: -------------------------------------------------------------------------------- 1 | version_info = (0, 0, 9) 2 | 3 | __version__ = '.'.join(map(str, version_info)) 4 | -------------------------------------------------------------------------------- /postpy/admin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Database administration queries 3 | """ 4 | 5 | import psycopg2 6 | 7 | from postpy.base import Table, Column, Database, PrimaryKey 8 | from postpy.ddl import compile_qualified_name 9 | from postpy.extensions import install_extension 10 | from postpy.sql import select_dict 11 | 12 | 13 | def get_user_tables(conn): 14 | """Retrieve all user tables.""" 15 | 16 | query_string = "select schemaname, relname from pg_stat_user_tables;" 17 | with conn.cursor() as cursor: 18 | cursor.execute(query_string) 19 | tables = cursor.fetchall() 20 | 21 | return tables 22 | 23 | 24 | def get_primary_keys(conn, table: str, schema='public'): 25 | """Returns primary key columns for a specific table.""" 26 | 27 | query = """\ 28 | SELECT 29 | c.constraint_name AS pkey_constraint_name, 30 | c.column_name AS column_name 31 | FROM 32 | information_schema.key_column_usage AS c 33 | JOIN information_schema.table_constraints AS t 34 | ON t.constraint_name = c.constraint_name 35 | AND t.table_catalog = c.table_catalog 36 | AND t.table_schema = c.table_schema 37 | AND t.table_name = c.table_name 38 | WHERE t.constraint_type = 'PRIMARY KEY' 39 | AND c.table_schema=%s 40 | AND c.table_name=%s 41 | ORDER BY c.ordinal_position""" 42 | 43 | for record in select_dict(conn, query, params=(schema, table)): 44 | yield record['column_name'] 45 | 46 | 47 | def get_column_metadata(conn, table: str, schema='public'): 48 | """Returns column data following db.Column parameter specification.""" 49 | query = """\ 50 | SELECT 51 | attname as name, 52 | format_type(atttypid, atttypmod) AS data_type, 53 | NOT attnotnull AS nullable 54 | FROM pg_catalog.pg_attribute 55 | WHERE attrelid=%s::regclass 56 | AND attnum > 0 AND NOT attisdropped 57 | ORDER BY attnum;""" 58 | 59 | qualified_name = compile_qualified_name(table, schema=schema) 60 | 61 | for record in select_dict(conn, query, params=(qualified_name,)): 62 | yield record 63 | 64 | 65 | def reflect_table(conn, table_name, schema='public'): 66 | """Reflect basic table attributes.""" 67 | 68 | column_meta = list(get_column_metadata(conn, table_name, schema=schema)) 69 | primary_key_columns = list(get_primary_keys(conn, table_name, schema=schema)) 70 | 71 | columns = [Column(**column_data) for column_data in column_meta] 72 | primary_key = PrimaryKey(primary_key_columns) 73 | 74 | return Table(table_name, columns, primary_key, schema=schema) 75 | 76 | 77 | def reset(db_name): 78 | """Reset database.""" 79 | 80 | conn = psycopg2.connect(database='postgres') 81 | db = Database(db_name) 82 | conn.autocommit = True 83 | 84 | with conn.cursor() as cursor: 85 | cursor.execute(db.drop_statement()) 86 | cursor.execute(db.create_statement()) 87 | conn.close() 88 | 89 | 90 | def install_extensions(extensions, **connection_parameters): 91 | """Install Postgres extension if available. 92 | 93 | Notes 94 | ----- 95 | - superuser is generally required for installing extensions. 96 | - Currently does not support specific schema. 97 | """ 98 | 99 | from postpy.connections import connect 100 | 101 | conn = connect(**connection_parameters) 102 | conn.autocommit = True 103 | 104 | for extension in extensions: 105 | install_extension(conn, extension) 106 | -------------------------------------------------------------------------------- /postpy/base.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from foil.formatters import format_repr 4 | 5 | from postpy.ddl import ( 6 | compile_column, compile_qualified_name, compile_primary_key, 7 | compile_create_table, compile_create_temporary_table 8 | ) 9 | 10 | 11 | __all__ = ('Database', 'Schema', 'Table', 'Column', 'PrimaryKey', 'View') 12 | 13 | 14 | class Database: 15 | __slots__ = 'name', 16 | 17 | def __init__(self, name): 18 | self.name = name 19 | 20 | def create_statement(self): 21 | return 'CREATE DATABASE %s;' % self.name 22 | 23 | def drop_statement(self): 24 | return 'DROP DATABASE IF EXISTS %s;' % self.name 25 | 26 | def __repr__(self): 27 | return format_repr(self, self.__slots__) 28 | 29 | 30 | class Schema: 31 | __slots__ = 'name', 32 | 33 | def __init__(self, name): 34 | self.name = name 35 | 36 | def create_statement(self): 37 | return 'CREATE SCHEMA IF NOT EXISTS %s;' % self.name 38 | 39 | def drop_statement(self): 40 | return 'DROP SCHEMA IF EXISTS %s CASCADE;' % self.name 41 | 42 | def __repr__(self): 43 | return format_repr(self, self.__slots__) 44 | 45 | 46 | class Table(namedtuple('Table', 'name columns primary_key schema')): 47 | """Table statement formatter.""" 48 | 49 | __slots__ = () 50 | 51 | def __new__(cls, name: str, columns, primary_key, schema='public'): 52 | return super(Table, cls).__new__(cls, name, columns, 53 | primary_key, 54 | schema) 55 | 56 | def create_statement(self): 57 | return compile_create_table(self.qualified_name, 58 | self.column_statement, 59 | self.primary_key_statement) 60 | 61 | def drop_statement(self): 62 | return 'DROP TABLE IF EXISTS {};'.format(self.qualified_name) 63 | 64 | def create_temporary_statement(self): 65 | """Temporary Table Statement formatter.""" 66 | 67 | return compile_create_temporary_table(self.name, 68 | self.column_statement, 69 | self.primary_key_statement) 70 | 71 | def drop_temporary_statement(self): 72 | return 'DROP TABLE IF EXISTS {};'.format(self.name) 73 | 74 | @property 75 | def qualified_name(self): 76 | return compile_qualified_name(self.name, schema=self.schema) 77 | 78 | @property 79 | def column_names(self): 80 | return [column.name for column in self.columns] 81 | 82 | @property 83 | def primary_key_columns(self): 84 | return self.primary_key.column_names 85 | 86 | @property 87 | def column_statement(self): 88 | return ' '.join(c.create_statement() for c in self.columns) 89 | 90 | @property 91 | def primary_key_statement(self): 92 | return self.primary_key.create_statement() 93 | 94 | 95 | class Column(namedtuple('Column', 'name data_type nullable')): 96 | __slots__ = () 97 | 98 | def __new__(cls, name: str, data_type: str, nullable=False): 99 | return super(Column, cls).__new__(cls, name, data_type, nullable) 100 | 101 | def create_statement(self): 102 | return compile_column(self.name, self.data_type, self.nullable) 103 | 104 | 105 | class PrimaryKey(namedtuple('PrimaryKey', ['column_names'])): 106 | __slots__ = () 107 | 108 | def __new__(cls, column_names: list): 109 | return super(PrimaryKey, cls).__new__(cls, column_names) 110 | 111 | def create_statement(self): 112 | return compile_primary_key(self.column_names) 113 | 114 | 115 | class View: 116 | """Postgresql View statement formatter. 117 | 118 | Attributes 119 | ---------- 120 | name : view name 121 | statement: the select or join statement the view is based on. 122 | """ 123 | 124 | def __init__(self, name: str, statement: str): 125 | self.name = name 126 | self.statement = statement 127 | 128 | def drop_statement(self): 129 | return 'DROP VIEW IF EXISTS {};'.format(self.name) 130 | 131 | def create_statement(self): 132 | return 'CREATE VIEW {name} AS {statement};'.format( 133 | name=self.name, statement=self.statement) 134 | 135 | 136 | def make_delete_table(table: Table, delete_prefix='delete_from__') -> Table: 137 | """Table referencing a delete from using primary key join.""" 138 | 139 | name = delete_prefix + table.name 140 | primary_key = table.primary_key 141 | key_names = set(primary_key.column_names) 142 | columns = [column for column in table.columns if column.name in key_names] 143 | table = Table(name, columns, primary_key) 144 | 145 | return table 146 | 147 | 148 | def split_qualified_name(qualified_name: str, schema='public'): 149 | if '.' in qualified_name: 150 | schema, table = qualified_name.split('.') 151 | else: 152 | table = qualified_name 153 | 154 | return schema, table 155 | 156 | 157 | def order_table_columns(table: Table, column_names: list) -> Table: 158 | """Record table column(s) and primary key columns by specified order.""" 159 | 160 | unordered_columns = table.column_names 161 | index_order = (unordered_columns.index(name) for name in column_names) 162 | ordered_columns = [table.columns[i] for i in index_order] 163 | ordered_pkey_names = [column for column in column_names 164 | if column in table.primary_key_columns] 165 | primary_key = PrimaryKey(ordered_pkey_names) 166 | 167 | return Table(table.name, ordered_columns, primary_key, table.schema) 168 | -------------------------------------------------------------------------------- /postpy/connections.py: -------------------------------------------------------------------------------- 1 | import os 2 | import psycopg2 3 | 4 | 5 | __all__ = ('connect',) 6 | 7 | 8 | def connect(host=None, database=None, user=None, password=None, **kwargs): 9 | """Create a database connection.""" 10 | 11 | host = host or os.environ['PGHOST'] 12 | database = database or os.environ['PGDATABASE'] 13 | user = user or os.environ['PGUSER'] 14 | password = password or os.environ['PGPASSWORD'] 15 | 16 | return psycopg2.connect(host=host, 17 | database=database, 18 | user=user, 19 | password=password, 20 | **kwargs) 21 | -------------------------------------------------------------------------------- /postpy/data_types.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime 2 | from decimal import Decimal 3 | from types import MappingProxyType 4 | 5 | from foil.compose import create_quantiles 6 | from psycopg2.extras import NumericRange 7 | 8 | 9 | TYPE_MAP = MappingProxyType({ 10 | 'bool': bool, 11 | 'boolean': bool, 12 | 'smallint': int, 13 | 'integer': int, 14 | 'bigint': int, 15 | 'real': float, 16 | 'float': float, 17 | 'double precision': float, 18 | 'decimal': Decimal, 19 | 'numeric': Decimal, 20 | 'char': str, 21 | 'character': str, 22 | 'text': str, 23 | 'varchar': str, 24 | 'character varying': str, 25 | 'date': date, 26 | 'timestamp': datetime 27 | }) 28 | 29 | 30 | def generate_numeric_range(items, lower_bound, upper_bound): 31 | """Generate postgresql numeric range and label for insertion. 32 | 33 | Parameters 34 | ---------- 35 | items: iterable labels for ranges. 36 | lower_bound: numeric lower bound 37 | upper_bound: numeric upper bound 38 | """ 39 | 40 | quantile_grid = create_quantiles(items, lower_bound, upper_bound) 41 | labels, bounds = (zip(*quantile_grid)) 42 | ranges = ((label, NumericRange(*bound)) 43 | for label, bound in zip(labels, bounds)) 44 | return ranges 45 | -------------------------------------------------------------------------------- /postpy/ddl.py: -------------------------------------------------------------------------------- 1 | """ 2 | ddl.py contains the Data Definition Language for Postgresql Server. 3 | """ 4 | 5 | from psycopg2.extensions import AsIs 6 | 7 | 8 | def compile_qualified_name(table: str, schema='public') -> str: 9 | """Format table's fully qualified name string.""" 10 | 11 | return '{}.{}'.format(schema, table) 12 | 13 | 14 | def compile_create_table(qualified_name: str, column_statement: str, 15 | primary_key_statement: str) -> str: 16 | """Postgresql Create Table statement formatter.""" 17 | 18 | statement = """ 19 | CREATE TABLE {table} ({columns} {primary_keys}); 20 | """.format(table=qualified_name, 21 | columns=column_statement, 22 | primary_keys=primary_key_statement) 23 | return statement 24 | 25 | 26 | def compile_create_temporary_table(table_name: str, 27 | column_statement: str, 28 | primary_key_statement: str) -> str: 29 | """Postgresql Create Temporary Table statement formatter.""" 30 | 31 | statement = """ 32 | CREATE TEMPORARY TABLE {table} ({columns} {primary_keys}); 33 | """.format(table=table_name, 34 | columns=column_statement, 35 | primary_keys=primary_key_statement) 36 | return statement 37 | 38 | 39 | def compile_column(name: str, data_type: str, nullable: bool) -> str: 40 | """Create column definition statement.""" 41 | 42 | null_str = 'NULL' if nullable else 'NOT NULL' 43 | 44 | return '{name} {data_type} {null},'.format(name=name, 45 | data_type=data_type, 46 | null=null_str) 47 | 48 | 49 | def compile_primary_key(column_names): 50 | return 'PRIMARY KEY ({})'.format(', '.join(column_names)) 51 | 52 | 53 | class CreateTableAs: 54 | def __init__(self, table, parent_table, columns=('*',), *, clause): 55 | self.table = table 56 | self.parent_table = parent_table 57 | self.columns = columns 58 | self.column_str = ', '.join(columns) 59 | self.clause = clause 60 | 61 | def compile(self): 62 | statement = '{create} ({select} {clause})'.format( 63 | create=self._create_statement(), 64 | select=self._select_statement(), 65 | clause=self._clause_statement()) 66 | 67 | return statement 68 | 69 | def compile_with_cte(self, common_table_expression): 70 | statement = '{create} (WITH {cte} {select} {clause})'.format( 71 | create=self._create_statement(), 72 | cte=common_table_expression, 73 | select=self._select_statement(), 74 | clause=self._clause_statement()) 75 | 76 | return statement 77 | 78 | def _create_statement(self): 79 | return 'CREATE TABLE {} AS'.format(self.table) 80 | 81 | def _select_statement(self): 82 | return '\n SELECT {column_str} \n FROM {parent_table}'.format( 83 | column_str=self.column_str, parent_table=self.parent_table) 84 | 85 | def _clause_statement(self): 86 | return '\n WHERE %s' % self.clause 87 | 88 | 89 | class MaterializedView: 90 | """Postgres materialized view declaration formatter.""" 91 | 92 | def __init__(self, name, query='', query_values=None): 93 | self.name = name 94 | self.query = query 95 | self.query_values = query_values 96 | 97 | def create(self, no_data=False): 98 | """Declare materalized view.""" 99 | 100 | if self.query: 101 | ddl_statement = self.compile_create_as() 102 | else: 103 | ddl_statement = self.compile_create() 104 | 105 | if no_data: 106 | ddl_statement += '\nWITH NO DATA' 107 | 108 | return ddl_statement, self.query_values 109 | 110 | def compile_create(self): 111 | """Materalized view.""" 112 | 113 | return 'CREATE MATERIALIZED VIEW {}'.format(AsIs(self.name)) 114 | 115 | def compile_create_as(self): 116 | """Build from a select statement.""" 117 | 118 | return '{} AS \n {}'.format(self.compile_create(), self.query) 119 | 120 | def refresh(self): 121 | """Refresh a materialized view.""" 122 | 123 | return 'REFRESH MATERIALIZED VIEW {}'.format(AsIs(self.name)) 124 | 125 | def drop(self): 126 | return 'DROP MATERIALIZED VIEW {}'.format(AsIs(self.name)) 127 | -------------------------------------------------------------------------------- /postpy/dml.py: -------------------------------------------------------------------------------- 1 | """Data Manipulation Language for Postgresql.""" 2 | 3 | import warnings 4 | 5 | from foil.iteration import chunks 6 | from psycopg2.extras import NamedTupleCursor 7 | 8 | from postpy.base import make_delete_table 9 | from postpy.formatting import PARAM_STYLES, PYFORMAT 10 | from postpy.sql import execute_transaction 11 | from postpy.dml_copy import BulkDmlPrimaryKey, CopyFromCsvBase, copy_from_csv_sql 12 | 13 | 14 | def create_insert_statement(qualified_name, column_names, table_alias='', 15 | param_style=PYFORMAT): 16 | 17 | column_string = ', '.join(column_names) 18 | param_func = PARAM_STYLES.get(param_style) 19 | value_string = param_func(column_names) 20 | 21 | if table_alias: 22 | table_alias = ' AS %s' % table_alias 23 | 24 | return 'INSERT INTO {0}{1} ({2}) VALUES ({3})'.format(qualified_name, 25 | table_alias, 26 | column_string, 27 | value_string) 28 | 29 | 30 | def insert(conn, qualified_name: str, column_names, records): 31 | """Insert a collection of namedtuple records.""" 32 | 33 | query = create_insert_statement(qualified_name, column_names) 34 | 35 | with conn: 36 | with conn.cursor(cursor_factory=NamedTupleCursor) as cursor: 37 | for record in records: 38 | cursor.execute(query, record) 39 | 40 | 41 | def insert_many(conn, tablename, column_names, records, chunksize=2500): 42 | """Insert many records by chunking data into insert statements. 43 | 44 | Notes 45 | ----- 46 | records should be Iterable collection of namedtuples or tuples. 47 | """ 48 | 49 | groups = chunks(records, chunksize) 50 | column_str = ','.join(column_names) 51 | insert_template = 'INSERT INTO {table} ({columns}) VALUES {values}'.format( 52 | table=tablename, columns=column_str, values='{0}') 53 | 54 | with conn: 55 | with conn.cursor() as cursor: 56 | for recs in groups: 57 | record_group = list(recs) 58 | records_template_str = ','.join(['%s'] * len(record_group)) 59 | insert_query = insert_template.format(records_template_str) 60 | cursor.execute(insert_query, record_group) 61 | 62 | 63 | def upsert_records(conn, records, upsert_statement): 64 | """Upsert records.""" 65 | 66 | with conn: 67 | with conn.cursor() as cursor: 68 | for record in records: 69 | cursor.execute(upsert_statement, record) 70 | 71 | 72 | def format_upsert(qualified_name, column_names, constraint, clause='', 73 | table_alias='current', param_style=PYFORMAT): 74 | insert_template = create_insert_statement( 75 | qualified_name, column_names, table_alias=table_alias, 76 | param_style=param_style 77 | ) 78 | statement = format_upsert_expert(insert_template, column_names, 79 | constraint, clause, table_alias) 80 | return statement 81 | 82 | 83 | def format_upsert_expert(insert_template, column_names, constraint, clause='', 84 | table_alias='current'): 85 | 86 | constraint_str = ', '.join(constraint) 87 | non_key_columns = [column for column in column_names if column not in constraint] 88 | 89 | if non_key_columns: 90 | non_key_column_str = ', '.join(non_key_columns) 91 | excluded_str = ', '.join('EXCLUDED.' + column for column in non_key_columns) 92 | action = ( 93 | ' DO UPDATE' 94 | ' SET ({non_key_columns}) = ({excluded})' 95 | ' {clause}').format(non_key_columns=non_key_column_str, 96 | excluded=excluded_str, clause=clause, 97 | table_alias=table_alias) 98 | else: 99 | action = ' DO NOTHING' 100 | 101 | statement = ( 102 | '{insert_template}' 103 | ' ON CONFLICT ({constraint})' 104 | '{action}').format(insert_template=insert_template, 105 | constraint=constraint_str, 106 | action=action) 107 | 108 | return statement 109 | 110 | 111 | def format_upsert_primary_key(qualified_name, column_names, primary_key_names, 112 | param_style=PYFORMAT): 113 | warning = 'Deprecation Warning. Function will be removed as of version 0.1.0.' 114 | warnings.warn(warning, DeprecationWarning) 115 | 116 | query = format_upsert(qualified_name, column_names, primary_key_names, 117 | param_style=param_style) 118 | 119 | return query 120 | 121 | 122 | class DeleteManyPrimaryKey: 123 | """Deletes subset of table rows. 124 | 125 | Uses DELETE FROM in conjunction with a where clause 126 | through a temporary table reference containing primary keys 127 | to delete. 128 | """ 129 | 130 | def __init__(self, table): 131 | self.table = table 132 | self.delete_table = make_delete_table(table) 133 | 134 | def __call__(self, conn, records, chunksize=2500): 135 | with conn: 136 | execute_transaction(conn, 137 | [self.delete_table.create_temporary_statement()]) 138 | 139 | insert_many( 140 | conn, self.delete_table.name, self.delete_table.column_names, 141 | records, chunksize=chunksize) 142 | 143 | delete_from_statement = delete_joined_table_sql( 144 | self.table.qualified_name, self.delete_table.name, 145 | self.table.primary_key.column_names) 146 | 147 | execute_transaction(conn, [delete_from_statement]) 148 | 149 | with conn: 150 | execute_transaction(conn, 151 | [self.delete_table.drop_temporary_statement()]) 152 | 153 | 154 | class UpsertPrimaryKey: 155 | def __init__(self, qualified_name, column_names, primary_key_names): 156 | 157 | self.query = format_upsert( 158 | qualified_name, column_names, primary_key_names 159 | ) 160 | 161 | def __call__(self, conn, records): 162 | upsert_records(conn, records, self.query) 163 | 164 | 165 | def delete_joined_table_sql(qualified_name, removing_qualified_name, primary_key): 166 | """SQL statement for a joined delete from. 167 | Generate SQL statement for deleting the intersection of rows between 168 | both tables from table referenced by tablename. 169 | """ 170 | 171 | condition_template = 't.{}=d.{}' 172 | where_clause = ' AND '.join(condition_template.format(pkey, pkey) 173 | for pkey in primary_key) 174 | delete_statement = ( 175 | 'DELETE FROM {table} t' 176 | ' USING {delete_table} d' 177 | ' WHERE {where_clause}').format(table=qualified_name, 178 | delete_table=removing_qualified_name, 179 | where_clause=where_clause) 180 | return delete_statement 181 | 182 | 183 | def compile_truncate_table(qualfied_name): 184 | """Delete all data in table and vacuum.""" 185 | 186 | return 'TRUNCATE %s CASCADE;' % qualfied_name 187 | 188 | 189 | class CopyFrom(CopyFromCsvBase): 190 | """Copy from CSV file object.""" 191 | 192 | def __call__(self, conn, file_object): 193 | with conn.cursor() as cursor: 194 | cursor.copy_expert(self.copy_sql, file_object) 195 | 196 | 197 | class CopyFromUpsert(BulkDmlPrimaryKey): 198 | """Upsert subset of table rows contained in a file stream. 199 | 200 | Upsert rows based on same composite primary key. 201 | """ 202 | 203 | TEMP_PREFIX = 'tmp_bulk_upsert' 204 | 205 | _INSERT_TEMPLATE = ( 206 | 'INSERT INTO {table} ({columns})\n' 207 | ' SELECT {columns} FROM {temp_table}\n' 208 | ) 209 | 210 | def make_dml_query(self): 211 | query = self._INSERT_TEMPLATE.format( 212 | table=self.table.qualified_name, columns=self.column_str, 213 | temp_table=self.copy_table.name 214 | ) 215 | query = format_upsert_expert(query, self.table.column_names, 216 | self.table.primary_key_columns) 217 | 218 | return query 219 | 220 | 221 | class CopyFromDelete(BulkDmlPrimaryKey): 222 | """Deletes subset of table rows contained in a file stream. 223 | 224 | Deletes rows with matching composite primary key. 225 | """ 226 | 227 | TEMP_PREFIX = 'delete_from' 228 | 229 | def make_dml_query(self): 230 | delete_from_statement = delete_joined_table_sql( 231 | self.table.qualified_name, self.copy_table.name, 232 | self.table.primary_key.column_names) 233 | 234 | return delete_from_statement 235 | 236 | 237 | def copy_from_csv(conn, file, qualified_name: str, delimiter=',', encoding='utf8', 238 | null_str='', header=True, escape_str='\\', quote_char='"', 239 | force_not_null=None, force_null=None): 240 | """Copy file-like object to database table. 241 | 242 | Notes 243 | ----- 244 | Implementation defaults to postgres standard except for encoding. 245 | Postgres falls back on client encoding, while function defaults to utf-8. 246 | 247 | References 248 | ---------- 249 | https://www.postgresql.org/docs/current/static/sql-copy.html 250 | 251 | """ 252 | 253 | copy_sql = copy_from_csv_sql(qualified_name, delimiter, encoding, 254 | null_str=null_str, header=header, 255 | escape_str=escape_str, quote_char=quote_char, 256 | force_not_null=force_not_null, 257 | force_null=force_null) 258 | 259 | with conn: 260 | with conn.cursor() as cursor: 261 | cursor.copy_expert(copy_sql, file) 262 | -------------------------------------------------------------------------------- /postpy/dml_copy.py: -------------------------------------------------------------------------------- 1 | """Copy specific dml statement generators.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from random import randint 5 | 6 | from postpy.base import Table 7 | from postpy.pg_encodings import get_postgres_encoding 8 | 9 | 10 | class CopyFromCsvBase(ABC): 11 | def __init__(self, table, delimiter=',', encoding='utf8', 12 | null_str='', header=True, escape_str='\\', quote_char='"', 13 | force_not_null=None, force_null=None): 14 | self.table = table 15 | self.copy_table, self.copy_name = self.get_copy_table(self.table) 16 | self.copy_sql = copy_from_csv_sql(self.copy_name, 17 | delimiter, encoding, 18 | null_str=null_str, header=header, 19 | escape_str=escape_str, 20 | quote_char=quote_char, 21 | force_not_null=force_not_null, 22 | force_null=force_null) 23 | 24 | def get_copy_table(self, table): 25 | return table, table.qualified_name 26 | 27 | @abstractmethod 28 | def __call__(self, conn, file_object): 29 | NotImplemented 30 | 31 | 32 | class BulkDmlPrimaryKey(CopyFromCsvBase): 33 | """Row record changes on primary key join.""" 34 | 35 | RAND_MIN = 0 36 | RAND_MAX = 10000000 37 | _TEMP_FORMATTER = '{temp_prefix}_{random}_{table_name}' 38 | TEMP_PREFIX = '' 39 | 40 | def __init__(self, table, **kwargs): 41 | super().__init__(table, **kwargs) 42 | self.dml_query = self.make_dml_query() 43 | 44 | def __call__(self, conn, file_object): 45 | with conn.cursor() as cursor: 46 | cursor.execute(self.copy_table.create_temporary_statement()) 47 | cursor.copy_expert(self.copy_sql, file_object) 48 | cursor.execute(self.dml_query) 49 | cursor.execute(self.copy_table.drop_temporary_statement()) 50 | 51 | def get_copy_table(self, table): 52 | temp_table = self.make_temp_copy_table() 53 | qualified_name = temp_table.name 54 | 55 | return temp_table, qualified_name 56 | 57 | def make_temp_copy_table(self): 58 | temp_table_name = self.generate_temp_table_name() 59 | table_attributes = self.table._asdict() 60 | table_attributes['name'] = temp_table_name 61 | 62 | return Table(**table_attributes) 63 | 64 | def generate_temp_table_name(self): 65 | rand_char = randint(self.RAND_MIN, self.RAND_MAX) 66 | temp_table_name = self._TEMP_FORMATTER.format( 67 | temp_prefix=self.TEMP_PREFIX, 68 | table_name=self.table.name, 69 | random=rand_char 70 | ) 71 | 72 | return temp_table_name 73 | 74 | @property 75 | def column_str(self): 76 | return ', '.join(self.table.column_names) 77 | 78 | @abstractmethod 79 | def make_dml_query(self): 80 | NotImplemented 81 | 82 | 83 | def copy_from_csv_sql(qualified_name: str, delimiter=',', encoding='utf8', 84 | null_str='', header=True, escape_str='\\', quote_char='"', 85 | force_not_null=None, force_null=None): 86 | """Generate copy from csv statement.""" 87 | 88 | options = [] 89 | options.append("DELIMITER '%s'" % delimiter) 90 | options.append("NULL '%s'" % null_str) 91 | 92 | if header: 93 | options.append('HEADER') 94 | 95 | options.append("QUOTE '%s'" % quote_char) 96 | options.append("ESCAPE '%s'" % escape_str) 97 | 98 | if force_not_null: 99 | options.append(_format_force_not_null(column_names=force_not_null)) 100 | 101 | if force_null: 102 | options.append(_format_force_null(column_names=force_null)) 103 | 104 | postgres_encoding = get_postgres_encoding(encoding) 105 | options.append("ENCODING '%s'" % postgres_encoding) 106 | 107 | copy_sql = _format_copy_csv_sql(qualified_name, copy_options=options) 108 | 109 | return copy_sql 110 | 111 | 112 | def _format_copy_csv_sql(qualified_name: str, copy_options: list) -> str: 113 | options_str = ',\n '.join(copy_options) 114 | 115 | copy_sql = """\ 116 | COPY {table} FROM STDIN 117 | WITH ( 118 | FORMAT CSV, 119 | {options})""".format(table=qualified_name, options=options_str) 120 | 121 | return copy_sql 122 | 123 | 124 | def _format_force_not_null(column_names): 125 | column_str = ', '.join(column_names) 126 | force_not_null_str = 'FORCE_NOT_NULL ({})'.format(column_str) 127 | return force_not_null_str 128 | 129 | 130 | def _format_force_null(column_names): 131 | column_str = ', '.join(column_names) 132 | force_null_str = 'FORCE_NULL ({})'.format(column_str) 133 | return force_null_str 134 | -------------------------------------------------------------------------------- /postpy/extensions.py: -------------------------------------------------------------------------------- 1 | import psycopg2 2 | from psycopg2._psycopg import AsIs 3 | 4 | 5 | def install_extension(conn, extension: str): 6 | """Install Postgres extension.""" 7 | 8 | query = 'CREATE EXTENSION IF NOT EXISTS "%s";' 9 | 10 | with conn.cursor() as cursor: 11 | cursor.execute(query, (AsIs(extension),)) 12 | 13 | installed = check_extension(conn, extension) 14 | 15 | if not installed: 16 | raise psycopg2.ProgrammingError( 17 | 'Postgres extension failed installation.', extension 18 | ) 19 | 20 | 21 | def check_extension(conn, extension: str) -> bool: 22 | """Check to see if an extension is installed.""" 23 | 24 | query = 'SELECT installed_version FROM pg_available_extensions WHERE name=%s;' 25 | 26 | with conn.cursor() as cursor: 27 | cursor.execute(query, (extension,)) 28 | result = cursor.fetchone() 29 | 30 | if result is None: 31 | raise psycopg2.ProgrammingError( 32 | 'Extension is not available for installation.', extension 33 | ) 34 | else: 35 | extension_version = result[0] 36 | 37 | return bool(extension_version) 38 | -------------------------------------------------------------------------------- /postpy/fixtures.py: -------------------------------------------------------------------------------- 1 | from unittest.util import safe_repr 2 | from functools import wraps 3 | 4 | from psycopg2.extras import NamedTupleCursor 5 | 6 | from postpy.connections import connect 7 | 8 | 9 | PG_UPSERT_VERSION = (9, 5) 10 | 11 | 12 | class PostgreSQLFixture(object): 13 | 14 | @classmethod 15 | def setUpClass(cls): 16 | cls.conn = connect() 17 | cls._prep() 18 | 19 | @classmethod 20 | def _prep(cls): 21 | pass 22 | 23 | @classmethod 24 | def _clean(cls): 25 | pass 26 | 27 | @classmethod 28 | def tearDownClass(cls): 29 | cls._clean() 30 | cls.conn.close() 31 | 32 | 33 | class PostgresDmlFixture(PostgreSQLFixture): 34 | 35 | def tearDown(self): 36 | with self.conn as conn: 37 | with conn.cursor() as cursor: 38 | cursor.execute('DROP TABLE IF EXISTS %s;' % self.table_name) 39 | self.conn.commit() 40 | self.conn.commit() 41 | 42 | 43 | class PostgresStatementFixture(object): 44 | maxDiff = True 45 | 46 | def assertSQLStatementEqual(self, first, second, msg=None): 47 | if squeeze_whitespace(first) != squeeze_whitespace(second): 48 | standardMsg = 'SQL statement {0} != {1}'.format( 49 | safe_repr(first), safe_repr(second)) 50 | self.fail(self._formatMessage(msg, standardMsg)) 51 | 52 | 53 | def get_records(conn, qualified_name): 54 | with conn: 55 | with conn.cursor(cursor_factory=NamedTupleCursor) as cursor: 56 | cursor.execute('select * from %s;' % qualified_name) 57 | records = cursor.fetchall() 58 | return records 59 | 60 | 61 | def fetch_one_result(conn, result_query): 62 | with conn.cursor() as cursor: 63 | cursor.execute(result_query) 64 | result = cursor.fetchone() 65 | return result 66 | 67 | 68 | def skipPGVersionBefore(*ver): 69 | """Skip PG versions below specific version i.e. (9, 5).""" 70 | 71 | ver = ver + (0,) * (3 - len(ver)) 72 | 73 | def skip_before_postgres_(func): 74 | @wraps(func) 75 | def skip_before_postgres__(obj, *args, **kwargs): 76 | 77 | if hasattr(obj.conn, 'server_version'): 78 | server_version = obj.conn.server_version 79 | else: # Assume Sqlalchemy 80 | server_version = obj.conn.connection.connection.server_version 81 | 82 | if server_version < int('%d%02d%02d' % ver): 83 | return obj.skipTest("Skipped because PostgreSQL {}".format( 84 | server_version)) 85 | else: 86 | return func(obj, *args, **kwargs) 87 | return skip_before_postgres__ 88 | return skip_before_postgres_ 89 | 90 | 91 | def squeeze_whitespace(text): 92 | """Remove extra whitespace, newline and tab characters from text.""" 93 | 94 | return ' '.join(text.split()) 95 | -------------------------------------------------------------------------------- /postpy/formatting.py: -------------------------------------------------------------------------------- 1 | """Formatting helpers.""" 2 | 3 | from types import MappingProxyType 4 | 5 | PYFORMAT = 'pyformat' 6 | NAMED_STYLE = 'named_style' 7 | 8 | 9 | def pyformat_parameters(parameters): 10 | return ', '.join(['%s']*len(parameters)) 11 | 12 | 13 | def named_style_parameters(parameters): 14 | return ', '.join('%({})s'.format(p) for p in parameters) 15 | 16 | 17 | PARAM_STYLES = MappingProxyType({ 18 | PYFORMAT: pyformat_parameters, 19 | NAMED_STYLE: named_style_parameters 20 | }) 21 | -------------------------------------------------------------------------------- /postpy/pg_encodings.py: -------------------------------------------------------------------------------- 1 | from encodings import normalize_encoding, aliases 2 | from types import MappingProxyType 3 | 4 | from psycopg2.extensions import encodings as _PG_ENCODING_MAP 5 | 6 | 7 | PG_ENCODING_MAP = MappingProxyType(_PG_ENCODING_MAP) 8 | 9 | # python to postgres encoding map 10 | _PYTHON_ENCODING_MAP = { 11 | v: k for k, v in PG_ENCODING_MAP.items() 12 | } 13 | 14 | 15 | def get_postgres_encoding(python_encoding: str) -> str: 16 | """Python to postgres encoding map.""" 17 | 18 | encoding = normalize_encoding(python_encoding.lower()) 19 | encoding_ = aliases.aliases[encoding.replace('_', '', 1)].upper() 20 | pg_encoding = PG_ENCODING_MAP[encoding_.replace('_', '')] 21 | 22 | return pg_encoding 23 | -------------------------------------------------------------------------------- /postpy/sql.py: -------------------------------------------------------------------------------- 1 | from contextlib import closing 2 | from typing import Iterable 3 | 4 | import psycopg2 5 | from psycopg2.extras import NamedTupleCursor, RealDictCursor 6 | 7 | from postpy import connect 8 | 9 | 10 | def execute_transaction(conn, statements: Iterable): 11 | """Execute several statements in single DB transaction.""" 12 | 13 | with conn: 14 | with conn.cursor() as cursor: 15 | for statement in statements: 16 | cursor.execute(statement) 17 | conn.commit() 18 | 19 | 20 | def execute_transactions(conn, statements: Iterable): 21 | """Execute several statements each as a single DB transaction.""" 22 | 23 | with conn.cursor() as cursor: 24 | for statement in statements: 25 | try: 26 | cursor.execute(statement) 27 | conn.commit() 28 | except psycopg2.ProgrammingError: 29 | conn.rollback() 30 | 31 | 32 | def execute_closing_transaction(statements: Iterable): 33 | """Open a connection, commit a transaction, and close it.""" 34 | 35 | with closing(connect()) as conn: 36 | with conn.cursor() as cursor: 37 | for statement in statements: 38 | cursor.execute(statement) 39 | 40 | 41 | def select(conn, query: str, params=None, name=None, itersize=5000): 42 | """Return a select statement's results as a namedtuple. 43 | 44 | Parameters 45 | ---------- 46 | conn : database connection 47 | query : select query string 48 | params : query parameters. 49 | name : server side cursor name. defaults to client side. 50 | itersize : number of records fetched by server. 51 | """ 52 | 53 | with conn.cursor(name, cursor_factory=NamedTupleCursor) as cursor: 54 | cursor.itersize = itersize 55 | cursor.execute(query, params) 56 | 57 | for result in cursor: 58 | yield result 59 | 60 | 61 | def select_dict(conn, query: str, params=None, name=None, itersize=5000): 62 | """Return a select statement's results as dictionary. 63 | 64 | Parameters 65 | ---------- 66 | conn : database connection 67 | query : select query string 68 | params : query parameters. 69 | name : server side cursor name. defaults to client side. 70 | itersize : number of records fetched by server. 71 | """ 72 | 73 | with conn.cursor(name, cursor_factory=RealDictCursor) as cursor: 74 | cursor.itersize = itersize 75 | cursor.execute(query, params) 76 | 77 | for result in cursor: 78 | yield result 79 | 80 | 81 | def select_each(conn, query: str, parameter_groups, name=None): 82 | """Run select query for each parameter set in single transaction.""" 83 | 84 | with conn: 85 | with conn.cursor(name=name) as cursor: 86 | for parameters in parameter_groups: 87 | cursor.execute(query, parameters) 88 | yield cursor.fetchone() 89 | 90 | 91 | def query_columns(conn, query, name=None): 92 | """Lightweight query to retrieve column list of select query. 93 | 94 | Notes 95 | ----- 96 | Strongly urged to specify a cursor name for performance. 97 | """ 98 | 99 | with conn.cursor(name) as cursor: 100 | cursor.itersize = 1 101 | cursor.execute(query) 102 | cursor.fetchmany(0) 103 | column_names = [column.name for column in cursor.description] 104 | 105 | return column_names 106 | -------------------------------------------------------------------------------- /postpy/uuids.py: -------------------------------------------------------------------------------- 1 | """Configure psycopg2 to support UUID conversion.""" 2 | 3 | import psycopg2.extras 4 | 5 | from postpy.admin import install_extensions 6 | 7 | 8 | CRYPTO_EXTENSION = 'pgcrypto' 9 | UUID_OSSP_EXTENSION = 'uuid-ossp' 10 | 11 | 12 | def register_client(): 13 | """Have psycopg2 marshall UUID objects automatically.""" 14 | 15 | psycopg2.extras.register_uuid() 16 | 17 | 18 | def register_crypto(): 19 | """Support for UUID's on server side. 20 | 21 | Lighter dependency than uuid-ossp supporting 22 | random_uuid_function for UUID generation. 23 | """ 24 | 25 | install_extensions([CRYPTO_EXTENSION]) 26 | 27 | 28 | def register_uuid(): 29 | """Support for UUID's on server side. 30 | 31 | Notes 32 | ----- 33 | uuid-ossp can be problematic on some platforms. See: 34 | https://www.postgresql.org/docs/current/static/uuid-ossp.html 35 | """ 36 | 37 | install_extensions([UUID_OSSP_EXTENSION]) 38 | 39 | 40 | def random_uuid_function(schema=None): 41 | """Cryptographic random UUID function. 42 | 43 | Generates random database side UUID's. 44 | 45 | Notes 46 | ----- 47 | Lighter dependency than uuid-ossp, but higher 48 | fragmentation on disk if used as auto-generating primary key UUID. 49 | """ 50 | 51 | return '{}gen_random_uuid()'.format(_format_schema(schema)) 52 | 53 | 54 | def uuid_sequence_function(schema=None): 55 | """Sequential UUID generation. 56 | 57 | Sequential UUID creation on database side offering 58 | less table fragmentation issues when used as UUID primary key. 59 | """ 60 | 61 | return '{}uuid_generate_v1mc()'.format(_format_schema(schema)) 62 | 63 | 64 | def _format_schema(schema): 65 | return '{}.'.format(schema) if schema else '' 66 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | norecursedirs = .tox build dist *.egg-info 4 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | pylint 3 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-cov 2 | pytest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | foil 2 | psycopg2 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [wheel] 2 | python-tag = py35 3 | 4 | [bdist_wheel] 5 | python-tag = py35 6 | 7 | [aliases] 8 | test=pytest 9 | 10 | [flake8] 11 | ignore = E302,E402,F403,E265,E201,E124,E202,E123,E731 12 | max-line-length = 90 13 | exclude = .git,__pycache__,.tox,*.egg 14 | max-complexity = 15 15 | 16 | [metadata] 17 | description-file = README.md 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from setuptools import find_packages, setup 5 | 6 | ENCODING = 'utf-8' 7 | PACKAGE_NAME = 'postpy' 8 | 9 | local_directory = os.path.abspath(os.path.dirname(__file__)) 10 | version_path = os.path.join(local_directory, PACKAGE_NAME, '_version.py') 11 | 12 | version_ns = {} 13 | with open(version_path, 'r', encoding=ENCODING) as f: 14 | exec(f.read(), {}, version_ns) 15 | 16 | 17 | def get_requirements(requirement_file): 18 | requirements = list( 19 | open(requirement_file, 'r', 20 | encoding=ENCODING).read().strip().split('\r\n')) 21 | return requirements 22 | 23 | 24 | setup(name=PACKAGE_NAME, 25 | packages=find_packages(exclude=('tests',)), 26 | include_package_data=True, 27 | version=version_ns['__version__'], 28 | license='MIT', 29 | description='Postgresql utilities for ETL and data processing.', 30 | url='https://github.com/portfoliome/postpy', 31 | author='Philip Martin', 32 | author_email='philip.martin@censible.co', 33 | classifiers=[ 34 | 'Development Status :: 5 - Production/Stable', 35 | 'Intended Audience :: Developers', 36 | 'Natural Language :: English', 37 | 'Programming Language :: Python :: 3.5', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Topic :: Utilities', 40 | ], 41 | keywords='ETL data postgres', 42 | install_requires=get_requirements('requirements.txt'), 43 | extras_require={ 44 | 'develop': get_requirements('requirements-dev.txt'), 45 | 'test': get_requirements('requirements-test.txt') 46 | }, 47 | zip_safe=False) 48 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/portfoliome/postpy/fe26199131b15295fc5f669a0ad2a7f47bf490ee/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_admin.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy.admin import (get_user_tables, get_primary_keys, 4 | get_column_metadata, install_extensions, 5 | reflect_table, reset) 6 | from postpy.base import Database, Column, PrimaryKey, Table 7 | from postpy.connections import connect 8 | from postpy.fixtures import PostgreSQLFixture 9 | 10 | 11 | class TestTableStats(PostgreSQLFixture, unittest.TestCase): 12 | 13 | @classmethod 14 | def _prep(cls): 15 | cls.conn.autocommit = True 16 | cls.schema = 'stats_test' 17 | cls.table = 'admin_table_tests' 18 | create_table_statement = """\ 19 | CREATE TABLE {schema}.{table} ( 20 | mycol CHAR(2), 21 | mycol2 CHAR(3) NULL, 22 | PRIMARY KEY (mycol));""".format(schema=cls.schema, 23 | table=cls.table) 24 | 25 | with cls.conn.cursor() as cursor: 26 | cursor.execute('CREATE SCHEMA {};'.format(cls.schema)) 27 | cursor.execute(create_table_statement) 28 | 29 | def test_get_user_tables(self): 30 | 31 | expected = (self.schema, self.table) 32 | result = get_user_tables(self.conn) 33 | 34 | self.assertIn(expected, result) 35 | 36 | def test_get_column_meta_data(self): 37 | expected = [ 38 | {'name': 'mycol', 39 | 'data_type': 'character(2)', 40 | 'nullable': False}, 41 | {'name': 'mycol2', 42 | 'data_type': 'character(3)', 43 | 'nullable': True} 44 | ] 45 | result = list( 46 | get_column_metadata(self.conn, self.table, schema=self.schema) 47 | ) 48 | 49 | self.assertEqual(expected, result) 50 | 51 | def test_get_primary_keys(self): 52 | expected = ['mycol'] 53 | result = list(get_primary_keys(self.conn, self.table, self.schema)) 54 | 55 | self.assertEqual(expected, result) 56 | 57 | def test_reflect_table(self): 58 | columns = [Column('mycol', data_type='character(2)', nullable=False), 59 | Column('mycol2', data_type='character(3)', nullable=True)] 60 | primary_key = PrimaryKey(['mycol']) 61 | 62 | expected = Table(self.table, columns, primary_key, schema=self.schema) 63 | result = reflect_table(self.conn, self.table, self.schema) 64 | 65 | self.assertEqual(expected, result) 66 | 67 | @classmethod 68 | def _clean(cls): 69 | statement = 'DROP SCHEMA IF EXISTS {} CASCADE;'.format(cls.schema) 70 | 71 | with cls.conn.cursor() as cursor: 72 | cursor.execute(statement) 73 | 74 | 75 | class TestDatabase(unittest.TestCase): 76 | 77 | def setUp(self): 78 | self.db = Database('reset_db_test') 79 | self.db_query = """SELECT datname 80 | FROM pg_database 81 | WHERE datistemplate=false;""" 82 | self.conn = connect() 83 | self.conn.autocommit = True 84 | 85 | def test_reset(self): 86 | reset(self.db.name) 87 | 88 | with self.conn.cursor() as cursor: 89 | cursor.execute(self.db_query) 90 | result = [item[0] for item in cursor.fetchall()] 91 | 92 | self.assertIn(self.db.name, result) 93 | 94 | def tearDown(self): 95 | with self.conn.cursor() as cursor: 96 | cursor.execute(self.db.drop_statement()) 97 | 98 | self.conn.close() 99 | 100 | 101 | class TestExtensions(PostgreSQLFixture, unittest.TestCase): 102 | @classmethod 103 | def _prep(cls): 104 | cls.pg_extension = 'sslinfo' 105 | cls.conn.autocommit = True 106 | 107 | def test_install_extensions(self): 108 | 109 | install_extensions([self.pg_extension]) 110 | 111 | @classmethod 112 | def _clean(cls): 113 | statement = 'DROP EXTENSION IF EXISTS {};'.format(cls.pg_extension) 114 | 115 | with cls.conn.cursor() as cursor: 116 | cursor.execute(statement) 117 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy.base import (Schema, Column, Table, PrimaryKey, View, 4 | make_delete_table, order_table_columns, 5 | split_qualified_name) 6 | from postpy.fixtures import PostgreSQLFixture, PostgresStatementFixture 7 | 8 | 9 | def table_columns(): 10 | columns = [ 11 | Column(name='city', data_type='VARCHAR(50)', nullable=False), 12 | Column(name='state', data_type='CHAR(2)', nullable=False), 13 | Column(name='population', data_type='INTEGER', nullable=True) 14 | ] 15 | return columns 16 | 17 | 18 | def table_primary_keys(): 19 | return PrimaryKey(['city', 'state']) 20 | 21 | 22 | class TestColumn(PostgresStatementFixture, unittest.TestCase): 23 | def setUp(self): 24 | self.column_name = 'brand' 25 | self.data_type = 'VARCHAR(20)' 26 | 27 | def test_create_column_statement(self): 28 | column = Column(self.column_name, self.data_type) 29 | 30 | expected = 'brand VARCHAR(20) NOT NULL,' 31 | result = column.create_statement() 32 | 33 | self.assertSQLStatementEqual(expected, result) 34 | 35 | def test_create_null_column_statement(self): 36 | column = Column(self.column_name, self.data_type, nullable=True) 37 | 38 | expected = 'brand VARCHAR(20) NULL,' 39 | result = column.create_statement() 40 | 41 | self.assertSQLStatementEqual(expected, result) 42 | 43 | 44 | class TestPrimaryKey(PostgresStatementFixture, unittest.TestCase): 45 | def test_primary_key_create(self): 46 | pkey = PrimaryKey(['brand', 'item']) 47 | 48 | expected = 'PRIMARY KEY (brand, item)' 49 | result = pkey.create_statement() 50 | 51 | self.assertSQLStatementEqual(expected, result) 52 | 53 | 54 | class TestSchemaStatements(PostgresStatementFixture, unittest.TestCase): 55 | def setUp(self): 56 | self.schema_name = 'test_schema' 57 | self.schema = Schema(self.schema_name) 58 | 59 | def test_create_schema_statement(self): 60 | expected = 'CREATE SCHEMA IF NOT EXISTS test_schema;' 61 | result = self.schema.create_statement() 62 | 63 | self.assertSQLStatementEqual(expected, result) 64 | 65 | def test_drop_schema_statement(self): 66 | expected = 'DROP SCHEMA IF EXISTS test_schema CASCADE;' 67 | result = self.schema.drop_statement() 68 | 69 | self.assertSQLStatementEqual(expected, result) 70 | 71 | 72 | class TestTableDDL(PostgresStatementFixture, unittest.TestCase): 73 | def setUp(self): 74 | self.schema = 'ddl_schema' 75 | self.tablename = 'create_table_test' 76 | self.qualified_name = 'ddl_schema.create_table_test' 77 | self.column_statement = ("city VARCHAR(50) NOT NULL," 78 | " state CHAR(2) NOT NULL," 79 | " population INTEGER NULL,") 80 | self.primary_key_statement = 'PRIMARY KEY (city, state)' 81 | self.columns = table_columns() 82 | self.primary_keys = table_primary_keys() 83 | 84 | self.table = Table(self.tablename, self.columns, self.primary_keys, schema=self.schema) 85 | 86 | def test_drop_table_statement(self): 87 | expected = 'DROP TABLE IF EXISTS ddl_schema.create_table_test;' 88 | result = self.table.drop_statement() 89 | 90 | self.assertSQLStatementEqual(expected, result) 91 | 92 | def test_column_statement(self): 93 | expected = 'city VARCHAR(50) NOT NULL, state CHAR(2) NOT NULL, population INTEGER NULL,' 94 | result = self.table.column_statement 95 | 96 | self.assertSQLStatementEqual(expected, result) 97 | 98 | def test_primary_key_statement(self): 99 | expected = 'PRIMARY KEY (city, state)' 100 | result = self.table.primary_key_statement 101 | 102 | self.assertSQLStatementEqual(expected, result) 103 | 104 | def test_primary_key_columns(self): 105 | expected = ['city', 'state'] 106 | result = self.table.primary_key_columns 107 | 108 | self.assertEqual(expected, result) 109 | 110 | def test_column_names(self): 111 | expected = ['city', 'state', 'population'] 112 | result = self.table.column_names 113 | 114 | self.assertEqual(expected, result) 115 | 116 | def test_create_statement(self): 117 | expected = ('CREATE TABLE ddl_schema.create_table_test (' 118 | 'city VARCHAR(50) NOT NULL, ' 119 | 'state CHAR(2) NOT NULL, ' 120 | 'population INTEGER NULL, ' 121 | 'PRIMARY KEY (city, state));') 122 | result = self.table.create_statement() 123 | 124 | self.assertSQLStatementEqual(expected, result) 125 | 126 | def test_create_temporary_statement(self): 127 | temp_table = Table(self.tablename, self.columns, self.primary_keys) 128 | 129 | expected = ('CREATE TEMPORARY TABLE create_table_test (' 130 | 'city VARCHAR(50) NOT NULL, ' 131 | 'state CHAR(2) NOT NULL, ' 132 | 'population INTEGER NULL, ' 133 | 'PRIMARY KEY (city, state));') 134 | result = temp_table.create_temporary_statement() 135 | 136 | self.assertSQLStatementEqual(expected, result) 137 | 138 | def test_split_qualified_name(self): 139 | expected = self.schema, self.tablename 140 | result = split_qualified_name(self.qualified_name) 141 | 142 | self.assertEqual(expected, result) 143 | 144 | expected = 'public', self.tablename 145 | result = split_qualified_name(self.tablename) 146 | 147 | self.assertEqual(expected, result) 148 | 149 | 150 | class TestCreateTableEvent(PostgreSQLFixture, unittest.TestCase): 151 | def setUp(self): 152 | self.tablename = 'create_table_event' 153 | self.columns = table_columns() 154 | self.primary_key = table_primary_keys() 155 | self.table_query = "select relname from pg_stat_user_tables where relname=%s;" 156 | self.table = Table(self.tablename, self.columns, self.primary_key) 157 | 158 | def test_create_table(self): 159 | with self.conn.cursor() as cursor: 160 | cursor.execute(self.table.create_statement()) 161 | self.conn.commit() 162 | 163 | with self.conn.cursor() as cursor: 164 | cursor.execute(self.table_query, (self.tablename,)) 165 | table = cursor.fetchone() 166 | 167 | self.assertEqual((self.tablename,), table) 168 | 169 | def test_temporary_table(self): 170 | with self.conn as conn: 171 | with conn.cursor() as cursor: 172 | cursor.execute(self.table.create_temporary_statement()) 173 | cursor.execute(self.table_query, (self.tablename,)) 174 | temp_table = cursor.fetchone() 175 | cursor.execute(self.table.drop_temporary_statement()) 176 | cursor.execute(self.table_query, (self.tablename,)) 177 | no_table = cursor.fetchone() 178 | 179 | self.assertEqual((self.tablename,), temp_table) 180 | self.assertTrue(no_table is None) 181 | 182 | def tearDown(self): 183 | with self.conn.cursor() as cursor: 184 | cursor.execute('DROP TABLE IF EXISTS {table};'.format( 185 | table=self.tablename)) 186 | self.conn.commit() 187 | 188 | 189 | class TestViewStatements(PostgresStatementFixture, unittest.TestCase): 190 | def setUp(self): 191 | self.name = 'test_view' 192 | self.statement = '(select * from other_table)' 193 | self.view = View(self.name, self.statement) 194 | 195 | def test_create_statement(self): 196 | expected = 'CREATE VIEW test_view AS (select * from other_table);' 197 | result = self.view.create_statement() 198 | 199 | self.assertSQLStatementEqual(expected, result) 200 | 201 | def test_drop_statement(self): 202 | expected = 'DROP VIEW IF EXISTS test_view;' 203 | result = self.view.drop_statement() 204 | 205 | self.assertSQLStatementEqual(expected, result) 206 | 207 | 208 | class TestMakeDeleteTable(unittest.TestCase): 209 | def setUp(self): 210 | self.delete_prefix = 'delete_from__' 211 | self.primary_key_column = Column(name='city', 212 | data_type='VARCHAR(50)', 213 | nullable=False) 214 | self.columns = [self.primary_key_column, Column(name='population', 215 | data_type='INTEGER', 216 | nullable=True)] 217 | self.primary_key = PrimaryKey(['city']) 218 | self.tablename = 'original_table' 219 | self.schema = 'to_delete' 220 | self.table = Table(self.tablename, self.columns, 221 | self.primary_key, self.schema) 222 | 223 | def test_make_delete_table(self): 224 | result = make_delete_table(self.table, delete_prefix=self.delete_prefix) 225 | 226 | self.assertEqual(self.delete_prefix + self.tablename, result.name) 227 | self.assertEqual([self.primary_key_column], result.columns) 228 | self.assertEqual(self.primary_key, result.primary_key) 229 | 230 | 231 | class TestReOrderTableColumns(unittest.TestCase): 232 | def setUp(self): 233 | self.maxDiff = None 234 | self.schema = 'foo' 235 | self.table_name = 'foobar' 236 | columns = table_columns() 237 | primary_key = table_primary_keys() 238 | self.table = Table(self.table_name, columns, primary_key, self.schema) 239 | 240 | def test_reorder_columns(self): 241 | column_names = ['state', 'city', 'population'] 242 | 243 | expect_columns = [Column('state', 'CHAR(2)', False), 244 | Column('city', 'VARCHAR(50)', False), 245 | Column('population', 'INTEGER', True)] 246 | expect_pkey = PrimaryKey(['state', 'city']) 247 | 248 | expected = Table(self.table_name, expect_columns, expect_pkey, self.schema) 249 | result = order_table_columns(self.table, column_names) 250 | 251 | self.assertEqual(expected, result) 252 | -------------------------------------------------------------------------------- /tests/test_data_types.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy.data_types import generate_numeric_range, NumericRange 4 | 5 | 6 | class TestNumericRange(unittest.TestCase): 7 | def test_generate_numeric_range(self): 8 | items = ('freezing', 'cold', 'cool', 'hot') 9 | lower_bound = 0. 10 | upper_bound = 100. 11 | 12 | expected = [ 13 | ('freezing', NumericRange(0., 25.)), 14 | ('cold', NumericRange(25., 50.)), 15 | ('cool', NumericRange(50., 75.)), 16 | ('hot', NumericRange(75., 100.)), 17 | ] 18 | result = list(generate_numeric_range(items, lower_bound, upper_bound)) 19 | 20 | self.assertEqual(expected, result) 21 | -------------------------------------------------------------------------------- /tests/test_ddl.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy import ddl 4 | from postpy.fixtures import PostgresStatementFixture 5 | 6 | 7 | class TestCompileDDLStatements(PostgresStatementFixture, unittest.TestCase): 8 | def test_compile_qualified_name(self): 9 | schema = 'mockschema' 10 | tablename = 'mocktable' 11 | 12 | expected = 'mockschema.mocktable' 13 | result = ddl.compile_qualified_name(tablename, schema=schema) 14 | 15 | self.assertSQLStatementEqual(expected, result) 16 | 17 | def test_compile_column(self): 18 | expected = "c VARCHAR(45) NOT NULL," 19 | result = ddl.compile_column('c', 'VARCHAR(45)', False) 20 | 21 | self.assertSQLStatementEqual(expected, result) 22 | 23 | def test_compile_null_column(self): 24 | expected = 'c INTEGER NULL,' 25 | result = ddl.compile_column('c', 'INTEGER', True) 26 | 27 | self.assertSQLStatementEqual(expected, result) 28 | 29 | def test_compile_primary_key(self): 30 | expected = 'PRIMARY KEY (c2)' 31 | result = ddl.compile_primary_key(['c2']) 32 | 33 | self.assertSQLStatementEqual(expected, result) 34 | 35 | def test_compile_create_table(self): 36 | expected = 'CREATE TABLE tname (c1 INTEGER NULL, PRIMARY KEY (c1));' 37 | result = ddl.compile_create_table( 38 | qualified_name='tname', 39 | column_statement='c1 INTEGER NULL,', 40 | primary_key_statement='PRIMARY KEY (c1)') 41 | 42 | self.assertSQLStatementEqual(expected, result) 43 | 44 | 45 | class TestCreateTableAs(PostgresStatementFixture, unittest.TestCase): 46 | 47 | def setUp(self): 48 | table = 't' 49 | parent_table = 'p' 50 | columns = ['one', 'two'] 51 | clause = "one=two" 52 | self.create = ddl.CreateTableAs(table, parent_table, columns, clause=clause) 53 | 54 | def test_compile(self): 55 | expected = "CREATE TABLE t AS (\nSELECT one, two FROM p WHERE one=two)" 56 | result = self.create.compile() 57 | 58 | self.assertSQLStatementEqual(expected, result) 59 | 60 | def test_compile_with_cte(self): 61 | cte = 't2 AS (SELECT * FROM t3)' 62 | 63 | expected = "CREATE TABLE t AS (WITH t2 AS (SELECT * FROM t3) SELECT one, two FROM p WHERE one=two)" 64 | result = self.create.compile_with_cte(cte) 65 | 66 | self.assertSQLStatementEqual(expected, result) 67 | 68 | 69 | class TestMaterializedView(PostgresStatementFixture, unittest.TestCase): 70 | def setUp(self): 71 | self.name = 'my_view' 72 | self.query = 'SELECT * FROM FOOBAR' 73 | 74 | def test_compile_create(self): 75 | expected = 'CREATE MATERIALIZED VIEW my_view' 76 | result = ddl.MaterializedView(self.name).compile_create() 77 | 78 | self.assertSQLStatementEqual(expected, result) 79 | 80 | def test_create_as(self): 81 | expected = 'CREATE MATERIALIZED VIEW my_view AS SELECT * FROM FOOBAR' 82 | result = ddl.MaterializedView(self.name, self.query).create()[0] 83 | 84 | self.assertSQLStatementEqual(expected, result) 85 | 86 | def test_no_data(self): 87 | expected = 'CREATE MATERIALIZED VIEW my_view WITH NO DATA' 88 | result = ddl.MaterializedView(self.name).create(no_data=True)[0] 89 | 90 | self.assertSQLStatementEqual(expected, result) 91 | -------------------------------------------------------------------------------- /tests/test_dml.py: -------------------------------------------------------------------------------- 1 | import io 2 | import textwrap 3 | import unittest 4 | from collections import namedtuple 5 | from datetime import date 6 | 7 | from postpy.base import Table, Column, PrimaryKey 8 | from postpy import dml 9 | from postpy.fixtures import (PostgresStatementFixture, skipPGVersionBefore, 10 | get_records, PG_UPSERT_VERSION, PostgresDmlFixture, 11 | fetch_one_result) 12 | 13 | 14 | def make_records(): 15 | columns = ['city', 'state'] 16 | Record = namedtuple('Record', columns) 17 | records = [Record('Chicago', 'IL'), 18 | Record('New York', 'NY'), 19 | Record('Zootopia', None), 20 | Record('Miami', 'FL')] 21 | return columns, records 22 | 23 | 24 | def delimited_text(): 25 | file_content = textwrap.dedent(""" 26 | "city"|"state" 27 | "Chicago"|"IL" 28 | "New York"|"NY" 29 | "Zootopia"|"" 30 | "Miami"|"FL" 31 | """).strip() 32 | 33 | return file_content 34 | 35 | 36 | class TestStatementFormatting(PostgresStatementFixture, unittest.TestCase): 37 | 38 | def setUp(self): 39 | self.maxDiff = None 40 | 41 | def test_create_insert_statement(self): 42 | columns = ['one'] 43 | qualified_name = 'tname' 44 | 45 | expected = "INSERT INTO tname (one) VALUES (%s)" 46 | result = dml.create_insert_statement(qualified_name, columns) 47 | 48 | self.assertSQLStatementEqual(expected, result) 49 | 50 | def test_compile_truncate_table(self): 51 | qualified_name = 'my_schema.my_table' 52 | 53 | expected = 'TRUNCATE my_schema.my_table CASCADE;' 54 | result = dml.compile_truncate_table(qualified_name) 55 | 56 | self.assertSQLStatementEqual(expected, result) 57 | 58 | 59 | class TestInsertRecords(PostgresDmlFixture, unittest.TestCase): 60 | 61 | def setUp(self): 62 | self.columns, self.records = make_records() 63 | self.table_name = 'insert_record_table' 64 | create_table_stmt = """CREATE TABLE {table} ( 65 | city VARCHAR(50), 66 | state char(2) NULL, 67 | PRIMARY KEY (city)); 68 | """.format(table=self.table_name) 69 | 70 | with self.conn.cursor() as cursor: 71 | cursor.execute(create_table_stmt) 72 | self.conn.commit() 73 | 74 | def test_insert(self): 75 | dml.insert(self.conn, self.table_name, self.columns, self.records) 76 | 77 | expected = self.records 78 | result = get_records(self.conn, self.table_name) 79 | 80 | self.assertEqual(expected, result) 81 | 82 | def test_insert_many_namedtuples(self): 83 | dml.insert_many(self.conn, self.table_name, self.columns, 84 | self.records, chunksize=2) 85 | 86 | expected = self.records 87 | result = get_records(self.conn, self.table_name) 88 | 89 | self.assertEqual(expected, result) 90 | 91 | def test_insert_many_tuples(self): 92 | records = [record[:] for record in self.records] 93 | dml.insert_many(self.conn, self.table_name, 94 | self.columns, records, chunksize=4) 95 | 96 | expected = self.records 97 | result = get_records(self.conn, self.table_name) 98 | 99 | self.assertEqual(expected, result) 100 | 101 | def test_copy_from_csv(self): 102 | self.columns, self.records = make_records() 103 | file_object = io.StringIO(delimited_text()) 104 | 105 | dml.copy_from_csv(self.conn, file_object, self.table_name, '|', 106 | force_null=['state'], encoding='utf-8', null_str='') 107 | 108 | result = get_records(self.conn, self.table_name) 109 | 110 | self.assertEqual(self.records, result) 111 | 112 | 113 | class TestUpsert(PostgresDmlFixture, unittest.TestCase): 114 | 115 | def setUp(self): 116 | self.maxDiff = None 117 | self.table_name = 'upsert_test1' 118 | self.column_names = ['ticker', 'report_date', 'score'] 119 | 120 | with self.conn.cursor() as cursor: 121 | cursor.execute("""\ 122 | CREATE TABLE {tablename} ( 123 | ticker CHAR(4), 124 | report_date DATE, 125 | score INT, 126 | PRIMARY KEY(ticker));""".format(tablename=self.table_name)) 127 | self.conn.commit() 128 | constraint = ['ticker'] 129 | clause = 'WHERE current.report_date < EXCLUDED.report_date' 130 | self.upsert_statement = dml.format_upsert(self.table_name, 131 | self.column_names, 132 | constraint, 133 | clause) 134 | self.result_query = "select * from {tablename} where ticker='AAPL'".format( 135 | tablename=self.table_name) 136 | 137 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 138 | def test_new_record(self): 139 | from datetime import date 140 | expected = ('AAPL', date(2014, 4, 1), 5) 141 | dml.upsert_records(self.conn, [expected], self.upsert_statement) 142 | 143 | result = fetch_one_result(self.conn, self.result_query) 144 | 145 | self.assertEqual(expected, result) 146 | 147 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 148 | def test_conflict_replace_record(self): 149 | first = ('AAPL', date(2014, 4, 1), 5) 150 | expected = ('AAPL', date(2015, 4, 1), 5) 151 | dml.upsert_records(self.conn, [first], self.upsert_statement) 152 | dml.upsert_records(self.conn, [expected], self.upsert_statement) 153 | 154 | result = fetch_one_result(self.conn, self.result_query) 155 | 156 | self.assertEqual(expected, result) 157 | 158 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 159 | def test_conflict_no_update(self): 160 | expected = ('AAPL', date(2014, 4, 1), 5) 161 | second = ('AAPL', date(2013, 4, 1), 5) 162 | dml.upsert_records(self.conn, [expected], self.upsert_statement) 163 | dml.upsert_records(self.conn, [second], self.upsert_statement) 164 | 165 | result = fetch_one_result(self.conn, self.result_query) 166 | 167 | self.assertEqual(expected, result) 168 | 169 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 170 | def test_update_on_primary_key(self): 171 | primary_keys = ['ticker'] 172 | upserter = dml.UpsertPrimaryKey(self.table_name, self.column_names, 173 | primary_keys) 174 | first = ('AAPL', date(2014, 4, 1), 5) 175 | expected = ('AAPL', date(2014, 4, 1), 6) 176 | 177 | upserter(self.conn, [first]) 178 | upserter(self.conn, [expected]) 179 | 180 | result = fetch_one_result(self.conn, self.result_query) 181 | 182 | self.assertEqual(expected, result) 183 | 184 | 185 | class TestUpsertPrimary(PostgresStatementFixture, unittest.TestCase): 186 | 187 | def setUp(self): 188 | self.table_name = 'foobar' 189 | self.columns = ['foo', 'bar'] 190 | self.primary_keys = self.columns 191 | 192 | def test_when_all_columns_are_primary_keys(self): 193 | upserter = dml.UpsertPrimaryKey(self.table_name, 194 | self.columns, 195 | self.primary_keys) 196 | 197 | expected = ('INSERT INTO foobar AS current (foo, bar) VALUES (%s, %s)' 198 | ' ON CONFLICT (foo, bar) DO NOTHING') 199 | result = upserter.query 200 | 201 | self.assertSQLStatementEqual(expected, result) 202 | 203 | 204 | class TestBulkCopy(PostgresDmlFixture, unittest.TestCase): 205 | 206 | def setUp(self): 207 | self.table_name = 'insert_record_table' 208 | self.column_names, self.records = make_records() 209 | self.columns = [Column(self.column_names[0], 'VARCHAR(50)'), 210 | Column(self.column_names[1], 'CHAR(2)', nullable=True)] 211 | self.primary_key_names = ['city'] 212 | self.primary_key = PrimaryKey(self.primary_key_names) 213 | self.table = Table(self.table_name, self.columns, self.primary_key) 214 | self.delimiter = '|' 215 | self.force_null = ['state'] 216 | self.null_str = '' 217 | self.insert_query = 'INSERT INTO {} VALUES (%s, %s)'.format( 218 | self.table_name 219 | ) 220 | 221 | with self.conn.cursor() as cursor: 222 | cursor.execute(self.table.create_statement()) 223 | self.conn.commit() 224 | 225 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 226 | def test_upsert_many(self): 227 | records = [('Miami', 'TX'), ('Chicago', 'MI')] 228 | 229 | with self.conn.cursor() as cursor: 230 | cursor.executemany(self.insert_query, records) 231 | self.conn.commit() 232 | 233 | bulk_upserter = dml.CopyFromUpsert( 234 | self.table, delimiter=self.delimiter, null_str=self.null_str, 235 | force_null=self.force_null 236 | ) 237 | file_object = io.StringIO(delimited_text()) 238 | 239 | with self.conn: 240 | bulk_upserter(self.conn, file_object) 241 | 242 | result = get_records(self.conn, self.table_name) 243 | 244 | self.assertEqual(self.records, result) 245 | 246 | def test_copy_table_from_csv(self): 247 | self.columns, self.records = make_records() 248 | file_object = io.StringIO(delimited_text()) 249 | copy_from_table = dml.CopyFrom(self.table, 250 | delimiter=self.delimiter, 251 | null_str=self.null_str, 252 | force_null=self.force_null) 253 | 254 | with self.conn: 255 | copy_from_table(self.conn, file_object) 256 | 257 | result = get_records(self.conn, self.table_name) 258 | 259 | self.assertEqual(self.records, result) 260 | 261 | 262 | class TestBulkCopyAllColumnPrimary(PostgresDmlFixture, unittest.TestCase): 263 | 264 | def setUp(self): 265 | self.table_name = 'insert_record_table' 266 | self.column_names, self.records = make_records() 267 | self.records = self.records[0:1] 268 | self.columns = [Column(self.column_names[0], 'VARCHAR(50)'), 269 | Column(self.column_names[1], 'CHAR(2)')] 270 | self.primary_key_names = ['city', 'state'] 271 | self.primary_key = PrimaryKey(self.primary_key_names) 272 | self.table = Table(self.table_name, self.columns, self.primary_key) 273 | self.delimiter = '|' 274 | self.force_null = [] 275 | self.null_str = '' 276 | self.insert_query = 'INSERT INTO {} VALUES (%s, %s)'.format( 277 | self.table_name 278 | ) 279 | 280 | with self.conn.cursor() as cursor: 281 | cursor.execute(self.table.create_statement()) 282 | self.conn.commit() 283 | 284 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 285 | def test_upsert_many_primary_key(self): 286 | records = [('Chicago', 'IL')] 287 | 288 | with self.conn.cursor() as cursor: 289 | cursor.executemany(self.insert_query, records) 290 | self.conn.commit() 291 | 292 | bulk_upserter = dml.CopyFromUpsert( 293 | self.table, delimiter=self.delimiter, null_str=self.null_str, 294 | force_null=self.force_null 295 | ) 296 | file_object = io.StringIO('\n'.join([*delimited_text().splitlines()[0:1], ''])) 297 | 298 | with self.conn: 299 | bulk_upserter(self.conn, file_object) 300 | 301 | result = get_records(self.conn, self.table_name) 302 | 303 | self.assertEqual(self.records, result) 304 | 305 | @skipPGVersionBefore(*PG_UPSERT_VERSION) 306 | def test_upsert_many_empty_file(self): 307 | bulk_upserter = dml.CopyFromUpsert( 308 | self.table, delimiter=self.delimiter, null_str=self.null_str, 309 | force_null=self.force_null 310 | ) 311 | text = '\n'.join([delimited_text().splitlines()[0], '']) 312 | file_object = io.StringIO(text) 313 | 314 | with self.conn: 315 | bulk_upserter(self.conn, file_object) 316 | 317 | 318 | class TestDeleteRecordStatements(PostgresStatementFixture, unittest.TestCase): 319 | 320 | def test_delete_joined_table_sql(self): 321 | table_name = 'table_foo' 322 | delete_table = 'delete_from_foo' 323 | primary_key = ['city', 'state'] 324 | 325 | expected = ( 326 | 'DELETE FROM table_foo t' 327 | ' USING delete_from_foo d' 328 | ' WHERE t.city=d.city AND t.state=d.state') 329 | result = dml.delete_joined_table_sql(table_name, delete_table, primary_key) 330 | self.assertSQLStatementEqual(expected, result) 331 | 332 | 333 | class TestDeletePrimaryKeyRecords(PostgresDmlFixture, unittest.TestCase): 334 | 335 | def setUp(self): 336 | self.column_names, self.records = make_records() 337 | self.primary_key_name = self.column_names[0] 338 | self.delete_records = [(self.records[0].city,), 339 | (self.records[2].city,), 340 | (self.records[3].city,)] 341 | self.table_name = 'insert_test' 342 | self.columns = [Column(name='city', data_type='VARCHAR(50)', 343 | nullable=False), 344 | Column(name='state', data_type='CHAR(2)', 345 | nullable=True)] 346 | self.primary_key = PrimaryKey([self.primary_key_name]) 347 | self.table = Table(self.table_name, self.columns, self.primary_key) 348 | self._setup_table_data() 349 | 350 | def test_process_delete_insert(self): 351 | delete_processor = dml.DeleteManyPrimaryKey(self.table) 352 | delete_processor(self.conn, self.delete_records) 353 | 354 | expected = set([self.records[1]]) 355 | result = set(get_records(self.conn, self.table.qualified_name)) 356 | 357 | self.assertSetEqual(expected, result) 358 | 359 | def test_process_delete_copy(self): 360 | text = '\n'.join( 361 | line for index, line in enumerate(delimited_text().split('\n')) 362 | if index != 2 363 | ) 364 | delete_processor = dml.CopyFromDelete(self.table, delimiter='|', 365 | header=True) 366 | file_obj = io.StringIO(text) 367 | 368 | with self.conn: 369 | delete_processor(self.conn, file_obj) 370 | 371 | def _setup_table_data(self): 372 | insert_statement = 'INSERT INTO insert_test (city, state) VALUES (%s, %s)' 373 | with self.conn.cursor() as cursor: 374 | cursor.execute(self.table.create_statement()) 375 | 376 | for record in self.records: 377 | cursor.execute(insert_statement, record) 378 | self.conn.commit() 379 | -------------------------------------------------------------------------------- /tests/test_dml_copy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy.fixtures import PostgresStatementFixture 4 | from postpy.dml_copy import copy_from_csv_sql 5 | 6 | 7 | class TestDmlCopyStatements(PostgresStatementFixture, unittest.TestCase): 8 | def test_copy_from_csv_sql(self): 9 | table = 'my_table' 10 | delimiter = '|' 11 | encoding = 'latin1' 12 | force_not_null = ['foo', 'bar'] 13 | 14 | expected = ("COPY my_table FROM STDIN" 15 | " WITH (" 16 | " FORMAT CSV," 17 | " DELIMITER '|'," 18 | " NULL 'NULL'," 19 | " QUOTE '\"'," 20 | " ESCAPE '\\'," 21 | " FORCE_NOT_NULL (foo, bar)," 22 | " ENCODING 'iso8859_1')") 23 | result = copy_from_csv_sql(table, delimiter=delimiter, null_str='NULL', 24 | header=False, encoding=encoding, 25 | force_not_null=force_not_null) 26 | 27 | self.assertSQLStatementEqual(expected, result) 28 | -------------------------------------------------------------------------------- /tests/test_extensions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import psycopg2 4 | 5 | from postpy.fixtures import PostgreSQLFixture 6 | from postpy.extensions import check_extension 7 | 8 | 9 | class TestExtensions(PostgreSQLFixture, unittest.TestCase): 10 | @classmethod 11 | def _prep(cls): 12 | cls.pg_extension = 'sslinfo' 13 | 14 | def test_check_no_extension(self): 15 | with self.assertRaises(psycopg2.ProgrammingError): 16 | check_extension(self.conn, 'fake_extension') 17 | 18 | def test_check_uninstalled_extension(self): 19 | 20 | self.assertFalse(check_extension(self.conn, self.pg_extension)) 21 | -------------------------------------------------------------------------------- /tests/test_formatting.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy.formatting import pyformat_parameters, named_style_parameters 4 | 5 | 6 | class TestParamFormatting(unittest.TestCase): 7 | 8 | def setUp(self): 9 | self.column_names = ['foo', 'bar', 'foobar'] 10 | 11 | def test_pyformat_parameters(self): 12 | expected = '%s, %s, %s' 13 | result = pyformat_parameters(self.column_names) 14 | 15 | self.assertEqual(expected, result) 16 | 17 | def test_compile_truncate_table(self): 18 | expected = '%(foo)s, %(bar)s, %(foobar)s' 19 | result = named_style_parameters(self.column_names) 20 | 21 | self.assertEqual(expected, result) 22 | -------------------------------------------------------------------------------- /tests/test_pg_encodings.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy.pg_encodings import get_postgres_encoding 4 | 5 | 6 | class TestPGEncodings(unittest.TestCase): 7 | def test_get_postgres_encoding(self): 8 | expected = 'utf_8' 9 | result = get_postgres_encoding('utf8') 10 | 11 | self.assertEqual(expected, result) 12 | -------------------------------------------------------------------------------- /tests/test_sql.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import psycopg2 3 | from collections import namedtuple 4 | 5 | from postpy import sql 6 | from postpy.fixtures import PostgreSQLFixture 7 | 8 | 9 | TABLE_QUERY = ("select table_name from information_schema.tables" 10 | " where table_schema = 'public';") 11 | 12 | 13 | class TestExecute(PostgreSQLFixture, unittest.TestCase): 14 | 15 | def test_execute_transactions(self): 16 | table = "execute_transactions" 17 | reset_query = "DROP TABLE IF EXISTS execute_transactions;" 18 | mock_statements = ["CREATE TABLE execute_transactions();"] 19 | 20 | sql.execute_transactions(self.conn, mock_statements) 21 | 22 | with self.conn.cursor() as cursor: 23 | cursor.execute(TABLE_QUERY) 24 | 25 | expected = (table,) 26 | result = cursor.fetchone() 27 | 28 | self.assertEqual(expected, result) 29 | cursor.execute(reset_query) 30 | 31 | self.conn.commit() 32 | 33 | def test_execute_transaction(self): 34 | table = "execute_transaction" 35 | reset_query = "DROP TABLE IF EXISTS execute_transaction;" 36 | mock_statements = ["CREATE TABLE execute_transaction();"] 37 | 38 | sql.execute_transaction(self.conn, mock_statements) 39 | 40 | with self.conn.cursor() as cursor: 41 | cursor.execute(TABLE_QUERY) 42 | 43 | expected = (table,) 44 | result = cursor.fetchone() 45 | 46 | self.assertEqual(expected, result) 47 | cursor.execute(reset_query) 48 | 49 | self.conn.commit() 50 | 51 | def test_doesnt_raise_exception(self): 52 | query = ["insert nothing into nothing"] 53 | try: 54 | sql.execute_transactions(self.conn, query) 55 | except psycopg2.ProgrammingError: 56 | self.fail('Raised DB Programming Error') 57 | 58 | 59 | class TestClosingTransaction(PostgreSQLFixture, unittest.TestCase): 60 | def test_execute_closing_transaction(self): 61 | statements = [ 62 | 'CREATE TABLE close_foo();', 63 | 'DROP TABLE close_foo;'] 64 | sql.execute_closing_transaction(statements) 65 | 66 | with self.conn.cursor() as cursor: 67 | cursor.execute(TABLE_QUERY) 68 | 69 | expected = None 70 | result = cursor.fetchone() 71 | 72 | self.assertEqual(expected, result) 73 | 74 | 75 | class TestSelectQueries(PostgreSQLFixture, unittest.TestCase): 76 | 77 | def setUp(self): 78 | self.query = 'select * from generate_series(1,3) as col1;' 79 | 80 | def test_select_dict(self): 81 | 82 | expected = [{'col1': 1}, {'col1': 2}, {'col1': 3}] 83 | result = list(sql.select_dict(self.conn, self.query)) 84 | 85 | self.assertEqual(expected, result) 86 | 87 | def test_select(self): 88 | Record = namedtuple('Record', 'col1') 89 | 90 | expected = [Record(col1=1), Record(col1=2), Record(col1=3)] 91 | result = list(sql.select(self.conn, self.query)) 92 | 93 | self.assertEqual(expected, result) 94 | 95 | def test_query_columns(self): 96 | query = "SELECT 1 AS foo, 'cip' AS bar;" 97 | 98 | expected = ['foo', 'bar'] 99 | result = list(sql.query_columns(self.conn, query)) 100 | 101 | self.assertEqual(expected, result) 102 | query = 'select * from generate_series(1,3) as col1 where col1=%s;' 103 | parameter_groups = [(3,), (2,), (1,)] 104 | Record = namedtuple('Record', 'col1') 105 | 106 | expected = [Record(col1=3), Record(col1=2), Record(col1=1)] 107 | result = list(sql.select_each(self.conn, query, parameter_groups)) 108 | 109 | self.assertEqual(expected, result) 110 | -------------------------------------------------------------------------------- /tests/test_uuids.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from postpy import uuids 4 | 5 | 6 | class TestUUIDFunctions(unittest.TestCase): 7 | 8 | def test_function_formatters(self): 9 | 10 | expected = 'gen_random_uuid()' 11 | result = uuids.random_uuid_function() 12 | 13 | self.assertEqual(expected, result) 14 | 15 | expected = 'my_schema.uuid_generate_v1mc()' 16 | result = uuids.uuid_sequence_function('my_schema') 17 | 18 | self.assertEqual(expected, result) 19 | --------------------------------------------------------------------------------