├── .gitignore ├── .travis.yml ├── LICENSE ├── README.rst ├── pytest.ini ├── setup.py └── warp_prism ├── __init__.py ├── _warp_prism.c └── tests ├── __init__.py └── test_warp_prism.py /.gitignore: -------------------------------------------------------------------------------- 1 | .bundle 2 | db/*.sqlite3 3 | log/*.log 4 | *.log 5 | tmp/**/* 6 | tmp/* 7 | *.swp 8 | *~ 9 | #mac autosaving file 10 | .DS_Store 11 | *.py[co] 12 | 13 | # Installer logs 14 | pip-log.txt 15 | 16 | # Unit test / coverage reports 17 | .coverage 18 | .tox 19 | test.log 20 | .noseids 21 | *.xlsx 22 | 23 | # Compiled python files 24 | *.py[co] 25 | 26 | # Packages 27 | *.egg 28 | *.egg-info 29 | dist 30 | build 31 | eggs 32 | cover 33 | parts 34 | bin 35 | var 36 | sdist 37 | develop-eggs 38 | .installed.cfg 39 | coverage.xml 40 | nosetests.xml 41 | 42 | # C Extensions 43 | *.o 44 | *.so 45 | *.out 46 | 47 | # Vim 48 | *.swp 49 | *.swo 50 | 51 | # Built documentation 52 | docs/_build/* 53 | 54 | # database of vbench 55 | benchmarks.db 56 | 57 | # Vagrant temp folder 58 | .vagrant 59 | 60 | # pytest 61 | .cache 62 | 63 | .dir-locals.el 64 | 65 | TAGS 66 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: false 3 | python: 4 | - "3.4" 5 | - "3.5" 6 | - "3.6" 7 | 8 | env: 9 | - CC=gcc CXX=g++ 10 | - CC=gcc-5 CXX=g++-5 11 | 12 | addons: 13 | postgresql: "9.4" 14 | apt: 15 | sources: 16 | - ubuntu-toolchain-r-test 17 | packages: 18 | - gcc-5 19 | - g++-5 20 | 21 | install: 22 | - ${CC} --version 23 | - pip install numpy 24 | - python -c "import numpy;print(numpy.__version__)" 25 | - pip install -e .[dev] 26 | 27 | script: 28 | - py.test warp_prism 29 | - flake8 warp_prism 30 | 31 | notifications: 32 | email: false 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | warp_prism 2 | ========== 3 | 4 | Quickly move data from postgres to numpy or pandas. 5 | 6 | API 7 | --- 8 | 9 | ``to_arrays(query, *, bind=None)`` 10 | `````````````````````````````````` 11 | 12 | .. code-block:: 13 | 14 | Run the query returning a the results as np.ndarrays. 15 | 16 | Parameters 17 | ---------- 18 | query : sa.sql.Selectable 19 | The query to run. This can be a select or a table. 20 | bind : sa.Engine, optional 21 | The engine used to create the connection. If not provided 22 | ``query.bind`` will be used. 23 | 24 | Returns 25 | ------- 26 | arrays : dict[str, (np.ndarray, np.ndarray)] 27 | A map from column name to the result arrays. The first array holds the 28 | values and the second array is a boolean mask for NULLs. The values 29 | where the mask is False are 0 interpreted by the type. 30 | 31 | 32 | ``to_dataframe(query, *, bind=None, null_values=None)`` 33 | ``````````````````````````````````````````````````````` 34 | 35 | .. code-block:: 36 | 37 | Run the query returning a the results as a pd.DataFrame. 38 | 39 | Parameters 40 | ---------- 41 | query : sa.sql.Selectable 42 | The query to run. This can be a select or a table. 43 | bind : sa.Engine, optional 44 | The engine used to create the connection. If not provided 45 | ``query.bind`` will be used. 46 | null_values : dict[str, any] 47 | The null values to use for each column. This falls back to 48 | ``warp_prism.null_values`` for columns that are not specified. 49 | 50 | Returns 51 | ------- 52 | df : pd.DataFrame 53 | A pandas DataFrame holding the results of the query. The columns 54 | of the DataFrame will be named the same and be in the same order as the 55 | query. 56 | 57 | 58 | ``register_odo_dataframe_edge()`` 59 | ````````````````````````````````` 60 | 61 | .. code-block:: 62 | 63 | Register an odo edge for sqlalchemy selectable objects to dataframe. 64 | 65 | This edge will have a lower cost that the default edge so it will be 66 | selected as the fasted path. 67 | 68 | If the selectable is not in a postgres database, it will fallback to the 69 | default odo edge. 70 | 71 | 72 | Comparisons 73 | ----------- 74 | 75 | A quick comparison between ``warp_prism``, ``odo``, and ``pd.read_sql_table``. 76 | 77 | In this example we will read real data for VIX from quandl stored in a local 78 | postgres database using ``warp_prism``, ``odo``, and ``pd.read_sql_table``. 79 | After that, we will use ``odo`` to create a table with two float columns and 80 | 1000000 rows and query it with the tree tools again. 81 | 82 | .. code-block:: python 83 | 84 | In [1]: import warp_prism 85 | 86 | In [2]: from odo import odo, resource 87 | 88 | In [3]: import pandas as pd 89 | 90 | In [4]: table = resource( 91 | ...: 'postgresql://localhost/bz::yahoo_index_vix', 92 | ...: schema='quandl', 93 | ...: ) 94 | 95 | In [5]: warp_prism.to_dataframe(table).head() 96 | Out[5]: 97 | asof_date open_ high low close volume \ 98 | 0 2016-01-08 22.959999 27.080000 22.480000 27.010000 0.0 99 | 1 2015-12-04 17.430000 17.650000 14.690000 14.810000 0.0 100 | 2 2015-10-29 14.800000 15.460000 14.330000 14.610000 0.0 101 | 3 2015-12-21 19.639999 20.209999 18.700001 18.700001 0.0 102 | 4 2015-10-26 14.760000 15.430000 14.680000 15.290000 0.0 103 | 104 | adjusted_close timestamp 105 | 0 27.010000 2016-01-11 23:14:54.682220 106 | 1 14.810000 2016-01-11 23:14:54.682220 107 | 2 14.610000 2016-01-11 23:14:54.682220 108 | 3 18.700001 2016-01-11 23:14:54.682220 109 | 4 15.290000 2016-01-11 23:14:54.682220 110 | 111 | In [6]: odo(table, pd.DataFrame).head() 112 | Out[6]: 113 | asof_date open_ high low close volume \ 114 | 0 2016-01-08 22.959999 27.080000 22.480000 27.010000 0.0 115 | 1 2015-12-04 17.430000 17.650000 14.690000 14.810000 0.0 116 | 2 2015-10-29 14.800000 15.460000 14.330000 14.610000 0.0 117 | 3 2015-12-21 19.639999 20.209999 18.700001 18.700001 0.0 118 | 4 2015-10-26 14.760000 15.430000 14.680000 15.290000 0.0 119 | 120 | adjusted_close timestamp 121 | 0 27.010000 2016-01-11 23:14:54.682220 122 | 1 14.810000 2016-01-11 23:14:54.682220 123 | 2 14.610000 2016-01-11 23:14:54.682220 124 | 3 18.700001 2016-01-11 23:14:54.682220 125 | 4 15.290000 2016-01-11 23:14:54.682220 126 | 127 | In [7]: pd.read_sql_table(table.name, table.bind, table.schema).head() 128 | Out[7]: 129 | asof_date open_ high low close volume \ 130 | 0 2016-01-08 22.959999 27.080000 22.480000 27.010000 0.0 131 | 1 2015-12-04 17.430000 17.650000 14.690000 14.810000 0.0 132 | 2 2015-10-29 14.800000 15.460000 14.330000 14.610000 0.0 133 | 3 2015-12-21 19.639999 20.209999 18.700001 18.700001 0.0 134 | 4 2015-10-26 14.760000 15.430000 14.680000 15.290000 0.0 135 | 136 | adjusted_close timestamp 137 | 0 27.010000 2016-01-11 23:14:54.682220 138 | 1 14.810000 2016-01-11 23:14:54.682220 139 | 2 14.610000 2016-01-11 23:14:54.682220 140 | 3 18.700001 2016-01-11 23:14:54.682220 141 | 4 15.290000 2016-01-11 23:14:54.682220 142 | 143 | In [8]: len(warp_prism.to_dataframe(table)) 144 | Out[8]: 6565 145 | 146 | In [9]: %timeit warp_prism.to_dataframe(table) 147 | 100 loops, best of 3: 7.55 ms per loop 148 | 149 | In [10]: %timeit odo(table, pd.DataFrame) 150 | 10 loops, best of 3: 49.9 ms per loop 151 | 152 | In [11]: %timeit pd.read_sql_table(table.name, table.bind, table.schema) 153 | 10 loops, best of 3: 61.8 ms per loop 154 | 155 | In [12]: big_table = odo( 156 | ...: pd.DataFrame({ 157 | ...: 'a': np.random.rand(1000000), 158 | ...: 'b': np.random.rand(1000000)}, 159 | ...: ), 160 | ...: 'postgresql://localhost/test::largefloattest', 161 | ...: ) 162 | 163 | In [13]: %timeit warp_prism.to_dataframe(big_table) 164 | 1 loop, best of 3: 248 ms per loop 165 | 166 | In [14]: %timeit odo(big_table, pd.DataFrame) 167 | 1 loop, best of 3: 1.51 s per loop 168 | 169 | In [15]: %timeit pd.read_sql_table(big_table.name, big_table.bind) 170 | 1 loop, best of 3: 1.9 s per loop 171 | 172 | 173 | Installation 174 | ------------ 175 | 176 | Warp Prism can be pip installed but requires numpy to build its C extensions: 177 | 178 | .. code-block:: 179 | 180 | $ pip install numpy 181 | $ pip install warp_prism 182 | 183 | 184 | License 185 | ------- 186 | 187 | Warp Prism is licensed under the Apache 2.0. 188 | 189 | Warp Prism is sponsored by `Quantopian `_ where it 190 | is used to fetch data for use in `Zipline `_ through the 191 | `Pipeline API `_ or interactively 192 | with `Blaze `_. 193 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --doctest-modules --ignore setup.py -vv 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, Extension, find_packages 3 | import sys 4 | 5 | import numpy as np 6 | 7 | long_description = '' 8 | 9 | if 'upload' in sys.argv: 10 | with open('README.rst') as f: 11 | long_description = f.read() 12 | 13 | classifiers = [ 14 | 'Development Status :: 5 - Production/Stable', 15 | 'Intended Audience :: Developers', 16 | 'Intended Audience :: Science/Research', 17 | 'License :: OSI Approved :: Apache Software License', 18 | 'Natural Language :: English', 19 | 'Programming Language :: C', 20 | 'Programming Language :: Python :: 3', 21 | 'Programming Language :: Python :: Implementation :: CPython', 22 | 'Topic :: Scientific/Engineering', 23 | ] 24 | 25 | setup( 26 | name='warp_prism', 27 | version='0.1.1', 28 | description='Quickly move data from postgres to numpy or pandas.', 29 | author='Quantopian Inc.', 30 | author_email='opensource@gmail.com', 31 | packages=find_packages(), 32 | long_description=long_description, 33 | license='Apache 2.0', 34 | classifiers=classifiers, 35 | url='https://github.com/quantopian/warp_prism', 36 | ext_modules=[ 37 | Extension( 38 | 'warp_prism._warp_prism', 39 | ['warp_prism/_warp_prism.c'], 40 | include_dirs=[np.get_include()], 41 | extra_compile_args=['-std=c99', '-Wall', '-Wextra'], 42 | ), 43 | ], 44 | install_requires=[ 45 | 'datashape', 46 | 'numpy', 47 | 'pandas', 48 | 'sqlalchemy', 49 | 'psycopg2', 50 | 'odo', 51 | 'toolz', 52 | 'networkx<=1.11', 53 | ], 54 | extras_require={ 55 | 'dev': [ 56 | 'flake8==3.3.0', 57 | 'pycodestyle==2.3.1', 58 | 'pyflakes==1.5.0', 59 | 'pytest==3.0.6', 60 | ], 61 | }, 62 | ) 63 | -------------------------------------------------------------------------------- /warp_prism/__init__.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | from datashape import discover 4 | from datashape.predicates import istabular 5 | import numpy as np 6 | from odo import convert 7 | import pandas as pd 8 | import sqlalchemy as sa 9 | from sqlalchemy.ext.compiler import compiles 10 | from toolz import keymap 11 | 12 | from ._warp_prism import ( 13 | raw_to_arrays as _raw_to_arrays, 14 | typeid_map as _raw_typeid_map, 15 | ) 16 | 17 | 18 | __version__ = '0.1.1' 19 | 20 | 21 | _typeid_map = keymap(np.dtype, _raw_typeid_map) 22 | _object_type_id = _raw_typeid_map['object'] 23 | 24 | 25 | class _CopyToBinary(sa.sql.expression.Executable, sa.sql.ClauseElement): 26 | 27 | def __init__(self, element, bind): 28 | self.element = element 29 | self._bind = bind = bind 30 | 31 | @property 32 | def bind(self): 33 | return self._bind 34 | 35 | 36 | def literal_compile(s): 37 | """Compile a sql expression with bind params inlined as literals. 38 | 39 | Parameters 40 | ---------- 41 | s : Selectable 42 | The expression to compile. 43 | 44 | Returns 45 | ------- 46 | cs : str 47 | An equivalent sql string. 48 | """ 49 | return str(s.compile(compile_kwargs={'literal_binds': True})) 50 | 51 | 52 | @compiles(_CopyToBinary, 'postgresql') 53 | def _compile_copy_to_binary_postgres(element, compiler, **kwargs): 54 | selectable = element.element 55 | return compiler.process( 56 | sa.text( 57 | 'COPY {stmt} TO STDOUT (FORMAT BINARY)'.format( 58 | stmt=( 59 | compiler.preparer.format_table(selectable) 60 | if isinstance(selectable, sa.Table) else 61 | '({})'.format(literal_compile(selectable)) 62 | ), 63 | ) 64 | ), 65 | **kwargs 66 | ) 67 | 68 | 69 | def _warp_prism_types(query): 70 | for name, dtype in discover(query).measure.fields: 71 | try: 72 | np_dtype = getattr(dtype, 'ty', dtype).to_numpy_dtype() 73 | if np_dtype.kind == 'U': 74 | yield _object_type_id 75 | else: 76 | yield _typeid_map[np_dtype] 77 | except KeyError: 78 | raise TypeError( 79 | 'warp_prism cannot query columns of type %s' % dtype, 80 | ) 81 | 82 | 83 | def _getbind(selectable, bind): 84 | """Return an explicitly passed connection or infer the connection from 85 | the selectable. 86 | 87 | Parameters 88 | ---------- 89 | selectable : sa.sql.Selectable 90 | The selectable object being queried. 91 | bind : bind or None 92 | The explicit connection or engine to use to execute the query. 93 | 94 | Returns 95 | ------- 96 | bind : The bind which should be used to execute the query. 97 | """ 98 | if bind is None: 99 | return selectable.bind 100 | 101 | if isinstance(bind, sa.engine.base.Engine): 102 | return bind 103 | 104 | return sa.create_engine(bind) 105 | 106 | 107 | def to_arrays(query, *, bind=None): 108 | """Run the query returning a the results as np.ndarrays. 109 | 110 | Parameters 111 | ---------- 112 | query : sa.sql.Selectable 113 | The query to run. This can be a select or a table. 114 | bind : sa.Engine, optional 115 | The engine used to create the connection. If not provided 116 | ``query.bind`` will be used. 117 | 118 | Returns 119 | ------- 120 | arrays : dict[str, (np.ndarray, np.ndarray)] 121 | A map from column name to the result arrays. The first array holds the 122 | values and the second array is a boolean mask for NULLs. The values 123 | where the mask is False are 0 interpreted by the type. 124 | """ 125 | # check types before doing any work 126 | types = tuple(_warp_prism_types(query)) 127 | 128 | buf = BytesIO() 129 | bind = _getbind(query, bind) 130 | 131 | stmt = _CopyToBinary(query, bind) 132 | with bind.connect() as conn: 133 | conn.connection.cursor().copy_expert(literal_compile(stmt), buf) 134 | out = _raw_to_arrays(buf.getbuffer(), types) 135 | column_names = query.c.keys() 136 | return {column_names[n]: v for n, v in enumerate(out)} 137 | 138 | 139 | null_values = keymap(np.dtype, { 140 | 'float32': np.nan, 141 | 'float64': np.nan, 142 | 'int16': np.nan, 143 | 'int32': np.nan, 144 | 'int64': np.nan, 145 | 'bool': np.nan, 146 | 'datetime64[ns]': np.datetime64('nat', 'ns'), 147 | 'object': None, 148 | }) 149 | 150 | # alias because ``to_dataframe`` shadows this name 151 | _default_null_values_for_type = null_values 152 | 153 | 154 | def to_dataframe(query, *, bind=None, null_values=None): 155 | """Run the query returning a the results as a pd.DataFrame. 156 | 157 | Parameters 158 | ---------- 159 | query : sa.sql.Selectable 160 | The query to run. This can be a select or a table. 161 | bind : sa.Engine, optional 162 | The engine used to create the connection. If not provided 163 | ``query.bind`` will be used. 164 | null_values : dict[str, any] 165 | The null values to use for each column. This falls back to 166 | ``warp_prism.null_values`` for columns that are not specified. 167 | 168 | Returns 169 | ------- 170 | df : pd.DataFrame 171 | A pandas DataFrame holding the results of the query. The columns 172 | of the DataFrame will be named the same and be in the same order as the 173 | query. 174 | """ 175 | arrays = to_arrays(query, bind=bind) 176 | 177 | if null_values is None: 178 | null_values = {} 179 | 180 | for name, (array, mask) in arrays.items(): 181 | if array.dtype.kind == 'i': 182 | if not mask.all(): 183 | try: 184 | null = null_values[name] 185 | except KeyError: 186 | # no explicit override, cast to float and use NaN as null 187 | array = array.astype('float64') 188 | null = np.nan 189 | 190 | array[~mask] = null 191 | 192 | arrays[name] = array 193 | continue 194 | 195 | if array.dtype.kind == 'M': 196 | # pandas needs datetime64[ns], not ``us`` or ``D`` 197 | array = array.astype('datetime64[ns]') 198 | 199 | try: 200 | null = null_values[name] 201 | except KeyError: 202 | null = _default_null_values_for_type[array.dtype] 203 | 204 | array[~mask] = null 205 | arrays[name] = array 206 | 207 | return pd.DataFrame(arrays, columns=[column.name for column in query.c]) 208 | 209 | 210 | def register_odo_dataframe_edge(): 211 | """Register an odo edge for sqlalchemy selectable objects to dataframe. 212 | 213 | This edge will have a lower cost that the default edge so it will be 214 | selected as the fasted path. 215 | 216 | If the selectable is not in a postgres database, it will fallback to the 217 | default odo edge. 218 | """ 219 | # estimating 8 times faster 220 | df_cost = convert.graph.edge[sa.sql.Select][pd.DataFrame]['cost'] / 8 221 | 222 | @convert.register( 223 | pd.DataFrame, 224 | (sa.sql.Select, sa.sql.Selectable), 225 | cost=df_cost, 226 | ) 227 | def select_or_selectable_to_frame(el, bind=None, dshape=None, **kwargs): 228 | bind = _getbind(el, bind) 229 | 230 | if bind.dialect.name != 'postgresql': 231 | # fall back to the general edge 232 | raise NotImplementedError() 233 | 234 | return to_dataframe(el, bind=bind) 235 | 236 | # higher priority than df edge so that 237 | # ``odo('select one_column from ...', list)`` returns a list of scalars 238 | # instead of a list of tuples of length 1 239 | @convert.register( 240 | pd.Series, 241 | (sa.sql.Select, sa.sql.Selectable), 242 | cost=df_cost - 1, 243 | ) 244 | def select_or_selectable_to_series(el, bind=None, dshape=None, **kwargs): 245 | bind = _getbind(el, bind) 246 | 247 | if istabular(dshape) or bind.dialect.name != 'postgresql': 248 | # fall back to the general edge 249 | raise NotImplementedError() 250 | 251 | return to_dataframe(el, bind=bind).iloc[:, 0] 252 | -------------------------------------------------------------------------------- /warp_prism/_warp_prism.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "Python.h" 7 | #include "numpy/arrayobject.h" 8 | 9 | const char* const signature = "PGCOPY\n\377\r\n\0"; 10 | const size_t signature_len = 11; 11 | 12 | const size_t column_buffer_growth_factor = 2; 13 | const size_t starting_column_buffer_length = 4096; 14 | 15 | #if __GNUC__ >= 5 16 | #define add_overflow __builtin_add_overflow 17 | #define mul_overflow __builtin_mul_overflow 18 | #else 19 | static inline bool add_overflow(size_t a, size_t b, size_t* out) { 20 | *out = a + b; 21 | return *out < a; 22 | } 23 | 24 | static inline bool mul_overflow(size_t a, size_t b, size_t* out) { 25 | *out = a * b; 26 | return a > SIZE_MAX / b; 27 | } 28 | #endif 29 | 30 | #ifdef __ORDER_LITTLE_ENDIAN__ 31 | #define MAYBE_BSWAP(arg, size) __builtin_bswap ## size (arg) 32 | #else 33 | #define MAYBE_BSWAP(arg, size) arg 34 | #endif 35 | 36 | #ifndef likely 37 | #define likely(p) __builtin_expect(!!(p), 1) 38 | #endif 39 | 40 | #ifndef unlikely 41 | #define unlikely(p) __builtin_expect(!!(p), 0) 42 | #endif 43 | 44 | #define TYPE(size) uint ## size ## _t 45 | 46 | #define DEFINE_READ(size) \ 47 | static inline TYPE(size) read ## size (const char* buffer) { \ 48 | return MAYBE_BSWAP(*(TYPE(size)*) buffer, size); \ 49 | } 50 | 51 | static inline uint8_t read8(const char* buffer) { 52 | return *buffer; 53 | } 54 | 55 | DEFINE_READ(16) 56 | DEFINE_READ(32) 57 | DEFINE_READ(64) 58 | 59 | #undef DEFINE_READ 60 | 61 | #define DEFINE_WRITE(size) \ 62 | static inline TYPE(size) write ## size (char* buffer, TYPE(size) value) { \ 63 | return *((TYPE(size)*) buffer) = value; \ 64 | } 65 | 66 | DEFINE_WRITE(8) 67 | DEFINE_WRITE(16) 68 | DEFINE_WRITE(32) 69 | DEFINE_WRITE(64) 70 | 71 | #undef DEFINE_WRITE 72 | 73 | typedef int (*parse_function)(char* column_buffer, 74 | const char * const input_buffer, 75 | size_t len); 76 | typedef void (*free_function)(void* colbuffer, size_t rowcount); 77 | typedef int (*write_null_function)(char* dst, size_t size); 78 | 79 | typedef struct { 80 | const char* const dtype_name; 81 | parse_function parse; 82 | free_function free; 83 | write_null_function write_null; 84 | size_t size; 85 | PyArray_Descr* dtype; 86 | } warp_prism_type; 87 | 88 | static int parse_int16(char* column_buffer, 89 | const char* const input_buffer, 90 | size_t len) { 91 | if (unlikely(len != sizeof(uint16_t))) { 92 | PyErr_Format(PyExc_ValueError, "mismatched int16 size: %zu", len); 93 | return -1; 94 | } 95 | 96 | write16(column_buffer, read16(input_buffer)); 97 | return 0; 98 | } 99 | 100 | static int parse_int32(char* column_buffer, 101 | const char* const input_buffer, 102 | size_t len) { 103 | if (unlikely(len != sizeof(uint32_t))) { 104 | PyErr_Format(PyExc_ValueError, "mismatched int32 size: %zu", len); 105 | return -1; 106 | } 107 | 108 | write32(column_buffer, read32(input_buffer)); 109 | return 0; 110 | } 111 | 112 | static int parse_int64(char* column_buffer, 113 | const char* const input_buffer, 114 | size_t len) { 115 | if (unlikely(len != sizeof(uint64_t))) { 116 | PyErr_Format(PyExc_ValueError, "mismatched int64 size: %zu", len); 117 | return -1; 118 | } 119 | 120 | write64(column_buffer, read64(input_buffer)); 121 | return 0; 122 | } 123 | 124 | /* 2000-01-01 in us since 1970-01-01 00:00:00+0000 125 | postgres stores datetimes as us since jan 1 2000 *not* jan 1 1970 */ 126 | const int64_t datetime_offset = 946684800000000l; 127 | 128 | static int parse_datetime(char* column_buffer, 129 | const char* const input_buffer, 130 | size_t len) { 131 | if (unlikely(len != sizeof(int64_t))) { 132 | PyErr_Format(PyExc_ValueError, "mismatched datetime size: %zu", len); 133 | return -1; 134 | } 135 | 136 | write64(column_buffer, read64(input_buffer) + datetime_offset); 137 | return 0; 138 | } 139 | 140 | /* 2000-01-01 in days since 1970-01-01; 141 | postgres stores date as days since jan 1 2000 *not* jan 1 1970 */ 142 | const int32_t date_offset = 10957; 143 | 144 | static int parse_date(char* column_buffer, 145 | const char* const input_buffer, 146 | size_t len) { 147 | if (unlikely(len != sizeof(int32_t))) { 148 | PyErr_Format(PyExc_ValueError, "mismatched date size: %zu", len); 149 | return -1; 150 | } 151 | 152 | /* We read 32 bits of data and but write it as 64 bits; postgres uses 32 bit 153 | integers for dates but numpy datetime64[D] uses 64. */ 154 | write64(column_buffer, read32(input_buffer) + date_offset); 155 | return 0; 156 | } 157 | 158 | static int parse_float32(char* column_buffer, 159 | const char* const input_buffer, 160 | size_t len) { 161 | if (unlikely(len != sizeof(float))) { 162 | PyErr_Format(PyExc_ValueError, "mismatched float32 size: %zu", len); 163 | return -1; 164 | } 165 | 166 | write32(column_buffer, read32(input_buffer)); 167 | return 0; 168 | } 169 | 170 | static int parse_float64(char* column_buffer, 171 | const char* const input_buffer, 172 | size_t len) { 173 | if (unlikely(len != sizeof(double))) { 174 | PyErr_Format(PyExc_ValueError, "mismatched float64 size: %zu", len); 175 | return -1; 176 | } 177 | 178 | write64(column_buffer, read64(input_buffer)); 179 | return 0; 180 | } 181 | 182 | static int parse_bool(char* column_buffer, 183 | const char* const input_buffer, 184 | size_t len) { 185 | if (unlikely(len != sizeof(uint8_t))) { 186 | PyErr_Format(PyExc_ValueError, "mismatched bool size: %zu", len); 187 | return -1; 188 | } 189 | 190 | write8(column_buffer, read8(input_buffer)); 191 | return 0; 192 | } 193 | 194 | static int parse_text(char* column_buffer, 195 | const char* const input_buffer, 196 | size_t len) { 197 | PyObject* value = PyUnicode_FromStringAndSize(input_buffer, len); 198 | if (unlikely(!value)) { 199 | return -1; 200 | } 201 | 202 | *(PyObject**) column_buffer = value; 203 | return 0; 204 | } 205 | 206 | static void simple_free(void* colbuffer, 207 | size_t rowcount __attribute__((unused))) { 208 | PyMem_Free(colbuffer); 209 | } 210 | 211 | static void free_object(PyObject** colbuffer, size_t rowcount) { 212 | for (size_t n = 0; n < rowcount; ++n) { 213 | Py_XDECREF(colbuffer[n]); 214 | } 215 | 216 | PyMem_Free(colbuffer); 217 | } 218 | 219 | static int simple_write_null(char* dst, size_t size) { 220 | memset(dst, 0, size); 221 | return 0; 222 | } 223 | 224 | static int datetime_write_null(char* dst, size_t size) { 225 | if (size != sizeof(int64_t)) { 226 | PyErr_Format(PyExc_ValueError, 227 | "wrong size for NULL datetime field: %zu, expected %zu", 228 | size, 229 | sizeof(int64_t)); 230 | return -1; 231 | } 232 | 233 | write64(dst, NPY_DATETIME_NAT); 234 | return 0; 235 | } 236 | 237 | static int object_write_null(char* dst, size_t size) { 238 | if (size != sizeof(PyObject*)) { 239 | PyErr_Format(PyExc_ValueError, 240 | "wrong size for NULL object field: %zu, expected %zu", 241 | size, 242 | sizeof(PyObject*)); 243 | return -1; 244 | } 245 | 246 | Py_INCREF(Py_None); 247 | *(PyObject**) dst = Py_None; 248 | return 0; 249 | } 250 | 251 | warp_prism_type int16_type = { 252 | "int16", 253 | (parse_function) parse_int16, 254 | simple_free, 255 | simple_write_null, 256 | sizeof(int16_t), 257 | NULL, 258 | }; 259 | 260 | warp_prism_type int32_type = { 261 | "int32", 262 | (parse_function) parse_int32, 263 | simple_free, 264 | simple_write_null, 265 | sizeof(uint32_t), 266 | NULL, 267 | }; 268 | 269 | warp_prism_type int64_type = { 270 | "int64", 271 | (parse_function) parse_int64, 272 | simple_free, 273 | simple_write_null, 274 | sizeof(int64_t), 275 | NULL, 276 | }; 277 | 278 | warp_prism_type float32_type = { 279 | "float32", 280 | (parse_function) parse_float32, 281 | simple_free, 282 | simple_write_null, 283 | sizeof(float), 284 | NULL, 285 | }; 286 | 287 | warp_prism_type float64_type = { 288 | "float64", 289 | (parse_function) parse_float64, 290 | simple_free, 291 | simple_write_null, 292 | sizeof(double), 293 | NULL, 294 | }; 295 | 296 | warp_prism_type bool_type = { 297 | "bool", 298 | (parse_function) parse_bool, 299 | simple_free, 300 | simple_write_null, 301 | sizeof(bool), 302 | NULL, 303 | }; 304 | 305 | warp_prism_type string_type = { 306 | "object", 307 | (parse_function) parse_text, 308 | (free_function) free_object, 309 | object_write_null, 310 | sizeof(PyObject*), 311 | NULL, 312 | }; 313 | 314 | warp_prism_type datetime_type = { 315 | "datetime64[us]", 316 | (parse_function) parse_datetime, 317 | simple_free, 318 | datetime_write_null, 319 | sizeof(int64_t), 320 | NULL, 321 | }; 322 | 323 | warp_prism_type date_type = { 324 | "datetime64[D]", 325 | (parse_function) parse_date, 326 | simple_free, 327 | datetime_write_null, 328 | sizeof(int64_t), 329 | NULL, 330 | }; 331 | 332 | const warp_prism_type* typeids[] = { 333 | &int16_type, 334 | &int32_type, 335 | &int64_type, 336 | &float32_type, 337 | &float64_type, 338 | &bool_type, 339 | &string_type, 340 | &datetime_type, 341 | &date_type, 342 | }; 343 | 344 | const size_t max_typeid = sizeof(typeids) / sizeof(warp_prism_type*); 345 | 346 | static inline bool have_oids(uint32_t flags) { 347 | return flags & (1 << 16); 348 | } 349 | 350 | static inline bool valid_flags(uint32_t flags) { 351 | return flags == 0 || flags == (1 << 16); 352 | } 353 | 354 | static inline bool assert_can_consume(size_t size, 355 | size_t cursor, 356 | size_t buffer_len) { 357 | size_t new_cursor; 358 | if (unlikely(add_overflow(cursor, size, &new_cursor))) { 359 | PyErr_Format(PyExc_ValueError, 360 | "consuming %zu bytes would cause an overflow", 361 | size); 362 | return true; 363 | } 364 | /* new_cursor is the *next* location we would do a read. This cannot be 365 | greater than the buffer_len because that means we would have read past 366 | the buffer's allocated space. It may be exactly equal to buffer_len which 367 | means we have consumed all of the input (but no more). For example: 368 | 369 | If cursor = 0; size = buffer_len = 10: then new_cursor = 0 + 10 = 10. 370 | This means that new_cursor is exactly equal to buffer_len but we have not 371 | done an out of bounds read. */ 372 | if (unlikely(new_cursor > buffer_len)) { 373 | PyErr_Format(PyExc_ValueError, 374 | "reading %zu bytes would cause an out of bounds access", 375 | size); 376 | return true; 377 | } 378 | return false; 379 | } 380 | 381 | #define DEFINE_CONSUME(size) \ 382 | static inline TYPE(size) consume ## size (const char* buffer, \ 383 | size_t* cursor) { \ 384 | TYPE(size) ret = read ## size (&buffer[*cursor]); \ 385 | *cursor += sizeof(TYPE(size)); \ 386 | return ret; \ 387 | } 388 | 389 | DEFINE_CONSUME(16) 390 | DEFINE_CONSUME(32) 391 | 392 | #undef DEFINE_CONSUME 393 | 394 | #define DEFINE_CHECKED_CONSUME(size) \ 395 | static inline bool checked_consume ## size (const char* buffer, \ 396 | size_t* cursor, \ 397 | size_t buffer_len, \ 398 | TYPE(size)* out) { \ 399 | if (assert_can_consume(sizeof(TYPE(size)), *cursor, buffer_len)) { \ 400 | return true; \ 401 | } \ 402 | *out = consume ## size (buffer, cursor); \ 403 | return false; \ 404 | } 405 | 406 | DEFINE_CHECKED_CONSUME(16) 407 | DEFINE_CHECKED_CONSUME(32) 408 | 409 | #undef DEFINE_CHECKED_CONSUME 410 | #undef TYPE 411 | 412 | static inline void free_outarrays(uint16_t ncolumns, 413 | size_t rowcount, 414 | const warp_prism_type** column_types, 415 | char** outarrays, 416 | bool** outmasks) { 417 | for (uint_fast16_t n = 0; n < ncolumns; ++n) { 418 | column_types[n]->free(outarrays[n], rowcount); 419 | PyMem_Free(outmasks[n]); 420 | } 421 | } 422 | 423 | static inline int allocate_outarrays(uint16_t ncolumns, 424 | const warp_prism_type** column_types, 425 | char** outarrays, 426 | bool** outmasks) { 427 | uint_fast16_t n = 0; 428 | size_t mask_allocation_size; 429 | 430 | if (unlikely(mul_overflow(starting_column_buffer_length, 431 | sizeof(bool), 432 | &mask_allocation_size))) { 433 | /* this should literally never happen */ 434 | PyErr_SetString(PyExc_OverflowError, 435 | "allocation size would overflow"); 436 | goto error; 437 | } 438 | 439 | for (; n < ncolumns; ++n) { 440 | size_t allocation_size; 441 | if (unlikely(mul_overflow(starting_column_buffer_length, 442 | column_types[n]->size, 443 | &allocation_size))) { 444 | PyErr_SetString(PyExc_OverflowError, 445 | "allocation size would overflow"); 446 | goto error; 447 | } 448 | outarrays[n] = PyMem_Malloc(allocation_size); 449 | if (!outarrays[n]) { 450 | goto error; 451 | } 452 | 453 | outmasks[n] = PyMem_Malloc(mask_allocation_size); 454 | if (!outmasks[n]) { 455 | /* free_outarrays expects that both the array and the mask are 456 | present but we failed halfway through the allocation. Clean up 457 | just the outarray and then, at the error label, feed 458 | free_outarrays n - 1 to clean up the allocated columns. */ 459 | column_types[n]->free(outarrays[n], 0); 460 | goto error; 461 | } 462 | } 463 | return 0; 464 | 465 | error: 466 | if (n > 0) { 467 | /* free the column and mask buffers that have already been allocated */ 468 | free_outarrays(n - 1, 469 | 0, 470 | column_types, 471 | outarrays, 472 | outmasks); 473 | } 474 | return -1; 475 | } 476 | 477 | static inline int grow_outarrays(uint16_t ncolumns, 478 | size_t* row_count, 479 | const warp_prism_type** column_types, 480 | char** outarrays, 481 | bool** outmasks) { 482 | size_t old_row_count = *row_count; 483 | size_t new_row_count; 484 | size_t new_mask_size; 485 | uint_fast16_t n; 486 | 487 | if (unlikely(mul_overflow(old_row_count, 488 | column_buffer_growth_factor, 489 | &new_row_count))) { 490 | PyErr_SetString(PyExc_OverflowError, "row count would overflow"); 491 | goto error; 492 | } 493 | *row_count = new_row_count; 494 | 495 | if (unlikely(mul_overflow(new_row_count, 496 | sizeof(bool), 497 | &new_mask_size))) { 498 | PyErr_SetString(PyExc_OverflowError, "mask size would overflow"); 499 | goto error; 500 | } 501 | 502 | for (n = 0; n < ncolumns; ++n) { 503 | size_t allocation_size; 504 | char* new; 505 | bool* newmask; 506 | 507 | if (unlikely(mul_overflow(new_row_count, 508 | column_types[n]->size, 509 | &allocation_size))) { 510 | PyErr_SetString(PyExc_OverflowError, 511 | "allocation size would overflow"); 512 | goto error; 513 | } 514 | 515 | new = PyMem_Realloc(outarrays[n], allocation_size); 516 | if (!new) { 517 | goto error; 518 | } 519 | outarrays[n] = new; 520 | 521 | newmask = PyMem_Realloc(outmasks[n], new_mask_size); 522 | if (!newmask) { 523 | goto error; 524 | } 525 | outmasks[n] = newmask; 526 | } 527 | return 0; 528 | error: 529 | free_outarrays(ncolumns, 530 | old_row_count, 531 | column_types, 532 | outarrays, 533 | outmasks); 534 | return -1; 535 | } 536 | 537 | int warp_prism_read_binary_results(const char* const input_buffer, 538 | size_t input_len, 539 | const uint16_t ncolumns, 540 | const warp_prism_type** column_types, 541 | size_t* written_rows, 542 | char** outarrays, 543 | bool** outmasks) { 544 | size_t cursor = 0; 545 | uint32_t flags; 546 | size_t row_count = 0; 547 | size_t allocated_rows = starting_column_buffer_length; 548 | uint32_t extension_area; 549 | 550 | if (input_len < signature_len || 551 | memcmp(input_buffer, signature, signature_len)) { 552 | 553 | PyErr_SetString(PyExc_ValueError, "missing postgres signature"); 554 | return -1; 555 | } 556 | 557 | /* advance the cursor through up to the flags segment */ 558 | cursor += signature_len; 559 | 560 | /* flags field */ 561 | if (checked_consume32(input_buffer, 562 | &cursor, 563 | input_len, 564 | &flags)) { 565 | return -1; 566 | } 567 | 568 | if (!valid_flags(flags)) { 569 | PyErr_SetString(PyExc_ValueError, "invalid flags in header"); 570 | return -1; 571 | } 572 | 573 | /* skip header extension area */ 574 | if (checked_consume32(input_buffer, 575 | &cursor, 576 | input_len, 577 | &extension_area)) { 578 | return -1; 579 | } 580 | cursor += extension_area; 581 | if (extension_area) { 582 | PyErr_SetString(PyExc_ValueError, "non-zero extension area length"); 583 | return -1; 584 | } 585 | 586 | if (allocate_outarrays(ncolumns, column_types, outarrays, outmasks)) { 587 | return -1; 588 | } 589 | 590 | 591 | while (true) { 592 | int16_t field_count; 593 | 594 | if (checked_consume16(input_buffer, 595 | &cursor, 596 | input_len, 597 | (uint16_t*) &field_count)) { 598 | free_outarrays(ncolumns, 599 | row_count, 600 | column_types, 601 | outarrays, 602 | outmasks); 603 | return -1; 604 | } 605 | 606 | if (field_count == -1) { 607 | /* field_count == -1 signals the end of the input data */ 608 | break; 609 | } 610 | 611 | if (field_count != ncolumns) { 612 | PyErr_Format(PyExc_ValueError, 613 | "mismatched field_count and ncolumns on row %zu:" 614 | " %d != %d", 615 | row_count, 616 | field_count, 617 | ncolumns); 618 | return -1; 619 | } 620 | 621 | if (have_oids(flags)) { 622 | uint32_t oid; 623 | if (checked_consume32(input_buffer, 624 | &cursor, 625 | input_len, 626 | &oid)) { 627 | free_outarrays(ncolumns, 628 | row_count, 629 | column_types, 630 | outarrays, 631 | outmasks); 632 | return -1; 633 | } 634 | } 635 | 636 | /* advance the row count; grow arrays if needed */ 637 | if (row_count++ == allocated_rows) { 638 | if (grow_outarrays(ncolumns, 639 | &allocated_rows, 640 | column_types, 641 | outarrays, 642 | outmasks)) { 643 | return -1; 644 | } 645 | } 646 | 647 | for (uint_fast16_t n = 0; n < ncolumns; ++n) { 648 | const warp_prism_type* column_type = column_types[n]; 649 | int32_t datalen; 650 | size_t row_ix = row_count - 1; 651 | char* column_buffer = &outarrays[n][row_ix * column_type->size]; 652 | 653 | if (checked_consume32(input_buffer, 654 | &cursor, 655 | input_len, 656 | (uint32_t*) &datalen)) { 657 | goto error; 658 | } 659 | 660 | if (!(outmasks[n][row_ix] = (datalen != -1))) { 661 | if (column_type->write_null(column_buffer, column_type->size)) { 662 | goto error; 663 | } 664 | 665 | /* no value bytes follow a null */ 666 | continue; 667 | } 668 | 669 | if (assert_can_consume(datalen, cursor, input_len) || 670 | column_type->parse(column_buffer, 671 | &input_buffer[cursor], 672 | datalen)) { 673 | goto error; 674 | } 675 | cursor += datalen; 676 | continue; 677 | 678 | error: 679 | /* Write a NULL of the correct size to all of the columns that 680 | have not yet been written. This ensures that we can properly 681 | cleanup all of the column arrays with `free_outarrays`. */ 682 | for (; n < ncolumns; ++n) { 683 | const warp_prism_type* type = column_types[n]; 684 | char* buffer = &outarrays[n][row_ix * column_type->size]; 685 | memset(buffer, 0, type->size); 686 | } 687 | free_outarrays(ncolumns, 688 | row_count, 689 | column_types, 690 | outarrays, 691 | outmasks); 692 | return -1; 693 | } 694 | } 695 | *written_rows = row_count; 696 | return 0; 697 | } 698 | 699 | typedef struct { 700 | char* buffer; 701 | const warp_prism_type* type; 702 | size_t rowcount; 703 | } capsule_contents; 704 | 705 | static void free_acapsule(PyObject* capsule) { 706 | capsule_contents* c = PyCapsule_GetPointer(capsule, NULL); 707 | 708 | if (c) { 709 | c->type->free(c->buffer, c->rowcount); 710 | PyMem_Free(c); 711 | } 712 | } 713 | 714 | static void free_mcapsule(PyObject* capsule) { 715 | PyMem_Free(PyCapsule_GetPointer(capsule, NULL)); 716 | } 717 | 718 | static PyObject* warp_prism_to_arrays(PyObject* self __attribute__((unused)), 719 | PyObject* args) { 720 | Py_buffer view; 721 | PyObject* pytypeids; 722 | Py_ssize_t ncolumns; 723 | const warp_prism_type** types = NULL; 724 | char** outarrays = NULL; 725 | bool** outmasks = NULL;; 726 | Py_ssize_t n; 727 | size_t written_rows; 728 | PyObject* out; 729 | 730 | if (PyTuple_GET_SIZE(args) != 2) { 731 | PyErr_SetString(PyExc_TypeError, 732 | "expected exactly 2 arguments (buffer, type_ids)"); 733 | return NULL; 734 | } 735 | 736 | pytypeids = PyTuple_GET_ITEM(args, 1); 737 | 738 | if (!PyTuple_Check(pytypeids)) { 739 | PyErr_SetString(PyExc_TypeError, "type_ids must be a tuple"); 740 | return NULL; 741 | } 742 | ncolumns = PyTuple_GET_SIZE(pytypeids); 743 | if (ncolumns > UINT16_MAX) { 744 | PyErr_SetString(PyExc_ValueError, "column count must fit in uint16_t"); 745 | return NULL; 746 | } 747 | 748 | if (!(outarrays = PyMem_Malloc(sizeof(char*) * ncolumns))) { 749 | goto free_arrays; 750 | } 751 | if (!(outmasks = PyMem_Malloc(sizeof(bool*) * ncolumns))) { 752 | goto free_arrays; 753 | } 754 | if (!(types = PyMem_Malloc(sizeof(warp_prism_type*) * ncolumns))) { 755 | goto free_arrays; 756 | } 757 | 758 | for (n = 0; n < ncolumns; ++n) { 759 | unsigned long id_ix; 760 | 761 | id_ix = PyLong_AsUnsignedLong(PyTuple_GET_ITEM(pytypeids, n)); 762 | if (PyErr_Occurred() || id_ix > max_typeid) { 763 | goto free_arrays; 764 | } 765 | 766 | types[n] = typeids[id_ix]; 767 | 768 | } 769 | 770 | if (!(out = PyTuple_New(ncolumns))) { 771 | goto free_arrays; 772 | } 773 | 774 | if (PyObject_GetBuffer(PyTuple_GET_ITEM(args, 0), 775 | &view, 776 | PyBUF_CONTIG_RO)) { 777 | return NULL; 778 | } 779 | 780 | if (warp_prism_read_binary_results(view.buf, 781 | view.len, 782 | ncolumns, 783 | types, 784 | &written_rows, 785 | outarrays, 786 | outmasks)) { 787 | PyBuffer_Release(&view); 788 | goto free_arrays; 789 | } 790 | PyBuffer_Release(&view); 791 | 792 | for (n = 0;n < ncolumns; ++n) { 793 | capsule_contents* ac; 794 | PyObject* acapsule; 795 | PyObject* mcapsule; 796 | PyObject* andaray; 797 | PyObject* mndarray; 798 | PyObject* pair; 799 | 800 | Py_INCREF(types[n]->dtype); 801 | if (!(andaray = PyArray_NewFromDescr(&PyArray_Type, 802 | types[n]->dtype, 803 | 1, 804 | (npy_intp*) &written_rows, 805 | NULL, 806 | outarrays[n], 807 | NPY_ARRAY_CARRAY, 808 | NULL))) { 809 | Py_DECREF(out); 810 | goto clear_arrays; 811 | } 812 | 813 | if (!(ac = PyMem_Malloc(sizeof(capsule_contents)))) { 814 | Py_DECREF(andaray); 815 | Py_DECREF(out); 816 | goto clear_arrays; 817 | } 818 | 819 | ac->buffer = outarrays[n]; 820 | ac->type = types[n]; 821 | ac->rowcount = written_rows; 822 | 823 | if (!(acapsule = PyCapsule_New(ac, NULL, free_acapsule))) { 824 | PyMem_Free(ac); 825 | Py_DECREF(andaray); 826 | Py_DECREF(out); 827 | goto clear_arrays; 828 | } 829 | 830 | if (PyArray_SetBaseObject((PyArrayObject*) andaray, acapsule)) { 831 | Py_DECREF(acapsule); 832 | Py_DECREF(andaray); 833 | Py_DECREF(out); 834 | goto clear_arrays; 835 | } 836 | 837 | if (!(mndarray = PyArray_SimpleNewFromData(1, 838 | (npy_intp*) &written_rows, 839 | NPY_BOOL, 840 | outmasks[n]))) { 841 | Py_DECREF(andaray); 842 | Py_DECREF(out); 843 | goto clear_arrays; 844 | } 845 | 846 | if (!(mcapsule = PyCapsule_New(outmasks[n], NULL, free_mcapsule))) { 847 | Py_DECREF(andaray); 848 | Py_DECREF(mndarray); 849 | Py_DECREF(out); 850 | goto clear_arrays; 851 | } 852 | 853 | if (PyArray_SetBaseObject((PyArrayObject*) mndarray, mcapsule)) { 854 | Py_DECREF(andaray); 855 | Py_DECREF(mndarray); 856 | Py_DECREF(mcapsule); 857 | Py_DECREF(out); 858 | goto clear_arrays; 859 | } 860 | 861 | if (!(pair = PyTuple_New(2))) { 862 | Py_DECREF(andaray); 863 | Py_DECREF(mndarray); 864 | Py_DECREF(out); 865 | goto clear_arrays; 866 | } 867 | 868 | PyTuple_SET_ITEM(pair, 0, andaray); 869 | PyTuple_SET_ITEM(pair, 1, mndarray); 870 | PyTuple_SET_ITEM(out, n, pair); 871 | } 872 | 873 | return out; 874 | 875 | clear_arrays: 876 | free_outarrays(ncolumns, written_rows, types, outarrays, outmasks); 877 | free_arrays: 878 | PyMem_Free(outarrays); 879 | PyMem_Free(outmasks); 880 | PyMem_Free(types); 881 | return NULL; 882 | } 883 | 884 | PyObject* test_overflow_operations(PyObject* self __attribute__((unused))) { 885 | size_t out; 886 | 887 | /* we actually want to run this code; gcc can figure out this out at compile 888 | time without the volatile */ 889 | volatile int a = 2; 890 | 891 | if (!add_overflow(SIZE_MAX, a, &out)) { 892 | PyErr_SetString(PyExc_AssertionError, "add_overflow max + 2 failed"); 893 | return NULL; 894 | } 895 | 896 | if (!mul_overflow(SIZE_MAX, a, &out)) { 897 | PyErr_SetString(PyExc_AssertionError, "mul_overflow max * 2 failed"); 898 | return NULL; 899 | } 900 | 901 | /* this should not overflow */ 902 | if (mul_overflow(2, a, &out)) { 903 | PyErr_SetString(PyExc_AssertionError, "mul_overflow 2 * 2 failed"); 904 | return NULL; 905 | } 906 | 907 | if (out != 4) { 908 | PyErr_Format(PyExc_AssertionError, 909 | "mul_overflow 2 * 2 failed; %ld != 4", 910 | out); 911 | return NULL; 912 | } 913 | 914 | Py_RETURN_NONE; 915 | } 916 | 917 | PyMethodDef methods[] = { 918 | {"raw_to_arrays", (PyCFunction) warp_prism_to_arrays, METH_VARARGS, NULL}, 919 | {"test_overflow_operations", (PyCFunction) test_overflow_operations, METH_NOARGS, NULL}, 920 | {NULL}, 921 | }; 922 | 923 | static struct PyModuleDef _warp_prism_module = { 924 | PyModuleDef_HEAD_INIT, 925 | "warp_prism._warp_prism", 926 | "", 927 | -1, 928 | methods, 929 | NULL, 930 | NULL, 931 | NULL, 932 | NULL 933 | }; 934 | 935 | PyMODINIT_FUNC PyInit__warp_prism(void) { 936 | PyObject* m; 937 | PyObject* typeid_map; 938 | PyObject* signature_ob; 939 | 940 | /* This is needed to setup the numpy C-API. */ 941 | import_array(); 942 | 943 | if (!(typeid_map = PyDict_New())) { 944 | return NULL; 945 | } 946 | 947 | for (size_t n = 0; n < max_typeid; ++n) { 948 | PyObject* dtype_name_ob; 949 | PyObject* n_ob; 950 | int err; 951 | 952 | 953 | if (!(dtype_name_ob = PyUnicode_FromString(typeids[n]->dtype_name))) { 954 | Py_DECREF(typeid_map); 955 | return NULL; 956 | } 957 | 958 | if (!PyArray_DescrConverter(dtype_name_ob, 959 | (PyArray_Descr**) &typeids[n]->dtype)) { 960 | Py_DECREF(dtype_name_ob); 961 | Py_DECREF(typeid_map); 962 | return NULL; 963 | } 964 | 965 | 966 | if (!(n_ob = PyLong_FromLong(n))) { 967 | Py_DECREF(dtype_name_ob); 968 | Py_DECREF(typeid_map); 969 | return NULL; 970 | } 971 | 972 | err = PyDict_SetItem(typeid_map, dtype_name_ob, n_ob); 973 | Py_DECREF(dtype_name_ob); 974 | Py_DECREF(n_ob); 975 | if (err) { 976 | Py_DECREF(typeid_map); 977 | return NULL; 978 | } 979 | } 980 | 981 | if (!(m = PyModule_Create(&_warp_prism_module))) { 982 | Py_DECREF(typeid_map); 983 | return NULL; 984 | } 985 | 986 | if (PyModule_AddObject(m, "typeid_map", typeid_map)) { 987 | Py_DECREF(typeid_map); 988 | Py_DECREF(m); 989 | return NULL; 990 | } 991 | 992 | if (!(signature_ob = PyBytes_FromStringAndSize(signature, signature_len))) { 993 | Py_DECREF(m); 994 | return NULL; 995 | } 996 | 997 | if (PyModule_AddObject(m, "postgres_signature", signature_ob)) { 998 | Py_DECREF(signature_ob); 999 | Py_DECREF(m); 1000 | return NULL; 1001 | } 1002 | 1003 | return m; 1004 | } 1005 | -------------------------------------------------------------------------------- /warp_prism/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from uuid import uuid4 3 | import warnings 4 | 5 | from odo import resource 6 | import sqlalchemy as sa 7 | 8 | 9 | def _dropdb(root_conn, db_name): 10 | root_conn.execute('COMMIT') 11 | root_conn.execute('DROP DATABASE %s' % db_name) 12 | 13 | 14 | @contextmanager 15 | def disposable_engine(uri): 16 | """An engine which is disposed on exit. 17 | 18 | Parameters 19 | ---------- 20 | uri : str 21 | The uri to the db. 22 | 23 | Yields 24 | ------ 25 | engine : sa.engine.Engine 26 | """ 27 | engine = resource(uri) 28 | try: 29 | yield engine 30 | finally: 31 | engine.dispose() 32 | 33 | 34 | _pg_stat_activity = sa.Table( 35 | 'pg_stat_activity', 36 | sa.MetaData(), 37 | sa.Column('pid', sa.Integer), 38 | ) 39 | 40 | 41 | @contextmanager 42 | def tmp_db_uri(): 43 | """Create a temporary postgres database to run the tests against. 44 | """ 45 | db_name = '_warp_prism_test_' + uuid4().hex 46 | root = 'postgresql://localhost/' 47 | uri = root + db_name 48 | with disposable_engine(root + 'postgres') as e, e.connect() as root_conn: 49 | root_conn.execute('COMMIT') 50 | root_conn.execute('CREATE DATABASE %s' % db_name) 51 | try: 52 | yield uri 53 | finally: 54 | resource(uri).dispose() 55 | try: 56 | _dropdb(root_conn, db_name) 57 | except sa.exc.OperationalError: 58 | # We couldn't drop the db. The most likely cause is that there 59 | # are active queries. Even more likely is that these are 60 | # rollbacks because there was an exception somewhere inside the 61 | # tests. We will cancel all the running queries and try to drop 62 | # the database again. 63 | pid = _pg_stat_activity.c.pid 64 | root_conn.execute( 65 | sa.select( 66 | (sa.func.pg_terminate_backend(pid),), 67 | ).where( 68 | pid != sa.func.pg_backend_pid(), 69 | ) 70 | ) 71 | try: 72 | _dropdb(root_conn, db_name) 73 | except sa.exc.OperationalError: # pragma: no cover 74 | # The database STILL wasn't cleaned up. Just tell the user 75 | # to deal with this manually. 76 | warnings.warn( 77 | "leaking database '%s', please manually delete this" % 78 | db_name, 79 | ) 80 | -------------------------------------------------------------------------------- /warp_prism/tests/test_warp_prism.py: -------------------------------------------------------------------------------- 1 | from string import ascii_letters 2 | import struct 3 | from uuid import uuid4 4 | 5 | from datashape import var, R, Option, dshape 6 | import numpy as np 7 | from odo import resource, odo 8 | import pandas as pd 9 | import pytest 10 | import sqlalchemy as sa 11 | 12 | from warp_prism._warp_prism import ( 13 | postgres_signature, 14 | raw_to_arrays, 15 | test_overflow_operations as _test_overflow_operations, 16 | ) 17 | from warp_prism import ( 18 | to_arrays, 19 | to_dataframe, 20 | null_values as null_values_for_type, 21 | _typeid_map, 22 | ) 23 | from warp_prism.tests import tmp_db_uri as tmp_db_uri_ctx 24 | 25 | 26 | @pytest.fixture(scope='module') 27 | def tmp_db_uri(): 28 | with tmp_db_uri_ctx() as db_uri: 29 | yield db_uri 30 | 31 | 32 | @pytest.fixture 33 | def tmp_table_uri(tmp_db_uri): 34 | return '%s::%s%s' % (tmp_db_uri, 'table_', uuid4().hex) 35 | 36 | 37 | def check_roundtrip_nonnull(table_uri, data, dtype, sqltype): 38 | """Check the data roundtrip through postgres using warp_prism to read the 39 | data 40 | 41 | Parameters 42 | ---------- 43 | table_uri : str 44 | The uri to a unique table. 45 | data : np.array 46 | The input data. 47 | dtype : str 48 | The dtype of the data. 49 | sqltype : type 50 | The sqlalchemy type of the data. 51 | """ 52 | input_dataframe = pd.DataFrame({'a': data}) 53 | table = odo(input_dataframe, table_uri, dshape=var * R['a': dtype]) 54 | # Ensure that odo created the table correctly. If these fail the other 55 | # tests are not well defined. 56 | assert table.columns.keys() == ['a'] 57 | assert isinstance(table.columns['a'].type, sqltype) 58 | 59 | arrays = to_arrays(table) 60 | assert len(arrays) == 1 61 | array, mask = arrays['a'] 62 | assert (array == data).all() 63 | assert mask.all() 64 | 65 | output_dataframe = to_dataframe(table) 66 | pd.util.testing.assert_frame_equal(output_dataframe, input_dataframe) 67 | 68 | 69 | @pytest.mark.parametrize('dtype,sqltype,start,stop,step', ( 70 | ('int16', sa.SmallInteger, 0, 5000, 1), 71 | ('int32', sa.Integer, 0, 5000, 1), 72 | ('int64', sa.BigInteger, 0, 5000, 1), 73 | ('float32', sa.REAL, 0, 2500, 0.5), 74 | ('float64', sa.FLOAT, 0, 2500, 0.5), 75 | )) 76 | def test_numeric_type_nonnull(tmp_table_uri, 77 | dtype, 78 | sqltype, 79 | start, 80 | stop, 81 | step): 82 | data = np.arange(start, stop, step, dtype=dtype) 83 | check_roundtrip_nonnull(tmp_table_uri, data, dtype, sqltype) 84 | 85 | 86 | def test_bool_type_nonnull(tmp_table_uri): 87 | data = np.array([True] * 2500 + [False] * 2500, dtype=bool) 88 | check_roundtrip_nonnull(tmp_table_uri, data, 'bool', sa.Boolean) 89 | 90 | 91 | def test_string_type_nonnull(tmp_table_uri): 92 | data = np.array(list(ascii_letters) * 200, dtype='object') 93 | check_roundtrip_nonnull(tmp_table_uri, data, 'string', sa.String) 94 | 95 | 96 | def test_datetime_type_nonnull(tmp_table_uri): 97 | data = pd.date_range( 98 | '2000', 99 | '2016', 100 | ).values.astype('datetime64[us]') 101 | check_roundtrip_nonnull(tmp_table_uri, data, 'datetime', sa.DateTime) 102 | 103 | 104 | def test_date_type_nonnull(tmp_table_uri): 105 | data = pd.date_range( 106 | '2000', 107 | '2016', 108 | ).values.astype('datetime64[D]') 109 | check_roundtrip_nonnull(tmp_table_uri, data, 'date', sa.Date) 110 | 111 | 112 | def check_roundtrip_null_values(table_uri, 113 | data, 114 | dtype, 115 | sqltype, 116 | null_values, 117 | mask, 118 | *, 119 | astype=False): 120 | """Check the data roundtrip through postgres using warp_prism to read the 121 | data 122 | 123 | Parameters 124 | ---------- 125 | table_uri : str 126 | The uri to a unique table. 127 | data : iterable[any] 128 | The input data. 129 | dtype : str 130 | The dtype of the data. 131 | sqltype : type 132 | The sqlalchemy type of the data. 133 | null_values : dict[str, any] 134 | The value to coerce ``NULL`` to. 135 | astype : bool, optional 136 | Coerce the input data to the given dtype before making assertions about 137 | the output data. 138 | """ 139 | table = resource(table_uri, dshape=var * R['a': Option(dtype)]) 140 | # Ensure that odo created the table correctly. If these fail the other 141 | # tests are not well defined. 142 | assert table.columns.keys() == ['a'] 143 | assert isinstance(table.columns['a'].type, sqltype) 144 | table.insert().values([{'a': v} for v in data]).execute() 145 | 146 | arrays = to_arrays(table) 147 | assert len(arrays) == 1 148 | array, actual_mask = arrays['a'] 149 | assert (actual_mask == mask).all() 150 | assert (array[mask] == data[mask]).all() 151 | 152 | output_dataframe = to_dataframe(table, null_values=null_values) 153 | if astype: 154 | data = data.astype(dshape(dtype).measure.to_numpy_dtype()) 155 | expected_dataframe = pd.DataFrame({'a': data}) 156 | expected_dataframe[~mask] = null_values.get( 157 | 'a', 158 | null_values_for_type[ 159 | array.dtype 160 | if array.dtype.kind != 'M' else 161 | np.dtype('datetime64[ns]') 162 | ], 163 | ) 164 | pd.util.testing.assert_frame_equal( 165 | output_dataframe, 166 | expected_dataframe, 167 | check_dtype=False, 168 | ) 169 | 170 | 171 | def check_roundtrip_null(table_uri, 172 | data, 173 | dtype, 174 | sqltype, 175 | null, 176 | mask, 177 | *, 178 | astype=False): 179 | """Check the data roundtrip through postgres using warp_prism to read the 180 | data 181 | 182 | Parameters 183 | ---------- 184 | table_uri : str 185 | The uri to a unique table. 186 | data : iterable[any] 187 | The input data. 188 | dtype : str 189 | The dtype of the data. 190 | sqltype : type 191 | The sqlalchemy type of the data. 192 | null : any 193 | The value to coerce ``NULL`` to. 194 | astype : bool, optional 195 | Coerce the input data to the given dtype before making assertions about 196 | the output data. 197 | """ 198 | check_roundtrip_null_values( 199 | table_uri, 200 | data, 201 | dtype, 202 | sqltype, 203 | {'a': null}, 204 | mask, 205 | astype=astype, 206 | ) 207 | 208 | 209 | @pytest.mark.parametrize('dtype,sqltype,start,stop,step,null', ( 210 | ('int16', sa.SmallInteger, 0, 5000, 1, -1), 211 | ('int32', sa.Integer, 0, 5000, 1, -1), 212 | ('int64', sa.BigInteger, 0, 5000, 1, -1), 213 | ('float32', sa.REAL, 0, 2500, 0.5, -1.0), 214 | ('float64', sa.FLOAT, 0, 2500, 0.5, -1.0), 215 | )) 216 | def test_numeric_type_null(tmp_table_uri, 217 | dtype, 218 | sqltype, 219 | start, 220 | stop, 221 | step, 222 | null): 223 | data = np.arange(start, stop, step, dtype=dtype).astype(object) 224 | mask = np.tile(np.array([True, False]), len(data) // 2) 225 | data[~mask] = None 226 | check_roundtrip_null(tmp_table_uri, data, dtype, sqltype, null, mask) 227 | 228 | 229 | @pytest.mark.parametrize('dtype,sqltype', ( 230 | ('int16', sa.SmallInteger), 231 | ('int32', sa.Integer), 232 | ('int64', sa.BigInteger), 233 | )) 234 | def test_numeric_default_null_promote(tmp_table_uri, dtype, sqltype): 235 | data = np.arange(0, 100, dtype=dtype).astype(object) 236 | mask = np.tile(np.array([True, False]), len(data) // 2) 237 | data[~mask] = None 238 | check_roundtrip_null_values(tmp_table_uri, data, dtype, sqltype, {}, mask) 239 | 240 | 241 | def test_bool_type_null(tmp_table_uri): 242 | data = np.array([True] * 2500 + [False] * 2500, dtype=bool).astype(object) 243 | mask = np.tile(np.array([True, False]), len(data) // 2) 244 | data[~mask] = None 245 | check_roundtrip_null(tmp_table_uri, data, 'bool', sa.Boolean, False, mask) 246 | 247 | 248 | def test_string_type_null(tmp_table_uri): 249 | data = np.array(list(ascii_letters) * 200, dtype='object') 250 | mask = np.tile(np.array([True, False]), len(data) // 2) 251 | data[~mask] = None 252 | check_roundtrip_null( 253 | tmp_table_uri, 254 | data, 255 | 'string', 256 | sa.String, 257 | 'ayy lmao', 258 | mask, 259 | ) 260 | 261 | 262 | def test_datetime_type_null(tmp_table_uri): 263 | data = np.array( 264 | list(pd.date_range( 265 | '2000', 266 | '2016', 267 | )), 268 | dtype=object, 269 | )[:-1] # slice the last element off to have an even number 270 | mask = np.tile(np.array([True, False]), len(data) // 2) 271 | data[~mask] = None 272 | check_roundtrip_null( 273 | tmp_table_uri, 274 | data, 275 | 'datetime', 276 | sa.DateTime, 277 | pd.Timestamp('1995-12-13').to_datetime64(), 278 | mask, 279 | ) 280 | 281 | 282 | def test_date_type_null(tmp_table_uri): 283 | data = np.arange( 284 | '2000', 285 | '2016', 286 | dtype='datetime64[D]', 287 | ).astype(object) 288 | mask = np.tile(np.array([True, False]), len(data) // 2) 289 | data[~mask] = None 290 | check_roundtrip_null( 291 | tmp_table_uri, 292 | data, 293 | 'date', 294 | sa.Date, 295 | pd.Timestamp('1995-12-13').to_datetime64(), 296 | mask, 297 | astype=True, 298 | ) 299 | 300 | 301 | def _pack_as_invalid_size_postgres_binary_data(char, itemsize, value): 302 | """Create mock postgres data for testing the column data size checks. 303 | 304 | Parameters 305 | ---------- 306 | char : str 307 | The format char for struct. 308 | value : any 309 | The value to pack, this will appear twice. 310 | 311 | Returns 312 | ------- 313 | binary_data : bytes 314 | The binary data to feed to raw_to_arrays. 315 | """ 316 | return postgres_signature + struct.pack( 317 | '>iihi{char}hi{char}'.format(char=char), 318 | 0, # flags 319 | 0, # extension area size 320 | 1, # field_count 321 | itemsize, # data_size 322 | value, 323 | 1, # field_count 324 | itemsize - 1, # incorrect size for the given type 325 | value, # default value of the given type 326 | ) 327 | 328 | 329 | @pytest.mark.parametrize('dtype', map(np.dtype, ( 330 | 'bool', 331 | 'int16', 332 | 'int32', 333 | 'float32', 334 | 'float64', 335 | ))) 336 | def test_invalid_numeric_size(dtype): 337 | input_data = _pack_as_invalid_size_postgres_binary_data( 338 | dtype.char, 339 | dtype.itemsize, 340 | dtype.type(), 341 | ) 342 | 343 | with pytest.raises(ValueError) as e: 344 | raw_to_arrays(input_data, (_typeid_map[dtype],)) 345 | 346 | assert str(e.value) == 'mismatched %s size: %s' % ( 347 | dtype.name, 348 | dtype.itemsize - 1, 349 | ) 350 | 351 | 352 | # timedelta to adjust a numpy datetime into a postgres datetime 353 | _epoch_offset = np.datetime64('2000-01-01') - np.datetime64('1970-01-01') 354 | 355 | 356 | def test_invalid_datetime_size(): 357 | input_data = _pack_as_invalid_size_postgres_binary_data( 358 | 'q', # int64_t (quadword) 359 | 8, 360 | (pd.Timestamp('2014-01-01').to_datetime64().astype('datetime64[us]') + 361 | _epoch_offset).view('int64'), 362 | ) 363 | 364 | dtype = np.dtype('datetime64[us]') 365 | with pytest.raises(ValueError) as e: 366 | raw_to_arrays(input_data, (_typeid_map[dtype],)) 367 | 368 | assert str(e.value) == 'mismatched datetime size: 7' 369 | 370 | 371 | def test_invalid_date_size(): 372 | input_data = _pack_as_invalid_size_postgres_binary_data( 373 | 'i', # int32_t 374 | 4, 375 | (np.datetime64('2014-01-01', 'D') + _epoch_offset).view('int64'), 376 | ) 377 | 378 | dtype = np.dtype('datetime64[D]') 379 | with pytest.raises(ValueError) as e: 380 | raw_to_arrays(input_data, (_typeid_map[dtype],)) 381 | 382 | assert str(e.value) == 'mismatched date size: 3' 383 | 384 | 385 | def test_invalid_text(): 386 | input_data = postgres_signature + struct.pack( 387 | '>iihi1si1shi{}si1s'.format(len(postgres_signature)), 388 | 0, # flags 389 | 0, # extension area size 390 | 391 | # row 0 392 | 2, # field_count 393 | 1, # data_size 394 | b'\0', 395 | 1, # data_size 396 | b'\0', 397 | 398 | # row 1 399 | 2, # field_count 400 | len(postgres_signature) + 1, # data_size 401 | postgres_signature + b'\0', # postgres signature is invalid unicode 402 | 1, # data_size 403 | b'\1', 404 | ) 405 | # we put the invalid unicode as the first column to test that we can clean 406 | # up the cell in the second column before we have written a string there 407 | 408 | str_typeid = _typeid_map[np.dtype(object)] 409 | with pytest.raises(UnicodeDecodeError): 410 | raw_to_arrays(input_data, (str_typeid, str_typeid)) 411 | 412 | 413 | def test_missing_signature(): 414 | input_data = b'' 415 | 416 | with pytest.raises(ValueError) as e: 417 | raw_to_arrays(input_data, ()) 418 | 419 | assert str(e.value) == 'missing postgres signature' 420 | 421 | 422 | def test_missing_flags(): 423 | input_data = postgres_signature 424 | 425 | with pytest.raises(ValueError) as e: 426 | raw_to_arrays(input_data, ()) 427 | 428 | assert ( 429 | str(e.value) == 'reading 4 bytes would cause an out of bounds access' 430 | ) 431 | 432 | 433 | def test_missing_extension_length(): 434 | input_data = postgres_signature + (0).to_bytes(4, 'big') 435 | 436 | with pytest.raises(ValueError) as e: 437 | raw_to_arrays(input_data, ()) 438 | 439 | assert ( 440 | str(e.value) == 'reading 4 bytes would cause an out of bounds access' 441 | ) 442 | 443 | 444 | def test_missing_end_marker(): 445 | input_data = postgres_signature + (0).to_bytes(4, 'big') 446 | 447 | with pytest.raises(ValueError) as e: 448 | raw_to_arrays(input_data, ()) 449 | 450 | assert ( 451 | str(e.value) == 'reading 4 bytes would cause an out of bounds access' 452 | ) 453 | 454 | 455 | def test_overflow_operations(): 456 | # thanks pytest 457 | _test_overflow_operations() 458 | --------------------------------------------------------------------------------