├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── TCLIService ├── TCLIService-remote ├── TCLIService.py ├── __init__.py ├── constants.py └── ttypes.py ├── dev_requirements.txt ├── generate.py ├── pyhive ├── __init__.py ├── common.py ├── exc.py ├── hive.py ├── presto.py ├── sasl_compat.py ├── sqlalchemy_hive.py ├── sqlalchemy_presto.py ├── sqlalchemy_trino.py ├── tests │ ├── __init__.py │ ├── dbapi_test_case.py │ ├── ldif_data │ │ ├── INITIAL_TESTDATA.ldif │ │ └── base.ldif │ ├── sqlalchemy_test_case.py │ ├── test_common.py │ ├── test_hive.py │ ├── test_presto.py │ ├── test_sasl_compat.py │ ├── test_sqlalchemy_hive.py │ ├── test_sqlalchemy_presto.py │ ├── test_sqlalchemy_trino.py │ └── test_trino.py └── trino.py ├── scripts ├── ldap_config │ └── slapd.conf ├── make_many_rows.sh ├── make_one_row.sh ├── make_one_row_complex.sh ├── make_test_database.sh ├── make_test_tables.sh ├── thrift-patches │ └── TCLIService.patch ├── travis-conf │ ├── com │ │ └── dropbox │ │ │ └── DummyPasswdAuthenticationProvider.java │ ├── hive │ │ ├── hive-site-custom.xml │ │ ├── hive-site-ldap.xml │ │ └── hive-site.xml │ ├── presto │ │ ├── catalog │ │ │ └── hive.properties │ │ ├── config.properties │ │ ├── jvm.config │ │ └── node.properties │ └── trino │ │ ├── catalog │ │ └── hive.properties │ │ ├── config.properties │ │ ├── jvm.config │ │ └── node.properties ├── travis-install.sh └── update_thrift_bindings.sh ├── setup.cfg └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cover/ 2 | .coverage 3 | /dist/ 4 | .DS_Store 5 | *.egg 6 | /env/ 7 | /htmlcov/ 8 | .idea/ 9 | .project 10 | *.pyc 11 | .pydevproject 12 | /PyHive.egg-info/ 13 | .settings 14 | .cache/ 15 | *.iml 16 | /scripts/.thrift_gen 17 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | language: python 3 | dist: trusty 4 | matrix: 5 | include: 6 | # https://docs.python.org/devguide/#status-of-python-branches 7 | # One build pulls latest versions dynamically 8 | - python: 3.6 9 | env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE TRINO=RELEASE SQLALCHEMY=sqlalchemy>=1.3.0 10 | - python: 3.6 11 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 12 | - python: 3.5 13 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 14 | - python: 3.4 15 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 16 | - python: 2.7 17 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0 18 | install: 19 | - ./scripts/travis-install.sh 20 | - pip install codecov 21 | script: PYTHONHASHSEED=random pytest -v 22 | after_success: codecov 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014 Dropbox, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======================================================== 2 | PyHive project has been donated to Apache Kyuubi 3 | ======================================================== 4 | 5 | You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive 6 | 7 | 8 | 9 | Legacy notes / instructions 10 | =========================== 11 | 12 | PyHive 13 | ********** 14 | 15 | 16 | PyHive is a collection of Python `DB-API `_ and 17 | `SQLAlchemy `_ interfaces for `Presto `_ , 18 | `Hive `_ and `Trino `_. 19 | 20 | Usage 21 | ********** 22 | 23 | DB-API 24 | ------ 25 | .. code-block:: python 26 | 27 | from pyhive import presto # or import hive or import trino 28 | cursor = presto.connect('localhost').cursor() # or use hive.connect or use trino.connect 29 | cursor.execute('SELECT * FROM my_awesome_data LIMIT 10') 30 | print cursor.fetchone() 31 | print cursor.fetchall() 32 | 33 | DB-API (asynchronous) 34 | --------------------- 35 | .. code-block:: python 36 | 37 | from pyhive import hive 38 | from TCLIService.ttypes import TOperationState 39 | cursor = hive.connect('localhost').cursor() 40 | cursor.execute('SELECT * FROM my_awesome_data LIMIT 10', async=True) 41 | 42 | status = cursor.poll().operationState 43 | while status in (TOperationState.INITIALIZED_STATE, TOperationState.RUNNING_STATE): 44 | logs = cursor.fetch_logs() 45 | for message in logs: 46 | print message 47 | 48 | # If needed, an asynchronous query can be cancelled at any time with: 49 | # cursor.cancel() 50 | 51 | status = cursor.poll().operationState 52 | 53 | print cursor.fetchall() 54 | 55 | In Python 3.7 `async` became a keyword; you can use `async_` instead: 56 | 57 | .. code-block:: python 58 | 59 | cursor.execute('SELECT * FROM my_awesome_data LIMIT 10', async_=True) 60 | 61 | 62 | SQLAlchemy 63 | ---------- 64 | First install this package to register it with SQLAlchemy, see ``entry_points`` in ``setup.py``. 65 | 66 | .. code-block:: python 67 | 68 | from sqlalchemy import * 69 | from sqlalchemy.engine import create_engine 70 | from sqlalchemy.schema import * 71 | # Presto 72 | engine = create_engine('presto://localhost:8080/hive/default') 73 | # Trino 74 | engine = create_engine('trino+pyhive://localhost:8080/hive/default') 75 | # Hive 76 | engine = create_engine('hive://localhost:10000/default') 77 | 78 | # SQLAlchemy < 2.0 79 | logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) 80 | print select([func.count('*')], from_obj=logs).scalar() 81 | 82 | # Hive + HTTPS + LDAP or basic Auth 83 | engine = create_engine('hive+https://username:password@localhost:10000/') 84 | logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) 85 | print select([func.count('*')], from_obj=logs).scalar() 86 | 87 | # SQLAlchemy >= 2.0 88 | metadata_obj = MetaData() 89 | books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String)) 90 | metadata_obj.create_all(engine) 91 | inspector = inspect(engine) 92 | inspector.get_columns('books') 93 | 94 | with engine.connect() as con: 95 | data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" }, 96 | { "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }] 97 | con.execute(books.insert(), data[0]) 98 | result = con.execute(text("select * from books")) 99 | print(result.fetchall()) 100 | 101 | Note: query generation functionality is not exhaustive or fully tested, but there should be no 102 | problem with raw SQL. 103 | 104 | Passing session configuration 105 | ----------------------------- 106 | 107 | .. code-block:: python 108 | 109 | # DB-API 110 | hive.connect('localhost', configuration={'hive.exec.reducers.max': '123'}) 111 | presto.connect('localhost', session_props={'query_max_run_time': '1234m'}) 112 | trino.connect('localhost', session_props={'query_max_run_time': '1234m'}) 113 | # SQLAlchemy 114 | create_engine( 115 | 'presto://user@host:443/hive', 116 | connect_args={'protocol': 'https', 117 | 'session_props': {'query_max_run_time': '1234m'}} 118 | ) 119 | create_engine( 120 | 'trino+pyhive://user@host:443/hive', 121 | connect_args={'protocol': 'https', 122 | 'session_props': {'query_max_run_time': '1234m'}} 123 | ) 124 | create_engine( 125 | 'hive://user@host:10000/database', 126 | connect_args={'configuration': {'hive.exec.reducers.max': '123'}}, 127 | ) 128 | # SQLAlchemy with LDAP 129 | create_engine( 130 | 'hive://user:password@host:10000/database', 131 | connect_args={'auth': 'LDAP'}, 132 | ) 133 | 134 | Requirements 135 | ************ 136 | 137 | Install using 138 | 139 | - ``pip install 'pyhive[hive]'`` or ``pip install 'pyhive[hive_pure_sasl]'`` for the Hive interface 140 | - ``pip install 'pyhive[presto]'`` for the Presto interface 141 | - ``pip install 'pyhive[trino]'`` for the Trino interface 142 | 143 | Note: ``'pyhive[hive]'`` extras uses `sasl `_ that doesn't support Python 3.11, See `github issue `_. 144 | Hence PyHive also supports `pure-sasl `_ via additional extras ``'pyhive[hive_pure_sasl]'`` which support Python 3.11. 145 | 146 | PyHive works with 147 | 148 | - Python 2.7 / Python 3 149 | - For Presto: `Presto installation `_ 150 | - For Trino: `Trino installation `_ 151 | - For Hive: `HiveServer2 `_ daemon 152 | 153 | Changelog 154 | ********* 155 | See https://github.com/dropbox/PyHive/releases. 156 | 157 | Contributing 158 | ************ 159 | - Please fill out the Dropbox Contributor License Agreement at https://opensource.dropbox.com/cla/ and note this in your pull request. 160 | - Changes must come with tests, with the exception of trivial things like fixing comments. See .travis.yml for the test environment setup. 161 | - Notes on project scope: 162 | 163 | - This project is intended to be a minimal Hive/Presto client that does that one thing and nothing else. 164 | Features that can be implemented on top of PyHive, such integration with your favorite data analysis library, are likely out of scope. 165 | - We prefer having a small number of generic features over a large number of specialized, inflexible features. 166 | For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option. 167 | 168 | Tips for test environment setup 169 | **************************************** 170 | You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download. 171 | It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment. 172 | 173 | You can clone `this repository `_ which has Docker Compose setup for Presto and Hive. 174 | You can add below lines to its docker-compose.yaml to start Trino in same environment:: 175 | 176 | trino: 177 | image: trinodb/trino:351 178 | ports: 179 | - "18080:18080" 180 | volumes: 181 | - ./trino:/etc/trino 182 | 183 | Note: ``./trino`` for docker volume defined above is `trino config from PyHive repository `_ 184 | 185 | Then run:: 186 | docker-compose up -d 187 | 188 | Testing 189 | ******* 190 | .. image:: https://travis-ci.org/dropbox/PyHive.svg 191 | :target: https://travis-ci.org/dropbox/PyHive 192 | .. image:: http://codecov.io/github/dropbox/PyHive/coverage.svg?branch=master 193 | :target: http://codecov.io/github/dropbox/PyHive?branch=master 194 | 195 | Run the following in an environment with Hive/Presto:: 196 | 197 | ./scripts/make_test_tables.sh 198 | virtualenv --no-site-packages env 199 | source env/bin/activate 200 | pip install -e . 201 | pip install -r dev_requirements.txt 202 | py.test 203 | 204 | WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and ``many_rows``, plus a 205 | database called ``pyhive_test_database``. 206 | 207 | Updating TCLIService 208 | ******************** 209 | 210 | The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the 211 | ``generate.py`` file can be used: ``python generate.py ``. When left blank, the 212 | version for Hive 2.3 will be downloaded. 213 | -------------------------------------------------------------------------------- /TCLIService/TCLIService-remote: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Autogenerated by Thrift Compiler (0.10.0) 4 | # 5 | # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING 6 | # 7 | # options string: py 8 | # 9 | 10 | import sys 11 | import pprint 12 | if sys.version_info[0] > 2: 13 | from urllib.parse import urlparse 14 | else: 15 | from urlparse import urlparse 16 | from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient 17 | from thrift.protocol.TBinaryProtocol import TBinaryProtocol 18 | 19 | from TCLIService import TCLIService 20 | from TCLIService.ttypes import * 21 | 22 | if len(sys.argv) <= 1 or sys.argv[1] == '--help': 23 | print('') 24 | print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]') 25 | print('') 26 | print('Functions:') 27 | print(' TOpenSessionResp OpenSession(TOpenSessionReq req)') 28 | print(' TCloseSessionResp CloseSession(TCloseSessionReq req)') 29 | print(' TGetInfoResp GetInfo(TGetInfoReq req)') 30 | print(' TExecuteStatementResp ExecuteStatement(TExecuteStatementReq req)') 31 | print(' TGetTypeInfoResp GetTypeInfo(TGetTypeInfoReq req)') 32 | print(' TGetCatalogsResp GetCatalogs(TGetCatalogsReq req)') 33 | print(' TGetSchemasResp GetSchemas(TGetSchemasReq req)') 34 | print(' TGetTablesResp GetTables(TGetTablesReq req)') 35 | print(' TGetTableTypesResp GetTableTypes(TGetTableTypesReq req)') 36 | print(' TGetColumnsResp GetColumns(TGetColumnsReq req)') 37 | print(' TGetFunctionsResp GetFunctions(TGetFunctionsReq req)') 38 | print(' TGetPrimaryKeysResp GetPrimaryKeys(TGetPrimaryKeysReq req)') 39 | print(' TGetCrossReferenceResp GetCrossReference(TGetCrossReferenceReq req)') 40 | print(' TGetOperationStatusResp GetOperationStatus(TGetOperationStatusReq req)') 41 | print(' TCancelOperationResp CancelOperation(TCancelOperationReq req)') 42 | print(' TCloseOperationResp CloseOperation(TCloseOperationReq req)') 43 | print(' TGetResultSetMetadataResp GetResultSetMetadata(TGetResultSetMetadataReq req)') 44 | print(' TFetchResultsResp FetchResults(TFetchResultsReq req)') 45 | print(' TGetDelegationTokenResp GetDelegationToken(TGetDelegationTokenReq req)') 46 | print(' TCancelDelegationTokenResp CancelDelegationToken(TCancelDelegationTokenReq req)') 47 | print(' TRenewDelegationTokenResp RenewDelegationToken(TRenewDelegationTokenReq req)') 48 | print(' TGetLogResp GetLog(TGetLogReq req)') 49 | print('') 50 | sys.exit(0) 51 | 52 | pp = pprint.PrettyPrinter(indent=2) 53 | host = 'localhost' 54 | port = 9090 55 | uri = '' 56 | framed = False 57 | ssl = False 58 | validate = True 59 | ca_certs = None 60 | keyfile = None 61 | certfile = None 62 | http = False 63 | argi = 1 64 | 65 | if sys.argv[argi] == '-h': 66 | parts = sys.argv[argi + 1].split(':') 67 | host = parts[0] 68 | if len(parts) > 1: 69 | port = int(parts[1]) 70 | argi += 2 71 | 72 | if sys.argv[argi] == '-u': 73 | url = urlparse(sys.argv[argi + 1]) 74 | parts = url[1].split(':') 75 | host = parts[0] 76 | if len(parts) > 1: 77 | port = int(parts[1]) 78 | else: 79 | port = 80 80 | uri = url[2] 81 | if url[4]: 82 | uri += '?%s' % url[4] 83 | http = True 84 | argi += 2 85 | 86 | if sys.argv[argi] == '-f' or sys.argv[argi] == '-framed': 87 | framed = True 88 | argi += 1 89 | 90 | if sys.argv[argi] == '-s' or sys.argv[argi] == '-ssl': 91 | ssl = True 92 | argi += 1 93 | 94 | if sys.argv[argi] == '-novalidate': 95 | validate = False 96 | argi += 1 97 | 98 | if sys.argv[argi] == '-ca_certs': 99 | ca_certs = sys.argv[argi+1] 100 | argi += 2 101 | 102 | if sys.argv[argi] == '-keyfile': 103 | keyfile = sys.argv[argi+1] 104 | argi += 2 105 | 106 | if sys.argv[argi] == '-certfile': 107 | certfile = sys.argv[argi+1] 108 | argi += 2 109 | 110 | cmd = sys.argv[argi] 111 | args = sys.argv[argi + 1:] 112 | 113 | if http: 114 | transport = THttpClient.THttpClient(host, port, uri) 115 | else: 116 | if ssl: 117 | socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile) 118 | else: 119 | socket = TSocket.TSocket(host, port) 120 | if framed: 121 | transport = TTransport.TFramedTransport(socket) 122 | else: 123 | transport = TTransport.TBufferedTransport(socket) 124 | protocol = TBinaryProtocol(transport) 125 | client = TCLIService.Client(protocol) 126 | transport.open() 127 | 128 | if cmd == 'OpenSession': 129 | if len(args) != 1: 130 | print('OpenSession requires 1 args') 131 | sys.exit(1) 132 | pp.pprint(client.OpenSession(eval(args[0]),)) 133 | 134 | elif cmd == 'CloseSession': 135 | if len(args) != 1: 136 | print('CloseSession requires 1 args') 137 | sys.exit(1) 138 | pp.pprint(client.CloseSession(eval(args[0]),)) 139 | 140 | elif cmd == 'GetInfo': 141 | if len(args) != 1: 142 | print('GetInfo requires 1 args') 143 | sys.exit(1) 144 | pp.pprint(client.GetInfo(eval(args[0]),)) 145 | 146 | elif cmd == 'ExecuteStatement': 147 | if len(args) != 1: 148 | print('ExecuteStatement requires 1 args') 149 | sys.exit(1) 150 | pp.pprint(client.ExecuteStatement(eval(args[0]),)) 151 | 152 | elif cmd == 'GetTypeInfo': 153 | if len(args) != 1: 154 | print('GetTypeInfo requires 1 args') 155 | sys.exit(1) 156 | pp.pprint(client.GetTypeInfo(eval(args[0]),)) 157 | 158 | elif cmd == 'GetCatalogs': 159 | if len(args) != 1: 160 | print('GetCatalogs requires 1 args') 161 | sys.exit(1) 162 | pp.pprint(client.GetCatalogs(eval(args[0]),)) 163 | 164 | elif cmd == 'GetSchemas': 165 | if len(args) != 1: 166 | print('GetSchemas requires 1 args') 167 | sys.exit(1) 168 | pp.pprint(client.GetSchemas(eval(args[0]),)) 169 | 170 | elif cmd == 'GetTables': 171 | if len(args) != 1: 172 | print('GetTables requires 1 args') 173 | sys.exit(1) 174 | pp.pprint(client.GetTables(eval(args[0]),)) 175 | 176 | elif cmd == 'GetTableTypes': 177 | if len(args) != 1: 178 | print('GetTableTypes requires 1 args') 179 | sys.exit(1) 180 | pp.pprint(client.GetTableTypes(eval(args[0]),)) 181 | 182 | elif cmd == 'GetColumns': 183 | if len(args) != 1: 184 | print('GetColumns requires 1 args') 185 | sys.exit(1) 186 | pp.pprint(client.GetColumns(eval(args[0]),)) 187 | 188 | elif cmd == 'GetFunctions': 189 | if len(args) != 1: 190 | print('GetFunctions requires 1 args') 191 | sys.exit(1) 192 | pp.pprint(client.GetFunctions(eval(args[0]),)) 193 | 194 | elif cmd == 'GetPrimaryKeys': 195 | if len(args) != 1: 196 | print('GetPrimaryKeys requires 1 args') 197 | sys.exit(1) 198 | pp.pprint(client.GetPrimaryKeys(eval(args[0]),)) 199 | 200 | elif cmd == 'GetCrossReference': 201 | if len(args) != 1: 202 | print('GetCrossReference requires 1 args') 203 | sys.exit(1) 204 | pp.pprint(client.GetCrossReference(eval(args[0]),)) 205 | 206 | elif cmd == 'GetOperationStatus': 207 | if len(args) != 1: 208 | print('GetOperationStatus requires 1 args') 209 | sys.exit(1) 210 | pp.pprint(client.GetOperationStatus(eval(args[0]),)) 211 | 212 | elif cmd == 'CancelOperation': 213 | if len(args) != 1: 214 | print('CancelOperation requires 1 args') 215 | sys.exit(1) 216 | pp.pprint(client.CancelOperation(eval(args[0]),)) 217 | 218 | elif cmd == 'CloseOperation': 219 | if len(args) != 1: 220 | print('CloseOperation requires 1 args') 221 | sys.exit(1) 222 | pp.pprint(client.CloseOperation(eval(args[0]),)) 223 | 224 | elif cmd == 'GetResultSetMetadata': 225 | if len(args) != 1: 226 | print('GetResultSetMetadata requires 1 args') 227 | sys.exit(1) 228 | pp.pprint(client.GetResultSetMetadata(eval(args[0]),)) 229 | 230 | elif cmd == 'FetchResults': 231 | if len(args) != 1: 232 | print('FetchResults requires 1 args') 233 | sys.exit(1) 234 | pp.pprint(client.FetchResults(eval(args[0]),)) 235 | 236 | elif cmd == 'GetDelegationToken': 237 | if len(args) != 1: 238 | print('GetDelegationToken requires 1 args') 239 | sys.exit(1) 240 | pp.pprint(client.GetDelegationToken(eval(args[0]),)) 241 | 242 | elif cmd == 'CancelDelegationToken': 243 | if len(args) != 1: 244 | print('CancelDelegationToken requires 1 args') 245 | sys.exit(1) 246 | pp.pprint(client.CancelDelegationToken(eval(args[0]),)) 247 | 248 | elif cmd == 'RenewDelegationToken': 249 | if len(args) != 1: 250 | print('RenewDelegationToken requires 1 args') 251 | sys.exit(1) 252 | pp.pprint(client.RenewDelegationToken(eval(args[0]),)) 253 | 254 | elif cmd == 'GetLog': 255 | if len(args) != 1: 256 | print('GetLog requires 1 args') 257 | sys.exit(1) 258 | pp.pprint(client.GetLog(eval(args[0]),)) 259 | 260 | else: 261 | print('Unrecognized method %s' % cmd) 262 | sys.exit(1) 263 | 264 | transport.close() 265 | -------------------------------------------------------------------------------- /TCLIService/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['ttypes', 'constants', 'TCLIService'] 2 | -------------------------------------------------------------------------------- /TCLIService/constants.py: -------------------------------------------------------------------------------- 1 | # 2 | # Autogenerated by Thrift Compiler (0.10.0) 3 | # 4 | # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING 5 | # 6 | # options string: py 7 | # 8 | 9 | from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException 10 | from thrift.protocol.TProtocol import TProtocolException 11 | import sys 12 | from .ttypes import * 13 | PRIMITIVE_TYPES = set(( 14 | 0, 15 | 1, 16 | 2, 17 | 3, 18 | 4, 19 | 5, 20 | 6, 21 | 7, 22 | 8, 23 | 9, 24 | 15, 25 | 16, 26 | 17, 27 | 18, 28 | 19, 29 | 20, 30 | 21, 31 | )) 32 | COMPLEX_TYPES = set(( 33 | 10, 34 | 11, 35 | 12, 36 | 13, 37 | 14, 38 | )) 39 | COLLECTION_TYPES = set(( 40 | 10, 41 | 11, 42 | )) 43 | TYPE_NAMES = { 44 | 0: "BOOLEAN", 45 | 1: "TINYINT", 46 | 2: "SMALLINT", 47 | 3: "INT", 48 | 4: "BIGINT", 49 | 5: "FLOAT", 50 | 6: "DOUBLE", 51 | 7: "STRING", 52 | 8: "TIMESTAMP", 53 | 9: "BINARY", 54 | 10: "ARRAY", 55 | 11: "MAP", 56 | 12: "STRUCT", 57 | 13: "UNIONTYPE", 58 | 15: "DECIMAL", 59 | 16: "NULL", 60 | 17: "DATE", 61 | 18: "VARCHAR", 62 | 19: "CHAR", 63 | 20: "INTERVAL_YEAR_MONTH", 64 | 21: "INTERVAL_DAY_TIME", 65 | } 66 | CHARACTER_MAXIMUM_LENGTH = "characterMaximumLength" 67 | PRECISION = "precision" 68 | SCALE = "scale" 69 | -------------------------------------------------------------------------------- /dev_requirements.txt: -------------------------------------------------------------------------------- 1 | # test-only packages: pin everything to minimize change 2 | flake8==3.4.1 3 | mock==2.0.0 4 | pycodestyle==2.3.1 5 | pytest==3.2.1 6 | pytest-cov==2.5.1 7 | pytest-flake8==0.8.1 8 | pytest-random==0.2 9 | pytest-timeout==1.2.0 10 | 11 | # actual dependencies: let things break if a package changes 12 | requests>=1.0.0 13 | requests_kerberos>=0.12.0 14 | sasl>=0.2.1 15 | pure-sasl>=0.6.2 16 | kerberos>=1.3.0 17 | thrift>=0.10.0 18 | #thrift_sasl>=0.1.0 19 | git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches 20 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file can be used to generate a new version of the TCLIService package 3 | using a TCLIService.thrift URL. 4 | 5 | If no URL is specified, the file for Hive 2.3 will be downloaded. 6 | 7 | Usage: 8 | 9 | python generate.py THRIFT_URL 10 | 11 | or 12 | 13 | python generate.py 14 | """ 15 | import shutil 16 | import sys 17 | from os import path 18 | from urllib.request import urlopen 19 | import subprocess 20 | 21 | here = path.abspath(path.dirname(__file__)) 22 | 23 | PACKAGE = 'TCLIService' 24 | GENERATED = 'gen-py' 25 | 26 | HIVE_SERVER2_URL = \ 27 | 'https://raw.githubusercontent.com/apache/hive/branch-2.3/service-rpc/if/TCLIService.thrift' 28 | 29 | 30 | def save_url(url): 31 | data = urlopen(url).read() 32 | file_path = path.join(here, url.rsplit('/', 1)[-1]) 33 | with open(file_path, 'wb') as f: 34 | f.write(data) 35 | 36 | 37 | def main(hive_server2_url): 38 | save_url(hive_server2_url) 39 | hive_server2_path = path.join(here, hive_server2_url.rsplit('/', 1)[-1]) 40 | 41 | subprocess.call(['thrift', '-r', '--gen', 'py', hive_server2_path]) 42 | shutil.move(path.join(here, PACKAGE), path.join(here, PACKAGE + '.old')) 43 | shutil.move(path.join(here, GENERATED, PACKAGE), path.join(here, PACKAGE)) 44 | shutil.rmtree(path.join(here, PACKAGE + '.old')) 45 | 46 | 47 | if __name__ == '__main__': 48 | if len(sys.argv) > 1: 49 | url = sys.argv[1] 50 | else: 51 | url = HIVE_SERVER2_URL 52 | main(url) 53 | -------------------------------------------------------------------------------- /pyhive/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | __version__ = '0.7.0' 4 | -------------------------------------------------------------------------------- /pyhive/common.py: -------------------------------------------------------------------------------- 1 | """Package private common utilities. Do not use directly. 2 | 3 | Many docstrings in this file are based on PEP-249, which is in the public domain. 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import unicode_literals 8 | from builtins import bytes 9 | from builtins import int 10 | from builtins import object 11 | from builtins import str 12 | from past.builtins import basestring 13 | from pyhive import exc 14 | import abc 15 | import collections 16 | import time 17 | import datetime 18 | from future.utils import with_metaclass 19 | from itertools import islice 20 | 21 | try: 22 | from collections.abc import Iterable 23 | except ImportError: 24 | from collections import Iterable 25 | 26 | 27 | class DBAPICursor(with_metaclass(abc.ABCMeta, object)): 28 | """Base class for some common DB-API logic""" 29 | 30 | _STATE_NONE = 0 31 | _STATE_RUNNING = 1 32 | _STATE_FINISHED = 2 33 | 34 | def __init__(self, poll_interval=1): 35 | self._poll_interval = poll_interval 36 | self._reset_state() 37 | self.lastrowid = None 38 | 39 | def _reset_state(self): 40 | """Reset state about the previous query in preparation for running another query""" 41 | # State to return as part of DB-API 42 | self._rownumber = 0 43 | 44 | # Internal helper state 45 | self._state = self._STATE_NONE 46 | self._data = collections.deque() 47 | self._columns = None 48 | 49 | def _fetch_while(self, fn): 50 | while fn(): 51 | self._fetch_more() 52 | if fn(): 53 | time.sleep(self._poll_interval) 54 | 55 | @abc.abstractproperty 56 | def description(self): 57 | raise NotImplementedError # pragma: no cover 58 | 59 | def close(self): 60 | """By default, do nothing""" 61 | pass 62 | 63 | @abc.abstractmethod 64 | def _fetch_more(self): 65 | """Get more results, append it to ``self._data``, and update ``self._state``.""" 66 | raise NotImplementedError # pragma: no cover 67 | 68 | @property 69 | def rowcount(self): 70 | """By default, return -1 to indicate that this is not supported.""" 71 | return -1 72 | 73 | @abc.abstractmethod 74 | def execute(self, operation, parameters=None): 75 | """Prepare and execute a database operation (query or command). 76 | 77 | Parameters may be provided as sequence or mapping and will be bound to variables in the 78 | operation. Variables are specified in a database-specific notation (see the module's 79 | ``paramstyle`` attribute for details). 80 | 81 | Return values are not defined. 82 | """ 83 | raise NotImplementedError # pragma: no cover 84 | 85 | def executemany(self, operation, seq_of_parameters): 86 | """Prepare a database operation (query or command) and then execute it against all parameter 87 | sequences or mappings found in the sequence ``seq_of_parameters``. 88 | 89 | Only the final result set is retained. 90 | 91 | Return values are not defined. 92 | """ 93 | for parameters in seq_of_parameters[:-1]: 94 | self.execute(operation, parameters) 95 | while self._state != self._STATE_FINISHED: 96 | self._fetch_more() 97 | if seq_of_parameters: 98 | self.execute(operation, seq_of_parameters[-1]) 99 | 100 | def fetchone(self): 101 | """Fetch the next row of a query result set, returning a single sequence, or ``None`` when 102 | no more data is available. 103 | 104 | An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to 105 | :py:meth:`execute` did not produce any result set or no call was issued yet. 106 | """ 107 | if self._state == self._STATE_NONE: 108 | raise exc.ProgrammingError("No query yet") 109 | 110 | # Sleep until we're done or we have some data to return 111 | self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED) 112 | 113 | if not self._data: 114 | return None 115 | else: 116 | self._rownumber += 1 117 | return self._data.popleft() 118 | 119 | def fetchmany(self, size=None): 120 | """Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a 121 | list of tuples). An empty sequence is returned when no more rows are available. 122 | 123 | The number of rows to fetch per call is specified by the parameter. If it is not given, the 124 | cursor's arraysize determines the number of rows to be fetched. The method should try to 125 | fetch as many rows as indicated by the size parameter. If this is not possible due to the 126 | specified number of rows not being available, fewer rows may be returned. 127 | 128 | An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to 129 | :py:meth:`execute` did not produce any result set or no call was issued yet. 130 | """ 131 | if size is None: 132 | size = self.arraysize 133 | return list(islice(iter(self.fetchone, None), size)) 134 | 135 | def fetchall(self): 136 | """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences 137 | (e.g. a list of tuples). 138 | 139 | An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to 140 | :py:meth:`execute` did not produce any result set or no call was issued yet. 141 | """ 142 | return list(iter(self.fetchone, None)) 143 | 144 | @property 145 | def arraysize(self): 146 | """This read/write attribute specifies the number of rows to fetch at a time with 147 | :py:meth:`fetchmany`. It defaults to 1 meaning to fetch a single row at a time. 148 | """ 149 | return self._arraysize 150 | 151 | @arraysize.setter 152 | def arraysize(self, value): 153 | self._arraysize = value 154 | 155 | def setinputsizes(self, sizes): 156 | """Does nothing by default""" 157 | pass 158 | 159 | def setoutputsize(self, size, column=None): 160 | """Does nothing by default""" 161 | pass 162 | 163 | # 164 | # Optional DB API Extensions 165 | # 166 | 167 | @property 168 | def rownumber(self): 169 | """This read-only attribute should provide the current 0-based index of the cursor in the 170 | result set. 171 | 172 | The index can be seen as index of the cursor in a sequence (the result set). The next fetch 173 | operation will fetch the row indexed by ``rownumber`` in that sequence. 174 | """ 175 | return self._rownumber 176 | 177 | def __next__(self): 178 | """Return the next row from the currently executing SQL statement using the same semantics 179 | as :py:meth:`fetchone`. A ``StopIteration`` exception is raised when the result set is 180 | exhausted. 181 | """ 182 | one = self.fetchone() 183 | if one is None: 184 | raise StopIteration 185 | else: 186 | return one 187 | 188 | next = __next__ 189 | 190 | def __iter__(self): 191 | """Return self to make cursors compatible to the iteration protocol.""" 192 | return self 193 | 194 | 195 | class DBAPITypeObject(object): 196 | # Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints 197 | def __init__(self, *values): 198 | self.values = values 199 | 200 | def __cmp__(self, other): 201 | if other in self.values: 202 | return 0 203 | if other < self.values: 204 | return 1 205 | else: 206 | return -1 207 | 208 | 209 | class ParamEscaper(object): 210 | _DATE_FORMAT = "%Y-%m-%d" 211 | _TIME_FORMAT = "%H:%M:%S.%f" 212 | _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT) 213 | 214 | def escape_args(self, parameters): 215 | if isinstance(parameters, dict): 216 | return {k: self.escape_item(v) for k, v in parameters.items()} 217 | elif isinstance(parameters, (list, tuple)): 218 | return tuple(self.escape_item(x) for x in parameters) 219 | else: 220 | raise exc.ProgrammingError("Unsupported param format: {}".format(parameters)) 221 | 222 | def escape_number(self, item): 223 | return item 224 | 225 | def escape_string(self, item): 226 | # Need to decode UTF-8 because of old sqlalchemy. 227 | # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings 228 | # as byte strings. The old version always encodes Unicode as byte strings, which breaks 229 | # string formatting here. 230 | if isinstance(item, bytes): 231 | item = item.decode('utf-8') 232 | # This is good enough when backslashes are literal, newlines are just followed, and the way 233 | # to escape a single quote is to put two single quotes. 234 | # (i.e. only special character is single quote) 235 | return "'{}'".format(item.replace("'", "''")) 236 | 237 | def escape_sequence(self, item): 238 | l = map(str, map(self.escape_item, item)) 239 | return '(' + ','.join(l) + ')' 240 | 241 | def escape_datetime(self, item, format, cutoff=0): 242 | dt_str = item.strftime(format) 243 | formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str 244 | return "'{}'".format(formatted) 245 | 246 | def escape_item(self, item): 247 | if item is None: 248 | return 'NULL' 249 | elif isinstance(item, (int, float)): 250 | return self.escape_number(item) 251 | elif isinstance(item, basestring): 252 | return self.escape_string(item) 253 | elif isinstance(item, Iterable): 254 | return self.escape_sequence(item) 255 | elif isinstance(item, datetime.datetime): 256 | return self.escape_datetime(item, self._DATETIME_FORMAT) 257 | elif isinstance(item, datetime.date): 258 | return self.escape_datetime(item, self._DATE_FORMAT) 259 | else: 260 | raise exc.ProgrammingError("Unsupported object {}".format(item)) 261 | 262 | 263 | class UniversalSet(object): 264 | """set containing everything""" 265 | def __contains__(self, item): 266 | return True 267 | -------------------------------------------------------------------------------- /pyhive/exc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Package private common utilities. Do not use directly. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import unicode_literals 6 | 7 | __all__ = [ 8 | 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError', 9 | 'ProgrammingError', 'DataError', 'NotSupportedError', 10 | ] 11 | 12 | 13 | class Error(Exception): 14 | """Exception that is the base class of all other error exceptions. 15 | 16 | You can use this to catch all errors with one single except statement. 17 | """ 18 | pass 19 | 20 | 21 | class Warning(Exception): 22 | """Exception raised for important warnings like data truncations while inserting, etc.""" 23 | pass 24 | 25 | 26 | class InterfaceError(Error): 27 | """Exception raised for errors that are related to the database interface rather than the 28 | database itself. 29 | """ 30 | pass 31 | 32 | 33 | class DatabaseError(Error): 34 | """Exception raised for errors that are related to the database.""" 35 | pass 36 | 37 | 38 | class InternalError(DatabaseError): 39 | """Exception raised when the database encounters an internal error, e.g. the cursor is not valid 40 | anymore, the transaction is out of sync, etc.""" 41 | pass 42 | 43 | 44 | class OperationalError(DatabaseError): 45 | """Exception raised for errors that are related to the database's operation and not necessarily 46 | under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name 47 | is not found, a transaction could not be processed, a memory allocation error occurred during 48 | processing, etc. 49 | """ 50 | pass 51 | 52 | 53 | class ProgrammingError(DatabaseError): 54 | """Exception raised for programming errors, e.g. table not found or already exists, syntax error 55 | in the SQL statement, wrong number of parameters specified, etc. 56 | """ 57 | pass 58 | 59 | 60 | class DataError(DatabaseError): 61 | """Exception raised for errors that are due to problems with the processed data like division by 62 | zero, numeric value out of range, etc. 63 | """ 64 | pass 65 | 66 | 67 | class NotSupportedError(DatabaseError): 68 | """Exception raised in case a method or database API was used which is not supported by the 69 | database, e.g. requesting a ``.rollback()`` on a connection that does not support transaction or 70 | has transactions turned off. 71 | """ 72 | pass 73 | -------------------------------------------------------------------------------- /pyhive/hive.py: -------------------------------------------------------------------------------- 1 | """DB-API implementation backed by HiveServer2 (Thrift API) 2 | 3 | See http://www.python.org/dev/peps/pep-0249/ 4 | 5 | Many docstrings in this file are based on the PEP, which is in the public domain. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import unicode_literals 10 | 11 | import base64 12 | import datetime 13 | import re 14 | from decimal import Decimal 15 | from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context 16 | 17 | 18 | from TCLIService import TCLIService 19 | from TCLIService import constants 20 | from TCLIService import ttypes 21 | from pyhive import common 22 | from pyhive.common import DBAPITypeObject 23 | # Make all exceptions visible in this module per DB-API 24 | from pyhive.exc import * # noqa 25 | from builtins import range 26 | import contextlib 27 | from future.utils import iteritems 28 | import getpass 29 | import logging 30 | import sys 31 | import thrift.transport.THttpClient 32 | import thrift.protocol.TBinaryProtocol 33 | import thrift.transport.TSocket 34 | import thrift.transport.TTransport 35 | 36 | # PEP 249 module globals 37 | apilevel = '2.0' 38 | threadsafety = 2 # Threads may share the module and connections. 39 | paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s 40 | 41 | _logger = logging.getLogger(__name__) 42 | 43 | _TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') 44 | 45 | ssl_cert_parameter_map = { 46 | "none": CERT_NONE, 47 | "optional": CERT_OPTIONAL, 48 | "required": CERT_REQUIRED, 49 | } 50 | 51 | 52 | def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): 53 | import sasl 54 | sasl_client = sasl.Client() 55 | sasl_client.setAttr('host', host) 56 | 57 | if sasl_auth == 'GSSAPI': 58 | sasl_client.setAttr('service', service) 59 | elif sasl_auth == 'PLAIN': 60 | sasl_client.setAttr('username', username) 61 | sasl_client.setAttr('password', password) 62 | else: 63 | raise ValueError("sasl_auth only supports GSSAPI and PLAIN") 64 | 65 | sasl_client.init() 66 | return sasl_client 67 | 68 | 69 | def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): 70 | from pyhive.sasl_compat import PureSASLClient 71 | 72 | if sasl_auth == 'GSSAPI': 73 | sasl_kwargs = {'service': service} 74 | elif sasl_auth == 'PLAIN': 75 | sasl_kwargs = {'username': username, 'password': password} 76 | else: 77 | raise ValueError("sasl_auth only supports GSSAPI and PLAIN") 78 | 79 | return PureSASLClient(host=host, **sasl_kwargs) 80 | 81 | 82 | def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): 83 | try: 84 | return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) 85 | # The sasl library is available 86 | except ImportError: 87 | # Fallback to pure-sasl library 88 | return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) 89 | 90 | 91 | def _parse_timestamp(value): 92 | if value: 93 | match = _TIMESTAMP_PATTERN.match(value) 94 | if match: 95 | if match.group(2): 96 | format = '%Y-%m-%d %H:%M:%S.%f' 97 | # use the pattern to truncate the value 98 | value = match.group() 99 | else: 100 | format = '%Y-%m-%d %H:%M:%S' 101 | value = datetime.datetime.strptime(value, format) 102 | else: 103 | raise Exception( 104 | 'Cannot convert "{}" into a datetime'.format(value)) 105 | else: 106 | value = None 107 | return value 108 | 109 | 110 | TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal, 111 | "TIMESTAMP_TYPE": _parse_timestamp} 112 | 113 | 114 | class HiveParamEscaper(common.ParamEscaper): 115 | def escape_string(self, item): 116 | # backslashes and single quotes need to be escaped 117 | # TODO verify against parser 118 | # Need to decode UTF-8 because of old sqlalchemy. 119 | # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings 120 | # as byte strings. The old version always encodes Unicode as byte strings, which breaks 121 | # string formatting here. 122 | if isinstance(item, bytes): 123 | item = item.decode('utf-8') 124 | return "'{}'".format( 125 | item 126 | .replace('\\', '\\\\') 127 | .replace("'", "\\'") 128 | .replace('\r', '\\r') 129 | .replace('\n', '\\n') 130 | .replace('\t', '\\t') 131 | ) 132 | 133 | 134 | _escaper = HiveParamEscaper() 135 | 136 | 137 | def connect(*args, **kwargs): 138 | """Constructor for creating a connection to the database. See class :py:class:`Connection` for 139 | arguments. 140 | 141 | :returns: a :py:class:`Connection` object. 142 | """ 143 | return Connection(*args, **kwargs) 144 | 145 | 146 | class Connection(object): 147 | """Wraps a Thrift session""" 148 | 149 | def __init__( 150 | self, 151 | host=None, 152 | port=None, 153 | scheme=None, 154 | username=None, 155 | database='default', 156 | auth=None, 157 | configuration=None, 158 | kerberos_service_name=None, 159 | password=None, 160 | check_hostname=None, 161 | ssl_cert=None, 162 | thrift_transport=None 163 | ): 164 | """Connect to HiveServer2 165 | 166 | :param host: What host HiveServer2 runs on 167 | :param port: What port HiveServer2 runs on. Defaults to 10000. 168 | :param auth: The value of hive.server2.authentication used by HiveServer2. 169 | Defaults to ``NONE``. 170 | :param configuration: A dictionary of Hive settings (functionally same as the `set` command) 171 | :param kerberos_service_name: Use with auth='KERBEROS' only 172 | :param password: Use with auth='LDAP' or auth='CUSTOM' only 173 | :param thrift_transport: A ``TTransportBase`` for custom advanced usage. 174 | Incompatible with host, port, auth, kerberos_service_name, and password. 175 | 176 | The way to support LDAP and GSSAPI is originated from cloudera/Impyla: 177 | https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62 178 | /impala/_thrift_api.py#L152-L160 179 | """ 180 | if scheme in ("https", "http") and thrift_transport is None: 181 | port = port or 1000 182 | ssl_context = None 183 | if scheme == "https": 184 | ssl_context = create_default_context() 185 | ssl_context.check_hostname = check_hostname == "true" 186 | ssl_cert = ssl_cert or "none" 187 | ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE) 188 | thrift_transport = thrift.transport.THttpClient.THttpClient( 189 | uri_or_host="{scheme}://{host}:{port}/cliservice/".format( 190 | scheme=scheme, host=host, port=port 191 | ), 192 | ssl_context=ssl_context, 193 | ) 194 | 195 | if auth in ("BASIC", "NOSASL", "NONE", None): 196 | # Always needs the Authorization header 197 | self._set_authorization_header(thrift_transport, username, password) 198 | elif auth == "KERBEROS" and kerberos_service_name: 199 | self._set_kerberos_header(thrift_transport, kerberos_service_name, host) 200 | else: 201 | raise ValueError( 202 | "Authentication is not valid use one of:" 203 | "BASIC, NOSASL, KERBEROS, NONE" 204 | ) 205 | host, port, auth, kerberos_service_name, password = ( 206 | None, None, None, None, None 207 | ) 208 | 209 | username = username or getpass.getuser() 210 | configuration = configuration or {} 211 | 212 | if (password is not None) != (auth in ('LDAP', 'CUSTOM')): 213 | raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; " 214 | "Remove password or use one of those modes") 215 | if (kerberos_service_name is not None) != (auth == 'KERBEROS'): 216 | raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode") 217 | if thrift_transport is not None: 218 | has_incompatible_arg = ( 219 | host is not None 220 | or port is not None 221 | or auth is not None 222 | or kerberos_service_name is not None 223 | or password is not None 224 | ) 225 | if has_incompatible_arg: 226 | raise ValueError("thrift_transport cannot be used with " 227 | "host/port/auth/kerberos_service_name/password") 228 | 229 | if thrift_transport is not None: 230 | self._transport = thrift_transport 231 | else: 232 | if port is None: 233 | port = 10000 234 | if auth is None: 235 | auth = 'NONE' 236 | socket = thrift.transport.TSocket.TSocket(host, port) 237 | if auth == 'NOSASL': 238 | # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml 239 | self._transport = thrift.transport.TTransport.TBufferedTransport(socket) 240 | elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): 241 | # Defer import so package dependency is optional 242 | import thrift_sasl 243 | 244 | if auth == 'KERBEROS': 245 | # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library 246 | sasl_auth = 'GSSAPI' 247 | else: 248 | sasl_auth = 'PLAIN' 249 | if password is None: 250 | # Password doesn't matter in NONE mode, just needs to be nonempty. 251 | password = 'x' 252 | 253 | self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) 254 | else: 255 | # All HS2 config options: 256 | # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration 257 | # PAM currently left to end user via thrift_transport option. 258 | raise NotImplementedError( 259 | "Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM " 260 | "authentication are supported, got {}".format(auth)) 261 | 262 | protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) 263 | self._client = TCLIService.Client(protocol) 264 | # oldest version that still contains features we care about 265 | # "V6 uses binary type for binary payload (was string) and uses columnar result set" 266 | protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6 267 | 268 | try: 269 | self._transport.open() 270 | open_session_req = ttypes.TOpenSessionReq( 271 | client_protocol=protocol_version, 272 | configuration=configuration, 273 | username=username, 274 | ) 275 | response = self._client.OpenSession(open_session_req) 276 | _check_status(response) 277 | assert response.sessionHandle is not None, "Expected a session from OpenSession" 278 | self._sessionHandle = response.sessionHandle 279 | assert response.serverProtocolVersion == protocol_version, \ 280 | "Unable to handle protocol version {}".format(response.serverProtocolVersion) 281 | with contextlib.closing(self.cursor()) as cursor: 282 | cursor.execute('USE `{}`'.format(database)) 283 | except: 284 | self._transport.close() 285 | raise 286 | 287 | @staticmethod 288 | def _set_authorization_header(transport, username=None, password=None): 289 | username = username or "user" 290 | password = password or "pass" 291 | auth_credentials = "{username}:{password}".format( 292 | username=username, password=password 293 | ).encode("UTF-8") 294 | auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( 295 | "UTF-8" 296 | ) 297 | transport.setCustomHeaders( 298 | { 299 | "Authorization": "Basic {auth_credentials_base64}".format( 300 | auth_credentials_base64=auth_credentials_base64 301 | ) 302 | } 303 | ) 304 | 305 | @staticmethod 306 | def _set_kerberos_header(transport, kerberos_service_name, host): 307 | import kerberos 308 | 309 | __, krb_context = kerberos.authGSSClientInit( 310 | service="{kerberos_service_name}@{host}".format( 311 | kerberos_service_name=kerberos_service_name, host=host 312 | ) 313 | ) 314 | kerberos.authGSSClientClean(krb_context, "") 315 | kerberos.authGSSClientStep(krb_context, "") 316 | auth_header = kerberos.authGSSClientResponse(krb_context) 317 | 318 | transport.setCustomHeaders( 319 | { 320 | "Authorization": "Negotiate {auth_header}".format( 321 | auth_header=auth_header 322 | ) 323 | } 324 | ) 325 | 326 | def __enter__(self): 327 | """Transport should already be opened by __init__""" 328 | return self 329 | 330 | def __exit__(self, exc_type, exc_val, exc_tb): 331 | """Call close""" 332 | self.close() 333 | 334 | def close(self): 335 | """Close the underlying session and Thrift transport""" 336 | req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle) 337 | response = self._client.CloseSession(req) 338 | self._transport.close() 339 | _check_status(response) 340 | 341 | def commit(self): 342 | """Hive does not support transactions, so this does nothing.""" 343 | pass 344 | 345 | def cursor(self, *args, **kwargs): 346 | """Return a new :py:class:`Cursor` object using the connection.""" 347 | return Cursor(self, *args, **kwargs) 348 | 349 | @property 350 | def client(self): 351 | return self._client 352 | 353 | @property 354 | def sessionHandle(self): 355 | return self._sessionHandle 356 | 357 | def rollback(self): 358 | raise NotSupportedError("Hive does not have transactions") # pragma: no cover 359 | 360 | 361 | class Cursor(common.DBAPICursor): 362 | """These objects represent a database cursor, which is used to manage the context of a fetch 363 | operation. 364 | 365 | Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately 366 | visible by other cursors or connections. 367 | """ 368 | 369 | def __init__(self, connection, arraysize=1000): 370 | self._operationHandle = None 371 | super(Cursor, self).__init__() 372 | self._arraysize = arraysize 373 | self._connection = connection 374 | 375 | def _reset_state(self): 376 | """Reset state about the previous query in preparation for running another query""" 377 | super(Cursor, self)._reset_state() 378 | self._description = None 379 | if self._operationHandle is not None: 380 | request = ttypes.TCloseOperationReq(self._operationHandle) 381 | try: 382 | response = self._connection.client.CloseOperation(request) 383 | _check_status(response) 384 | finally: 385 | self._operationHandle = None 386 | 387 | @property 388 | def arraysize(self): 389 | return self._arraysize 390 | 391 | @arraysize.setter 392 | def arraysize(self, value): 393 | """Array size cannot be None, and should be an integer""" 394 | default_arraysize = 1000 395 | try: 396 | self._arraysize = int(value) or default_arraysize 397 | except TypeError: 398 | self._arraysize = default_arraysize 399 | 400 | @property 401 | def description(self): 402 | """This read-only attribute is a sequence of 7-item sequences. 403 | 404 | Each of these sequences contains information describing one result column: 405 | 406 | - name 407 | - type_code 408 | - display_size (None in current implementation) 409 | - internal_size (None in current implementation) 410 | - precision (None in current implementation) 411 | - scale (None in current implementation) 412 | - null_ok (always True in current implementation) 413 | 414 | This attribute will be ``None`` for operations that do not return rows or if the cursor has 415 | not had an operation invoked via the :py:meth:`execute` method yet. 416 | 417 | The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the 418 | section below. 419 | """ 420 | if self._operationHandle is None or not self._operationHandle.hasResultSet: 421 | return None 422 | if self._description is None: 423 | req = ttypes.TGetResultSetMetadataReq(self._operationHandle) 424 | response = self._connection.client.GetResultSetMetadata(req) 425 | _check_status(response) 426 | columns = response.schema.columns 427 | self._description = [] 428 | for col in columns: 429 | primary_type_entry = col.typeDesc.types[0] 430 | if primary_type_entry.primitiveEntry is None: 431 | # All fancy stuff maps to string 432 | type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE] 433 | else: 434 | type_id = primary_type_entry.primitiveEntry.type 435 | type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id] 436 | self._description.append(( 437 | col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName, 438 | type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code, 439 | None, None, None, None, True 440 | )) 441 | return self._description 442 | 443 | def __enter__(self): 444 | return self 445 | 446 | def __exit__(self, exc_type, exc_val, exc_tb): 447 | self.close() 448 | 449 | def close(self): 450 | """Close the operation handle""" 451 | self._reset_state() 452 | 453 | def execute(self, operation, parameters=None, **kwargs): 454 | """Prepare and execute a database operation (query or command). 455 | 456 | Return values are not defined. 457 | """ 458 | # backward compatibility with Python < 3.7 459 | for kw in ['async', 'async_']: 460 | if kw in kwargs: 461 | async_ = kwargs[kw] 462 | break 463 | else: 464 | async_ = False 465 | 466 | # Prepare statement 467 | if parameters is None: 468 | sql = operation 469 | else: 470 | sql = operation % _escaper.escape_args(parameters) 471 | 472 | self._reset_state() 473 | 474 | self._state = self._STATE_RUNNING 475 | _logger.info('%s', sql) 476 | 477 | req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, 478 | sql, runAsync=async_) 479 | _logger.debug(req) 480 | response = self._connection.client.ExecuteStatement(req) 481 | _check_status(response) 482 | self._operationHandle = response.operationHandle 483 | 484 | def cancel(self): 485 | req = ttypes.TCancelOperationReq( 486 | operationHandle=self._operationHandle, 487 | ) 488 | response = self._connection.client.CancelOperation(req) 489 | _check_status(response) 490 | 491 | def _fetch_more(self): 492 | """Send another TFetchResultsReq and update state""" 493 | assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more" 494 | assert(self._operationHandle is not None), "Should have an op handle in _fetch_more" 495 | if not self._operationHandle.hasResultSet: 496 | raise ProgrammingError("No result set") 497 | req = ttypes.TFetchResultsReq( 498 | operationHandle=self._operationHandle, 499 | orientation=ttypes.TFetchOrientation.FETCH_NEXT, 500 | maxRows=self.arraysize, 501 | ) 502 | response = self._connection.client.FetchResults(req) 503 | _check_status(response) 504 | schema = self.description 505 | assert not response.results.rows, 'expected data in columnar format' 506 | columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in 507 | zip(response.results.columns, schema)] 508 | new_data = list(zip(*columns)) 509 | self._data += new_data 510 | # response.hasMoreRows seems to always be False, so we instead check the number of rows 511 | # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678 512 | # if not response.hasMoreRows: 513 | if not new_data: 514 | self._state = self._STATE_FINISHED 515 | 516 | def poll(self, get_progress_update=True): 517 | """Poll for and return the raw status data provided by the Hive Thrift REST API. 518 | :returns: ``ttypes.TGetOperationStatusResp`` 519 | :raises: ``ProgrammingError`` when no query has been started 520 | .. note:: 521 | This is not a part of DB-API. 522 | """ 523 | if self._state == self._STATE_NONE: 524 | raise ProgrammingError("No query yet") 525 | 526 | req = ttypes.TGetOperationStatusReq( 527 | operationHandle=self._operationHandle, 528 | getProgressUpdate=get_progress_update, 529 | ) 530 | response = self._connection.client.GetOperationStatus(req) 531 | _check_status(response) 532 | 533 | return response 534 | 535 | def fetch_logs(self): 536 | """Retrieve the logs produced by the execution of the query. 537 | Can be called multiple times to fetch the logs produced after the previous call. 538 | :returns: list 539 | :raises: ``ProgrammingError`` when no query has been started 540 | .. note:: 541 | This is not a part of DB-API. 542 | """ 543 | if self._state == self._STATE_NONE: 544 | raise ProgrammingError("No query yet") 545 | 546 | try: # Older Hive instances require logs to be retrieved using GetLog 547 | req = ttypes.TGetLogReq(operationHandle=self._operationHandle) 548 | logs = self._connection.client.GetLog(req).log.splitlines() 549 | except ttypes.TApplicationException as e: # Otherwise, retrieve logs using newer method 550 | if e.type != ttypes.TApplicationException.UNKNOWN_METHOD: 551 | raise 552 | logs = [] 553 | while True: 554 | req = ttypes.TFetchResultsReq( 555 | operationHandle=self._operationHandle, 556 | orientation=ttypes.TFetchOrientation.FETCH_NEXT, 557 | maxRows=self.arraysize, 558 | fetchType=1 # 0: results, 1: logs 559 | ) 560 | response = self._connection.client.FetchResults(req) 561 | _check_status(response) 562 | assert not response.results.rows, 'expected data in columnar format' 563 | assert len(response.results.columns) == 1, response.results.columns 564 | new_logs = _unwrap_column(response.results.columns[0]) 565 | logs += new_logs 566 | 567 | if not new_logs: 568 | break 569 | 570 | return logs 571 | 572 | 573 | # 574 | # Type Objects and Constructors 575 | # 576 | 577 | 578 | for type_id in constants.PRIMITIVE_TYPES: 579 | name = ttypes.TTypeId._VALUES_TO_NAMES[type_id] 580 | setattr(sys.modules[__name__], name, DBAPITypeObject([name])) 581 | 582 | 583 | # 584 | # Private utilities 585 | # 586 | 587 | 588 | def _unwrap_column(col, type_=None): 589 | """Return a list of raw values from a TColumn instance.""" 590 | for attr, wrapper in iteritems(col.__dict__): 591 | if wrapper is not None: 592 | result = wrapper.values 593 | nulls = wrapper.nulls # bit set describing what's null 594 | assert isinstance(nulls, bytes) 595 | for i, char in enumerate(nulls): 596 | byte = ord(char) if sys.version_info[0] == 2 else char 597 | for b in range(8): 598 | if byte & (1 << b): 599 | result[i * 8 + b] = None 600 | converter = TYPES_CONVERTER.get(type_, None) 601 | if converter and type_: 602 | result = [converter(row) if row else row for row in result] 603 | return result 604 | raise DataError("Got empty column value {}".format(col)) # pragma: no cover 605 | 606 | 607 | def _check_status(response): 608 | """Raise an OperationalError if the status is not success""" 609 | _logger.debug(response) 610 | if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS: 611 | raise OperationalError(response) 612 | -------------------------------------------------------------------------------- /pyhive/presto.py: -------------------------------------------------------------------------------- 1 | """DB-API implementation backed by Presto 2 | 3 | See http://www.python.org/dev/peps/pep-0249/ 4 | 5 | Many docstrings in this file are based on the PEP, which is in the public domain. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import unicode_literals 10 | 11 | from builtins import object 12 | from decimal import Decimal 13 | 14 | from pyhive import common 15 | from pyhive.common import DBAPITypeObject 16 | # Make all exceptions visible in this module per DB-API 17 | from pyhive.exc import * # noqa 18 | import base64 19 | import getpass 20 | import datetime 21 | import logging 22 | import requests 23 | from requests.auth import HTTPBasicAuth 24 | import os 25 | 26 | try: # Python 3 27 | import urllib.parse as urlparse 28 | except ImportError: # Python 2 29 | import urlparse 30 | 31 | 32 | # PEP 249 module globals 33 | apilevel = '2.0' 34 | threadsafety = 2 # Threads may share the module and connections. 35 | paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s 36 | 37 | _logger = logging.getLogger(__name__) 38 | 39 | TYPES_CONVERTER = { 40 | "decimal": Decimal, 41 | # As of Presto 0.69, binary data is returned as the varbinary type in base64 format 42 | "varbinary": base64.b64decode 43 | } 44 | 45 | class PrestoParamEscaper(common.ParamEscaper): 46 | def escape_datetime(self, item, format): 47 | _type = "timestamp" if isinstance(item, datetime.datetime) else "date" 48 | formatted = super(PrestoParamEscaper, self).escape_datetime(item, format, 3) 49 | return "{} {}".format(_type, formatted) 50 | 51 | 52 | _escaper = PrestoParamEscaper() 53 | 54 | 55 | def connect(*args, **kwargs): 56 | """Constructor for creating a connection to the database. See class :py:class:`Connection` for 57 | arguments. 58 | 59 | :returns: a :py:class:`Connection` object. 60 | """ 61 | return Connection(*args, **kwargs) 62 | 63 | 64 | class Connection(object): 65 | """Presto does not have a notion of a persistent connection. 66 | 67 | Thus, these objects are small stateless factories for cursors, which do all the real work. 68 | """ 69 | 70 | def __init__(self, *args, **kwargs): 71 | self._args = args 72 | self._kwargs = kwargs 73 | 74 | def close(self): 75 | """Presto does not have anything to close""" 76 | # TODO cancel outstanding queries? 77 | pass 78 | 79 | def commit(self): 80 | """Presto does not support transactions""" 81 | pass 82 | 83 | def cursor(self): 84 | """Return a new :py:class:`Cursor` object using the connection.""" 85 | return Cursor(*self._args, **self._kwargs) 86 | 87 | def rollback(self): 88 | raise NotSupportedError("Presto does not have transactions") # pragma: no cover 89 | 90 | 91 | class Cursor(common.DBAPICursor): 92 | """These objects represent a database cursor, which is used to manage the context of a fetch 93 | operation. 94 | 95 | Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately 96 | visible by other cursors or connections. 97 | """ 98 | 99 | def __init__(self, host, port='8080', username=None, principal_username=None, catalog='hive', 100 | schema='default', poll_interval=1, source='pyhive', session_props=None, 101 | protocol='http', password=None, requests_session=None, requests_kwargs=None, 102 | KerberosRemoteServiceName=None, KerberosPrincipal=None, 103 | KerberosConfigPath=None, KerberosKeytabPath=None, 104 | KerberosCredentialCachePath=None, KerberosUseCanonicalHostname=None): 105 | """ 106 | :param host: hostname to connect to, e.g. ``presto.example.com`` 107 | :param port: int -- port, defaults to 8080 108 | :param username: string -- defaults to system user name 109 | :param principal_username: string -- defaults to ``username`` argument if it exists, 110 | else defaults to system user name 111 | :param catalog: string -- defaults to ``hive`` 112 | :param schema: string -- defaults to ``default`` 113 | :param poll_interval: float -- how often to ask the Presto REST interface for a progress 114 | update, defaults to a second 115 | :param source: string -- arbitrary identifier (shows up in the Presto monitoring page) 116 | :param protocol: string -- network protocol, valid options are ``http`` and ``https``. 117 | defaults to ``http`` 118 | :param password: string -- Deprecated. Defaults to ``None``. 119 | Using BasicAuth, requires ``https``. 120 | Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``. 121 | May not be specified with ``requests_kwargs['auth']``. 122 | :param requests_session: a ``requests.Session`` object for advanced usage. If absent, this 123 | class will use the default requests behavior of making a new session per HTTP request. 124 | Caller is responsible for closing session. 125 | :param requests_kwargs: Additional ``**kwargs`` to pass to requests 126 | :param KerberosRemoteServiceName: string -- Presto coordinator Kerberos service name. 127 | This parameter is required for Kerberos authentiation. 128 | :param KerberosPrincipal: string -- The principal to use when authenticating to 129 | the Presto coordinator. 130 | :param KerberosConfigPath: string -- Kerberos configuration file. 131 | (default: /etc/krb5.conf) 132 | :param KerberosKeytabPath: string -- Kerberos keytab file. 133 | :param KerberosCredentialCachePath: string -- Kerberos credential cache. 134 | :param KerberosUseCanonicalHostname: boolean -- Use the canonical hostname of the 135 | Presto coordinator for the Kerberos service principal by first resolving the 136 | hostname to an IP address and then doing a reverse DNS lookup for that IP address. 137 | This is enabled by default. 138 | """ 139 | super(Cursor, self).__init__(poll_interval) 140 | # Config 141 | self._host = host 142 | self._port = port 143 | """ 144 | Presto User Impersonation: https://docs.starburstdata.com/latest/security/impersonation.html 145 | 146 | User impersonation allows the execution of queries in Presto based on principal_username 147 | argument, instead of executing the query as the account which authenticated against Presto. 148 | (Usually a service account) 149 | 150 | Allows for a service account to authenticate with Presto, and then leverage the 151 | principal_username as the user Presto will execute the query as. This is required by 152 | applications that leverage authentication methods like SAML, where the application has a 153 | username, but not a password to still leverage user specific Presto Resource Groups and 154 | Authorization rules that would not be applied when only using a shared service account. 155 | This also allows auditing of who is executing a query in these environments, instead of 156 | having all queryes run by the shared service account. 157 | """ 158 | self._username = principal_username or username or getpass.getuser() 159 | self._catalog = catalog 160 | self._schema = schema 161 | self._arraysize = 1 162 | self._poll_interval = poll_interval 163 | self._source = source 164 | self._session_props = session_props if session_props is not None else {} 165 | self.last_query_id = None 166 | 167 | if protocol not in ('http', 'https'): 168 | raise ValueError("Protocol must be http/https, was {!r}".format(protocol)) 169 | self._protocol = protocol 170 | 171 | self._requests_session = requests_session or requests 172 | 173 | requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {} 174 | 175 | if KerberosRemoteServiceName is not None: 176 | from requests_kerberos import HTTPKerberosAuth, OPTIONAL 177 | 178 | hostname_override = None 179 | if KerberosUseCanonicalHostname is not None \ 180 | and KerberosUseCanonicalHostname.lower() == 'false': 181 | hostname_override = host 182 | if KerberosConfigPath is not None: 183 | os.environ['KRB5_CONFIG'] = KerberosConfigPath 184 | if KerberosKeytabPath is not None: 185 | os.environ['KRB5_CLIENT_KTNAME'] = KerberosKeytabPath 186 | if KerberosCredentialCachePath is not None: 187 | os.environ['KRB5CCNAME'] = KerberosCredentialCachePath 188 | 189 | requests_kwargs['auth'] = HTTPKerberosAuth(mutual_authentication=OPTIONAL, 190 | principal=KerberosPrincipal, 191 | service=KerberosRemoteServiceName, 192 | hostname_override=hostname_override) 193 | 194 | else: 195 | if password is not None and 'auth' in requests_kwargs: 196 | raise ValueError("Cannot use both password and requests_kwargs authentication") 197 | for k in ('method', 'url', 'data', 'headers'): 198 | if k in requests_kwargs: 199 | raise ValueError("Cannot override requests argument {}".format(k)) 200 | if password is not None: 201 | requests_kwargs['auth'] = HTTPBasicAuth(username, password) 202 | if protocol != 'https': 203 | raise ValueError("Protocol must be https when passing a password") 204 | self._requests_kwargs = requests_kwargs 205 | 206 | self._reset_state() 207 | 208 | def _reset_state(self): 209 | """Reset state about the previous query in preparation for running another query""" 210 | super(Cursor, self)._reset_state() 211 | self._nextUri = None 212 | self._columns = None 213 | 214 | @property 215 | def description(self): 216 | """This read-only attribute is a sequence of 7-item sequences. 217 | 218 | Each of these sequences contains information describing one result column: 219 | 220 | - name 221 | - type_code 222 | - display_size (None in current implementation) 223 | - internal_size (None in current implementation) 224 | - precision (None in current implementation) 225 | - scale (None in current implementation) 226 | - null_ok (always True in current implementation) 227 | 228 | The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the 229 | section below. 230 | """ 231 | # Sleep until we're done or we got the columns 232 | self._fetch_while( 233 | lambda: self._columns is None and 234 | self._state not in (self._STATE_NONE, self._STATE_FINISHED) 235 | ) 236 | if self._columns is None: 237 | return None 238 | return [ 239 | # name, type_code, display_size, internal_size, precision, scale, null_ok 240 | (col['name'], col['type'], None, None, None, None, True) 241 | for col in self._columns 242 | ] 243 | 244 | def execute(self, operation, parameters=None): 245 | """Prepare and execute a database operation (query or command). 246 | 247 | Return values are not defined. 248 | """ 249 | headers = { 250 | 'X-Presto-Catalog': self._catalog, 251 | 'X-Presto-Schema': self._schema, 252 | 'X-Presto-Source': self._source, 253 | 'X-Presto-User': self._username, 254 | } 255 | 256 | if self._session_props: 257 | headers['X-Presto-Session'] = ','.join( 258 | '{}={}'.format(propname, propval) 259 | for propname, propval in self._session_props.items() 260 | ) 261 | 262 | # Prepare statement 263 | if parameters is None: 264 | sql = operation 265 | else: 266 | sql = operation % _escaper.escape_args(parameters) 267 | 268 | self._reset_state() 269 | 270 | self._state = self._STATE_RUNNING 271 | url = urlparse.urlunparse(( 272 | self._protocol, 273 | '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) 274 | _logger.info('%s', sql) 275 | _logger.debug("Headers: %s", headers) 276 | response = self._requests_session.post( 277 | url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) 278 | self._process_response(response) 279 | 280 | def cancel(self): 281 | if self._state == self._STATE_NONE: 282 | raise ProgrammingError("No query yet") 283 | if self._nextUri is None: 284 | assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" 285 | return 286 | 287 | response = self._requests_session.delete(self._nextUri, **self._requests_kwargs) 288 | if response.status_code != requests.codes.no_content: 289 | fmt = "Unexpected status code after cancel {}\n{}" 290 | raise OperationalError(fmt.format(response.status_code, response.content)) 291 | 292 | self._state = self._STATE_FINISHED 293 | self._nextUri = None 294 | 295 | def poll(self): 296 | """Poll for and return the raw status data provided by the Presto REST API. 297 | 298 | :returns: dict -- JSON status information or ``None`` if the query is done 299 | :raises: ``ProgrammingError`` when no query has been started 300 | 301 | .. note:: 302 | This is not a part of DB-API. 303 | """ 304 | if self._state == self._STATE_NONE: 305 | raise ProgrammingError("No query yet") 306 | if self._nextUri is None: 307 | assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" 308 | return None 309 | response = self._requests_session.get(self._nextUri, **self._requests_kwargs) 310 | self._process_response(response) 311 | return response.json() 312 | 313 | def _fetch_more(self): 314 | """Fetch the next URI and update state""" 315 | self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) 316 | 317 | def _process_data(self, rows): 318 | for i, col in enumerate(self.description): 319 | col_type = col[1].split("(")[0].lower() 320 | if col_type in TYPES_CONVERTER: 321 | for row in rows: 322 | if row[i] is not None: 323 | row[i] = TYPES_CONVERTER[col_type](row[i]) 324 | 325 | def _process_response(self, response): 326 | """Given the JSON response from Presto's REST API, update the internal state with the next 327 | URI and any data from the response 328 | """ 329 | # TODO handle HTTP 503 330 | if response.status_code != requests.codes.ok: 331 | fmt = "Unexpected status code {}\n{}" 332 | raise OperationalError(fmt.format(response.status_code, response.content)) 333 | 334 | response_json = response.json() 335 | _logger.debug("Got response %s", response_json) 336 | assert self._state == self._STATE_RUNNING, "Should be running if processing response" 337 | self._nextUri = response_json.get('nextUri') 338 | self._columns = response_json.get('columns') 339 | if 'id' in response_json: 340 | self.last_query_id = response_json['id'] 341 | if 'X-Presto-Clear-Session' in response.headers: 342 | propname = response.headers['X-Presto-Clear-Session'] 343 | self._session_props.pop(propname, None) 344 | if 'X-Presto-Set-Session' in response.headers: 345 | propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1) 346 | self._session_props[propname] = propval 347 | if 'data' in response_json: 348 | assert self._columns 349 | new_data = response_json['data'] 350 | self._process_data(new_data) 351 | self._data += map(tuple, new_data) 352 | if 'nextUri' not in response_json: 353 | self._state = self._STATE_FINISHED 354 | if 'error' in response_json: 355 | raise DatabaseError(response_json['error']) 356 | 357 | 358 | # 359 | # Type Objects and Constructors 360 | # 361 | 362 | 363 | # See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java 364 | FIXED_INT_64 = DBAPITypeObject(['bigint']) 365 | VARIABLE_BINARY = DBAPITypeObject(['varchar']) 366 | DOUBLE = DBAPITypeObject(['double']) 367 | BOOLEAN = DBAPITypeObject(['boolean']) 368 | -------------------------------------------------------------------------------- /pyhive/sasl_compat.py: -------------------------------------------------------------------------------- 1 | # Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py 2 | # which uses Apache-2.0 license as of 21 May 2023. 3 | # This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl 4 | # via PR https://github.com/cloudera/impyla/pull/179 5 | # Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 6 | # but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 7 | # Hence this code is required for the fallback to work. 8 | 9 | 10 | from puresasl.client import SASLClient, SASLError 11 | from contextlib import contextmanager 12 | 13 | @contextmanager 14 | def error_catcher(self, Exc = Exception): 15 | try: 16 | self.error = None 17 | yield 18 | except Exc as e: 19 | self.error = str(e) 20 | 21 | 22 | class PureSASLClient(SASLClient): 23 | def __init__(self, *args, **kwargs): 24 | self.error = None 25 | super(PureSASLClient, self).__init__(*args, **kwargs) 26 | 27 | def start(self, mechanism): 28 | with error_catcher(self, SASLError): 29 | if isinstance(mechanism, list): 30 | self.choose_mechanism(mechanism) 31 | else: 32 | self.choose_mechanism([mechanism]) 33 | return True, self.mechanism, self.process() 34 | # else 35 | return False, mechanism, None 36 | 37 | def encode(self, incoming): 38 | with error_catcher(self): 39 | return True, self.unwrap(incoming) 40 | # else 41 | return False, None 42 | 43 | def decode(self, outgoing): 44 | with error_catcher(self): 45 | return True, self.wrap(outgoing) 46 | # else 47 | return False, None 48 | 49 | def step(self, challenge=None): 50 | with error_catcher(self): 51 | return True, self.process(challenge) 52 | # else 53 | return False, None 54 | 55 | def getError(self): 56 | return self.error 57 | -------------------------------------------------------------------------------- /pyhive/sqlalchemy_hive.py: -------------------------------------------------------------------------------- 1 | """Integration between SQLAlchemy and Hive. 2 | 3 | Some code based on 4 | https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py 5 | which is released under the MIT license. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import unicode_literals 10 | 11 | import datetime 12 | import decimal 13 | 14 | import re 15 | from sqlalchemy import exc 16 | from sqlalchemy.sql import text 17 | try: 18 | from sqlalchemy import processors 19 | except ImportError: 20 | # Required for SQLAlchemy>=2.0 21 | from sqlalchemy.engine import processors 22 | from sqlalchemy import types 23 | from sqlalchemy import util 24 | # TODO shouldn't use mysql type 25 | try: 26 | from sqlalchemy.databases import mysql 27 | mysql_tinyinteger = mysql.MSTinyInteger 28 | except ImportError: 29 | # Required for SQLAlchemy>2.0 30 | from sqlalchemy.dialects import mysql 31 | mysql_tinyinteger = mysql.base.MSTinyInteger 32 | from sqlalchemy.engine import default 33 | from sqlalchemy.sql import compiler 34 | from sqlalchemy.sql.compiler import SQLCompiler 35 | 36 | from pyhive import hive 37 | from pyhive.common import UniversalSet 38 | 39 | from dateutil.parser import parse 40 | from decimal import Decimal 41 | 42 | 43 | class HiveStringTypeBase(types.TypeDecorator): 44 | """Translates strings returned by Thrift into something else""" 45 | impl = types.String 46 | 47 | def process_bind_param(self, value, dialect): 48 | raise NotImplementedError("Writing to Hive not supported") 49 | 50 | 51 | class HiveDate(HiveStringTypeBase): 52 | """Translates date strings to date objects""" 53 | impl = types.DATE 54 | 55 | def process_result_value(self, value, dialect): 56 | return processors.str_to_date(value) 57 | 58 | def result_processor(self, dialect, coltype): 59 | def process(value): 60 | if isinstance(value, datetime.datetime): 61 | return value.date() 62 | elif isinstance(value, datetime.date): 63 | return value 64 | elif value is not None: 65 | return parse(value).date() 66 | else: 67 | return None 68 | 69 | return process 70 | 71 | def adapt(self, impltype, **kwargs): 72 | return self.impl 73 | 74 | 75 | class HiveTimestamp(HiveStringTypeBase): 76 | """Translates timestamp strings to datetime objects""" 77 | impl = types.TIMESTAMP 78 | 79 | def process_result_value(self, value, dialect): 80 | return processors.str_to_datetime(value) 81 | 82 | def result_processor(self, dialect, coltype): 83 | def process(value): 84 | if isinstance(value, datetime.datetime): 85 | return value 86 | elif value is not None: 87 | return parse(value) 88 | else: 89 | return None 90 | 91 | return process 92 | 93 | def adapt(self, impltype, **kwargs): 94 | return self.impl 95 | 96 | 97 | class HiveDecimal(HiveStringTypeBase): 98 | """Translates strings to decimals""" 99 | impl = types.DECIMAL 100 | 101 | def process_result_value(self, value, dialect): 102 | if value is not None: 103 | return decimal.Decimal(value) 104 | else: 105 | return None 106 | 107 | def result_processor(self, dialect, coltype): 108 | def process(value): 109 | if isinstance(value, Decimal): 110 | return value 111 | elif value is not None: 112 | return Decimal(value) 113 | else: 114 | return None 115 | 116 | return process 117 | 118 | def adapt(self, impltype, **kwargs): 119 | return self.impl 120 | 121 | 122 | class HiveIdentifierPreparer(compiler.IdentifierPreparer): 123 | # Just quote everything to make things simpler / easier to upgrade 124 | reserved_words = UniversalSet() 125 | 126 | def __init__(self, dialect): 127 | super(HiveIdentifierPreparer, self).__init__( 128 | dialect, 129 | initial_quote='`', 130 | ) 131 | 132 | 133 | _type_map = { 134 | 'boolean': types.Boolean, 135 | 'tinyint': mysql_tinyinteger, 136 | 'smallint': types.SmallInteger, 137 | 'int': types.Integer, 138 | 'bigint': types.BigInteger, 139 | 'float': types.Float, 140 | 'double': types.Float, 141 | 'string': types.String, 142 | 'varchar': types.String, 143 | 'char': types.String, 144 | 'date': HiveDate, 145 | 'timestamp': HiveTimestamp, 146 | 'binary': types.String, 147 | 'array': types.String, 148 | 'map': types.String, 149 | 'struct': types.String, 150 | 'uniontype': types.String, 151 | 'decimal': HiveDecimal, 152 | } 153 | 154 | 155 | class HiveCompiler(SQLCompiler): 156 | def visit_concat_op_binary(self, binary, operator, **kw): 157 | return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) 158 | 159 | def visit_insert(self, *args, **kwargs): 160 | result = super(HiveCompiler, self).visit_insert(*args, **kwargs) 161 | # Massage the result into Hive's format 162 | # INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ... 163 | # => 164 | # INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ... 165 | regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)' 166 | assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result) 167 | return re.sub(regex, r'\1 TABLE \2', result) 168 | 169 | def visit_column(self, *args, **kwargs): 170 | result = super(HiveCompiler, self).visit_column(*args, **kwargs) 171 | dot_count = result.count('.') 172 | assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result) 173 | if dot_count == 2: 174 | # we have something of the form schema.table.column 175 | # hive doesn't like the schema in front, so chop it out 176 | result = result[result.index('.') + 1:] 177 | return result 178 | 179 | def visit_char_length_func(self, fn, **kw): 180 | return 'length{}'.format(self.function_argspec(fn, **kw)) 181 | 182 | 183 | class HiveTypeCompiler(compiler.GenericTypeCompiler): 184 | def visit_INTEGER(self, type_): 185 | return 'INT' 186 | 187 | def visit_NUMERIC(self, type_): 188 | return 'DECIMAL' 189 | 190 | def visit_CHAR(self, type_): 191 | return 'STRING' 192 | 193 | def visit_VARCHAR(self, type_): 194 | return 'STRING' 195 | 196 | def visit_NCHAR(self, type_): 197 | return 'STRING' 198 | 199 | def visit_TEXT(self, type_): 200 | return 'STRING' 201 | 202 | def visit_CLOB(self, type_): 203 | return 'STRING' 204 | 205 | def visit_BLOB(self, type_): 206 | return 'BINARY' 207 | 208 | def visit_TIME(self, type_): 209 | return 'TIMESTAMP' 210 | 211 | def visit_DATE(self, type_): 212 | return 'TIMESTAMP' 213 | 214 | def visit_DATETIME(self, type_): 215 | return 'TIMESTAMP' 216 | 217 | 218 | class HiveExecutionContext(default.DefaultExecutionContext): 219 | """This is pretty much the same as SQLiteExecutionContext to work around the same issue. 220 | 221 | http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#dotted-column-names 222 | 223 | engine = create_engine('hive://...', execution_options={'hive_raw_colnames': True}) 224 | """ 225 | 226 | @util.memoized_property 227 | def _preserve_raw_colnames(self): 228 | # Ideally, this would also gate on hive.resultset.use.unique.column.names 229 | return self.execution_options.get('hive_raw_colnames', False) 230 | 231 | def _translate_colname(self, colname): 232 | # Adjust for dotted column names. 233 | # When hive.resultset.use.unique.column.names is true (the default), Hive returns column 234 | # names as "tablename.colname" in cursor.description. 235 | if not self._preserve_raw_colnames and '.' in colname: 236 | return colname.split('.')[-1], colname 237 | else: 238 | return colname, None 239 | 240 | 241 | class HiveDialect(default.DefaultDialect): 242 | name = 'hive' 243 | driver = 'thrift' 244 | execution_ctx_cls = HiveExecutionContext 245 | preparer = HiveIdentifierPreparer 246 | statement_compiler = HiveCompiler 247 | supports_views = True 248 | supports_alter = True 249 | supports_pk_autoincrement = False 250 | supports_default_values = False 251 | supports_empty_insert = False 252 | supports_native_decimal = True 253 | supports_native_boolean = True 254 | supports_unicode_statements = True 255 | supports_unicode_binds = True 256 | returns_unicode_strings = True 257 | description_encoding = None 258 | supports_multivalues_insert = True 259 | type_compiler = HiveTypeCompiler 260 | supports_sane_rowcount = False 261 | supports_statement_cache = False 262 | 263 | @classmethod 264 | def dbapi(cls): 265 | return hive 266 | 267 | @classmethod 268 | def import_dbapi(cls): 269 | return hive 270 | 271 | def create_connect_args(self, url): 272 | kwargs = { 273 | 'host': url.host, 274 | 'port': url.port or 10000, 275 | 'username': url.username, 276 | 'password': url.password, 277 | 'database': url.database or 'default', 278 | } 279 | kwargs.update(url.query) 280 | return [], kwargs 281 | 282 | def get_schema_names(self, connection, **kw): 283 | # Equivalent to SHOW DATABASES 284 | return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] 285 | 286 | def get_view_names(self, connection, schema=None, **kw): 287 | # Hive does not provide functionality to query tableType 288 | # This allows reflection to not crash at the cost of being inaccurate 289 | return self.get_table_names(connection, schema, **kw) 290 | 291 | def _get_table_columns(self, connection, table_name, schema): 292 | full_table = table_name 293 | if schema: 294 | full_table = schema + '.' + table_name 295 | # TODO using TGetColumnsReq hangs after sending TFetchResultsReq. 296 | # Using DESCRIBE works but is uglier. 297 | try: 298 | # This needs the table name to be unescaped (no backticks). 299 | rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() 300 | except exc.OperationalError as e: 301 | # Does the table exist? 302 | regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' 303 | regex = regex_fmt.format(re.escape(full_table)) 304 | if re.search(regex, e.args[0]): 305 | raise exc.NoSuchTableError(full_table) 306 | else: 307 | raise 308 | else: 309 | # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist 310 | regex = r'Table .* does not exist' 311 | if len(rows) == 1 and re.match(regex, rows[0].col_name): 312 | raise exc.NoSuchTableError(full_table) 313 | return rows 314 | 315 | def has_table(self, connection, table_name, schema=None, **kw): 316 | try: 317 | self._get_table_columns(connection, table_name, schema) 318 | return True 319 | except exc.NoSuchTableError: 320 | return False 321 | 322 | def get_columns(self, connection, table_name, schema=None, **kw): 323 | rows = self._get_table_columns(connection, table_name, schema) 324 | # Strip whitespace 325 | rows = [[col.strip() if col else None for col in row] for row in rows] 326 | # Filter out empty rows and comment 327 | rows = [row for row in rows if row[0] and row[0] != '# col_name'] 328 | result = [] 329 | for (col_name, col_type, _comment) in rows: 330 | if col_name == '# Partition Information': 331 | break 332 | # Take out the more detailed type information 333 | # e.g. 'map' -> 'map' 334 | # 'decimal(10,1)' -> decimal 335 | col_type = re.search(r'^\w+', col_type).group(0) 336 | try: 337 | coltype = _type_map[col_type] 338 | except KeyError: 339 | util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name)) 340 | coltype = types.NullType 341 | 342 | result.append({ 343 | 'name': col_name, 344 | 'type': coltype, 345 | 'nullable': True, 346 | 'default': None, 347 | }) 348 | return result 349 | 350 | def get_foreign_keys(self, connection, table_name, schema=None, **kw): 351 | # Hive has no support for foreign keys. 352 | return [] 353 | 354 | def get_pk_constraint(self, connection, table_name, schema=None, **kw): 355 | # Hive has no support for primary keys. 356 | return [] 357 | 358 | def get_indexes(self, connection, table_name, schema=None, **kw): 359 | rows = self._get_table_columns(connection, table_name, schema) 360 | # Strip whitespace 361 | rows = [[col.strip() if col else None for col in row] for row in rows] 362 | # Filter out empty rows and comment 363 | rows = [row for row in rows if row[0] and row[0] != '# col_name'] 364 | for i, (col_name, _col_type, _comment) in enumerate(rows): 365 | if col_name == '# Partition Information': 366 | break 367 | # Handle partition columns 368 | col_names = [] 369 | for col_name, _col_type, _comment in rows[i + 1:]: 370 | col_names.append(col_name) 371 | if col_names: 372 | return [{'name': 'partition', 'column_names': col_names, 'unique': False}] 373 | else: 374 | return [] 375 | 376 | def get_table_names(self, connection, schema=None, **kw): 377 | query = 'SHOW TABLES' 378 | if schema: 379 | query += ' IN ' + self.identifier_preparer.quote_identifier(schema) 380 | return [row[0] for row in connection.execute(text(query))] 381 | 382 | def do_rollback(self, dbapi_connection): 383 | # No transactions for Hive 384 | pass 385 | 386 | def _check_unicode_returns(self, connection, additional_tests=None): 387 | # We decode everything as UTF-8 388 | return True 389 | 390 | def _check_unicode_description(self, connection): 391 | # We decode everything as UTF-8 392 | return True 393 | 394 | 395 | class HiveHTTPDialect(HiveDialect): 396 | 397 | name = "hive" 398 | scheme = "http" 399 | driver = "rest" 400 | 401 | def create_connect_args(self, url): 402 | kwargs = { 403 | "host": url.host, 404 | "port": url.port or 10000, 405 | "scheme": self.scheme, 406 | "username": url.username or None, 407 | "password": url.password or None, 408 | } 409 | if url.query: 410 | kwargs.update(url.query) 411 | return [], kwargs 412 | return ([], kwargs) 413 | 414 | 415 | class HiveHTTPSDialect(HiveHTTPDialect): 416 | 417 | name = "hive" 418 | scheme = "https" 419 | -------------------------------------------------------------------------------- /pyhive/sqlalchemy_presto.py: -------------------------------------------------------------------------------- 1 | """Integration between SQLAlchemy and Presto. 2 | 3 | Some code based on 4 | https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py 5 | which is released under the MIT license. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import unicode_literals 10 | 11 | import re 12 | import sqlalchemy 13 | from sqlalchemy import exc 14 | from sqlalchemy import types 15 | from sqlalchemy import util 16 | # TODO shouldn't use mysql type 17 | from sqlalchemy.sql import text 18 | try: 19 | from sqlalchemy.databases import mysql 20 | mysql_tinyinteger = mysql.MSTinyInteger 21 | except ImportError: 22 | # Required for SQLAlchemy>=2.0 23 | from sqlalchemy.dialects import mysql 24 | mysql_tinyinteger = mysql.base.MSTinyInteger 25 | from sqlalchemy.engine import default 26 | from sqlalchemy.sql import compiler 27 | from sqlalchemy.sql.compiler import SQLCompiler 28 | 29 | from pyhive import presto 30 | from pyhive.common import UniversalSet 31 | 32 | sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) 33 | 34 | class PrestoIdentifierPreparer(compiler.IdentifierPreparer): 35 | # Just quote everything to make things simpler / easier to upgrade 36 | reserved_words = UniversalSet() 37 | 38 | 39 | _type_map = { 40 | 'boolean': types.Boolean, 41 | 'tinyint': mysql_tinyinteger, 42 | 'smallint': types.SmallInteger, 43 | 'integer': types.Integer, 44 | 'bigint': types.BigInteger, 45 | 'real': types.Float, 46 | 'double': types.Float, 47 | 'varchar': types.String, 48 | 'timestamp': types.TIMESTAMP, 49 | 'date': types.DATE, 50 | 'varbinary': types.VARBINARY, 51 | } 52 | 53 | 54 | class PrestoCompiler(SQLCompiler): 55 | def visit_char_length_func(self, fn, **kw): 56 | return 'length{}'.format(self.function_argspec(fn, **kw)) 57 | 58 | 59 | class PrestoTypeCompiler(compiler.GenericTypeCompiler): 60 | def visit_CLOB(self, type_, **kw): 61 | raise ValueError("Presto does not support the CLOB column type.") 62 | 63 | def visit_NCLOB(self, type_, **kw): 64 | raise ValueError("Presto does not support the NCLOB column type.") 65 | 66 | def visit_DATETIME(self, type_, **kw): 67 | raise ValueError("Presto does not support the DATETIME column type.") 68 | 69 | def visit_FLOAT(self, type_, **kw): 70 | return 'DOUBLE' 71 | 72 | def visit_TEXT(self, type_, **kw): 73 | if type_.length: 74 | return 'VARCHAR({:d})'.format(type_.length) 75 | else: 76 | return 'VARCHAR' 77 | 78 | 79 | class PrestoDialect(default.DefaultDialect): 80 | name = 'presto' 81 | driver = 'rest' 82 | paramstyle = 'pyformat' 83 | preparer = PrestoIdentifierPreparer 84 | statement_compiler = PrestoCompiler 85 | supports_alter = False 86 | supports_pk_autoincrement = False 87 | supports_default_values = False 88 | supports_empty_insert = False 89 | supports_multivalues_insert = True 90 | supports_unicode_statements = True 91 | supports_unicode_binds = True 92 | supports_statement_cache = False 93 | returns_unicode_strings = True 94 | description_encoding = None 95 | supports_native_boolean = True 96 | type_compiler = PrestoTypeCompiler 97 | 98 | @classmethod 99 | def dbapi(cls): 100 | return presto 101 | 102 | @classmethod 103 | def import_dbapi(cls): 104 | return presto 105 | 106 | def create_connect_args(self, url): 107 | db_parts = (url.database or 'hive').split('/') 108 | kwargs = { 109 | 'host': url.host, 110 | 'port': url.port or 8080, 111 | 'username': url.username, 112 | 'password': url.password 113 | } 114 | kwargs.update(url.query) 115 | if len(db_parts) == 1: 116 | kwargs['catalog'] = db_parts[0] 117 | elif len(db_parts) == 2: 118 | kwargs['catalog'] = db_parts[0] 119 | kwargs['schema'] = db_parts[1] 120 | else: 121 | raise ValueError("Unexpected database format {}".format(url.database)) 122 | return [], kwargs 123 | 124 | def get_schema_names(self, connection, **kw): 125 | return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] 126 | 127 | def _get_table_columns(self, connection, table_name, schema): 128 | full_table = self.identifier_preparer.quote_identifier(table_name) 129 | if schema: 130 | full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table 131 | try: 132 | return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) 133 | except (presto.DatabaseError, exc.DatabaseError) as e: 134 | # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which 135 | # it successfully does in the Hive version. The difference with Presto is that this 136 | # error is raised when fetching the cursor's description rather than the initial execute 137 | # call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped 138 | # presto.DatabaseError here. 139 | # Does the table exist? 140 | msg = ( 141 | e.args[0].get('message') if e.args and isinstance(e.args[0], dict) 142 | else e.args[0] if e.args and isinstance(e.args[0], str) 143 | else None 144 | ) 145 | regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name)) 146 | if msg and re.search(regex, msg): 147 | raise exc.NoSuchTableError(table_name) 148 | else: 149 | raise 150 | 151 | def has_table(self, connection, table_name, schema=None, **kw): 152 | try: 153 | self._get_table_columns(connection, table_name, schema) 154 | return True 155 | except exc.NoSuchTableError: 156 | return False 157 | 158 | def get_columns(self, connection, table_name, schema=None, **kw): 159 | rows = self._get_table_columns(connection, table_name, schema) 160 | result = [] 161 | for row in rows: 162 | try: 163 | coltype = _type_map[row.Type] 164 | except KeyError: 165 | util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column)) 166 | coltype = types.NullType 167 | result.append({ 168 | 'name': row.Column, 169 | 'type': coltype, 170 | # newer Presto no longer includes this column 171 | 'nullable': getattr(row, 'Null', True), 172 | 'default': None, 173 | }) 174 | return result 175 | 176 | def get_foreign_keys(self, connection, table_name, schema=None, **kw): 177 | # Hive has no support for foreign keys. 178 | return [] 179 | 180 | def get_pk_constraint(self, connection, table_name, schema=None, **kw): 181 | # Hive has no support for primary keys. 182 | return [] 183 | 184 | def get_indexes(self, connection, table_name, schema=None, **kw): 185 | rows = self._get_table_columns(connection, table_name, schema) 186 | col_names = [] 187 | for row in rows: 188 | part_key = 'Partition Key' 189 | # Presto puts this information in one of 3 places depending on version 190 | # - a boolean column named "Partition Key" 191 | # - a string in the "Comment" column 192 | # - a string in the "Extra" column 193 | if sqlalchemy_version >= 1.4: 194 | row = row._mapping 195 | is_partition_key = ( 196 | (part_key in row and row[part_key]) 197 | or row['Comment'].startswith(part_key) 198 | or ('Extra' in row and 'partition key' in row['Extra']) 199 | ) 200 | if is_partition_key: 201 | col_names.append(row['Column']) 202 | if col_names: 203 | return [{'name': 'partition', 'column_names': col_names, 'unique': False}] 204 | else: 205 | return [] 206 | 207 | def get_table_names(self, connection, schema=None, **kw): 208 | query = 'SHOW TABLES' 209 | if schema: 210 | query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) 211 | return [row.Table for row in connection.execute(text(query))] 212 | 213 | def do_rollback(self, dbapi_connection): 214 | # No transactions for Presto 215 | pass 216 | 217 | def _check_unicode_returns(self, connection, additional_tests=None): 218 | # requests gives back Unicode strings 219 | return True 220 | 221 | def _check_unicode_description(self, connection): 222 | # requests gives back Unicode strings 223 | return True 224 | -------------------------------------------------------------------------------- /pyhive/sqlalchemy_trino.py: -------------------------------------------------------------------------------- 1 | """Integration between SQLAlchemy and Trino. 2 | 3 | Some code based on 4 | https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py 5 | which is released under the MIT license. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import unicode_literals 10 | 11 | import re 12 | from sqlalchemy import exc 13 | from sqlalchemy import types 14 | from sqlalchemy import util 15 | # TODO shouldn't use mysql type 16 | try: 17 | from sqlalchemy.databases import mysql 18 | mysql_tinyinteger = mysql.MSTinyInteger 19 | except ImportError: 20 | # Required for SQLAlchemy>=2.0 21 | from sqlalchemy.dialects import mysql 22 | mysql_tinyinteger = mysql.base.MSTinyInteger 23 | from sqlalchemy.engine import default 24 | from sqlalchemy.sql import compiler 25 | from sqlalchemy.sql.compiler import SQLCompiler 26 | 27 | from pyhive import trino 28 | from pyhive.common import UniversalSet 29 | from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer 30 | 31 | class TrinoIdentifierPreparer(PrestoIdentifierPreparer): 32 | pass 33 | 34 | 35 | _type_map = { 36 | 'boolean': types.Boolean, 37 | 'tinyint': mysql_tinyinteger, 38 | 'smallint': types.SmallInteger, 39 | 'integer': types.Integer, 40 | 'bigint': types.BigInteger, 41 | 'real': types.Float, 42 | 'double': types.Float, 43 | 'varchar': types.String, 44 | 'timestamp': types.TIMESTAMP, 45 | 'date': types.DATE, 46 | 'varbinary': types.VARBINARY, 47 | } 48 | 49 | 50 | class TrinoCompiler(PrestoCompiler): 51 | pass 52 | 53 | 54 | class TrinoTypeCompiler(PrestoCompiler): 55 | def visit_CLOB(self, type_, **kw): 56 | raise ValueError("Trino does not support the CLOB column type.") 57 | 58 | def visit_NCLOB(self, type_, **kw): 59 | raise ValueError("Trino does not support the NCLOB column type.") 60 | 61 | def visit_DATETIME(self, type_, **kw): 62 | raise ValueError("Trino does not support the DATETIME column type.") 63 | 64 | def visit_FLOAT(self, type_, **kw): 65 | return 'DOUBLE' 66 | 67 | def visit_TEXT(self, type_, **kw): 68 | if type_.length: 69 | return 'VARCHAR({:d})'.format(type_.length) 70 | else: 71 | return 'VARCHAR' 72 | 73 | 74 | class TrinoDialect(PrestoDialect): 75 | name = 'trino' 76 | supports_statement_cache = False 77 | 78 | @classmethod 79 | def dbapi(cls): 80 | return trino 81 | 82 | @classmethod 83 | def import_dbapi(cls): 84 | return trino 85 | -------------------------------------------------------------------------------- /pyhive/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/PyHive/ac09074a652fd50e10b57a7f0bbc4f6410961301/pyhive/tests/__init__.py -------------------------------------------------------------------------------- /pyhive/tests/dbapi_test_case.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """Shared DB-API test cases""" 3 | 4 | from __future__ import absolute_import 5 | from __future__ import unicode_literals 6 | from builtins import object 7 | from builtins import range 8 | from future.utils import with_metaclass 9 | from pyhive import exc 10 | import abc 11 | import contextlib 12 | import functools 13 | 14 | 15 | def with_cursor(fn): 16 | """Pass a cursor to the given function and handle cleanup. 17 | 18 | The cursor is taken from ``self.connect()``. 19 | """ 20 | @functools.wraps(fn) 21 | def wrapped_fn(self, *args, **kwargs): 22 | with contextlib.closing(self.connect()) as connection: 23 | with contextlib.closing(connection.cursor()) as cursor: 24 | fn(self, cursor, *args, **kwargs) 25 | return wrapped_fn 26 | 27 | 28 | class DBAPITestCase(with_metaclass(abc.ABCMeta, object)): 29 | @abc.abstractmethod 30 | def connect(self): 31 | raise NotImplementedError # pragma: no cover 32 | 33 | @with_cursor 34 | def test_fetchone(self, cursor): 35 | cursor.execute('SELECT * FROM one_row') 36 | self.assertEqual(cursor.rownumber, 0) 37 | self.assertEqual(cursor.fetchone(), (1,)) 38 | self.assertEqual(cursor.rownumber, 1) 39 | self.assertIsNone(cursor.fetchone()) 40 | 41 | @with_cursor 42 | def test_fetchall(self, cursor): 43 | cursor.execute('SELECT * FROM one_row') 44 | self.assertEqual(cursor.fetchall(), [(1,)]) 45 | cursor.execute('SELECT a FROM many_rows ORDER BY a') 46 | self.assertEqual(cursor.fetchall(), [(i,) for i in range(10000)]) 47 | 48 | @with_cursor 49 | def test_null_param(self, cursor): 50 | cursor.execute('SELECT %s FROM one_row', (None,)) 51 | self.assertEqual(cursor.fetchall(), [(None,)]) 52 | 53 | @with_cursor 54 | def test_iterator(self, cursor): 55 | cursor.execute('SELECT * FROM one_row') 56 | self.assertEqual(list(cursor), [(1,)]) 57 | self.assertRaises(StopIteration, cursor.__next__) 58 | 59 | @with_cursor 60 | def test_description_initial(self, cursor): 61 | self.assertIsNone(cursor.description) 62 | 63 | @with_cursor 64 | def test_description_failed(self, cursor): 65 | try: 66 | cursor.execute('blah_blah') 67 | self.assertIsNone(cursor.description) 68 | except exc.DatabaseError: 69 | pass 70 | 71 | @with_cursor 72 | def test_bad_query(self, cursor): 73 | def run(): 74 | cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') 75 | cursor.fetchone() 76 | self.assertRaises(exc.DatabaseError, run) 77 | 78 | @with_cursor 79 | def test_concurrent_execution(self, cursor): 80 | cursor.execute('SELECT * FROM one_row') 81 | cursor.execute('SELECT * FROM one_row') 82 | self.assertEqual(cursor.fetchall(), [(1,)]) 83 | 84 | @with_cursor 85 | def test_executemany(self, cursor): 86 | for length in 1, 2: 87 | cursor.executemany( 88 | 'SELECT %(x)d FROM one_row', 89 | [{'x': i} for i in range(1, length + 1)] 90 | ) 91 | self.assertEqual(cursor.fetchall(), [(length,)]) 92 | 93 | @with_cursor 94 | def test_executemany_none(self, cursor): 95 | cursor.executemany('should_never_get_used', []) 96 | self.assertIsNone(cursor.description) 97 | self.assertRaises(exc.ProgrammingError, cursor.fetchone) 98 | 99 | @with_cursor 100 | def test_fetchone_no_data(self, cursor): 101 | self.assertRaises(exc.ProgrammingError, cursor.fetchone) 102 | 103 | @with_cursor 104 | def test_fetchmany(self, cursor): 105 | cursor.execute('SELECT * FROM many_rows LIMIT 15') 106 | self.assertEqual(cursor.fetchmany(0), []) 107 | self.assertEqual(len(cursor.fetchmany(10)), 10) 108 | self.assertEqual(len(cursor.fetchmany(10)), 5) 109 | 110 | @with_cursor 111 | def test_arraysize(self, cursor): 112 | cursor.arraysize = 5 113 | cursor.execute('SELECT * FROM many_rows LIMIT 20') 114 | self.assertEqual(len(cursor.fetchmany()), 5) 115 | 116 | @with_cursor 117 | def test_polling_loop(self, cursor): 118 | """Try to trigger the polling logic in fetchone()""" 119 | cursor._poll_interval = 0 120 | cursor.execute('SELECT COUNT(*) FROM many_rows') 121 | self.assertEqual(cursor.fetchone(), (10000,)) 122 | 123 | @with_cursor 124 | def test_no_params(self, cursor): 125 | cursor.execute("SELECT '%(x)s' FROM one_row") 126 | self.assertEqual(cursor.fetchall(), [('%(x)s',)]) 127 | 128 | def test_escape(self): 129 | """Verify that funny characters can be escaped as strings and SELECTed back""" 130 | bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t ''' 131 | self.run_escape_case(bad_str) 132 | 133 | @with_cursor 134 | def run_escape_case(self, cursor, bad_str): 135 | cursor.execute( 136 | 'SELECT %d, %s FROM one_row', 137 | (1, bad_str) 138 | ) 139 | self.assertEqual(cursor.fetchall(), [(1, bad_str,)]) 140 | cursor.execute( 141 | 'SELECT %(a)d, %(b)s FROM one_row', 142 | {'a': 1, 'b': bad_str} 143 | ) 144 | self.assertEqual(cursor.fetchall(), [(1, bad_str)]) 145 | 146 | @with_cursor 147 | def test_invalid_params(self, cursor): 148 | self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi')) 149 | self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [object])) 150 | 151 | def test_open_close(self): 152 | with contextlib.closing(self.connect()): 153 | pass 154 | with contextlib.closing(self.connect()) as connection: 155 | with contextlib.closing(connection.cursor()): 156 | pass 157 | 158 | @with_cursor 159 | def test_unicode(self, cursor): 160 | unicode_str = "王兢" 161 | cursor.execute( 162 | 'SELECT %s FROM one_row', 163 | (unicode_str,) 164 | ) 165 | self.assertEqual(cursor.fetchall(), [(unicode_str,)]) 166 | 167 | @with_cursor 168 | def test_null(self, cursor): 169 | cursor.execute('SELECT null FROM many_rows') 170 | self.assertEqual(cursor.fetchall(), [(None,)] * 10000) 171 | cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows') 172 | self.assertEqual(cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in range(10000)]) 173 | 174 | @with_cursor 175 | def test_sql_where_in(self, cursor): 176 | cursor.execute('SELECT * FROM many_rows where a in %s', ([1, 2, 3],)) 177 | self.assertEqual(len(cursor.fetchall()), 3) 178 | cursor.execute('SELECT * FROM many_rows where b in %s limit 10', 179 | (['blah'],)) 180 | self.assertEqual(len(cursor.fetchall()), 10) 181 | -------------------------------------------------------------------------------- /pyhive/tests/ldif_data/INITIAL_TESTDATA.ldif: -------------------------------------------------------------------------------- 1 | dn: cn=existing,dc=example,dc=com 2 | objectClass: inetOrgPerson 3 | objectClass: organizationalPerson 4 | objectClass: person 5 | objectClass: top 6 | cn: existing 7 | sn: testentry 8 | mail: test.entry@example.com 9 | mail: te@example.com 10 | ou:: UGFsbcO2 11 | userPassword: testpw 12 | givenName: i am 13 | description: A test entry for pyhive 14 | -------------------------------------------------------------------------------- /pyhive/tests/ldif_data/base.ldif: -------------------------------------------------------------------------------- 1 | # ldapadd -x -h localhost -p 389 -D "cn=admin,dc=test,dc=com" -w secret -f base.ldif 2 | 3 | dn: dc=example,dc=com 4 | objectClass: dcObject 5 | objectClass: organizationalUnit 6 | #dc: test 7 | ou: Test 8 | -------------------------------------------------------------------------------- /pyhive/tests/sqlalchemy_test_case.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import absolute_import 3 | from __future__ import unicode_literals 4 | 5 | import abc 6 | import re 7 | import contextlib 8 | import functools 9 | 10 | import pytest 11 | import sqlalchemy 12 | from builtins import object 13 | from future.utils import with_metaclass 14 | from sqlalchemy.exc import NoSuchTableError 15 | from sqlalchemy.schema import Index 16 | from sqlalchemy.schema import MetaData 17 | from sqlalchemy.schema import Table 18 | from sqlalchemy.sql import expression, text 19 | from sqlalchemy import String 20 | 21 | sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) 22 | 23 | def with_engine_connection(fn): 24 | """Pass a connection to the given function and handle cleanup. 25 | 26 | The connection is taken from ``self.create_engine()``. 27 | """ 28 | @functools.wraps(fn) 29 | def wrapped_fn(self, *args, **kwargs): 30 | engine = self.create_engine() 31 | try: 32 | with contextlib.closing(engine.connect()) as connection: 33 | fn(self, engine, connection, *args, **kwargs) 34 | finally: 35 | engine.dispose() 36 | return wrapped_fn 37 | 38 | def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks): 39 | if sqlalchemy_version >= 1.4: 40 | insp = sqlalchemy.inspect(engine) 41 | insp.reflect_table( 42 | table, 43 | include_columns=include_columns, 44 | exclude_columns=exclude_columns, 45 | resolve_fks=resolve_fks, 46 | ) 47 | else: 48 | engine.dialect.reflecttable( 49 | connection, table, include_columns=include_columns, 50 | exclude_columns=exclude_columns, resolve_fks=resolve_fks) 51 | 52 | 53 | class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)): 54 | @with_engine_connection 55 | def test_basic_query(self, engine, connection): 56 | rows = connection.execute(text('SELECT * FROM one_row')).fetchall() 57 | self.assertEqual(len(rows), 1) 58 | self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name 59 | self.assertEqual(len(rows[0]), 1) 60 | 61 | @with_engine_connection 62 | def test_one_row_complex_null(self, engine, connection): 63 | one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine) 64 | rows = connection.execute(one_row_complex_null.select()).fetchall() 65 | self.assertEqual(len(rows), 1) 66 | self.assertEqual(list(rows[0]), [None] * len(rows[0])) 67 | 68 | @with_engine_connection 69 | def test_reflect_no_such_table(self, engine, connection): 70 | """reflecttable should throw an exception on an invalid table""" 71 | self.assertRaises( 72 | NoSuchTableError, 73 | lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) 74 | self.assertRaises( 75 | NoSuchTableError, 76 | lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine)) 77 | 78 | @with_engine_connection 79 | def test_reflect_include_columns(self, engine, connection): 80 | """When passed include_columns, reflecttable should filter out other columns""" 81 | 82 | one_row_complex = Table('one_row_complex', MetaData()) 83 | reflect_table(engine, connection, one_row_complex, include_columns=['int'], 84 | exclude_columns=[], resolve_fks=True) 85 | 86 | self.assertEqual(len(one_row_complex.c), 1) 87 | self.assertIsNotNone(one_row_complex.c.int) 88 | self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint) 89 | 90 | @with_engine_connection 91 | def test_reflect_with_schema(self, engine, connection): 92 | dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine) 93 | self.assertEqual(len(dummy.c), 1) 94 | self.assertIsNotNone(dummy.c.a) 95 | 96 | @pytest.mark.filterwarnings('default:Omitting:sqlalchemy.exc.SAWarning') 97 | @with_engine_connection 98 | def test_reflect_partitions(self, engine, connection): 99 | """reflecttable should get the partition column as an index""" 100 | many_rows = Table('many_rows', MetaData(), autoload_with=engine) 101 | self.assertEqual(len(many_rows.c), 2) 102 | self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) 103 | 104 | many_rows = Table('many_rows', MetaData()) 105 | reflect_table(engine, connection, many_rows, include_columns=['a'], 106 | exclude_columns=[], resolve_fks=True) 107 | 108 | self.assertEqual(len(many_rows.c), 1) 109 | self.assertFalse(many_rows.c.a.index) 110 | self.assertFalse(many_rows.indexes) 111 | 112 | many_rows = Table('many_rows', MetaData()) 113 | reflect_table(engine, connection, many_rows, include_columns=['b'], 114 | exclude_columns=[], resolve_fks=True) 115 | 116 | self.assertEqual(len(many_rows.c), 1) 117 | self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) 118 | 119 | @with_engine_connection 120 | def test_unicode(self, engine, connection): 121 | """Verify that unicode strings make it through SQLAlchemy and the backend""" 122 | unicode_str = "中文" 123 | one_row = Table('one_row', MetaData()) 124 | 125 | if sqlalchemy_version >= 1.4: 126 | returned_str = connection.execute(sqlalchemy.select( 127 | expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar() 128 | else: 129 | returned_str = connection.execute(sqlalchemy.select([ 130 | expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar() 131 | 132 | self.assertEqual(returned_str, unicode_str) 133 | 134 | @with_engine_connection 135 | def test_reflect_schemas(self, engine, connection): 136 | insp = sqlalchemy.inspect(engine) 137 | schemas = insp.get_schema_names() 138 | self.assertIn('pyhive_test_database', schemas) 139 | self.assertIn('default', schemas) 140 | 141 | @with_engine_connection 142 | def test_get_table_names(self, engine, connection): 143 | meta = MetaData() 144 | meta.reflect(bind=engine) 145 | self.assertIn('one_row', meta.tables) 146 | self.assertIn('one_row_complex', meta.tables) 147 | 148 | insp = sqlalchemy.inspect(engine) 149 | self.assertIn( 150 | 'dummy_table', 151 | insp.get_table_names(schema='pyhive_test_database'), 152 | ) 153 | 154 | @with_engine_connection 155 | def test_has_table(self, engine, connection): 156 | if sqlalchemy_version >= 1.4: 157 | insp = sqlalchemy.inspect(engine) 158 | self.assertTrue(insp.has_table("one_row")) 159 | self.assertFalse(insp.has_table("this_table_does_not_exist")) 160 | else: 161 | self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) 162 | self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) 163 | 164 | @with_engine_connection 165 | def test_char_length(self, engine, connection): 166 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) 167 | 168 | if sqlalchemy_version >= 1.4: 169 | result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar() 170 | else: 171 | result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar() 172 | 173 | self.assertEqual(result, len('a string')) 174 | -------------------------------------------------------------------------------- /pyhive/tests/test_common.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from __future__ import absolute_import 3 | from __future__ import unicode_literals 4 | from pyhive import common 5 | import datetime 6 | import unittest 7 | 8 | 9 | class TestCommon(unittest.TestCase): 10 | def test_escape_args(self): 11 | escaper = common.ParamEscaper() 12 | self.assertEqual(escaper.escape_args({'foo': 'bar'}), 13 | {'foo': "'bar'"}) 14 | self.assertEqual(escaper.escape_args({'foo': 123}), 15 | {'foo': 123}) 16 | self.assertEqual(escaper.escape_args({'foo': 123.456}), 17 | {'foo': 123.456}) 18 | self.assertEqual(escaper.escape_args({'foo': ['a', 'b', 'c']}), 19 | {'foo': "('a','b','c')"}) 20 | self.assertEqual(escaper.escape_args({'foo': ('a', 'b', 'c')}), 21 | {'foo': "('a','b','c')"}) 22 | self.assertIn(escaper.escape_args({'foo': {'a', 'b'}}), 23 | ({'foo': "('a','b')"}, {'foo': "('b','a')"})) 24 | self.assertIn(escaper.escape_args({'foo': frozenset(['a', 'b'])}), 25 | ({'foo': "('a','b')"}, {'foo': "('b','a')"})) 26 | 27 | self.assertEqual(escaper.escape_args(('bar',)), 28 | ("'bar'",)) 29 | self.assertEqual(escaper.escape_args([123]), 30 | (123,)) 31 | self.assertEqual(escaper.escape_args((123.456,)), 32 | (123.456,)) 33 | self.assertEqual(escaper.escape_args((['a', 'b', 'c'],)), 34 | ("('a','b','c')",)) 35 | self.assertEqual(escaper.escape_args((['你好', 'b', 'c'],)), 36 | ("('你好','b','c')",)) 37 | self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), 38 | ("'2020-04-17'",)) 39 | self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), 40 | ("'2020-04-17 12:00:00.123456'",)) 41 | -------------------------------------------------------------------------------- /pyhive/tests/test_hive.py: -------------------------------------------------------------------------------- 1 | """Hive integration tests. 2 | 3 | These rely on having a Hive+Hadoop cluster set up with HiveServer2 running. 4 | They also require a tables created by make_test_tables.sh. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import unicode_literals 9 | 10 | import contextlib 11 | import datetime 12 | import os 13 | import socket 14 | import subprocess 15 | import time 16 | import unittest 17 | from decimal import Decimal 18 | 19 | import mock 20 | import thrift.transport.TSocket 21 | import thrift.transport.TTransport 22 | import thrift_sasl 23 | from thrift.transport.TTransport import TTransportException 24 | 25 | from TCLIService import ttypes 26 | from pyhive import hive 27 | from pyhive.tests.dbapi_test_case import DBAPITestCase 28 | from pyhive.tests.dbapi_test_case import with_cursor 29 | 30 | _HOST = 'localhost' 31 | 32 | 33 | class TestHive(unittest.TestCase, DBAPITestCase): 34 | __test__ = True 35 | 36 | def connect(self): 37 | return hive.connect(host=_HOST, configuration={'mapred.job.tracker': 'local'}) 38 | 39 | @with_cursor 40 | def test_description(self, cursor): 41 | cursor.execute('SELECT * FROM one_row') 42 | 43 | desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)] 44 | self.assertEqual(cursor.description, desc) 45 | 46 | @with_cursor 47 | def test_complex(self, cursor): 48 | cursor.execute('SELECT * FROM one_row_complex') 49 | self.assertEqual(cursor.description, [ 50 | ('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True), 51 | ('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True), 52 | ('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True), 53 | ('one_row_complex.int', 'INT_TYPE', None, None, None, None, True), 54 | ('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True), 55 | ('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True), 56 | ('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True), 57 | ('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True), 58 | ('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True), 59 | ('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True), 60 | ('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True), 61 | ('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True), 62 | ('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True), 63 | ('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True), 64 | ('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True), 65 | ]) 66 | rows = cursor.fetchall() 67 | expected = [( 68 | True, 69 | 127, 70 | 32767, 71 | 2147483647, 72 | 9223372036854775807, 73 | 0.5, 74 | 0.25, 75 | 'a string', 76 | datetime.datetime(1970, 1, 1, 0, 0), 77 | b'123', 78 | '[1,2]', 79 | '{1:2,3:4}', 80 | '{"a":1,"b":2}', 81 | '{0:1}', 82 | Decimal('0.1'), 83 | )] 84 | self.assertEqual(rows, expected) 85 | # catch unicode/str 86 | self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) 87 | 88 | @with_cursor 89 | def test_async(self, cursor): 90 | cursor.execute('SELECT * FROM one_row', async_=True) 91 | unfinished_states = ( 92 | ttypes.TOperationState.INITIALIZED_STATE, 93 | ttypes.TOperationState.RUNNING_STATE, 94 | ) 95 | while cursor.poll().operationState in unfinished_states: 96 | cursor.fetch_logs() 97 | assert cursor.poll().operationState == ttypes.TOperationState.FINISHED_STATE 98 | 99 | self.assertEqual(len(cursor.fetchall()), 1) 100 | 101 | @with_cursor 102 | def test_cancel(self, cursor): 103 | # Need to do a JOIN to force a MR job. Without it, Hive optimizes the query to a fetch 104 | # operator and prematurely declares the query done. 105 | cursor.execute( 106 | "SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) " 107 | "FROM one_row a JOIN one_row b", 108 | async_=True 109 | ) 110 | self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE) 111 | assert any('Stage' in line for line in cursor.fetch_logs()) 112 | cursor.cancel() 113 | self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE) 114 | 115 | def test_noops(self): 116 | """The DB-API specification requires that certain actions exist, even though they might not 117 | be applicable.""" 118 | # Wohoo inflating coverage stats! 119 | with contextlib.closing(self.connect()) as connection: 120 | with contextlib.closing(connection.cursor()) as cursor: 121 | self.assertEqual(cursor.rowcount, -1) 122 | cursor.setinputsizes([]) 123 | cursor.setoutputsize(1, 'blah') 124 | connection.commit() 125 | 126 | @mock.patch('TCLIService.TCLIService.Client.OpenSession') 127 | def test_open_failed(self, open_session): 128 | open_session.return_value.serverProtocolVersion = \ 129 | ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1 130 | self.assertRaises(hive.OperationalError, self.connect) 131 | 132 | def test_escape(self): 133 | # Hive thrift translates newlines into multiple rows. WTF. 134 | bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t ''' 135 | self.run_escape_case(bad_str) 136 | 137 | def test_newlines(self): 138 | """Verify that newlines are passed through correctly""" 139 | cursor = self.connect().cursor() 140 | orig = ' \r\n \r \n ' 141 | cursor.execute( 142 | 'SELECT %s FROM one_row', 143 | (orig,) 144 | ) 145 | result = cursor.fetchall() 146 | self.assertEqual(result, [(orig,)]) 147 | 148 | @with_cursor 149 | def test_no_result_set(self, cursor): 150 | cursor.execute('USE default') 151 | self.assertIsNone(cursor.description) 152 | self.assertRaises(hive.ProgrammingError, cursor.fetchone) 153 | 154 | def test_ldap_connection(self): 155 | rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) 156 | orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-ldap.xml') 157 | orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') 158 | des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') 159 | try: 160 | subprocess.check_call(['sudo', 'cp', orig_ldap, des]) 161 | _restart_hs2() 162 | with contextlib.closing(hive.connect( 163 | host=_HOST, username='existing', auth='LDAP', password='testpw') 164 | ) as connection: 165 | with contextlib.closing(connection.cursor()) as cursor: 166 | cursor.execute('SELECT * FROM one_row') 167 | self.assertEqual(cursor.fetchall(), [(1,)]) 168 | 169 | self.assertRaisesRegexp( 170 | TTransportException, 'Error validating the login', 171 | lambda: hive.connect( 172 | host=_HOST, username='existing', auth='LDAP', password='wrong') 173 | ) 174 | 175 | finally: 176 | subprocess.check_call(['sudo', 'cp', orig_none, des]) 177 | _restart_hs2() 178 | 179 | def test_invalid_ldap_config(self): 180 | """password should be set if and only if using LDAP""" 181 | self.assertRaisesRegexp(ValueError, 'Password.*LDAP', 182 | lambda: hive.connect(_HOST, password='')) 183 | self.assertRaisesRegexp(ValueError, 'Password.*LDAP', 184 | lambda: hive.connect(_HOST, auth='LDAP')) 185 | 186 | def test_invalid_kerberos_config(self): 187 | """kerberos_service_name should be set if and only if using KERBEROS""" 188 | self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', 189 | lambda: hive.connect(_HOST, kerberos_service_name='')) 190 | self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', 191 | lambda: hive.connect(_HOST, auth='KERBEROS')) 192 | 193 | def test_invalid_transport(self): 194 | """transport and auth are incompatible""" 195 | socket = thrift.transport.TSocket.TSocket('localhost', 10000) 196 | transport = thrift.transport.TTransport.TBufferedTransport(socket) 197 | self.assertRaisesRegexp( 198 | ValueError, 'thrift_transport cannot be used with', 199 | lambda: hive.connect(_HOST, thrift_transport=transport) 200 | ) 201 | 202 | def test_custom_transport(self): 203 | socket = thrift.transport.TSocket.TSocket('localhost', 10000) 204 | sasl_auth = 'PLAIN' 205 | 206 | transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket) 207 | conn = hive.connect(thrift_transport=transport) 208 | with contextlib.closing(conn): 209 | with contextlib.closing(conn.cursor()) as cursor: 210 | cursor.execute('SELECT * FROM one_row') 211 | self.assertEqual(cursor.fetchall(), [(1,)]) 212 | 213 | def test_custom_connection(self): 214 | rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) 215 | orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-custom.xml') 216 | orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml') 217 | des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') 218 | try: 219 | subprocess.check_call(['sudo', 'cp', orig_ldap, des]) 220 | _restart_hs2() 221 | with contextlib.closing(hive.connect( 222 | host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd') 223 | ) as connection: 224 | with contextlib.closing(connection.cursor()) as cursor: 225 | cursor.execute('SELECT * FROM one_row') 226 | self.assertEqual(cursor.fetchall(), [(1,)]) 227 | 228 | self.assertRaisesRegexp( 229 | TTransportException, 'Error validating the login', 230 | lambda: hive.connect( 231 | host=_HOST, username='the-user', auth='CUSTOM', password='wrong') 232 | ) 233 | 234 | finally: 235 | subprocess.check_call(['sudo', 'cp', orig_none, des]) 236 | _restart_hs2() 237 | 238 | 239 | def _restart_hs2(): 240 | subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart']) 241 | with contextlib.closing(socket.socket()) as s: 242 | while s.connect_ex(('localhost', 10000)) != 0: 243 | time.sleep(1) 244 | -------------------------------------------------------------------------------- /pyhive/tests/test_presto.py: -------------------------------------------------------------------------------- 1 | """Presto integration tests. 2 | 3 | These rely on having a Presto+Hadoop cluster set up. 4 | They also require a tables created by make_test_tables.sh. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import unicode_literals 9 | 10 | import contextlib 11 | import os 12 | from decimal import Decimal 13 | 14 | import requests 15 | 16 | from pyhive import exc 17 | from pyhive import presto 18 | from pyhive.tests.dbapi_test_case import DBAPITestCase 19 | from pyhive.tests.dbapi_test_case import with_cursor 20 | import mock 21 | import unittest 22 | import datetime 23 | 24 | _HOST = 'localhost' 25 | _PORT = '8080' 26 | 27 | 28 | class TestPresto(unittest.TestCase, DBAPITestCase): 29 | __test__ = True 30 | 31 | def connect(self): 32 | return presto.connect(host=_HOST, port=_PORT, source=self.id()) 33 | 34 | def test_bad_protocol(self): 35 | self.assertRaisesRegexp(ValueError, 'Protocol must be', 36 | lambda: presto.connect('localhost', protocol='nonsense').cursor()) 37 | 38 | def test_escape_args(self): 39 | escaper = presto.PrestoParamEscaper() 40 | 41 | self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), 42 | ("date '2020-04-17'",)) 43 | self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), 44 | ("timestamp '2020-04-17 12:00:00.123'",)) 45 | 46 | @with_cursor 47 | def test_description(self, cursor): 48 | cursor.execute('SELECT 1 AS foobar FROM one_row') 49 | self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) 50 | self.assertIsNotNone(cursor.last_query_id) 51 | 52 | @with_cursor 53 | def test_complex(self, cursor): 54 | cursor.execute('SELECT * FROM one_row_complex') 55 | # TODO Presto drops the union field 56 | if os.environ.get('PRESTO') == '0.147': 57 | tinyint_type = 'integer' 58 | smallint_type = 'integer' 59 | float_type = 'double' 60 | else: 61 | # some later version made these map to more specific types 62 | tinyint_type = 'tinyint' 63 | smallint_type = 'smallint' 64 | float_type = 'real' 65 | self.assertEqual(cursor.description, [ 66 | ('boolean', 'boolean', None, None, None, None, True), 67 | ('tinyint', tinyint_type, None, None, None, None, True), 68 | ('smallint', smallint_type, None, None, None, None, True), 69 | ('int', 'integer', None, None, None, None, True), 70 | ('bigint', 'bigint', None, None, None, None, True), 71 | ('float', float_type, None, None, None, None, True), 72 | ('double', 'double', None, None, None, None, True), 73 | ('string', 'varchar', None, None, None, None, True), 74 | ('timestamp', 'timestamp', None, None, None, None, True), 75 | ('binary', 'varbinary', None, None, None, None, True), 76 | ('array', 'array(integer)', None, None, None, None, True), 77 | ('map', 'map(integer,integer)', None, None, None, None, True), 78 | ('struct', 'row(a integer,b integer)', None, None, None, None, True), 79 | # ('union', 'varchar', None, None, None, None, True), 80 | ('decimal', 'decimal(10,1)', None, None, None, None, True), 81 | ]) 82 | rows = cursor.fetchall() 83 | expected = [( 84 | True, 85 | 127, 86 | 32767, 87 | 2147483647, 88 | 9223372036854775807, 89 | 0.5, 90 | 0.25, 91 | 'a string', 92 | '1970-01-01 00:00:00.000', 93 | b'123', 94 | [1, 2], 95 | {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON 96 | [1, 2], # struct is returned as a list of elements 97 | # '{0:1}', 98 | Decimal('0.1'), 99 | )] 100 | self.assertEqual(rows, expected) 101 | # catch unicode/str 102 | self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) 103 | 104 | @with_cursor 105 | def test_cancel(self, cursor): 106 | cursor.execute( 107 | "SELECT a.a * rand(), b.a * rand()" 108 | "FROM many_rows a " 109 | "CROSS JOIN many_rows b " 110 | ) 111 | self.assertIn(cursor.poll()['stats']['state'], ( 112 | 'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED')) 113 | cursor.cancel() 114 | self.assertIsNotNone(cursor.last_query_id) 115 | self.assertIsNone(cursor.poll()) 116 | 117 | def test_noops(self): 118 | """The DB-API specification requires that certain actions exist, even though they might not 119 | be applicable.""" 120 | # Wohoo inflating coverage stats! 121 | connection = self.connect() 122 | cursor = connection.cursor() 123 | self.assertEqual(cursor.rowcount, -1) 124 | cursor.setinputsizes([]) 125 | cursor.setoutputsize(1, 'blah') 126 | self.assertIsNone(cursor.last_query_id) 127 | connection.commit() 128 | 129 | @mock.patch('requests.post') 130 | def test_non_200(self, post): 131 | cursor = self.connect().cursor() 132 | post.return_value.status_code = 404 133 | self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables')) 134 | 135 | @with_cursor 136 | def test_poll(self, cursor): 137 | self.assertRaises(presto.ProgrammingError, cursor.poll) 138 | 139 | cursor.execute('SELECT * FROM one_row') 140 | while True: 141 | status = cursor.poll() 142 | if status is None: 143 | break 144 | self.assertIn('stats', status) 145 | 146 | def fail(*args, **kwargs): 147 | self.fail("Should not need requests.get after done polling") # pragma: no cover 148 | 149 | with mock.patch('requests.get', fail): 150 | self.assertEqual(cursor.fetchall(), [(1,)]) 151 | 152 | @with_cursor 153 | def test_set_session(self, cursor): 154 | id = None 155 | self.assertIsNone(cursor.last_query_id) 156 | cursor.execute("SET SESSION query_max_run_time = '1234m'") 157 | self.assertIsNotNone(cursor.last_query_id) 158 | id = cursor.last_query_id 159 | cursor.fetchall() 160 | self.assertEqual(id, cursor.last_query_id) 161 | 162 | cursor.execute('SHOW SESSION') 163 | self.assertIsNotNone(cursor.last_query_id) 164 | self.assertNotEqual(id, cursor.last_query_id) 165 | id = cursor.last_query_id 166 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] 167 | self.assertEqual(len(rows), 1) 168 | session_prop = rows[0] 169 | self.assertEqual(session_prop[1], '1234m') 170 | self.assertEqual(id, cursor.last_query_id) 171 | 172 | cursor.execute('RESET SESSION query_max_run_time') 173 | self.assertIsNotNone(cursor.last_query_id) 174 | self.assertNotEqual(id, cursor.last_query_id) 175 | id = cursor.last_query_id 176 | cursor.fetchall() 177 | self.assertEqual(id, cursor.last_query_id) 178 | 179 | cursor.execute('SHOW SESSION') 180 | self.assertIsNotNone(cursor.last_query_id) 181 | self.assertNotEqual(id, cursor.last_query_id) 182 | id = cursor.last_query_id 183 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] 184 | self.assertEqual(len(rows), 1) 185 | session_prop = rows[0] 186 | self.assertNotEqual(session_prop[1], '1234m') 187 | self.assertEqual(id, cursor.last_query_id) 188 | 189 | def test_set_session_in_constructor(self): 190 | conn = presto.connect( 191 | host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'} 192 | ) 193 | with contextlib.closing(conn): 194 | with contextlib.closing(conn.cursor()) as cursor: 195 | cursor.execute('SHOW SESSION') 196 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] 197 | assert len(rows) == 1 198 | session_prop = rows[0] 199 | assert session_prop[1] == '1234m' 200 | 201 | cursor.execute('RESET SESSION query_max_run_time') 202 | cursor.fetchall() 203 | 204 | cursor.execute('SHOW SESSION') 205 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] 206 | assert len(rows) == 1 207 | session_prop = rows[0] 208 | assert session_prop[1] != '1234m' 209 | 210 | def test_invalid_protocol_config(self): 211 | """protocol should be https when passing password""" 212 | self.assertRaisesRegexp( 213 | ValueError, 'Protocol.*https.*password', lambda: presto.connect( 214 | host=_HOST, username='user', password='secret', protocol='http').cursor() 215 | ) 216 | 217 | def test_invalid_password_and_kwargs(self): 218 | """password and requests_kwargs authentication are incompatible""" 219 | self.assertRaisesRegexp( 220 | ValueError, 'Cannot use both', lambda: presto.connect( 221 | host=_HOST, username='user', password='secret', protocol='https', 222 | requests_kwargs={'auth': requests.auth.HTTPBasicAuth('user', 'secret')} 223 | ).cursor() 224 | ) 225 | 226 | def test_invalid_kwargs(self): 227 | """some kwargs are reserved""" 228 | self.assertRaisesRegexp( 229 | ValueError, 'Cannot override', lambda: presto.connect( 230 | host=_HOST, username='user', requests_kwargs={'url': 'test'} 231 | ).cursor() 232 | ) 233 | 234 | def test_requests_kwargs(self): 235 | connection = presto.connect( 236 | host=_HOST, port=_PORT, source=self.id(), 237 | requests_kwargs={'proxies': {'http': 'localhost:9999'}}, 238 | ) 239 | cursor = connection.cursor() 240 | self.assertRaises(requests.exceptions.ProxyError, 241 | lambda: cursor.execute('SELECT * FROM one_row')) 242 | 243 | def test_requests_session(self): 244 | with requests.Session() as session: 245 | connection = presto.connect( 246 | host=_HOST, port=_PORT, source=self.id(), requests_session=session 247 | ) 248 | cursor = connection.cursor() 249 | cursor.execute('SELECT * FROM one_row') 250 | self.assertEqual(cursor.fetchall(), [(1,)]) 251 | -------------------------------------------------------------------------------- /pyhive/tests/test_sasl_compat.py: -------------------------------------------------------------------------------- 1 | ''' 2 | http://www.opensource.org/licenses/mit-license.php 3 | 4 | Copyright 2007-2011 David Alan Cridland 5 | Copyright 2011 Lance Stout 6 | Copyright 2012 Tyler L Hobbs 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this 9 | software and associated documentation files (the "Software"), to deal in the Software 10 | without restriction, including without limitation the rights to use, copy, modify, merge, 11 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons 12 | to whom the Software is furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all copies or 15 | substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 18 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 19 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 20 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 21 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22 | DEALINGS IN THE SOFTWARE. 23 | ''' 24 | # This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit 25 | # and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py. 26 | 27 | import unittest 28 | import base64 29 | import hashlib 30 | import hmac 31 | import kerberos 32 | from mock import patch 33 | import six 34 | import struct 35 | from puresasl import SASLProtocolException, QOP 36 | from puresasl.client import SASLError 37 | from pyhive.sasl_compat import PureSASLClient, error_catcher 38 | 39 | 40 | class TestPureSASLClient(unittest.TestCase): 41 | """Test cases for initialization of SASL client using PureSASLClient class""" 42 | 43 | def setUp(self): 44 | self.sasl_kwargs = {} 45 | self.sasl = PureSASLClient('localhost', **self.sasl_kwargs) 46 | 47 | def test_start_no_mechanism(self): 48 | """Test starting SASL authentication with no mechanism.""" 49 | success, mechanism, response = self.sasl.start(mechanism=None) 50 | self.assertFalse(success) 51 | self.assertIsNone(mechanism) 52 | self.assertIsNone(response) 53 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') 54 | 55 | def test_start_wrong_mechanism(self): 56 | """Test starting SASL authentication with a single unsupported mechanism.""" 57 | success, mechanism, response = self.sasl.start(mechanism='WRONG') 58 | self.assertFalse(success) 59 | self.assertEqual(mechanism, 'WRONG') 60 | self.assertIsNone(response) 61 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') 62 | 63 | def test_start_list_of_invalid_mechanisms(self): 64 | """Test starting SASL authentication with a list of unsupported mechanisms.""" 65 | self.sasl.start(['invalid1', 'invalid2']) 66 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') 67 | 68 | def test_start_list_of_valid_mechanisms(self): 69 | """Test starting SASL authentication with a list of supported mechanisms.""" 70 | self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5']) 71 | # Validate right mechanism is chosen based on score. 72 | self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5') 73 | 74 | def test_error_catcher_no_error(self): 75 | """Test the error_catcher with no error.""" 76 | with error_catcher(self.sasl): 77 | result, _, _ = self.sasl.start(mechanism='ANONYMOUS') 78 | 79 | self.assertEqual(self.sasl.getError(), None) 80 | self.assertEqual(result, True) 81 | 82 | def test_error_catcher_with_error(self): 83 | """Test the error_catcher with an error.""" 84 | with error_catcher(self.sasl): 85 | result, _, _ = self.sasl.start(mechanism='WRONG') 86 | 87 | self.assertEqual(result, False) 88 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') 89 | 90 | """Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims""" 91 | 92 | class _BaseMechanismTests(unittest.TestCase): 93 | """Base test case for SASL mechanisms.""" 94 | 95 | mechanism = 'ANONYMOUS' 96 | sasl_kwargs = {} 97 | 98 | def setUp(self): 99 | self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) 100 | self.mechanism_class = self.sasl._chosen_mech 101 | 102 | def test_init_basic(self, *args): 103 | sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) 104 | mech = sasl._chosen_mech 105 | self.assertIs(mech.sasl, sasl) 106 | 107 | def test_step_basic(self, *args): 108 | success, response = self.sasl.step(six.b('string')) 109 | self.assertTrue(success) 110 | self.assertIsInstance(response, six.binary_type) 111 | 112 | def test_decode_encode(self, *args): 113 | self.assertEqual(self.sasl.encode('msg'), (False, None)) 114 | self.assertEqual(self.sasl.getError(), '') 115 | self.assertEqual(self.sasl.decode('msg'), (False, None)) 116 | self.assertEqual(self.sasl.getError(), '') 117 | 118 | 119 | class AnonymousMechanismTest(_BaseMechanismTests): 120 | """Test case for the Anonymous SASL mechanism.""" 121 | 122 | mechanism = 'ANONYMOUS' 123 | 124 | 125 | class PlainTextMechanismTest(_BaseMechanismTests): 126 | """Test case for the PlainText SASL mechanism.""" 127 | 128 | mechanism = 'PLAIN' 129 | username = 'user' 130 | password = 'pass' 131 | sasl_kwargs = {'username': username, 'password': password} 132 | 133 | def test_step(self): 134 | for challenge in (None, '', b'asdf', u"\U0001F44D"): 135 | success, response = self.sasl.step(challenge) 136 | self.assertTrue(success) 137 | self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}')) 138 | self.assertIsInstance(response, six.binary_type) 139 | 140 | def test_step_with_authorization_id_or_identity(self): 141 | challenge = u"\U0001F44D" 142 | identity = 'user2' 143 | 144 | # Test that we can pass an identity 145 | sasl_kwargs = self.sasl_kwargs.copy() 146 | sasl_kwargs.update({'identity': identity}) 147 | sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) 148 | success, response = sasl.step(challenge) 149 | self.assertTrue(success) 150 | self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}')) 151 | self.assertIsInstance(response, six.binary_type) 152 | self.assertTrue(sasl.complete) 153 | 154 | # Test that the sasl authorization_id has priority over identity 155 | auth_id = 'user3' 156 | sasl_kwargs.update({'authorization_id': auth_id}) 157 | sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) 158 | success, response = sasl.step(challenge) 159 | self.assertTrue(success) 160 | self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}')) 161 | self.assertIsInstance(response, six.binary_type) 162 | self.assertTrue(sasl.complete) 163 | 164 | def test_decode_encode(self): 165 | msg = 'msg' 166 | self.assertEqual(self.sasl.decode(msg), (True, msg)) 167 | self.assertEqual(self.sasl.encode(msg), (True, msg)) 168 | 169 | 170 | class ExternalMechanismTest(_BaseMechanismTests): 171 | """Test case for the External SASL mechanisms""" 172 | 173 | mechanism = 'EXTERNAL' 174 | 175 | def test_step(self): 176 | self.assertEqual(self.sasl.step(), (True, b'')) 177 | 178 | def test_decode_encode(self): 179 | msg = 'msg' 180 | self.assertEqual(self.sasl.decode(msg), (True, msg)) 181 | self.assertEqual(self.sasl.encode(msg), (True, msg)) 182 | 183 | 184 | @patch('puresasl.mechanisms.kerberos.authGSSClientStep') 185 | @patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response'))) 186 | class GSSAPIMechanismTest(_BaseMechanismTests): 187 | """Test case for the GSSAPI SASL mechanism.""" 188 | 189 | mechanism = 'GSSAPI' 190 | service = 'GSSAPI' 191 | sasl_kwargs = {'service': service} 192 | 193 | @patch('puresasl.mechanisms.kerberos.authGSSClientWrap') 194 | @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') 195 | def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args): 196 | # bypassing step setup by setting qop directly 197 | self.mechanism_class.qop = QOP.AUTH 198 | msg = b'msg' 199 | self.assertEqual(self.sasl.decode(msg), (True, msg)) 200 | self.assertEqual(self.sasl.encode(msg), (True, msg)) 201 | 202 | # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication 203 | for qop in (QOP.AUTH_INT, QOP.AUTH_CONF): 204 | self.mechanism_class.qop = qop 205 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1): 206 | self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) 207 | self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) 208 | if qop == QOP.AUTH_CONF: 209 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0): 210 | self.assertEqual(self.sasl.encode(msg), (False, None)) 211 | self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.') 212 | 213 | def test_step_no_user(self, authGSSClientResponse, *args): 214 | msg = six.b('whatever') 215 | 216 | # no user 217 | self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) 218 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''): 219 | self.assertEqual(self.sasl.step(msg), (True, six.b(''))) 220 | 221 | username = 'username' 222 | # with user; this has to be last because it sets mechanism.user 223 | with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE): 224 | with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)): 225 | self.assertEqual(self.sasl.step(msg), (True, six.b(''))) 226 | self.assertEqual(self.mechanism_class.user, six.b(username)) 227 | 228 | @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') 229 | def test_step_qop(self, *args): 230 | self.mechanism_class._have_negotiated_details = True 231 | self.mechanism_class.user = 'user' 232 | msg = six.b('msg') 233 | self.assertEqual(self.sasl.step(msg), (False, None)) 234 | self.assertEqual(self.sasl.getError(), 'Bad response from server') 235 | 236 | max_len = 100 237 | self.assertLess(max_len, self.sasl.max_buffer) 238 | for i, qop in QOP.bit_map.items(): 239 | qop_size = struct.pack('!i', i << 24 | max_len) 240 | response = base64.b64encode(qop_size) 241 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response): 242 | with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap: 243 | self.mechanism_class.complete = False 244 | self.assertEqual(self.sasl.step(msg), (True, qop_size)) 245 | self.assertTrue(self.mechanism_class.complete) 246 | self.assertEqual(self.mechanism_class.qop, qop) 247 | self.assertEqual(self.mechanism_class.max_buffer, max_len) 248 | 249 | args = authGSSClientWrap.call_args[0] 250 | out_data = args[1] 251 | out = base64.b64decode(out_data) 252 | self.assertEqual(out[:4], qop_size) 253 | self.assertEqual(out[4:], six.b(self.mechanism_class.user)) 254 | 255 | 256 | class CramMD5MechanismTest(_BaseMechanismTests): 257 | """Test case for the CRAM-MD5 SASL mechanism.""" 258 | 259 | mechanism = 'CRAM-MD5' 260 | username = 'user' 261 | password = 'pass' 262 | sasl_kwargs = {'username': username, 'password': password} 263 | 264 | def test_step(self): 265 | success, response = self.sasl.step(None) 266 | self.assertTrue(success) 267 | self.assertIsNone(response) 268 | challenge = six.b('msg') 269 | hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5) 270 | hash.update(challenge) 271 | success, response = self.sasl.step(challenge) 272 | self.assertTrue(success) 273 | self.assertIn(six.b(self.username), response) 274 | self.assertIn(six.b(hash.hexdigest()), response) 275 | self.assertIsInstance(response, six.binary_type) 276 | self.assertTrue(self.sasl.complete) 277 | 278 | def test_decode_encode(self): 279 | msg = 'msg' 280 | self.assertEqual(self.sasl.decode(msg), (True, msg)) 281 | self.assertEqual(self.sasl.encode(msg), (True, msg)) 282 | 283 | 284 | class DigestMD5MechanismTest(_BaseMechanismTests): 285 | """Test case for the DIGEST-MD5 SASL mechanism.""" 286 | 287 | mechanism = 'DIGEST-MD5' 288 | username = 'user' 289 | password = 'pass' 290 | sasl_kwargs = {'username': username, 'password': password} 291 | 292 | def test_decode_encode(self): 293 | msg = 'msg' 294 | self.assertEqual(self.sasl.decode(msg), (True, msg)) 295 | self.assertEqual(self.sasl.encode(msg), (True, msg)) 296 | 297 | def test_step_basic(self, *args): 298 | pass 299 | 300 | def test_step(self): 301 | """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism.""" 302 | testChallenge = ( 303 | b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r' 304 | b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc' 305 | b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess' 306 | ) 307 | result, response = self.sasl.step(testChallenge) 308 | self.assertTrue(result) 309 | self.assertIsNotNone(response) 310 | 311 | def test_step_server_answer(self): 312 | """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism.""" 313 | sasl_kwargs = {'username': "chris", 'password': "secret"} 314 | sasl = PureSASLClient('elwood.innosoft.com', 315 | service="imap", 316 | mechanism=self.mechanism, 317 | mutual_auth=True, 318 | **sasl_kwargs) 319 | testChallenge = ( 320 | b'utf-8,username="chris",realm="elwood.innosoft.com",' 321 | b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' 322 | b'digest-uri="imap/elwood.innosoft.com",' 323 | b'response=d388dad90d4bbd760a152321f2143af7,qop=auth' 324 | ) 325 | sasl.step(testChallenge) 326 | sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" 327 | 328 | serverResponse = ( 329 | b'rspauth=ea40f60335c427b5527b84dbabcdfffd' 330 | ) 331 | sasl.step(serverResponse) 332 | # assert that step choses the only supported QOP for for DIGEST-MD5 333 | self.assertEqual(self.sasl.qop, QOP.AUTH) 334 | -------------------------------------------------------------------------------- /pyhive/tests/test_sqlalchemy_hive.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from builtins import str 4 | from pyhive.sqlalchemy_hive import HiveDate 5 | from pyhive.sqlalchemy_hive import HiveDecimal 6 | from pyhive.sqlalchemy_hive import HiveTimestamp 7 | from sqlalchemy.exc import NoSuchTableError, OperationalError 8 | from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase 9 | from pyhive.tests.sqlalchemy_test_case import with_engine_connection 10 | from sqlalchemy import types 11 | from sqlalchemy.engine import create_engine 12 | from sqlalchemy.schema import Column 13 | from sqlalchemy.schema import MetaData 14 | from sqlalchemy.schema import Table 15 | from sqlalchemy.sql import text 16 | import contextlib 17 | import datetime 18 | import decimal 19 | import sqlalchemy.types 20 | import unittest 21 | import re 22 | 23 | sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) 24 | 25 | _ONE_ROW_COMPLEX_CONTENTS = [ 26 | True, 27 | 127, 28 | 32767, 29 | 2147483647, 30 | 9223372036854775807, 31 | 0.5, 32 | 0.25, 33 | 'a string', 34 | datetime.datetime(1970, 1, 1), 35 | b'123', 36 | '[1,2]', 37 | '{1:2,3:4}', 38 | '{"a":1,"b":2}', 39 | '{0:1}', 40 | decimal.Decimal('0.1'), 41 | ] 42 | 43 | 44 | # [ 45 | # ('boolean', 'boolean', ''), 46 | # ('tinyint', 'tinyint', ''), 47 | # ('smallint', 'smallint', ''), 48 | # ('int', 'int', ''), 49 | # ('bigint', 'bigint', ''), 50 | # ('float', 'float', ''), 51 | # ('double', 'double', ''), 52 | # ('string', 'string', ''), 53 | # ('timestamp', 'timestamp', ''), 54 | # ('binary', 'binary', ''), 55 | # ('array', 'array', ''), 56 | # ('map', 'map', ''), 57 | # ('struct', 'struct', ''), 58 | # ('union', 'uniontype', ''), 59 | # ('decimal', 'decimal(10,1)', '') 60 | # ] 61 | 62 | 63 | class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase): 64 | def create_engine(self): 65 | return create_engine('hive://localhost:10000/default') 66 | 67 | @with_engine_connection 68 | def test_dotted_column_names(self, engine, connection): 69 | """When Hive returns a dotted column name, both the non-dotted version should be available 70 | as an attribute, and the dotted version should remain available as a key. 71 | """ 72 | row = connection.execute(text('SELECT * FROM one_row')).fetchone() 73 | 74 | if sqlalchemy_version >= 1.4: 75 | row = row._mapping 76 | 77 | assert row.keys() == ['number_of_rows'] 78 | assert 'number_of_rows' in row 79 | assert row.number_of_rows == 1 80 | assert row['number_of_rows'] == 1 81 | assert getattr(row, 'one_row.number_of_rows') == 1 82 | assert row['one_row.number_of_rows'] == 1 83 | 84 | @with_engine_connection 85 | def test_dotted_column_names_raw(self, engine, connection): 86 | """When Hive returns a dotted column name, and raw mode is on, nothing should be modified. 87 | """ 88 | row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone() 89 | 90 | if sqlalchemy_version >= 1.4: 91 | row = row._mapping 92 | 93 | assert row.keys() == ['one_row.number_of_rows'] 94 | assert 'number_of_rows' not in row 95 | assert getattr(row, 'one_row.number_of_rows') == 1 96 | assert row['one_row.number_of_rows'] == 1 97 | 98 | @with_engine_connection 99 | def test_reflect_no_such_table(self, engine, connection): 100 | """reflecttable should throw an exception on an invalid table""" 101 | self.assertRaises( 102 | NoSuchTableError, 103 | lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) 104 | self.assertRaises( 105 | OperationalError, 106 | lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) 107 | 108 | @with_engine_connection 109 | def test_reflect_select(self, engine, connection): 110 | """reflecttable should be able to fill in a table from the name""" 111 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) 112 | self.assertEqual(len(one_row_complex.c), 15) 113 | self.assertIsInstance(one_row_complex.c.string, Column) 114 | row = connection.execute(one_row_complex.select()).fetchone() 115 | self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) 116 | 117 | # TODO some of these types could be filled in better 118 | self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) 119 | self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) 120 | self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) 121 | self.assertIsInstance(one_row_complex.c.int.type, types.Integer) 122 | self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) 123 | self.assertIsInstance(one_row_complex.c.float.type, types.Float) 124 | self.assertIsInstance(one_row_complex.c.double.type, types.Float) 125 | self.assertIsInstance(one_row_complex.c.string.type, types.String) 126 | self.assertIsInstance(one_row_complex.c.timestamp.type, HiveTimestamp) 127 | self.assertIsInstance(one_row_complex.c.binary.type, types.String) 128 | self.assertIsInstance(one_row_complex.c.array.type, types.String) 129 | self.assertIsInstance(one_row_complex.c.map.type, types.String) 130 | self.assertIsInstance(one_row_complex.c.struct.type, types.String) 131 | self.assertIsInstance(one_row_complex.c.union.type, types.String) 132 | self.assertIsInstance(one_row_complex.c.decimal.type, HiveDecimal) 133 | 134 | @with_engine_connection 135 | def test_type_map(self, engine, connection): 136 | """sqlalchemy should use the dbapi_type_map to infer types from raw queries""" 137 | row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone() 138 | self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) 139 | 140 | @with_engine_connection 141 | def test_reserved_words(self, engine, connection): 142 | """Hive uses backticks""" 143 | # Use keywords for the table/column name 144 | fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String)) 145 | query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine)) 146 | self.assertIn('`select`', query) 147 | self.assertIn('`map`', query) 148 | self.assertNotIn('"select"', query) 149 | self.assertNotIn('"map"', query) 150 | 151 | def test_switch_database(self): 152 | engine = create_engine('hive://localhost:10000/pyhive_test_database') 153 | try: 154 | with contextlib.closing(engine.connect()) as connection: 155 | self.assertIn( 156 | ('dummy_table',), 157 | connection.execute(text('SHOW TABLES')).fetchall() 158 | ) 159 | connection.execute(text('USE default')) 160 | self.assertIn( 161 | ('one_row',), 162 | connection.execute(text('SHOW TABLES')).fetchall() 163 | ) 164 | finally: 165 | engine.dispose() 166 | 167 | @with_engine_connection 168 | def test_lots_of_types(self, engine, connection): 169 | # Presto doesn't have raw CREATE TABLE support, so we ony test hive 170 | # take type list from sqlalchemy.types 171 | types = [ 172 | 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT', 173 | 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 174 | 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME', 175 | 'String', 'Integer', 'SmallInteger', 176 | 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary', 177 | 'Boolean', 'Unicode', 'UnicodeText', 178 | ] 179 | cols = [] 180 | for i, t in enumerate(types): 181 | cols.append(Column(str(i), getattr(sqlalchemy.types, t))) 182 | cols.append(Column('hive_date', HiveDate)) 183 | cols.append(Column('hive_decimal', HiveDecimal)) 184 | cols.append(Column('hive_timestamp', HiveTimestamp)) 185 | table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,) 186 | table.drop(checkfirst=True, bind=connection) 187 | table.create(bind=connection) 188 | connection.execute(text('SET mapred.job.tracker=local')) 189 | connection.execute(text('USE pyhive_test_database')) 190 | big_number = 10 ** 10 - 1 191 | connection.execute(text(""" 192 | INSERT OVERWRITE TABLE test_table 193 | SELECT 194 | 1, "a", "a", "a", "a", "a", 0.1, 195 | 0.1, 0.1, 0, 0, "a", "a", 196 | false, 1, 0, 0, 197 | "a", 1, 1, 198 | 0.1, 0.1, 0, 0, 0, "a", 199 | false, "a", "a", 200 | 0, :big_number, 123 + 2000 201 | FROM default.one_row 202 | """), {"big_number": big_number}) 203 | row = connection.execute(text("select * from test_table")).fetchone() 204 | self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0)) 205 | self.assertEqual(row.hive_decimal, decimal.Decimal(big_number)) 206 | self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000)) 207 | table.drop(bind=connection) 208 | 209 | @with_engine_connection 210 | def test_insert_select(self, engine, connection): 211 | one_row = Table('one_row', MetaData(), autoload_with=engine) 212 | table = Table('insert_test', MetaData(schema='pyhive_test_database'), 213 | Column('a', sqlalchemy.types.Integer)) 214 | table.drop(checkfirst=True, bind=connection) 215 | table.create(bind=connection) 216 | connection.execute(text('SET mapred.job.tracker=local')) 217 | # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES 218 | connection.execute(table.insert().from_select(['a'], one_row.select())) 219 | 220 | result = connection.execute(table.select()).fetchall() 221 | expected = [(1,)] 222 | self.assertEqual(result, expected) 223 | 224 | @with_engine_connection 225 | def test_insert_values(self, engine, connection): 226 | table = Table('insert_test', MetaData(schema='pyhive_test_database'), 227 | Column('a', sqlalchemy.types.Integer),) 228 | table.drop(checkfirst=True, bind=connection) 229 | table.create(bind=connection) 230 | connection.execute(table.insert().values([{'a': 1}, {'a': 2}])) 231 | 232 | result = connection.execute(table.select()).fetchall() 233 | expected = [(1,), (2,)] 234 | self.assertEqual(result, expected) 235 | 236 | @with_engine_connection 237 | def test_supports_san_rowcount(self, engine, connection): 238 | self.assertFalse(engine.dialect.supports_sane_rowcount_returning) 239 | -------------------------------------------------------------------------------- /pyhive/tests/test_sqlalchemy_presto.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from builtins import str 4 | from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase 5 | from pyhive.tests.sqlalchemy_test_case import with_engine_connection 6 | from sqlalchemy import types 7 | from sqlalchemy.engine import create_engine 8 | from sqlalchemy.schema import Column 9 | from sqlalchemy.schema import MetaData 10 | from sqlalchemy.schema import Table 11 | from sqlalchemy.sql import text 12 | from sqlalchemy.types import String 13 | from decimal import Decimal 14 | 15 | import contextlib 16 | import unittest 17 | 18 | 19 | class TestSqlAlchemyPresto(unittest.TestCase, SqlAlchemyTestCase): 20 | def create_engine(self): 21 | return create_engine('presto://localhost:8080/hive/default?source={}'.format(self.id())) 22 | 23 | def test_bad_format(self): 24 | self.assertRaises( 25 | ValueError, 26 | lambda: create_engine('presto://localhost:8080/hive/default/what'), 27 | ) 28 | 29 | @with_engine_connection 30 | def test_reflect_select(self, engine, connection): 31 | """reflecttable should be able to fill in a table from the name""" 32 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) 33 | # Presto ignores the union column 34 | self.assertEqual(len(one_row_complex.c), 15 - 1) 35 | self.assertIsInstance(one_row_complex.c.string, Column) 36 | rows = connection.execute(one_row_complex.select()).fetchall() 37 | self.assertEqual(len(rows), 1) 38 | self.assertEqual(list(rows[0]), [ 39 | True, 40 | 127, 41 | 32767, 42 | 2147483647, 43 | 9223372036854775807, 44 | 0.5, 45 | 0.25, 46 | 'a string', 47 | '1970-01-01 00:00:00.000', 48 | b'123', 49 | [1, 2], 50 | {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON 51 | [1, 2], # struct is returned as a list of elements 52 | # '{0:1}', 53 | Decimal('0.1'), 54 | ]) 55 | 56 | # TODO some of these types could be filled in better 57 | self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) 58 | self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) 59 | self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) 60 | self.assertIsInstance(one_row_complex.c.int.type, types.Integer) 61 | self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) 62 | self.assertIsInstance(one_row_complex.c.float.type, types.Float) 63 | self.assertIsInstance(one_row_complex.c.double.type, types.Float) 64 | self.assertIsInstance(one_row_complex.c.string.type, String) 65 | self.assertIsInstance(one_row_complex.c.timestamp.type, types.TIMESTAMP) 66 | self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY) 67 | self.assertIsInstance(one_row_complex.c.array.type, types.NullType) 68 | self.assertIsInstance(one_row_complex.c.map.type, types.NullType) 69 | self.assertIsInstance(one_row_complex.c.struct.type, types.NullType) 70 | self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) 71 | 72 | def test_url_default(self): 73 | engine = create_engine('presto://localhost:8080/hive') 74 | try: 75 | with contextlib.closing(engine.connect()) as connection: 76 | self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) 77 | finally: 78 | engine.dispose() 79 | 80 | @with_engine_connection 81 | def test_reserved_words(self, engine, connection): 82 | """Presto uses double quotes, not backticks""" 83 | # Use keywords for the table/column name 84 | fake_table = Table('select', MetaData(), Column('current_timestamp', String)) 85 | query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) 86 | self.assertIn('"select"', query) 87 | self.assertIn('"current_timestamp"', query) 88 | self.assertNotIn('`select`', query) 89 | self.assertNotIn('`current_timestamp`', query) 90 | -------------------------------------------------------------------------------- /pyhive/tests/test_sqlalchemy_trino.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.engine import create_engine 2 | from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase 3 | from pyhive.tests.sqlalchemy_test_case import with_engine_connection 4 | from sqlalchemy.exc import NoSuchTableError, DatabaseError 5 | from sqlalchemy.schema import MetaData, Table, Column 6 | from sqlalchemy.types import String 7 | from sqlalchemy.sql import text 8 | from sqlalchemy import types 9 | from decimal import Decimal 10 | 11 | import unittest 12 | import contextlib 13 | 14 | 15 | class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase): 16 | def create_engine(self): 17 | return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id())) 18 | 19 | def test_bad_format(self): 20 | self.assertRaises( 21 | ValueError, 22 | lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'), 23 | ) 24 | 25 | @with_engine_connection 26 | def test_reflect_select(self, engine, connection): 27 | """reflecttable should be able to fill in a table from the name""" 28 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) 29 | # Presto ignores the union column 30 | self.assertEqual(len(one_row_complex.c), 15 - 1) 31 | self.assertIsInstance(one_row_complex.c.string, Column) 32 | rows = connection.execute(one_row_complex.select()).fetchall() 33 | self.assertEqual(len(rows), 1) 34 | self.assertEqual(list(rows[0]), [ 35 | True, 36 | 127, 37 | 32767, 38 | 2147483647, 39 | 9223372036854775807, 40 | 0.5, 41 | 0.25, 42 | 'a string', 43 | '1970-01-01 00:00:00.000', 44 | b'123', 45 | [1, 2], 46 | {"1": 2, "3": 4}, 47 | [1, 2], 48 | Decimal('0.1'), 49 | ]) 50 | 51 | self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) 52 | self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) 53 | self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) 54 | self.assertIsInstance(one_row_complex.c.int.type, types.Integer) 55 | self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) 56 | self.assertIsInstance(one_row_complex.c.float.type, types.Float) 57 | self.assertIsInstance(one_row_complex.c.double.type, types.Float) 58 | self.assertIsInstance(one_row_complex.c.string.type, String) 59 | self.assertIsInstance(one_row_complex.c.timestamp.type, types.NullType) 60 | self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY) 61 | self.assertIsInstance(one_row_complex.c.array.type, types.NullType) 62 | self.assertIsInstance(one_row_complex.c.map.type, types.NullType) 63 | self.assertIsInstance(one_row_complex.c.struct.type, types.NullType) 64 | self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) 65 | 66 | @with_engine_connection 67 | def test_reflect_no_such_table(self, engine, connection): 68 | """reflecttable should throw an exception on an invalid table""" 69 | self.assertRaises( 70 | NoSuchTableError, 71 | lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) 72 | self.assertRaises( 73 | DatabaseError, 74 | lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) 75 | 76 | def test_url_default(self): 77 | engine = create_engine('trino+pyhive://localhost:18080/hive') 78 | try: 79 | with contextlib.closing(engine.connect()) as connection: 80 | self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) 81 | finally: 82 | engine.dispose() 83 | 84 | @with_engine_connection 85 | def test_reserved_words(self, engine, connection): 86 | """Trino uses double quotes, not backticks""" 87 | # Use keywords for the table/column name 88 | fake_table = Table('select', MetaData(), Column('current_timestamp', String)) 89 | query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) 90 | self.assertIn('"select"', query) 91 | self.assertIn('"current_timestamp"', query) 92 | self.assertNotIn('`select`', query) 93 | self.assertNotIn('`current_timestamp`', query) 94 | -------------------------------------------------------------------------------- /pyhive/tests/test_trino.py: -------------------------------------------------------------------------------- 1 | """Trino integration tests. 2 | 3 | These rely on having a Trino+Hadoop cluster set up. 4 | They also require a tables created by make_test_tables.sh. 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import unicode_literals 9 | 10 | import contextlib 11 | import os 12 | from decimal import Decimal 13 | 14 | import requests 15 | 16 | from pyhive import exc 17 | from pyhive import trino 18 | from pyhive.tests.dbapi_test_case import DBAPITestCase 19 | from pyhive.tests.dbapi_test_case import with_cursor 20 | from pyhive.tests.test_presto import TestPresto 21 | import mock 22 | import unittest 23 | import datetime 24 | 25 | _HOST = 'localhost' 26 | _PORT = '18080' 27 | 28 | 29 | class TestTrino(TestPresto): 30 | __test__ = True 31 | 32 | def connect(self): 33 | return trino.connect(host=_HOST, port=_PORT, source=self.id()) 34 | 35 | def test_bad_protocol(self): 36 | self.assertRaisesRegexp(ValueError, 'Protocol must be', 37 | lambda: trino.connect('localhost', protocol='nonsense').cursor()) 38 | 39 | def test_escape_args(self): 40 | escaper = trino.TrinoParamEscaper() 41 | 42 | self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), 43 | ("date '2020-04-17'",)) 44 | self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), 45 | ("timestamp '2020-04-17 12:00:00.123'",)) 46 | 47 | @with_cursor 48 | def test_description(self, cursor): 49 | cursor.execute('SELECT 1 AS foobar FROM one_row') 50 | self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) 51 | self.assertIsNotNone(cursor.last_query_id) 52 | 53 | @with_cursor 54 | def test_complex(self, cursor): 55 | cursor.execute('SELECT * FROM one_row_complex') 56 | # TODO Trino drops the union field 57 | 58 | tinyint_type = 'tinyint' 59 | smallint_type = 'smallint' 60 | float_type = 'real' 61 | self.assertEqual(cursor.description, [ 62 | ('boolean', 'boolean', None, None, None, None, True), 63 | ('tinyint', tinyint_type, None, None, None, None, True), 64 | ('smallint', smallint_type, None, None, None, None, True), 65 | ('int', 'integer', None, None, None, None, True), 66 | ('bigint', 'bigint', None, None, None, None, True), 67 | ('float', float_type, None, None, None, None, True), 68 | ('double', 'double', None, None, None, None, True), 69 | ('string', 'varchar', None, None, None, None, True), 70 | ('timestamp', 'timestamp', None, None, None, None, True), 71 | ('binary', 'varbinary', None, None, None, None, True), 72 | ('array', 'array(integer)', None, None, None, None, True), 73 | ('map', 'map(integer,integer)', None, None, None, None, True), 74 | ('struct', 'row(a integer,b integer)', None, None, None, None, True), 75 | # ('union', 'varchar', None, None, None, None, True), 76 | ('decimal', 'decimal(10,1)', None, None, None, None, True), 77 | ]) 78 | rows = cursor.fetchall() 79 | expected = [( 80 | True, 81 | 127, 82 | 32767, 83 | 2147483647, 84 | 9223372036854775807, 85 | 0.5, 86 | 0.25, 87 | 'a string', 88 | '1970-01-01 00:00:00.000', 89 | b'123', 90 | [1, 2], 91 | {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON 92 | [1, 2], # struct is returned as a list of elements 93 | # '{0:1}', 94 | Decimal('0.1'), 95 | )] 96 | self.assertEqual(rows, expected) 97 | # catch unicode/str 98 | self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) -------------------------------------------------------------------------------- /pyhive/trino.py: -------------------------------------------------------------------------------- 1 | """DB-API implementation backed by Trino 2 | 3 | See http://www.python.org/dev/peps/pep-0249/ 4 | 5 | Many docstrings in this file are based on the PEP, which is in the public domain. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import unicode_literals 10 | 11 | import logging 12 | 13 | import requests 14 | 15 | # Make all exceptions visible in this module per DB-API 16 | from pyhive.common import DBAPITypeObject 17 | from pyhive.exc import * # noqa 18 | from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper 19 | 20 | try: # Python 3 21 | import urllib.parse as urlparse 22 | except ImportError: # Python 2 23 | import urlparse 24 | 25 | # PEP 249 module globals 26 | apilevel = '2.0' 27 | threadsafety = 2 # Threads may share the module and connections. 28 | paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s 29 | 30 | _logger = logging.getLogger(__name__) 31 | 32 | 33 | class TrinoParamEscaper(PrestoParamEscaper): 34 | pass 35 | 36 | 37 | _escaper = TrinoParamEscaper() 38 | 39 | 40 | def connect(*args, **kwargs): 41 | """Constructor for creating a connection to the database. See class :py:class:`Connection` for 42 | arguments. 43 | 44 | :returns: a :py:class:`Connection` object. 45 | """ 46 | return Connection(*args, **kwargs) 47 | 48 | 49 | class Connection(PrestoConnection): 50 | def __init__(self, *args, **kwargs): 51 | super().__init__(*args, **kwargs) 52 | 53 | def cursor(self): 54 | """Return a new :py:class:`Cursor` object using the connection.""" 55 | return Cursor(*self._args, **self._kwargs) 56 | 57 | 58 | class Cursor(PrestoCursor): 59 | """These objects represent a database cursor, which is used to manage the context of a fetch 60 | operation. 61 | 62 | Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately 63 | visible by other cursors or connections. 64 | """ 65 | 66 | def execute(self, operation, parameters=None): 67 | """Prepare and execute a database operation (query or command). 68 | 69 | Return values are not defined. 70 | """ 71 | headers = { 72 | 'X-Trino-Catalog': self._catalog, 73 | 'X-Trino-Schema': self._schema, 74 | 'X-Trino-Source': self._source, 75 | 'X-Trino-User': self._username, 76 | } 77 | 78 | if self._session_props: 79 | headers['X-Trino-Session'] = ','.join( 80 | '{}={}'.format(propname, propval) 81 | for propname, propval in self._session_props.items() 82 | ) 83 | 84 | # Prepare statement 85 | if parameters is None: 86 | sql = operation 87 | else: 88 | sql = operation % _escaper.escape_args(parameters) 89 | 90 | self._reset_state() 91 | 92 | self._state = self._STATE_RUNNING 93 | url = urlparse.urlunparse(( 94 | self._protocol, 95 | '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) 96 | _logger.info('%s', sql) 97 | _logger.debug("Headers: %s", headers) 98 | response = self._requests_session.post( 99 | url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) 100 | self._process_response(response) 101 | 102 | def _process_response(self, response): 103 | """Given the JSON response from Trino's REST API, update the internal state with the next 104 | URI and any data from the response 105 | """ 106 | # TODO handle HTTP 503 107 | if response.status_code != requests.codes.ok: 108 | fmt = "Unexpected status code {}\n{}" 109 | raise OperationalError(fmt.format(response.status_code, response.content)) 110 | 111 | response_json = response.json() 112 | _logger.debug("Got response %s", response_json) 113 | assert self._state == self._STATE_RUNNING, "Should be running if processing response" 114 | self._nextUri = response_json.get('nextUri') 115 | self._columns = response_json.get('columns') 116 | if 'id' in response_json: 117 | self.last_query_id = response_json['id'] 118 | if 'X-Trino-Clear-Session' in response.headers: 119 | propname = response.headers['X-Trino-Clear-Session'] 120 | self._session_props.pop(propname, None) 121 | if 'X-Trino-Set-Session' in response.headers: 122 | propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1) 123 | self._session_props[propname] = propval 124 | if 'data' in response_json: 125 | assert self._columns 126 | new_data = response_json['data'] 127 | self._process_data(new_data) 128 | self._data += map(tuple, new_data) 129 | if 'nextUri' not in response_json: 130 | self._state = self._STATE_FINISHED 131 | if 'error' in response_json: 132 | raise DatabaseError(response_json['error']) 133 | 134 | 135 | # 136 | # Type Objects and Constructors 137 | # 138 | 139 | 140 | # See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java 141 | FIXED_INT_64 = DBAPITypeObject(['bigint']) 142 | VARIABLE_BINARY = DBAPITypeObject(['varchar']) 143 | DOUBLE = DBAPITypeObject(['double']) 144 | BOOLEAN = DBAPITypeObject(['boolean']) 145 | -------------------------------------------------------------------------------- /scripts/ldap_config/slapd.conf: -------------------------------------------------------------------------------- 1 | # See slapd.conf(5) for details on configuration options. 2 | include /etc/ldap/schema/core.schema 3 | include /etc/ldap/schema/cosine.schema 4 | include /etc/ldap/schema/inetorgperson.schema 5 | include /etc/ldap/schema/nis.schema 6 | 7 | pidfile /tmp/slapd/slapd.pid 8 | argsfile /tmp/slapd/slapd.args 9 | 10 | modulepath /usr/lib/openldap 11 | 12 | database ldif 13 | directory /tmp/slapd 14 | 15 | suffix "dc=example,dc=com" 16 | rootdn "cn=admin,dc=example,dc=com" 17 | rootpw {SSHA}AIzygLSXlArhAMzddUriXQxf7UlkqopP 18 | -------------------------------------------------------------------------------- /scripts/make_many_rows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | temp_file=/tmp/pyhive_test_data_many_rows.tsv 4 | seq 0 9999 > $temp_file 5 | 6 | hive -e " 7 | DROP TABLE IF EXISTS many_rows; 8 | CREATE TABLE many_rows ( 9 | a INT 10 | ) PARTITIONED BY ( 11 | b STRING 12 | ) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE; 13 | LOAD DATA LOCAL INPATH '$temp_file' INTO TABLE many_rows PARTITION (b='blah'); 14 | " 15 | rm -f $temp_file 16 | -------------------------------------------------------------------------------- /scripts/make_one_row.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | hive -e ' 3 | set mapred.job.tracker=local; 4 | DROP TABLE IF EXISTS one_row; 5 | CREATE TABLE one_row (number_of_rows INT); 6 | INSERT INTO TABLE one_row VALUES (1); 7 | ' 8 | -------------------------------------------------------------------------------- /scripts/make_one_row_complex.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | COLUMNS=' 4 | `boolean` BOOLEAN, 5 | `tinyint` TINYINT, 6 | `smallint` SMALLINT, 7 | `int` INT, 8 | `bigint` BIGINT, 9 | `float` FLOAT, 10 | `double` DOUBLE, 11 | `string` STRING, 12 | `timestamp` TIMESTAMP, 13 | `binary` BINARY, 14 | `array` ARRAY, 15 | `map` MAP, 16 | `struct` STRUCT, 17 | `union` UNIONTYPE, 18 | `decimal` DECIMAL(10, 1) 19 | ' 20 | 21 | hive -e " 22 | set mapred.job.tracker=local; 23 | DROP TABLE IF EXISTS one_row_complex; 24 | DROP TABLE IF EXISTS one_row_complex_null; 25 | CREATE TABLE one_row_complex ($COLUMNS); 26 | CREATE TABLE one_row_complex_null ($COLUMNS); 27 | INSERT OVERWRITE TABLE one_row_complex SELECT 28 | true, 29 | 127, 30 | 32767, 31 | 2147483647, 32 | 9223372036854775807, 33 | 0.5, 34 | 0.25, 35 | 'a string', 36 | 0, 37 | '123', 38 | array(1, 2), 39 | map(1, 2, 3, 4), 40 | named_struct('a', 1, 'b', 2), 41 | create_union(0, 1, 'test_string'), 42 | 0.1 43 | FROM one_row; 44 | INSERT OVERWRITE TABLE one_row_complex_null SELECT 45 | null, 46 | null, 47 | null, 48 | null, 49 | null, 50 | null, 51 | null, 52 | null, 53 | null, 54 | null, 55 | IF(false, array(1, 2), null), 56 | IF(false, map(1, 2, 3, 4), null), 57 | IF(false, named_struct('a', 1, 'b', 2), null), 58 | IF(false, create_union(0, 1, 'test_string'), null), 59 | null 60 | FROM one_row; 61 | " 62 | 63 | # Note: using IF(false, ...) above to work around https://issues.apache.org/jira/browse/HIVE-4022 64 | # The problem is that a "void" type cannot be inserted into a complex type field. 65 | -------------------------------------------------------------------------------- /scripts/make_test_database.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | hive -e ' 3 | DROP DATABASE IF EXISTS pyhive_test_database CASCADE; 4 | CREATE DATABASE pyhive_test_database; 5 | CREATE TABLE pyhive_test_database.dummy_table (a INT); 6 | ' 7 | -------------------------------------------------------------------------------- /scripts/make_test_tables.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | # Hive must be on the path for this script to work. 3 | # WARNING: drops and recreates tables called one_row, one_row_complex, and many_rows, plus a 4 | # database called pyhive_test_database. 5 | 6 | $(dirname $0)/make_one_row.sh 7 | $(dirname $0)/make_one_row_complex.sh 8 | $(dirname $0)/make_many_rows.sh 9 | $(dirname $0)/make_test_database.sh 10 | -------------------------------------------------------------------------------- /scripts/thrift-patches/TCLIService.patch: -------------------------------------------------------------------------------- 1 | --- TCLIService.thrift 2017-04-16 22:51:45.000000000 -0700 2 | +++ TCLIService.thrift 2017-03-17 17:12:07.000000000 -0700 3 | @@ -32,8 +32,9 @@ 4 | // * Service names begin with the letter "T", use a capital letter for each 5 | // new word (with no underscores), and end with the word "Service". 6 | 7 | namespace java org.apache.hive.service.rpc.thrift 8 | namespace cpp apache.hive.service.rpc.thrift 9 | +namespace py TCLIService 10 | 11 | // List of protocol versions. A new token should be 12 | // added to the end of this list every time a change is made. 13 | @@ -1224,6 +1225,19 @@ 14 | 6: required i64 startTime 15 | } 16 | 17 | +// GetLog() is not present in never versions of Hive, but we add it here for backwards compatibility 18 | +struct TGetLogReq { 19 | + // operation handle 20 | + 1: required TOperationHandle operationHandle 21 | +} 22 | + 23 | +struct TGetLogResp { 24 | + // status of the request 25 | + 1: required TStatus status 26 | + // log content as text 27 | + 2: required string log 28 | +} 29 | + 30 | service TCLIService { 31 | 32 | TOpenSessionResp OpenSession(1:TOpenSessionReq req); 33 | @@ -1267,4 +1281,7 @@ 34 | TCancelDelegationTokenResp CancelDelegationToken(1:TCancelDelegationTokenReq req); 35 | 36 | TRenewDelegationTokenResp RenewDelegationToken(1:TRenewDelegationTokenReq req); 37 | + 38 | + // Adding older log retrieval method for backward compatibility 39 | + TGetLogResp GetLog(1:TGetLogReq req); 40 | } 41 | -------------------------------------------------------------------------------- /scripts/travis-conf/com/dropbox/DummyPasswdAuthenticationProvider.java: -------------------------------------------------------------------------------- 1 | package com.dropbox; 2 | 3 | import org.apache.hive.service.auth.PasswdAuthenticationProvider; 4 | 5 | import javax.security.sasl.AuthenticationException; 6 | 7 | public class DummyPasswdAuthenticationProvider implements PasswdAuthenticationProvider { 8 | @Override 9 | public void Authenticate(String user, String password) throws AuthenticationException { 10 | if (!user.equals("the-user") || !password.equals("p4ssw0rd")) { 11 | throw new AuthenticationException(); 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /scripts/travis-conf/hive/hive-site-custom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | hive.metastore.uris 6 | thrift://localhost:9083 7 | 8 | 9 | javax.jdo.option.ConnectionURL 10 | jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true 11 | 12 | 13 | fs.defaultFS 14 | file:/// 15 | 16 | 17 | hive.server2.authentication 18 | CUSTOM 19 | 20 | 21 | hive.server2.custom.authentication.class 22 | com.dropbox.DummyPasswdAuthenticationProvider 23 | 24 | 25 | -------------------------------------------------------------------------------- /scripts/travis-conf/hive/hive-site-ldap.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | hive.metastore.uris 6 | thrift://localhost:9083 7 | 8 | 9 | javax.jdo.option.ConnectionURL 10 | jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true 11 | 12 | 13 | fs.defaultFS 14 | file:/// 15 | 16 | 17 | hive.server2.authentication 18 | LDAP 19 | 20 | 21 | hive.server2.authentication.ldap.url 22 | ldap://localhost:3389 23 | 24 | 25 | hive.server2.authentication.ldap.baseDN 26 | dc=example,dc=com 27 | 28 | 29 | hive.server2.authentication.ldap.guidKey 30 | cn 31 | 32 | 33 | -------------------------------------------------------------------------------- /scripts/travis-conf/hive/hive-site.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | hive.metastore.uris 6 | thrift://localhost:9083 7 | 8 | 9 | javax.jdo.option.ConnectionURL 10 | jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true 11 | 12 | 13 | fs.defaultFS 14 | file:/// 15 | 16 | 17 | -------------------------------------------------------------------------------- /scripts/travis-conf/presto/catalog/hive.properties: -------------------------------------------------------------------------------- 1 | connector.name=hive-hadoop2 2 | hive.metastore.uri=thrift://localhost:9083 3 | -------------------------------------------------------------------------------- /scripts/travis-conf/presto/config.properties: -------------------------------------------------------------------------------- 1 | coordinator=true 2 | node-scheduler.include-coordinator=true 3 | http-server.http.port=8080 4 | query.max-memory=100MB 5 | query.max-memory-per-node=100MB 6 | discovery-server.enabled=true 7 | discovery.uri=http://localhost:8080 8 | -------------------------------------------------------------------------------- /scripts/travis-conf/presto/jvm.config: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/PyHive/ac09074a652fd50e10b57a7f0bbc4f6410961301/scripts/travis-conf/presto/jvm.config -------------------------------------------------------------------------------- /scripts/travis-conf/presto/node.properties: -------------------------------------------------------------------------------- 1 | node.environment=production 2 | node.id=ffffffff-ffff-ffff-ffff-ffffffffffff 3 | node.data-dir=/tmp/presto/data 4 | -------------------------------------------------------------------------------- /scripts/travis-conf/trino/catalog/hive.properties: -------------------------------------------------------------------------------- 1 | connector.name=hive-hadoop2 2 | hive.metastore.uri=thrift://localhost:9083 3 | -------------------------------------------------------------------------------- /scripts/travis-conf/trino/config.properties: -------------------------------------------------------------------------------- 1 | coordinator=true 2 | node-scheduler.include-coordinator=true 3 | http-server.http.port=18080 4 | query.max-memory=100MB 5 | query.max-memory-per-node=100MB 6 | discovery-server.enabled=true 7 | discovery.uri=http://localhost:18080 8 | -------------------------------------------------------------------------------- /scripts/travis-conf/trino/jvm.config: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/PyHive/ac09074a652fd50e10b57a7f0bbc4f6410961301/scripts/travis-conf/trino/jvm.config -------------------------------------------------------------------------------- /scripts/travis-conf/trino/node.properties: -------------------------------------------------------------------------------- 1 | node.environment=production 2 | node.id=11111111-1111-1111-1111-111111111111 3 | node.data-dir=/tmp/trino/data 4 | -------------------------------------------------------------------------------- /scripts/travis-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | source /etc/lsb-release 4 | 5 | echo "deb [arch=amd64] https://archive.cloudera.com/${CDH}/ubuntu/${DISTRIB_CODENAME}/amd64/cdh ${DISTRIB_CODENAME}-cdh${CDH_VERSION} contrib 6 | deb-src https://archive.cloudera.com/${CDH}/ubuntu/${DISTRIB_CODENAME}/amd64/cdh ${DISTRIB_CODENAME}-cdh${CDH_VERSION} contrib" | sudo tee /etc/apt/sources.list.d/cloudera.list 7 | sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 327574EE02A818DD 8 | sudo apt-get -q update 9 | 10 | sudo apt-get -q install -y oracle-java8-installer python-dev g++ libsasl2-dev maven 11 | sudo update-java-alternatives -s java-8-oracle 12 | 13 | # 14 | # LDAP 15 | # 16 | sudo apt-get -q -y --no-install-suggests --no-install-recommends --force-yes install ldap-utils slapd 17 | sudo mkdir -p /tmp/slapd 18 | sudo slapd -f $(dirname $0)/ldap_config/slapd.conf -h ldap://localhost:3389 & 19 | while ! nc -vz localhost 3389; do sleep 1; done 20 | sudo ldapadd -h localhost:3389 -D cn=admin,dc=example,dc=com -w test -f $(dirname $0)/../pyhive/tests/ldif_data/base.ldif 21 | sudo ldapadd -h localhost:3389 -D cn=admin,dc=example,dc=com -w test -f $(dirname $0)/../pyhive/tests/ldif_data/INITIAL_TESTDATA.ldif 22 | 23 | # 24 | # Hive 25 | # 26 | 27 | sudo apt-get -q install -y --force-yes hive 28 | 29 | javac -cp /usr/lib/hive/lib/hive-service.jar $(dirname $0)/travis-conf/com/dropbox/DummyPasswdAuthenticationProvider.java 30 | jar cf $(dirname $0)/dummy-auth.jar -C $(dirname $0)/travis-conf com 31 | sudo cp $(dirname $0)/dummy-auth.jar /usr/lib/hive/lib 32 | 33 | # Hack around broken symlink in Hive's installation 34 | # /usr/lib/hive/lib/zookeeper.jar -> ../../zookeeper/zookeeper.jar 35 | # Without this, Hive fails to start up due to failing to find ZK classes. 36 | sudo ln -nsfv /usr/share/java/zookeeper.jar /usr/lib/hive/lib/zookeeper.jar 37 | 38 | sudo mkdir -p /user/hive 39 | sudo chown hive:hive /user/hive 40 | sudo cp $(dirname $0)/travis-conf/hive/hive-site.xml /etc/hive/conf/hive-site.xml 41 | sudo apt-get -q install -y --force-yes hive-metastore hive-server2 || (grep . /var/log/hive/* && exit 2) 42 | 43 | while ! nc -vz localhost 9083; do sleep 1; done 44 | while ! nc -vz localhost 10000; do sleep 1; done 45 | 46 | sudo -Eu hive $(dirname $0)/make_test_tables.sh 47 | 48 | # 49 | # Presto 50 | # 51 | 52 | sudo apt-get -q install -y python # Use python2 for presto server 53 | 54 | mvn -q org.apache.maven.plugins:maven-dependency-plugin:3.0.0:copy \ 55 | -Dartifact=com.facebook.presto:presto-server:${PRESTO}:tar.gz \ 56 | -DoutputDirectory=. 57 | tar -x -z -f presto-server-*.tar.gz 58 | rm -rf presto-server 59 | mv presto-server-*/ presto-server 60 | 61 | cp -r $(dirname $0)/travis-conf/presto presto-server/etc 62 | 63 | /usr/bin/python2.7 presto-server/bin/launcher.py start 64 | 65 | # 66 | # Trino 67 | # 68 | 69 | sudo apt-get -q install -y python # Use python2 for trino server 70 | 71 | mvn -q org.apache.maven.plugins:maven-dependency-plugin:3.0.0:copy \ 72 | -Dartifact=io.trino:trino-server:${TRINO}:tar.gz \ 73 | -DoutputDirectory=. 74 | tar -x -z -f trino-server-*.tar.gz 75 | rm -rf trino-server 76 | mv trino-server-*/ trino-server 77 | 78 | cp -r $(dirname $0)/travis-conf/trino trino-server/etc 79 | 80 | /usr/bin/python2.7 trino-server/bin/launcher.py start 81 | 82 | # 83 | # Python 84 | # 85 | 86 | pip install $SQLALCHEMY 87 | pip install -e . 88 | pip install -r dev_requirements.txt 89 | 90 | # Sleep so Presto has time to start up. 91 | # Otherwise we might get 'No nodes available to run query' or 'Presto server is still initializing' 92 | while ! grep -q 'SERVER STARTED' /tmp/presto/data/var/log/server.log; do sleep 1; done 93 | 94 | # Sleep so Trino has time to start up. 95 | # Otherwise we might get 'No nodes available to run query' or 'Presto server is still initializing' 96 | while ! grep -q 'SERVER STARTED' /tmp/trino/data/var/log/server.log; do sleep 1; done 97 | -------------------------------------------------------------------------------- /scripts/update_thrift_bindings.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | HIVE_VERSION='2.3.0' # Must be a released version 4 | 5 | # Create a temporary directory 6 | scriptdir=`dirname $0` 7 | tmpdir=$scriptdir/.thrift_gen 8 | 9 | # Clean up previous generation attempts, in case it breaks things 10 | rm -rf $tmpdir 11 | mkdir $tmpdir 12 | 13 | # Download TCLIService.thrift from Hive 14 | curl -o $tmpdir/TCLIService.thrift \ 15 | https://raw.githubusercontent.com/apache/hive/rel/release-$HIVE_VERSION/service-rpc/if/TCLIService.thrift 16 | 17 | # Apply patch that adds legacy GetLog methods 18 | patch -d $tmpdir < $scriptdir/thrift-patches/TCLIService.patch 19 | 20 | thrift -r --gen py -out $scriptdir/../ $tmpdir/TCLIService.thrift 21 | rm $scriptdir/../__init__.py 22 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [egg_info] 2 | tag_build = dev 3 | 4 | [tool:pytest] 5 | timeout = 100 6 | timeout_method = thread 7 | addopts = --random --tb=short --cov pyhive --cov-report html --cov-report term --flake8 8 | norecursedirs = env 9 | python_files = test_*.py 10 | flake8-max-line-length = 100 11 | flake8-ignore = 12 | TCLIService/*.py ALL 13 | pyhive/sqlalchemy_backports.py ALL 14 | presto-server/** ALL 15 | pyhive/hive.py F405 16 | pyhive/presto.py F405 17 | pyhive/trino.py F405 18 | W503 19 | filterwarnings = 20 | error 21 | # For Python 2 flake8 22 | default:You passed a bytestring as `filenames`:DeprecationWarning:flake8.options.config 23 | default::UnicodeWarning:_pytest.warnings 24 | # TODO For old sqlalchemy 25 | default:cgi.parse_qsl is deprecated:PendingDeprecationWarning:sqlalchemy.engine.url 26 | # TODO 27 | default:Did not recognize type:sqlalchemy.exc.SAWarning 28 | default:The Binary type has been renamed to LargeBinary:sqlalchemy.exc.SADeprecationWarning 29 | default:Please use assertRaisesRegex instead:DeprecationWarning 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | from setuptools.command.test import test as TestCommand 5 | import pyhive 6 | import sys 7 | 8 | 9 | class PyTest(TestCommand): 10 | def finalize_options(self): 11 | TestCommand.finalize_options(self) 12 | self.test_args = [] 13 | self.test_suite = True 14 | 15 | def run_tests(self): 16 | # import here, cause outside the eggs aren't loaded 17 | import pytest 18 | errno = pytest.main(self.test_args) 19 | sys.exit(errno) 20 | 21 | 22 | with open('README.rst') as readme: 23 | long_description = readme.read() 24 | 25 | setup( 26 | name="PyHive", 27 | version=pyhive.__version__, 28 | description="Python interface to Hive", 29 | long_description=long_description, 30 | url='https://github.com/dropbox/PyHive', 31 | author="Jing Wang", 32 | author_email="jing@dropbox.com", 33 | license="Apache License, Version 2.0", 34 | packages=['pyhive', 'TCLIService'], 35 | classifiers=[ 36 | "Intended Audience :: Developers", 37 | "License :: OSI Approved :: Apache Software License", 38 | "Operating System :: OS Independent", 39 | "Topic :: Database :: Front-Ends", 40 | ], 41 | install_requires=[ 42 | 'future', 43 | 'python-dateutil', 44 | ], 45 | extras_require={ 46 | 'presto': ['requests>=1.0.0'], 47 | 'trino': ['requests>=1.0.0'], 48 | 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 49 | 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 50 | 'sqlalchemy': ['sqlalchemy>=1.3.0'], 51 | 'kerberos': ['requests_kerberos>=0.12.0'], 52 | }, 53 | tests_require=[ 54 | 'mock>=1.0.0', 55 | 'pytest', 56 | 'pytest-cov', 57 | 'requests>=1.0.0', 58 | 'requests_kerberos>=0.12.0', 59 | 'sasl>=0.2.1', 60 | 'pure-sasl>=0.6.2', 61 | 'kerberos>=1.3.0', 62 | 'sqlalchemy>=1.3.0', 63 | 'thrift>=0.10.0', 64 | ], 65 | cmdclass={'test': PyTest}, 66 | package_data={ 67 | '': ['*.rst'], 68 | }, 69 | entry_points={ 70 | 'sqlalchemy.dialects': [ 71 | 'hive = pyhive.sqlalchemy_hive:HiveDialect', 72 | "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect", 73 | "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect", 74 | 'presto = pyhive.sqlalchemy_presto:PrestoDialect', 75 | 'trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect', 76 | ], 77 | } 78 | ) 79 | --------------------------------------------------------------------------------