├── .coveragerc
├── .gitignore
├── .travis.yml
├── LICENSE
├── MANIFEST.in
├── README.rst
├── TCLIService
├── TCLIService-remote
├── TCLIService.py
├── __init__.py
├── constants.py
└── ttypes.py
├── dev_requirements.txt
├── generate.py
├── pyhive
├── __init__.py
├── common.py
├── exc.py
├── hive.py
├── presto.py
├── sasl_compat.py
├── sqlalchemy_hive.py
├── sqlalchemy_presto.py
├── sqlalchemy_trino.py
├── tests
│ ├── __init__.py
│ ├── dbapi_test_case.py
│ ├── ldif_data
│ │ ├── INITIAL_TESTDATA.ldif
│ │ └── base.ldif
│ ├── sqlalchemy_test_case.py
│ ├── test_common.py
│ ├── test_hive.py
│ ├── test_presto.py
│ ├── test_sasl_compat.py
│ ├── test_sqlalchemy_hive.py
│ ├── test_sqlalchemy_presto.py
│ ├── test_sqlalchemy_trino.py
│ └── test_trino.py
└── trino.py
├── scripts
├── ldap_config
│ └── slapd.conf
├── make_many_rows.sh
├── make_one_row.sh
├── make_one_row_complex.sh
├── make_test_database.sh
├── make_test_tables.sh
├── thrift-patches
│ └── TCLIService.patch
├── travis-conf
│ ├── com
│ │ └── dropbox
│ │ │ └── DummyPasswdAuthenticationProvider.java
│ ├── hive
│ │ ├── hive-site-custom.xml
│ │ ├── hive-site-ldap.xml
│ │ └── hive-site.xml
│ ├── presto
│ │ ├── catalog
│ │ │ └── hive.properties
│ │ ├── config.properties
│ │ ├── jvm.config
│ │ └── node.properties
│ └── trino
│ │ ├── catalog
│ │ └── hive.properties
│ │ ├── config.properties
│ │ ├── jvm.config
│ │ └── node.properties
├── travis-install.sh
└── update_thrift_bindings.sh
├── setup.cfg
└── setup.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | branch = True
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | cover/
2 | .coverage
3 | /dist/
4 | .DS_Store
5 | *.egg
6 | /env/
7 | /htmlcov/
8 | .idea/
9 | .project
10 | *.pyc
11 | .pydevproject
12 | /PyHive.egg-info/
13 | .settings
14 | .cache/
15 | *.iml
16 | /scripts/.thrift_gen
17 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 | language: python
3 | dist: trusty
4 | matrix:
5 | include:
6 | # https://docs.python.org/devguide/#status-of-python-branches
7 | # One build pulls latest versions dynamically
8 | - python: 3.6
9 | env: CDH=cdh5 CDH_VERSION=5 PRESTO=RELEASE TRINO=RELEASE SQLALCHEMY=sqlalchemy>=1.3.0
10 | - python: 3.6
11 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
12 | - python: 3.5
13 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
14 | - python: 3.4
15 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
16 | - python: 2.7
17 | env: CDH=cdh5 CDH_VERSION=5.10.1 PRESTO=0.147 TRINO=351 SQLALCHEMY=sqlalchemy>=1.3.0
18 | install:
19 | - ./scripts/travis-install.sh
20 | - pip install codecov
21 | script: PYTHONHASHSEED=random pytest -v
22 | after_success: codecov
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2014 Dropbox, Inc.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | ========================================================
2 | PyHive project has been donated to Apache Kyuubi
3 | ========================================================
4 |
5 | You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive
6 |
7 |
8 |
9 | Legacy notes / instructions
10 | ===========================
11 |
12 | PyHive
13 | **********
14 |
15 |
16 | PyHive is a collection of Python `DB-API `_ and
17 | `SQLAlchemy `_ interfaces for `Presto `_ ,
18 | `Hive `_ and `Trino `_.
19 |
20 | Usage
21 | **********
22 |
23 | DB-API
24 | ------
25 | .. code-block:: python
26 |
27 | from pyhive import presto # or import hive or import trino
28 | cursor = presto.connect('localhost').cursor() # or use hive.connect or use trino.connect
29 | cursor.execute('SELECT * FROM my_awesome_data LIMIT 10')
30 | print cursor.fetchone()
31 | print cursor.fetchall()
32 |
33 | DB-API (asynchronous)
34 | ---------------------
35 | .. code-block:: python
36 |
37 | from pyhive import hive
38 | from TCLIService.ttypes import TOperationState
39 | cursor = hive.connect('localhost').cursor()
40 | cursor.execute('SELECT * FROM my_awesome_data LIMIT 10', async=True)
41 |
42 | status = cursor.poll().operationState
43 | while status in (TOperationState.INITIALIZED_STATE, TOperationState.RUNNING_STATE):
44 | logs = cursor.fetch_logs()
45 | for message in logs:
46 | print message
47 |
48 | # If needed, an asynchronous query can be cancelled at any time with:
49 | # cursor.cancel()
50 |
51 | status = cursor.poll().operationState
52 |
53 | print cursor.fetchall()
54 |
55 | In Python 3.7 `async` became a keyword; you can use `async_` instead:
56 |
57 | .. code-block:: python
58 |
59 | cursor.execute('SELECT * FROM my_awesome_data LIMIT 10', async_=True)
60 |
61 |
62 | SQLAlchemy
63 | ----------
64 | First install this package to register it with SQLAlchemy, see ``entry_points`` in ``setup.py``.
65 |
66 | .. code-block:: python
67 |
68 | from sqlalchemy import *
69 | from sqlalchemy.engine import create_engine
70 | from sqlalchemy.schema import *
71 | # Presto
72 | engine = create_engine('presto://localhost:8080/hive/default')
73 | # Trino
74 | engine = create_engine('trino+pyhive://localhost:8080/hive/default')
75 | # Hive
76 | engine = create_engine('hive://localhost:10000/default')
77 |
78 | # SQLAlchemy < 2.0
79 | logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
80 | print select([func.count('*')], from_obj=logs).scalar()
81 |
82 | # Hive + HTTPS + LDAP or basic Auth
83 | engine = create_engine('hive+https://username:password@localhost:10000/')
84 | logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
85 | print select([func.count('*')], from_obj=logs).scalar()
86 |
87 | # SQLAlchemy >= 2.0
88 | metadata_obj = MetaData()
89 | books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String))
90 | metadata_obj.create_all(engine)
91 | inspector = inspect(engine)
92 | inspector.get_columns('books')
93 |
94 | with engine.connect() as con:
95 | data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" },
96 | { "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }]
97 | con.execute(books.insert(), data[0])
98 | result = con.execute(text("select * from books"))
99 | print(result.fetchall())
100 |
101 | Note: query generation functionality is not exhaustive or fully tested, but there should be no
102 | problem with raw SQL.
103 |
104 | Passing session configuration
105 | -----------------------------
106 |
107 | .. code-block:: python
108 |
109 | # DB-API
110 | hive.connect('localhost', configuration={'hive.exec.reducers.max': '123'})
111 | presto.connect('localhost', session_props={'query_max_run_time': '1234m'})
112 | trino.connect('localhost', session_props={'query_max_run_time': '1234m'})
113 | # SQLAlchemy
114 | create_engine(
115 | 'presto://user@host:443/hive',
116 | connect_args={'protocol': 'https',
117 | 'session_props': {'query_max_run_time': '1234m'}}
118 | )
119 | create_engine(
120 | 'trino+pyhive://user@host:443/hive',
121 | connect_args={'protocol': 'https',
122 | 'session_props': {'query_max_run_time': '1234m'}}
123 | )
124 | create_engine(
125 | 'hive://user@host:10000/database',
126 | connect_args={'configuration': {'hive.exec.reducers.max': '123'}},
127 | )
128 | # SQLAlchemy with LDAP
129 | create_engine(
130 | 'hive://user:password@host:10000/database',
131 | connect_args={'auth': 'LDAP'},
132 | )
133 |
134 | Requirements
135 | ************
136 |
137 | Install using
138 |
139 | - ``pip install 'pyhive[hive]'`` or ``pip install 'pyhive[hive_pure_sasl]'`` for the Hive interface
140 | - ``pip install 'pyhive[presto]'`` for the Presto interface
141 | - ``pip install 'pyhive[trino]'`` for the Trino interface
142 |
143 | Note: ``'pyhive[hive]'`` extras uses `sasl `_ that doesn't support Python 3.11, See `github issue `_.
144 | Hence PyHive also supports `pure-sasl `_ via additional extras ``'pyhive[hive_pure_sasl]'`` which support Python 3.11.
145 |
146 | PyHive works with
147 |
148 | - Python 2.7 / Python 3
149 | - For Presto: `Presto installation `_
150 | - For Trino: `Trino installation `_
151 | - For Hive: `HiveServer2 `_ daemon
152 |
153 | Changelog
154 | *********
155 | See https://github.com/dropbox/PyHive/releases.
156 |
157 | Contributing
158 | ************
159 | - Please fill out the Dropbox Contributor License Agreement at https://opensource.dropbox.com/cla/ and note this in your pull request.
160 | - Changes must come with tests, with the exception of trivial things like fixing comments. See .travis.yml for the test environment setup.
161 | - Notes on project scope:
162 |
163 | - This project is intended to be a minimal Hive/Presto client that does that one thing and nothing else.
164 | Features that can be implemented on top of PyHive, such integration with your favorite data analysis library, are likely out of scope.
165 | - We prefer having a small number of generic features over a large number of specialized, inflexible features.
166 | For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option.
167 |
168 | Tips for test environment setup
169 | ****************************************
170 | You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download.
171 | It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment.
172 |
173 | You can clone `this repository `_ which has Docker Compose setup for Presto and Hive.
174 | You can add below lines to its docker-compose.yaml to start Trino in same environment::
175 |
176 | trino:
177 | image: trinodb/trino:351
178 | ports:
179 | - "18080:18080"
180 | volumes:
181 | - ./trino:/etc/trino
182 |
183 | Note: ``./trino`` for docker volume defined above is `trino config from PyHive repository `_
184 |
185 | Then run::
186 | docker-compose up -d
187 |
188 | Testing
189 | *******
190 | .. image:: https://travis-ci.org/dropbox/PyHive.svg
191 | :target: https://travis-ci.org/dropbox/PyHive
192 | .. image:: http://codecov.io/github/dropbox/PyHive/coverage.svg?branch=master
193 | :target: http://codecov.io/github/dropbox/PyHive?branch=master
194 |
195 | Run the following in an environment with Hive/Presto::
196 |
197 | ./scripts/make_test_tables.sh
198 | virtualenv --no-site-packages env
199 | source env/bin/activate
200 | pip install -e .
201 | pip install -r dev_requirements.txt
202 | py.test
203 |
204 | WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and ``many_rows``, plus a
205 | database called ``pyhive_test_database``.
206 |
207 | Updating TCLIService
208 | ********************
209 |
210 | The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the
211 | ``generate.py`` file can be used: ``python generate.py ``. When left blank, the
212 | version for Hive 2.3 will be downloaded.
213 |
--------------------------------------------------------------------------------
/TCLIService/TCLIService-remote:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # Autogenerated by Thrift Compiler (0.10.0)
4 | #
5 | # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
6 | #
7 | # options string: py
8 | #
9 |
10 | import sys
11 | import pprint
12 | if sys.version_info[0] > 2:
13 | from urllib.parse import urlparse
14 | else:
15 | from urlparse import urlparse
16 | from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient
17 | from thrift.protocol.TBinaryProtocol import TBinaryProtocol
18 |
19 | from TCLIService import TCLIService
20 | from TCLIService.ttypes import *
21 |
22 | if len(sys.argv) <= 1 or sys.argv[1] == '--help':
23 | print('')
24 | print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]')
25 | print('')
26 | print('Functions:')
27 | print(' TOpenSessionResp OpenSession(TOpenSessionReq req)')
28 | print(' TCloseSessionResp CloseSession(TCloseSessionReq req)')
29 | print(' TGetInfoResp GetInfo(TGetInfoReq req)')
30 | print(' TExecuteStatementResp ExecuteStatement(TExecuteStatementReq req)')
31 | print(' TGetTypeInfoResp GetTypeInfo(TGetTypeInfoReq req)')
32 | print(' TGetCatalogsResp GetCatalogs(TGetCatalogsReq req)')
33 | print(' TGetSchemasResp GetSchemas(TGetSchemasReq req)')
34 | print(' TGetTablesResp GetTables(TGetTablesReq req)')
35 | print(' TGetTableTypesResp GetTableTypes(TGetTableTypesReq req)')
36 | print(' TGetColumnsResp GetColumns(TGetColumnsReq req)')
37 | print(' TGetFunctionsResp GetFunctions(TGetFunctionsReq req)')
38 | print(' TGetPrimaryKeysResp GetPrimaryKeys(TGetPrimaryKeysReq req)')
39 | print(' TGetCrossReferenceResp GetCrossReference(TGetCrossReferenceReq req)')
40 | print(' TGetOperationStatusResp GetOperationStatus(TGetOperationStatusReq req)')
41 | print(' TCancelOperationResp CancelOperation(TCancelOperationReq req)')
42 | print(' TCloseOperationResp CloseOperation(TCloseOperationReq req)')
43 | print(' TGetResultSetMetadataResp GetResultSetMetadata(TGetResultSetMetadataReq req)')
44 | print(' TFetchResultsResp FetchResults(TFetchResultsReq req)')
45 | print(' TGetDelegationTokenResp GetDelegationToken(TGetDelegationTokenReq req)')
46 | print(' TCancelDelegationTokenResp CancelDelegationToken(TCancelDelegationTokenReq req)')
47 | print(' TRenewDelegationTokenResp RenewDelegationToken(TRenewDelegationTokenReq req)')
48 | print(' TGetLogResp GetLog(TGetLogReq req)')
49 | print('')
50 | sys.exit(0)
51 |
52 | pp = pprint.PrettyPrinter(indent=2)
53 | host = 'localhost'
54 | port = 9090
55 | uri = ''
56 | framed = False
57 | ssl = False
58 | validate = True
59 | ca_certs = None
60 | keyfile = None
61 | certfile = None
62 | http = False
63 | argi = 1
64 |
65 | if sys.argv[argi] == '-h':
66 | parts = sys.argv[argi + 1].split(':')
67 | host = parts[0]
68 | if len(parts) > 1:
69 | port = int(parts[1])
70 | argi += 2
71 |
72 | if sys.argv[argi] == '-u':
73 | url = urlparse(sys.argv[argi + 1])
74 | parts = url[1].split(':')
75 | host = parts[0]
76 | if len(parts) > 1:
77 | port = int(parts[1])
78 | else:
79 | port = 80
80 | uri = url[2]
81 | if url[4]:
82 | uri += '?%s' % url[4]
83 | http = True
84 | argi += 2
85 |
86 | if sys.argv[argi] == '-f' or sys.argv[argi] == '-framed':
87 | framed = True
88 | argi += 1
89 |
90 | if sys.argv[argi] == '-s' or sys.argv[argi] == '-ssl':
91 | ssl = True
92 | argi += 1
93 |
94 | if sys.argv[argi] == '-novalidate':
95 | validate = False
96 | argi += 1
97 |
98 | if sys.argv[argi] == '-ca_certs':
99 | ca_certs = sys.argv[argi+1]
100 | argi += 2
101 |
102 | if sys.argv[argi] == '-keyfile':
103 | keyfile = sys.argv[argi+1]
104 | argi += 2
105 |
106 | if sys.argv[argi] == '-certfile':
107 | certfile = sys.argv[argi+1]
108 | argi += 2
109 |
110 | cmd = sys.argv[argi]
111 | args = sys.argv[argi + 1:]
112 |
113 | if http:
114 | transport = THttpClient.THttpClient(host, port, uri)
115 | else:
116 | if ssl:
117 | socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile)
118 | else:
119 | socket = TSocket.TSocket(host, port)
120 | if framed:
121 | transport = TTransport.TFramedTransport(socket)
122 | else:
123 | transport = TTransport.TBufferedTransport(socket)
124 | protocol = TBinaryProtocol(transport)
125 | client = TCLIService.Client(protocol)
126 | transport.open()
127 |
128 | if cmd == 'OpenSession':
129 | if len(args) != 1:
130 | print('OpenSession requires 1 args')
131 | sys.exit(1)
132 | pp.pprint(client.OpenSession(eval(args[0]),))
133 |
134 | elif cmd == 'CloseSession':
135 | if len(args) != 1:
136 | print('CloseSession requires 1 args')
137 | sys.exit(1)
138 | pp.pprint(client.CloseSession(eval(args[0]),))
139 |
140 | elif cmd == 'GetInfo':
141 | if len(args) != 1:
142 | print('GetInfo requires 1 args')
143 | sys.exit(1)
144 | pp.pprint(client.GetInfo(eval(args[0]),))
145 |
146 | elif cmd == 'ExecuteStatement':
147 | if len(args) != 1:
148 | print('ExecuteStatement requires 1 args')
149 | sys.exit(1)
150 | pp.pprint(client.ExecuteStatement(eval(args[0]),))
151 |
152 | elif cmd == 'GetTypeInfo':
153 | if len(args) != 1:
154 | print('GetTypeInfo requires 1 args')
155 | sys.exit(1)
156 | pp.pprint(client.GetTypeInfo(eval(args[0]),))
157 |
158 | elif cmd == 'GetCatalogs':
159 | if len(args) != 1:
160 | print('GetCatalogs requires 1 args')
161 | sys.exit(1)
162 | pp.pprint(client.GetCatalogs(eval(args[0]),))
163 |
164 | elif cmd == 'GetSchemas':
165 | if len(args) != 1:
166 | print('GetSchemas requires 1 args')
167 | sys.exit(1)
168 | pp.pprint(client.GetSchemas(eval(args[0]),))
169 |
170 | elif cmd == 'GetTables':
171 | if len(args) != 1:
172 | print('GetTables requires 1 args')
173 | sys.exit(1)
174 | pp.pprint(client.GetTables(eval(args[0]),))
175 |
176 | elif cmd == 'GetTableTypes':
177 | if len(args) != 1:
178 | print('GetTableTypes requires 1 args')
179 | sys.exit(1)
180 | pp.pprint(client.GetTableTypes(eval(args[0]),))
181 |
182 | elif cmd == 'GetColumns':
183 | if len(args) != 1:
184 | print('GetColumns requires 1 args')
185 | sys.exit(1)
186 | pp.pprint(client.GetColumns(eval(args[0]),))
187 |
188 | elif cmd == 'GetFunctions':
189 | if len(args) != 1:
190 | print('GetFunctions requires 1 args')
191 | sys.exit(1)
192 | pp.pprint(client.GetFunctions(eval(args[0]),))
193 |
194 | elif cmd == 'GetPrimaryKeys':
195 | if len(args) != 1:
196 | print('GetPrimaryKeys requires 1 args')
197 | sys.exit(1)
198 | pp.pprint(client.GetPrimaryKeys(eval(args[0]),))
199 |
200 | elif cmd == 'GetCrossReference':
201 | if len(args) != 1:
202 | print('GetCrossReference requires 1 args')
203 | sys.exit(1)
204 | pp.pprint(client.GetCrossReference(eval(args[0]),))
205 |
206 | elif cmd == 'GetOperationStatus':
207 | if len(args) != 1:
208 | print('GetOperationStatus requires 1 args')
209 | sys.exit(1)
210 | pp.pprint(client.GetOperationStatus(eval(args[0]),))
211 |
212 | elif cmd == 'CancelOperation':
213 | if len(args) != 1:
214 | print('CancelOperation requires 1 args')
215 | sys.exit(1)
216 | pp.pprint(client.CancelOperation(eval(args[0]),))
217 |
218 | elif cmd == 'CloseOperation':
219 | if len(args) != 1:
220 | print('CloseOperation requires 1 args')
221 | sys.exit(1)
222 | pp.pprint(client.CloseOperation(eval(args[0]),))
223 |
224 | elif cmd == 'GetResultSetMetadata':
225 | if len(args) != 1:
226 | print('GetResultSetMetadata requires 1 args')
227 | sys.exit(1)
228 | pp.pprint(client.GetResultSetMetadata(eval(args[0]),))
229 |
230 | elif cmd == 'FetchResults':
231 | if len(args) != 1:
232 | print('FetchResults requires 1 args')
233 | sys.exit(1)
234 | pp.pprint(client.FetchResults(eval(args[0]),))
235 |
236 | elif cmd == 'GetDelegationToken':
237 | if len(args) != 1:
238 | print('GetDelegationToken requires 1 args')
239 | sys.exit(1)
240 | pp.pprint(client.GetDelegationToken(eval(args[0]),))
241 |
242 | elif cmd == 'CancelDelegationToken':
243 | if len(args) != 1:
244 | print('CancelDelegationToken requires 1 args')
245 | sys.exit(1)
246 | pp.pprint(client.CancelDelegationToken(eval(args[0]),))
247 |
248 | elif cmd == 'RenewDelegationToken':
249 | if len(args) != 1:
250 | print('RenewDelegationToken requires 1 args')
251 | sys.exit(1)
252 | pp.pprint(client.RenewDelegationToken(eval(args[0]),))
253 |
254 | elif cmd == 'GetLog':
255 | if len(args) != 1:
256 | print('GetLog requires 1 args')
257 | sys.exit(1)
258 | pp.pprint(client.GetLog(eval(args[0]),))
259 |
260 | else:
261 | print('Unrecognized method %s' % cmd)
262 | sys.exit(1)
263 |
264 | transport.close()
265 |
--------------------------------------------------------------------------------
/TCLIService/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['ttypes', 'constants', 'TCLIService']
2 |
--------------------------------------------------------------------------------
/TCLIService/constants.py:
--------------------------------------------------------------------------------
1 | #
2 | # Autogenerated by Thrift Compiler (0.10.0)
3 | #
4 | # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
5 | #
6 | # options string: py
7 | #
8 |
9 | from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException
10 | from thrift.protocol.TProtocol import TProtocolException
11 | import sys
12 | from .ttypes import *
13 | PRIMITIVE_TYPES = set((
14 | 0,
15 | 1,
16 | 2,
17 | 3,
18 | 4,
19 | 5,
20 | 6,
21 | 7,
22 | 8,
23 | 9,
24 | 15,
25 | 16,
26 | 17,
27 | 18,
28 | 19,
29 | 20,
30 | 21,
31 | ))
32 | COMPLEX_TYPES = set((
33 | 10,
34 | 11,
35 | 12,
36 | 13,
37 | 14,
38 | ))
39 | COLLECTION_TYPES = set((
40 | 10,
41 | 11,
42 | ))
43 | TYPE_NAMES = {
44 | 0: "BOOLEAN",
45 | 1: "TINYINT",
46 | 2: "SMALLINT",
47 | 3: "INT",
48 | 4: "BIGINT",
49 | 5: "FLOAT",
50 | 6: "DOUBLE",
51 | 7: "STRING",
52 | 8: "TIMESTAMP",
53 | 9: "BINARY",
54 | 10: "ARRAY",
55 | 11: "MAP",
56 | 12: "STRUCT",
57 | 13: "UNIONTYPE",
58 | 15: "DECIMAL",
59 | 16: "NULL",
60 | 17: "DATE",
61 | 18: "VARCHAR",
62 | 19: "CHAR",
63 | 20: "INTERVAL_YEAR_MONTH",
64 | 21: "INTERVAL_DAY_TIME",
65 | }
66 | CHARACTER_MAXIMUM_LENGTH = "characterMaximumLength"
67 | PRECISION = "precision"
68 | SCALE = "scale"
69 |
--------------------------------------------------------------------------------
/dev_requirements.txt:
--------------------------------------------------------------------------------
1 | # test-only packages: pin everything to minimize change
2 | flake8==3.4.1
3 | mock==2.0.0
4 | pycodestyle==2.3.1
5 | pytest==3.2.1
6 | pytest-cov==2.5.1
7 | pytest-flake8==0.8.1
8 | pytest-random==0.2
9 | pytest-timeout==1.2.0
10 |
11 | # actual dependencies: let things break if a package changes
12 | requests>=1.0.0
13 | requests_kerberos>=0.12.0
14 | sasl>=0.2.1
15 | pure-sasl>=0.6.2
16 | kerberos>=1.3.0
17 | thrift>=0.10.0
18 | #thrift_sasl>=0.1.0
19 | git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches
20 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | """
2 | This file can be used to generate a new version of the TCLIService package
3 | using a TCLIService.thrift URL.
4 |
5 | If no URL is specified, the file for Hive 2.3 will be downloaded.
6 |
7 | Usage:
8 |
9 | python generate.py THRIFT_URL
10 |
11 | or
12 |
13 | python generate.py
14 | """
15 | import shutil
16 | import sys
17 | from os import path
18 | from urllib.request import urlopen
19 | import subprocess
20 |
21 | here = path.abspath(path.dirname(__file__))
22 |
23 | PACKAGE = 'TCLIService'
24 | GENERATED = 'gen-py'
25 |
26 | HIVE_SERVER2_URL = \
27 | 'https://raw.githubusercontent.com/apache/hive/branch-2.3/service-rpc/if/TCLIService.thrift'
28 |
29 |
30 | def save_url(url):
31 | data = urlopen(url).read()
32 | file_path = path.join(here, url.rsplit('/', 1)[-1])
33 | with open(file_path, 'wb') as f:
34 | f.write(data)
35 |
36 |
37 | def main(hive_server2_url):
38 | save_url(hive_server2_url)
39 | hive_server2_path = path.join(here, hive_server2_url.rsplit('/', 1)[-1])
40 |
41 | subprocess.call(['thrift', '-r', '--gen', 'py', hive_server2_path])
42 | shutil.move(path.join(here, PACKAGE), path.join(here, PACKAGE + '.old'))
43 | shutil.move(path.join(here, GENERATED, PACKAGE), path.join(here, PACKAGE))
44 | shutil.rmtree(path.join(here, PACKAGE + '.old'))
45 |
46 |
47 | if __name__ == '__main__':
48 | if len(sys.argv) > 1:
49 | url = sys.argv[1]
50 | else:
51 | url = HIVE_SERVER2_URL
52 | main(url)
53 |
--------------------------------------------------------------------------------
/pyhive/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | __version__ = '0.7.0'
4 |
--------------------------------------------------------------------------------
/pyhive/common.py:
--------------------------------------------------------------------------------
1 | """Package private common utilities. Do not use directly.
2 |
3 | Many docstrings in this file are based on PEP-249, which is in the public domain.
4 | """
5 |
6 | from __future__ import absolute_import
7 | from __future__ import unicode_literals
8 | from builtins import bytes
9 | from builtins import int
10 | from builtins import object
11 | from builtins import str
12 | from past.builtins import basestring
13 | from pyhive import exc
14 | import abc
15 | import collections
16 | import time
17 | import datetime
18 | from future.utils import with_metaclass
19 | from itertools import islice
20 |
21 | try:
22 | from collections.abc import Iterable
23 | except ImportError:
24 | from collections import Iterable
25 |
26 |
27 | class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
28 | """Base class for some common DB-API logic"""
29 |
30 | _STATE_NONE = 0
31 | _STATE_RUNNING = 1
32 | _STATE_FINISHED = 2
33 |
34 | def __init__(self, poll_interval=1):
35 | self._poll_interval = poll_interval
36 | self._reset_state()
37 | self.lastrowid = None
38 |
39 | def _reset_state(self):
40 | """Reset state about the previous query in preparation for running another query"""
41 | # State to return as part of DB-API
42 | self._rownumber = 0
43 |
44 | # Internal helper state
45 | self._state = self._STATE_NONE
46 | self._data = collections.deque()
47 | self._columns = None
48 |
49 | def _fetch_while(self, fn):
50 | while fn():
51 | self._fetch_more()
52 | if fn():
53 | time.sleep(self._poll_interval)
54 |
55 | @abc.abstractproperty
56 | def description(self):
57 | raise NotImplementedError # pragma: no cover
58 |
59 | def close(self):
60 | """By default, do nothing"""
61 | pass
62 |
63 | @abc.abstractmethod
64 | def _fetch_more(self):
65 | """Get more results, append it to ``self._data``, and update ``self._state``."""
66 | raise NotImplementedError # pragma: no cover
67 |
68 | @property
69 | def rowcount(self):
70 | """By default, return -1 to indicate that this is not supported."""
71 | return -1
72 |
73 | @abc.abstractmethod
74 | def execute(self, operation, parameters=None):
75 | """Prepare and execute a database operation (query or command).
76 |
77 | Parameters may be provided as sequence or mapping and will be bound to variables in the
78 | operation. Variables are specified in a database-specific notation (see the module's
79 | ``paramstyle`` attribute for details).
80 |
81 | Return values are not defined.
82 | """
83 | raise NotImplementedError # pragma: no cover
84 |
85 | def executemany(self, operation, seq_of_parameters):
86 | """Prepare a database operation (query or command) and then execute it against all parameter
87 | sequences or mappings found in the sequence ``seq_of_parameters``.
88 |
89 | Only the final result set is retained.
90 |
91 | Return values are not defined.
92 | """
93 | for parameters in seq_of_parameters[:-1]:
94 | self.execute(operation, parameters)
95 | while self._state != self._STATE_FINISHED:
96 | self._fetch_more()
97 | if seq_of_parameters:
98 | self.execute(operation, seq_of_parameters[-1])
99 |
100 | def fetchone(self):
101 | """Fetch the next row of a query result set, returning a single sequence, or ``None`` when
102 | no more data is available.
103 |
104 | An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
105 | :py:meth:`execute` did not produce any result set or no call was issued yet.
106 | """
107 | if self._state == self._STATE_NONE:
108 | raise exc.ProgrammingError("No query yet")
109 |
110 | # Sleep until we're done or we have some data to return
111 | self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED)
112 |
113 | if not self._data:
114 | return None
115 | else:
116 | self._rownumber += 1
117 | return self._data.popleft()
118 |
119 | def fetchmany(self, size=None):
120 | """Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a
121 | list of tuples). An empty sequence is returned when no more rows are available.
122 |
123 | The number of rows to fetch per call is specified by the parameter. If it is not given, the
124 | cursor's arraysize determines the number of rows to be fetched. The method should try to
125 | fetch as many rows as indicated by the size parameter. If this is not possible due to the
126 | specified number of rows not being available, fewer rows may be returned.
127 |
128 | An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
129 | :py:meth:`execute` did not produce any result set or no call was issued yet.
130 | """
131 | if size is None:
132 | size = self.arraysize
133 | return list(islice(iter(self.fetchone, None), size))
134 |
135 | def fetchall(self):
136 | """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences
137 | (e.g. a list of tuples).
138 |
139 | An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to
140 | :py:meth:`execute` did not produce any result set or no call was issued yet.
141 | """
142 | return list(iter(self.fetchone, None))
143 |
144 | @property
145 | def arraysize(self):
146 | """This read/write attribute specifies the number of rows to fetch at a time with
147 | :py:meth:`fetchmany`. It defaults to 1 meaning to fetch a single row at a time.
148 | """
149 | return self._arraysize
150 |
151 | @arraysize.setter
152 | def arraysize(self, value):
153 | self._arraysize = value
154 |
155 | def setinputsizes(self, sizes):
156 | """Does nothing by default"""
157 | pass
158 |
159 | def setoutputsize(self, size, column=None):
160 | """Does nothing by default"""
161 | pass
162 |
163 | #
164 | # Optional DB API Extensions
165 | #
166 |
167 | @property
168 | def rownumber(self):
169 | """This read-only attribute should provide the current 0-based index of the cursor in the
170 | result set.
171 |
172 | The index can be seen as index of the cursor in a sequence (the result set). The next fetch
173 | operation will fetch the row indexed by ``rownumber`` in that sequence.
174 | """
175 | return self._rownumber
176 |
177 | def __next__(self):
178 | """Return the next row from the currently executing SQL statement using the same semantics
179 | as :py:meth:`fetchone`. A ``StopIteration`` exception is raised when the result set is
180 | exhausted.
181 | """
182 | one = self.fetchone()
183 | if one is None:
184 | raise StopIteration
185 | else:
186 | return one
187 |
188 | next = __next__
189 |
190 | def __iter__(self):
191 | """Return self to make cursors compatible to the iteration protocol."""
192 | return self
193 |
194 |
195 | class DBAPITypeObject(object):
196 | # Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints
197 | def __init__(self, *values):
198 | self.values = values
199 |
200 | def __cmp__(self, other):
201 | if other in self.values:
202 | return 0
203 | if other < self.values:
204 | return 1
205 | else:
206 | return -1
207 |
208 |
209 | class ParamEscaper(object):
210 | _DATE_FORMAT = "%Y-%m-%d"
211 | _TIME_FORMAT = "%H:%M:%S.%f"
212 | _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)
213 |
214 | def escape_args(self, parameters):
215 | if isinstance(parameters, dict):
216 | return {k: self.escape_item(v) for k, v in parameters.items()}
217 | elif isinstance(parameters, (list, tuple)):
218 | return tuple(self.escape_item(x) for x in parameters)
219 | else:
220 | raise exc.ProgrammingError("Unsupported param format: {}".format(parameters))
221 |
222 | def escape_number(self, item):
223 | return item
224 |
225 | def escape_string(self, item):
226 | # Need to decode UTF-8 because of old sqlalchemy.
227 | # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
228 | # as byte strings. The old version always encodes Unicode as byte strings, which breaks
229 | # string formatting here.
230 | if isinstance(item, bytes):
231 | item = item.decode('utf-8')
232 | # This is good enough when backslashes are literal, newlines are just followed, and the way
233 | # to escape a single quote is to put two single quotes.
234 | # (i.e. only special character is single quote)
235 | return "'{}'".format(item.replace("'", "''"))
236 |
237 | def escape_sequence(self, item):
238 | l = map(str, map(self.escape_item, item))
239 | return '(' + ','.join(l) + ')'
240 |
241 | def escape_datetime(self, item, format, cutoff=0):
242 | dt_str = item.strftime(format)
243 | formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
244 | return "'{}'".format(formatted)
245 |
246 | def escape_item(self, item):
247 | if item is None:
248 | return 'NULL'
249 | elif isinstance(item, (int, float)):
250 | return self.escape_number(item)
251 | elif isinstance(item, basestring):
252 | return self.escape_string(item)
253 | elif isinstance(item, Iterable):
254 | return self.escape_sequence(item)
255 | elif isinstance(item, datetime.datetime):
256 | return self.escape_datetime(item, self._DATETIME_FORMAT)
257 | elif isinstance(item, datetime.date):
258 | return self.escape_datetime(item, self._DATE_FORMAT)
259 | else:
260 | raise exc.ProgrammingError("Unsupported object {}".format(item))
261 |
262 |
263 | class UniversalSet(object):
264 | """set containing everything"""
265 | def __contains__(self, item):
266 | return True
267 |
--------------------------------------------------------------------------------
/pyhive/exc.py:
--------------------------------------------------------------------------------
1 | """
2 | Package private common utilities. Do not use directly.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import unicode_literals
6 |
7 | __all__ = [
8 | 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError',
9 | 'ProgrammingError', 'DataError', 'NotSupportedError',
10 | ]
11 |
12 |
13 | class Error(Exception):
14 | """Exception that is the base class of all other error exceptions.
15 |
16 | You can use this to catch all errors with one single except statement.
17 | """
18 | pass
19 |
20 |
21 | class Warning(Exception):
22 | """Exception raised for important warnings like data truncations while inserting, etc."""
23 | pass
24 |
25 |
26 | class InterfaceError(Error):
27 | """Exception raised for errors that are related to the database interface rather than the
28 | database itself.
29 | """
30 | pass
31 |
32 |
33 | class DatabaseError(Error):
34 | """Exception raised for errors that are related to the database."""
35 | pass
36 |
37 |
38 | class InternalError(DatabaseError):
39 | """Exception raised when the database encounters an internal error, e.g. the cursor is not valid
40 | anymore, the transaction is out of sync, etc."""
41 | pass
42 |
43 |
44 | class OperationalError(DatabaseError):
45 | """Exception raised for errors that are related to the database's operation and not necessarily
46 | under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name
47 | is not found, a transaction could not be processed, a memory allocation error occurred during
48 | processing, etc.
49 | """
50 | pass
51 |
52 |
53 | class ProgrammingError(DatabaseError):
54 | """Exception raised for programming errors, e.g. table not found or already exists, syntax error
55 | in the SQL statement, wrong number of parameters specified, etc.
56 | """
57 | pass
58 |
59 |
60 | class DataError(DatabaseError):
61 | """Exception raised for errors that are due to problems with the processed data like division by
62 | zero, numeric value out of range, etc.
63 | """
64 | pass
65 |
66 |
67 | class NotSupportedError(DatabaseError):
68 | """Exception raised in case a method or database API was used which is not supported by the
69 | database, e.g. requesting a ``.rollback()`` on a connection that does not support transaction or
70 | has transactions turned off.
71 | """
72 | pass
73 |
--------------------------------------------------------------------------------
/pyhive/hive.py:
--------------------------------------------------------------------------------
1 | """DB-API implementation backed by HiveServer2 (Thrift API)
2 |
3 | See http://www.python.org/dev/peps/pep-0249/
4 |
5 | Many docstrings in this file are based on the PEP, which is in the public domain.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import unicode_literals
10 |
11 | import base64
12 | import datetime
13 | import re
14 | from decimal import Decimal
15 | from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
16 |
17 |
18 | from TCLIService import TCLIService
19 | from TCLIService import constants
20 | from TCLIService import ttypes
21 | from pyhive import common
22 | from pyhive.common import DBAPITypeObject
23 | # Make all exceptions visible in this module per DB-API
24 | from pyhive.exc import * # noqa
25 | from builtins import range
26 | import contextlib
27 | from future.utils import iteritems
28 | import getpass
29 | import logging
30 | import sys
31 | import thrift.transport.THttpClient
32 | import thrift.protocol.TBinaryProtocol
33 | import thrift.transport.TSocket
34 | import thrift.transport.TTransport
35 |
36 | # PEP 249 module globals
37 | apilevel = '2.0'
38 | threadsafety = 2 # Threads may share the module and connections.
39 | paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
40 |
41 | _logger = logging.getLogger(__name__)
42 |
43 | _TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
44 |
45 | ssl_cert_parameter_map = {
46 | "none": CERT_NONE,
47 | "optional": CERT_OPTIONAL,
48 | "required": CERT_REQUIRED,
49 | }
50 |
51 |
52 | def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
53 | import sasl
54 | sasl_client = sasl.Client()
55 | sasl_client.setAttr('host', host)
56 |
57 | if sasl_auth == 'GSSAPI':
58 | sasl_client.setAttr('service', service)
59 | elif sasl_auth == 'PLAIN':
60 | sasl_client.setAttr('username', username)
61 | sasl_client.setAttr('password', password)
62 | else:
63 | raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
64 |
65 | sasl_client.init()
66 | return sasl_client
67 |
68 |
69 | def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
70 | from pyhive.sasl_compat import PureSASLClient
71 |
72 | if sasl_auth == 'GSSAPI':
73 | sasl_kwargs = {'service': service}
74 | elif sasl_auth == 'PLAIN':
75 | sasl_kwargs = {'username': username, 'password': password}
76 | else:
77 | raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
78 |
79 | return PureSASLClient(host=host, **sasl_kwargs)
80 |
81 |
82 | def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
83 | try:
84 | return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
85 | # The sasl library is available
86 | except ImportError:
87 | # Fallback to pure-sasl library
88 | return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
89 |
90 |
91 | def _parse_timestamp(value):
92 | if value:
93 | match = _TIMESTAMP_PATTERN.match(value)
94 | if match:
95 | if match.group(2):
96 | format = '%Y-%m-%d %H:%M:%S.%f'
97 | # use the pattern to truncate the value
98 | value = match.group()
99 | else:
100 | format = '%Y-%m-%d %H:%M:%S'
101 | value = datetime.datetime.strptime(value, format)
102 | else:
103 | raise Exception(
104 | 'Cannot convert "{}" into a datetime'.format(value))
105 | else:
106 | value = None
107 | return value
108 |
109 |
110 | TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
111 | "TIMESTAMP_TYPE": _parse_timestamp}
112 |
113 |
114 | class HiveParamEscaper(common.ParamEscaper):
115 | def escape_string(self, item):
116 | # backslashes and single quotes need to be escaped
117 | # TODO verify against parser
118 | # Need to decode UTF-8 because of old sqlalchemy.
119 | # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
120 | # as byte strings. The old version always encodes Unicode as byte strings, which breaks
121 | # string formatting here.
122 | if isinstance(item, bytes):
123 | item = item.decode('utf-8')
124 | return "'{}'".format(
125 | item
126 | .replace('\\', '\\\\')
127 | .replace("'", "\\'")
128 | .replace('\r', '\\r')
129 | .replace('\n', '\\n')
130 | .replace('\t', '\\t')
131 | )
132 |
133 |
134 | _escaper = HiveParamEscaper()
135 |
136 |
137 | def connect(*args, **kwargs):
138 | """Constructor for creating a connection to the database. See class :py:class:`Connection` for
139 | arguments.
140 |
141 | :returns: a :py:class:`Connection` object.
142 | """
143 | return Connection(*args, **kwargs)
144 |
145 |
146 | class Connection(object):
147 | """Wraps a Thrift session"""
148 |
149 | def __init__(
150 | self,
151 | host=None,
152 | port=None,
153 | scheme=None,
154 | username=None,
155 | database='default',
156 | auth=None,
157 | configuration=None,
158 | kerberos_service_name=None,
159 | password=None,
160 | check_hostname=None,
161 | ssl_cert=None,
162 | thrift_transport=None
163 | ):
164 | """Connect to HiveServer2
165 |
166 | :param host: What host HiveServer2 runs on
167 | :param port: What port HiveServer2 runs on. Defaults to 10000.
168 | :param auth: The value of hive.server2.authentication used by HiveServer2.
169 | Defaults to ``NONE``.
170 | :param configuration: A dictionary of Hive settings (functionally same as the `set` command)
171 | :param kerberos_service_name: Use with auth='KERBEROS' only
172 | :param password: Use with auth='LDAP' or auth='CUSTOM' only
173 | :param thrift_transport: A ``TTransportBase`` for custom advanced usage.
174 | Incompatible with host, port, auth, kerberos_service_name, and password.
175 |
176 | The way to support LDAP and GSSAPI is originated from cloudera/Impyla:
177 | https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
178 | /impala/_thrift_api.py#L152-L160
179 | """
180 | if scheme in ("https", "http") and thrift_transport is None:
181 | port = port or 1000
182 | ssl_context = None
183 | if scheme == "https":
184 | ssl_context = create_default_context()
185 | ssl_context.check_hostname = check_hostname == "true"
186 | ssl_cert = ssl_cert or "none"
187 | ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE)
188 | thrift_transport = thrift.transport.THttpClient.THttpClient(
189 | uri_or_host="{scheme}://{host}:{port}/cliservice/".format(
190 | scheme=scheme, host=host, port=port
191 | ),
192 | ssl_context=ssl_context,
193 | )
194 |
195 | if auth in ("BASIC", "NOSASL", "NONE", None):
196 | # Always needs the Authorization header
197 | self._set_authorization_header(thrift_transport, username, password)
198 | elif auth == "KERBEROS" and kerberos_service_name:
199 | self._set_kerberos_header(thrift_transport, kerberos_service_name, host)
200 | else:
201 | raise ValueError(
202 | "Authentication is not valid use one of:"
203 | "BASIC, NOSASL, KERBEROS, NONE"
204 | )
205 | host, port, auth, kerberos_service_name, password = (
206 | None, None, None, None, None
207 | )
208 |
209 | username = username or getpass.getuser()
210 | configuration = configuration or {}
211 |
212 | if (password is not None) != (auth in ('LDAP', 'CUSTOM')):
213 | raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; "
214 | "Remove password or use one of those modes")
215 | if (kerberos_service_name is not None) != (auth == 'KERBEROS'):
216 | raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode")
217 | if thrift_transport is not None:
218 | has_incompatible_arg = (
219 | host is not None
220 | or port is not None
221 | or auth is not None
222 | or kerberos_service_name is not None
223 | or password is not None
224 | )
225 | if has_incompatible_arg:
226 | raise ValueError("thrift_transport cannot be used with "
227 | "host/port/auth/kerberos_service_name/password")
228 |
229 | if thrift_transport is not None:
230 | self._transport = thrift_transport
231 | else:
232 | if port is None:
233 | port = 10000
234 | if auth is None:
235 | auth = 'NONE'
236 | socket = thrift.transport.TSocket.TSocket(host, port)
237 | if auth == 'NOSASL':
238 | # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
239 | self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
240 | elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
241 | # Defer import so package dependency is optional
242 | import thrift_sasl
243 |
244 | if auth == 'KERBEROS':
245 | # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
246 | sasl_auth = 'GSSAPI'
247 | else:
248 | sasl_auth = 'PLAIN'
249 | if password is None:
250 | # Password doesn't matter in NONE mode, just needs to be nonempty.
251 | password = 'x'
252 |
253 | self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
254 | else:
255 | # All HS2 config options:
256 | # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
257 | # PAM currently left to end user via thrift_transport option.
258 | raise NotImplementedError(
259 | "Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM "
260 | "authentication are supported, got {}".format(auth))
261 |
262 | protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
263 | self._client = TCLIService.Client(protocol)
264 | # oldest version that still contains features we care about
265 | # "V6 uses binary type for binary payload (was string) and uses columnar result set"
266 | protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
267 |
268 | try:
269 | self._transport.open()
270 | open_session_req = ttypes.TOpenSessionReq(
271 | client_protocol=protocol_version,
272 | configuration=configuration,
273 | username=username,
274 | )
275 | response = self._client.OpenSession(open_session_req)
276 | _check_status(response)
277 | assert response.sessionHandle is not None, "Expected a session from OpenSession"
278 | self._sessionHandle = response.sessionHandle
279 | assert response.serverProtocolVersion == protocol_version, \
280 | "Unable to handle protocol version {}".format(response.serverProtocolVersion)
281 | with contextlib.closing(self.cursor()) as cursor:
282 | cursor.execute('USE `{}`'.format(database))
283 | except:
284 | self._transport.close()
285 | raise
286 |
287 | @staticmethod
288 | def _set_authorization_header(transport, username=None, password=None):
289 | username = username or "user"
290 | password = password or "pass"
291 | auth_credentials = "{username}:{password}".format(
292 | username=username, password=password
293 | ).encode("UTF-8")
294 | auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode(
295 | "UTF-8"
296 | )
297 | transport.setCustomHeaders(
298 | {
299 | "Authorization": "Basic {auth_credentials_base64}".format(
300 | auth_credentials_base64=auth_credentials_base64
301 | )
302 | }
303 | )
304 |
305 | @staticmethod
306 | def _set_kerberos_header(transport, kerberos_service_name, host):
307 | import kerberos
308 |
309 | __, krb_context = kerberos.authGSSClientInit(
310 | service="{kerberos_service_name}@{host}".format(
311 | kerberos_service_name=kerberos_service_name, host=host
312 | )
313 | )
314 | kerberos.authGSSClientClean(krb_context, "")
315 | kerberos.authGSSClientStep(krb_context, "")
316 | auth_header = kerberos.authGSSClientResponse(krb_context)
317 |
318 | transport.setCustomHeaders(
319 | {
320 | "Authorization": "Negotiate {auth_header}".format(
321 | auth_header=auth_header
322 | )
323 | }
324 | )
325 |
326 | def __enter__(self):
327 | """Transport should already be opened by __init__"""
328 | return self
329 |
330 | def __exit__(self, exc_type, exc_val, exc_tb):
331 | """Call close"""
332 | self.close()
333 |
334 | def close(self):
335 | """Close the underlying session and Thrift transport"""
336 | req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
337 | response = self._client.CloseSession(req)
338 | self._transport.close()
339 | _check_status(response)
340 |
341 | def commit(self):
342 | """Hive does not support transactions, so this does nothing."""
343 | pass
344 |
345 | def cursor(self, *args, **kwargs):
346 | """Return a new :py:class:`Cursor` object using the connection."""
347 | return Cursor(self, *args, **kwargs)
348 |
349 | @property
350 | def client(self):
351 | return self._client
352 |
353 | @property
354 | def sessionHandle(self):
355 | return self._sessionHandle
356 |
357 | def rollback(self):
358 | raise NotSupportedError("Hive does not have transactions") # pragma: no cover
359 |
360 |
361 | class Cursor(common.DBAPICursor):
362 | """These objects represent a database cursor, which is used to manage the context of a fetch
363 | operation.
364 |
365 | Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
366 | visible by other cursors or connections.
367 | """
368 |
369 | def __init__(self, connection, arraysize=1000):
370 | self._operationHandle = None
371 | super(Cursor, self).__init__()
372 | self._arraysize = arraysize
373 | self._connection = connection
374 |
375 | def _reset_state(self):
376 | """Reset state about the previous query in preparation for running another query"""
377 | super(Cursor, self)._reset_state()
378 | self._description = None
379 | if self._operationHandle is not None:
380 | request = ttypes.TCloseOperationReq(self._operationHandle)
381 | try:
382 | response = self._connection.client.CloseOperation(request)
383 | _check_status(response)
384 | finally:
385 | self._operationHandle = None
386 |
387 | @property
388 | def arraysize(self):
389 | return self._arraysize
390 |
391 | @arraysize.setter
392 | def arraysize(self, value):
393 | """Array size cannot be None, and should be an integer"""
394 | default_arraysize = 1000
395 | try:
396 | self._arraysize = int(value) or default_arraysize
397 | except TypeError:
398 | self._arraysize = default_arraysize
399 |
400 | @property
401 | def description(self):
402 | """This read-only attribute is a sequence of 7-item sequences.
403 |
404 | Each of these sequences contains information describing one result column:
405 |
406 | - name
407 | - type_code
408 | - display_size (None in current implementation)
409 | - internal_size (None in current implementation)
410 | - precision (None in current implementation)
411 | - scale (None in current implementation)
412 | - null_ok (always True in current implementation)
413 |
414 | This attribute will be ``None`` for operations that do not return rows or if the cursor has
415 | not had an operation invoked via the :py:meth:`execute` method yet.
416 |
417 | The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
418 | section below.
419 | """
420 | if self._operationHandle is None or not self._operationHandle.hasResultSet:
421 | return None
422 | if self._description is None:
423 | req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
424 | response = self._connection.client.GetResultSetMetadata(req)
425 | _check_status(response)
426 | columns = response.schema.columns
427 | self._description = []
428 | for col in columns:
429 | primary_type_entry = col.typeDesc.types[0]
430 | if primary_type_entry.primitiveEntry is None:
431 | # All fancy stuff maps to string
432 | type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
433 | else:
434 | type_id = primary_type_entry.primitiveEntry.type
435 | type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
436 | self._description.append((
437 | col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName,
438 | type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code,
439 | None, None, None, None, True
440 | ))
441 | return self._description
442 |
443 | def __enter__(self):
444 | return self
445 |
446 | def __exit__(self, exc_type, exc_val, exc_tb):
447 | self.close()
448 |
449 | def close(self):
450 | """Close the operation handle"""
451 | self._reset_state()
452 |
453 | def execute(self, operation, parameters=None, **kwargs):
454 | """Prepare and execute a database operation (query or command).
455 |
456 | Return values are not defined.
457 | """
458 | # backward compatibility with Python < 3.7
459 | for kw in ['async', 'async_']:
460 | if kw in kwargs:
461 | async_ = kwargs[kw]
462 | break
463 | else:
464 | async_ = False
465 |
466 | # Prepare statement
467 | if parameters is None:
468 | sql = operation
469 | else:
470 | sql = operation % _escaper.escape_args(parameters)
471 |
472 | self._reset_state()
473 |
474 | self._state = self._STATE_RUNNING
475 | _logger.info('%s', sql)
476 |
477 | req = ttypes.TExecuteStatementReq(self._connection.sessionHandle,
478 | sql, runAsync=async_)
479 | _logger.debug(req)
480 | response = self._connection.client.ExecuteStatement(req)
481 | _check_status(response)
482 | self._operationHandle = response.operationHandle
483 |
484 | def cancel(self):
485 | req = ttypes.TCancelOperationReq(
486 | operationHandle=self._operationHandle,
487 | )
488 | response = self._connection.client.CancelOperation(req)
489 | _check_status(response)
490 |
491 | def _fetch_more(self):
492 | """Send another TFetchResultsReq and update state"""
493 | assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
494 | assert(self._operationHandle is not None), "Should have an op handle in _fetch_more"
495 | if not self._operationHandle.hasResultSet:
496 | raise ProgrammingError("No result set")
497 | req = ttypes.TFetchResultsReq(
498 | operationHandle=self._operationHandle,
499 | orientation=ttypes.TFetchOrientation.FETCH_NEXT,
500 | maxRows=self.arraysize,
501 | )
502 | response = self._connection.client.FetchResults(req)
503 | _check_status(response)
504 | schema = self.description
505 | assert not response.results.rows, 'expected data in columnar format'
506 | columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
507 | zip(response.results.columns, schema)]
508 | new_data = list(zip(*columns))
509 | self._data += new_data
510 | # response.hasMoreRows seems to always be False, so we instead check the number of rows
511 | # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
512 | # if not response.hasMoreRows:
513 | if not new_data:
514 | self._state = self._STATE_FINISHED
515 |
516 | def poll(self, get_progress_update=True):
517 | """Poll for and return the raw status data provided by the Hive Thrift REST API.
518 | :returns: ``ttypes.TGetOperationStatusResp``
519 | :raises: ``ProgrammingError`` when no query has been started
520 | .. note::
521 | This is not a part of DB-API.
522 | """
523 | if self._state == self._STATE_NONE:
524 | raise ProgrammingError("No query yet")
525 |
526 | req = ttypes.TGetOperationStatusReq(
527 | operationHandle=self._operationHandle,
528 | getProgressUpdate=get_progress_update,
529 | )
530 | response = self._connection.client.GetOperationStatus(req)
531 | _check_status(response)
532 |
533 | return response
534 |
535 | def fetch_logs(self):
536 | """Retrieve the logs produced by the execution of the query.
537 | Can be called multiple times to fetch the logs produced after the previous call.
538 | :returns: list
539 | :raises: ``ProgrammingError`` when no query has been started
540 | .. note::
541 | This is not a part of DB-API.
542 | """
543 | if self._state == self._STATE_NONE:
544 | raise ProgrammingError("No query yet")
545 |
546 | try: # Older Hive instances require logs to be retrieved using GetLog
547 | req = ttypes.TGetLogReq(operationHandle=self._operationHandle)
548 | logs = self._connection.client.GetLog(req).log.splitlines()
549 | except ttypes.TApplicationException as e: # Otherwise, retrieve logs using newer method
550 | if e.type != ttypes.TApplicationException.UNKNOWN_METHOD:
551 | raise
552 | logs = []
553 | while True:
554 | req = ttypes.TFetchResultsReq(
555 | operationHandle=self._operationHandle,
556 | orientation=ttypes.TFetchOrientation.FETCH_NEXT,
557 | maxRows=self.arraysize,
558 | fetchType=1 # 0: results, 1: logs
559 | )
560 | response = self._connection.client.FetchResults(req)
561 | _check_status(response)
562 | assert not response.results.rows, 'expected data in columnar format'
563 | assert len(response.results.columns) == 1, response.results.columns
564 | new_logs = _unwrap_column(response.results.columns[0])
565 | logs += new_logs
566 |
567 | if not new_logs:
568 | break
569 |
570 | return logs
571 |
572 |
573 | #
574 | # Type Objects and Constructors
575 | #
576 |
577 |
578 | for type_id in constants.PRIMITIVE_TYPES:
579 | name = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
580 | setattr(sys.modules[__name__], name, DBAPITypeObject([name]))
581 |
582 |
583 | #
584 | # Private utilities
585 | #
586 |
587 |
588 | def _unwrap_column(col, type_=None):
589 | """Return a list of raw values from a TColumn instance."""
590 | for attr, wrapper in iteritems(col.__dict__):
591 | if wrapper is not None:
592 | result = wrapper.values
593 | nulls = wrapper.nulls # bit set describing what's null
594 | assert isinstance(nulls, bytes)
595 | for i, char in enumerate(nulls):
596 | byte = ord(char) if sys.version_info[0] == 2 else char
597 | for b in range(8):
598 | if byte & (1 << b):
599 | result[i * 8 + b] = None
600 | converter = TYPES_CONVERTER.get(type_, None)
601 | if converter and type_:
602 | result = [converter(row) if row else row for row in result]
603 | return result
604 | raise DataError("Got empty column value {}".format(col)) # pragma: no cover
605 |
606 |
607 | def _check_status(response):
608 | """Raise an OperationalError if the status is not success"""
609 | _logger.debug(response)
610 | if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
611 | raise OperationalError(response)
612 |
--------------------------------------------------------------------------------
/pyhive/presto.py:
--------------------------------------------------------------------------------
1 | """DB-API implementation backed by Presto
2 |
3 | See http://www.python.org/dev/peps/pep-0249/
4 |
5 | Many docstrings in this file are based on the PEP, which is in the public domain.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import unicode_literals
10 |
11 | from builtins import object
12 | from decimal import Decimal
13 |
14 | from pyhive import common
15 | from pyhive.common import DBAPITypeObject
16 | # Make all exceptions visible in this module per DB-API
17 | from pyhive.exc import * # noqa
18 | import base64
19 | import getpass
20 | import datetime
21 | import logging
22 | import requests
23 | from requests.auth import HTTPBasicAuth
24 | import os
25 |
26 | try: # Python 3
27 | import urllib.parse as urlparse
28 | except ImportError: # Python 2
29 | import urlparse
30 |
31 |
32 | # PEP 249 module globals
33 | apilevel = '2.0'
34 | threadsafety = 2 # Threads may share the module and connections.
35 | paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
36 |
37 | _logger = logging.getLogger(__name__)
38 |
39 | TYPES_CONVERTER = {
40 | "decimal": Decimal,
41 | # As of Presto 0.69, binary data is returned as the varbinary type in base64 format
42 | "varbinary": base64.b64decode
43 | }
44 |
45 | class PrestoParamEscaper(common.ParamEscaper):
46 | def escape_datetime(self, item, format):
47 | _type = "timestamp" if isinstance(item, datetime.datetime) else "date"
48 | formatted = super(PrestoParamEscaper, self).escape_datetime(item, format, 3)
49 | return "{} {}".format(_type, formatted)
50 |
51 |
52 | _escaper = PrestoParamEscaper()
53 |
54 |
55 | def connect(*args, **kwargs):
56 | """Constructor for creating a connection to the database. See class :py:class:`Connection` for
57 | arguments.
58 |
59 | :returns: a :py:class:`Connection` object.
60 | """
61 | return Connection(*args, **kwargs)
62 |
63 |
64 | class Connection(object):
65 | """Presto does not have a notion of a persistent connection.
66 |
67 | Thus, these objects are small stateless factories for cursors, which do all the real work.
68 | """
69 |
70 | def __init__(self, *args, **kwargs):
71 | self._args = args
72 | self._kwargs = kwargs
73 |
74 | def close(self):
75 | """Presto does not have anything to close"""
76 | # TODO cancel outstanding queries?
77 | pass
78 |
79 | def commit(self):
80 | """Presto does not support transactions"""
81 | pass
82 |
83 | def cursor(self):
84 | """Return a new :py:class:`Cursor` object using the connection."""
85 | return Cursor(*self._args, **self._kwargs)
86 |
87 | def rollback(self):
88 | raise NotSupportedError("Presto does not have transactions") # pragma: no cover
89 |
90 |
91 | class Cursor(common.DBAPICursor):
92 | """These objects represent a database cursor, which is used to manage the context of a fetch
93 | operation.
94 |
95 | Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
96 | visible by other cursors or connections.
97 | """
98 |
99 | def __init__(self, host, port='8080', username=None, principal_username=None, catalog='hive',
100 | schema='default', poll_interval=1, source='pyhive', session_props=None,
101 | protocol='http', password=None, requests_session=None, requests_kwargs=None,
102 | KerberosRemoteServiceName=None, KerberosPrincipal=None,
103 | KerberosConfigPath=None, KerberosKeytabPath=None,
104 | KerberosCredentialCachePath=None, KerberosUseCanonicalHostname=None):
105 | """
106 | :param host: hostname to connect to, e.g. ``presto.example.com``
107 | :param port: int -- port, defaults to 8080
108 | :param username: string -- defaults to system user name
109 | :param principal_username: string -- defaults to ``username`` argument if it exists,
110 | else defaults to system user name
111 | :param catalog: string -- defaults to ``hive``
112 | :param schema: string -- defaults to ``default``
113 | :param poll_interval: float -- how often to ask the Presto REST interface for a progress
114 | update, defaults to a second
115 | :param source: string -- arbitrary identifier (shows up in the Presto monitoring page)
116 | :param protocol: string -- network protocol, valid options are ``http`` and ``https``.
117 | defaults to ``http``
118 | :param password: string -- Deprecated. Defaults to ``None``.
119 | Using BasicAuth, requires ``https``.
120 | Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``.
121 | May not be specified with ``requests_kwargs['auth']``.
122 | :param requests_session: a ``requests.Session`` object for advanced usage. If absent, this
123 | class will use the default requests behavior of making a new session per HTTP request.
124 | Caller is responsible for closing session.
125 | :param requests_kwargs: Additional ``**kwargs`` to pass to requests
126 | :param KerberosRemoteServiceName: string -- Presto coordinator Kerberos service name.
127 | This parameter is required for Kerberos authentiation.
128 | :param KerberosPrincipal: string -- The principal to use when authenticating to
129 | the Presto coordinator.
130 | :param KerberosConfigPath: string -- Kerberos configuration file.
131 | (default: /etc/krb5.conf)
132 | :param KerberosKeytabPath: string -- Kerberos keytab file.
133 | :param KerberosCredentialCachePath: string -- Kerberos credential cache.
134 | :param KerberosUseCanonicalHostname: boolean -- Use the canonical hostname of the
135 | Presto coordinator for the Kerberos service principal by first resolving the
136 | hostname to an IP address and then doing a reverse DNS lookup for that IP address.
137 | This is enabled by default.
138 | """
139 | super(Cursor, self).__init__(poll_interval)
140 | # Config
141 | self._host = host
142 | self._port = port
143 | """
144 | Presto User Impersonation: https://docs.starburstdata.com/latest/security/impersonation.html
145 |
146 | User impersonation allows the execution of queries in Presto based on principal_username
147 | argument, instead of executing the query as the account which authenticated against Presto.
148 | (Usually a service account)
149 |
150 | Allows for a service account to authenticate with Presto, and then leverage the
151 | principal_username as the user Presto will execute the query as. This is required by
152 | applications that leverage authentication methods like SAML, where the application has a
153 | username, but not a password to still leverage user specific Presto Resource Groups and
154 | Authorization rules that would not be applied when only using a shared service account.
155 | This also allows auditing of who is executing a query in these environments, instead of
156 | having all queryes run by the shared service account.
157 | """
158 | self._username = principal_username or username or getpass.getuser()
159 | self._catalog = catalog
160 | self._schema = schema
161 | self._arraysize = 1
162 | self._poll_interval = poll_interval
163 | self._source = source
164 | self._session_props = session_props if session_props is not None else {}
165 | self.last_query_id = None
166 |
167 | if protocol not in ('http', 'https'):
168 | raise ValueError("Protocol must be http/https, was {!r}".format(protocol))
169 | self._protocol = protocol
170 |
171 | self._requests_session = requests_session or requests
172 |
173 | requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {}
174 |
175 | if KerberosRemoteServiceName is not None:
176 | from requests_kerberos import HTTPKerberosAuth, OPTIONAL
177 |
178 | hostname_override = None
179 | if KerberosUseCanonicalHostname is not None \
180 | and KerberosUseCanonicalHostname.lower() == 'false':
181 | hostname_override = host
182 | if KerberosConfigPath is not None:
183 | os.environ['KRB5_CONFIG'] = KerberosConfigPath
184 | if KerberosKeytabPath is not None:
185 | os.environ['KRB5_CLIENT_KTNAME'] = KerberosKeytabPath
186 | if KerberosCredentialCachePath is not None:
187 | os.environ['KRB5CCNAME'] = KerberosCredentialCachePath
188 |
189 | requests_kwargs['auth'] = HTTPKerberosAuth(mutual_authentication=OPTIONAL,
190 | principal=KerberosPrincipal,
191 | service=KerberosRemoteServiceName,
192 | hostname_override=hostname_override)
193 |
194 | else:
195 | if password is not None and 'auth' in requests_kwargs:
196 | raise ValueError("Cannot use both password and requests_kwargs authentication")
197 | for k in ('method', 'url', 'data', 'headers'):
198 | if k in requests_kwargs:
199 | raise ValueError("Cannot override requests argument {}".format(k))
200 | if password is not None:
201 | requests_kwargs['auth'] = HTTPBasicAuth(username, password)
202 | if protocol != 'https':
203 | raise ValueError("Protocol must be https when passing a password")
204 | self._requests_kwargs = requests_kwargs
205 |
206 | self._reset_state()
207 |
208 | def _reset_state(self):
209 | """Reset state about the previous query in preparation for running another query"""
210 | super(Cursor, self)._reset_state()
211 | self._nextUri = None
212 | self._columns = None
213 |
214 | @property
215 | def description(self):
216 | """This read-only attribute is a sequence of 7-item sequences.
217 |
218 | Each of these sequences contains information describing one result column:
219 |
220 | - name
221 | - type_code
222 | - display_size (None in current implementation)
223 | - internal_size (None in current implementation)
224 | - precision (None in current implementation)
225 | - scale (None in current implementation)
226 | - null_ok (always True in current implementation)
227 |
228 | The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
229 | section below.
230 | """
231 | # Sleep until we're done or we got the columns
232 | self._fetch_while(
233 | lambda: self._columns is None and
234 | self._state not in (self._STATE_NONE, self._STATE_FINISHED)
235 | )
236 | if self._columns is None:
237 | return None
238 | return [
239 | # name, type_code, display_size, internal_size, precision, scale, null_ok
240 | (col['name'], col['type'], None, None, None, None, True)
241 | for col in self._columns
242 | ]
243 |
244 | def execute(self, operation, parameters=None):
245 | """Prepare and execute a database operation (query or command).
246 |
247 | Return values are not defined.
248 | """
249 | headers = {
250 | 'X-Presto-Catalog': self._catalog,
251 | 'X-Presto-Schema': self._schema,
252 | 'X-Presto-Source': self._source,
253 | 'X-Presto-User': self._username,
254 | }
255 |
256 | if self._session_props:
257 | headers['X-Presto-Session'] = ','.join(
258 | '{}={}'.format(propname, propval)
259 | for propname, propval in self._session_props.items()
260 | )
261 |
262 | # Prepare statement
263 | if parameters is None:
264 | sql = operation
265 | else:
266 | sql = operation % _escaper.escape_args(parameters)
267 |
268 | self._reset_state()
269 |
270 | self._state = self._STATE_RUNNING
271 | url = urlparse.urlunparse((
272 | self._protocol,
273 | '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None))
274 | _logger.info('%s', sql)
275 | _logger.debug("Headers: %s", headers)
276 | response = self._requests_session.post(
277 | url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
278 | self._process_response(response)
279 |
280 | def cancel(self):
281 | if self._state == self._STATE_NONE:
282 | raise ProgrammingError("No query yet")
283 | if self._nextUri is None:
284 | assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
285 | return
286 |
287 | response = self._requests_session.delete(self._nextUri, **self._requests_kwargs)
288 | if response.status_code != requests.codes.no_content:
289 | fmt = "Unexpected status code after cancel {}\n{}"
290 | raise OperationalError(fmt.format(response.status_code, response.content))
291 |
292 | self._state = self._STATE_FINISHED
293 | self._nextUri = None
294 |
295 | def poll(self):
296 | """Poll for and return the raw status data provided by the Presto REST API.
297 |
298 | :returns: dict -- JSON status information or ``None`` if the query is done
299 | :raises: ``ProgrammingError`` when no query has been started
300 |
301 | .. note::
302 | This is not a part of DB-API.
303 | """
304 | if self._state == self._STATE_NONE:
305 | raise ProgrammingError("No query yet")
306 | if self._nextUri is None:
307 | assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None"
308 | return None
309 | response = self._requests_session.get(self._nextUri, **self._requests_kwargs)
310 | self._process_response(response)
311 | return response.json()
312 |
313 | def _fetch_more(self):
314 | """Fetch the next URI and update state"""
315 | self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))
316 |
317 | def _process_data(self, rows):
318 | for i, col in enumerate(self.description):
319 | col_type = col[1].split("(")[0].lower()
320 | if col_type in TYPES_CONVERTER:
321 | for row in rows:
322 | if row[i] is not None:
323 | row[i] = TYPES_CONVERTER[col_type](row[i])
324 |
325 | def _process_response(self, response):
326 | """Given the JSON response from Presto's REST API, update the internal state with the next
327 | URI and any data from the response
328 | """
329 | # TODO handle HTTP 503
330 | if response.status_code != requests.codes.ok:
331 | fmt = "Unexpected status code {}\n{}"
332 | raise OperationalError(fmt.format(response.status_code, response.content))
333 |
334 | response_json = response.json()
335 | _logger.debug("Got response %s", response_json)
336 | assert self._state == self._STATE_RUNNING, "Should be running if processing response"
337 | self._nextUri = response_json.get('nextUri')
338 | self._columns = response_json.get('columns')
339 | if 'id' in response_json:
340 | self.last_query_id = response_json['id']
341 | if 'X-Presto-Clear-Session' in response.headers:
342 | propname = response.headers['X-Presto-Clear-Session']
343 | self._session_props.pop(propname, None)
344 | if 'X-Presto-Set-Session' in response.headers:
345 | propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1)
346 | self._session_props[propname] = propval
347 | if 'data' in response_json:
348 | assert self._columns
349 | new_data = response_json['data']
350 | self._process_data(new_data)
351 | self._data += map(tuple, new_data)
352 | if 'nextUri' not in response_json:
353 | self._state = self._STATE_FINISHED
354 | if 'error' in response_json:
355 | raise DatabaseError(response_json['error'])
356 |
357 |
358 | #
359 | # Type Objects and Constructors
360 | #
361 |
362 |
363 | # See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java
364 | FIXED_INT_64 = DBAPITypeObject(['bigint'])
365 | VARIABLE_BINARY = DBAPITypeObject(['varchar'])
366 | DOUBLE = DBAPITypeObject(['double'])
367 | BOOLEAN = DBAPITypeObject(['boolean'])
368 |
--------------------------------------------------------------------------------
/pyhive/sasl_compat.py:
--------------------------------------------------------------------------------
1 | # Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
2 | # which uses Apache-2.0 license as of 21 May 2023.
3 | # This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
4 | # via PR https://github.com/cloudera/impyla/pull/179
5 | # Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
6 | # but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
7 | # Hence this code is required for the fallback to work.
8 |
9 |
10 | from puresasl.client import SASLClient, SASLError
11 | from contextlib import contextmanager
12 |
13 | @contextmanager
14 | def error_catcher(self, Exc = Exception):
15 | try:
16 | self.error = None
17 | yield
18 | except Exc as e:
19 | self.error = str(e)
20 |
21 |
22 | class PureSASLClient(SASLClient):
23 | def __init__(self, *args, **kwargs):
24 | self.error = None
25 | super(PureSASLClient, self).__init__(*args, **kwargs)
26 |
27 | def start(self, mechanism):
28 | with error_catcher(self, SASLError):
29 | if isinstance(mechanism, list):
30 | self.choose_mechanism(mechanism)
31 | else:
32 | self.choose_mechanism([mechanism])
33 | return True, self.mechanism, self.process()
34 | # else
35 | return False, mechanism, None
36 |
37 | def encode(self, incoming):
38 | with error_catcher(self):
39 | return True, self.unwrap(incoming)
40 | # else
41 | return False, None
42 |
43 | def decode(self, outgoing):
44 | with error_catcher(self):
45 | return True, self.wrap(outgoing)
46 | # else
47 | return False, None
48 |
49 | def step(self, challenge=None):
50 | with error_catcher(self):
51 | return True, self.process(challenge)
52 | # else
53 | return False, None
54 |
55 | def getError(self):
56 | return self.error
57 |
--------------------------------------------------------------------------------
/pyhive/sqlalchemy_hive.py:
--------------------------------------------------------------------------------
1 | """Integration between SQLAlchemy and Hive.
2 |
3 | Some code based on
4 | https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
5 | which is released under the MIT license.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import unicode_literals
10 |
11 | import datetime
12 | import decimal
13 |
14 | import re
15 | from sqlalchemy import exc
16 | from sqlalchemy.sql import text
17 | try:
18 | from sqlalchemy import processors
19 | except ImportError:
20 | # Required for SQLAlchemy>=2.0
21 | from sqlalchemy.engine import processors
22 | from sqlalchemy import types
23 | from sqlalchemy import util
24 | # TODO shouldn't use mysql type
25 | try:
26 | from sqlalchemy.databases import mysql
27 | mysql_tinyinteger = mysql.MSTinyInteger
28 | except ImportError:
29 | # Required for SQLAlchemy>2.0
30 | from sqlalchemy.dialects import mysql
31 | mysql_tinyinteger = mysql.base.MSTinyInteger
32 | from sqlalchemy.engine import default
33 | from sqlalchemy.sql import compiler
34 | from sqlalchemy.sql.compiler import SQLCompiler
35 |
36 | from pyhive import hive
37 | from pyhive.common import UniversalSet
38 |
39 | from dateutil.parser import parse
40 | from decimal import Decimal
41 |
42 |
43 | class HiveStringTypeBase(types.TypeDecorator):
44 | """Translates strings returned by Thrift into something else"""
45 | impl = types.String
46 |
47 | def process_bind_param(self, value, dialect):
48 | raise NotImplementedError("Writing to Hive not supported")
49 |
50 |
51 | class HiveDate(HiveStringTypeBase):
52 | """Translates date strings to date objects"""
53 | impl = types.DATE
54 |
55 | def process_result_value(self, value, dialect):
56 | return processors.str_to_date(value)
57 |
58 | def result_processor(self, dialect, coltype):
59 | def process(value):
60 | if isinstance(value, datetime.datetime):
61 | return value.date()
62 | elif isinstance(value, datetime.date):
63 | return value
64 | elif value is not None:
65 | return parse(value).date()
66 | else:
67 | return None
68 |
69 | return process
70 |
71 | def adapt(self, impltype, **kwargs):
72 | return self.impl
73 |
74 |
75 | class HiveTimestamp(HiveStringTypeBase):
76 | """Translates timestamp strings to datetime objects"""
77 | impl = types.TIMESTAMP
78 |
79 | def process_result_value(self, value, dialect):
80 | return processors.str_to_datetime(value)
81 |
82 | def result_processor(self, dialect, coltype):
83 | def process(value):
84 | if isinstance(value, datetime.datetime):
85 | return value
86 | elif value is not None:
87 | return parse(value)
88 | else:
89 | return None
90 |
91 | return process
92 |
93 | def adapt(self, impltype, **kwargs):
94 | return self.impl
95 |
96 |
97 | class HiveDecimal(HiveStringTypeBase):
98 | """Translates strings to decimals"""
99 | impl = types.DECIMAL
100 |
101 | def process_result_value(self, value, dialect):
102 | if value is not None:
103 | return decimal.Decimal(value)
104 | else:
105 | return None
106 |
107 | def result_processor(self, dialect, coltype):
108 | def process(value):
109 | if isinstance(value, Decimal):
110 | return value
111 | elif value is not None:
112 | return Decimal(value)
113 | else:
114 | return None
115 |
116 | return process
117 |
118 | def adapt(self, impltype, **kwargs):
119 | return self.impl
120 |
121 |
122 | class HiveIdentifierPreparer(compiler.IdentifierPreparer):
123 | # Just quote everything to make things simpler / easier to upgrade
124 | reserved_words = UniversalSet()
125 |
126 | def __init__(self, dialect):
127 | super(HiveIdentifierPreparer, self).__init__(
128 | dialect,
129 | initial_quote='`',
130 | )
131 |
132 |
133 | _type_map = {
134 | 'boolean': types.Boolean,
135 | 'tinyint': mysql_tinyinteger,
136 | 'smallint': types.SmallInteger,
137 | 'int': types.Integer,
138 | 'bigint': types.BigInteger,
139 | 'float': types.Float,
140 | 'double': types.Float,
141 | 'string': types.String,
142 | 'varchar': types.String,
143 | 'char': types.String,
144 | 'date': HiveDate,
145 | 'timestamp': HiveTimestamp,
146 | 'binary': types.String,
147 | 'array': types.String,
148 | 'map': types.String,
149 | 'struct': types.String,
150 | 'uniontype': types.String,
151 | 'decimal': HiveDecimal,
152 | }
153 |
154 |
155 | class HiveCompiler(SQLCompiler):
156 | def visit_concat_op_binary(self, binary, operator, **kw):
157 | return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right))
158 |
159 | def visit_insert(self, *args, **kwargs):
160 | result = super(HiveCompiler, self).visit_insert(*args, **kwargs)
161 | # Massage the result into Hive's format
162 | # INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ...
163 | # =>
164 | # INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ...
165 | regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)'
166 | assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result)
167 | return re.sub(regex, r'\1 TABLE \2', result)
168 |
169 | def visit_column(self, *args, **kwargs):
170 | result = super(HiveCompiler, self).visit_column(*args, **kwargs)
171 | dot_count = result.count('.')
172 | assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result)
173 | if dot_count == 2:
174 | # we have something of the form schema.table.column
175 | # hive doesn't like the schema in front, so chop it out
176 | result = result[result.index('.') + 1:]
177 | return result
178 |
179 | def visit_char_length_func(self, fn, **kw):
180 | return 'length{}'.format(self.function_argspec(fn, **kw))
181 |
182 |
183 | class HiveTypeCompiler(compiler.GenericTypeCompiler):
184 | def visit_INTEGER(self, type_):
185 | return 'INT'
186 |
187 | def visit_NUMERIC(self, type_):
188 | return 'DECIMAL'
189 |
190 | def visit_CHAR(self, type_):
191 | return 'STRING'
192 |
193 | def visit_VARCHAR(self, type_):
194 | return 'STRING'
195 |
196 | def visit_NCHAR(self, type_):
197 | return 'STRING'
198 |
199 | def visit_TEXT(self, type_):
200 | return 'STRING'
201 |
202 | def visit_CLOB(self, type_):
203 | return 'STRING'
204 |
205 | def visit_BLOB(self, type_):
206 | return 'BINARY'
207 |
208 | def visit_TIME(self, type_):
209 | return 'TIMESTAMP'
210 |
211 | def visit_DATE(self, type_):
212 | return 'TIMESTAMP'
213 |
214 | def visit_DATETIME(self, type_):
215 | return 'TIMESTAMP'
216 |
217 |
218 | class HiveExecutionContext(default.DefaultExecutionContext):
219 | """This is pretty much the same as SQLiteExecutionContext to work around the same issue.
220 |
221 | http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#dotted-column-names
222 |
223 | engine = create_engine('hive://...', execution_options={'hive_raw_colnames': True})
224 | """
225 |
226 | @util.memoized_property
227 | def _preserve_raw_colnames(self):
228 | # Ideally, this would also gate on hive.resultset.use.unique.column.names
229 | return self.execution_options.get('hive_raw_colnames', False)
230 |
231 | def _translate_colname(self, colname):
232 | # Adjust for dotted column names.
233 | # When hive.resultset.use.unique.column.names is true (the default), Hive returns column
234 | # names as "tablename.colname" in cursor.description.
235 | if not self._preserve_raw_colnames and '.' in colname:
236 | return colname.split('.')[-1], colname
237 | else:
238 | return colname, None
239 |
240 |
241 | class HiveDialect(default.DefaultDialect):
242 | name = 'hive'
243 | driver = 'thrift'
244 | execution_ctx_cls = HiveExecutionContext
245 | preparer = HiveIdentifierPreparer
246 | statement_compiler = HiveCompiler
247 | supports_views = True
248 | supports_alter = True
249 | supports_pk_autoincrement = False
250 | supports_default_values = False
251 | supports_empty_insert = False
252 | supports_native_decimal = True
253 | supports_native_boolean = True
254 | supports_unicode_statements = True
255 | supports_unicode_binds = True
256 | returns_unicode_strings = True
257 | description_encoding = None
258 | supports_multivalues_insert = True
259 | type_compiler = HiveTypeCompiler
260 | supports_sane_rowcount = False
261 | supports_statement_cache = False
262 |
263 | @classmethod
264 | def dbapi(cls):
265 | return hive
266 |
267 | @classmethod
268 | def import_dbapi(cls):
269 | return hive
270 |
271 | def create_connect_args(self, url):
272 | kwargs = {
273 | 'host': url.host,
274 | 'port': url.port or 10000,
275 | 'username': url.username,
276 | 'password': url.password,
277 | 'database': url.database or 'default',
278 | }
279 | kwargs.update(url.query)
280 | return [], kwargs
281 |
282 | def get_schema_names(self, connection, **kw):
283 | # Equivalent to SHOW DATABASES
284 | return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))]
285 |
286 | def get_view_names(self, connection, schema=None, **kw):
287 | # Hive does not provide functionality to query tableType
288 | # This allows reflection to not crash at the cost of being inaccurate
289 | return self.get_table_names(connection, schema, **kw)
290 |
291 | def _get_table_columns(self, connection, table_name, schema):
292 | full_table = table_name
293 | if schema:
294 | full_table = schema + '.' + table_name
295 | # TODO using TGetColumnsReq hangs after sending TFetchResultsReq.
296 | # Using DESCRIBE works but is uglier.
297 | try:
298 | # This needs the table name to be unescaped (no backticks).
299 | rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall()
300 | except exc.OperationalError as e:
301 | # Does the table exist?
302 | regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
303 | regex = regex_fmt.format(re.escape(full_table))
304 | if re.search(regex, e.args[0]):
305 | raise exc.NoSuchTableError(full_table)
306 | else:
307 | raise
308 | else:
309 | # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist
310 | regex = r'Table .* does not exist'
311 | if len(rows) == 1 and re.match(regex, rows[0].col_name):
312 | raise exc.NoSuchTableError(full_table)
313 | return rows
314 |
315 | def has_table(self, connection, table_name, schema=None, **kw):
316 | try:
317 | self._get_table_columns(connection, table_name, schema)
318 | return True
319 | except exc.NoSuchTableError:
320 | return False
321 |
322 | def get_columns(self, connection, table_name, schema=None, **kw):
323 | rows = self._get_table_columns(connection, table_name, schema)
324 | # Strip whitespace
325 | rows = [[col.strip() if col else None for col in row] for row in rows]
326 | # Filter out empty rows and comment
327 | rows = [row for row in rows if row[0] and row[0] != '# col_name']
328 | result = []
329 | for (col_name, col_type, _comment) in rows:
330 | if col_name == '# Partition Information':
331 | break
332 | # Take out the more detailed type information
333 | # e.g. 'map' -> 'map'
334 | # 'decimal(10,1)' -> decimal
335 | col_type = re.search(r'^\w+', col_type).group(0)
336 | try:
337 | coltype = _type_map[col_type]
338 | except KeyError:
339 | util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name))
340 | coltype = types.NullType
341 |
342 | result.append({
343 | 'name': col_name,
344 | 'type': coltype,
345 | 'nullable': True,
346 | 'default': None,
347 | })
348 | return result
349 |
350 | def get_foreign_keys(self, connection, table_name, schema=None, **kw):
351 | # Hive has no support for foreign keys.
352 | return []
353 |
354 | def get_pk_constraint(self, connection, table_name, schema=None, **kw):
355 | # Hive has no support for primary keys.
356 | return []
357 |
358 | def get_indexes(self, connection, table_name, schema=None, **kw):
359 | rows = self._get_table_columns(connection, table_name, schema)
360 | # Strip whitespace
361 | rows = [[col.strip() if col else None for col in row] for row in rows]
362 | # Filter out empty rows and comment
363 | rows = [row for row in rows if row[0] and row[0] != '# col_name']
364 | for i, (col_name, _col_type, _comment) in enumerate(rows):
365 | if col_name == '# Partition Information':
366 | break
367 | # Handle partition columns
368 | col_names = []
369 | for col_name, _col_type, _comment in rows[i + 1:]:
370 | col_names.append(col_name)
371 | if col_names:
372 | return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
373 | else:
374 | return []
375 |
376 | def get_table_names(self, connection, schema=None, **kw):
377 | query = 'SHOW TABLES'
378 | if schema:
379 | query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
380 | return [row[0] for row in connection.execute(text(query))]
381 |
382 | def do_rollback(self, dbapi_connection):
383 | # No transactions for Hive
384 | pass
385 |
386 | def _check_unicode_returns(self, connection, additional_tests=None):
387 | # We decode everything as UTF-8
388 | return True
389 |
390 | def _check_unicode_description(self, connection):
391 | # We decode everything as UTF-8
392 | return True
393 |
394 |
395 | class HiveHTTPDialect(HiveDialect):
396 |
397 | name = "hive"
398 | scheme = "http"
399 | driver = "rest"
400 |
401 | def create_connect_args(self, url):
402 | kwargs = {
403 | "host": url.host,
404 | "port": url.port or 10000,
405 | "scheme": self.scheme,
406 | "username": url.username or None,
407 | "password": url.password or None,
408 | }
409 | if url.query:
410 | kwargs.update(url.query)
411 | return [], kwargs
412 | return ([], kwargs)
413 |
414 |
415 | class HiveHTTPSDialect(HiveHTTPDialect):
416 |
417 | name = "hive"
418 | scheme = "https"
419 |
--------------------------------------------------------------------------------
/pyhive/sqlalchemy_presto.py:
--------------------------------------------------------------------------------
1 | """Integration between SQLAlchemy and Presto.
2 |
3 | Some code based on
4 | https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
5 | which is released under the MIT license.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import unicode_literals
10 |
11 | import re
12 | import sqlalchemy
13 | from sqlalchemy import exc
14 | from sqlalchemy import types
15 | from sqlalchemy import util
16 | # TODO shouldn't use mysql type
17 | from sqlalchemy.sql import text
18 | try:
19 | from sqlalchemy.databases import mysql
20 | mysql_tinyinteger = mysql.MSTinyInteger
21 | except ImportError:
22 | # Required for SQLAlchemy>=2.0
23 | from sqlalchemy.dialects import mysql
24 | mysql_tinyinteger = mysql.base.MSTinyInteger
25 | from sqlalchemy.engine import default
26 | from sqlalchemy.sql import compiler
27 | from sqlalchemy.sql.compiler import SQLCompiler
28 |
29 | from pyhive import presto
30 | from pyhive.common import UniversalSet
31 |
32 | sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
33 |
34 | class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
35 | # Just quote everything to make things simpler / easier to upgrade
36 | reserved_words = UniversalSet()
37 |
38 |
39 | _type_map = {
40 | 'boolean': types.Boolean,
41 | 'tinyint': mysql_tinyinteger,
42 | 'smallint': types.SmallInteger,
43 | 'integer': types.Integer,
44 | 'bigint': types.BigInteger,
45 | 'real': types.Float,
46 | 'double': types.Float,
47 | 'varchar': types.String,
48 | 'timestamp': types.TIMESTAMP,
49 | 'date': types.DATE,
50 | 'varbinary': types.VARBINARY,
51 | }
52 |
53 |
54 | class PrestoCompiler(SQLCompiler):
55 | def visit_char_length_func(self, fn, **kw):
56 | return 'length{}'.format(self.function_argspec(fn, **kw))
57 |
58 |
59 | class PrestoTypeCompiler(compiler.GenericTypeCompiler):
60 | def visit_CLOB(self, type_, **kw):
61 | raise ValueError("Presto does not support the CLOB column type.")
62 |
63 | def visit_NCLOB(self, type_, **kw):
64 | raise ValueError("Presto does not support the NCLOB column type.")
65 |
66 | def visit_DATETIME(self, type_, **kw):
67 | raise ValueError("Presto does not support the DATETIME column type.")
68 |
69 | def visit_FLOAT(self, type_, **kw):
70 | return 'DOUBLE'
71 |
72 | def visit_TEXT(self, type_, **kw):
73 | if type_.length:
74 | return 'VARCHAR({:d})'.format(type_.length)
75 | else:
76 | return 'VARCHAR'
77 |
78 |
79 | class PrestoDialect(default.DefaultDialect):
80 | name = 'presto'
81 | driver = 'rest'
82 | paramstyle = 'pyformat'
83 | preparer = PrestoIdentifierPreparer
84 | statement_compiler = PrestoCompiler
85 | supports_alter = False
86 | supports_pk_autoincrement = False
87 | supports_default_values = False
88 | supports_empty_insert = False
89 | supports_multivalues_insert = True
90 | supports_unicode_statements = True
91 | supports_unicode_binds = True
92 | supports_statement_cache = False
93 | returns_unicode_strings = True
94 | description_encoding = None
95 | supports_native_boolean = True
96 | type_compiler = PrestoTypeCompiler
97 |
98 | @classmethod
99 | def dbapi(cls):
100 | return presto
101 |
102 | @classmethod
103 | def import_dbapi(cls):
104 | return presto
105 |
106 | def create_connect_args(self, url):
107 | db_parts = (url.database or 'hive').split('/')
108 | kwargs = {
109 | 'host': url.host,
110 | 'port': url.port or 8080,
111 | 'username': url.username,
112 | 'password': url.password
113 | }
114 | kwargs.update(url.query)
115 | if len(db_parts) == 1:
116 | kwargs['catalog'] = db_parts[0]
117 | elif len(db_parts) == 2:
118 | kwargs['catalog'] = db_parts[0]
119 | kwargs['schema'] = db_parts[1]
120 | else:
121 | raise ValueError("Unexpected database format {}".format(url.database))
122 | return [], kwargs
123 |
124 | def get_schema_names(self, connection, **kw):
125 | return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]
126 |
127 | def _get_table_columns(self, connection, table_name, schema):
128 | full_table = self.identifier_preparer.quote_identifier(table_name)
129 | if schema:
130 | full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
131 | try:
132 | return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
133 | except (presto.DatabaseError, exc.DatabaseError) as e:
134 | # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
135 | # it successfully does in the Hive version. The difference with Presto is that this
136 | # error is raised when fetching the cursor's description rather than the initial execute
137 | # call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
138 | # presto.DatabaseError here.
139 | # Does the table exist?
140 | msg = (
141 | e.args[0].get('message') if e.args and isinstance(e.args[0], dict)
142 | else e.args[0] if e.args and isinstance(e.args[0], str)
143 | else None
144 | )
145 | regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
146 | if msg and re.search(regex, msg):
147 | raise exc.NoSuchTableError(table_name)
148 | else:
149 | raise
150 |
151 | def has_table(self, connection, table_name, schema=None, **kw):
152 | try:
153 | self._get_table_columns(connection, table_name, schema)
154 | return True
155 | except exc.NoSuchTableError:
156 | return False
157 |
158 | def get_columns(self, connection, table_name, schema=None, **kw):
159 | rows = self._get_table_columns(connection, table_name, schema)
160 | result = []
161 | for row in rows:
162 | try:
163 | coltype = _type_map[row.Type]
164 | except KeyError:
165 | util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column))
166 | coltype = types.NullType
167 | result.append({
168 | 'name': row.Column,
169 | 'type': coltype,
170 | # newer Presto no longer includes this column
171 | 'nullable': getattr(row, 'Null', True),
172 | 'default': None,
173 | })
174 | return result
175 |
176 | def get_foreign_keys(self, connection, table_name, schema=None, **kw):
177 | # Hive has no support for foreign keys.
178 | return []
179 |
180 | def get_pk_constraint(self, connection, table_name, schema=None, **kw):
181 | # Hive has no support for primary keys.
182 | return []
183 |
184 | def get_indexes(self, connection, table_name, schema=None, **kw):
185 | rows = self._get_table_columns(connection, table_name, schema)
186 | col_names = []
187 | for row in rows:
188 | part_key = 'Partition Key'
189 | # Presto puts this information in one of 3 places depending on version
190 | # - a boolean column named "Partition Key"
191 | # - a string in the "Comment" column
192 | # - a string in the "Extra" column
193 | if sqlalchemy_version >= 1.4:
194 | row = row._mapping
195 | is_partition_key = (
196 | (part_key in row and row[part_key])
197 | or row['Comment'].startswith(part_key)
198 | or ('Extra' in row and 'partition key' in row['Extra'])
199 | )
200 | if is_partition_key:
201 | col_names.append(row['Column'])
202 | if col_names:
203 | return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
204 | else:
205 | return []
206 |
207 | def get_table_names(self, connection, schema=None, **kw):
208 | query = 'SHOW TABLES'
209 | if schema:
210 | query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
211 | return [row.Table for row in connection.execute(text(query))]
212 |
213 | def do_rollback(self, dbapi_connection):
214 | # No transactions for Presto
215 | pass
216 |
217 | def _check_unicode_returns(self, connection, additional_tests=None):
218 | # requests gives back Unicode strings
219 | return True
220 |
221 | def _check_unicode_description(self, connection):
222 | # requests gives back Unicode strings
223 | return True
224 |
--------------------------------------------------------------------------------
/pyhive/sqlalchemy_trino.py:
--------------------------------------------------------------------------------
1 | """Integration between SQLAlchemy and Trino.
2 |
3 | Some code based on
4 | https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
5 | which is released under the MIT license.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import unicode_literals
10 |
11 | import re
12 | from sqlalchemy import exc
13 | from sqlalchemy import types
14 | from sqlalchemy import util
15 | # TODO shouldn't use mysql type
16 | try:
17 | from sqlalchemy.databases import mysql
18 | mysql_tinyinteger = mysql.MSTinyInteger
19 | except ImportError:
20 | # Required for SQLAlchemy>=2.0
21 | from sqlalchemy.dialects import mysql
22 | mysql_tinyinteger = mysql.base.MSTinyInteger
23 | from sqlalchemy.engine import default
24 | from sqlalchemy.sql import compiler
25 | from sqlalchemy.sql.compiler import SQLCompiler
26 |
27 | from pyhive import trino
28 | from pyhive.common import UniversalSet
29 | from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer
30 |
31 | class TrinoIdentifierPreparer(PrestoIdentifierPreparer):
32 | pass
33 |
34 |
35 | _type_map = {
36 | 'boolean': types.Boolean,
37 | 'tinyint': mysql_tinyinteger,
38 | 'smallint': types.SmallInteger,
39 | 'integer': types.Integer,
40 | 'bigint': types.BigInteger,
41 | 'real': types.Float,
42 | 'double': types.Float,
43 | 'varchar': types.String,
44 | 'timestamp': types.TIMESTAMP,
45 | 'date': types.DATE,
46 | 'varbinary': types.VARBINARY,
47 | }
48 |
49 |
50 | class TrinoCompiler(PrestoCompiler):
51 | pass
52 |
53 |
54 | class TrinoTypeCompiler(PrestoCompiler):
55 | def visit_CLOB(self, type_, **kw):
56 | raise ValueError("Trino does not support the CLOB column type.")
57 |
58 | def visit_NCLOB(self, type_, **kw):
59 | raise ValueError("Trino does not support the NCLOB column type.")
60 |
61 | def visit_DATETIME(self, type_, **kw):
62 | raise ValueError("Trino does not support the DATETIME column type.")
63 |
64 | def visit_FLOAT(self, type_, **kw):
65 | return 'DOUBLE'
66 |
67 | def visit_TEXT(self, type_, **kw):
68 | if type_.length:
69 | return 'VARCHAR({:d})'.format(type_.length)
70 | else:
71 | return 'VARCHAR'
72 |
73 |
74 | class TrinoDialect(PrestoDialect):
75 | name = 'trino'
76 | supports_statement_cache = False
77 |
78 | @classmethod
79 | def dbapi(cls):
80 | return trino
81 |
82 | @classmethod
83 | def import_dbapi(cls):
84 | return trino
85 |
--------------------------------------------------------------------------------
/pyhive/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/PyHive/ac09074a652fd50e10b57a7f0bbc4f6410961301/pyhive/tests/__init__.py
--------------------------------------------------------------------------------
/pyhive/tests/dbapi_test_case.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """Shared DB-API test cases"""
3 |
4 | from __future__ import absolute_import
5 | from __future__ import unicode_literals
6 | from builtins import object
7 | from builtins import range
8 | from future.utils import with_metaclass
9 | from pyhive import exc
10 | import abc
11 | import contextlib
12 | import functools
13 |
14 |
15 | def with_cursor(fn):
16 | """Pass a cursor to the given function and handle cleanup.
17 |
18 | The cursor is taken from ``self.connect()``.
19 | """
20 | @functools.wraps(fn)
21 | def wrapped_fn(self, *args, **kwargs):
22 | with contextlib.closing(self.connect()) as connection:
23 | with contextlib.closing(connection.cursor()) as cursor:
24 | fn(self, cursor, *args, **kwargs)
25 | return wrapped_fn
26 |
27 |
28 | class DBAPITestCase(with_metaclass(abc.ABCMeta, object)):
29 | @abc.abstractmethod
30 | def connect(self):
31 | raise NotImplementedError # pragma: no cover
32 |
33 | @with_cursor
34 | def test_fetchone(self, cursor):
35 | cursor.execute('SELECT * FROM one_row')
36 | self.assertEqual(cursor.rownumber, 0)
37 | self.assertEqual(cursor.fetchone(), (1,))
38 | self.assertEqual(cursor.rownumber, 1)
39 | self.assertIsNone(cursor.fetchone())
40 |
41 | @with_cursor
42 | def test_fetchall(self, cursor):
43 | cursor.execute('SELECT * FROM one_row')
44 | self.assertEqual(cursor.fetchall(), [(1,)])
45 | cursor.execute('SELECT a FROM many_rows ORDER BY a')
46 | self.assertEqual(cursor.fetchall(), [(i,) for i in range(10000)])
47 |
48 | @with_cursor
49 | def test_null_param(self, cursor):
50 | cursor.execute('SELECT %s FROM one_row', (None,))
51 | self.assertEqual(cursor.fetchall(), [(None,)])
52 |
53 | @with_cursor
54 | def test_iterator(self, cursor):
55 | cursor.execute('SELECT * FROM one_row')
56 | self.assertEqual(list(cursor), [(1,)])
57 | self.assertRaises(StopIteration, cursor.__next__)
58 |
59 | @with_cursor
60 | def test_description_initial(self, cursor):
61 | self.assertIsNone(cursor.description)
62 |
63 | @with_cursor
64 | def test_description_failed(self, cursor):
65 | try:
66 | cursor.execute('blah_blah')
67 | self.assertIsNone(cursor.description)
68 | except exc.DatabaseError:
69 | pass
70 |
71 | @with_cursor
72 | def test_bad_query(self, cursor):
73 | def run():
74 | cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist')
75 | cursor.fetchone()
76 | self.assertRaises(exc.DatabaseError, run)
77 |
78 | @with_cursor
79 | def test_concurrent_execution(self, cursor):
80 | cursor.execute('SELECT * FROM one_row')
81 | cursor.execute('SELECT * FROM one_row')
82 | self.assertEqual(cursor.fetchall(), [(1,)])
83 |
84 | @with_cursor
85 | def test_executemany(self, cursor):
86 | for length in 1, 2:
87 | cursor.executemany(
88 | 'SELECT %(x)d FROM one_row',
89 | [{'x': i} for i in range(1, length + 1)]
90 | )
91 | self.assertEqual(cursor.fetchall(), [(length,)])
92 |
93 | @with_cursor
94 | def test_executemany_none(self, cursor):
95 | cursor.executemany('should_never_get_used', [])
96 | self.assertIsNone(cursor.description)
97 | self.assertRaises(exc.ProgrammingError, cursor.fetchone)
98 |
99 | @with_cursor
100 | def test_fetchone_no_data(self, cursor):
101 | self.assertRaises(exc.ProgrammingError, cursor.fetchone)
102 |
103 | @with_cursor
104 | def test_fetchmany(self, cursor):
105 | cursor.execute('SELECT * FROM many_rows LIMIT 15')
106 | self.assertEqual(cursor.fetchmany(0), [])
107 | self.assertEqual(len(cursor.fetchmany(10)), 10)
108 | self.assertEqual(len(cursor.fetchmany(10)), 5)
109 |
110 | @with_cursor
111 | def test_arraysize(self, cursor):
112 | cursor.arraysize = 5
113 | cursor.execute('SELECT * FROM many_rows LIMIT 20')
114 | self.assertEqual(len(cursor.fetchmany()), 5)
115 |
116 | @with_cursor
117 | def test_polling_loop(self, cursor):
118 | """Try to trigger the polling logic in fetchone()"""
119 | cursor._poll_interval = 0
120 | cursor.execute('SELECT COUNT(*) FROM many_rows')
121 | self.assertEqual(cursor.fetchone(), (10000,))
122 |
123 | @with_cursor
124 | def test_no_params(self, cursor):
125 | cursor.execute("SELECT '%(x)s' FROM one_row")
126 | self.assertEqual(cursor.fetchall(), [('%(x)s',)])
127 |
128 | def test_escape(self):
129 | """Verify that funny characters can be escaped as strings and SELECTed back"""
130 | bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t '''
131 | self.run_escape_case(bad_str)
132 |
133 | @with_cursor
134 | def run_escape_case(self, cursor, bad_str):
135 | cursor.execute(
136 | 'SELECT %d, %s FROM one_row',
137 | (1, bad_str)
138 | )
139 | self.assertEqual(cursor.fetchall(), [(1, bad_str,)])
140 | cursor.execute(
141 | 'SELECT %(a)d, %(b)s FROM one_row',
142 | {'a': 1, 'b': bad_str}
143 | )
144 | self.assertEqual(cursor.fetchall(), [(1, bad_str)])
145 |
146 | @with_cursor
147 | def test_invalid_params(self, cursor):
148 | self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', 'hi'))
149 | self.assertRaises(exc.ProgrammingError, lambda: cursor.execute('', [object]))
150 |
151 | def test_open_close(self):
152 | with contextlib.closing(self.connect()):
153 | pass
154 | with contextlib.closing(self.connect()) as connection:
155 | with contextlib.closing(connection.cursor()):
156 | pass
157 |
158 | @with_cursor
159 | def test_unicode(self, cursor):
160 | unicode_str = "王兢"
161 | cursor.execute(
162 | 'SELECT %s FROM one_row',
163 | (unicode_str,)
164 | )
165 | self.assertEqual(cursor.fetchall(), [(unicode_str,)])
166 |
167 | @with_cursor
168 | def test_null(self, cursor):
169 | cursor.execute('SELECT null FROM many_rows')
170 | self.assertEqual(cursor.fetchall(), [(None,)] * 10000)
171 | cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows')
172 | self.assertEqual(cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in range(10000)])
173 |
174 | @with_cursor
175 | def test_sql_where_in(self, cursor):
176 | cursor.execute('SELECT * FROM many_rows where a in %s', ([1, 2, 3],))
177 | self.assertEqual(len(cursor.fetchall()), 3)
178 | cursor.execute('SELECT * FROM many_rows where b in %s limit 10',
179 | (['blah'],))
180 | self.assertEqual(len(cursor.fetchall()), 10)
181 |
--------------------------------------------------------------------------------
/pyhive/tests/ldif_data/INITIAL_TESTDATA.ldif:
--------------------------------------------------------------------------------
1 | dn: cn=existing,dc=example,dc=com
2 | objectClass: inetOrgPerson
3 | objectClass: organizationalPerson
4 | objectClass: person
5 | objectClass: top
6 | cn: existing
7 | sn: testentry
8 | mail: test.entry@example.com
9 | mail: te@example.com
10 | ou:: UGFsbcO2
11 | userPassword: testpw
12 | givenName: i am
13 | description: A test entry for pyhive
14 |
--------------------------------------------------------------------------------
/pyhive/tests/ldif_data/base.ldif:
--------------------------------------------------------------------------------
1 | # ldapadd -x -h localhost -p 389 -D "cn=admin,dc=test,dc=com" -w secret -f base.ldif
2 |
3 | dn: dc=example,dc=com
4 | objectClass: dcObject
5 | objectClass: organizationalUnit
6 | #dc: test
7 | ou: Test
8 |
--------------------------------------------------------------------------------
/pyhive/tests/sqlalchemy_test_case.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | from __future__ import absolute_import
3 | from __future__ import unicode_literals
4 |
5 | import abc
6 | import re
7 | import contextlib
8 | import functools
9 |
10 | import pytest
11 | import sqlalchemy
12 | from builtins import object
13 | from future.utils import with_metaclass
14 | from sqlalchemy.exc import NoSuchTableError
15 | from sqlalchemy.schema import Index
16 | from sqlalchemy.schema import MetaData
17 | from sqlalchemy.schema import Table
18 | from sqlalchemy.sql import expression, text
19 | from sqlalchemy import String
20 |
21 | sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
22 |
23 | def with_engine_connection(fn):
24 | """Pass a connection to the given function and handle cleanup.
25 |
26 | The connection is taken from ``self.create_engine()``.
27 | """
28 | @functools.wraps(fn)
29 | def wrapped_fn(self, *args, **kwargs):
30 | engine = self.create_engine()
31 | try:
32 | with contextlib.closing(engine.connect()) as connection:
33 | fn(self, engine, connection, *args, **kwargs)
34 | finally:
35 | engine.dispose()
36 | return wrapped_fn
37 |
38 | def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks):
39 | if sqlalchemy_version >= 1.4:
40 | insp = sqlalchemy.inspect(engine)
41 | insp.reflect_table(
42 | table,
43 | include_columns=include_columns,
44 | exclude_columns=exclude_columns,
45 | resolve_fks=resolve_fks,
46 | )
47 | else:
48 | engine.dialect.reflecttable(
49 | connection, table, include_columns=include_columns,
50 | exclude_columns=exclude_columns, resolve_fks=resolve_fks)
51 |
52 |
53 | class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)):
54 | @with_engine_connection
55 | def test_basic_query(self, engine, connection):
56 | rows = connection.execute(text('SELECT * FROM one_row')).fetchall()
57 | self.assertEqual(len(rows), 1)
58 | self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name
59 | self.assertEqual(len(rows[0]), 1)
60 |
61 | @with_engine_connection
62 | def test_one_row_complex_null(self, engine, connection):
63 | one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine)
64 | rows = connection.execute(one_row_complex_null.select()).fetchall()
65 | self.assertEqual(len(rows), 1)
66 | self.assertEqual(list(rows[0]), [None] * len(rows[0]))
67 |
68 | @with_engine_connection
69 | def test_reflect_no_such_table(self, engine, connection):
70 | """reflecttable should throw an exception on an invalid table"""
71 | self.assertRaises(
72 | NoSuchTableError,
73 | lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
74 | self.assertRaises(
75 | NoSuchTableError,
76 | lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine))
77 |
78 | @with_engine_connection
79 | def test_reflect_include_columns(self, engine, connection):
80 | """When passed include_columns, reflecttable should filter out other columns"""
81 |
82 | one_row_complex = Table('one_row_complex', MetaData())
83 | reflect_table(engine, connection, one_row_complex, include_columns=['int'],
84 | exclude_columns=[], resolve_fks=True)
85 |
86 | self.assertEqual(len(one_row_complex.c), 1)
87 | self.assertIsNotNone(one_row_complex.c.int)
88 | self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint)
89 |
90 | @with_engine_connection
91 | def test_reflect_with_schema(self, engine, connection):
92 | dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine)
93 | self.assertEqual(len(dummy.c), 1)
94 | self.assertIsNotNone(dummy.c.a)
95 |
96 | @pytest.mark.filterwarnings('default:Omitting:sqlalchemy.exc.SAWarning')
97 | @with_engine_connection
98 | def test_reflect_partitions(self, engine, connection):
99 | """reflecttable should get the partition column as an index"""
100 | many_rows = Table('many_rows', MetaData(), autoload_with=engine)
101 | self.assertEqual(len(many_rows.c), 2)
102 | self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
103 |
104 | many_rows = Table('many_rows', MetaData())
105 | reflect_table(engine, connection, many_rows, include_columns=['a'],
106 | exclude_columns=[], resolve_fks=True)
107 |
108 | self.assertEqual(len(many_rows.c), 1)
109 | self.assertFalse(many_rows.c.a.index)
110 | self.assertFalse(many_rows.indexes)
111 |
112 | many_rows = Table('many_rows', MetaData())
113 | reflect_table(engine, connection, many_rows, include_columns=['b'],
114 | exclude_columns=[], resolve_fks=True)
115 |
116 | self.assertEqual(len(many_rows.c), 1)
117 | self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
118 |
119 | @with_engine_connection
120 | def test_unicode(self, engine, connection):
121 | """Verify that unicode strings make it through SQLAlchemy and the backend"""
122 | unicode_str = "中文"
123 | one_row = Table('one_row', MetaData())
124 |
125 | if sqlalchemy_version >= 1.4:
126 | returned_str = connection.execute(sqlalchemy.select(
127 | expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar()
128 | else:
129 | returned_str = connection.execute(sqlalchemy.select([
130 | expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar()
131 |
132 | self.assertEqual(returned_str, unicode_str)
133 |
134 | @with_engine_connection
135 | def test_reflect_schemas(self, engine, connection):
136 | insp = sqlalchemy.inspect(engine)
137 | schemas = insp.get_schema_names()
138 | self.assertIn('pyhive_test_database', schemas)
139 | self.assertIn('default', schemas)
140 |
141 | @with_engine_connection
142 | def test_get_table_names(self, engine, connection):
143 | meta = MetaData()
144 | meta.reflect(bind=engine)
145 | self.assertIn('one_row', meta.tables)
146 | self.assertIn('one_row_complex', meta.tables)
147 |
148 | insp = sqlalchemy.inspect(engine)
149 | self.assertIn(
150 | 'dummy_table',
151 | insp.get_table_names(schema='pyhive_test_database'),
152 | )
153 |
154 | @with_engine_connection
155 | def test_has_table(self, engine, connection):
156 | if sqlalchemy_version >= 1.4:
157 | insp = sqlalchemy.inspect(engine)
158 | self.assertTrue(insp.has_table("one_row"))
159 | self.assertFalse(insp.has_table("this_table_does_not_exist"))
160 | else:
161 | self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
162 | self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists())
163 |
164 | @with_engine_connection
165 | def test_char_length(self, engine, connection):
166 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
167 |
168 | if sqlalchemy_version >= 1.4:
169 | result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar()
170 | else:
171 | result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar()
172 |
173 | self.assertEqual(result, len('a string'))
174 |
--------------------------------------------------------------------------------
/pyhive/tests/test_common.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | from __future__ import absolute_import
3 | from __future__ import unicode_literals
4 | from pyhive import common
5 | import datetime
6 | import unittest
7 |
8 |
9 | class TestCommon(unittest.TestCase):
10 | def test_escape_args(self):
11 | escaper = common.ParamEscaper()
12 | self.assertEqual(escaper.escape_args({'foo': 'bar'}),
13 | {'foo': "'bar'"})
14 | self.assertEqual(escaper.escape_args({'foo': 123}),
15 | {'foo': 123})
16 | self.assertEqual(escaper.escape_args({'foo': 123.456}),
17 | {'foo': 123.456})
18 | self.assertEqual(escaper.escape_args({'foo': ['a', 'b', 'c']}),
19 | {'foo': "('a','b','c')"})
20 | self.assertEqual(escaper.escape_args({'foo': ('a', 'b', 'c')}),
21 | {'foo': "('a','b','c')"})
22 | self.assertIn(escaper.escape_args({'foo': {'a', 'b'}}),
23 | ({'foo': "('a','b')"}, {'foo': "('b','a')"}))
24 | self.assertIn(escaper.escape_args({'foo': frozenset(['a', 'b'])}),
25 | ({'foo': "('a','b')"}, {'foo': "('b','a')"}))
26 |
27 | self.assertEqual(escaper.escape_args(('bar',)),
28 | ("'bar'",))
29 | self.assertEqual(escaper.escape_args([123]),
30 | (123,))
31 | self.assertEqual(escaper.escape_args((123.456,)),
32 | (123.456,))
33 | self.assertEqual(escaper.escape_args((['a', 'b', 'c'],)),
34 | ("('a','b','c')",))
35 | self.assertEqual(escaper.escape_args((['你好', 'b', 'c'],)),
36 | ("('你好','b','c')",))
37 | self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
38 | ("'2020-04-17'",))
39 | self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
40 | ("'2020-04-17 12:00:00.123456'",))
41 |
--------------------------------------------------------------------------------
/pyhive/tests/test_hive.py:
--------------------------------------------------------------------------------
1 | """Hive integration tests.
2 |
3 | These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
4 | They also require a tables created by make_test_tables.sh.
5 | """
6 |
7 | from __future__ import absolute_import
8 | from __future__ import unicode_literals
9 |
10 | import contextlib
11 | import datetime
12 | import os
13 | import socket
14 | import subprocess
15 | import time
16 | import unittest
17 | from decimal import Decimal
18 |
19 | import mock
20 | import thrift.transport.TSocket
21 | import thrift.transport.TTransport
22 | import thrift_sasl
23 | from thrift.transport.TTransport import TTransportException
24 |
25 | from TCLIService import ttypes
26 | from pyhive import hive
27 | from pyhive.tests.dbapi_test_case import DBAPITestCase
28 | from pyhive.tests.dbapi_test_case import with_cursor
29 |
30 | _HOST = 'localhost'
31 |
32 |
33 | class TestHive(unittest.TestCase, DBAPITestCase):
34 | __test__ = True
35 |
36 | def connect(self):
37 | return hive.connect(host=_HOST, configuration={'mapred.job.tracker': 'local'})
38 |
39 | @with_cursor
40 | def test_description(self, cursor):
41 | cursor.execute('SELECT * FROM one_row')
42 |
43 | desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)]
44 | self.assertEqual(cursor.description, desc)
45 |
46 | @with_cursor
47 | def test_complex(self, cursor):
48 | cursor.execute('SELECT * FROM one_row_complex')
49 | self.assertEqual(cursor.description, [
50 | ('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True),
51 | ('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True),
52 | ('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True),
53 | ('one_row_complex.int', 'INT_TYPE', None, None, None, None, True),
54 | ('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True),
55 | ('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True),
56 | ('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True),
57 | ('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True),
58 | ('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True),
59 | ('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True),
60 | ('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True),
61 | ('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True),
62 | ('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True),
63 | ('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True),
64 | ('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True),
65 | ])
66 | rows = cursor.fetchall()
67 | expected = [(
68 | True,
69 | 127,
70 | 32767,
71 | 2147483647,
72 | 9223372036854775807,
73 | 0.5,
74 | 0.25,
75 | 'a string',
76 | datetime.datetime(1970, 1, 1, 0, 0),
77 | b'123',
78 | '[1,2]',
79 | '{1:2,3:4}',
80 | '{"a":1,"b":2}',
81 | '{0:1}',
82 | Decimal('0.1'),
83 | )]
84 | self.assertEqual(rows, expected)
85 | # catch unicode/str
86 | self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
87 |
88 | @with_cursor
89 | def test_async(self, cursor):
90 | cursor.execute('SELECT * FROM one_row', async_=True)
91 | unfinished_states = (
92 | ttypes.TOperationState.INITIALIZED_STATE,
93 | ttypes.TOperationState.RUNNING_STATE,
94 | )
95 | while cursor.poll().operationState in unfinished_states:
96 | cursor.fetch_logs()
97 | assert cursor.poll().operationState == ttypes.TOperationState.FINISHED_STATE
98 |
99 | self.assertEqual(len(cursor.fetchall()), 1)
100 |
101 | @with_cursor
102 | def test_cancel(self, cursor):
103 | # Need to do a JOIN to force a MR job. Without it, Hive optimizes the query to a fetch
104 | # operator and prematurely declares the query done.
105 | cursor.execute(
106 | "SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) "
107 | "FROM one_row a JOIN one_row b",
108 | async_=True
109 | )
110 | self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE)
111 | assert any('Stage' in line for line in cursor.fetch_logs())
112 | cursor.cancel()
113 | self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE)
114 |
115 | def test_noops(self):
116 | """The DB-API specification requires that certain actions exist, even though they might not
117 | be applicable."""
118 | # Wohoo inflating coverage stats!
119 | with contextlib.closing(self.connect()) as connection:
120 | with contextlib.closing(connection.cursor()) as cursor:
121 | self.assertEqual(cursor.rowcount, -1)
122 | cursor.setinputsizes([])
123 | cursor.setoutputsize(1, 'blah')
124 | connection.commit()
125 |
126 | @mock.patch('TCLIService.TCLIService.Client.OpenSession')
127 | def test_open_failed(self, open_session):
128 | open_session.return_value.serverProtocolVersion = \
129 | ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
130 | self.assertRaises(hive.OperationalError, self.connect)
131 |
132 | def test_escape(self):
133 | # Hive thrift translates newlines into multiple rows. WTF.
134 | bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
135 | self.run_escape_case(bad_str)
136 |
137 | def test_newlines(self):
138 | """Verify that newlines are passed through correctly"""
139 | cursor = self.connect().cursor()
140 | orig = ' \r\n \r \n '
141 | cursor.execute(
142 | 'SELECT %s FROM one_row',
143 | (orig,)
144 | )
145 | result = cursor.fetchall()
146 | self.assertEqual(result, [(orig,)])
147 |
148 | @with_cursor
149 | def test_no_result_set(self, cursor):
150 | cursor.execute('USE default')
151 | self.assertIsNone(cursor.description)
152 | self.assertRaises(hive.ProgrammingError, cursor.fetchone)
153 |
154 | def test_ldap_connection(self):
155 | rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
156 | orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-ldap.xml')
157 | orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml')
158 | des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
159 | try:
160 | subprocess.check_call(['sudo', 'cp', orig_ldap, des])
161 | _restart_hs2()
162 | with contextlib.closing(hive.connect(
163 | host=_HOST, username='existing', auth='LDAP', password='testpw')
164 | ) as connection:
165 | with contextlib.closing(connection.cursor()) as cursor:
166 | cursor.execute('SELECT * FROM one_row')
167 | self.assertEqual(cursor.fetchall(), [(1,)])
168 |
169 | self.assertRaisesRegexp(
170 | TTransportException, 'Error validating the login',
171 | lambda: hive.connect(
172 | host=_HOST, username='existing', auth='LDAP', password='wrong')
173 | )
174 |
175 | finally:
176 | subprocess.check_call(['sudo', 'cp', orig_none, des])
177 | _restart_hs2()
178 |
179 | def test_invalid_ldap_config(self):
180 | """password should be set if and only if using LDAP"""
181 | self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
182 | lambda: hive.connect(_HOST, password=''))
183 | self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
184 | lambda: hive.connect(_HOST, auth='LDAP'))
185 |
186 | def test_invalid_kerberos_config(self):
187 | """kerberos_service_name should be set if and only if using KERBEROS"""
188 | self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
189 | lambda: hive.connect(_HOST, kerberos_service_name=''))
190 | self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
191 | lambda: hive.connect(_HOST, auth='KERBEROS'))
192 |
193 | def test_invalid_transport(self):
194 | """transport and auth are incompatible"""
195 | socket = thrift.transport.TSocket.TSocket('localhost', 10000)
196 | transport = thrift.transport.TTransport.TBufferedTransport(socket)
197 | self.assertRaisesRegexp(
198 | ValueError, 'thrift_transport cannot be used with',
199 | lambda: hive.connect(_HOST, thrift_transport=transport)
200 | )
201 |
202 | def test_custom_transport(self):
203 | socket = thrift.transport.TSocket.TSocket('localhost', 10000)
204 | sasl_auth = 'PLAIN'
205 |
206 | transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
207 | conn = hive.connect(thrift_transport=transport)
208 | with contextlib.closing(conn):
209 | with contextlib.closing(conn.cursor()) as cursor:
210 | cursor.execute('SELECT * FROM one_row')
211 | self.assertEqual(cursor.fetchall(), [(1,)])
212 |
213 | def test_custom_connection(self):
214 | rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
215 | orig_ldap = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site-custom.xml')
216 | orig_none = os.path.join(rootdir, 'scripts', 'travis-conf', 'hive', 'hive-site.xml')
217 | des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
218 | try:
219 | subprocess.check_call(['sudo', 'cp', orig_ldap, des])
220 | _restart_hs2()
221 | with contextlib.closing(hive.connect(
222 | host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd')
223 | ) as connection:
224 | with contextlib.closing(connection.cursor()) as cursor:
225 | cursor.execute('SELECT * FROM one_row')
226 | self.assertEqual(cursor.fetchall(), [(1,)])
227 |
228 | self.assertRaisesRegexp(
229 | TTransportException, 'Error validating the login',
230 | lambda: hive.connect(
231 | host=_HOST, username='the-user', auth='CUSTOM', password='wrong')
232 | )
233 |
234 | finally:
235 | subprocess.check_call(['sudo', 'cp', orig_none, des])
236 | _restart_hs2()
237 |
238 |
239 | def _restart_hs2():
240 | subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart'])
241 | with contextlib.closing(socket.socket()) as s:
242 | while s.connect_ex(('localhost', 10000)) != 0:
243 | time.sleep(1)
244 |
--------------------------------------------------------------------------------
/pyhive/tests/test_presto.py:
--------------------------------------------------------------------------------
1 | """Presto integration tests.
2 |
3 | These rely on having a Presto+Hadoop cluster set up.
4 | They also require a tables created by make_test_tables.sh.
5 | """
6 |
7 | from __future__ import absolute_import
8 | from __future__ import unicode_literals
9 |
10 | import contextlib
11 | import os
12 | from decimal import Decimal
13 |
14 | import requests
15 |
16 | from pyhive import exc
17 | from pyhive import presto
18 | from pyhive.tests.dbapi_test_case import DBAPITestCase
19 | from pyhive.tests.dbapi_test_case import with_cursor
20 | import mock
21 | import unittest
22 | import datetime
23 |
24 | _HOST = 'localhost'
25 | _PORT = '8080'
26 |
27 |
28 | class TestPresto(unittest.TestCase, DBAPITestCase):
29 | __test__ = True
30 |
31 | def connect(self):
32 | return presto.connect(host=_HOST, port=_PORT, source=self.id())
33 |
34 | def test_bad_protocol(self):
35 | self.assertRaisesRegexp(ValueError, 'Protocol must be',
36 | lambda: presto.connect('localhost', protocol='nonsense').cursor())
37 |
38 | def test_escape_args(self):
39 | escaper = presto.PrestoParamEscaper()
40 |
41 | self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
42 | ("date '2020-04-17'",))
43 | self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
44 | ("timestamp '2020-04-17 12:00:00.123'",))
45 |
46 | @with_cursor
47 | def test_description(self, cursor):
48 | cursor.execute('SELECT 1 AS foobar FROM one_row')
49 | self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)])
50 | self.assertIsNotNone(cursor.last_query_id)
51 |
52 | @with_cursor
53 | def test_complex(self, cursor):
54 | cursor.execute('SELECT * FROM one_row_complex')
55 | # TODO Presto drops the union field
56 | if os.environ.get('PRESTO') == '0.147':
57 | tinyint_type = 'integer'
58 | smallint_type = 'integer'
59 | float_type = 'double'
60 | else:
61 | # some later version made these map to more specific types
62 | tinyint_type = 'tinyint'
63 | smallint_type = 'smallint'
64 | float_type = 'real'
65 | self.assertEqual(cursor.description, [
66 | ('boolean', 'boolean', None, None, None, None, True),
67 | ('tinyint', tinyint_type, None, None, None, None, True),
68 | ('smallint', smallint_type, None, None, None, None, True),
69 | ('int', 'integer', None, None, None, None, True),
70 | ('bigint', 'bigint', None, None, None, None, True),
71 | ('float', float_type, None, None, None, None, True),
72 | ('double', 'double', None, None, None, None, True),
73 | ('string', 'varchar', None, None, None, None, True),
74 | ('timestamp', 'timestamp', None, None, None, None, True),
75 | ('binary', 'varbinary', None, None, None, None, True),
76 | ('array', 'array(integer)', None, None, None, None, True),
77 | ('map', 'map(integer,integer)', None, None, None, None, True),
78 | ('struct', 'row(a integer,b integer)', None, None, None, None, True),
79 | # ('union', 'varchar', None, None, None, None, True),
80 | ('decimal', 'decimal(10,1)', None, None, None, None, True),
81 | ])
82 | rows = cursor.fetchall()
83 | expected = [(
84 | True,
85 | 127,
86 | 32767,
87 | 2147483647,
88 | 9223372036854775807,
89 | 0.5,
90 | 0.25,
91 | 'a string',
92 | '1970-01-01 00:00:00.000',
93 | b'123',
94 | [1, 2],
95 | {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
96 | [1, 2], # struct is returned as a list of elements
97 | # '{0:1}',
98 | Decimal('0.1'),
99 | )]
100 | self.assertEqual(rows, expected)
101 | # catch unicode/str
102 | self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
103 |
104 | @with_cursor
105 | def test_cancel(self, cursor):
106 | cursor.execute(
107 | "SELECT a.a * rand(), b.a * rand()"
108 | "FROM many_rows a "
109 | "CROSS JOIN many_rows b "
110 | )
111 | self.assertIn(cursor.poll()['stats']['state'], (
112 | 'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED'))
113 | cursor.cancel()
114 | self.assertIsNotNone(cursor.last_query_id)
115 | self.assertIsNone(cursor.poll())
116 |
117 | def test_noops(self):
118 | """The DB-API specification requires that certain actions exist, even though they might not
119 | be applicable."""
120 | # Wohoo inflating coverage stats!
121 | connection = self.connect()
122 | cursor = connection.cursor()
123 | self.assertEqual(cursor.rowcount, -1)
124 | cursor.setinputsizes([])
125 | cursor.setoutputsize(1, 'blah')
126 | self.assertIsNone(cursor.last_query_id)
127 | connection.commit()
128 |
129 | @mock.patch('requests.post')
130 | def test_non_200(self, post):
131 | cursor = self.connect().cursor()
132 | post.return_value.status_code = 404
133 | self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables'))
134 |
135 | @with_cursor
136 | def test_poll(self, cursor):
137 | self.assertRaises(presto.ProgrammingError, cursor.poll)
138 |
139 | cursor.execute('SELECT * FROM one_row')
140 | while True:
141 | status = cursor.poll()
142 | if status is None:
143 | break
144 | self.assertIn('stats', status)
145 |
146 | def fail(*args, **kwargs):
147 | self.fail("Should not need requests.get after done polling") # pragma: no cover
148 |
149 | with mock.patch('requests.get', fail):
150 | self.assertEqual(cursor.fetchall(), [(1,)])
151 |
152 | @with_cursor
153 | def test_set_session(self, cursor):
154 | id = None
155 | self.assertIsNone(cursor.last_query_id)
156 | cursor.execute("SET SESSION query_max_run_time = '1234m'")
157 | self.assertIsNotNone(cursor.last_query_id)
158 | id = cursor.last_query_id
159 | cursor.fetchall()
160 | self.assertEqual(id, cursor.last_query_id)
161 |
162 | cursor.execute('SHOW SESSION')
163 | self.assertIsNotNone(cursor.last_query_id)
164 | self.assertNotEqual(id, cursor.last_query_id)
165 | id = cursor.last_query_id
166 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
167 | self.assertEqual(len(rows), 1)
168 | session_prop = rows[0]
169 | self.assertEqual(session_prop[1], '1234m')
170 | self.assertEqual(id, cursor.last_query_id)
171 |
172 | cursor.execute('RESET SESSION query_max_run_time')
173 | self.assertIsNotNone(cursor.last_query_id)
174 | self.assertNotEqual(id, cursor.last_query_id)
175 | id = cursor.last_query_id
176 | cursor.fetchall()
177 | self.assertEqual(id, cursor.last_query_id)
178 |
179 | cursor.execute('SHOW SESSION')
180 | self.assertIsNotNone(cursor.last_query_id)
181 | self.assertNotEqual(id, cursor.last_query_id)
182 | id = cursor.last_query_id
183 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
184 | self.assertEqual(len(rows), 1)
185 | session_prop = rows[0]
186 | self.assertNotEqual(session_prop[1], '1234m')
187 | self.assertEqual(id, cursor.last_query_id)
188 |
189 | def test_set_session_in_constructor(self):
190 | conn = presto.connect(
191 | host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'}
192 | )
193 | with contextlib.closing(conn):
194 | with contextlib.closing(conn.cursor()) as cursor:
195 | cursor.execute('SHOW SESSION')
196 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
197 | assert len(rows) == 1
198 | session_prop = rows[0]
199 | assert session_prop[1] == '1234m'
200 |
201 | cursor.execute('RESET SESSION query_max_run_time')
202 | cursor.fetchall()
203 |
204 | cursor.execute('SHOW SESSION')
205 | rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time']
206 | assert len(rows) == 1
207 | session_prop = rows[0]
208 | assert session_prop[1] != '1234m'
209 |
210 | def test_invalid_protocol_config(self):
211 | """protocol should be https when passing password"""
212 | self.assertRaisesRegexp(
213 | ValueError, 'Protocol.*https.*password', lambda: presto.connect(
214 | host=_HOST, username='user', password='secret', protocol='http').cursor()
215 | )
216 |
217 | def test_invalid_password_and_kwargs(self):
218 | """password and requests_kwargs authentication are incompatible"""
219 | self.assertRaisesRegexp(
220 | ValueError, 'Cannot use both', lambda: presto.connect(
221 | host=_HOST, username='user', password='secret', protocol='https',
222 | requests_kwargs={'auth': requests.auth.HTTPBasicAuth('user', 'secret')}
223 | ).cursor()
224 | )
225 |
226 | def test_invalid_kwargs(self):
227 | """some kwargs are reserved"""
228 | self.assertRaisesRegexp(
229 | ValueError, 'Cannot override', lambda: presto.connect(
230 | host=_HOST, username='user', requests_kwargs={'url': 'test'}
231 | ).cursor()
232 | )
233 |
234 | def test_requests_kwargs(self):
235 | connection = presto.connect(
236 | host=_HOST, port=_PORT, source=self.id(),
237 | requests_kwargs={'proxies': {'http': 'localhost:9999'}},
238 | )
239 | cursor = connection.cursor()
240 | self.assertRaises(requests.exceptions.ProxyError,
241 | lambda: cursor.execute('SELECT * FROM one_row'))
242 |
243 | def test_requests_session(self):
244 | with requests.Session() as session:
245 | connection = presto.connect(
246 | host=_HOST, port=_PORT, source=self.id(), requests_session=session
247 | )
248 | cursor = connection.cursor()
249 | cursor.execute('SELECT * FROM one_row')
250 | self.assertEqual(cursor.fetchall(), [(1,)])
251 |
--------------------------------------------------------------------------------
/pyhive/tests/test_sasl_compat.py:
--------------------------------------------------------------------------------
1 | '''
2 | http://www.opensource.org/licenses/mit-license.php
3 |
4 | Copyright 2007-2011 David Alan Cridland
5 | Copyright 2011 Lance Stout
6 | Copyright 2012 Tyler L Hobbs
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this
9 | software and associated documentation files (the "Software"), to deal in the Software
10 | without restriction, including without limitation the rights to use, copy, modify, merge,
11 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
12 | to whom the Software is furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all copies or
15 | substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
18 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
19 | PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
20 | FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
21 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22 | DEALINGS IN THE SOFTWARE.
23 | '''
24 | # This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit
25 | # and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py.
26 |
27 | import unittest
28 | import base64
29 | import hashlib
30 | import hmac
31 | import kerberos
32 | from mock import patch
33 | import six
34 | import struct
35 | from puresasl import SASLProtocolException, QOP
36 | from puresasl.client import SASLError
37 | from pyhive.sasl_compat import PureSASLClient, error_catcher
38 |
39 |
40 | class TestPureSASLClient(unittest.TestCase):
41 | """Test cases for initialization of SASL client using PureSASLClient class"""
42 |
43 | def setUp(self):
44 | self.sasl_kwargs = {}
45 | self.sasl = PureSASLClient('localhost', **self.sasl_kwargs)
46 |
47 | def test_start_no_mechanism(self):
48 | """Test starting SASL authentication with no mechanism."""
49 | success, mechanism, response = self.sasl.start(mechanism=None)
50 | self.assertFalse(success)
51 | self.assertIsNone(mechanism)
52 | self.assertIsNone(response)
53 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
54 |
55 | def test_start_wrong_mechanism(self):
56 | """Test starting SASL authentication with a single unsupported mechanism."""
57 | success, mechanism, response = self.sasl.start(mechanism='WRONG')
58 | self.assertFalse(success)
59 | self.assertEqual(mechanism, 'WRONG')
60 | self.assertIsNone(response)
61 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
62 |
63 | def test_start_list_of_invalid_mechanisms(self):
64 | """Test starting SASL authentication with a list of unsupported mechanisms."""
65 | self.sasl.start(['invalid1', 'invalid2'])
66 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
67 |
68 | def test_start_list_of_valid_mechanisms(self):
69 | """Test starting SASL authentication with a list of supported mechanisms."""
70 | self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5'])
71 | # Validate right mechanism is chosen based on score.
72 | self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5')
73 |
74 | def test_error_catcher_no_error(self):
75 | """Test the error_catcher with no error."""
76 | with error_catcher(self.sasl):
77 | result, _, _ = self.sasl.start(mechanism='ANONYMOUS')
78 |
79 | self.assertEqual(self.sasl.getError(), None)
80 | self.assertEqual(result, True)
81 |
82 | def test_error_catcher_with_error(self):
83 | """Test the error_catcher with an error."""
84 | with error_catcher(self.sasl):
85 | result, _, _ = self.sasl.start(mechanism='WRONG')
86 |
87 | self.assertEqual(result, False)
88 | self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
89 |
90 | """Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims"""
91 |
92 | class _BaseMechanismTests(unittest.TestCase):
93 | """Base test case for SASL mechanisms."""
94 |
95 | mechanism = 'ANONYMOUS'
96 | sasl_kwargs = {}
97 |
98 | def setUp(self):
99 | self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs)
100 | self.mechanism_class = self.sasl._chosen_mech
101 |
102 | def test_init_basic(self, *args):
103 | sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs)
104 | mech = sasl._chosen_mech
105 | self.assertIs(mech.sasl, sasl)
106 |
107 | def test_step_basic(self, *args):
108 | success, response = self.sasl.step(six.b('string'))
109 | self.assertTrue(success)
110 | self.assertIsInstance(response, six.binary_type)
111 |
112 | def test_decode_encode(self, *args):
113 | self.assertEqual(self.sasl.encode('msg'), (False, None))
114 | self.assertEqual(self.sasl.getError(), '')
115 | self.assertEqual(self.sasl.decode('msg'), (False, None))
116 | self.assertEqual(self.sasl.getError(), '')
117 |
118 |
119 | class AnonymousMechanismTest(_BaseMechanismTests):
120 | """Test case for the Anonymous SASL mechanism."""
121 |
122 | mechanism = 'ANONYMOUS'
123 |
124 |
125 | class PlainTextMechanismTest(_BaseMechanismTests):
126 | """Test case for the PlainText SASL mechanism."""
127 |
128 | mechanism = 'PLAIN'
129 | username = 'user'
130 | password = 'pass'
131 | sasl_kwargs = {'username': username, 'password': password}
132 |
133 | def test_step(self):
134 | for challenge in (None, '', b'asdf', u"\U0001F44D"):
135 | success, response = self.sasl.step(challenge)
136 | self.assertTrue(success)
137 | self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}'))
138 | self.assertIsInstance(response, six.binary_type)
139 |
140 | def test_step_with_authorization_id_or_identity(self):
141 | challenge = u"\U0001F44D"
142 | identity = 'user2'
143 |
144 | # Test that we can pass an identity
145 | sasl_kwargs = self.sasl_kwargs.copy()
146 | sasl_kwargs.update({'identity': identity})
147 | sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs)
148 | success, response = sasl.step(challenge)
149 | self.assertTrue(success)
150 | self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}'))
151 | self.assertIsInstance(response, six.binary_type)
152 | self.assertTrue(sasl.complete)
153 |
154 | # Test that the sasl authorization_id has priority over identity
155 | auth_id = 'user3'
156 | sasl_kwargs.update({'authorization_id': auth_id})
157 | sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs)
158 | success, response = sasl.step(challenge)
159 | self.assertTrue(success)
160 | self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}'))
161 | self.assertIsInstance(response, six.binary_type)
162 | self.assertTrue(sasl.complete)
163 |
164 | def test_decode_encode(self):
165 | msg = 'msg'
166 | self.assertEqual(self.sasl.decode(msg), (True, msg))
167 | self.assertEqual(self.sasl.encode(msg), (True, msg))
168 |
169 |
170 | class ExternalMechanismTest(_BaseMechanismTests):
171 | """Test case for the External SASL mechanisms"""
172 |
173 | mechanism = 'EXTERNAL'
174 |
175 | def test_step(self):
176 | self.assertEqual(self.sasl.step(), (True, b''))
177 |
178 | def test_decode_encode(self):
179 | msg = 'msg'
180 | self.assertEqual(self.sasl.decode(msg), (True, msg))
181 | self.assertEqual(self.sasl.encode(msg), (True, msg))
182 |
183 |
184 | @patch('puresasl.mechanisms.kerberos.authGSSClientStep')
185 | @patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response')))
186 | class GSSAPIMechanismTest(_BaseMechanismTests):
187 | """Test case for the GSSAPI SASL mechanism."""
188 |
189 | mechanism = 'GSSAPI'
190 | service = 'GSSAPI'
191 | sasl_kwargs = {'service': service}
192 |
193 | @patch('puresasl.mechanisms.kerberos.authGSSClientWrap')
194 | @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap')
195 | def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args):
196 | # bypassing step setup by setting qop directly
197 | self.mechanism_class.qop = QOP.AUTH
198 | msg = b'msg'
199 | self.assertEqual(self.sasl.decode(msg), (True, msg))
200 | self.assertEqual(self.sasl.encode(msg), (True, msg))
201 |
202 | # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication
203 | for qop in (QOP.AUTH_INT, QOP.AUTH_CONF):
204 | self.mechanism_class.qop = qop
205 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1):
206 | self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
207 | self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
208 | if qop == QOP.AUTH_CONF:
209 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0):
210 | self.assertEqual(self.sasl.encode(msg), (False, None))
211 | self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.')
212 |
213 | def test_step_no_user(self, authGSSClientResponse, *args):
214 | msg = six.b('whatever')
215 |
216 | # no user
217 | self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
218 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''):
219 | self.assertEqual(self.sasl.step(msg), (True, six.b('')))
220 |
221 | username = 'username'
222 | # with user; this has to be last because it sets mechanism.user
223 | with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE):
224 | with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)):
225 | self.assertEqual(self.sasl.step(msg), (True, six.b('')))
226 | self.assertEqual(self.mechanism_class.user, six.b(username))
227 |
228 | @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap')
229 | def test_step_qop(self, *args):
230 | self.mechanism_class._have_negotiated_details = True
231 | self.mechanism_class.user = 'user'
232 | msg = six.b('msg')
233 | self.assertEqual(self.sasl.step(msg), (False, None))
234 | self.assertEqual(self.sasl.getError(), 'Bad response from server')
235 |
236 | max_len = 100
237 | self.assertLess(max_len, self.sasl.max_buffer)
238 | for i, qop in QOP.bit_map.items():
239 | qop_size = struct.pack('!i', i << 24 | max_len)
240 | response = base64.b64encode(qop_size)
241 | with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response):
242 | with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap:
243 | self.mechanism_class.complete = False
244 | self.assertEqual(self.sasl.step(msg), (True, qop_size))
245 | self.assertTrue(self.mechanism_class.complete)
246 | self.assertEqual(self.mechanism_class.qop, qop)
247 | self.assertEqual(self.mechanism_class.max_buffer, max_len)
248 |
249 | args = authGSSClientWrap.call_args[0]
250 | out_data = args[1]
251 | out = base64.b64decode(out_data)
252 | self.assertEqual(out[:4], qop_size)
253 | self.assertEqual(out[4:], six.b(self.mechanism_class.user))
254 |
255 |
256 | class CramMD5MechanismTest(_BaseMechanismTests):
257 | """Test case for the CRAM-MD5 SASL mechanism."""
258 |
259 | mechanism = 'CRAM-MD5'
260 | username = 'user'
261 | password = 'pass'
262 | sasl_kwargs = {'username': username, 'password': password}
263 |
264 | def test_step(self):
265 | success, response = self.sasl.step(None)
266 | self.assertTrue(success)
267 | self.assertIsNone(response)
268 | challenge = six.b('msg')
269 | hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5)
270 | hash.update(challenge)
271 | success, response = self.sasl.step(challenge)
272 | self.assertTrue(success)
273 | self.assertIn(six.b(self.username), response)
274 | self.assertIn(six.b(hash.hexdigest()), response)
275 | self.assertIsInstance(response, six.binary_type)
276 | self.assertTrue(self.sasl.complete)
277 |
278 | def test_decode_encode(self):
279 | msg = 'msg'
280 | self.assertEqual(self.sasl.decode(msg), (True, msg))
281 | self.assertEqual(self.sasl.encode(msg), (True, msg))
282 |
283 |
284 | class DigestMD5MechanismTest(_BaseMechanismTests):
285 | """Test case for the DIGEST-MD5 SASL mechanism."""
286 |
287 | mechanism = 'DIGEST-MD5'
288 | username = 'user'
289 | password = 'pass'
290 | sasl_kwargs = {'username': username, 'password': password}
291 |
292 | def test_decode_encode(self):
293 | msg = 'msg'
294 | self.assertEqual(self.sasl.decode(msg), (True, msg))
295 | self.assertEqual(self.sasl.encode(msg), (True, msg))
296 |
297 | def test_step_basic(self, *args):
298 | pass
299 |
300 | def test_step(self):
301 | """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism."""
302 | testChallenge = (
303 | b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r'
304 | b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc'
305 | b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess'
306 | )
307 | result, response = self.sasl.step(testChallenge)
308 | self.assertTrue(result)
309 | self.assertIsNotNone(response)
310 |
311 | def test_step_server_answer(self):
312 | """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism."""
313 | sasl_kwargs = {'username': "chris", 'password': "secret"}
314 | sasl = PureSASLClient('elwood.innosoft.com',
315 | service="imap",
316 | mechanism=self.mechanism,
317 | mutual_auth=True,
318 | **sasl_kwargs)
319 | testChallenge = (
320 | b'utf-8,username="chris",realm="elwood.innosoft.com",'
321 | b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",'
322 | b'digest-uri="imap/elwood.innosoft.com",'
323 | b'response=d388dad90d4bbd760a152321f2143af7,qop=auth'
324 | )
325 | sasl.step(testChallenge)
326 | sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk"
327 |
328 | serverResponse = (
329 | b'rspauth=ea40f60335c427b5527b84dbabcdfffd'
330 | )
331 | sasl.step(serverResponse)
332 | # assert that step choses the only supported QOP for for DIGEST-MD5
333 | self.assertEqual(self.sasl.qop, QOP.AUTH)
334 |
--------------------------------------------------------------------------------
/pyhive/tests/test_sqlalchemy_hive.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from builtins import str
4 | from pyhive.sqlalchemy_hive import HiveDate
5 | from pyhive.sqlalchemy_hive import HiveDecimal
6 | from pyhive.sqlalchemy_hive import HiveTimestamp
7 | from sqlalchemy.exc import NoSuchTableError, OperationalError
8 | from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
9 | from pyhive.tests.sqlalchemy_test_case import with_engine_connection
10 | from sqlalchemy import types
11 | from sqlalchemy.engine import create_engine
12 | from sqlalchemy.schema import Column
13 | from sqlalchemy.schema import MetaData
14 | from sqlalchemy.schema import Table
15 | from sqlalchemy.sql import text
16 | import contextlib
17 | import datetime
18 | import decimal
19 | import sqlalchemy.types
20 | import unittest
21 | import re
22 |
23 | sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
24 |
25 | _ONE_ROW_COMPLEX_CONTENTS = [
26 | True,
27 | 127,
28 | 32767,
29 | 2147483647,
30 | 9223372036854775807,
31 | 0.5,
32 | 0.25,
33 | 'a string',
34 | datetime.datetime(1970, 1, 1),
35 | b'123',
36 | '[1,2]',
37 | '{1:2,3:4}',
38 | '{"a":1,"b":2}',
39 | '{0:1}',
40 | decimal.Decimal('0.1'),
41 | ]
42 |
43 |
44 | # [
45 | # ('boolean', 'boolean', ''),
46 | # ('tinyint', 'tinyint', ''),
47 | # ('smallint', 'smallint', ''),
48 | # ('int', 'int', ''),
49 | # ('bigint', 'bigint', ''),
50 | # ('float', 'float', ''),
51 | # ('double', 'double', ''),
52 | # ('string', 'string', ''),
53 | # ('timestamp', 'timestamp', ''),
54 | # ('binary', 'binary', ''),
55 | # ('array', 'array', ''),
56 | # ('map', 'map', ''),
57 | # ('struct', 'struct', ''),
58 | # ('union', 'uniontype', ''),
59 | # ('decimal', 'decimal(10,1)', '')
60 | # ]
61 |
62 |
63 | class TestSqlAlchemyHive(unittest.TestCase, SqlAlchemyTestCase):
64 | def create_engine(self):
65 | return create_engine('hive://localhost:10000/default')
66 |
67 | @with_engine_connection
68 | def test_dotted_column_names(self, engine, connection):
69 | """When Hive returns a dotted column name, both the non-dotted version should be available
70 | as an attribute, and the dotted version should remain available as a key.
71 | """
72 | row = connection.execute(text('SELECT * FROM one_row')).fetchone()
73 |
74 | if sqlalchemy_version >= 1.4:
75 | row = row._mapping
76 |
77 | assert row.keys() == ['number_of_rows']
78 | assert 'number_of_rows' in row
79 | assert row.number_of_rows == 1
80 | assert row['number_of_rows'] == 1
81 | assert getattr(row, 'one_row.number_of_rows') == 1
82 | assert row['one_row.number_of_rows'] == 1
83 |
84 | @with_engine_connection
85 | def test_dotted_column_names_raw(self, engine, connection):
86 | """When Hive returns a dotted column name, and raw mode is on, nothing should be modified.
87 | """
88 | row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone()
89 |
90 | if sqlalchemy_version >= 1.4:
91 | row = row._mapping
92 |
93 | assert row.keys() == ['one_row.number_of_rows']
94 | assert 'number_of_rows' not in row
95 | assert getattr(row, 'one_row.number_of_rows') == 1
96 | assert row['one_row.number_of_rows'] == 1
97 |
98 | @with_engine_connection
99 | def test_reflect_no_such_table(self, engine, connection):
100 | """reflecttable should throw an exception on an invalid table"""
101 | self.assertRaises(
102 | NoSuchTableError,
103 | lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
104 | self.assertRaises(
105 | OperationalError,
106 | lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine))
107 |
108 | @with_engine_connection
109 | def test_reflect_select(self, engine, connection):
110 | """reflecttable should be able to fill in a table from the name"""
111 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
112 | self.assertEqual(len(one_row_complex.c), 15)
113 | self.assertIsInstance(one_row_complex.c.string, Column)
114 | row = connection.execute(one_row_complex.select()).fetchone()
115 | self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
116 |
117 | # TODO some of these types could be filled in better
118 | self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean)
119 | self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer)
120 | self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer)
121 | self.assertIsInstance(one_row_complex.c.int.type, types.Integer)
122 | self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger)
123 | self.assertIsInstance(one_row_complex.c.float.type, types.Float)
124 | self.assertIsInstance(one_row_complex.c.double.type, types.Float)
125 | self.assertIsInstance(one_row_complex.c.string.type, types.String)
126 | self.assertIsInstance(one_row_complex.c.timestamp.type, HiveTimestamp)
127 | self.assertIsInstance(one_row_complex.c.binary.type, types.String)
128 | self.assertIsInstance(one_row_complex.c.array.type, types.String)
129 | self.assertIsInstance(one_row_complex.c.map.type, types.String)
130 | self.assertIsInstance(one_row_complex.c.struct.type, types.String)
131 | self.assertIsInstance(one_row_complex.c.union.type, types.String)
132 | self.assertIsInstance(one_row_complex.c.decimal.type, HiveDecimal)
133 |
134 | @with_engine_connection
135 | def test_type_map(self, engine, connection):
136 | """sqlalchemy should use the dbapi_type_map to infer types from raw queries"""
137 | row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone()
138 | self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
139 |
140 | @with_engine_connection
141 | def test_reserved_words(self, engine, connection):
142 | """Hive uses backticks"""
143 | # Use keywords for the table/column name
144 | fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String))
145 | query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine))
146 | self.assertIn('`select`', query)
147 | self.assertIn('`map`', query)
148 | self.assertNotIn('"select"', query)
149 | self.assertNotIn('"map"', query)
150 |
151 | def test_switch_database(self):
152 | engine = create_engine('hive://localhost:10000/pyhive_test_database')
153 | try:
154 | with contextlib.closing(engine.connect()) as connection:
155 | self.assertIn(
156 | ('dummy_table',),
157 | connection.execute(text('SHOW TABLES')).fetchall()
158 | )
159 | connection.execute(text('USE default'))
160 | self.assertIn(
161 | ('one_row',),
162 | connection.execute(text('SHOW TABLES')).fetchall()
163 | )
164 | finally:
165 | engine.dispose()
166 |
167 | @with_engine_connection
168 | def test_lots_of_types(self, engine, connection):
169 | # Presto doesn't have raw CREATE TABLE support, so we ony test hive
170 | # take type list from sqlalchemy.types
171 | types = [
172 | 'INT', 'CHAR', 'VARCHAR', 'NCHAR', 'TEXT', 'Text', 'FLOAT',
173 | 'NUMERIC', 'DECIMAL', 'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB',
174 | 'BOOLEAN', 'SMALLINT', 'DATE', 'TIME',
175 | 'String', 'Integer', 'SmallInteger',
176 | 'Numeric', 'Float', 'DateTime', 'Date', 'Time', 'LargeBinary',
177 | 'Boolean', 'Unicode', 'UnicodeText',
178 | ]
179 | cols = []
180 | for i, t in enumerate(types):
181 | cols.append(Column(str(i), getattr(sqlalchemy.types, t)))
182 | cols.append(Column('hive_date', HiveDate))
183 | cols.append(Column('hive_decimal', HiveDecimal))
184 | cols.append(Column('hive_timestamp', HiveTimestamp))
185 | table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,)
186 | table.drop(checkfirst=True, bind=connection)
187 | table.create(bind=connection)
188 | connection.execute(text('SET mapred.job.tracker=local'))
189 | connection.execute(text('USE pyhive_test_database'))
190 | big_number = 10 ** 10 - 1
191 | connection.execute(text("""
192 | INSERT OVERWRITE TABLE test_table
193 | SELECT
194 | 1, "a", "a", "a", "a", "a", 0.1,
195 | 0.1, 0.1, 0, 0, "a", "a",
196 | false, 1, 0, 0,
197 | "a", 1, 1,
198 | 0.1, 0.1, 0, 0, 0, "a",
199 | false, "a", "a",
200 | 0, :big_number, 123 + 2000
201 | FROM default.one_row
202 | """), {"big_number": big_number})
203 | row = connection.execute(text("select * from test_table")).fetchone()
204 | self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0))
205 | self.assertEqual(row.hive_decimal, decimal.Decimal(big_number))
206 | self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000))
207 | table.drop(bind=connection)
208 |
209 | @with_engine_connection
210 | def test_insert_select(self, engine, connection):
211 | one_row = Table('one_row', MetaData(), autoload_with=engine)
212 | table = Table('insert_test', MetaData(schema='pyhive_test_database'),
213 | Column('a', sqlalchemy.types.Integer))
214 | table.drop(checkfirst=True, bind=connection)
215 | table.create(bind=connection)
216 | connection.execute(text('SET mapred.job.tracker=local'))
217 | # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES
218 | connection.execute(table.insert().from_select(['a'], one_row.select()))
219 |
220 | result = connection.execute(table.select()).fetchall()
221 | expected = [(1,)]
222 | self.assertEqual(result, expected)
223 |
224 | @with_engine_connection
225 | def test_insert_values(self, engine, connection):
226 | table = Table('insert_test', MetaData(schema='pyhive_test_database'),
227 | Column('a', sqlalchemy.types.Integer),)
228 | table.drop(checkfirst=True, bind=connection)
229 | table.create(bind=connection)
230 | connection.execute(table.insert().values([{'a': 1}, {'a': 2}]))
231 |
232 | result = connection.execute(table.select()).fetchall()
233 | expected = [(1,), (2,)]
234 | self.assertEqual(result, expected)
235 |
236 | @with_engine_connection
237 | def test_supports_san_rowcount(self, engine, connection):
238 | self.assertFalse(engine.dialect.supports_sane_rowcount_returning)
239 |
--------------------------------------------------------------------------------
/pyhive/tests/test_sqlalchemy_presto.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from builtins import str
4 | from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
5 | from pyhive.tests.sqlalchemy_test_case import with_engine_connection
6 | from sqlalchemy import types
7 | from sqlalchemy.engine import create_engine
8 | from sqlalchemy.schema import Column
9 | from sqlalchemy.schema import MetaData
10 | from sqlalchemy.schema import Table
11 | from sqlalchemy.sql import text
12 | from sqlalchemy.types import String
13 | from decimal import Decimal
14 |
15 | import contextlib
16 | import unittest
17 |
18 |
19 | class TestSqlAlchemyPresto(unittest.TestCase, SqlAlchemyTestCase):
20 | def create_engine(self):
21 | return create_engine('presto://localhost:8080/hive/default?source={}'.format(self.id()))
22 |
23 | def test_bad_format(self):
24 | self.assertRaises(
25 | ValueError,
26 | lambda: create_engine('presto://localhost:8080/hive/default/what'),
27 | )
28 |
29 | @with_engine_connection
30 | def test_reflect_select(self, engine, connection):
31 | """reflecttable should be able to fill in a table from the name"""
32 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
33 | # Presto ignores the union column
34 | self.assertEqual(len(one_row_complex.c), 15 - 1)
35 | self.assertIsInstance(one_row_complex.c.string, Column)
36 | rows = connection.execute(one_row_complex.select()).fetchall()
37 | self.assertEqual(len(rows), 1)
38 | self.assertEqual(list(rows[0]), [
39 | True,
40 | 127,
41 | 32767,
42 | 2147483647,
43 | 9223372036854775807,
44 | 0.5,
45 | 0.25,
46 | 'a string',
47 | '1970-01-01 00:00:00.000',
48 | b'123',
49 | [1, 2],
50 | {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
51 | [1, 2], # struct is returned as a list of elements
52 | # '{0:1}',
53 | Decimal('0.1'),
54 | ])
55 |
56 | # TODO some of these types could be filled in better
57 | self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean)
58 | self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer)
59 | self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer)
60 | self.assertIsInstance(one_row_complex.c.int.type, types.Integer)
61 | self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger)
62 | self.assertIsInstance(one_row_complex.c.float.type, types.Float)
63 | self.assertIsInstance(one_row_complex.c.double.type, types.Float)
64 | self.assertIsInstance(one_row_complex.c.string.type, String)
65 | self.assertIsInstance(one_row_complex.c.timestamp.type, types.TIMESTAMP)
66 | self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY)
67 | self.assertIsInstance(one_row_complex.c.array.type, types.NullType)
68 | self.assertIsInstance(one_row_complex.c.map.type, types.NullType)
69 | self.assertIsInstance(one_row_complex.c.struct.type, types.NullType)
70 | self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType)
71 |
72 | def test_url_default(self):
73 | engine = create_engine('presto://localhost:8080/hive')
74 | try:
75 | with contextlib.closing(engine.connect()) as connection:
76 | self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1)
77 | finally:
78 | engine.dispose()
79 |
80 | @with_engine_connection
81 | def test_reserved_words(self, engine, connection):
82 | """Presto uses double quotes, not backticks"""
83 | # Use keywords for the table/column name
84 | fake_table = Table('select', MetaData(), Column('current_timestamp', String))
85 | query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine))
86 | self.assertIn('"select"', query)
87 | self.assertIn('"current_timestamp"', query)
88 | self.assertNotIn('`select`', query)
89 | self.assertNotIn('`current_timestamp`', query)
90 |
--------------------------------------------------------------------------------
/pyhive/tests/test_sqlalchemy_trino.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy.engine import create_engine
2 | from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
3 | from pyhive.tests.sqlalchemy_test_case import with_engine_connection
4 | from sqlalchemy.exc import NoSuchTableError, DatabaseError
5 | from sqlalchemy.schema import MetaData, Table, Column
6 | from sqlalchemy.types import String
7 | from sqlalchemy.sql import text
8 | from sqlalchemy import types
9 | from decimal import Decimal
10 |
11 | import unittest
12 | import contextlib
13 |
14 |
15 | class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase):
16 | def create_engine(self):
17 | return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id()))
18 |
19 | def test_bad_format(self):
20 | self.assertRaises(
21 | ValueError,
22 | lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'),
23 | )
24 |
25 | @with_engine_connection
26 | def test_reflect_select(self, engine, connection):
27 | """reflecttable should be able to fill in a table from the name"""
28 | one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
29 | # Presto ignores the union column
30 | self.assertEqual(len(one_row_complex.c), 15 - 1)
31 | self.assertIsInstance(one_row_complex.c.string, Column)
32 | rows = connection.execute(one_row_complex.select()).fetchall()
33 | self.assertEqual(len(rows), 1)
34 | self.assertEqual(list(rows[0]), [
35 | True,
36 | 127,
37 | 32767,
38 | 2147483647,
39 | 9223372036854775807,
40 | 0.5,
41 | 0.25,
42 | 'a string',
43 | '1970-01-01 00:00:00.000',
44 | b'123',
45 | [1, 2],
46 | {"1": 2, "3": 4},
47 | [1, 2],
48 | Decimal('0.1'),
49 | ])
50 |
51 | self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean)
52 | self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer)
53 | self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer)
54 | self.assertIsInstance(one_row_complex.c.int.type, types.Integer)
55 | self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger)
56 | self.assertIsInstance(one_row_complex.c.float.type, types.Float)
57 | self.assertIsInstance(one_row_complex.c.double.type, types.Float)
58 | self.assertIsInstance(one_row_complex.c.string.type, String)
59 | self.assertIsInstance(one_row_complex.c.timestamp.type, types.NullType)
60 | self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY)
61 | self.assertIsInstance(one_row_complex.c.array.type, types.NullType)
62 | self.assertIsInstance(one_row_complex.c.map.type, types.NullType)
63 | self.assertIsInstance(one_row_complex.c.struct.type, types.NullType)
64 | self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType)
65 |
66 | @with_engine_connection
67 | def test_reflect_no_such_table(self, engine, connection):
68 | """reflecttable should throw an exception on an invalid table"""
69 | self.assertRaises(
70 | NoSuchTableError,
71 | lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
72 | self.assertRaises(
73 | DatabaseError,
74 | lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine))
75 |
76 | def test_url_default(self):
77 | engine = create_engine('trino+pyhive://localhost:18080/hive')
78 | try:
79 | with contextlib.closing(engine.connect()) as connection:
80 | self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1)
81 | finally:
82 | engine.dispose()
83 |
84 | @with_engine_connection
85 | def test_reserved_words(self, engine, connection):
86 | """Trino uses double quotes, not backticks"""
87 | # Use keywords for the table/column name
88 | fake_table = Table('select', MetaData(), Column('current_timestamp', String))
89 | query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine))
90 | self.assertIn('"select"', query)
91 | self.assertIn('"current_timestamp"', query)
92 | self.assertNotIn('`select`', query)
93 | self.assertNotIn('`current_timestamp`', query)
94 |
--------------------------------------------------------------------------------
/pyhive/tests/test_trino.py:
--------------------------------------------------------------------------------
1 | """Trino integration tests.
2 |
3 | These rely on having a Trino+Hadoop cluster set up.
4 | They also require a tables created by make_test_tables.sh.
5 | """
6 |
7 | from __future__ import absolute_import
8 | from __future__ import unicode_literals
9 |
10 | import contextlib
11 | import os
12 | from decimal import Decimal
13 |
14 | import requests
15 |
16 | from pyhive import exc
17 | from pyhive import trino
18 | from pyhive.tests.dbapi_test_case import DBAPITestCase
19 | from pyhive.tests.dbapi_test_case import with_cursor
20 | from pyhive.tests.test_presto import TestPresto
21 | import mock
22 | import unittest
23 | import datetime
24 |
25 | _HOST = 'localhost'
26 | _PORT = '18080'
27 |
28 |
29 | class TestTrino(TestPresto):
30 | __test__ = True
31 |
32 | def connect(self):
33 | return trino.connect(host=_HOST, port=_PORT, source=self.id())
34 |
35 | def test_bad_protocol(self):
36 | self.assertRaisesRegexp(ValueError, 'Protocol must be',
37 | lambda: trino.connect('localhost', protocol='nonsense').cursor())
38 |
39 | def test_escape_args(self):
40 | escaper = trino.TrinoParamEscaper()
41 |
42 | self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)),
43 | ("date '2020-04-17'",))
44 | self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)),
45 | ("timestamp '2020-04-17 12:00:00.123'",))
46 |
47 | @with_cursor
48 | def test_description(self, cursor):
49 | cursor.execute('SELECT 1 AS foobar FROM one_row')
50 | self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)])
51 | self.assertIsNotNone(cursor.last_query_id)
52 |
53 | @with_cursor
54 | def test_complex(self, cursor):
55 | cursor.execute('SELECT * FROM one_row_complex')
56 | # TODO Trino drops the union field
57 |
58 | tinyint_type = 'tinyint'
59 | smallint_type = 'smallint'
60 | float_type = 'real'
61 | self.assertEqual(cursor.description, [
62 | ('boolean', 'boolean', None, None, None, None, True),
63 | ('tinyint', tinyint_type, None, None, None, None, True),
64 | ('smallint', smallint_type, None, None, None, None, True),
65 | ('int', 'integer', None, None, None, None, True),
66 | ('bigint', 'bigint', None, None, None, None, True),
67 | ('float', float_type, None, None, None, None, True),
68 | ('double', 'double', None, None, None, None, True),
69 | ('string', 'varchar', None, None, None, None, True),
70 | ('timestamp', 'timestamp', None, None, None, None, True),
71 | ('binary', 'varbinary', None, None, None, None, True),
72 | ('array', 'array(integer)', None, None, None, None, True),
73 | ('map', 'map(integer,integer)', None, None, None, None, True),
74 | ('struct', 'row(a integer,b integer)', None, None, None, None, True),
75 | # ('union', 'varchar', None, None, None, None, True),
76 | ('decimal', 'decimal(10,1)', None, None, None, None, True),
77 | ])
78 | rows = cursor.fetchall()
79 | expected = [(
80 | True,
81 | 127,
82 | 32767,
83 | 2147483647,
84 | 9223372036854775807,
85 | 0.5,
86 | 0.25,
87 | 'a string',
88 | '1970-01-01 00:00:00.000',
89 | b'123',
90 | [1, 2],
91 | {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON
92 | [1, 2], # struct is returned as a list of elements
93 | # '{0:1}',
94 | Decimal('0.1'),
95 | )]
96 | self.assertEqual(rows, expected)
97 | # catch unicode/str
98 | self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
--------------------------------------------------------------------------------
/pyhive/trino.py:
--------------------------------------------------------------------------------
1 | """DB-API implementation backed by Trino
2 |
3 | See http://www.python.org/dev/peps/pep-0249/
4 |
5 | Many docstrings in this file are based on the PEP, which is in the public domain.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import unicode_literals
10 |
11 | import logging
12 |
13 | import requests
14 |
15 | # Make all exceptions visible in this module per DB-API
16 | from pyhive.common import DBAPITypeObject
17 | from pyhive.exc import * # noqa
18 | from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper
19 |
20 | try: # Python 3
21 | import urllib.parse as urlparse
22 | except ImportError: # Python 2
23 | import urlparse
24 |
25 | # PEP 249 module globals
26 | apilevel = '2.0'
27 | threadsafety = 2 # Threads may share the module and connections.
28 | paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
29 |
30 | _logger = logging.getLogger(__name__)
31 |
32 |
33 | class TrinoParamEscaper(PrestoParamEscaper):
34 | pass
35 |
36 |
37 | _escaper = TrinoParamEscaper()
38 |
39 |
40 | def connect(*args, **kwargs):
41 | """Constructor for creating a connection to the database. See class :py:class:`Connection` for
42 | arguments.
43 |
44 | :returns: a :py:class:`Connection` object.
45 | """
46 | return Connection(*args, **kwargs)
47 |
48 |
49 | class Connection(PrestoConnection):
50 | def __init__(self, *args, **kwargs):
51 | super().__init__(*args, **kwargs)
52 |
53 | def cursor(self):
54 | """Return a new :py:class:`Cursor` object using the connection."""
55 | return Cursor(*self._args, **self._kwargs)
56 |
57 |
58 | class Cursor(PrestoCursor):
59 | """These objects represent a database cursor, which is used to manage the context of a fetch
60 | operation.
61 |
62 | Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
63 | visible by other cursors or connections.
64 | """
65 |
66 | def execute(self, operation, parameters=None):
67 | """Prepare and execute a database operation (query or command).
68 |
69 | Return values are not defined.
70 | """
71 | headers = {
72 | 'X-Trino-Catalog': self._catalog,
73 | 'X-Trino-Schema': self._schema,
74 | 'X-Trino-Source': self._source,
75 | 'X-Trino-User': self._username,
76 | }
77 |
78 | if self._session_props:
79 | headers['X-Trino-Session'] = ','.join(
80 | '{}={}'.format(propname, propval)
81 | for propname, propval in self._session_props.items()
82 | )
83 |
84 | # Prepare statement
85 | if parameters is None:
86 | sql = operation
87 | else:
88 | sql = operation % _escaper.escape_args(parameters)
89 |
90 | self._reset_state()
91 |
92 | self._state = self._STATE_RUNNING
93 | url = urlparse.urlunparse((
94 | self._protocol,
95 | '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None))
96 | _logger.info('%s', sql)
97 | _logger.debug("Headers: %s", headers)
98 | response = self._requests_session.post(
99 | url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs)
100 | self._process_response(response)
101 |
102 | def _process_response(self, response):
103 | """Given the JSON response from Trino's REST API, update the internal state with the next
104 | URI and any data from the response
105 | """
106 | # TODO handle HTTP 503
107 | if response.status_code != requests.codes.ok:
108 | fmt = "Unexpected status code {}\n{}"
109 | raise OperationalError(fmt.format(response.status_code, response.content))
110 |
111 | response_json = response.json()
112 | _logger.debug("Got response %s", response_json)
113 | assert self._state == self._STATE_RUNNING, "Should be running if processing response"
114 | self._nextUri = response_json.get('nextUri')
115 | self._columns = response_json.get('columns')
116 | if 'id' in response_json:
117 | self.last_query_id = response_json['id']
118 | if 'X-Trino-Clear-Session' in response.headers:
119 | propname = response.headers['X-Trino-Clear-Session']
120 | self._session_props.pop(propname, None)
121 | if 'X-Trino-Set-Session' in response.headers:
122 | propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1)
123 | self._session_props[propname] = propval
124 | if 'data' in response_json:
125 | assert self._columns
126 | new_data = response_json['data']
127 | self._process_data(new_data)
128 | self._data += map(tuple, new_data)
129 | if 'nextUri' not in response_json:
130 | self._state = self._STATE_FINISHED
131 | if 'error' in response_json:
132 | raise DatabaseError(response_json['error'])
133 |
134 |
135 | #
136 | # Type Objects and Constructors
137 | #
138 |
139 |
140 | # See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java
141 | FIXED_INT_64 = DBAPITypeObject(['bigint'])
142 | VARIABLE_BINARY = DBAPITypeObject(['varchar'])
143 | DOUBLE = DBAPITypeObject(['double'])
144 | BOOLEAN = DBAPITypeObject(['boolean'])
145 |
--------------------------------------------------------------------------------
/scripts/ldap_config/slapd.conf:
--------------------------------------------------------------------------------
1 | # See slapd.conf(5) for details on configuration options.
2 | include /etc/ldap/schema/core.schema
3 | include /etc/ldap/schema/cosine.schema
4 | include /etc/ldap/schema/inetorgperson.schema
5 | include /etc/ldap/schema/nis.schema
6 |
7 | pidfile /tmp/slapd/slapd.pid
8 | argsfile /tmp/slapd/slapd.args
9 |
10 | modulepath /usr/lib/openldap
11 |
12 | database ldif
13 | directory /tmp/slapd
14 |
15 | suffix "dc=example,dc=com"
16 | rootdn "cn=admin,dc=example,dc=com"
17 | rootpw {SSHA}AIzygLSXlArhAMzddUriXQxf7UlkqopP
18 |
--------------------------------------------------------------------------------
/scripts/make_many_rows.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | temp_file=/tmp/pyhive_test_data_many_rows.tsv
4 | seq 0 9999 > $temp_file
5 |
6 | hive -e "
7 | DROP TABLE IF EXISTS many_rows;
8 | CREATE TABLE many_rows (
9 | a INT
10 | ) PARTITIONED BY (
11 | b STRING
12 | ) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE;
13 | LOAD DATA LOCAL INPATH '$temp_file' INTO TABLE many_rows PARTITION (b='blah');
14 | "
15 | rm -f $temp_file
16 |
--------------------------------------------------------------------------------
/scripts/make_one_row.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 | hive -e '
3 | set mapred.job.tracker=local;
4 | DROP TABLE IF EXISTS one_row;
5 | CREATE TABLE one_row (number_of_rows INT);
6 | INSERT INTO TABLE one_row VALUES (1);
7 | '
8 |
--------------------------------------------------------------------------------
/scripts/make_one_row_complex.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 |
3 | COLUMNS='
4 | `boolean` BOOLEAN,
5 | `tinyint` TINYINT,
6 | `smallint` SMALLINT,
7 | `int` INT,
8 | `bigint` BIGINT,
9 | `float` FLOAT,
10 | `double` DOUBLE,
11 | `string` STRING,
12 | `timestamp` TIMESTAMP,
13 | `binary` BINARY,
14 | `array` ARRAY,
15 | `map` MAP,
16 | `struct` STRUCT,
17 | `union` UNIONTYPE,
18 | `decimal` DECIMAL(10, 1)
19 | '
20 |
21 | hive -e "
22 | set mapred.job.tracker=local;
23 | DROP TABLE IF EXISTS one_row_complex;
24 | DROP TABLE IF EXISTS one_row_complex_null;
25 | CREATE TABLE one_row_complex ($COLUMNS);
26 | CREATE TABLE one_row_complex_null ($COLUMNS);
27 | INSERT OVERWRITE TABLE one_row_complex SELECT
28 | true,
29 | 127,
30 | 32767,
31 | 2147483647,
32 | 9223372036854775807,
33 | 0.5,
34 | 0.25,
35 | 'a string',
36 | 0,
37 | '123',
38 | array(1, 2),
39 | map(1, 2, 3, 4),
40 | named_struct('a', 1, 'b', 2),
41 | create_union(0, 1, 'test_string'),
42 | 0.1
43 | FROM one_row;
44 | INSERT OVERWRITE TABLE one_row_complex_null SELECT
45 | null,
46 | null,
47 | null,
48 | null,
49 | null,
50 | null,
51 | null,
52 | null,
53 | null,
54 | null,
55 | IF(false, array(1, 2), null),
56 | IF(false, map(1, 2, 3, 4), null),
57 | IF(false, named_struct('a', 1, 'b', 2), null),
58 | IF(false, create_union(0, 1, 'test_string'), null),
59 | null
60 | FROM one_row;
61 | "
62 |
63 | # Note: using IF(false, ...) above to work around https://issues.apache.org/jira/browse/HIVE-4022
64 | # The problem is that a "void" type cannot be inserted into a complex type field.
65 |
--------------------------------------------------------------------------------
/scripts/make_test_database.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 | hive -e '
3 | DROP DATABASE IF EXISTS pyhive_test_database CASCADE;
4 | CREATE DATABASE pyhive_test_database;
5 | CREATE TABLE pyhive_test_database.dummy_table (a INT);
6 | '
7 |
--------------------------------------------------------------------------------
/scripts/make_test_tables.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 | # Hive must be on the path for this script to work.
3 | # WARNING: drops and recreates tables called one_row, one_row_complex, and many_rows, plus a
4 | # database called pyhive_test_database.
5 |
6 | $(dirname $0)/make_one_row.sh
7 | $(dirname $0)/make_one_row_complex.sh
8 | $(dirname $0)/make_many_rows.sh
9 | $(dirname $0)/make_test_database.sh
10 |
--------------------------------------------------------------------------------
/scripts/thrift-patches/TCLIService.patch:
--------------------------------------------------------------------------------
1 | --- TCLIService.thrift 2017-04-16 22:51:45.000000000 -0700
2 | +++ TCLIService.thrift 2017-03-17 17:12:07.000000000 -0700
3 | @@ -32,8 +32,9 @@
4 | // * Service names begin with the letter "T", use a capital letter for each
5 | // new word (with no underscores), and end with the word "Service".
6 |
7 | namespace java org.apache.hive.service.rpc.thrift
8 | namespace cpp apache.hive.service.rpc.thrift
9 | +namespace py TCLIService
10 |
11 | // List of protocol versions. A new token should be
12 | // added to the end of this list every time a change is made.
13 | @@ -1224,6 +1225,19 @@
14 | 6: required i64 startTime
15 | }
16 |
17 | +// GetLog() is not present in never versions of Hive, but we add it here for backwards compatibility
18 | +struct TGetLogReq {
19 | + // operation handle
20 | + 1: required TOperationHandle operationHandle
21 | +}
22 | +
23 | +struct TGetLogResp {
24 | + // status of the request
25 | + 1: required TStatus status
26 | + // log content as text
27 | + 2: required string log
28 | +}
29 | +
30 | service TCLIService {
31 |
32 | TOpenSessionResp OpenSession(1:TOpenSessionReq req);
33 | @@ -1267,4 +1281,7 @@
34 | TCancelDelegationTokenResp CancelDelegationToken(1:TCancelDelegationTokenReq req);
35 |
36 | TRenewDelegationTokenResp RenewDelegationToken(1:TRenewDelegationTokenReq req);
37 | +
38 | + // Adding older log retrieval method for backward compatibility
39 | + TGetLogResp GetLog(1:TGetLogReq req);
40 | }
41 |
--------------------------------------------------------------------------------
/scripts/travis-conf/com/dropbox/DummyPasswdAuthenticationProvider.java:
--------------------------------------------------------------------------------
1 | package com.dropbox;
2 |
3 | import org.apache.hive.service.auth.PasswdAuthenticationProvider;
4 |
5 | import javax.security.sasl.AuthenticationException;
6 |
7 | public class DummyPasswdAuthenticationProvider implements PasswdAuthenticationProvider {
8 | @Override
9 | public void Authenticate(String user, String password) throws AuthenticationException {
10 | if (!user.equals("the-user") || !password.equals("p4ssw0rd")) {
11 | throw new AuthenticationException();
12 | }
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/scripts/travis-conf/hive/hive-site-custom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | hive.metastore.uris
6 | thrift://localhost:9083
7 |
8 |
9 | javax.jdo.option.ConnectionURL
10 | jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true
11 |
12 |
13 | fs.defaultFS
14 | file:///
15 |
16 |
17 | hive.server2.authentication
18 | CUSTOM
19 |
20 |
21 | hive.server2.custom.authentication.class
22 | com.dropbox.DummyPasswdAuthenticationProvider
23 |
24 |
25 |
--------------------------------------------------------------------------------
/scripts/travis-conf/hive/hive-site-ldap.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | hive.metastore.uris
6 | thrift://localhost:9083
7 |
8 |
9 | javax.jdo.option.ConnectionURL
10 | jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true
11 |
12 |
13 | fs.defaultFS
14 | file:///
15 |
16 |
17 | hive.server2.authentication
18 | LDAP
19 |
20 |
21 | hive.server2.authentication.ldap.url
22 | ldap://localhost:3389
23 |
24 |
25 | hive.server2.authentication.ldap.baseDN
26 | dc=example,dc=com
27 |
28 |
29 | hive.server2.authentication.ldap.guidKey
30 | cn
31 |
32 |
33 |
--------------------------------------------------------------------------------
/scripts/travis-conf/hive/hive-site.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | hive.metastore.uris
6 | thrift://localhost:9083
7 |
8 |
9 | javax.jdo.option.ConnectionURL
10 | jdbc:derby:;databaseName=/var/lib/hive/metastore/metastore_db;create=true
11 |
12 |
13 | fs.defaultFS
14 | file:///
15 |
16 |
17 |
--------------------------------------------------------------------------------
/scripts/travis-conf/presto/catalog/hive.properties:
--------------------------------------------------------------------------------
1 | connector.name=hive-hadoop2
2 | hive.metastore.uri=thrift://localhost:9083
3 |
--------------------------------------------------------------------------------
/scripts/travis-conf/presto/config.properties:
--------------------------------------------------------------------------------
1 | coordinator=true
2 | node-scheduler.include-coordinator=true
3 | http-server.http.port=8080
4 | query.max-memory=100MB
5 | query.max-memory-per-node=100MB
6 | discovery-server.enabled=true
7 | discovery.uri=http://localhost:8080
8 |
--------------------------------------------------------------------------------
/scripts/travis-conf/presto/jvm.config:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/PyHive/ac09074a652fd50e10b57a7f0bbc4f6410961301/scripts/travis-conf/presto/jvm.config
--------------------------------------------------------------------------------
/scripts/travis-conf/presto/node.properties:
--------------------------------------------------------------------------------
1 | node.environment=production
2 | node.id=ffffffff-ffff-ffff-ffff-ffffffffffff
3 | node.data-dir=/tmp/presto/data
4 |
--------------------------------------------------------------------------------
/scripts/travis-conf/trino/catalog/hive.properties:
--------------------------------------------------------------------------------
1 | connector.name=hive-hadoop2
2 | hive.metastore.uri=thrift://localhost:9083
3 |
--------------------------------------------------------------------------------
/scripts/travis-conf/trino/config.properties:
--------------------------------------------------------------------------------
1 | coordinator=true
2 | node-scheduler.include-coordinator=true
3 | http-server.http.port=18080
4 | query.max-memory=100MB
5 | query.max-memory-per-node=100MB
6 | discovery-server.enabled=true
7 | discovery.uri=http://localhost:18080
8 |
--------------------------------------------------------------------------------
/scripts/travis-conf/trino/jvm.config:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/PyHive/ac09074a652fd50e10b57a7f0bbc4f6410961301/scripts/travis-conf/trino/jvm.config
--------------------------------------------------------------------------------
/scripts/travis-conf/trino/node.properties:
--------------------------------------------------------------------------------
1 | node.environment=production
2 | node.id=11111111-1111-1111-1111-111111111111
3 | node.data-dir=/tmp/trino/data
4 |
--------------------------------------------------------------------------------
/scripts/travis-install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 |
3 | source /etc/lsb-release
4 |
5 | echo "deb [arch=amd64] https://archive.cloudera.com/${CDH}/ubuntu/${DISTRIB_CODENAME}/amd64/cdh ${DISTRIB_CODENAME}-cdh${CDH_VERSION} contrib
6 | deb-src https://archive.cloudera.com/${CDH}/ubuntu/${DISTRIB_CODENAME}/amd64/cdh ${DISTRIB_CODENAME}-cdh${CDH_VERSION} contrib" | sudo tee /etc/apt/sources.list.d/cloudera.list
7 | sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys 327574EE02A818DD
8 | sudo apt-get -q update
9 |
10 | sudo apt-get -q install -y oracle-java8-installer python-dev g++ libsasl2-dev maven
11 | sudo update-java-alternatives -s java-8-oracle
12 |
13 | #
14 | # LDAP
15 | #
16 | sudo apt-get -q -y --no-install-suggests --no-install-recommends --force-yes install ldap-utils slapd
17 | sudo mkdir -p /tmp/slapd
18 | sudo slapd -f $(dirname $0)/ldap_config/slapd.conf -h ldap://localhost:3389 &
19 | while ! nc -vz localhost 3389; do sleep 1; done
20 | sudo ldapadd -h localhost:3389 -D cn=admin,dc=example,dc=com -w test -f $(dirname $0)/../pyhive/tests/ldif_data/base.ldif
21 | sudo ldapadd -h localhost:3389 -D cn=admin,dc=example,dc=com -w test -f $(dirname $0)/../pyhive/tests/ldif_data/INITIAL_TESTDATA.ldif
22 |
23 | #
24 | # Hive
25 | #
26 |
27 | sudo apt-get -q install -y --force-yes hive
28 |
29 | javac -cp /usr/lib/hive/lib/hive-service.jar $(dirname $0)/travis-conf/com/dropbox/DummyPasswdAuthenticationProvider.java
30 | jar cf $(dirname $0)/dummy-auth.jar -C $(dirname $0)/travis-conf com
31 | sudo cp $(dirname $0)/dummy-auth.jar /usr/lib/hive/lib
32 |
33 | # Hack around broken symlink in Hive's installation
34 | # /usr/lib/hive/lib/zookeeper.jar -> ../../zookeeper/zookeeper.jar
35 | # Without this, Hive fails to start up due to failing to find ZK classes.
36 | sudo ln -nsfv /usr/share/java/zookeeper.jar /usr/lib/hive/lib/zookeeper.jar
37 |
38 | sudo mkdir -p /user/hive
39 | sudo chown hive:hive /user/hive
40 | sudo cp $(dirname $0)/travis-conf/hive/hive-site.xml /etc/hive/conf/hive-site.xml
41 | sudo apt-get -q install -y --force-yes hive-metastore hive-server2 || (grep . /var/log/hive/* && exit 2)
42 |
43 | while ! nc -vz localhost 9083; do sleep 1; done
44 | while ! nc -vz localhost 10000; do sleep 1; done
45 |
46 | sudo -Eu hive $(dirname $0)/make_test_tables.sh
47 |
48 | #
49 | # Presto
50 | #
51 |
52 | sudo apt-get -q install -y python # Use python2 for presto server
53 |
54 | mvn -q org.apache.maven.plugins:maven-dependency-plugin:3.0.0:copy \
55 | -Dartifact=com.facebook.presto:presto-server:${PRESTO}:tar.gz \
56 | -DoutputDirectory=.
57 | tar -x -z -f presto-server-*.tar.gz
58 | rm -rf presto-server
59 | mv presto-server-*/ presto-server
60 |
61 | cp -r $(dirname $0)/travis-conf/presto presto-server/etc
62 |
63 | /usr/bin/python2.7 presto-server/bin/launcher.py start
64 |
65 | #
66 | # Trino
67 | #
68 |
69 | sudo apt-get -q install -y python # Use python2 for trino server
70 |
71 | mvn -q org.apache.maven.plugins:maven-dependency-plugin:3.0.0:copy \
72 | -Dartifact=io.trino:trino-server:${TRINO}:tar.gz \
73 | -DoutputDirectory=.
74 | tar -x -z -f trino-server-*.tar.gz
75 | rm -rf trino-server
76 | mv trino-server-*/ trino-server
77 |
78 | cp -r $(dirname $0)/travis-conf/trino trino-server/etc
79 |
80 | /usr/bin/python2.7 trino-server/bin/launcher.py start
81 |
82 | #
83 | # Python
84 | #
85 |
86 | pip install $SQLALCHEMY
87 | pip install -e .
88 | pip install -r dev_requirements.txt
89 |
90 | # Sleep so Presto has time to start up.
91 | # Otherwise we might get 'No nodes available to run query' or 'Presto server is still initializing'
92 | while ! grep -q 'SERVER STARTED' /tmp/presto/data/var/log/server.log; do sleep 1; done
93 |
94 | # Sleep so Trino has time to start up.
95 | # Otherwise we might get 'No nodes available to run query' or 'Presto server is still initializing'
96 | while ! grep -q 'SERVER STARTED' /tmp/trino/data/var/log/server.log; do sleep 1; done
97 |
--------------------------------------------------------------------------------
/scripts/update_thrift_bindings.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 |
3 | HIVE_VERSION='2.3.0' # Must be a released version
4 |
5 | # Create a temporary directory
6 | scriptdir=`dirname $0`
7 | tmpdir=$scriptdir/.thrift_gen
8 |
9 | # Clean up previous generation attempts, in case it breaks things
10 | rm -rf $tmpdir
11 | mkdir $tmpdir
12 |
13 | # Download TCLIService.thrift from Hive
14 | curl -o $tmpdir/TCLIService.thrift \
15 | https://raw.githubusercontent.com/apache/hive/rel/release-$HIVE_VERSION/service-rpc/if/TCLIService.thrift
16 |
17 | # Apply patch that adds legacy GetLog methods
18 | patch -d $tmpdir < $scriptdir/thrift-patches/TCLIService.patch
19 |
20 | thrift -r --gen py -out $scriptdir/../ $tmpdir/TCLIService.thrift
21 | rm $scriptdir/../__init__.py
22 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [egg_info]
2 | tag_build = dev
3 |
4 | [tool:pytest]
5 | timeout = 100
6 | timeout_method = thread
7 | addopts = --random --tb=short --cov pyhive --cov-report html --cov-report term --flake8
8 | norecursedirs = env
9 | python_files = test_*.py
10 | flake8-max-line-length = 100
11 | flake8-ignore =
12 | TCLIService/*.py ALL
13 | pyhive/sqlalchemy_backports.py ALL
14 | presto-server/** ALL
15 | pyhive/hive.py F405
16 | pyhive/presto.py F405
17 | pyhive/trino.py F405
18 | W503
19 | filterwarnings =
20 | error
21 | # For Python 2 flake8
22 | default:You passed a bytestring as `filenames`:DeprecationWarning:flake8.options.config
23 | default::UnicodeWarning:_pytest.warnings
24 | # TODO For old sqlalchemy
25 | default:cgi.parse_qsl is deprecated:PendingDeprecationWarning:sqlalchemy.engine.url
26 | # TODO
27 | default:Did not recognize type:sqlalchemy.exc.SAWarning
28 | default:The Binary type has been renamed to LargeBinary:sqlalchemy.exc.SADeprecationWarning
29 | default:Please use assertRaisesRegex instead:DeprecationWarning
30 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup
4 | from setuptools.command.test import test as TestCommand
5 | import pyhive
6 | import sys
7 |
8 |
9 | class PyTest(TestCommand):
10 | def finalize_options(self):
11 | TestCommand.finalize_options(self)
12 | self.test_args = []
13 | self.test_suite = True
14 |
15 | def run_tests(self):
16 | # import here, cause outside the eggs aren't loaded
17 | import pytest
18 | errno = pytest.main(self.test_args)
19 | sys.exit(errno)
20 |
21 |
22 | with open('README.rst') as readme:
23 | long_description = readme.read()
24 |
25 | setup(
26 | name="PyHive",
27 | version=pyhive.__version__,
28 | description="Python interface to Hive",
29 | long_description=long_description,
30 | url='https://github.com/dropbox/PyHive',
31 | author="Jing Wang",
32 | author_email="jing@dropbox.com",
33 | license="Apache License, Version 2.0",
34 | packages=['pyhive', 'TCLIService'],
35 | classifiers=[
36 | "Intended Audience :: Developers",
37 | "License :: OSI Approved :: Apache Software License",
38 | "Operating System :: OS Independent",
39 | "Topic :: Database :: Front-Ends",
40 | ],
41 | install_requires=[
42 | 'future',
43 | 'python-dateutil',
44 | ],
45 | extras_require={
46 | 'presto': ['requests>=1.0.0'],
47 | 'trino': ['requests>=1.0.0'],
48 | 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'],
49 | 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'],
50 | 'sqlalchemy': ['sqlalchemy>=1.3.0'],
51 | 'kerberos': ['requests_kerberos>=0.12.0'],
52 | },
53 | tests_require=[
54 | 'mock>=1.0.0',
55 | 'pytest',
56 | 'pytest-cov',
57 | 'requests>=1.0.0',
58 | 'requests_kerberos>=0.12.0',
59 | 'sasl>=0.2.1',
60 | 'pure-sasl>=0.6.2',
61 | 'kerberos>=1.3.0',
62 | 'sqlalchemy>=1.3.0',
63 | 'thrift>=0.10.0',
64 | ],
65 | cmdclass={'test': PyTest},
66 | package_data={
67 | '': ['*.rst'],
68 | },
69 | entry_points={
70 | 'sqlalchemy.dialects': [
71 | 'hive = pyhive.sqlalchemy_hive:HiveDialect',
72 | "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect",
73 | "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect",
74 | 'presto = pyhive.sqlalchemy_presto:PrestoDialect',
75 | 'trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect',
76 | ],
77 | }
78 | )
79 |
--------------------------------------------------------------------------------