├── .circleci └── config.yml ├── .gitignore ├── ChangeLog.md ├── LICENSE ├── README.rst ├── pyproject.toml ├── setup.py └── src ├── pyrqlite ├── __init__.py ├── _ephemeral.py ├── connections.py ├── constants.py ├── cursors.py ├── dbapi2.py ├── exceptions.py ├── extensions.py ├── row.py └── types.py └── test ├── test_dbapi.py ├── test_row.py └── test_types.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | jobs: 3 | build: 4 | machine: 5 | image: ubuntu-2204:2024.05.1 6 | environment: 7 | RQLITE_VERSION: 8.32.1 8 | steps: 9 | - checkout 10 | - run: cat /etc/lsb-release 11 | - run: pip install pytest pytest-cov setuptools 12 | - run: | 13 | curl -L https://github.com/rqlite/rqlite/releases/download/v${RQLITE_VERSION}/rqlite-v${RQLITE_VERSION}-linux-amd64.tar.gz -o rqlite-v${RQLITE_VERSION}-linux-amd64.tar.gz 14 | tar xvfz rqlite-v${RQLITE_VERSION}-linux-amd64.tar.gz 15 | cp rqlite-v${RQLITE_VERSION}-linux-amd64/rqlited /home/circleci/project/rqlited 16 | - run: 17 | name: Run tests 18 | environment: 19 | RQLITED_PATH: /home/circleci/project/rqlited 20 | PYTHONPATH: src 21 | command: python setup.py test 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | __pycache__ 4 | *.egg-info 5 | /.cache 6 | /.coverage 7 | /htmlcov 8 | -------------------------------------------------------------------------------- /ChangeLog.md: -------------------------------------------------------------------------------- 1 | ## 2.2.3 (May 25 2024) 2 | - [PR #59](https://github.com/rqlite/pyrqlite/pull/62): Support for empty Datetime column. 3 | 4 | ## 2.2.2 (Jan 21 2024) 5 | - [PR #59](https://github.com/rqlite/pyrqlite/pull/59): Row: Allow non-unique column names. 6 | 7 | ## 2.2.1 (Jan 3 2024) 8 | - [PR #57](https://github.com/rqlite/pyrqlite/issues/57): Add project.license to pyproject.toml. 9 | 10 | ## 2.1.1 (Dec 23 2021) 11 | - [PR #40](https://github.com/rqlite/pyrqlite/pull/40): Support ssl_context. 12 | 13 | ## 2.1 (Feb 28 2021) 14 | - [PR #4](https://github.com/rqlite/pyrqlite/pull/4): Leader redirection. 15 | - [PR #22](https://github.com/rqlite/pyrqlite/pull/22): Basic auth support. 16 | - [PR #25](https://github.com/rqlite/pyrqlite/pull/25): Named parameters. 17 | - [PR #28](https://github.com/rqlite/pyrqlite/pull/28): Cursor.fetchmany(). 18 | - [PR #32](https://github.com/rqlite/pyrqlite/pull/32): HTTPS support. 19 | 20 | ## 2.0 (May 1st 2016) 21 | - Compatible with rqlite v2.0. 22 | - Remove PARSE_DECLTYPES support. 23 | - Remove rqlite v1.0 compatibility. 24 | 25 | ## 1.0 (February 25th 2016) 26 | - Compatible with rqlite v1.0. 27 | - Parse sql in order to map results to column names and types. 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Zac Medico 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ------------- 2 | pyrqlite 3 | ------------- 4 | 5 | .. image:: https://circleci.com/gh/rqlite/pyrqlite.svg?style=svg 6 | :target: https://circleci.com/gh/rqlite/pyrqlite 7 | 8 | This package contains a pure-Python rqlite client library. 9 | 10 | .. contents:: 11 | 12 | Requirements 13 | ------------- 14 | 15 | * Python -- one of the following: 16 | 17 | - CPython_ >= 2.7 or >= 3.3 18 | 19 | * rqlite Server 20 | 21 | 22 | Installation 23 | ------------ 24 | 25 | The last stable release is available on github and can be installed with ``pip``:: 26 | 27 | $ pip install git+https://github.com/rqlite/pyrqlite.git 28 | 29 | You can also just clone the repo and install it from source:: 30 | 31 | $ git clone https://github.com/rqlite/pyrqlite.git 32 | $ cd pyrqlite 33 | $ python setup.py install 34 | 35 | Finally (e.g. if ``pip`` is not available), a tarball can be downloaded 36 | from GitHub and installed with Setuptools:: 37 | 38 | $ # X.Y.Z is the desired pyrqlite version (e.g. 2.2.1). 39 | $ curl -L https://github.com/rqlite/pyrqlite/archive/refs/tags/vX.Y.Z.tar.gz | tar xz 40 | $ cd pyrqlite* 41 | $ python setup.py install 42 | $ # The folder pyrqlite* can be safely removed now. 43 | 44 | You mean need to run the installation process with ``root`` privileges. 45 | 46 | Test Suite 47 | ---------- 48 | 49 | To run all the tests, execute the script ``setup.py``:: 50 | 51 | $ python setup.py test 52 | 53 | pytest (https://pytest.org/) and pytest-cov are required to run the test 54 | suite. They can both be installed with ``pip`` 55 | 56 | Example 57 | ------- 58 | 59 | The following code creates a connection and executes some statements: 60 | 61 | .. code:: python 62 | 63 | import pyrqlite.dbapi2 as dbapi2 64 | 65 | # Connect to the database 66 | connection = dbapi2.connect( 67 | host='localhost', 68 | port=4001, 69 | ) 70 | 71 | try: 72 | with connection.cursor() as cursor: 73 | cursor.execute('CREATE TABLE foo (id integer not null primary key, name text)') 74 | cursor.executemany('INSERT INTO foo(name) VALUES(?)', seq_of_parameters=(('a',), ('b',))) 75 | 76 | with connection.cursor() as cursor: 77 | # Read a single record with qmark parameter style 78 | sql = "SELECT `id`, `name` FROM `foo` WHERE `name`=?" 79 | cursor.execute(sql, ('a',)) 80 | result = cursor.fetchone() 81 | print(result) 82 | # Read a single record with named parameter style 83 | sql = "SELECT `id`, `name` FROM `foo` WHERE `name`=:name" 84 | cursor.execute(sql, {'name': 'b'}) 85 | result = cursor.fetchone() 86 | print(result) 87 | finally: 88 | connection.close() 89 | 90 | .. code:: python 91 | 92 | This example will print: 93 | 94 | 95 | (1, 'a') 96 | (2, 'b') 97 | 98 | Paramstyle 99 | ------------- 100 | 101 | Only qmark and named paramstyles (as defined in PEP 249) are supported. 102 | 103 | Limitations 104 | ------------- 105 | Transactions are not supported. 106 | 107 | Resources 108 | ------------- 109 | DB-API 2.0: http://www.python.org/dev/peps/pep-0249 110 | 111 | 112 | License 113 | ------------- 114 | pyrqlite is released under the MIT License. See LICENSE for more information. 115 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pyrqlite" 3 | version = "2.2.3" 4 | authors = [ 5 | { name="Zac Medico", email="zmedico@gmail.com" }, 6 | { name="Philip O'Toole", email="philip.otoole@yahoo.com" }, 7 | ] 8 | description = "Python rqlite client library" 9 | readme = "README.md" 10 | requires-python = ">=3.3" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: OS Independent", 15 | ] 16 | dynamic = ["maintainers"] 17 | 18 | [project.license] 19 | file = "LICENSE" 20 | 21 | [project.urls] 22 | "Homepage" = "https://www.rqlite.io" 23 | "Bug Tracker" = "https://github.com/rqlite/pyrqlite/issues" 24 | 25 | [build-system] 26 | requires = [ 27 | "setuptools", 28 | "wheel", 29 | ] 30 | build-backend = "setuptools.build_meta" 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | from os.path import isdir, islink, relpath, dirname 5 | import subprocess 6 | import sys 7 | from setuptools import ( 8 | Command, 9 | setup, 10 | find_packages, 11 | ) 12 | 13 | sys.path.insert(0, 'src') 14 | from pyrqlite.constants import ( 15 | __author__, 16 | __email__, 17 | __license__, 18 | ) 19 | 20 | class PyTest(Command): 21 | user_options = [('match=', 'k', 'Run only tests that match the provided expressions')] 22 | 23 | def initialize_options(self): 24 | self.match = None 25 | 26 | def finalize_options(self): 27 | pass 28 | 29 | def run(self): 30 | testpath = 'src/test' 31 | buildlink = 'build/lib/test' 32 | 33 | if isdir(dirname(buildlink)): 34 | if islink(buildlink): 35 | os.unlink(buildlink) 36 | 37 | os.symlink(relpath(testpath, dirname(buildlink)), buildlink) 38 | testpath = buildlink 39 | 40 | try: 41 | os.environ['EPYTHON'] = 'python{}.{}'.format(sys.version_info.major, sys.version_info.minor) 42 | subprocess.check_call(['py.test', '-v', testpath, '-s', 43 | '--cov-report=html', '--cov-report=term-missing'] + 44 | (['-k', self.match] if self.match else []) + 45 | ['--cov={}'.format(p) for p in find_packages(dirname(testpath), exclude=['test'])]) 46 | 47 | finally: 48 | if islink(buildlink): 49 | os.unlink(buildlink) 50 | 51 | 52 | class PyLint(Command): 53 | user_options = [('errorsonly', 'E', 'Check only errors with pylint'), 54 | ('format=', 'f', 'Change the output format')] 55 | 56 | def initialize_options(self): 57 | self.errorsonly = 0 58 | self.format = 'colorized' 59 | 60 | def finalize_options(self): 61 | pass 62 | 63 | def run(self): 64 | cli_options = ['-E'] if self.errorsonly else [] 65 | cli_options.append('--output-format={0}'.format(self.format)) 66 | errno = subprocess.call(['pylint'] + cli_options + [ 67 | "--msg-template='{C}:{msg_id}:{path}:{line:3d},{column}: {obj}: {msg} ({symbol})'"] + 68 | find_packages('src', exclude=['test']), cwd='./src') 69 | raise SystemExit(errno) 70 | 71 | 72 | setup( 73 | name="pyrqlite", 74 | url='https://github.com/rqlite/pyrqlite/', 75 | author=__author__, 76 | author_email=__email__, 77 | maintainer=__author__, 78 | maintainer_email=__email__, 79 | description='python DB API 2.0 driver for rqlite', 80 | license=__license__, 81 | package_dir={'': 'src'}, 82 | packages=find_packages('src', exclude=['test']), 83 | platforms=['Posix'], 84 | cmdclass={'test': PyTest, 'lint': PyLint}, 85 | tests_require=['pytest', 'pytest-cov'], 86 | classifiers=[ 87 | 'Development Status :: 3 - Alpha', 88 | 'Environment :: Console', 89 | 'Intended Audience :: Developers', 90 | 'License :: OSI Approved :: MIT License', 91 | 'Programming Language :: Python', 92 | 'Programming Language :: Python :: 2', 93 | 'Programming Language :: Python :: 2.7', 94 | 'Programming Language :: Python :: 3', 95 | 'Programming Language :: Python :: 3.3', 96 | 'Programming Language :: Python :: 3.4', 97 | 'Programming Language :: Python :: 3.5', 98 | 'Programming Language :: Python :: Implementation :: CPython', 99 | 'Programming Language :: Python :: Implementation :: PyPy', 100 | 'Topic :: Database', 101 | ], 102 | ) 103 | -------------------------------------------------------------------------------- /src/pyrqlite/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rqlite/pyrqlite/46c2c952ca8adeec18353a3af1a5de955fe5eaa7/src/pyrqlite/__init__.py -------------------------------------------------------------------------------- /src/pyrqlite/_ephemeral.py: -------------------------------------------------------------------------------- 1 | 2 | import contextlib 3 | import errno 4 | import os 5 | import shutil 6 | import socket 7 | import subprocess 8 | import sys 9 | import tempfile 10 | import time 11 | 12 | try: 13 | from http.client import HTTPConnection 14 | except ImportError: 15 | # pylint: disable=import-error 16 | from httplib import HTTPConnection 17 | 18 | 19 | RQLITED_PATH = os.environ.get("RQLITED_PATH", "rqlited") 20 | 21 | class EphemeralRqlited(object): 22 | def __init__(self): 23 | self.host = None 24 | self.http = None 25 | self.raft = None 26 | self._tempdir = None 27 | self._proc = None 28 | 29 | @staticmethod 30 | def _unused_ports(host, count): 31 | sockets = [] 32 | ports = [] 33 | try: 34 | sockets.extend( 35 | socket.socket(socket.AF_INET, socket.SOCK_STREAM) 36 | for i in range(count)) 37 | for s in sockets: 38 | s.bind((host, 0)) 39 | ports.append(s.getsockname()[-1]) 40 | finally: 41 | while sockets: 42 | sockets.pop().close() 43 | 44 | return ports 45 | 46 | @staticmethod 47 | def _test_port(host, port, timeout=None): 48 | try: 49 | with contextlib.closing( 50 | socket.create_connection((host, port), timeout=timeout)): 51 | return True 52 | except socket.error: 53 | return False 54 | 55 | @staticmethod 56 | def _test_readyz(host, port): 57 | try: 58 | with contextlib.closing(HTTPConnection(host, port=port)) as conn: 59 | conn.request("GET", "/readyz") 60 | return conn.getresponse().status == 200 61 | except Exception: 62 | return False 63 | 64 | def _start(self): 65 | self._tempdir = tempfile.mkdtemp() 66 | self.host = 'localhost' 67 | 68 | # Allocation of unused ports is racy, so retry 69 | # until ports have been successfully acquired. 70 | while self._proc is None: 71 | http_port, raft_port = self._unused_ports(self.host, 2) 72 | self.http = (self.host, http_port) 73 | self.raft = (self.host, raft_port) 74 | with open(os.devnull, mode='wb', buffering=0) as devnull: 75 | filename = RQLITED_PATH 76 | try: 77 | self._proc = subprocess.Popen([filename, 78 | '-http-addr', '{}:{}'.format(*self.http), 79 | '-raft-addr', '{}:{}'.format(*self.raft), self._tempdir], 80 | stdout=devnull, stderr=devnull) 81 | except EnvironmentError as e: 82 | if e.errno == errno.ENOENT and sys.version_info.major < 3: 83 | # Add filename to clarify exception message. 84 | e.filename = filename 85 | raise 86 | 87 | while not self._test_port(*self.http) and self._proc.poll() is None: 88 | time.sleep(0.5) 89 | 90 | while not self._test_readyz(*self.http): 91 | time.sleep(0.5) 92 | 93 | if self._proc.poll() is not None: 94 | self._proc = None 95 | 96 | def __enter__(self): 97 | self._start() 98 | return self 99 | 100 | def __exit__(self, exc_type, exc_value, exc_traceback): 101 | if self._tempdir is not None: 102 | shutil.rmtree(self._tempdir) 103 | self._tempdir = None 104 | if self._proc is not None: 105 | self._proc.terminate() 106 | self._proc.wait() 107 | self._proc = None 108 | return False 109 | -------------------------------------------------------------------------------- /src/pyrqlite/connections.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import unicode_literals 3 | 4 | import codecs 5 | import logging 6 | import warnings 7 | 8 | try: 9 | from http.client import HTTPConnection, HTTPSConnection 10 | except ImportError: 11 | # pylint: disable=import-error 12 | from httplib import HTTPConnection, HTTPSConnection 13 | 14 | try: 15 | from urllib.parse import urlparse 16 | except ImportError: 17 | # pylint: disable=import-error 18 | from urlparse import urlparse 19 | 20 | from .constants import ( 21 | UNLIMITED_REDIRECTS, 22 | ) 23 | 24 | from .cursors import Cursor 25 | from ._ephemeral import EphemeralRqlited as _EphemeralRqlited 26 | from .extensions import PARSE_DECLTYPES, PARSE_COLNAMES 27 | 28 | 29 | class Connection(object): 30 | 31 | from .exceptions import ( 32 | Warning, 33 | Error, 34 | InterfaceError, 35 | DatabaseError, 36 | DataError, 37 | OperationalError, 38 | IntegrityError, 39 | InternalError, 40 | ProgrammingError, 41 | NotSupportedError, 42 | ) 43 | 44 | def __init__( 45 | self, 46 | scheme="http", 47 | host="localhost", 48 | port=4001, 49 | ssl_context=None, 50 | user=None, 51 | password=None, 52 | connect_timeout=DeprecationWarning, 53 | timeout=None, 54 | detect_types=0, 55 | max_redirects=UNLIMITED_REDIRECTS, 56 | ): 57 | 58 | self.messages = [] 59 | self.scheme = scheme 60 | self.host = host 61 | self.port = port 62 | self.ssl_context = ssl_context 63 | self._headers = {} 64 | if not (user is None or password is None): 65 | self._headers['Authorization'] = 'Basic ' + \ 66 | codecs.encode('{}:{}'.format(user, password).encode('utf-8'), 67 | 'base64').decode('utf-8').rstrip('\n') 68 | 69 | if connect_timeout is not DeprecationWarning: 70 | warnings.warn( 71 | "connect_timeout parameter is deprecated and renamed to timeout", 72 | DeprecationWarning, 73 | stacklevel=1, 74 | ) 75 | timeout = connect_timeout 76 | self.timeout = timeout 77 | self.max_redirects = max_redirects 78 | self.detect_types = detect_types 79 | self.parse_decltypes = detect_types & PARSE_DECLTYPES 80 | self.parse_colnames = detect_types & PARSE_COLNAMES 81 | self._ephemeral = None 82 | if scheme == ':memory:': 83 | self._ephemeral = _EphemeralRqlited().__enter__() 84 | self.host, self.port = self._ephemeral.http 85 | self._connection = self._init_connection() 86 | 87 | @property 88 | def connect_timeout(self): 89 | warnings.warn( 90 | "connect_timeout attribute is deprecated and renamed to timeout", 91 | DeprecationWarning, 92 | stacklevel=1, 93 | ) 94 | return self.timeout 95 | 96 | def _init_connection(self): 97 | timeout = None if self.timeout is None else float(self.timeout) 98 | if self.scheme in ('http', ':memory:'): 99 | return HTTPConnection(self.host, port=self.port, timeout=timeout) 100 | elif self.scheme == 'https': 101 | return HTTPSConnection(self.host, port=self.port, context=self.ssl_context, 102 | timeout=timeout) 103 | else: 104 | raise Connection.ProgrammingError('Unsupported scheme %r' % self.scheme) 105 | 106 | def _retry_request(self, method, uri, body=None, headers={}): 107 | tries = 10 108 | while tries: 109 | tries -= 1 110 | try: 111 | self._connection.request(method, uri, body=body, 112 | headers=dict(self._headers, **headers)) 113 | return self._connection.getresponse() 114 | except Exception: 115 | if not tries: 116 | raise 117 | self._connection.close() 118 | self._connection = self._init_connection() 119 | 120 | def _fetch_response(self, method, uri, body=None, headers={}): 121 | """ 122 | Fetch a response, handling redirection. 123 | """ 124 | response = self._retry_request(method, uri, body=body, headers=headers) 125 | redirects = 0 126 | 127 | while response.status == 301 and \ 128 | response.getheader('Location') is not None and \ 129 | (self.max_redirects == UNLIMITED_REDIRECTS or redirects < self.max_redirects): 130 | redirects += 1 131 | uri = response.getheader('Location') 132 | location = urlparse(uri) 133 | 134 | logging.getLogger(__name__).debug("status: %s reason: '%s' location: '%s'", 135 | response.status, response.reason, uri) 136 | 137 | if self.host != location.hostname or self.port != location.port: 138 | self._connection.close() 139 | self.host = location.hostname 140 | self.port = location.port 141 | self._connection = self._init_connection() 142 | 143 | response = self._retry_request(method, uri, body=body, headers=headers) 144 | 145 | return response 146 | 147 | def close(self): 148 | """Close the connection now (rather than whenever .__del__() is 149 | called). 150 | 151 | The connection will be unusable from this point forward; an 152 | Error (or subclass) exception will be raised if any operation 153 | is attempted with the connection. The same applies to all 154 | cursor objects trying to use the connection. Note that closing 155 | a connection without committing the changes first will cause an 156 | implicit rollback to be performed.""" 157 | self._connection.close() 158 | if self._ephemeral is not None: 159 | self._ephemeral.__exit__(None, None, None) 160 | self._ephemeral = None 161 | 162 | def __del__(self): 163 | self.close() 164 | 165 | def commit(self): 166 | """Database modules that do not support transactions should 167 | implement this method with void functionality.""" 168 | pass 169 | 170 | def rollback(self): 171 | """This method is optional since not all databases provide 172 | transaction support. """ 173 | pass 174 | 175 | def cursor(self, factory=None): 176 | """Return a new Cursor Object using the connection.""" 177 | if factory: 178 | return factory(self) 179 | else: 180 | return Cursor(self) 181 | 182 | def execute(self, *args, **kwargs): 183 | return self.cursor().execute(*args, **kwargs) 184 | 185 | def ping(self, reconnect=True): 186 | if self._connection.sock is None: 187 | if reconnect: 188 | self._connection = self._init_connection() 189 | else: 190 | raise self.Error("Already closed") 191 | try: 192 | self.execute("SELECT 1") 193 | except Exception: 194 | if reconnect: 195 | self._connection = self._init_connection() 196 | self.ping(False) 197 | else: 198 | raise 199 | 200 | -------------------------------------------------------------------------------- /src/pyrqlite/constants.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import unicode_literals 3 | 4 | __project__ = "pyrqlite" 5 | 6 | __author__ = "Zac Medico" 7 | __email__ = "zmedico@gmail.com" 8 | 9 | __copyright__ = "Copyright (C) 2016 Zac Medico" 10 | __license__ = "MIT" 11 | __description__ = "Python dbapi2 driver for rqlite" 12 | 13 | UNLIMITED_REDIRECTS = -1 14 | -------------------------------------------------------------------------------- /src/pyrqlite/cursors.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import unicode_literals 3 | 4 | from collections import OrderedDict 5 | import json 6 | import logging 7 | import sys 8 | import re 9 | 10 | try: 11 | # pylint: disable=no-name-in-module 12 | from urllib.parse import urlencode 13 | except ImportError: 14 | # pylint: disable=no-name-in-module 15 | from urllib import urlencode 16 | 17 | from .exceptions import Error, ProgrammingError 18 | 19 | from .row import Row 20 | from .extensions import _convert_to_python, _adapt_from_python, _column_stripper 21 | 22 | 23 | if sys.version_info[0] >= 3: 24 | basestring = str 25 | _urlencode = urlencode 26 | else: 27 | # avoid UnicodeEncodeError from urlencode 28 | def _urlencode(query, doseq=0): 29 | return urlencode(dict( 30 | (k if isinstance(k, bytes) else k.encode('utf-8'), 31 | v if isinstance(v, bytes) else v.encode('utf-8')) 32 | for k, v in query.items()), doseq=doseq) 33 | 34 | 35 | class Cursor(object): 36 | arraysize = 1 37 | 38 | def __init__(self, connection): 39 | self._connection = connection 40 | self.messages = [] 41 | self.lastrowid = None 42 | self.description = None 43 | self.rownumber = 0 44 | self.rowcount = -1 45 | self.arraysize = 1 46 | self._rows = None 47 | self._column_type_cache = {} 48 | 49 | def __enter__(self): 50 | return self 51 | 52 | def __exit__(self, exc_type, exc_value, exc_traceback): 53 | self.close() 54 | 55 | @property 56 | def connection(self): 57 | return self._connection 58 | 59 | def close(self): 60 | self._rows = None 61 | 62 | def _request(self, method, uri, body=None, headers={}): 63 | logger = logging.getLogger(__name__) 64 | debug = logger.getEffectiveLevel() < logging.DEBUG 65 | logger.debug( 66 | 'request method: %s uri: %s headers: %s body: %s', 67 | method, 68 | uri, 69 | headers, 70 | body) 71 | response = self.connection._fetch_response( 72 | method, uri, body=body, headers=headers) 73 | if response.code != 200: 74 | raise Error("received unexpected http status: %d" % response.code) 75 | logger.debug( 76 | "status: %s reason: %s", 77 | response.status, 78 | response.reason) 79 | response_text = response.read().decode('utf-8') 80 | logger.debug("raw response: %s", response_text) 81 | response_json = json.loads( 82 | response_text, object_pairs_hook=OrderedDict) 83 | if debug: 84 | logger.debug( 85 | "formatted response: %s", 86 | json.dumps( 87 | response_json, 88 | indent=4)) 89 | return response_json 90 | 91 | def _substitute_params(self, operation, parameters): 92 | ''' 93 | SQLite natively supports only the types TEXT, INTEGER, REAL, BLOB and 94 | NULL 95 | ''' 96 | 97 | param_matches = 0 98 | 99 | qmark_re = re.compile(r"(\?)") 100 | named_re = re.compile(r"(:{1}[a-zA-Z]+?\b)") 101 | 102 | qmark_matches = qmark_re.findall(operation) 103 | named_matches = named_re.findall(operation) 104 | param_matches = len(qmark_matches) + len(named_matches) 105 | 106 | # Matches but no parameters 107 | if param_matches > 0 and parameters is None: 108 | raise ProgrammingError('parameter required but not given: %s' % 109 | operation) 110 | 111 | # No regex matches and no parameters. 112 | if parameters is None: 113 | return operation 114 | 115 | if len(qmark_matches) > 0 and len(named_matches) > 0: 116 | raise ProgrammingError('different parameter types in operation not' 117 | 'permitted: %s %s' % 118 | (operation, parameters)) 119 | 120 | if isinstance(parameters, dict): 121 | # parameters is a dict or a dict subclass 122 | if len(qmark_matches) > 0: 123 | raise ProgrammingError('Unamed binding used, but you supplied ' 124 | 'a dictionary (which has only names): ' 125 | '%s %s' % (operation, parameters)) 126 | for op_key in named_matches: 127 | try: 128 | operation = operation.replace(op_key, 129 | _adapt_from_python(parameters[op_key[1:]])) 130 | except KeyError: 131 | raise ProgrammingError('the named parameters given do not ' 132 | 'match operation: %s %s' % 133 | (operation, parameters)) 134 | else: 135 | # parameters is a sequence 136 | if param_matches != len(parameters): 137 | raise ProgrammingError('incorrect number of parameters ' 138 | '(%s != %s): %s %s' % (param_matches, 139 | len(parameters), operation, parameters)) 140 | if len(named_matches) > 0: 141 | raise ProgrammingError('Named binding used, but you supplied a' 142 | ' sequence (which has no names): %s %s' % 143 | (operation, parameters)) 144 | parts = operation.split('?') 145 | subst = [] 146 | for i, part in enumerate(parts): 147 | subst.append(part) 148 | if i < len(parameters): 149 | subst.append(_adapt_from_python(parameters[i])) 150 | operation = ''.join(subst) 151 | 152 | return operation 153 | 154 | def _get_sql_command(self, sql_str): 155 | return sql_str.split(None, 1)[0].upper() 156 | 157 | def execute(self, operation, parameters=None, queue=False, wait=False, consistency=None): 158 | if not isinstance(operation, basestring): 159 | raise ValueError( 160 | "argument must be a string, not '{}'".format(type(operation).__name__)) 161 | 162 | operation = self._substitute_params(operation, parameters) 163 | 164 | command = self._get_sql_command(operation) 165 | if command in ('SELECT', 'PRAGMA'): 166 | params = {'q': operation} 167 | if consistency: 168 | params["level"] = consistency 169 | payload = self._request("GET", "/db/query?" + _urlencode(params)) 170 | else: 171 | path = "/db/execute?transaction" 172 | if queue: 173 | path = path + "&queue" 174 | if wait: 175 | path = path +"&wait" 176 | payload = self._request("POST", path, 177 | headers={'Content-Type': 'application/json'}, body=json.dumps([operation])) 178 | 179 | last_insert_id = None 180 | rows_affected = -1 181 | payload_rows = {} 182 | try: 183 | results = payload["results"] 184 | except KeyError: 185 | pass 186 | else: 187 | rows_affected = 0 188 | for item in results: 189 | if 'error' in item: 190 | logging.getLogger(__name__).error(json.dumps(item)) 191 | raise Error(json.dumps(item)) 192 | try: 193 | rows_affected += item['rows_affected'] 194 | except KeyError: 195 | pass 196 | try: 197 | last_insert_id = item['last_insert_id'] 198 | except KeyError: 199 | pass 200 | if 'columns' in item: 201 | payload_rows = item 202 | 203 | try: 204 | fields = payload_rows['columns'] 205 | except KeyError: 206 | self.description = None 207 | self._rows = [] 208 | if command == 'INSERT': 209 | self.lastrowid = last_insert_id 210 | else: 211 | rows = [] 212 | description = [] 213 | for field in fields: 214 | description.append(( 215 | _column_stripper(field, parse_colnames=self.connection.parse_colnames), 216 | None, 217 | None, 218 | None, 219 | None, 220 | None, 221 | None, 222 | )) 223 | 224 | try: 225 | values = payload_rows['values'] 226 | types = payload_rows['types'] 227 | except KeyError: 228 | pass 229 | else: 230 | if values: 231 | converters = [_convert_to_python(field, type_, 232 | parse_decltypes=self.connection.parse_decltypes, 233 | parse_colnames=self.connection.parse_colnames) 234 | for field, type_ in zip(fields, types)] 235 | for payload_row in values: 236 | row = [] 237 | for field, converter, value in zip(fields, converters, payload_row): 238 | row.append((field, (value if converter is None 239 | else converter(value)))) 240 | rows.append(Row(row)) 241 | self._rows = rows 242 | self.description = tuple(description) 243 | 244 | self.rownumber = 0 245 | if command in ('UPDATE', 'DELETE'): 246 | # sqalchemy's _emit_update_statements function asserts 247 | # rowcount for each update, and _emit_delete_statements 248 | # warns unless rowcount matches 249 | self.rowcount = rows_affected 250 | else: 251 | self.rowcount = len(self._rows) 252 | return self 253 | 254 | def executemany(self, operation, seq_of_parameters=None, queue=False, wait=False): 255 | if not isinstance(operation, basestring): 256 | raise ValueError( 257 | "argument must be a string, not '{}'".format(type(operation).__name__)) 258 | 259 | statements = [] 260 | for parameters in seq_of_parameters: 261 | statements.append(self._substitute_params(operation, parameters)) 262 | 263 | path = "/db/execute?transaction" 264 | if queue: 265 | path = path + "&queue" 266 | if wait: 267 | path = path +"&wait" 268 | payload = self._request("POST", path, 269 | headers={'Content-Type': 'application/json'}, 270 | body=json.dumps(statements)) 271 | rows_affected = -1 272 | try: 273 | results = payload["results"] 274 | except KeyError: 275 | pass 276 | else: 277 | rows_affected = 0 278 | for item in results: 279 | if 'error' in item: 280 | logging.getLogger(__name__).error(json.dumps(item)) 281 | try: 282 | rows_affected += item['rows_affected'] 283 | except KeyError: 284 | pass 285 | self._rows = [] 286 | self.rownumber = 0 287 | self.rowcount = rows_affected 288 | 289 | def fetchone(self): 290 | ''' Fetch the next row ''' 291 | if self._rows is None or self.rownumber >= len(self._rows): 292 | return None 293 | result = self._rows[self.rownumber] 294 | self.rownumber += 1 295 | return result 296 | 297 | def fetchmany(self, size=None): 298 | remaining = self.arraysize if size is None else size 299 | remaining = min(remaining, self.rowcount - self.rownumber) 300 | return [self.fetchone() for i in range(remaining)] 301 | 302 | def fetchall(self): 303 | rows = [] 304 | while self.rownumber < self.rowcount: 305 | rows.append(self.fetchone()) 306 | return rows 307 | 308 | def setinputsizes(self, sizes): 309 | raise NotImplementedError(self) 310 | 311 | def setoutputsize(self, size, column=None): 312 | raise NotImplementedError(self) 313 | 314 | def scroll(self, value, mode='relative'): 315 | raise NotImplementedError(self) 316 | 317 | def next(self): 318 | raise NotImplementedError(self) 319 | 320 | def __iter__(self): 321 | while self.rownumber < self.rowcount: 322 | yield self.fetchone() 323 | -------------------------------------------------------------------------------- /src/pyrqlite/dbapi2.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import unicode_literals 3 | 4 | import time 5 | 6 | from .constants import ( 7 | UNLIMITED_REDIRECTS, 8 | ) 9 | 10 | from .connections import Connection 11 | connect = Connection 12 | 13 | from .exceptions import ( 14 | Warning, 15 | Error, 16 | InterfaceError, 17 | DataError, 18 | DatabaseError, 19 | OperationalError, 20 | IntegrityError, 21 | InternalError, 22 | NotSupportedError, 23 | ProgrammingError, 24 | ) 25 | 26 | from .types import ( 27 | Binary, 28 | Date, 29 | Time, 30 | Timestamp, 31 | STRING, 32 | BINARY, 33 | NUMBER, 34 | DATETIME, 35 | ROWID, 36 | ) 37 | 38 | # Compat with native sqlite module 39 | from .extensions import converters, adapters, register_converter, register_adapter 40 | from sqlite3.dbapi2 import PrepareProtocol 41 | 42 | 43 | paramstyle = "qmark" 44 | 45 | threadsafety = 1 46 | 47 | apilevel = "2.0" 48 | 49 | 50 | def DateFromTicks(ticks): 51 | return Date(*time.localtime(ticks)[:3]) 52 | 53 | 54 | def TimeFromTicks(ticks): 55 | return Time(*time.localtime(ticks)[3:6]) 56 | 57 | 58 | def TimestampFromTicks(ticks): 59 | return Timestamp(*time.localtime(ticks)[:6]) 60 | 61 | # accessed by sqlalchemy sqlite dialect 62 | sqlite_version_info = (3, 10, 0) 63 | 64 | # Compat with native sqlite module 65 | from .extensions import PARSE_DECLTYPES, PARSE_COLNAMES 66 | -------------------------------------------------------------------------------- /src/pyrqlite/exceptions.py: -------------------------------------------------------------------------------- 1 | from sqlite3 import (Warning, Error, InterfaceError, DatabaseError, DataError, 2 | OperationalError, IntegrityError, InternalError, 3 | ProgrammingError, NotSupportedError) 4 | -------------------------------------------------------------------------------- /src/pyrqlite/extensions.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | """ 4 | SQLite natively supports only the types TEXT, INTEGER, REAL, BLOB and NULL. 5 | And RQLite always answers 'bytes' values. 6 | 7 | Converters transforms RQLite answers to Python native types. 8 | Adapters transforms Python native types to RQLite-aware values. 9 | """ 10 | 11 | import codecs 12 | import datetime 13 | import functools 14 | import re 15 | import sqlite3 16 | import sys 17 | 18 | from .exceptions import InterfaceError 19 | 20 | if sys.version_info[0] >= 3: 21 | basestring = str 22 | unicode = str 23 | 24 | PARSE_DECLTYPES = 1 25 | PARSE_COLNAMES = 2 26 | 27 | 28 | def _decoder(conv_func): 29 | """ The Python sqlite3 interface returns always byte strings. 30 | This function converts the received value to a regular string before 31 | passing it to the receiver function. 32 | """ 33 | return lambda s: conv_func(s.decode('utf-8')) 34 | 35 | if sys.version_info[0] >= 3: 36 | 37 | def _escape_string(value): 38 | if isinstance(value, bytes): 39 | return "X'{}'".format( 40 | codecs.encode(value, 'hex').decode('utf-8')) 41 | 42 | return "'{}'".format(value.replace("'", "''")) 43 | 44 | def _adapt_datetime(val): 45 | return val.isoformat(" ") 46 | else: 47 | 48 | def _escape_string(value): 49 | if isinstance(value, bytes): 50 | try: 51 | value = value.decode('utf-8') 52 | except UnicodeDecodeError: 53 | # Encode as a BLOB literal containing hexadecimal data 54 | return "X'{}'".format( 55 | codecs.encode(value, 'hex').decode('utf-8')) 56 | 57 | return "'{}'".format(value.replace("'", "''")) 58 | 59 | def _adapt_datetime(val): 60 | return val.isoformat(b" ") 61 | 62 | def _adapt_date(val): 63 | return val.isoformat() 64 | 65 | def _convert_date(val): 66 | return datetime.date(*map(int, val.split('T')[0].split("-"))) 67 | 68 | def _convert_timestamp(val): 69 | datepart, timepart = val.split("T") 70 | year, month, day = map(int, datepart.split("-")) 71 | timepart_full = timepart.strip('Z').split(".") 72 | hours, minutes, seconds = map(int, timepart_full[0].split(":")) 73 | if len(timepart_full) == 2: 74 | microseconds = int('{:0<6.6}'.format(timepart_full[1])) 75 | else: 76 | microseconds = 0 77 | 78 | val = datetime.datetime(year, month, day, hours, minutes, seconds, microseconds) 79 | return val 80 | 81 | 82 | def _null_wrapper(converter, value): 83 | if value is not None: 84 | value = converter(value) 85 | return value 86 | 87 | 88 | adapters = { 89 | bytes: lambda x: x, 90 | float: lambda x: x, 91 | int: lambda x: x, 92 | bool: int, 93 | unicode: lambda x: x.encode('utf-8'), 94 | type(None): lambda x: None, 95 | datetime.date: _adapt_date, 96 | datetime.datetime: _adapt_datetime, 97 | 98 | } 99 | adapters = {(type_, sqlite3.PrepareProtocol): val for type_, val in adapters.items()} 100 | _default_adapters = adapters.copy() 101 | 102 | converters = { 103 | 'UNICODE': functools.partial(_null_wrapper, lambda x: x.decode('utf-8')), 104 | 'INTEGER': functools.partial(_null_wrapper, int), 105 | 'BOOL': functools.partial(_null_wrapper, bool), 106 | 'FLOAT': functools.partial(_null_wrapper, float), 107 | 'REAL': functools.partial(_null_wrapper, float), 108 | 'NULL': lambda x: None, 109 | 'BLOB': lambda x: x, 110 | 'DATE': functools.partial(_null_wrapper, _convert_date), 111 | 'DATETIME': lambda x: x.replace('T', ' ').rstrip('Z') if x is not None else None, 112 | 'TIMESTAMP': functools.partial(_null_wrapper, _convert_timestamp), 113 | } 114 | 115 | # Non-native converters will be decoded from base64 before fed into converter 116 | _native_converters = ('BOOL', 'FLOAT', 'INTEGER', 'REAL', 'NUMBER', 'NULL', 'DATE', 'DATETIME', 'TIMESTAMP') 117 | 118 | # SQLite TEXT affinity: https://www.sqlite.org/datatype3.html 119 | _text_affinity_re = re.compile(r'CHAR|CLOB|TEXT') 120 | 121 | 122 | def register_converter(type_string, function): 123 | converters[type_string.upper()] = function 124 | 125 | 126 | def register_adapter(type_, function): 127 | adapters[(type_, sqlite3.PrepareProtocol)] = function 128 | 129 | 130 | def _convert_to_python(column_name, type_, parse_decltypes=False, parse_colnames=False): 131 | """ 132 | Tries to mimic stock sqlite3 module behaviours. 133 | 134 | PARSE_COLNAMES have precedence over PARSE_DECLTYPES on _sqlite/cursor.c code 135 | """ 136 | converter = None 137 | type_upper = None 138 | 139 | if type_ == '': # q="select 3.0" -> type='' column_name='3.0' value=3 140 | if column_name.isdigit(): 141 | type_ = 'int' 142 | elif all([slice.isdigit() for slice in column_name.partition('.')[::2]]): # 3.14 143 | type_ = 'real' 144 | 145 | if '[' in column_name and ']' in column_name and parse_colnames: 146 | type_upper = column_name.upper().partition('[')[-1].partition(']')[0] 147 | if type_upper in converters: 148 | converter = converters[type_upper] 149 | 150 | if not converter: 151 | type_upper = type_.upper() 152 | if parse_decltypes: 153 | ## From: https://github.com/python/cpython/blob/c72b6008e0578e334f962ee298279a23ba298856/Modules/_sqlite/cursor.c#L167 154 | # /* Converter names are split at '(' and blanks. 155 | # * This allows 'INTEGER NOT NULL' to be treated as 'INTEGER' and 156 | # * 'NUMBER(10)' to be treated as 'NUMBER', for example. 157 | # * In other words, it will work as people expect it to work.*/ 158 | type_upper = type_upper.partition('(')[0].partition(' ')[0] 159 | if type_upper in converters: 160 | if type_upper in _native_converters or parse_decltypes: 161 | converter = converters[type_upper] 162 | 163 | if converter: 164 | if type_upper not in _native_converters: 165 | converter = functools.partial(_decode_base64_converter, converter) 166 | elif not type_upper or _text_affinity_re.search(type_upper): 167 | # Python's sqlite3 module has a text_factory attribute which 168 | # returns unicode by default. 169 | pass 170 | else: 171 | converter = _conditional_string_decode_base64 172 | 173 | return converter 174 | 175 | 176 | def _adapt_from_python(value): 177 | if not isinstance(value, basestring): 178 | adapter_key = (type(value), sqlite3.PrepareProtocol) 179 | adapter = adapters.get(adapter_key) 180 | try: 181 | if adapter is None: 182 | # Fall back to _default_adapters, so that ObjectAdaptationTests 183 | # teardown will correctly restore the default state. 184 | adapter = _default_adapters[adapter_key] 185 | except KeyError as e: 186 | # No adapter registered. Let the object adapt itself via PEP-246. 187 | # It has been rejected by the BDFL, but is still implemented 188 | # on stdlib sqlite3 module even on Python 3 !! 189 | if hasattr(value, '__adapt__'): 190 | adapted = value.__adapt__(sqlite3.PrepareProtocol) 191 | elif hasattr(value, '__conform__'): 192 | adapted = value.__conform__(sqlite3.PrepareProtocol) 193 | else: 194 | raise InterfaceError(e) 195 | else: 196 | adapted = adapter(value) 197 | else: 198 | adapted = value 199 | 200 | # The adapter could had returned a string 201 | if isinstance(adapted, (bytes, unicode)): 202 | adapted = _escape_string(adapted) 203 | elif adapted is None: 204 | adapted = 'NULL' 205 | else: 206 | adapted = str(adapted) 207 | 208 | return adapted 209 | 210 | def _column_stripper(column_name, parse_colnames=False): 211 | return column_name.partition(' ')[0] if parse_colnames else column_name 212 | 213 | def _decode_base64_converter(converter, value): 214 | if value is not None: 215 | if not isinstance(value, bytes): 216 | value = value.encode('utf-8') 217 | value = converter(codecs.decode(value, 'base64')) 218 | return value 219 | 220 | def _conditional_string_decode_base64(value): 221 | if isinstance(value, basestring): 222 | if not isinstance(value, bytes): 223 | value = value.encode('utf-8') 224 | value = codecs.decode(value, 'base64') 225 | return value 226 | -------------------------------------------------------------------------------- /src/pyrqlite/row.py: -------------------------------------------------------------------------------- 1 | 2 | try: 3 | # pylint: disable=no-name-in-module 4 | from collections.abc import Mapping 5 | except ImportError: 6 | from collections import Mapping 7 | 8 | from collections import OrderedDict 9 | 10 | 11 | class Row(tuple, Mapping): 12 | 13 | def __new__(cls, items): 14 | return tuple.__new__(cls, (item[1] for item in items)) 15 | 16 | def __init__(self, items): 17 | super(Row, self).__init__() 18 | self._items = tuple(items) 19 | # If the same column name appears more than once then only the 20 | # value for the first occurence is indexed (see __getitem__). 21 | d = {} 22 | for k, v in self._items: 23 | d.setdefault(k, v) 24 | self._dict = d 25 | 26 | def __getitem__(self, k): 27 | """ 28 | If the same column name appears more than once then this returns 29 | the value from the first matching column just like sqlite3.Row. 30 | """ 31 | try: 32 | return self._dict[k] 33 | except (KeyError, TypeError): 34 | if isinstance(k, (int, slice)): 35 | return tuple.__getitem__(self, k) 36 | else: 37 | raise 38 | 39 | def __iter__(self): 40 | """ 41 | Return an iterator over the values of the row. 42 | 43 | View contains all values. 44 | """ 45 | return tuple.__iter__(self) 46 | 47 | def items(self): 48 | """ 49 | Return a new view of the row’s items ((key, value) pairs). 50 | 51 | View contains all items, multiple items can have the same key. 52 | """ 53 | for item in self._items: 54 | yield item 55 | 56 | def values(self): 57 | """ 58 | Return a new view of the rows’s values. 59 | 60 | View contains all values. 61 | """ 62 | for item in self._items: 63 | yield item[1] 64 | 65 | def keys(self): 66 | """ 67 | Returns a list of column names which are not necessarily 68 | unique, just like sqlite3.Row. 69 | """ 70 | return [item[0] for item in self._items] 71 | 72 | def __len__(self): 73 | return tuple.__len__(self) 74 | 75 | def __delitem__(self, k): 76 | raise NotImplementedError(self) 77 | 78 | def pop(self, k): 79 | raise NotImplementedError(self) 80 | -------------------------------------------------------------------------------- /src/pyrqlite/types.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import sys 4 | 5 | Binary = bytes 6 | Date = datetime.date 7 | Time = datetime.time 8 | Timestamp = datetime.datetime 9 | 10 | # pylint: disable=undefined-variable 11 | STRING = str if sys.version_info[0] >= 3 else unicode 12 | BINARY = bytes 13 | NUMBER = float 14 | DATETIME = Timestamp 15 | # pylint: disable=undefined-variable 16 | ROWID = int if sys.version_info[0] >= 3 else long 17 | -------------------------------------------------------------------------------- /src/test/test_dbapi.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | # pysqlite2/test/dbapi.py: tests for DB-API compliance 3 | # 4 | # Copyright (C) 2004-2010 Gerhard Höring 5 | # 6 | # This file is part of pysqlite. 7 | # 8 | # This software is provided 'as-is', without any express or implied 9 | # warranty. In no event will the authors be held liable for any damages 10 | # arising from the use of this software. 11 | # 12 | # Permission is granted to anyone to use this software for any purpose, 13 | # including commercial applications, and to alter it and redistribute it 14 | # freely, subject to the following restrictions: 15 | # 16 | # 1. The origin of this software must not be misrepresented; you must not 17 | # claim that you wrote the original software. If you use this software 18 | # in a product, an acknowledgment in the product documentation would be 19 | # appreciated but is not required. 20 | # 2. Altered source versions must be plainly marked as such, and must not be 21 | # misrepresented as being the original software. 22 | # 3. This notice may not be removed or altered from any source distribution. 23 | 24 | from __future__ import print_function 25 | 26 | import sys 27 | import unittest 28 | 29 | import pyrqlite.dbapi2 as sqlite 30 | 31 | if sys.version_info[0] >= 3: 32 | StandardError = Exception 33 | 34 | 35 | class ModuleTests(unittest.TestCase): 36 | def test_CheckAPILevel(self): 37 | self.assertEqual(sqlite.apilevel, "2.0", 38 | "apilevel is %s, should be 2.0" % sqlite.apilevel) 39 | 40 | def test_CheckThreadSafety(self): 41 | self.assertEqual(sqlite.threadsafety, 1, 42 | "threadsafety is %d, should be 1" % sqlite.threadsafety) 43 | 44 | def test_CheckParamStyle(self): 45 | self.assertEqual(sqlite.paramstyle, "qmark", 46 | "paramstyle is '%s', should be 'qmark'" % 47 | sqlite.paramstyle) 48 | 49 | def test_CheckWarning(self): 50 | self.assertTrue(issubclass(sqlite.Warning, StandardError), 51 | "Warning is not a subclass of StandardError") 52 | 53 | def test_CheckError(self): 54 | self.assertTrue(issubclass(sqlite.Error, StandardError), 55 | "Error is not a subclass of StandardError") 56 | 57 | def test_CheckInterfaceError(self): 58 | self.assertTrue(issubclass(sqlite.InterfaceError, sqlite.Error), 59 | "InterfaceError is not a subclass of Error") 60 | 61 | def test_CheckDatabaseError(self): 62 | self.assertTrue(issubclass(sqlite.DatabaseError, sqlite.Error), 63 | "DatabaseError is not a subclass of Error") 64 | 65 | def test_CheckDataError(self): 66 | self.assertTrue(issubclass(sqlite.DataError, sqlite.DatabaseError), 67 | "DataError is not a subclass of DatabaseError") 68 | 69 | def test_CheckOperationalError(self): 70 | self.assertTrue(issubclass(sqlite.OperationalError, sqlite.DatabaseError), 71 | "OperationalError is not a subclass of DatabaseError") 72 | 73 | def test_CheckIntegrityError(self): 74 | self.assertTrue(issubclass(sqlite.IntegrityError, sqlite.DatabaseError), 75 | "IntegrityError is not a subclass of DatabaseError") 76 | 77 | def test_CheckInternalError(self): 78 | self.assertTrue(issubclass(sqlite.InternalError, sqlite.DatabaseError), 79 | "InternalError is not a subclass of DatabaseError") 80 | 81 | def test_CheckProgrammingError(self): 82 | self.assertTrue(issubclass(sqlite.ProgrammingError, sqlite.DatabaseError), 83 | "ProgrammingError is not a subclass of DatabaseError") 84 | 85 | def test_CheckNotSupportedError(self): 86 | self.assertTrue(issubclass(sqlite.NotSupportedError, 87 | sqlite.DatabaseError), 88 | "NotSupportedError is not a subclass of DatabaseError") 89 | 90 | class ConnectionTests(unittest.TestCase): 91 | @classmethod 92 | def setUpClass(cls): 93 | cls.cx = sqlite.connect(":memory:") 94 | 95 | def setUp(self): 96 | cu = self.cx.cursor() 97 | cu.execute("create table test(id integer primary key, name text)") 98 | cu.execute("insert into test(name) values (?)", ("foo",)) 99 | 100 | def tearDown(self): 101 | self.cx.execute("drop table test") 102 | 103 | @classmethod 104 | def tearDownClass(cls): 105 | cls.cx.close() 106 | del cls.cx 107 | 108 | def test_CheckCommit(self): 109 | self.cx.commit() 110 | 111 | def test_CheckCommitAfterNoChanges(self): 112 | """ 113 | A commit should also work when no changes were made to the database. 114 | """ 115 | self.cx.commit() 116 | self.cx.commit() 117 | 118 | def test_CheckRollback(self): 119 | self.cx.rollback() 120 | 121 | def test_CheckRollbackAfterNoChanges(self): 122 | """ 123 | A rollback should also work when no changes were made to the database. 124 | """ 125 | self.cx.rollback() 126 | self.cx.rollback() 127 | 128 | def test_CheckCursor(self): 129 | cu = self.cx.cursor() 130 | 131 | @unittest.skip('not implemented') 132 | def test_CheckFailedOpen(self): 133 | YOU_CANNOT_OPEN_THIS = "/foo/bar/bla/23534/mydb.db" 134 | try: 135 | con = sqlite.connect(YOU_CANNOT_OPEN_THIS) 136 | except sqlite.OperationalError: 137 | return 138 | self.fail("should have raised an OperationalError") 139 | 140 | def test_CheckClose(self): 141 | # This would interfere with other tests, and 142 | # tearDownClass exercises it already. 143 | #self.cx.close() 144 | pass 145 | 146 | def test_CheckPing(self): 147 | self.cx.ping() 148 | 149 | def test_CheckExceptions(self): 150 | # Optional DB-API extension. 151 | self.assertEqual(self.cx.Warning, sqlite.Warning) 152 | self.assertEqual(self.cx.Error, sqlite.Error) 153 | self.assertEqual(self.cx.InterfaceError, sqlite.InterfaceError) 154 | self.assertEqual(self.cx.DatabaseError, sqlite.DatabaseError) 155 | self.assertEqual(self.cx.DataError, sqlite.DataError) 156 | self.assertEqual(self.cx.OperationalError, sqlite.OperationalError) 157 | self.assertEqual(self.cx.IntegrityError, sqlite.IntegrityError) 158 | self.assertEqual(self.cx.InternalError, sqlite.InternalError) 159 | self.assertEqual(self.cx.ProgrammingError, sqlite.ProgrammingError) 160 | self.assertEqual(self.cx.NotSupportedError, sqlite.NotSupportedError) 161 | 162 | class CursorTests(unittest.TestCase): 163 | @classmethod 164 | def setUpClass(cls): 165 | cls.cx = sqlite.connect(":memory:") 166 | 167 | def setUp(self): 168 | self.cu = self.cx.cursor() 169 | self.cu.execute("create table test(id integer primary key, name text, income number)") 170 | self.cu.execute("insert into test(name) values (?)", ("foo",)) 171 | 172 | def tearDown(self): 173 | self.cu.close() 174 | self.cx.execute("drop table test") 175 | 176 | @classmethod 177 | def tearDownClass(cls): 178 | cls.cx.close() 179 | del cls.cx 180 | 181 | def test_CheckExecuteNoArgs(self): 182 | self.cu.execute("delete from test") 183 | 184 | @unittest.skip('not implemented') 185 | def test_CheckExecuteIllegalSql(self): 186 | try: 187 | self.cu.execute("select asdf") 188 | self.fail("should have raised an OperationalError") 189 | except sqlite.OperationalError: 190 | return 191 | except: 192 | self.fail("raised wrong exception") 193 | 194 | @unittest.skip('not implemented') 195 | def test_CheckExecuteTooMuchSql(self): 196 | try: 197 | self.cu.execute("select 5+4; select 4+5") 198 | self.fail("should have raised a Warning") 199 | except sqlite.Warning: 200 | return 201 | except: 202 | self.fail("raised wrong exception") 203 | 204 | @unittest.skip('not implemented') 205 | def test_CheckExecuteTooMuchSql2(self): 206 | self.cu.execute("select 5+4; -- foo bar") 207 | 208 | @unittest.skip('not implemented') 209 | def test_CheckExecuteTooMuchSql3(self): 210 | self.cu.execute(""" 211 | select 5+4; 212 | 213 | /* 214 | foo 215 | */ 216 | """) 217 | 218 | def test_CheckExecuteWrongSqlArg(self): 219 | try: 220 | self.cu.execute(42) 221 | self.fail("should have raised a ValueError") 222 | except ValueError: 223 | return 224 | except: 225 | self.fail("raised wrong exception.") 226 | 227 | def test_CheckExecuteArgInt(self): 228 | self.cu.execute("insert into test(id) values (?)", (42,)) 229 | 230 | def test_CheckExecuteArgFloat(self): 231 | self.cu.execute("insert into test(income) values (?)", (2500.32,)) 232 | 233 | def test_CheckExecuteArgString(self): 234 | self.cu.execute("insert into test(name) values (?)", ("Hugo",)) 235 | 236 | @unittest.expectedFailure 237 | def test_CheckExecuteArgStringWithZeroByte(self): 238 | self.cu.execute("insert into test(name) values (?)", ("Hu\x00go",)) 239 | 240 | self.cu.execute("select name from test where id=?", (self.cu.lastrowid,)) 241 | row = self.cu.fetchone() 242 | self.assertEqual(row[0], "Hu\x00go") 243 | 244 | def test_CheckExecuteWrongNoOfArgs1(self): 245 | # too many parameters 246 | try: 247 | self.cu.execute("insert into test(id) values (?)", (17, "Egon")) 248 | self.fail("should have raised ProgrammingError") 249 | except sqlite.ProgrammingError: 250 | pass 251 | 252 | def test_CheckExecuteWrongNoOfArgs2(self): 253 | # too little parameters 254 | try: 255 | self.cu.execute("insert into test(id) values (?)") 256 | self.fail("should have raised ProgrammingError") 257 | except sqlite.ProgrammingError: 258 | pass 259 | 260 | def test_CheckExecuteWrongNoOfArgs3(self): 261 | # no parameters, parameters are needed 262 | try: 263 | self.cu.execute("insert into test(id) values (?)") 264 | self.fail("should have raised ProgrammingError") 265 | except sqlite.ProgrammingError: 266 | pass 267 | 268 | def test_CheckExecuteParamList(self): 269 | self.cu.execute("insert into test(name) values ('foo')") 270 | self.cu.execute("select name from test where name=?", ["foo"]) 271 | row = self.cu.fetchone() 272 | self.assertEqual(row[0], "foo") 273 | 274 | def test_CheckExecuteParamListQueueWait(self): 275 | self.cu.execute("insert into test(name) values ('foo')", queue=True, wait=True) 276 | self.cu.execute("select name from test where name=?", ["foo"]) 277 | row = self.cu.fetchone() 278 | self.assertEqual(row[0], "foo") 279 | 280 | def test_CheckExecuteParamListConsistency(self): 281 | self.cu.execute("insert into test(name) values ('foo')") 282 | for c in ['strong', 'linearizable', 'weak', 'none', None]: 283 | self.cu.execute("select name from test where name=?", ["foo"], consistency=c) 284 | row = self.cu.fetchone() 285 | self.assertEqual(row[0], "foo") 286 | 287 | def test_CheckExecuteParamSequence(self): 288 | class L(object): 289 | def __len__(self): 290 | return 1 291 | def __getitem__(self, x): 292 | assert x == 0 293 | return "foo" 294 | 295 | self.cu.execute("insert into test(name) values ('foo')") 296 | self.cu.execute("select name from test where name=?", L()) 297 | row = self.cu.fetchone() 298 | self.assertEqual(row[0], "foo") 299 | 300 | def test_CheckExecuteDictMapping(self): 301 | self.cu.execute("insert into test(name) values ('foo')") 302 | self.cu.execute("select name from test where name=:name", {"name": "foo"}) 303 | row = self.cu.fetchone() 304 | self.assertEqual(row[0], "foo") 305 | 306 | def test_CheckExecuteDictMapping_Mapping(self): 307 | # Test only works with Python 2.5 or later 308 | if sys.version_info < (2, 5, 0): 309 | return 310 | 311 | class D(dict): 312 | def __missing__(self, key): 313 | return "foo" 314 | 315 | self.cu.execute("insert into test(name) values ('foo')") 316 | self.cu.execute("select name from test where name=:name", D()) 317 | row = self.cu.fetchone() 318 | self.assertEqual(row[0], "foo") 319 | 320 | def test_CheckExecuteDictMappingTooLittleArgs(self): 321 | self.cu.execute("insert into test(name) values ('foo')") 322 | try: 323 | self.cu.execute("select name from test where name=:name and id=:id", {"name": "foo"}) 324 | self.fail("should have raised ProgrammingError") 325 | except sqlite.ProgrammingError: 326 | pass 327 | 328 | def test_CheckExecuteDictMappingNoArgs(self): 329 | self.cu.execute("insert into test(name) values ('foo')") 330 | try: 331 | self.cu.execute("select name from test where name=:name") 332 | self.fail("should have raised ProgrammingError") 333 | except sqlite.ProgrammingError: 334 | pass 335 | 336 | def test_CheckExecuteNamedWithoutDict(self): 337 | self.cu.execute("insert into test(name) values ('foo')") 338 | try: 339 | self.cu.execute("select name from test where name=:name", ("name",)) 340 | self.fail("should have raised ProgrammingError") 341 | except sqlite.ProgrammingError: 342 | pass 343 | 344 | def test_CheckExecuteDictMappingUnnamed(self): 345 | self.cu.execute("insert into test(name) values ('foo')") 346 | try: 347 | self.cu.execute("select name from test where name=?", {"name": "foo"}) 348 | self.fail("should have raised ProgrammingError") 349 | except sqlite.ProgrammingError: 350 | pass 351 | 352 | def test_CheckClose(self): 353 | self.cu.close() 354 | self.cu = self.cx.cursor() 355 | 356 | def test_CheckRowcountExecute(self): 357 | self.cu.execute("delete from test") 358 | self.cu.execute("insert into test(name, income) values (?, ?)", ("?", "1")) 359 | self.cu.execute("select name from test where name=?", ("?",)) 360 | self.assertEqual(self.cu.rowcount, 1, 361 | msg="test failed for https://github.com/rqlite/pyrqlite/issues/30") 362 | self.cu.execute("insert into test(name) values ('foo')") 363 | self.cu.execute("update test set name='bar'") 364 | self.assertEqual(self.cu.rowcount, 2) 365 | 366 | @unittest.skip('not implemented') 367 | def test_CheckRowcountSelect(self): 368 | """ 369 | pysqlite does not know the rowcount of SELECT statements, because we 370 | don't fetch all rows after executing the select statement. The rowcount 371 | has thus to be -1. 372 | """ 373 | self.cu.execute("select 5 union select 6") 374 | self.assertEqual(self.cu.rowcount, -1) 375 | 376 | def test_CheckRowcountExecutemany(self): 377 | self.cu.execute("delete from test") 378 | self.cu.executemany("insert into test(name) values (?)", [(1,), (2,), (3,)]) 379 | self.assertEqual(self.cu.rowcount, 3) 380 | 381 | @unittest.skip('Cursor.total_changes is not implemented') 382 | def test_CheckTotalChanges(self): 383 | self.cu.execute("insert into test(name) values ('foo')") 384 | self.cu.execute("insert into test(name) values ('foo')") 385 | if self.cx.total_changes < 2: 386 | self.fail("total changes reported wrong value") 387 | 388 | # Checks for executemany: 389 | # Sequences are required by the DB-API, iterators 390 | # enhancements in pysqlite. 391 | 392 | def test_CheckExecuteManySequence(self): 393 | self.cu.executemany("insert into test(income) values (?)", [(x,) for x in range(100, 110)]) 394 | 395 | def test_CheckExecuteManyIterator(self): 396 | class MyIter: 397 | def __init__(self): 398 | self.value = 5 399 | 400 | def __iter__(self): 401 | return self 402 | 403 | def __next__(self): 404 | if self.value == 10: 405 | raise StopIteration 406 | else: 407 | self.value += 1 408 | return (self.value,) 409 | 410 | next = __next__ 411 | 412 | self.cu.executemany("insert into test(income) values (?)", MyIter()) 413 | 414 | def test_CheckExecuteManyGenerator(self): 415 | def mygen(): 416 | for i in range(5): 417 | yield (i,) 418 | 419 | self.cu.executemany("insert into test(income) values (?)", mygen()) 420 | 421 | def test_CheckExecuteManyWrongSqlArg(self): 422 | try: 423 | self.cu.executemany(42, [(3,)]) 424 | self.fail("should have raised a ValueError") 425 | except ValueError: 426 | return 427 | except: 428 | self.fail("raised wrong exception.") 429 | 430 | @unittest.skip('not implemented') 431 | def test_CheckExecuteManySelect(self): 432 | try: 433 | self.cu.executemany("select ?", [(3,)]) 434 | self.fail("should have raised a ProgrammingError") 435 | except sqlite.ProgrammingError: 436 | return 437 | except: 438 | self.fail("raised wrong exception.") 439 | 440 | def test_CheckExecuteManyNotIterable(self): 441 | try: 442 | self.cu.executemany("insert into test(income) values (?)", 42) 443 | self.fail("should have raised a TypeError") 444 | except TypeError: 445 | return 446 | except Exception as e: 447 | print("raised", e.__class__) 448 | self.fail("raised wrong exception.") 449 | 450 | def test_CheckFetchIter(self): 451 | # Optional DB-API extension. 452 | self.cu.execute("delete from test") 453 | self.cu.execute("insert into test(id) values (?)", (5,)) 454 | self.cu.execute("insert into test(id) values (?)", (6,)) 455 | self.cu.execute("select id from test order by id") 456 | lst = [] 457 | for row in self.cu: 458 | lst.append(row[0]) 459 | self.assertEqual(lst[0], 5) 460 | self.assertEqual(lst[1], 6) 461 | 462 | def test_CheckFetchone(self): 463 | self.cu.execute("select name from test") 464 | row = self.cu.fetchone() 465 | self.assertEqual(row[0], "foo") 466 | row = self.cu.fetchone() 467 | self.assertEqual(row, None) 468 | 469 | def test_CheckFetchoneNoStatement(self): 470 | cur = self.cx.cursor() 471 | row = cur.fetchone() 472 | self.assertEqual(row, None) 473 | 474 | def test_CheckArraySize(self): 475 | # must default ot 1 476 | self.assertEqual(self.cu.arraysize, 1) 477 | 478 | # now set to 2 479 | self.cu.arraysize = 2 480 | 481 | # now make the query return 3 rows 482 | self.cu.execute("delete from test") 483 | self.cu.execute("insert into test(name) values ('A')") 484 | self.cu.execute("insert into test(name) values ('B')") 485 | self.cu.execute("insert into test(name) values ('C')") 486 | self.cu.execute("select name from test") 487 | res = self.cu.fetchmany() 488 | 489 | self.assertEqual(len(res), 2) 490 | 491 | def test_CheckFetchmany(self): 492 | self.cu.execute("select name from test") 493 | res = self.cu.fetchmany(100) 494 | self.assertEqual(len(res), 1) 495 | res = self.cu.fetchmany(100) 496 | self.assertEqual(res, []) 497 | 498 | def test_CheckFetchmanyKwArg(self): 499 | """Checks if fetchmany works with keyword arguments""" 500 | self.cu.execute("select name from test") 501 | res = self.cu.fetchmany(size=100) 502 | self.assertEqual(len(res), 1) 503 | 504 | def test_CheckFetchall(self): 505 | self.cu.execute("select name from test") 506 | res = self.cu.fetchall() 507 | self.assertEqual(len(res), 1) 508 | res = self.cu.fetchall() 509 | self.assertEqual(res, []) 510 | 511 | @unittest.skip('Cursor.setinputsizes is not implemented') 512 | def test_CheckSetinputsizes(self): 513 | self.cu.setinputsizes([3, 4, 5]) 514 | 515 | @unittest.skip('Cursor.setoutputsize is not implemented') 516 | def test_CheckSetoutputsize(self): 517 | self.cu.setoutputsize(5, 0) 518 | 519 | @unittest.skip('Cursor.setoutputsize is not implemented') 520 | def test_CheckSetoutputsizeNoColumn(self): 521 | self.cu.setoutputsize(42) 522 | 523 | def test_CheckCursorConnection(self): 524 | # Optional DB-API extension. 525 | self.assertEqual(self.cu.connection, self.cx) 526 | 527 | def test_CheckWrongCursorCallable(self): 528 | try: 529 | def f(): pass 530 | cur = self.cx.cursor(f) 531 | self.fail("should have raised a TypeError") 532 | except TypeError: 533 | return 534 | self.fail("should have raised a ValueError") 535 | 536 | @unittest.skip('not implemented') 537 | def test_CheckCursorWrongClass(self): 538 | class Foo: pass 539 | foo = Foo() 540 | try: 541 | cur = sqlite.Cursor(foo) 542 | self.fail("should have raised a ValueError") 543 | except TypeError: 544 | pass 545 | 546 | 547 | class ConstructorTests(unittest.TestCase): 548 | def test_CheckDate(self): 549 | d = sqlite.Date(2004, 10, 28) 550 | 551 | def test_CheckTime(self): 552 | t = sqlite.Time(12, 39, 35) 553 | 554 | def test_CheckTimestamp(self): 555 | ts = sqlite.Timestamp(2004, 10, 28, 12, 39, 35) 556 | 557 | def test_CheckDateFromTicks(self): 558 | d = sqlite.DateFromTicks(42) 559 | 560 | def test_CheckTimeFromTicks(self): 561 | t = sqlite.TimeFromTicks(42) 562 | 563 | def test_CheckTimestampFromTicks(self): 564 | ts = sqlite.TimestampFromTicks(42) 565 | 566 | def test_CheckBinary(self): 567 | self.assertEqual( 568 | b"\0'", 569 | sqlite.Binary( 570 | chr(0).encode() + b"'" if sys.version_info[0] >= 3 else chr(0) + b"'" 571 | ), 572 | ) 573 | 574 | class ExtensionTests(unittest.TestCase): 575 | @classmethod 576 | def setUpClass(cls): 577 | cls.con = sqlite.connect(":memory:") 578 | 579 | def tearDown(self): 580 | for row in self.con.execute( 581 | "SELECT name FROM sqlite_master WHERE type='table'").fetchall(): 582 | self.con.execute("drop table '{}'".format(row[0])) 583 | 584 | @classmethod 585 | def tearDownClass(cls): 586 | cls.con.close() 587 | del cls.con 588 | 589 | @unittest.skip('Cursor.executescript is not implemented') 590 | def test_CheckScriptStringSql(self): 591 | cur = self.con.cursor() 592 | cur.executescript(""" 593 | -- bla bla 594 | /* a stupid comment */ 595 | create table a(i); 596 | insert into a(i) values (5); 597 | """) 598 | cur.execute("select i from a") 599 | res = cur.fetchone()[0] 600 | self.assertEqual(res, 5) 601 | 602 | @unittest.skip('Cursor.executescript is not implemented') 603 | def test_CheckScriptStringUnicode(self): 604 | cur = self.con.cursor() 605 | cur.executescript(u""" 606 | create table a(i); 607 | insert into a(i) values (5); 608 | select i from a; 609 | delete from a; 610 | insert into a(i) values (6); 611 | """) 612 | cur.execute("select i from a") 613 | res = cur.fetchone()[0] 614 | self.assertEqual(res, 6) 615 | 616 | @unittest.skip('Cursor.executescript is not implemented') 617 | def test_CheckScriptSyntaxError(self): 618 | cur = self.con.cursor() 619 | raised = False 620 | try: 621 | cur.executescript("create table test(x); asdf; create table test2(x)") 622 | except sqlite.OperationalError: 623 | raised = True 624 | self.assertEqual(raised, True, "should have raised an exception") 625 | 626 | @unittest.skip('Cursor.executescript is not implemented') 627 | def test_CheckScriptErrorNormal(self): 628 | cur = self.con.cursor() 629 | raised = False 630 | try: 631 | cur.executescript("create table test(sadfsadfdsa); select foo from hurz;") 632 | except sqlite.OperationalError: 633 | raised = True 634 | self.assertEqual(raised, True, "should have raised an exception") 635 | 636 | def test_CheckConnectionExecute(self): 637 | result = self.con.execute("select 5").fetchone()[0] 638 | self.assertEqual(result, 5, "Basic test of Connection.execute") 639 | 640 | @unittest.skip('Connection.executemany is not implemented') 641 | def test_CheckConnectionExecutemany(self): 642 | con = self.con 643 | con.execute("create table test(foo)") 644 | con.executemany("insert into test(foo) values (?)", [(3,), (4,)]) 645 | result = con.execute("select foo from test order by foo").fetchall() 646 | self.assertEqual(result[0][0], 3, "Basic test of Connection.executemany") 647 | self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany") 648 | 649 | @unittest.skip('Connection.executescript is not implemented') 650 | def test_CheckConnectionExecutescript(self): 651 | con = self.con 652 | con.executescript("create table test(foo); insert into test(foo) values (5);") 653 | result = con.execute("select foo from test").fetchone()[0] 654 | self.assertEqual(result, 5, "Basic test of Connection.executescript") 655 | 656 | @unittest.skip('not implemented') 657 | class ClosedConTests(unittest.TestCase): 658 | @classmethod 659 | def setUpClass(cls): 660 | cls.con = sqlite.connect(":memory:") 661 | cls.cur = cls.con.cursor() 662 | cls.con.close() 663 | 664 | @classmethod 665 | def tearDownClass(cls): 666 | del cls.cur 667 | del cls.con 668 | 669 | def test_CheckClosedConCursor(self): 670 | con = self.con 671 | try: 672 | cur = con.cursor() 673 | self.fail("Should have raised a ProgrammingError") 674 | except sqlite.ProgrammingError: 675 | pass 676 | except: 677 | self.fail("Should have raised a ProgrammingError") 678 | 679 | def test_CheckClosedConCommit(self): 680 | con = self.con 681 | try: 682 | con.commit() 683 | self.fail("Should have raised a ProgrammingError") 684 | except sqlite.ProgrammingError: 685 | pass 686 | except: 687 | self.fail("Should have raised a ProgrammingError") 688 | 689 | def test_CheckClosedConRollback(self): 690 | con = self.con 691 | try: 692 | con.rollback() 693 | self.fail("Should have raised a ProgrammingError") 694 | except sqlite.ProgrammingError: 695 | pass 696 | except: 697 | self.fail("Should have raised a ProgrammingError") 698 | 699 | def test_CheckClosedCurExecute(self): 700 | cur = self.cur 701 | try: 702 | cur.execute("select 4") 703 | self.fail("Should have raised a ProgrammingError") 704 | except sqlite.ProgrammingError: 705 | pass 706 | except: 707 | self.fail("Should have raised a ProgrammingError") 708 | 709 | def test_CheckClosedCreateFunction(self): 710 | con = self.con 711 | def f(x): return 17 712 | try: 713 | con.create_function("foo", 1, f) 714 | self.fail("Should have raised a ProgrammingError") 715 | except sqlite.ProgrammingError: 716 | pass 717 | except: 718 | self.fail("Should have raised a ProgrammingError") 719 | 720 | def test_CheckClosedCreateAggregate(self): 721 | con = self.con 722 | class Agg: 723 | def __init__(self): 724 | pass 725 | def step(self, x): 726 | pass 727 | def finalize(self): 728 | return 17 729 | try: 730 | con.create_aggregate("foo", 1, Agg) 731 | self.fail("Should have raised a ProgrammingError") 732 | except sqlite.ProgrammingError: 733 | pass 734 | except: 735 | self.fail("Should have raised a ProgrammingError") 736 | 737 | def test_CheckClosedSetAuthorizer(self): 738 | con = self.con 739 | def authorizer(*args): 740 | return sqlite.DENY 741 | try: 742 | con.set_authorizer(authorizer) 743 | self.fail("Should have raised a ProgrammingError") 744 | except sqlite.ProgrammingError: 745 | pass 746 | except: 747 | self.fail("Should have raised a ProgrammingError") 748 | 749 | def test_CheckClosedSetProgressCallback(self): 750 | con = self.con 751 | def progress(): pass 752 | try: 753 | con.set_progress_handler(progress, 100) 754 | self.fail("Should have raised a ProgrammingError") 755 | except sqlite.ProgrammingError: 756 | pass 757 | except: 758 | self.fail("Should have raised a ProgrammingError") 759 | 760 | def test_CheckClosedCall(self): 761 | con = self.con 762 | try: 763 | con() 764 | self.fail("Should have raised a ProgrammingError") 765 | except sqlite.ProgrammingError: 766 | pass 767 | except: 768 | self.fail("Should have raised a ProgrammingError") 769 | 770 | @unittest.skip('not implemented') 771 | class ClosedCurTests(unittest.TestCase): 772 | @classmethod 773 | def setUpClass(cls): 774 | cls.con = sqlite.connect(":memory:") 775 | cls.cur = cls.con.cursor() 776 | cls.cur.close() 777 | 778 | @classmethod 779 | def tearDownClass(cls): 780 | cls.con.close() 781 | del cls.cur 782 | del cls.con 783 | 784 | def test_CheckClosed(self): 785 | cur = self.cur 786 | 787 | for method_name in ("execute", "executemany", "executescript", "fetchall", "fetchmany", "fetchone"): 788 | if method_name in ("execute", "executescript"): 789 | params = ("select 4 union select 5",) 790 | elif method_name == "executemany": 791 | params = ("insert into foo(bar) values (?)", [(3,), (4,)]) 792 | else: 793 | params = [] 794 | 795 | try: 796 | method = getattr(cur, method_name) 797 | 798 | method(*params) 799 | self.fail("Should have raised a ProgrammingError: method " + method_name) 800 | except sqlite.ProgrammingError: 801 | pass 802 | except: 803 | self.fail("Should have raised a ProgrammingError: " + method_name) 804 | 805 | def suite(): 806 | loader = unittest.TestLoader() 807 | module_suite = loader.loadTestsFromTestCase(ModuleTests) 808 | connection_suite = loader.loadTestsFromTestCase(ConnectionTests) 809 | cursor_suite = loader.loadTestsFromTestCase(CursorTests) 810 | constructor_suite = loader.loadTestsFromTestCase(ConstructorTests) 811 | ext_suite = loader.loadTestsFromTestCase(ExtensionTests) 812 | closed_con_suite = loader.loadTestsFromTestCase(ClosedConTests) 813 | closed_cur_suite = loader.loadTestsFromTestCase(ClosedCurTests) 814 | return unittest.TestSuite((module_suite, connection_suite, cursor_suite, constructor_suite, ext_suite, closed_con_suite, closed_cur_suite)) 815 | 816 | def main(): 817 | runner = unittest.TextTestRunner(verbosity=2) 818 | runner.run(suite()) 819 | 820 | if __name__ == "__main__": 821 | main() 822 | -------------------------------------------------------------------------------- /src/test/test_row.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pyrqlite.dbapi2 as sqlite 4 | from pyrqlite.row import Row 5 | 6 | 7 | class NonUniqueColumnsTest(unittest.TestCase): 8 | @classmethod 9 | def setUpClass(cls): 10 | cls.cx = sqlite.connect(":memory:") 11 | 12 | def setUp(self): 13 | """ 14 | Create tables to join with some columns that have the same names, 15 | to demonstrated rows with non-unique column names. 16 | """ 17 | self.cu = self.cx.cursor() 18 | self.cu.execute("create table tbl1(id integer primary key, name text)") 19 | self.cu.execute("insert into tbl1(id, name) values (1, 'foo')") 20 | self.cu.execute( 21 | "create table tbl2(id integer primary key, tbl1id integer, name text)" 22 | ) 23 | self.cu.execute("insert into tbl2(id, tbl1id, name) values (2, 1, 'bar')") 24 | 25 | def tearDown(self): 26 | self.cu.close() 27 | 28 | @classmethod 29 | def tearDownClass(cls): 30 | cls.cx.close() 31 | del cls.cx 32 | 33 | def testJoin(self): 34 | """ 35 | Test a join that demonstrates rows with non-unique column names. 36 | """ 37 | self.cu.execute("select * from tbl1 inner join tbl2 on tbl2.tbl1id = tbl1.id") 38 | row = self.cu.fetchone() 39 | self.assertEqual(tuple(row.keys()), ("id", "name", "id", "tbl1id", "name")) 40 | self.assertEqual(tuple(row.values()), (1, "foo", 2, 1, "bar")) 41 | self.assertEqual(row["name"], "foo") 42 | self.assertEqual(row["id"], 1) 43 | self.assertEqual(row["tbl1id"], 1) 44 | 45 | 46 | def test_row(): 47 | row = Row([('foo', 'foo'), ('bar', 'bar')]) 48 | assert len(row) == 2 49 | assert row['foo'] == 'foo' 50 | assert row['bar'] == 'bar' 51 | assert row[0] == 'foo' 52 | assert row[1] == 'bar' 53 | 54 | try: 55 | row[2] 56 | except IndexError: 57 | pass 58 | else: 59 | assert False 60 | 61 | try: 62 | row['non-existent'] 63 | except KeyError: 64 | pass 65 | else: 66 | assert False 67 | 68 | 69 | def test_row_with_non_unique_columns(): 70 | items = [("foo", "foo"), ("bar", "bar"), ("foo", "bar")] 71 | row = Row(items) 72 | assert len(row) == 3 73 | 74 | assert row["foo"] == "foo" 75 | assert row["bar"] == "bar" 76 | assert row[0] == "foo" 77 | assert row[1] == "bar" 78 | assert row[2] == "bar" 79 | assert list(row.items()) == items 80 | assert list(row.values()) == [item[1] for item in items] 81 | -------------------------------------------------------------------------------- /src/test/test_types.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | # pysqlite2/test/types.py: tests for type conversion and detection 3 | # 4 | # Copyright (C) 2005 Gerhard Höring 5 | # 6 | # This file is part of pysqlite. 7 | # 8 | # This software is provided 'as-is', without any express or implied 9 | # warranty. In no event will the authors be held liable for any damages 10 | # arising from the use of this software. 11 | # 12 | # Permission is granted to anyone to use this software for any purpose, 13 | # including commercial applications, and to alter it and redistribute it 14 | # freely, subject to the following restrictions: 15 | # 16 | # 1. The origin of this software must not be misrepresented; you must not 17 | # claim that you wrote the original software. If you use this software 18 | # in a product, an acknowledgment in the product documentation would be 19 | # appreciated but is not required. 20 | # 2. Altered source versions must be plainly marked as such, and must not be 21 | # misrepresented as being the original software. 22 | # 3. This notice may not be removed or altered from any source distribution. 23 | 24 | import datetime 25 | import sys 26 | import unittest 27 | import pyrqlite.dbapi2 as sqlite 28 | try: 29 | import zlib 30 | except ImportError: 31 | zlib = None 32 | 33 | 34 | class SqliteTypeTests(unittest.TestCase): 35 | @classmethod 36 | def setUpClass(cls): 37 | cls.con = sqlite.connect(":memory:") 38 | cls.cur = cls.con.cursor() 39 | if cls.cur.execute("pragma table_info(test)").fetchall(): 40 | cls.cur.execute("drop table test") 41 | 42 | def setUp(self): 43 | self.cur.execute("create table test(i INTEGER, s VARCHAR, f NUMBER, b BLOB)") 44 | 45 | def tearDown(self): 46 | self.cur.execute("drop table test") 47 | 48 | @classmethod 49 | def tearDownClass(cls): 50 | cls.cur.close() 51 | cls.con.close() 52 | 53 | def test_CheckString(self): 54 | self.cur.execute("insert into test(s) values (?)", (u"Österreich",)) 55 | self.cur.execute("select s from test") 56 | row = self.cur.fetchone() 57 | self.assertEqual(row[0], u"Österreich") 58 | 59 | def test_CheckStringNull(self): 60 | self.cur.execute("insert into test(s) values (?)", (None,)) 61 | self.cur.execute("select s from test") 62 | row = self.cur.fetchone() 63 | self.assertEqual(row[0], None) 64 | 65 | def test_CheckSmallInt(self): 66 | self.cur.execute("insert into test(i) values (?)", (42,)) 67 | self.cur.execute("select i from test") 68 | row = self.cur.fetchone() 69 | self.assertEqual(row[0], 42) 70 | 71 | def test_CheckLargeInt(self): 72 | num = 2**40 73 | self.cur.execute("insert into test(i) values (?)", (num,)) 74 | self.cur.execute("select i from test") 75 | row = self.cur.fetchone() 76 | self.assertEqual(row[0], num) 77 | 78 | def test_CheckIntNull(self): 79 | self.cur.execute("insert into test(i) values (?)", (None,)) 80 | self.cur.execute("select i from test") 81 | row = self.cur.fetchone() 82 | self.assertEqual(row[0], None) 83 | 84 | def test_CheckFloat(self): 85 | val = 3.14 86 | self.cur.execute("insert into test(f) values (?)", (val,)) 87 | self.cur.execute("select f from test") 88 | row = self.cur.fetchone() 89 | self.assertEqual(row[0], val) 90 | 91 | def test_CheckFloatNull(self): 92 | self.cur.execute("insert into test(f) values (?)", (None,)) 93 | self.cur.execute("select f from test") 94 | row = self.cur.fetchone() 95 | self.assertEqual(row[0], None) 96 | 97 | def test_CheckBlob(self): 98 | sample = b"\x99Guglhupf" 99 | val = sample 100 | self.cur.execute("insert into test(b) values (?)", (val,)) 101 | self.cur.execute("select b from test") 102 | row = self.cur.fetchone() 103 | self.assertEqual(row[0], sample) 104 | 105 | def test_CheckBlobNull(self): 106 | self.cur.execute("insert into test(b) values (?)", (None,)) 107 | self.cur.execute("select b from test") 108 | row = self.cur.fetchone() 109 | self.assertEqual(row[0], None) 110 | 111 | def test_CheckUnicodeExecute(self): 112 | self.cur.execute(u"select 'Österreich'") 113 | row = self.cur.fetchone() 114 | self.assertEqual(row[0], u"Österreich") 115 | 116 | def test_CheckNullExecute(self): 117 | self.cur.execute("select null") 118 | row = self.cur.fetchone() 119 | self.assertEqual(row[0], None) 120 | 121 | def test_PragmaTableInfo(self): 122 | self.cur.execute("pragma table_info('test')") 123 | rows = self.cur.fetchall() 124 | self.assertEqual(rows, 125 | [ 126 | (0, 'i', 'INTEGER', 0, None, 0), 127 | (1, 's', 'VARCHAR', 0, None, 0), 128 | (2, 'f', 'NUMBER', 0, None, 0), 129 | (3, 'b', 'BLOB', 0, None, 0), 130 | ] 131 | ) 132 | 133 | class DeclTypesTests(unittest.TestCase): 134 | class Foo: 135 | def __init__(self, _val): 136 | if isinstance(_val, bytes): 137 | # sqlite3 always calls __init__ with a bytes created from a 138 | # UTF-8 string when __conform__ was used to store the object. 139 | _val = _val.decode('utf-8') 140 | self.val = _val 141 | 142 | def __eq__(self, other): 143 | if not isinstance(other, DeclTypesTests.Foo): 144 | return NotImplemented 145 | return self.val == other.val 146 | 147 | def __conform__(self, protocol): 148 | if protocol is sqlite.PrepareProtocol: 149 | return self.val 150 | else: 151 | return None 152 | 153 | def __str__(self): 154 | return "<%s>" % self.val 155 | 156 | @classmethod 157 | def setUpClass(cls): 158 | cls.con = sqlite.connect(":memory:", 159 | detect_types=sqlite.PARSE_DECLTYPES) 160 | cls.cur = cls.con.cursor() 161 | if cls.cur.execute("pragma table_info(test)").fetchall(): 162 | cls.cur.execute("drop table test") 163 | 164 | # override float, make them always return the same number 165 | sqlite.converters["FLOAT"] = lambda x: 47.2 166 | 167 | # and implement two custom ones 168 | sqlite.converters["BOOL"] = lambda x: bool(int(x)) 169 | sqlite.converters["FOO"] = DeclTypesTests.Foo 170 | sqlite.converters["WRONG"] = lambda x: "WRONG" 171 | sqlite.converters["NUMBER"] = float 172 | 173 | def setUp(self): 174 | self.cur.execute("create table test(i int, s str, f float, b bool, u unicode, foo foo, bin blob, n1 number, n2 number(5))") 175 | 176 | def tearDown(self): 177 | self.cur.execute("drop table test") 178 | 179 | @classmethod 180 | def tearDownClass(cls): 181 | del sqlite.converters["FLOAT"] 182 | del sqlite.converters["BOOL"] 183 | del sqlite.converters["FOO"] 184 | del sqlite.converters["WRONG"] 185 | del sqlite.converters["NUMBER"] 186 | cls.cur.close() 187 | cls.con.close() 188 | 189 | @unittest.skipIf(sys.version_info[0] >= 3, "AssertionError: b'foo' != 'foo'") 190 | def test_CheckString(self): 191 | # default 192 | self.cur.execute("insert into test(s) values (?)", ("foo",)) 193 | self.cur.execute('select s as "s [WRONG]" from test') 194 | row = self.cur.fetchone() 195 | self.assertEqual(row[0], "foo") 196 | 197 | def test_CheckSmallInt(self): 198 | # default 199 | self.cur.execute("insert into test(i) values (?)", (42,)) 200 | self.cur.execute("select i from test") 201 | row = self.cur.fetchone() 202 | self.assertEqual(row[0], 42) 203 | 204 | def test_CheckLargeInt(self): 205 | # default 206 | num = 2**40 207 | self.cur.execute("insert into test(i) values (?)", (num,)) 208 | self.cur.execute("select i from test") 209 | row = self.cur.fetchone() 210 | self.assertEqual(row[0], num) 211 | 212 | def test_CheckFloat(self): 213 | # custom 214 | val = 3.14 215 | self.cur.execute("insert into test(f) values (?)", (val,)) 216 | self.cur.execute("select f from test") 217 | row = self.cur.fetchone() 218 | self.assertEqual(row[0], 47.2) 219 | 220 | def test_CheckBool(self): 221 | # custom 222 | self.cur.execute("insert into test(b) values (?)", (False,)) 223 | self.cur.execute("select b from test") 224 | row = self.cur.fetchone() 225 | self.assertEqual(row[0], False) 226 | 227 | self.cur.execute("delete from test") 228 | self.cur.execute("insert into test(b) values (?)", (True,)) 229 | self.cur.execute("select b from test") 230 | row = self.cur.fetchone() 231 | self.assertEqual(row[0], True) 232 | 233 | @unittest.expectedFailure # binascii.Error: decoding with 'base64' codec failed (Error: Incorrect padding) 234 | def test_CheckUnicode(self): 235 | # default 236 | val = u"\xd6sterreich" 237 | self.cur.execute("insert into test(u) values (?)", (val,)) 238 | self.cur.execute("select u from test") 239 | row = self.cur.fetchone() 240 | self.assertEqual(row[0], val) 241 | 242 | @unittest.expectedFailure # binascii.Error: decoding with 'base64' codec failed (Error: Incorrect padding) 243 | def test_CheckFoo(self): 244 | val = DeclTypesTests.Foo("bla") 245 | self.cur.execute("insert into test(foo) values (?)", (val,)) 246 | self.cur.execute("select foo from test") 247 | row = self.cur.fetchone() 248 | self.assertEqual(row[0], val) 249 | 250 | def test_CheckUnsupportedSeq(self): 251 | class Bar: pass 252 | val = Bar() 253 | with self.assertRaises(sqlite.InterfaceError): 254 | self.cur.execute("insert into test(f) values (?)", (val,)) 255 | 256 | @unittest.skip('named paramstyle is not implemented') 257 | def test_CheckUnsupportedDict(self): 258 | class Bar: pass 259 | val = Bar() 260 | with self.assertRaises(sqlite.InterfaceError): 261 | self.cur.execute("insert into test(f) values (:val)", {"val": val}) 262 | 263 | def test_CheckBlob(self): 264 | # default 265 | sample = b"Guglhupf" 266 | val = sample 267 | self.cur.execute("insert into test(bin) values (?)", (val,)) 268 | self.cur.execute("select bin from test") 269 | row = self.cur.fetchone() 270 | self.assertEqual(row[0], sample) 271 | 272 | def test_CheckBlobNull(self): 273 | self.cur.execute("insert into test(bin) values (?)", (None,)) 274 | self.cur.execute("select bin from test") 275 | row = self.cur.fetchone() 276 | self.assertEqual(row[0], None) 277 | 278 | def test_CheckNumber1(self): 279 | self.cur.execute("insert into test(n1) values (5)") 280 | value = self.cur.execute("select n1 from test").fetchone()[0] 281 | # if the converter is not used, it's an int instead of a float 282 | self.assertEqual(type(value), float) 283 | 284 | def test_CheckNumber2(self): 285 | """Checks whether converter names are cut off at '(' characters""" 286 | self.cur.execute("insert into test(n2) values (5)") 287 | value = self.cur.execute("select n2 from test").fetchone()[0] 288 | # if the converter is not used, it's an int instead of a float 289 | self.assertEqual(type(value), float) 290 | 291 | class ColNamesTests(unittest.TestCase): 292 | @classmethod 293 | def setUpClass(cls): 294 | cls.con = sqlite.connect(":memory:", 295 | detect_types=sqlite.PARSE_COLNAMES) 296 | cls.cur = cls.con.cursor() 297 | if cls.cur.execute("pragma table_info(test)").fetchall(): 298 | cls.cur.execute("drop table test") 299 | 300 | sqlite.converters["FOO"] = lambda x: "[%s]" % x.decode("ascii") 301 | sqlite.converters["BAR"] = lambda x: "<%s>" % x.decode("ascii") 302 | sqlite.converters["EXC"] = lambda x: 5/0 303 | sqlite.converters["B1B1"] = lambda x: "MARKER" 304 | 305 | def setUp(self): 306 | self.cur.execute("create table test(x foo)") 307 | 308 | def tearDown(self): 309 | self.cur.execute("drop table test") 310 | 311 | @classmethod 312 | def tearDownClass(cls): 313 | del sqlite.converters["FOO"] 314 | del sqlite.converters["BAR"] 315 | del sqlite.converters["EXC"] 316 | del sqlite.converters["B1B1"] 317 | cls.cur.close() 318 | cls.con.close() 319 | 320 | @unittest.expectedFailure # binascii.Error: decoding with 'base64' codec failed (Error: Incorrect padding) 321 | def test_CheckDeclTypeNotUsed(self): 322 | """ 323 | Assures that the declared type is not used when PARSE_DECLTYPES 324 | is not set. 325 | """ 326 | self.cur.execute("insert into test(x) values (?)", ("xxx",)) 327 | self.cur.execute("select x from test") 328 | val = self.cur.fetchone()[0] 329 | self.assertEqual(val, b"xxx") 330 | 331 | def test_CheckNone(self): 332 | self.cur.execute("insert into test(x) values (?)", (None,)) 333 | self.cur.execute("select x from test") 334 | val = self.cur.fetchone()[0] 335 | self.assertEqual(val, None) 336 | 337 | @unittest.expectedFailure # binascii.Error: decoding with 'base64' codec failed (Error: Incorrect padding) 338 | def test_CheckColName(self): 339 | self.cur.execute("insert into test(x) values (?)", ("xxx",)) 340 | self.cur.execute('select x as "x [bar]" from test') 341 | val = self.cur.fetchone()[0] 342 | self.assertEqual(val, "") 343 | 344 | # Check if the stripping of colnames works. Everything after the first 345 | # whitespace should be stripped. 346 | self.assertEqual(self.cur.description[0][0], "x") 347 | 348 | @unittest.expectedFailure # https://github.com/rqlite/pyrqlite/issues/16 349 | def test_CheckCaseInConverterName(self): 350 | self.cur.execute("select 'other' as \"x [b1b1]\"") 351 | val = self.cur.fetchone()[0] 352 | self.assertEqual(val, "MARKER") 353 | 354 | def test_CheckCursorDescriptionNoRow(self): 355 | """ 356 | cursor.description should at least provide the column name(s), even if 357 | no row returned. 358 | """ 359 | self.cur.execute("select * from test where 0 = 1") 360 | self.assertEqual(self.cur.description[0][0], "x") 361 | 362 | def test_CheckCursorDescriptionInsert(self): 363 | self.cur.execute("insert into test values (1)") 364 | self.assertIsNone(self.cur.description) 365 | 366 | 367 | @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "CTEs not supported") 368 | class CommonTableExpressionTests(unittest.TestCase): 369 | 370 | @classmethod 371 | def setUpClass(cls): 372 | cls.con = sqlite.connect(":memory:") 373 | cls.cur = cls.con.cursor() 374 | if cls.cur.execute("pragma table_info(test)").fetchall(): 375 | cls.cur.execute("drop table test") 376 | 377 | def setUp(self): 378 | self.cur.execute("create table test(x foo)") 379 | 380 | def tearDown(self): 381 | self.cur.execute("drop table test") 382 | 383 | @classmethod 384 | def tearDownClass(cls): 385 | cls.cur.close() 386 | cls.con.close() 387 | 388 | ## Disabled until this is resolved: https://github.com/rqlite/rqlite/issues/255 389 | # def test_CheckCursorDescriptionCTESimple(self): 390 | # self.cur.execute("with one as (select 1) select * from one") 391 | # self.assertIsNotNone(self.cur.description) 392 | # self.assertEqual(self.cur.description[0][0], "1") 393 | # 394 | # def test_CheckCursorDescriptionCTESMultipleColumns(self): 395 | # self.cur.execute("insert into test values(1)") 396 | # self.cur.execute("insert into test values(2)") 397 | # self.cur.execute("with testCTE as (select * from test) select * from testCTE") 398 | # self.assertIsNotNone(self.cur.description) 399 | # self.assertEqual(self.cur.description[0][0], "x") 400 | # 401 | # def test_CheckCursorDescriptionCTE(self): 402 | # self.cur.execute("insert into test values (1)") 403 | # self.cur.execute("with bar as (select * from test) select * from test where x = 1") 404 | # self.assertIsNotNone(self.cur.description) 405 | # self.assertEqual(self.cur.description[0][0], "x") 406 | # self.cur.execute("with bar as (select * from test) select * from test where x = 2") 407 | # self.assertIsNotNone(self.cur.description) 408 | # self.assertEqual(self.cur.description[0][0], "x") 409 | 410 | 411 | class ObjectAdaptationTests(unittest.TestCase): 412 | def cast(obj): 413 | return float(obj) 414 | cast = staticmethod(cast) 415 | 416 | @classmethod 417 | def setUpClass(cls): 418 | cls.con = sqlite.connect(":memory:") 419 | cls.cur = cls.con.cursor() 420 | if cls.cur.execute("pragma table_info(test)").fetchall(): 421 | cls.cur.execute("drop table test") 422 | try: 423 | del sqlite.adapters[int] 424 | except: 425 | pass 426 | sqlite.register_adapter(int, ObjectAdaptationTests.cast) 427 | 428 | @classmethod 429 | def tearDownClass(cls): 430 | del sqlite.adapters[(int, sqlite.PrepareProtocol)] 431 | cls.cur.close() 432 | cls.con.close() 433 | 434 | def test_CheckCasterIsUsed(self): 435 | self.cur.execute("select ?", (4,)) 436 | val = self.cur.fetchone()[0] 437 | self.assertEqual(type(val), float) 438 | 439 | @unittest.skipUnless(zlib, "requires zlib") 440 | class BinaryConverterTests(unittest.TestCase): 441 | def convert(s): 442 | return zlib.decompress(s) 443 | convert = staticmethod(convert) 444 | 445 | @classmethod 446 | def setUpClass(cls): 447 | cls.con = sqlite.connect(":memory:", 448 | detect_types=sqlite.PARSE_COLNAMES) 449 | if cls.con.execute("pragma table_info(test)").fetchall(): 450 | cls.con.execute("drop table test") 451 | sqlite.register_converter("bin", BinaryConverterTests.convert) 452 | 453 | @classmethod 454 | def tearDownClass(cls): 455 | cls.con.close() 456 | 457 | @unittest.expectedFailure # https://github.com/rqlite/pyrqlite/issues/17 458 | def test_CheckBinaryInputForConverter(self): 459 | testdata = b"abcdefg" * 10 460 | compressed = zlib.compress(testdata) 461 | result = self.con.execute('select ? as "x [bin]"', (compressed,)).fetchone()[0] 462 | self.assertEqual(testdata, result) 463 | 464 | class DateTimeTests(unittest.TestCase): 465 | @classmethod 466 | def setUpClass(cls): 467 | cls.con = sqlite.connect(":memory:", 468 | detect_types=sqlite.PARSE_DECLTYPES) 469 | cls.cur = cls.con.cursor() 470 | if cls.cur.execute("pragma table_info(test)").fetchall(): 471 | cls.cur.execute("drop table test") 472 | 473 | def setUp(self): 474 | self.cur.execute("create table test(d date, ts timestamp, dt datetime)") 475 | 476 | def tearDown(self): 477 | self.cur.execute("drop table test") 478 | 479 | @classmethod 480 | def tearDownClass(cls): 481 | cls.cur.close() 482 | cls.con.close() 483 | 484 | def test_CheckSqliteDate(self): 485 | d = sqlite.Date(2004, 2, 14) 486 | self.cur.execute("insert into test(d) values (?)", (d,)) 487 | self.cur.execute("select d from test") 488 | d2 = self.cur.fetchone()[0] 489 | self.assertEqual(d, d2) 490 | 491 | def test_CheckSqliteTimestamp(self): 492 | ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0) 493 | self.cur.execute("insert into test(ts) values (?)", (ts,)) 494 | self.cur.execute("select ts from test") 495 | ts2 = self.cur.fetchone()[0] 496 | self.assertEqual(ts, ts2) 497 | 498 | @unittest.skipIf(sqlite.sqlite_version_info < (3, 1), 499 | 'the date functions are available on 3.1 or later') 500 | def test_CheckSqlTimestamp(self): 501 | now = datetime.datetime.utcnow() 502 | self.cur.execute("insert into test(ts) values (current_timestamp)") 503 | self.cur.execute("select ts from test") 504 | ts = self.cur.fetchone()[0] 505 | self.assertEqual(type(ts), datetime.datetime) 506 | self.assertEqual(ts.year, now.year) 507 | 508 | def test_CheckDateTimeSubSeconds(self): 509 | ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 500000) 510 | self.cur.execute("insert into test(ts) values (?)", (ts,)) 511 | self.cur.execute("select ts from test") 512 | ts2 = self.cur.fetchone()[0] 513 | self.assertEqual(ts, ts2) 514 | 515 | def test_CheckDateTimeSubSecondsFloatingPoint(self): 516 | ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 510241) 517 | self.cur.execute("insert into test(ts) values (?)", (ts,)) 518 | self.cur.execute("select ts from test") 519 | ts2 = self.cur.fetchone()[0] 520 | self.assertEqual(ts, ts2) 521 | 522 | def test_CheckSqlDatetime(self): 523 | now = datetime.datetime.utcnow() 524 | self.cur.execute("insert into test(dt) values (?)", (now,)) 525 | self.cur.execute("select dt from test") 526 | dt = self.cur.fetchone()[0] 527 | self.assertEqual(dt, now.isoformat(' ').rstrip('0').rstrip('.')) 528 | 529 | def suite(): 530 | loader = unittest.TestLoader() 531 | sqlite_type_suite = loader.loadTestsFromTestCase(SqliteTypeTests) 532 | decltypes_type_suite = loader.loadTestsFromTestCase(DeclTypesTests) 533 | colnames_type_suite = loader.loadTestsFromTestCase(ColNamesTests) 534 | adaptation_suite = loader.loadTestsFromTestCase(ObjectAdaptationTests) 535 | bin_suite = loader.loadTestsFromTestCase(BinaryConverterTests) 536 | date_suite = loader.loadTestsFromTestCase(DateTimeTests) 537 | cte_suite = loader.loadTestsFromTestCase(CommonTableExpressionTests) 538 | return unittest.TestSuite((sqlite_type_suite, decltypes_type_suite, colnames_type_suite, adaptation_suite, bin_suite, date_suite, cte_suite)) 539 | 540 | def main(): 541 | runner = unittest.TextTestRunner(verbosity=2) 542 | runner.run(suite()) 543 | 544 | if __name__ == "__main__": 545 | main() 546 | --------------------------------------------------------------------------------