├── .gitattributes ├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.adoc ├── multi ├── pg8000 ├── __init__.py ├── _version.py └── core.py ├── setup.cfg ├── setup.py ├── tests ├── connection_settings.py ├── dbapi20.py ├── performance.py ├── stress.py ├── test_connection.py ├── test_copy.py ├── test_dbapi.py ├── test_error_recovery.py ├── test_paramstyle.py ├── test_pg8000_dbapi20.py ├── test_query.py ├── test_typeconversion.py └── test_typeobjects.py ├── tox.ini └── versioneer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | pg8000/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[co] 2 | *.swp 3 | *.orig 4 | *.class 5 | build 6 | pg8000.egg-info 7 | tmp 8 | dist 9 | .tox 10 | MANIFEST 11 | venv 12 | .cache 13 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | language: python 3 | python: 4 | - "2.7" 5 | - "3.5" 6 | - "3.6" 7 | - "pypy3.5" 8 | 9 | env: 10 | - DB="9.5" 11 | - DB="9.6" 12 | 13 | services: 14 | - postgresql 15 | addons: 16 | postgresql: "9.5" 17 | postgresql: "9.6" 18 | 19 | before_install: 20 | - sudo service postgresql stop 21 | - cd /etc/postgresql/$DB/main 22 | - sudo chmod ugo+rw pg_hba.conf 23 | - sudo cp pg_hba.conf old_pg_hba.conf 24 | - sudo echo "host pg8000_md5 all 127.0.0.1/32 md5" > pg_hba.conf 25 | - sudo echo "host pg8000_gss all 127.0.0.1/32 gss" >> pg_hba.conf 26 | - sudo echo "host pg8000_password all 127.0.0.1/32 password" >> pg_hba.conf 27 | - cat old_pg_hba.conf >> pg_hba.conf 28 | - sudo service postgresql start $DB 29 | - psql -U postgres -tc 'create extension hstore;' 30 | - psql -U postgres -tc 'show server_version;' 31 | - psql -U postgres -tc "alter user postgres with password 'pw';" 32 | - psql -U postgres -tc "alter system set client_min_messages = notice;" 33 | - sudo service postgresql reload $DB 34 | - psql -U postgres -tc 'show client_min_messages;' 35 | - cd $TRAVIS_BUILD_DIR 36 | 37 | install: 38 | - pip install nose 39 | - pip install pytz 40 | 41 | script: 42 | - nosetests 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2007-2009, Mathieu Fenniak 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are 6 | met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | * The name of the author may not be used to endorse or promote products 14 | derived from this software without specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 19 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 20 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 21 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 22 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 23 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 24 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 25 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 26 | POSSIBILITY OF SUCH DAMAGE. 27 | 28 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.creole 2 | include versioneer.py 3 | include pg8000/_version.py 4 | include LICENSE 5 | include doc/* 6 | -------------------------------------------------------------------------------- /README.adoc: -------------------------------------------------------------------------------- 1 | = pg8000 2 | 3 | Hello, the pg8000 repository has moved to https://github.com/tlocke/pg8000. 4 | -------------------------------------------------------------------------------- /multi: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # set postgres share memory to minimum to trigger unpinned buffer errors. 4 | 5 | for i in {1..100} 6 | do 7 | python -m pg8000.tests.stress & 8 | done 9 | wait 10 | echo "All processes done!" 11 | -------------------------------------------------------------------------------- /pg8000/__init__.py: -------------------------------------------------------------------------------- 1 | from pg8000.core import ( 2 | Warning, Bytea, DataError, DatabaseError, InterfaceError, ProgrammingError, 3 | Error, OperationalError, IntegrityError, InternalError, NotSupportedError, 4 | ArrayContentNotHomogenousError, ArrayDimensionsNotConsistentError, 5 | ArrayContentNotSupportedError, utc, Connection, Cursor, Binary, Date, 6 | DateFromTicks, Time, TimeFromTicks, Timestamp, TimestampFromTicks, BINARY, 7 | Interval, PGEnum, PGJson, PGJsonb, PGTsvector, PGText, PGVarchar) 8 | from ._version import get_versions 9 | __version__ = get_versions()['version'] 10 | del get_versions 11 | 12 | # Copyright (c) 2007-2009, Mathieu Fenniak 13 | # Copyright (c) The Contributors 14 | # All rights reserved. 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are 18 | # met: 19 | # 20 | # * Redistributions of source code must retain the above copyright notice, 21 | # this list of conditions and the following disclaimer. 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # * The name of the author may not be used to endorse or promote products 26 | # derived from this software without specific prior written permission. 27 | # 28 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 29 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 30 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 31 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 32 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 33 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 36 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 37 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 38 | # POSSIBILITY OF SUCH DAMAGE. 39 | 40 | __author__ = "Mathieu Fenniak" 41 | 42 | 43 | def connect( 44 | user, host='localhost', unix_sock=None, port=5432, database=None, 45 | password=None, ssl=False, timeout=None, application_name=None, 46 | max_prepared_statements=1000): 47 | """Creates a connection to a PostgreSQL database. 48 | 49 | This function is part of the `DBAPI 2.0 specification 50 | `_; however, the arguments of the 51 | function are not defined by the specification. 52 | 53 | :param user: 54 | The username to connect to the PostgreSQL server with. 55 | 56 | If your server character encoding is not ``ascii`` or ``utf8``, then 57 | you need to provide ``user`` as bytes, eg. 58 | ``"my_name".encode('EUC-JP')``. 59 | 60 | :keyword host: 61 | The hostname of the PostgreSQL server to connect with. Providing this 62 | parameter is necessary for TCP/IP connections. One of either ``host`` 63 | or ``unix_sock`` must be provided. The default is ``localhost``. 64 | 65 | :keyword unix_sock: 66 | The path to the UNIX socket to access the database through, for 67 | example, ``'/tmp/.s.PGSQL.5432'``. One of either ``host`` or 68 | ``unix_sock`` must be provided. 69 | 70 | :keyword port: 71 | The TCP/IP port of the PostgreSQL server instance. This parameter 72 | defaults to ``5432``, the registered common port of PostgreSQL TCP/IP 73 | servers. 74 | 75 | :keyword database: 76 | The name of the database instance to connect with. This parameter is 77 | optional; if omitted, the PostgreSQL server will assume the database 78 | name is the same as the username. 79 | 80 | If your server character encoding is not ``ascii`` or ``utf8``, then 81 | you need to provide ``database`` as bytes, eg. 82 | ``"my_db".encode('EUC-JP')``. 83 | 84 | :keyword password: 85 | The user password to connect to the server with. This parameter is 86 | optional; if omitted and the database server requests password-based 87 | authentication, the connection will fail to open. If this parameter 88 | is provided but not requested by the server, no error will occur. 89 | 90 | If your server character encoding is not ``ascii`` or ``utf8``, then 91 | you need to provide ``user`` as bytes, eg. 92 | ``"my_password".encode('EUC-JP')``. 93 | 94 | :keyword application_name: 95 | The name will be displayed in the pg_stat_activity view. 96 | This parameter is optional. 97 | 98 | :keyword ssl: 99 | Use SSL encryption for TCP/IP sockets if ``True``. Defaults to 100 | ``False``. 101 | 102 | :keyword timeout: 103 | Only used with Python 3, this is the time in seconds before the 104 | connection to the database will time out. The default is ``None`` which 105 | means no timeout. 106 | 107 | :rtype: 108 | A :class:`Connection` object. 109 | """ 110 | return Connection( 111 | user, host, unix_sock, port, database, password, ssl, timeout, 112 | application_name, max_prepared_statements) 113 | 114 | 115 | apilevel = "2.0" 116 | """The DBAPI level supported, currently "2.0". 117 | 118 | This property is part of the `DBAPI 2.0 specification 119 | `_. 120 | """ 121 | 122 | threadsafety = 1 123 | """Integer constant stating the level of thread safety the DBAPI interface 124 | supports. This DBAPI module supports sharing of the module only. Connections 125 | and cursors my not be shared between threads. This gives pg8000 a threadsafety 126 | value of 1. 127 | 128 | This property is part of the `DBAPI 2.0 specification 129 | `_. 130 | """ 131 | 132 | paramstyle = 'format' 133 | 134 | max_prepared_statements = 1000 135 | 136 | # I have no idea what this would be used for by a client app. Should it be 137 | # TEXT, VARCHAR, CHAR? It will only compare against row_description's 138 | # type_code if it is this one type. It is the varchar type oid for now, this 139 | # appears to match expectations in the DB API 2.0 compliance test suite. 140 | 141 | STRING = 1043 142 | """String type oid.""" 143 | 144 | 145 | NUMBER = 1700 146 | """Numeric type oid""" 147 | 148 | DATETIME = 1114 149 | """Timestamp type oid""" 150 | 151 | ROWID = 26 152 | """ROWID type oid""" 153 | 154 | __all__ = [ 155 | Warning, Bytea, DataError, DatabaseError, connect, InterfaceError, 156 | ProgrammingError, Error, OperationalError, IntegrityError, InternalError, 157 | NotSupportedError, ArrayContentNotHomogenousError, 158 | ArrayDimensionsNotConsistentError, ArrayContentNotSupportedError, utc, 159 | Connection, Cursor, Binary, Date, DateFromTicks, Time, TimeFromTicks, 160 | Timestamp, TimestampFromTicks, BINARY, Interval, PGEnum, PGJson, PGJsonb, 161 | PGTsvector, PGText, PGVarchar] 162 | 163 | """Version string for pg8000. 164 | 165 | .. versionadded:: 1.9.11 166 | """ 167 | -------------------------------------------------------------------------------- /pg8000/_version.py: -------------------------------------------------------------------------------- 1 | 2 | # This file helps to compute a version number in source trees obtained from 3 | # git-archive tarball (such as those provided by githubs download-from-tag 4 | # feature). Distribution tarballs (built by setup.py sdist) and build 5 | # directories (produced by setup.py build) will contain a much shorter file 6 | # that just contains the computed version number. 7 | 8 | # This file is released into the public domain. Generated by 9 | # versioneer-0.15 (https://github.com/warner/python-versioneer) 10 | 11 | import errno 12 | import os 13 | import re 14 | import subprocess 15 | import sys 16 | 17 | 18 | def get_keywords(): 19 | # these strings will be replaced by git during git-archive. 20 | # setup.py/versioneer.py will grep for the variable names, so they must 21 | # each be defined on a line of their own. _version.py will just call 22 | # get_keywords(). 23 | git_refnames = " (HEAD -> master)" 24 | git_full = "412eace074514ada824e7a102765e37e2cda8eaa" 25 | keywords = {"refnames": git_refnames, "full": git_full} 26 | return keywords 27 | 28 | 29 | class VersioneerConfig: 30 | pass 31 | 32 | 33 | def get_config(): 34 | # these strings are filled in when 'setup.py versioneer' creates 35 | # _version.py 36 | cfg = VersioneerConfig() 37 | cfg.VCS = "git" 38 | cfg.style = "pep440" 39 | cfg.tag_prefix = "" 40 | cfg.parentdir_prefix = "pg8000-" 41 | cfg.versionfile_source = "pg8000/_version.py" 42 | cfg.verbose = False 43 | return cfg 44 | 45 | 46 | class NotThisMethod(Exception): 47 | pass 48 | 49 | 50 | LONG_VERSION_PY = {} 51 | HANDLERS = {} 52 | 53 | 54 | def register_vcs_handler(vcs, method): # decorator 55 | def decorate(f): 56 | if vcs not in HANDLERS: 57 | HANDLERS[vcs] = {} 58 | HANDLERS[vcs][method] = f 59 | return f 60 | return decorate 61 | 62 | 63 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): 64 | assert isinstance(commands, list) 65 | p = None 66 | for c in commands: 67 | try: 68 | dispcmd = str([c] + args) 69 | # remember shell=False, so use git.cmd on windows, not just git 70 | p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, 71 | stderr=(subprocess.PIPE if hide_stderr 72 | else None)) 73 | break 74 | except EnvironmentError: 75 | e = sys.exc_info()[1] 76 | if e.errno == errno.ENOENT: 77 | continue 78 | if verbose: 79 | print("unable to run %s" % dispcmd) 80 | print(e) 81 | return None 82 | else: 83 | if verbose: 84 | print("unable to find command, tried %s" % (commands,)) 85 | return None 86 | stdout = p.communicate()[0].strip() 87 | if sys.version_info[0] >= 3: 88 | stdout = stdout.decode() 89 | if p.returncode != 0: 90 | if verbose: 91 | print("unable to run %s (error)" % dispcmd) 92 | return None 93 | return stdout 94 | 95 | 96 | def versions_from_parentdir(parentdir_prefix, root, verbose): 97 | # Source tarballs conventionally unpack into a directory that includes 98 | # both the project name and a version string. 99 | dirname = os.path.basename(root) 100 | if not dirname.startswith(parentdir_prefix): 101 | if verbose: 102 | print("guessing rootdir is '%s', but '%s' doesn't start with " 103 | "prefix '%s'" % (root, dirname, parentdir_prefix)) 104 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 105 | return {"version": dirname[len(parentdir_prefix):], 106 | "full-revisionid": None, 107 | "dirty": False, "error": None} 108 | 109 | 110 | @register_vcs_handler("git", "get_keywords") 111 | def git_get_keywords(versionfile_abs): 112 | # the code embedded in _version.py can just fetch the value of these 113 | # keywords. When used from setup.py, we don't want to import _version.py, 114 | # so we do it with a regexp instead. This function is not used from 115 | # _version.py. 116 | keywords = {} 117 | try: 118 | f = open(versionfile_abs, "r") 119 | for line in f.readlines(): 120 | if line.strip().startswith("git_refnames ="): 121 | mo = re.search(r'=\s*"(.*)"', line) 122 | if mo: 123 | keywords["refnames"] = mo.group(1) 124 | if line.strip().startswith("git_full ="): 125 | mo = re.search(r'=\s*"(.*)"', line) 126 | if mo: 127 | keywords["full"] = mo.group(1) 128 | f.close() 129 | except EnvironmentError: 130 | pass 131 | return keywords 132 | 133 | 134 | @register_vcs_handler("git", "keywords") 135 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 136 | if not keywords: 137 | raise NotThisMethod("no keywords at all, weird") 138 | refnames = keywords["refnames"].strip() 139 | if refnames.startswith("$Format"): 140 | if verbose: 141 | print("keywords are unexpanded, not using") 142 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 143 | refs = set([r.strip() for r in refnames.strip("()").split(",")]) 144 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 145 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 146 | TAG = "tag: " 147 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) 148 | if not tags: 149 | # Either we're using git < 1.8.3, or there really are no tags. We use 150 | # a heuristic: assume all version tags have a digit. The old git %d 151 | # expansion behaves like git log --decorate=short and strips out the 152 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 153 | # between branches and tags. By ignoring refnames without digits, we 154 | # filter out many common branch names like "release" and 155 | # "stabilization", as well as "HEAD" and "master". 156 | tags = set([r for r in refs if re.search(r'\d', r)]) 157 | if verbose: 158 | print("discarding '%s', no digits" % ",".join(refs-tags)) 159 | if verbose: 160 | print("likely tags: %s" % ",".join(sorted(tags))) 161 | for ref in sorted(tags): 162 | # sorting will prefer e.g. "2.0" over "2.0rc1" 163 | if ref.startswith(tag_prefix): 164 | r = ref[len(tag_prefix):] 165 | if verbose: 166 | print("picking %s" % r) 167 | return {"version": r, 168 | "full-revisionid": keywords["full"].strip(), 169 | "dirty": False, "error": None 170 | } 171 | # no suitable tags, so version is "0+unknown", but full hex is still there 172 | if verbose: 173 | print("no suitable tags, using unknown + full revision id") 174 | return {"version": "0+unknown", 175 | "full-revisionid": keywords["full"].strip(), 176 | "dirty": False, "error": "no suitable tags"} 177 | 178 | 179 | @register_vcs_handler("git", "pieces_from_vcs") 180 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): 181 | # this runs 'git' from the root of the source tree. This only gets called 182 | # if the git-archive 'subst' keywords were *not* expanded, and 183 | # _version.py hasn't already been rewritten with a short version string, 184 | # meaning we're inside a checked out source tree. 185 | 186 | if not os.path.exists(os.path.join(root, ".git")): 187 | if verbose: 188 | print("no .git in %s" % root) 189 | raise NotThisMethod("no .git directory") 190 | 191 | GITS = ["git"] 192 | if sys.platform == "win32": 193 | GITS = ["git.cmd", "git.exe"] 194 | # if there is a tag, this yields TAG-NUM-gHEX[-dirty] 195 | # if there are no tags, this yields HEX[-dirty] (no NUM) 196 | describe_out = run_command(GITS, ["describe", "--tags", "--dirty", 197 | "--always", "--long"], 198 | cwd=root) 199 | # --long was added in git-1.5.5 200 | if describe_out is None: 201 | raise NotThisMethod("'git describe' failed") 202 | describe_out = describe_out.strip() 203 | full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) 204 | if full_out is None: 205 | raise NotThisMethod("'git rev-parse' failed") 206 | full_out = full_out.strip() 207 | 208 | pieces = {} 209 | pieces["long"] = full_out 210 | pieces["short"] = full_out[:7] # maybe improved later 211 | pieces["error"] = None 212 | 213 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 214 | # TAG might have hyphens. 215 | git_describe = describe_out 216 | 217 | # look for -dirty suffix 218 | dirty = git_describe.endswith("-dirty") 219 | pieces["dirty"] = dirty 220 | if dirty: 221 | git_describe = git_describe[:git_describe.rindex("-dirty")] 222 | 223 | # now we have TAG-NUM-gHEX or HEX 224 | 225 | if "-" in git_describe: 226 | # TAG-NUM-gHEX 227 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 228 | if not mo: 229 | # unparseable. Maybe git-describe is misbehaving? 230 | pieces["error"] = ("unable to parse git-describe output: '%s'" 231 | % describe_out) 232 | return pieces 233 | 234 | # tag 235 | full_tag = mo.group(1) 236 | if not full_tag.startswith(tag_prefix): 237 | if verbose: 238 | fmt = "tag '%s' doesn't start with prefix '%s'" 239 | print(fmt % (full_tag, tag_prefix)) 240 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 241 | % (full_tag, tag_prefix)) 242 | return pieces 243 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 244 | 245 | # distance: number of commits since tag 246 | pieces["distance"] = int(mo.group(2)) 247 | 248 | # commit: short hex revision ID 249 | pieces["short"] = mo.group(3) 250 | 251 | else: 252 | # HEX: no tags 253 | pieces["closest-tag"] = None 254 | count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], 255 | cwd=root) 256 | pieces["distance"] = int(count_out) # total number of commits 257 | 258 | return pieces 259 | 260 | 261 | def plus_or_dot(pieces): 262 | if "+" in pieces.get("closest-tag", ""): 263 | return "." 264 | return "+" 265 | 266 | 267 | def render_pep440(pieces): 268 | # now build up version string, with post-release "local version 269 | # identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 270 | # get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 271 | 272 | # exceptions: 273 | # 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 274 | 275 | if pieces["closest-tag"]: 276 | rendered = pieces["closest-tag"] 277 | if pieces["distance"] or pieces["dirty"]: 278 | rendered += plus_or_dot(pieces) 279 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 280 | if pieces["dirty"]: 281 | rendered += ".dirty" 282 | else: 283 | # exception #1 284 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 285 | pieces["short"]) 286 | if pieces["dirty"]: 287 | rendered += ".dirty" 288 | return rendered 289 | 290 | 291 | def render_pep440_pre(pieces): 292 | # TAG[.post.devDISTANCE] . No -dirty 293 | 294 | # exceptions: 295 | # 1: no tags. 0.post.devDISTANCE 296 | 297 | if pieces["closest-tag"]: 298 | rendered = pieces["closest-tag"] 299 | if pieces["distance"]: 300 | rendered += ".post.dev%d" % pieces["distance"] 301 | else: 302 | # exception #1 303 | rendered = "0.post.dev%d" % pieces["distance"] 304 | return rendered 305 | 306 | 307 | def render_pep440_post(pieces): 308 | # TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that 309 | # .dev0 sorts backwards (a dirty tree will appear "older" than the 310 | # corresponding clean one), but you shouldn't be releasing software with 311 | # -dirty anyways. 312 | 313 | # exceptions: 314 | # 1: no tags. 0.postDISTANCE[.dev0] 315 | 316 | if pieces["closest-tag"]: 317 | rendered = pieces["closest-tag"] 318 | if pieces["distance"] or pieces["dirty"]: 319 | rendered += ".post%d" % pieces["distance"] 320 | if pieces["dirty"]: 321 | rendered += ".dev0" 322 | rendered += plus_or_dot(pieces) 323 | rendered += "g%s" % pieces["short"] 324 | else: 325 | # exception #1 326 | rendered = "0.post%d" % pieces["distance"] 327 | if pieces["dirty"]: 328 | rendered += ".dev0" 329 | rendered += "+g%s" % pieces["short"] 330 | return rendered 331 | 332 | 333 | def render_pep440_old(pieces): 334 | # TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. 335 | 336 | # exceptions: 337 | # 1: no tags. 0.postDISTANCE[.dev0] 338 | 339 | if pieces["closest-tag"]: 340 | rendered = pieces["closest-tag"] 341 | if pieces["distance"] or pieces["dirty"]: 342 | rendered += ".post%d" % pieces["distance"] 343 | if pieces["dirty"]: 344 | rendered += ".dev0" 345 | else: 346 | # exception #1 347 | rendered = "0.post%d" % pieces["distance"] 348 | if pieces["dirty"]: 349 | rendered += ".dev0" 350 | return rendered 351 | 352 | 353 | def render_git_describe(pieces): 354 | # TAG[-DISTANCE-gHEX][-dirty], like 'git describe --tags --dirty 355 | # --always' 356 | 357 | # exceptions: 358 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix) 359 | 360 | if pieces["closest-tag"]: 361 | rendered = pieces["closest-tag"] 362 | if pieces["distance"]: 363 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 364 | else: 365 | # exception #1 366 | rendered = pieces["short"] 367 | if pieces["dirty"]: 368 | rendered += "-dirty" 369 | return rendered 370 | 371 | 372 | def render_git_describe_long(pieces): 373 | # TAG-DISTANCE-gHEX[-dirty], like 'git describe --tags --dirty 374 | # --always -long'. The distance/hash is unconditional. 375 | 376 | # exceptions: 377 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix) 378 | 379 | if pieces["closest-tag"]: 380 | rendered = pieces["closest-tag"] 381 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 382 | else: 383 | # exception #1 384 | rendered = pieces["short"] 385 | if pieces["dirty"]: 386 | rendered += "-dirty" 387 | return rendered 388 | 389 | 390 | def render(pieces, style): 391 | if pieces["error"]: 392 | return {"version": "unknown", 393 | "full-revisionid": pieces.get("long"), 394 | "dirty": None, 395 | "error": pieces["error"]} 396 | 397 | if not style or style == "default": 398 | style = "pep440" # the default 399 | 400 | if style == "pep440": 401 | rendered = render_pep440(pieces) 402 | elif style == "pep440-pre": 403 | rendered = render_pep440_pre(pieces) 404 | elif style == "pep440-post": 405 | rendered = render_pep440_post(pieces) 406 | elif style == "pep440-old": 407 | rendered = render_pep440_old(pieces) 408 | elif style == "git-describe": 409 | rendered = render_git_describe(pieces) 410 | elif style == "git-describe-long": 411 | rendered = render_git_describe_long(pieces) 412 | else: 413 | raise ValueError("unknown style '%s'" % style) 414 | 415 | return {"version": rendered, "full-revisionid": pieces["long"], 416 | "dirty": pieces["dirty"], "error": None} 417 | 418 | 419 | def get_versions(): 420 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 421 | # __file__, we can work backwards from there to the root. Some 422 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 423 | # case we can only use expanded keywords. 424 | 425 | cfg = get_config() 426 | verbose = cfg.verbose 427 | 428 | try: 429 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 430 | verbose) 431 | except NotThisMethod: 432 | pass 433 | 434 | try: 435 | root = os.path.realpath(__file__) 436 | # versionfile_source is the relative path from the top of the source 437 | # tree (where the .git directory might live) to this file. Invert 438 | # this to find the root from __file__. 439 | for i in cfg.versionfile_source.split('/'): 440 | root = os.path.dirname(root) 441 | except NameError: 442 | return {"version": "0+unknown", "full-revisionid": None, 443 | "dirty": None, 444 | "error": "unable to find root of source tree"} 445 | 446 | try: 447 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 448 | return render(pieces, cfg.style) 449 | except NotThisMethod: 450 | pass 451 | 452 | try: 453 | if cfg.parentdir_prefix: 454 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 455 | except NotThisMethod: 456 | pass 457 | 458 | return {"version": "0+unknown", "full-revisionid": None, 459 | "dirty": None, 460 | "error": "unable to compute version"} 461 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [upload_docs] 2 | upload-dir = build/sphinx/html 3 | 4 | [bdist_wheel] 5 | universal=1 6 | 7 | [versioneer] 8 | VCS = git 9 | style = pep440 10 | versionfile_source = pg8000/_version.py 11 | versionfile_build = pg8000/_version.py 12 | tag_prefix = 13 | parentdir_prefix = pg8000- 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import versioneer 4 | from setuptools import setup 5 | 6 | long_description = """\ 7 | 8 | pg8000 9 | ------ 10 | 11 | pg8000 is a Pure-Python interface to the PostgreSQL database engine. It is \ 12 | one of many PostgreSQL interfaces for the Python programming language. pg8000 \ 13 | is somewhat distinctive in that it is written entirely in Python and does not \ 14 | rely on any external libraries (such as a compiled python module, or \ 15 | PostgreSQL's libpq library). pg8000 supports the standard Python DB-API \ 16 | version 2.0. 17 | 18 | pg8000's name comes from the belief that it is probably about the 8000th \ 19 | PostgreSQL interface for Python.""" 20 | 21 | cmdclass = dict(versioneer.get_cmdclass()) 22 | version = versioneer.get_version() 23 | 24 | setup( 25 | name="pg8000", 26 | version=version, 27 | cmdclass=cmdclass, 28 | description="PostgreSQL interface library", 29 | long_description=long_description, 30 | author="Mathieu Fenniak", 31 | author_email="biziqe@mathieu.fenniak.net", 32 | url="https://github.com/mfenniak/pg8000", 33 | license="BSD", 34 | install_requires=[ 35 | "six>=1.10.0", 36 | ], 37 | classifiers=[ 38 | "Development Status :: 4 - Beta", 39 | "Intended Audience :: Developers", 40 | "License :: OSI Approved :: BSD License", 41 | "Programming Language :: Python", 42 | "Programming Language :: Python :: 2", 43 | "Programming Language :: Python :: 2.7", 44 | "Programming Language :: Python :: 3", 45 | "Programming Language :: Python :: 3.3", 46 | "Programming Language :: Python :: 3.4", 47 | "Programming Language :: Python :: 3.5", 48 | "Programming Language :: Python :: Implementation", 49 | "Programming Language :: Python :: Implementation :: CPython", 50 | "Programming Language :: Python :: Implementation :: Jython", 51 | "Programming Language :: Python :: Implementation :: PyPy", 52 | "Operating System :: OS Independent", 53 | "Topic :: Database :: Front-Ends", 54 | "Topic :: Software Development :: Libraries :: Python Modules", 55 | ], 56 | keywords="postgresql dbapi", 57 | packages=("pg8000",), 58 | command_options={ 59 | 'build_sphinx': { 60 | 'version': ('setup.py', version), 61 | 'release': ('setup.py', version)}}, 62 | ) 63 | -------------------------------------------------------------------------------- /tests/connection_settings.py: -------------------------------------------------------------------------------- 1 | db_connect = { 2 | 'user': 'postgres', 3 | 'password': 'pw', 4 | 'port': 5432} 5 | -------------------------------------------------------------------------------- /tests/dbapi20.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import unittest 3 | import time 4 | import warnings 5 | from six import b 6 | ''' Python DB API 2.0 driver compliance unit test suite. 7 | 8 | This software is Public Domain and may be used without restrictions. 9 | 10 | "Now we have booze and barflies entering the discussion, plus rumours of 11 | DBAs on drugs... and I won't tell you what flashes through my mind each 12 | time I read the subject line with 'Anal Compliance' in it. All around 13 | this is turning out to be a thoroughly unwholesome unit test." 14 | 15 | -- Ian Bicking 16 | ''' 17 | 18 | __rcs_id__ = '$Id: dbapi20.py,v 1.10 2003/10/09 03:14:14 zenzen Exp $' 19 | __version__ = '$Revision: 1.10 $'[11:-2] 20 | __author__ = 'Stuart Bishop ' 21 | 22 | 23 | # $Log: dbapi20.py,v $ 24 | # Revision 1.10 2003/10/09 03:14:14 zenzen 25 | # Add test for DB API 2.0 optional extension, where database exceptions 26 | # are exposed as attributes on the Connection object. 27 | # 28 | # Revision 1.9 2003/08/13 01:16:36 zenzen 29 | # Minor tweak from Stefan Fleiter 30 | # 31 | # Revision 1.8 2003/04/10 00:13:25 zenzen 32 | # Changes, as per suggestions by M.-A. Lemburg 33 | # - Add a table prefix, to ensure namespace collisions can always be avoided 34 | # 35 | # Revision 1.7 2003/02/26 23:33:37 zenzen 36 | # Break out DDL into helper functions, as per request by David Rushby 37 | # 38 | # Revision 1.6 2003/02/21 03:04:33 zenzen 39 | # Stuff from Henrik Ekelund: 40 | # added test_None 41 | # added test_nextset & hooks 42 | # 43 | # Revision 1.5 2003/02/17 22:08:43 zenzen 44 | # Implement suggestions and code from Henrik Eklund - test that 45 | # cursor.arraysize defaults to 1 & generic cursor.callproc test added 46 | # 47 | # Revision 1.4 2003/02/15 00:16:33 zenzen 48 | # Changes, as per suggestions and bug reports by M.-A. Lemburg, 49 | # Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar 50 | # - Class renamed 51 | # - Now a subclass of TestCase, to avoid requiring the driver stub 52 | # to use multiple inheritance 53 | # - Reversed the polarity of buggy test in test_description 54 | # - Test exception heirarchy correctly 55 | # - self.populate is now self._populate(), so if a driver stub 56 | # overrides self.ddl1 this change propogates 57 | # - VARCHAR columns now have a width, which will hopefully make the 58 | # DDL even more portible (this will be reversed if it causes more problems) 59 | # - cursor.rowcount being checked after various execute and fetchXXX methods 60 | # - Check for fetchall and fetchmany returning empty lists after results 61 | # are exhausted (already checking for empty lists if select retrieved 62 | # nothing 63 | # - Fix bugs in test_setoutputsize_basic and test_setinputsizes 64 | # 65 | 66 | 67 | class DatabaseAPI20Test(unittest.TestCase): 68 | ''' Test a database self.driver for DB API 2.0 compatibility. 69 | This implementation tests Gadfly, but the TestCase 70 | is structured so that other self.drivers can subclass this 71 | test case to ensure compiliance with the DB-API. It is 72 | expected that this TestCase may be expanded in the future 73 | if ambiguities or edge conditions are discovered. 74 | 75 | The 'Optional Extensions' are not yet being tested. 76 | 77 | self.drivers should subclass this test, overriding setUp, tearDown, 78 | self.driver, connect_args and connect_kw_args. Class specification 79 | should be as follows: 80 | 81 | import dbapi20 82 | class mytest(dbapi20.DatabaseAPI20Test): 83 | [...] 84 | 85 | Don't 'import DatabaseAPI20Test from dbapi20', or you will 86 | confuse the unit tester - just 'import dbapi20'. 87 | ''' 88 | 89 | # The self.driver module. This should be the module where the 'connect' 90 | # method is to be found 91 | driver = None 92 | connect_args = () # List of arguments to pass to connect 93 | connect_kw_args = {} # Keyword arguments for connect 94 | table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables 95 | 96 | ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix 97 | ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix 98 | xddl1 = 'drop table %sbooze' % table_prefix 99 | xddl2 = 'drop table %sbarflys' % table_prefix 100 | 101 | # Name of stored procedure to convert 102 | # string->lowercase 103 | lowerfunc = 'lower' 104 | 105 | # Some drivers may need to override these helpers, for example adding 106 | # a 'commit' after the execute. 107 | def executeDDL1(self, cursor): 108 | cursor.execute(self.ddl1) 109 | 110 | def executeDDL2(self, cursor): 111 | cursor.execute(self.ddl2) 112 | 113 | def setUp(self): 114 | ''' self.drivers should override this method to perform required setup 115 | if any is necessary, such as creating the database. 116 | ''' 117 | pass 118 | 119 | def tearDown(self): 120 | ''' self.drivers should override this method to perform required 121 | cleanup if any is necessary, such as deleting the test database. 122 | The default drops the tables that may be created. 123 | ''' 124 | con = self._connect() 125 | try: 126 | cur = con.cursor() 127 | for ddl in (self.xddl1, self.xddl2): 128 | try: 129 | cur.execute(ddl) 130 | con.commit() 131 | except self.driver.Error: 132 | # Assume table didn't exist. Other tests will check if 133 | # execute is busted. 134 | pass 135 | finally: 136 | con.close() 137 | 138 | def _connect(self): 139 | try: 140 | return self.driver.connect( 141 | *self.connect_args, **self.connect_kw_args) 142 | except AttributeError: 143 | self.fail("No connect method found in self.driver module") 144 | 145 | def test_connect(self): 146 | con = self._connect() 147 | con.close() 148 | 149 | def test_apilevel(self): 150 | try: 151 | # Must exist 152 | apilevel = self.driver.apilevel 153 | # Must equal 2.0 154 | self.assertEqual(apilevel, '2.0') 155 | except AttributeError: 156 | self.fail("Driver doesn't define apilevel") 157 | 158 | def test_threadsafety(self): 159 | try: 160 | # Must exist 161 | threadsafety = self.driver.threadsafety 162 | # Must be a valid value 163 | self.assertEqual(threadsafety in (0, 1, 2, 3), True) 164 | except AttributeError: 165 | self.fail("Driver doesn't define threadsafety") 166 | 167 | def test_paramstyle(self): 168 | try: 169 | # Must exist 170 | paramstyle = self.driver.paramstyle 171 | # Must be a valid value 172 | self.assertEqual( 173 | paramstyle in ( 174 | 'qmark', 'numeric', 'named', 'format', 'pyformat'), True) 175 | except AttributeError: 176 | self.fail("Driver doesn't define paramstyle") 177 | 178 | def test_Exceptions(self): 179 | # Make sure required exceptions exist, and are in the 180 | # defined heirarchy. 181 | self.assertEqual(issubclass(self.driver.Warning, Exception), True) 182 | self.assertEqual(issubclass(self.driver.Error, Exception), True) 183 | self.assertEqual( 184 | issubclass(self.driver.InterfaceError, self.driver.Error), True) 185 | self.assertEqual( 186 | issubclass(self.driver.DatabaseError, self.driver.Error), True) 187 | self.assertEqual( 188 | issubclass(self.driver.OperationalError, self.driver.Error), True) 189 | self.assertEqual( 190 | issubclass(self.driver.IntegrityError, self.driver.Error), True) 191 | self.assertEqual( 192 | issubclass(self.driver.InternalError, self.driver.Error), True) 193 | self.assertEqual( 194 | issubclass(self.driver.ProgrammingError, self.driver.Error), True) 195 | self.assertEqual( 196 | issubclass(self.driver.NotSupportedError, self.driver.Error), True) 197 | 198 | def test_ExceptionsAsConnectionAttributes(self): 199 | # OPTIONAL EXTENSION 200 | # Test for the optional DB API 2.0 extension, where the exceptions 201 | # are exposed as attributes on the Connection object 202 | # I figure this optional extension will be implemented by any 203 | # driver author who is using this test suite, so it is enabled 204 | # by default. 205 | warnings.simplefilter("ignore") 206 | con = self._connect() 207 | drv = self.driver 208 | self.assertEqual(con.Warning is drv.Warning, True) 209 | self.assertEqual(con.Error is drv.Error, True) 210 | self.assertEqual(con.InterfaceError is drv.InterfaceError, True) 211 | self.assertEqual(con.DatabaseError is drv.DatabaseError, True) 212 | self.assertEqual(con.OperationalError is drv.OperationalError, True) 213 | self.assertEqual(con.IntegrityError is drv.IntegrityError, True) 214 | self.assertEqual(con.InternalError is drv.InternalError, True) 215 | self.assertEqual(con.ProgrammingError is drv.ProgrammingError, True) 216 | self.assertEqual(con.NotSupportedError is drv.NotSupportedError, True) 217 | warnings.resetwarnings() 218 | con.close() 219 | 220 | def test_commit(self): 221 | con = self._connect() 222 | try: 223 | # Commit must work, even if it doesn't do anything 224 | con.commit() 225 | finally: 226 | con.close() 227 | 228 | def test_rollback(self): 229 | con = self._connect() 230 | # If rollback is defined, it should either work or throw 231 | # the documented exception 232 | if hasattr(con, 'rollback'): 233 | try: 234 | con.rollback() 235 | except self.driver.NotSupportedError: 236 | pass 237 | con.close() 238 | 239 | def test_cursor(self): 240 | con = self._connect() 241 | try: 242 | con.cursor() 243 | finally: 244 | con.close() 245 | 246 | def test_cursor_isolation(self): 247 | con = self._connect() 248 | try: 249 | # Make sure cursors created from the same connection have 250 | # the documented transaction isolation level 251 | cur1 = con.cursor() 252 | cur2 = con.cursor() 253 | self.executeDDL1(cur1) 254 | cur1.execute( 255 | "insert into %sbooze values ('Victoria Bitter')" % ( 256 | self.table_prefix)) 257 | cur2.execute("select name from %sbooze" % self.table_prefix) 258 | booze = cur2.fetchall() 259 | self.assertEqual(len(booze), 1) 260 | self.assertEqual(len(booze[0]), 1) 261 | self.assertEqual(booze[0][0], 'Victoria Bitter') 262 | finally: 263 | con.close() 264 | 265 | def test_description(self): 266 | con = self._connect() 267 | try: 268 | cur = con.cursor() 269 | self.executeDDL1(cur) 270 | self.assertEqual( 271 | cur.description, None, 272 | 'cursor.description should be none after executing a ' 273 | 'statement that can return no rows (such as DDL)') 274 | cur.execute('select name from %sbooze' % self.table_prefix) 275 | self.assertEqual( 276 | len(cur.description), 1, 277 | 'cursor.description describes too many columns') 278 | self.assertEqual( 279 | len(cur.description[0]), 7, 280 | 'cursor.description[x] tuples must have 7 elements') 281 | self.assertEqual( 282 | cur.description[0][0].lower(), b('name'), 283 | 'cursor.description[x][0] must return column name') 284 | self.assertEqual( 285 | cur.description[0][1], self.driver.STRING, 286 | 'cursor.description[x][1] must return column type. Got %r' 287 | % cur.description[0][1]) 288 | 289 | # Make sure self.description gets reset 290 | self.executeDDL2(cur) 291 | self.assertEqual( 292 | cur.description, None, 293 | 'cursor.description not being set to None when executing ' 294 | 'no-result statements (eg. DDL)') 295 | finally: 296 | con.close() 297 | 298 | def test_rowcount(self): 299 | con = self._connect() 300 | try: 301 | cur = con.cursor() 302 | self.executeDDL1(cur) 303 | self.assertEqual( 304 | cur.rowcount, -1, 305 | 'cursor.rowcount should be -1 after executing no-result ' 306 | 'statements') 307 | cur.execute( 308 | "insert into %sbooze values ('Victoria Bitter')" % ( 309 | self.table_prefix)) 310 | self.assertEqual( 311 | cur.rowcount in (-1, 1), True, 312 | 'cursor.rowcount should == number or rows inserted, or ' 313 | 'set to -1 after executing an insert statement') 314 | cur.execute("select name from %sbooze" % self.table_prefix) 315 | self.assertEqual( 316 | cur.rowcount in (-1, 1), True, 317 | 'cursor.rowcount should == number of rows returned, or ' 318 | 'set to -1 after executing a select statement') 319 | self.executeDDL2(cur) 320 | self.assertEqual( 321 | cur.rowcount, -1, 322 | 'cursor.rowcount not being reset to -1 after executing ' 323 | 'no-result statements') 324 | finally: 325 | con.close() 326 | 327 | lower_func = 'lower' 328 | 329 | def test_callproc(self): 330 | con = self._connect() 331 | try: 332 | cur = con.cursor() 333 | if self.lower_func and hasattr(cur, 'callproc'): 334 | r = cur.callproc(self.lower_func, ('FOO',)) 335 | self.assertEqual(len(r), 1) 336 | self.assertEqual(r[0], 'FOO') 337 | r = cur.fetchall() 338 | self.assertEqual(len(r), 1, 'callproc produced no result set') 339 | self.assertEqual( 340 | len(r[0]), 1, 'callproc produced invalid result set') 341 | self.assertEqual( 342 | r[0][0], 'foo', 'callproc produced invalid results') 343 | finally: 344 | con.close() 345 | 346 | def test_close(self): 347 | con = self._connect() 348 | try: 349 | cur = con.cursor() 350 | finally: 351 | con.close() 352 | 353 | # cursor.execute should raise an Error if called after connection 354 | # closed 355 | self.assertRaises(self.driver.Error, self.executeDDL1, cur) 356 | 357 | # connection.commit should raise an Error if called after connection' 358 | # closed.' 359 | self.assertRaises(self.driver.Error, con.commit) 360 | 361 | # connection.close should raise an Error if called more than once 362 | self.assertRaises(self.driver.Error, con.close) 363 | 364 | def test_execute(self): 365 | con = self._connect() 366 | try: 367 | cur = con.cursor() 368 | self._paraminsert(cur) 369 | finally: 370 | con.close() 371 | 372 | def _paraminsert(self, cur): 373 | self.executeDDL1(cur) 374 | cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( 375 | self.table_prefix)) 376 | self.assertEqual(cur.rowcount in (-1, 1), True) 377 | 378 | if self.driver.paramstyle == 'qmark': 379 | cur.execute( 380 | 'insert into %sbooze values (?)' % self.table_prefix, 381 | ("Cooper's",)) 382 | elif self.driver.paramstyle == 'numeric': 383 | cur.execute( 384 | 'insert into %sbooze values (:1)' % self.table_prefix, 385 | ("Cooper's",)) 386 | elif self.driver.paramstyle == 'named': 387 | cur.execute( 388 | 'insert into %sbooze values (:beer)' % self.table_prefix, 389 | {'beer': "Cooper's"}) 390 | elif self.driver.paramstyle == 'format': 391 | cur.execute( 392 | 'insert into %sbooze values (%%s)' % self.table_prefix, 393 | ("Cooper's",)) 394 | elif self.driver.paramstyle == 'pyformat': 395 | cur.execute( 396 | 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, 397 | {'beer': "Cooper's"}) 398 | else: 399 | self.fail('Invalid paramstyle') 400 | self.assertEqual(cur.rowcount in (-1, 1), True) 401 | 402 | cur.execute('select name from %sbooze' % self.table_prefix) 403 | res = cur.fetchall() 404 | self.assertEqual( 405 | len(res), 2, 'cursor.fetchall returned too few rows') 406 | beers = [res[0][0], res[1][0]] 407 | beers.sort() 408 | self.assertEqual( 409 | beers[0], "Cooper's", 410 | 'cursor.fetchall retrieved incorrect data, or data inserted ' 411 | 'incorrectly') 412 | self.assertEqual( 413 | beers[1], "Victoria Bitter", 414 | 'cursor.fetchall retrieved incorrect data, or data inserted ' 415 | 'incorrectly') 416 | 417 | def test_executemany(self): 418 | con = self._connect() 419 | try: 420 | cur = con.cursor() 421 | self.executeDDL1(cur) 422 | largs = [("Cooper's",), ("Boag's",)] 423 | margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}] 424 | if self.driver.paramstyle == 'qmark': 425 | cur.executemany( 426 | 'insert into %sbooze values (?)' % self.table_prefix, 427 | largs 428 | ) 429 | elif self.driver.paramstyle == 'numeric': 430 | cur.executemany( 431 | 'insert into %sbooze values (:1)' % self.table_prefix, 432 | largs 433 | ) 434 | elif self.driver.paramstyle == 'named': 435 | cur.executemany( 436 | 'insert into %sbooze values (:beer)' % self.table_prefix, 437 | margs 438 | ) 439 | elif self.driver.paramstyle == 'format': 440 | cur.executemany( 441 | 'insert into %sbooze values (%%s)' % self.table_prefix, 442 | largs 443 | ) 444 | elif self.driver.paramstyle == 'pyformat': 445 | cur.executemany( 446 | 'insert into %sbooze values (%%(beer)s)' % ( 447 | self.table_prefix), margs) 448 | else: 449 | self.fail('Unknown paramstyle') 450 | self.assertEqual( 451 | cur.rowcount in (-1, 2), True, 452 | 'insert using cursor.executemany set cursor.rowcount to ' 453 | 'incorrect value %r' % cur.rowcount) 454 | cur.execute('select name from %sbooze' % self.table_prefix) 455 | res = cur.fetchall() 456 | self.assertEqual( 457 | len(res), 2, 458 | 'cursor.fetchall retrieved incorrect number of rows') 459 | beers = [res[0][0], res[1][0]] 460 | beers.sort() 461 | self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved') 462 | self.assertEqual(beers[1], "Cooper's", 'incorrect data retrieved') 463 | finally: 464 | con.close() 465 | 466 | def test_fetchone(self): 467 | con = self._connect() 468 | try: 469 | cur = con.cursor() 470 | 471 | # cursor.fetchone should raise an Error if called before 472 | # executing a select-type query 473 | self.assertRaises(self.driver.Error, cur.fetchone) 474 | 475 | # cursor.fetchone should raise an Error if called after 476 | # executing a query that cannnot return rows 477 | self.executeDDL1(cur) 478 | self.assertRaises(self.driver.Error, cur.fetchone) 479 | 480 | cur.execute('select name from %sbooze' % self.table_prefix) 481 | self.assertEqual( 482 | cur.fetchone(), None, 483 | 'cursor.fetchone should return None if a query retrieves ' 484 | 'no rows') 485 | self.assertEqual(cur.rowcount in (-1, 0), True) 486 | 487 | # cursor.fetchone should raise an Error if called after 488 | # executing a query that cannnot return rows 489 | cur.execute( 490 | "insert into %sbooze values ('Victoria Bitter')" % ( 491 | self.table_prefix)) 492 | self.assertRaises(self.driver.Error, cur.fetchone) 493 | 494 | cur.execute('select name from %sbooze' % self.table_prefix) 495 | r = cur.fetchone() 496 | self.assertEqual( 497 | len(r), 1, 498 | 'cursor.fetchone should have retrieved a single row') 499 | self.assertEqual( 500 | r[0], 'Victoria Bitter', 501 | 'cursor.fetchone retrieved incorrect data') 502 | self.assertEqual( 503 | cur.fetchone(), None, 504 | 'cursor.fetchone should return None if no more rows available') 505 | self.assertEqual(cur.rowcount in (-1, 1), True) 506 | finally: 507 | con.close() 508 | 509 | samples = [ 510 | 'Carlton Cold', 511 | 'Carlton Draft', 512 | 'Mountain Goat', 513 | 'Redback', 514 | 'Victoria Bitter', 515 | 'XXXX' 516 | ] 517 | 518 | def _populate(self): 519 | ''' Return a list of sql commands to setup the DB for the fetch 520 | tests. 521 | ''' 522 | populate = [ 523 | "insert into %sbooze values ('%s')" % (self.table_prefix, s) 524 | for s in self.samples] 525 | return populate 526 | 527 | def test_fetchmany(self): 528 | con = self._connect() 529 | try: 530 | cur = con.cursor() 531 | 532 | # cursor.fetchmany should raise an Error if called without 533 | # issuing a query 534 | self.assertRaises(self.driver.Error, cur.fetchmany, 4) 535 | 536 | self.executeDDL1(cur) 537 | for sql in self._populate(): 538 | cur.execute(sql) 539 | 540 | cur.execute('select name from %sbooze' % self.table_prefix) 541 | r = cur.fetchmany() 542 | self.assertEqual( 543 | len(r), 1, 544 | 'cursor.fetchmany retrieved incorrect number of rows, ' 545 | 'default of arraysize is one.') 546 | cur.arraysize = 10 547 | r = cur.fetchmany(3) # Should get 3 rows 548 | self.assertEqual( 549 | len(r), 3, 550 | 'cursor.fetchmany retrieved incorrect number of rows') 551 | r = cur.fetchmany(4) # Should get 2 more 552 | self.assertEqual( 553 | len(r), 2, 554 | 'cursor.fetchmany retrieved incorrect number of rows') 555 | r = cur.fetchmany(4) # Should be an empty sequence 556 | self.assertEqual( 557 | len(r), 0, 558 | 'cursor.fetchmany should return an empty sequence after ' 559 | 'results are exhausted') 560 | self.assertEqual(cur.rowcount in (-1, 6), True) 561 | 562 | # Same as above, using cursor.arraysize 563 | cur.arraysize = 4 564 | cur.execute('select name from %sbooze' % self.table_prefix) 565 | r = cur.fetchmany() # Should get 4 rows 566 | self.assertEqual( 567 | len(r), 4, 568 | 'cursor.arraysize not being honoured by fetchmany') 569 | r = cur.fetchmany() # Should get 2 more 570 | self.assertEqual(len(r), 2) 571 | r = cur.fetchmany() # Should be an empty sequence 572 | self.assertEqual(len(r), 0) 573 | self.assertEqual(cur.rowcount in (-1, 6), True) 574 | 575 | cur.arraysize = 6 576 | cur.execute('select name from %sbooze' % self.table_prefix) 577 | rows = cur.fetchmany() # Should get all rows 578 | self.assertEqual(cur.rowcount in (-1, 6), True) 579 | self.assertEqual(len(rows), 6) 580 | self.assertEqual(len(rows), 6) 581 | rows = [row[0] for row in rows] 582 | rows.sort() 583 | 584 | # Make sure we get the right data back out 585 | for i in range(0, 6): 586 | self.assertEqual( 587 | rows[i], self.samples[i], 588 | 'incorrect data retrieved by cursor.fetchmany') 589 | 590 | rows = cur.fetchmany() # Should return an empty list 591 | self.assertEqual( 592 | len(rows), 0, 593 | 'cursor.fetchmany should return an empty sequence if ' 594 | 'called after the whole result set has been fetched') 595 | self.assertEqual(cur.rowcount in (-1, 6), True) 596 | 597 | self.executeDDL2(cur) 598 | cur.execute('select name from %sbarflys' % self.table_prefix) 599 | r = cur.fetchmany() # Should get empty sequence 600 | self.assertEqual( 601 | len(r), 0, 602 | 'cursor.fetchmany should return an empty sequence if ' 603 | 'query retrieved no rows') 604 | self.assertEqual(cur.rowcount in (-1, 0), True) 605 | 606 | finally: 607 | con.close() 608 | 609 | def test_fetchall(self): 610 | con = self._connect() 611 | try: 612 | cur = con.cursor() 613 | # cursor.fetchall should raise an Error if called 614 | # without executing a query that may return rows (such 615 | # as a select) 616 | self.assertRaises(self.driver.Error, cur.fetchall) 617 | 618 | self.executeDDL1(cur) 619 | for sql in self._populate(): 620 | cur.execute(sql) 621 | 622 | # cursor.fetchall should raise an Error if called 623 | # after executing a a statement that cannot return rows 624 | self.assertRaises(self.driver.Error, cur.fetchall) 625 | 626 | cur.execute('select name from %sbooze' % self.table_prefix) 627 | rows = cur.fetchall() 628 | self.assertEqual(cur.rowcount in (-1, len(self.samples)), True) 629 | self.assertEqual( 630 | len(rows), len(self.samples), 631 | 'cursor.fetchall did not retrieve all rows') 632 | rows = [r[0] for r in rows] 633 | rows.sort() 634 | for i in range(0, len(self.samples)): 635 | self.assertEqual( 636 | rows[i], self.samples[i], 637 | 'cursor.fetchall retrieved incorrect rows') 638 | rows = cur.fetchall() 639 | self.assertEqual( 640 | len(rows), 0, 641 | 'cursor.fetchall should return an empty list if called ' 642 | 'after the whole result set has been fetched') 643 | self.assertEqual(cur.rowcount in (-1, len(self.samples)), True) 644 | 645 | self.executeDDL2(cur) 646 | cur.execute('select name from %sbarflys' % self.table_prefix) 647 | rows = cur.fetchall() 648 | self.assertEqual(cur.rowcount in (-1, 0), True) 649 | self.assertEqual( 650 | len(rows), 0, 651 | 'cursor.fetchall should return an empty list if ' 652 | 'a select query returns no rows') 653 | 654 | finally: 655 | con.close() 656 | 657 | def test_mixedfetch(self): 658 | con = self._connect() 659 | try: 660 | cur = con.cursor() 661 | self.executeDDL1(cur) 662 | for sql in self._populate(): 663 | cur.execute(sql) 664 | 665 | cur.execute('select name from %sbooze' % self.table_prefix) 666 | rows1 = cur.fetchone() 667 | rows23 = cur.fetchmany(2) 668 | rows4 = cur.fetchone() 669 | rows56 = cur.fetchall() 670 | self.assertEqual(cur.rowcount in (-1, 6), True) 671 | self.assertEqual( 672 | len(rows23), 2, 'fetchmany returned incorrect number of rows') 673 | self.assertEqual( 674 | len(rows56), 2, 'fetchall returned incorrect number of rows') 675 | 676 | rows = [rows1[0]] 677 | rows.extend([rows23[0][0], rows23[1][0]]) 678 | rows.append(rows4[0]) 679 | rows.extend([rows56[0][0], rows56[1][0]]) 680 | rows.sort() 681 | for i in range(0, len(self.samples)): 682 | self.assertEqual( 683 | rows[i], self.samples[i], 684 | 'incorrect data retrieved or inserted') 685 | finally: 686 | con.close() 687 | 688 | def help_nextset_setUp(self, cur): 689 | ''' Should create a procedure called deleteme 690 | that returns two result sets, first the 691 | number of rows in booze then "name from booze" 692 | ''' 693 | raise NotImplementedError('Helper not implemented') 694 | 695 | def help_nextset_tearDown(self, cur): 696 | 'If cleaning up is needed after nextSetTest' 697 | raise NotImplementedError('Helper not implemented') 698 | 699 | def test_nextset(self): 700 | con = self._connect() 701 | try: 702 | cur = con.cursor() 703 | if not hasattr(cur, 'nextset'): 704 | return 705 | 706 | try: 707 | self.executeDDL1(cur) 708 | sql = self._populate() 709 | for sql in self._populate(): 710 | cur.execute(sql) 711 | 712 | self.help_nextset_setUp(cur) 713 | 714 | cur.callproc('deleteme') 715 | numberofrows = cur.fetchone() 716 | assert numberofrows[0] == len(self.samples) 717 | assert cur.nextset() 718 | names = cur.fetchall() 719 | assert len(names) == len(self.samples) 720 | s = cur.nextset() 721 | assert s is None, 'No more return sets, should return None' 722 | finally: 723 | self.help_nextset_tearDown(cur) 724 | 725 | finally: 726 | con.close() 727 | 728 | ''' 729 | def test_nextset(self): 730 | raise NotImplementedError('Drivers need to override this test') 731 | ''' 732 | 733 | def test_arraysize(self): 734 | # Not much here - rest of the tests for this are in test_fetchmany 735 | con = self._connect() 736 | try: 737 | cur = con.cursor() 738 | self.assertEqual( 739 | hasattr(cur, 'arraysize'), True, 740 | 'cursor.arraysize must be defined') 741 | finally: 742 | con.close() 743 | 744 | def test_setinputsizes(self): 745 | con = self._connect() 746 | try: 747 | cur = con.cursor() 748 | cur.setinputsizes((25,)) 749 | self._paraminsert(cur) # Make sure cursor still works 750 | finally: 751 | con.close() 752 | 753 | def test_setoutputsize_basic(self): 754 | # Basic test is to make sure setoutputsize doesn't blow up 755 | con = self._connect() 756 | try: 757 | cur = con.cursor() 758 | cur.setoutputsize(1000) 759 | cur.setoutputsize(2000, 0) 760 | self._paraminsert(cur) # Make sure the cursor still works 761 | finally: 762 | con.close() 763 | 764 | def test_setoutputsize(self): 765 | # Real test for setoutputsize is driver dependant 766 | raise NotImplementedError('Driver need to override this test') 767 | 768 | def test_None(self): 769 | con = self._connect() 770 | try: 771 | cur = con.cursor() 772 | self.executeDDL1(cur) 773 | cur.execute( 774 | 'insert into %sbooze values (NULL)' % self.table_prefix) 775 | cur.execute('select name from %sbooze' % self.table_prefix) 776 | r = cur.fetchall() 777 | self.assertEqual(len(r), 1) 778 | self.assertEqual(len(r[0]), 1) 779 | self.assertEqual(r[0][0], None, 'NULL value not returned as None') 780 | finally: 781 | con.close() 782 | 783 | def test_Date(self): 784 | self.driver.Date(2002, 12, 25) 785 | self.driver.DateFromTicks( 786 | time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) 787 | # Can we assume this? API doesn't specify, but it seems implied 788 | # self.assertEqual(str(d1),str(d2)) 789 | 790 | def test_Time(self): 791 | self.driver.Time(13, 45, 30) 792 | self.driver.TimeFromTicks( 793 | time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) 794 | # Can we assume this? API doesn't specify, but it seems implied 795 | # self.assertEqual(str(t1),str(t2)) 796 | 797 | def test_Timestamp(self): 798 | self.driver.Timestamp(2002, 12, 25, 13, 45, 30) 799 | self.driver.TimestampFromTicks( 800 | time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))) 801 | # Can we assume this? API doesn't specify, but it seems implied 802 | # self.assertEqual(str(t1),str(t2)) 803 | 804 | def test_Binary(self): 805 | self.driver.Binary(b('Something')) 806 | self.driver.Binary(b('')) 807 | 808 | def test_STRING(self): 809 | self.assertEqual( 810 | hasattr(self.driver, 'STRING'), True, 811 | 'module.STRING must be defined') 812 | 813 | def test_BINARY(self): 814 | self.assertEqual( 815 | hasattr(self.driver, 'BINARY'), True, 816 | 'module.BINARY must be defined.') 817 | 818 | def test_NUMBER(self): 819 | self.assertTrue( 820 | hasattr(self.driver, 'NUMBER'), 'module.NUMBER must be defined.') 821 | 822 | def test_DATETIME(self): 823 | self.assertEqual( 824 | hasattr(self.driver, 'DATETIME'), True, 825 | 'module.DATETIME must be defined.') 826 | 827 | def test_ROWID(self): 828 | self.assertEqual( 829 | hasattr(self.driver, 'ROWID'), True, 830 | 'module.ROWID must be defined.') 831 | -------------------------------------------------------------------------------- /tests/performance.py: -------------------------------------------------------------------------------- 1 | import pg8000 2 | from pg8000.tests.connection_settings import db_connect 3 | import time 4 | import warnings 5 | from contextlib import closing 6 | from decimal import Decimal 7 | 8 | 9 | whole_begin_time = time.time() 10 | 11 | tests = ( 12 | ("cast(id / 100 as int2)", 'int2'), 13 | ("cast(id as int4)", 'int4'), 14 | ("cast(id * 100 as int8)", 'int8'), 15 | ("(id %% 2) = 0", 'bool'), 16 | ("N'Static text string'", 'txt'), 17 | ("cast(id / 100 as float4)", 'float4'), 18 | ("cast(id / 100 as float8)", 'float8'), 19 | ("cast(id / 100 as numeric)", 'numeric'), 20 | ("timestamp '2001-09-28' + id * interval '1 second'", 'timestamp'), 21 | ) 22 | 23 | with warnings.catch_warnings(), closing(pg8000.connect(**db_connect)) as db: 24 | for txt, name in tests: 25 | query = """SELECT {0} AS column1, {0} AS column2, {0} AS column3, 26 | {0} AS column4, {0} AS column5, {0} AS column6, {0} AS column7 27 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""".format(txt) 28 | cursor = db.cursor() 29 | print("Beginning %s test..." % name) 30 | for i in range(1, 5): 31 | begin_time = time.time() 32 | cursor.execute(query) 33 | for row in cursor: 34 | pass 35 | end_time = time.time() 36 | print("Attempt %s - %s seconds." % (i, end_time - begin_time)) 37 | db.commit() 38 | cursor = db.cursor() 39 | cursor.execute( 40 | "CREATE TEMPORARY TABLE t1 (f1 serial primary key, " 41 | "f2 bigint not null, f3 varchar(50) null, f4 bool)") 42 | db.commit() 43 | params = [(Decimal('7.4009'), 'season of mists...', True)] * 1000 44 | print("Beginning executemany test...") 45 | for i in range(1, 5): 46 | begin_time = time.time() 47 | cursor.executemany( 48 | "insert into t1 (f2, f3, f4) values (%s, %s, %s)", params) 49 | db.commit() 50 | end_time = time.time() 51 | print("Attempt {0} took {1} seconds.".format(i, end_time - begin_time)) 52 | 53 | print("Beginning reuse statements test...") 54 | begin_time = time.time() 55 | for i in range(2000): 56 | cursor.execute("select count(*) from t1") 57 | cursor.fetchall() 58 | print("Took {0} seconds.".format(time.time() - begin_time)) 59 | 60 | print("Whole time - %s seconds." % (time.time() - whole_begin_time)) 61 | -------------------------------------------------------------------------------- /tests/stress.py: -------------------------------------------------------------------------------- 1 | import pg8000 2 | from pg8000.tests.connection_settings import db_connect 3 | from contextlib import closing 4 | 5 | 6 | with closing(pg8000.connect(**db_connect)) as db: 7 | for i in range(100): 8 | cursor = db.cursor() 9 | cursor.execute(""" 10 | SELECT n.nspname as "Schema", 11 | pg_catalog.format_type(t.oid, NULL) AS "Name", 12 | pg_catalog.obj_description(t.oid, 'pg_type') as "Description" 13 | FROM pg_catalog.pg_type t 14 | LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace 15 | left join pg_catalog.pg_namespace kj on n.oid = t.typnamespace 16 | WHERE (t.typrelid = 0 17 | OR (SELECT c.relkind = 'c' 18 | FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) 19 | AND NOT EXISTS( 20 | SELECT 1 FROM pg_catalog.pg_type el 21 | WHERE el.oid = t.typelem AND el.typarray = t.oid) 22 | AND pg_catalog.pg_type_is_visible(t.oid) 23 | ORDER BY 1, 2;""") 24 | -------------------------------------------------------------------------------- /tests/test_connection.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pg8000 3 | from connection_settings import db_connect 4 | from six import PY2, u 5 | import sys 6 | from distutils.version import LooseVersion 7 | import socket 8 | import struct 9 | 10 | 11 | # Check if running in Jython 12 | if 'java' in sys.platform: 13 | from javax.net.ssl import TrustManager, X509TrustManager 14 | from jarray import array 15 | from javax.net.ssl import SSLContext 16 | 17 | class TrustAllX509TrustManager(X509TrustManager): 18 | '''Define a custom TrustManager which will blindly accept all 19 | certificates''' 20 | 21 | def checkClientTrusted(self, chain, auth): 22 | pass 23 | 24 | def checkServerTrusted(self, chain, auth): 25 | pass 26 | 27 | def getAcceptedIssuers(self): 28 | return None 29 | # Create a static reference to an SSLContext which will use 30 | # our custom TrustManager 31 | trust_managers = array([TrustAllX509TrustManager()], TrustManager) 32 | TRUST_ALL_CONTEXT = SSLContext.getInstance("SSL") 33 | TRUST_ALL_CONTEXT.init(None, trust_managers, None) 34 | # Keep a static reference to the JVM's default SSLContext for restoring 35 | # at a later time 36 | DEFAULT_CONTEXT = SSLContext.getDefault() 37 | 38 | 39 | def trust_all_certificates(f): 40 | '''Decorator function that will make it so the context of the decorated 41 | method will run with our TrustManager that accepts all certificates''' 42 | def wrapped(*args, **kwargs): 43 | # Only do this if running under Jython 44 | if 'java' in sys.platform: 45 | from javax.net.ssl import SSLContext 46 | SSLContext.setDefault(TRUST_ALL_CONTEXT) 47 | try: 48 | res = f(*args, **kwargs) 49 | return res 50 | finally: 51 | SSLContext.setDefault(DEFAULT_CONTEXT) 52 | else: 53 | return f(*args, **kwargs) 54 | return wrapped 55 | 56 | 57 | # Tests related to connecting to a database. 58 | class Tests(unittest.TestCase): 59 | def testSocketMissing(self): 60 | conn_params = { 61 | 'unix_sock': "/file-does-not-exist", 62 | 'user': "doesn't-matter"} 63 | self.assertRaises(pg8000.InterfaceError, pg8000.connect, **conn_params) 64 | 65 | def testDatabaseMissing(self): 66 | data = db_connect.copy() 67 | data["database"] = "missing-db" 68 | self.assertRaises(pg8000.ProgrammingError, pg8000.connect, **data) 69 | 70 | def testNotify(self): 71 | 72 | try: 73 | db = pg8000.connect(**db_connect) 74 | self.assertEqual(list(db.notifications), []) 75 | cursor = db.cursor() 76 | cursor.execute("LISTEN test") 77 | cursor.execute("NOTIFY test") 78 | db.commit() 79 | 80 | cursor.execute("VALUES (1, 2), (3, 4), (5, 6)") 81 | self.assertEqual(len(db.notifications), 1) 82 | self.assertEqual(db.notifications[0][1], "test") 83 | finally: 84 | cursor.close() 85 | db.close() 86 | 87 | # This requires a line in pg_hba.conf that requires md5 for the database 88 | # pg8000_md5 89 | 90 | def testMd5(self): 91 | data = db_connect.copy() 92 | data["database"] = "pg8000_md5" 93 | 94 | # Should only raise an exception saying db doesn't exist 95 | if PY2: 96 | self.assertRaises( 97 | pg8000.ProgrammingError, pg8000.connect, **data) 98 | else: 99 | self.assertRaisesRegex( 100 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data) 101 | 102 | # This requires a line in pg_hba.conf that requires gss for the database 103 | # pg8000_gss 104 | 105 | def testGss(self): 106 | data = db_connect.copy() 107 | data["database"] = "pg8000_gss" 108 | 109 | # Should raise an exception saying gss isn't supported 110 | if PY2: 111 | self.assertRaises(pg8000.InterfaceError, pg8000.connect, **data) 112 | else: 113 | self.assertRaisesRegex( 114 | pg8000.InterfaceError, 115 | "Authentication method 7 not supported by pg8000.", 116 | pg8000.connect, **data) 117 | 118 | @trust_all_certificates 119 | def testSsl(self): 120 | data = db_connect.copy() 121 | data["ssl"] = True 122 | db = pg8000.connect(**data) 123 | db.close() 124 | 125 | # This requires a line in pg_hba.conf that requires 'password' for the 126 | # database pg8000_password 127 | 128 | def testPassword(self): 129 | data = db_connect.copy() 130 | data["database"] = "pg8000_password" 131 | 132 | # Should only raise an exception saying db doesn't exist 133 | if PY2: 134 | self.assertRaises( 135 | pg8000.ProgrammingError, pg8000.connect, **data) 136 | else: 137 | self.assertRaisesRegex( 138 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data) 139 | 140 | def testUnicodeDatabaseName(self): 141 | data = db_connect.copy() 142 | data["database"] = "pg8000_sn\uFF6Fw" 143 | 144 | # Should only raise an exception saying db doesn't exist 145 | if PY2: 146 | self.assertRaises( 147 | pg8000.ProgrammingError, pg8000.connect, **data) 148 | else: 149 | self.assertRaisesRegex( 150 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data) 151 | 152 | def testBytesDatabaseName(self): 153 | data = db_connect.copy() 154 | 155 | # Should only raise an exception saying db doesn't exist 156 | if PY2: 157 | data["database"] = "pg8000_sn\uFF6Fw" 158 | self.assertRaises( 159 | pg8000.ProgrammingError, pg8000.connect, **data) 160 | else: 161 | data["database"] = bytes("pg8000_sn\uFF6Fw", 'utf8') 162 | self.assertRaisesRegex( 163 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data) 164 | 165 | def testBytesPassword(self): 166 | db = pg8000.connect(**db_connect) 167 | # Create user 168 | username = 'boltzmann' 169 | password = u('cha\uFF6Fs') 170 | cur = db.cursor() 171 | 172 | # Delete user if left over from previous run 173 | try: 174 | cur.execute("drop role " + username) 175 | except pg8000.ProgrammingError: 176 | cur.execute("rollback") 177 | 178 | cur.execute( 179 | "create user " + username + " with password '" + password + "';") 180 | cur.execute('commit;') 181 | db.close() 182 | 183 | data = db_connect.copy() 184 | data['user'] = username 185 | data['password'] = password.encode('utf8') 186 | data['database'] = 'pg8000_md5' 187 | if PY2: 188 | self.assertRaises( 189 | pg8000.ProgrammingError, pg8000.connect, **data) 190 | else: 191 | self.assertRaisesRegex( 192 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data) 193 | 194 | db = pg8000.connect(**db_connect) 195 | cur = db.cursor() 196 | cur.execute("drop role " + username) 197 | cur.execute("commit;") 198 | db.close() 199 | 200 | def testBrokenPipe(self): 201 | db1 = pg8000.connect(**db_connect) 202 | db2 = pg8000.connect(**db_connect) 203 | 204 | try: 205 | cur1 = db1.cursor() 206 | cur2 = db2.cursor() 207 | 208 | cur1.execute("select pg_backend_pid()") 209 | pid1 = cur1.fetchone()[0] 210 | 211 | cur2.execute("select pg_terminate_backend(%s)", (pid1,)) 212 | try: 213 | cur1.execute("select 1") 214 | except Exception as e: 215 | self.assertTrue(isinstance(e, (socket.error, struct.error))) 216 | 217 | cur2.close() 218 | finally: 219 | db1.close() 220 | db2.close() 221 | 222 | def testApplicatioName(self): 223 | params = db_connect.copy() 224 | params['application_name'] = 'my test application name' 225 | db = pg8000.connect(**params) 226 | cur = db.cursor() 227 | 228 | if db._server_version >= LooseVersion('9.2'): 229 | cur.execute('select application_name from pg_stat_activity ' 230 | ' where pid = pg_backend_pid()') 231 | else: 232 | # for pg9.1 and earlier, procpod field rather than pid 233 | cur.execute('select application_name from pg_stat_activity ' 234 | ' where procpid = pg_backend_pid()') 235 | 236 | application_name = cur.fetchone()[0] 237 | self.assertEqual(application_name, 'my test application name') 238 | 239 | 240 | if __name__ == "__main__": 241 | unittest.main() 242 | -------------------------------------------------------------------------------- /tests/test_copy.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pg8000 3 | from connection_settings import db_connect 4 | from six import b, BytesIO, u, iteritems 5 | from sys import exc_info 6 | 7 | 8 | class Tests(unittest.TestCase): 9 | def setUp(self): 10 | self.db = pg8000.connect(**db_connect) 11 | try: 12 | cursor = self.db.cursor() 13 | try: 14 | cursor = self.db.cursor() 15 | cursor.execute("DROP TABLE t1") 16 | except pg8000.DatabaseError: 17 | e = exc_info()[1] 18 | # the only acceptable error is: 19 | msg = e.args[0] 20 | code = msg['C'] 21 | self.assertEqual( 22 | code, '42P01', # table does not exist 23 | "incorrect error for drop table: " + str(msg)) 24 | self.db.rollback() 25 | cursor.execute( 26 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 27 | "f2 int not null, f3 varchar(50) null)") 28 | finally: 29 | cursor.close() 30 | 31 | def tearDown(self): 32 | self.db.close() 33 | 34 | def testCopyToWithTable(self): 35 | try: 36 | cursor = self.db.cursor() 37 | cursor.execute( 38 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, 1)) 39 | cursor.execute( 40 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 2, 2)) 41 | cursor.execute( 42 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 3, 3)) 43 | 44 | stream = BytesIO() 45 | cursor.execute("copy t1 to stdout", stream=stream) 46 | self.assertEqual( 47 | stream.getvalue(), b("1\t1\t1\n2\t2\t2\n3\t3\t3\n")) 48 | self.assertEqual(cursor.rowcount, 3) 49 | self.db.commit() 50 | finally: 51 | cursor.close() 52 | 53 | def testCopyToWithQuery(self): 54 | try: 55 | cursor = self.db.cursor() 56 | stream = BytesIO() 57 | cursor.execute( 58 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER " 59 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two", stream=stream) 60 | self.assertEqual(stream.getvalue(), b('oneXtwo\n1XY2Y\n')) 61 | self.assertEqual(cursor.rowcount, 1) 62 | self.db.rollback() 63 | finally: 64 | cursor.close() 65 | 66 | def testCopyFromWithTable(self): 67 | try: 68 | cursor = self.db.cursor() 69 | stream = BytesIO(b("1\t1\t1\n2\t2\t2\n3\t3\t3\n")) 70 | cursor.execute("copy t1 from STDIN", stream=stream) 71 | self.assertEqual(cursor.rowcount, 3) 72 | 73 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 74 | retval = cursor.fetchall() 75 | self.assertEqual(retval, ([1, 1, '1'], [2, 2, '2'], [3, 3, '3'])) 76 | self.db.rollback() 77 | finally: 78 | cursor.close() 79 | 80 | def testCopyFromWithQuery(self): 81 | try: 82 | cursor = self.db.cursor() 83 | stream = BytesIO(b("f1Xf2\n1XY1Y\n")) 84 | cursor.execute( 85 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 86 | "QUOTE AS 'Y' FORCE NOT NULL f1", stream=stream) 87 | self.assertEqual(cursor.rowcount, 1) 88 | 89 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 90 | retval = cursor.fetchall() 91 | self.assertEqual(retval, ([1, 1, None],)) 92 | self.db.commit() 93 | finally: 94 | cursor.close() 95 | 96 | def testCopyFromWithError(self): 97 | try: 98 | cursor = self.db.cursor() 99 | stream = BytesIO(b("f1Xf2\n\n1XY1Y\n")) 100 | cursor.execute( 101 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER " 102 | "QUOTE AS 'Y' FORCE NOT NULL f1", stream=stream) 103 | self.assertTrue(False, "Should have raised an exception") 104 | except: 105 | args_dict = { 106 | 'S': u('ERROR'), 107 | 'C': u('22P02'), 108 | 'M': u('invalid input syntax for integer: ""'), 109 | 'W': u('COPY t1, line 2, column f1: ""'), 110 | 'F': u('numutils.c'), 111 | 'R': u('pg_atoi') 112 | } 113 | args = exc_info()[1].args[0] 114 | for k, v in iteritems(args_dict): 115 | self.assertEqual(args[k], v) 116 | finally: 117 | cursor.close() 118 | 119 | 120 | if __name__ == "__main__": 121 | unittest.main() 122 | -------------------------------------------------------------------------------- /tests/test_dbapi.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import time 4 | import pg8000 5 | import datetime 6 | from connection_settings import db_connect 7 | from sys import exc_info 8 | from six import b 9 | from distutils.version import LooseVersion 10 | 11 | 12 | # DBAPI compatible interface tests 13 | class Tests(unittest.TestCase): 14 | def setUp(self): 15 | self.db = pg8000.connect(**db_connect) 16 | 17 | # Neither Windows nor Jython 2.5.3 have a time.tzset() so skip 18 | if hasattr(time, 'tzset'): 19 | os.environ['TZ'] = "UTC" 20 | time.tzset() 21 | self.HAS_TZSET = True 22 | else: 23 | self.HAS_TZSET = False 24 | 25 | try: 26 | c = self.db.cursor() 27 | try: 28 | c = self.db.cursor() 29 | c.execute("DROP TABLE t1") 30 | except pg8000.DatabaseError: 31 | e = exc_info()[1] 32 | # the only acceptable error is table does not exist 33 | self.assertEqual(e.args[0]['C'], '42P01') 34 | self.db.rollback() 35 | c.execute( 36 | "CREATE TEMPORARY TABLE t1 " 37 | "(f1 int primary key, f2 int not null, f3 varchar(50) null)") 38 | c.execute( 39 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 40 | (1, 1, None)) 41 | c.execute( 42 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 43 | (2, 10, None)) 44 | c.execute( 45 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 46 | (3, 100, None)) 47 | c.execute( 48 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 49 | (4, 1000, None)) 50 | c.execute( 51 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 52 | (5, 10000, None)) 53 | self.db.commit() 54 | finally: 55 | c.close() 56 | 57 | def tearDown(self): 58 | self.db.close() 59 | 60 | def testParallelQueries(self): 61 | try: 62 | c1 = self.db.cursor() 63 | c2 = self.db.cursor() 64 | 65 | c1.execute("SELECT f1, f2, f3 FROM t1") 66 | while 1: 67 | row = c1.fetchone() 68 | if row is None: 69 | break 70 | f1, f2, f3 = row 71 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,)) 72 | while 1: 73 | row = c2.fetchone() 74 | if row is None: 75 | break 76 | f1, f2, f3 = row 77 | finally: 78 | c1.close() 79 | c2.close() 80 | 81 | self.db.rollback() 82 | 83 | def testQmark(self): 84 | orig_paramstyle = pg8000.paramstyle 85 | try: 86 | pg8000.paramstyle = "qmark" 87 | c1 = self.db.cursor() 88 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,)) 89 | while 1: 90 | row = c1.fetchone() 91 | if row is None: 92 | break 93 | f1, f2, f3 = row 94 | self.db.rollback() 95 | finally: 96 | pg8000.paramstyle = orig_paramstyle 97 | c1.close() 98 | 99 | def testNumeric(self): 100 | orig_paramstyle = pg8000.paramstyle 101 | try: 102 | pg8000.paramstyle = "numeric" 103 | c1 = self.db.cursor() 104 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,)) 105 | while 1: 106 | row = c1.fetchone() 107 | if row is None: 108 | break 109 | f1, f2, f3 = row 110 | self.db.rollback() 111 | finally: 112 | pg8000.paramstyle = orig_paramstyle 113 | c1.close() 114 | 115 | def testNamed(self): 116 | orig_paramstyle = pg8000.paramstyle 117 | try: 118 | pg8000.paramstyle = "named" 119 | c1 = self.db.cursor() 120 | c1.execute( 121 | "SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3}) 122 | while 1: 123 | row = c1.fetchone() 124 | if row is None: 125 | break 126 | f1, f2, f3 = row 127 | self.db.rollback() 128 | finally: 129 | pg8000.paramstyle = orig_paramstyle 130 | c1.close() 131 | 132 | def testFormat(self): 133 | orig_paramstyle = pg8000.paramstyle 134 | try: 135 | pg8000.paramstyle = "format" 136 | c1 = self.db.cursor() 137 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,)) 138 | while 1: 139 | row = c1.fetchone() 140 | if row is None: 141 | break 142 | f1, f2, f3 = row 143 | self.db.commit() 144 | finally: 145 | pg8000.paramstyle = orig_paramstyle 146 | c1.close() 147 | 148 | def testPyformat(self): 149 | orig_paramstyle = pg8000.paramstyle 150 | try: 151 | pg8000.paramstyle = "pyformat" 152 | c1 = self.db.cursor() 153 | c1.execute( 154 | "SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3}) 155 | while 1: 156 | row = c1.fetchone() 157 | if row is None: 158 | break 159 | f1, f2, f3 = row 160 | self.db.commit() 161 | finally: 162 | pg8000.paramstyle = orig_paramstyle 163 | c1.close() 164 | 165 | def testArraysize(self): 166 | try: 167 | c1 = self.db.cursor() 168 | c1.arraysize = 3 169 | c1.execute("SELECT * FROM t1") 170 | retval = c1.fetchmany() 171 | self.assertEqual(len(retval), c1.arraysize) 172 | finally: 173 | c1.close() 174 | self.db.commit() 175 | 176 | def testDate(self): 177 | val = pg8000.Date(2001, 2, 3) 178 | self.assertEqual(val, datetime.date(2001, 2, 3)) 179 | 180 | def testTime(self): 181 | val = pg8000.Time(4, 5, 6) 182 | self.assertEqual(val, datetime.time(4, 5, 6)) 183 | 184 | def testTimestamp(self): 185 | val = pg8000.Timestamp(2001, 2, 3, 4, 5, 6) 186 | self.assertEqual(val, datetime.datetime(2001, 2, 3, 4, 5, 6)) 187 | 188 | def testDateFromTicks(self): 189 | if self.HAS_TZSET: 190 | val = pg8000.DateFromTicks(1173804319) 191 | self.assertEqual(val, datetime.date(2007, 3, 13)) 192 | 193 | def testTimeFromTicks(self): 194 | if self.HAS_TZSET: 195 | val = pg8000.TimeFromTicks(1173804319) 196 | self.assertEqual(val, datetime.time(16, 45, 19)) 197 | 198 | def testTimestampFromTicks(self): 199 | if self.HAS_TZSET: 200 | val = pg8000.TimestampFromTicks(1173804319) 201 | self.assertEqual(val, datetime.datetime(2007, 3, 13, 16, 45, 19)) 202 | 203 | def testBinary(self): 204 | v = pg8000.Binary(b("\x00\x01\x02\x03\x02\x01\x00")) 205 | self.assertEqual(v, b("\x00\x01\x02\x03\x02\x01\x00")) 206 | self.assertTrue(isinstance(v, pg8000.BINARY)) 207 | 208 | def testRowCount(self): 209 | try: 210 | c1 = self.db.cursor() 211 | c1.execute("SELECT * FROM t1") 212 | 213 | # Before PostgreSQL 9 we don't know the row count for a select 214 | if self.db._server_version > LooseVersion('8.0.0'): 215 | self.assertEqual(5, c1.rowcount) 216 | 217 | c1.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",)) 218 | self.assertEqual(2, c1.rowcount) 219 | 220 | c1.execute("DELETE FROM t1") 221 | self.assertEqual(5, c1.rowcount) 222 | finally: 223 | c1.close() 224 | self.db.commit() 225 | 226 | def testFetchMany(self): 227 | try: 228 | cursor = self.db.cursor() 229 | cursor.arraysize = 2 230 | cursor.execute("SELECT * FROM t1") 231 | self.assertEqual(2, len(cursor.fetchmany())) 232 | self.assertEqual(2, len(cursor.fetchmany())) 233 | self.assertEqual(1, len(cursor.fetchmany())) 234 | self.assertEqual(0, len(cursor.fetchmany())) 235 | finally: 236 | cursor.close() 237 | self.db.commit() 238 | 239 | def testIterator(self): 240 | from warnings import filterwarnings 241 | filterwarnings("ignore", "DB-API extension cursor.next()") 242 | filterwarnings("ignore", "DB-API extension cursor.__iter__()") 243 | 244 | try: 245 | cursor = self.db.cursor() 246 | cursor.execute("SELECT * FROM t1 ORDER BY f1") 247 | f1 = 0 248 | for row in cursor: 249 | next_f1 = row[0] 250 | assert next_f1 > f1 251 | f1 = next_f1 252 | except: 253 | cursor.close() 254 | 255 | self.db.commit() 256 | 257 | # Vacuum can't be run inside a transaction, so we need to turn 258 | # autocommit on. 259 | def testVacuum(self): 260 | self.db.autocommit = True 261 | try: 262 | cursor = self.db.cursor() 263 | cursor.execute("vacuum") 264 | finally: 265 | cursor.close() 266 | 267 | def testPreparedStatement(self): 268 | cursor = self.db.cursor() 269 | cursor.execute( 270 | 'PREPARE gen_series AS SELECT generate_series(1, 10);') 271 | cursor.execute('EXECUTE gen_series') 272 | 273 | 274 | if __name__ == "__main__": 275 | unittest.main() 276 | -------------------------------------------------------------------------------- /tests/test_error_recovery.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pg8000 3 | from connection_settings import db_connect 4 | import warnings 5 | import datetime 6 | from sys import exc_info 7 | 8 | 9 | class TestException(Exception): 10 | pass 11 | 12 | 13 | class Tests(unittest.TestCase): 14 | def setUp(self): 15 | self.db = pg8000.connect(**db_connect) 16 | 17 | def tearDown(self): 18 | self.db.close() 19 | 20 | def raiseException(self, value): 21 | raise TestException("oh noes!") 22 | 23 | def testPyValueFail(self): 24 | # Ensure that if types.py_value throws an exception, the original 25 | # exception is raised (TestException), and the connection is 26 | # still usable after the error. 27 | orig = self.db.py_types[datetime.time] 28 | self.db.py_types[datetime.time] = ( 29 | orig[0], orig[1], self.raiseException) 30 | 31 | try: 32 | c = self.db.cursor() 33 | try: 34 | try: 35 | c.execute("SELECT %s as f1", (datetime.time(10, 30),)) 36 | c.fetchall() 37 | # shouldn't get here, exception should be thrown 38 | self.fail() 39 | except TestException: 40 | # should be TestException type, this is OK! 41 | self.db.rollback() 42 | finally: 43 | self.db.py_types[datetime.time] = orig 44 | 45 | # ensure that the connection is still usable for a new query 46 | c.execute("VALUES ('hw3'::text)") 47 | self.assertEqual(c.fetchone()[0], "hw3") 48 | finally: 49 | c.close() 50 | 51 | def testNoDataErrorRecovery(self): 52 | for i in range(1, 4): 53 | try: 54 | try: 55 | cursor = self.db.cursor() 56 | cursor.execute("DROP TABLE t1") 57 | finally: 58 | cursor.close() 59 | except pg8000.DatabaseError: 60 | e = exc_info()[1] 61 | # the only acceptable error is 'table does not exist' 62 | self.assertEqual(e.args[0]['C'], '42P01') 63 | self.db.rollback() 64 | 65 | def testClosedConnection(self): 66 | warnings.simplefilter("ignore") 67 | my_db = pg8000.connect(**db_connect) 68 | cursor = my_db.cursor() 69 | my_db.close() 70 | try: 71 | cursor.execute("VALUES ('hw1'::text)") 72 | self.fail("Should have raised an exception") 73 | except: 74 | e = exc_info()[1] 75 | self.assertTrue(isinstance(e, self.db.InterfaceError)) 76 | self.assertEqual(str(e), 'connection is closed') 77 | 78 | warnings.resetwarnings() 79 | 80 | 81 | if __name__ == "__main__": 82 | unittest.main() 83 | -------------------------------------------------------------------------------- /tests/test_paramstyle.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pg8000 3 | 4 | 5 | # Tests of the convert_paramstyle function. 6 | class Tests(unittest.TestCase): 7 | def testQmark(self): 8 | new_query, make_args = pg8000.core.convert_paramstyle( 9 | "qmark", "SELECT ?, ?, \"field_?\" FROM t " 10 | "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'") 11 | self.assertEqual( 12 | new_query, "SELECT $1, $2, \"field_?\" FROM t WHERE " 13 | "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'") 14 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3)) 15 | 16 | def testQmark2(self): 17 | new_query, make_args = pg8000.core.convert_paramstyle( 18 | "qmark", "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'") 19 | self.assertEqual( 20 | new_query, 21 | "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'") 22 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3)) 23 | 24 | def testNumeric(self): 25 | new_query, make_args = pg8000.core.convert_paramstyle( 26 | "numeric", "SELECT :2, :1, * FROM t WHERE a=:3") 27 | self.assertEqual(new_query, "SELECT $2, $1, * FROM t WHERE a=$3") 28 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3)) 29 | 30 | def testNamed(self): 31 | new_query, make_args = pg8000.core.convert_paramstyle( 32 | "named", "SELECT :f_2, :f1 FROM t WHERE a=:f_2") 33 | self.assertEqual(new_query, "SELECT $1, $2 FROM t WHERE a=$1") 34 | self.assertEqual(make_args({"f_2": 1, "f1": 2}), (1, 2)) 35 | 36 | def testFormat(self): 37 | new_query, make_args = pg8000.core.convert_paramstyle( 38 | "format", "SELECT %s, %s, \"f1_%%\", E'txt_%%' " 39 | "FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %") 40 | self.assertEqual( 41 | new_query, 42 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " 43 | "b='75%%' AND c = '%' -- Comment with %") 44 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3)) 45 | 46 | sql = r"""COMMENT ON TABLE test_schema.comment_test """ \ 47 | r"""IS 'the test % '' " \ table comment'""" 48 | new_query, make_args = pg8000.core.convert_paramstyle("format", sql) 49 | self.assertEqual(new_query, sql) 50 | 51 | def testFormatMultiline(self): 52 | new_query, make_args = pg8000.core.convert_paramstyle( 53 | "format", "SELECT -- Comment\n%s FROM t") 54 | self.assertEqual( 55 | new_query, 56 | "SELECT -- Comment\n$1 FROM t") 57 | 58 | def testPyformat(self): 59 | new_query, make_args = pg8000.core.convert_paramstyle( 60 | "pyformat", "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' " 61 | "FROM t WHERE a=%(f2)s AND b='75%%'") 62 | self.assertEqual( 63 | new_query, 64 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " 65 | "b='75%%'") 66 | self.assertEqual(make_args({"f2": 1, "f1": 2, "f3": 3}), (1, 2)) 67 | 68 | # pyformat should support %s and an array, too: 69 | new_query, make_args = pg8000.core.convert_paramstyle( 70 | "pyformat", "SELECT %s, %s, \"f1_%%\", E'txt_%%' " 71 | "FROM t WHERE a=%s AND b='75%%'") 72 | self.assertEqual( 73 | new_query, 74 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " 75 | "b='75%%'") 76 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3)) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_pg8000_dbapi20.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import dbapi20 3 | import unittest 4 | import pg8000 5 | from connection_settings import db_connect 6 | 7 | 8 | class Tests(dbapi20.DatabaseAPI20Test): 9 | driver = pg8000 10 | connect_args = () 11 | connect_kw_args = db_connect 12 | 13 | lower_func = 'lower' # For stored procedure test 14 | 15 | def test_nextset(self): 16 | pass 17 | 18 | def test_setoutputsize(self): 19 | pass 20 | 21 | 22 | if __name__ == '__main__': 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /tests/test_query.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pg8000 3 | from connection_settings import db_connect 4 | from six import u 5 | from sys import exc_info 6 | import datetime 7 | from distutils.version import LooseVersion 8 | 9 | from warnings import filterwarnings 10 | 11 | 12 | # Tests relating to the basic operation of the database driver, driven by the 13 | # pg8000 custom interface. 14 | class Tests(unittest.TestCase): 15 | def setUp(self): 16 | self.db = pg8000.connect(**db_connect) 17 | filterwarnings("ignore", "DB-API extension cursor.next()") 18 | filterwarnings("ignore", "DB-API extension cursor.__iter__()") 19 | self.db.paramstyle = 'format' 20 | try: 21 | cursor = self.db.cursor() 22 | try: 23 | cursor.execute("DROP TABLE t1") 24 | except pg8000.DatabaseError: 25 | e = exc_info()[1] 26 | # the only acceptable error is 'table does not exist' 27 | self.assertEqual(e.args[0]['C'], '42P01') 28 | self.db.rollback() 29 | cursor.execute( 30 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, " 31 | "f2 bigint not null, f3 varchar(50) null)") 32 | finally: 33 | cursor.close() 34 | 35 | self.db.commit() 36 | 37 | def tearDown(self): 38 | self.db.close() 39 | 40 | def testDatabaseError(self): 41 | try: 42 | cursor = self.db.cursor() 43 | self.assertRaises( 44 | pg8000.ProgrammingError, cursor.execute, 45 | "INSERT INTO t99 VALUES (1, 2, 3)") 46 | finally: 47 | cursor.close() 48 | 49 | self.db.rollback() 50 | 51 | def testParallelQueries(self): 52 | try: 53 | cursor = self.db.cursor() 54 | cursor.execute( 55 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 56 | (1, 1, None)) 57 | cursor.execute( 58 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 59 | (2, 10, None)) 60 | cursor.execute( 61 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 62 | (3, 100, None)) 63 | cursor.execute( 64 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 65 | (4, 1000, None)) 66 | cursor.execute( 67 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 68 | (5, 10000, None)) 69 | try: 70 | c1 = self.db.cursor() 71 | c2 = self.db.cursor() 72 | c1.execute("SELECT f1, f2, f3 FROM t1") 73 | for row in c1: 74 | f1, f2, f3 = row 75 | c2.execute( 76 | "SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,)) 77 | for row in c2: 78 | f1, f2, f3 = row 79 | finally: 80 | c1.close() 81 | c2.close() 82 | finally: 83 | cursor.close() 84 | self.db.rollback() 85 | 86 | def testParallelOpenPortals(self): 87 | try: 88 | c1, c2 = self.db.cursor(), self.db.cursor() 89 | c1count, c2count = 0, 0 90 | q = "select * from generate_series(1, %s)" 91 | params = (100,) 92 | c1.execute(q, params) 93 | c2.execute(q, params) 94 | for c2row in c2: 95 | c2count += 1 96 | for c1row in c1: 97 | c1count += 1 98 | finally: 99 | c1.close() 100 | c2.close() 101 | self.db.rollback() 102 | 103 | self.assertEqual(c1count, c2count) 104 | 105 | # Run a query on a table, alter the structure of the table, then run the 106 | # original query again. 107 | 108 | def testAlter(self): 109 | try: 110 | cursor = self.db.cursor() 111 | cursor.execute("select * from t1") 112 | cursor.execute("alter table t1 drop column f3") 113 | cursor.execute("select * from t1") 114 | finally: 115 | cursor.close() 116 | self.db.rollback() 117 | 118 | # Run a query on a table, drop then re-create the table, then run the 119 | # original query again. 120 | 121 | def testCreate(self): 122 | try: 123 | cursor = self.db.cursor() 124 | cursor.execute("select * from t1") 125 | cursor.execute("drop table t1") 126 | cursor.execute("create temporary table t1 (f1 int primary key)") 127 | cursor.execute("select * from t1") 128 | finally: 129 | cursor.close() 130 | self.db.rollback() 131 | 132 | def testInsertReturning(self): 133 | try: 134 | cursor = self.db.cursor() 135 | cursor.execute("CREATE TABLE t2 (id serial, data text)") 136 | 137 | # Test INSERT ... RETURNING with one row... 138 | cursor.execute( 139 | "INSERT INTO t2 (data) VALUES (%s) RETURNING id", 140 | ("test1",)) 141 | row_id = cursor.fetchone()[0] 142 | cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,)) 143 | self.assertEqual("test1", cursor.fetchone()[0]) 144 | 145 | # Before PostgreSQL 9 we don't know the row count for a select 146 | if self.db._server_version > LooseVersion('8.0.0'): 147 | self.assertEqual(cursor.rowcount, 1) 148 | 149 | # Test with multiple rows... 150 | cursor.execute( 151 | "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) " 152 | "RETURNING id", ("test2", "test3", "test4")) 153 | self.assertEqual(cursor.rowcount, 3) 154 | ids = tuple([x[0] for x in cursor]) 155 | self.assertEqual(len(ids), 3) 156 | finally: 157 | cursor.close() 158 | self.db.rollback() 159 | 160 | def testRowCount(self): 161 | # Before PostgreSQL 9 we don't know the row count for a select 162 | if self.db._server_version > LooseVersion('8.0.0'): 163 | try: 164 | cursor = self.db.cursor() 165 | expected_count = 57 166 | cursor.executemany( 167 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 168 | tuple((i, i, None) for i in range(expected_count))) 169 | 170 | # Check rowcount after executemany 171 | self.assertEqual(expected_count, cursor.rowcount) 172 | self.db.commit() 173 | 174 | cursor.execute("SELECT * FROM t1") 175 | 176 | # Check row_count without doing any reading first... 177 | self.assertEqual(expected_count, cursor.rowcount) 178 | 179 | # Check rowcount after reading some rows, make sure it still 180 | # works... 181 | for i in range(expected_count // 2): 182 | cursor.fetchone() 183 | self.assertEqual(expected_count, cursor.rowcount) 184 | finally: 185 | cursor.close() 186 | self.db.commit() 187 | 188 | try: 189 | cursor = self.db.cursor() 190 | # Restart the cursor, read a few rows, and then check rowcount 191 | # again... 192 | cursor = self.db.cursor() 193 | cursor.execute("SELECT * FROM t1") 194 | for i in range(expected_count // 3): 195 | cursor.fetchone() 196 | self.assertEqual(expected_count, cursor.rowcount) 197 | self.db.rollback() 198 | 199 | # Should be -1 for a command with no results 200 | cursor.execute("DROP TABLE t1") 201 | self.assertEqual(-1, cursor.rowcount) 202 | finally: 203 | cursor.close() 204 | self.db.commit() 205 | 206 | def testRowCountUpdate(self): 207 | try: 208 | cursor = self.db.cursor() 209 | cursor.execute( 210 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 211 | (1, 1, None)) 212 | cursor.execute( 213 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 214 | (2, 10, None)) 215 | cursor.execute( 216 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 217 | (3, 100, None)) 218 | cursor.execute( 219 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 220 | (4, 1000, None)) 221 | cursor.execute( 222 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 223 | (5, 10000, None)) 224 | cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",)) 225 | self.assertEqual(cursor.rowcount, 2) 226 | finally: 227 | cursor.close() 228 | self.db.commit() 229 | 230 | def testIntOid(self): 231 | try: 232 | cursor = self.db.cursor() 233 | # https://bugs.launchpad.net/pg8000/+bug/230796 234 | cursor.execute( 235 | "SELECT typname FROM pg_type WHERE oid = %s", (100,)) 236 | finally: 237 | cursor.close() 238 | self.db.rollback() 239 | 240 | def testUnicodeQuery(self): 241 | try: 242 | cursor = self.db.cursor() 243 | cursor.execute( 244 | u( 245 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e " 246 | "(\u0438\u043c\u044f VARCHAR(50), " 247 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))")) 248 | finally: 249 | cursor.close() 250 | self.db.commit() 251 | 252 | def testExecutemany(self): 253 | try: 254 | cursor = self.db.cursor() 255 | cursor.executemany( 256 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 257 | ((1, 1, 'Avast ye!'), (2, 1, None))) 258 | 259 | cursor.executemany( 260 | "select %s", 261 | ( 262 | (datetime.datetime(2014, 5, 7, tzinfo=pg8000.core.utc), ), 263 | (datetime.datetime(2014, 5, 7),))) 264 | finally: 265 | cursor.close() 266 | self.db.commit() 267 | 268 | # Check that autocommit stays off 269 | # We keep track of whether we're in a transaction or not by using the 270 | # READY_FOR_QUERY message. 271 | def testTransactions(self): 272 | try: 273 | cursor = self.db.cursor() 274 | cursor.execute("commit") 275 | cursor.execute( 276 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", 277 | (1, 1, "Zombie")) 278 | cursor.execute("rollback") 279 | cursor.execute("select * from t1") 280 | 281 | # Before PostgreSQL 9 we don't know the row count for a select 282 | if self.db._server_version > LooseVersion('8.0.0'): 283 | self.assertEqual(cursor.rowcount, 0) 284 | finally: 285 | cursor.close() 286 | self.db.commit() 287 | 288 | def testIn(self): 289 | try: 290 | cursor = self.db.cursor() 291 | cursor.execute( 292 | "SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],)) 293 | ret = cursor.fetchall() 294 | self.assertEqual(ret[0][0], 'bool') 295 | finally: 296 | cursor.close() 297 | 298 | def test_no_previous_tpc(self): 299 | try: 300 | self.db.tpc_begin('Stacey') 301 | cursor = self.db.cursor() 302 | cursor.execute("SELECT * FROM pg_type") 303 | self.db.tpc_commit() 304 | finally: 305 | cursor.close() 306 | 307 | # Check that tpc_recover() doesn't start a transaction 308 | def test_tpc_recover(self): 309 | try: 310 | self.db.tpc_recover() 311 | cursor = self.db.cursor() 312 | self.db.autocommit = True 313 | 314 | # If tpc_recover() has started a transaction, this will fail 315 | cursor.execute("VACUUM") 316 | finally: 317 | cursor.close() 318 | 319 | # An empty query should raise a ProgrammingError 320 | def test_empty_query(self): 321 | try: 322 | cursor = self.db.cursor() 323 | self.assertRaises(pg8000.ProgrammingError, cursor.execute, "") 324 | finally: 325 | cursor.close() 326 | 327 | # rolling back when not in a transaction doesn't generate a warning 328 | def test_rollback_no_transaction(self): 329 | try: 330 | # Remove any existing notices 331 | self.db.notices.clear() 332 | 333 | cursor = self.db.cursor() 334 | 335 | # First, verify that a raw rollback does produce a notice 336 | self.db.execute(cursor, "rollback", None) 337 | 338 | self.assertEqual(1, len(self.db.notices)) 339 | # 25P01 is the code for no_active_sql_tronsaction. It has 340 | # a message and severity name, but those might be 341 | # localized/depend on the server version. 342 | self.assertEqual(self.db.notices.pop().get(b'C'), b'25P01') 343 | 344 | # Now going through the rollback method doesn't produce 345 | # any notices because it knows we're not in a transaction. 346 | self.db.rollback() 347 | 348 | self.assertEqual(0, len(self.db.notices)) 349 | 350 | finally: 351 | cursor.close() 352 | 353 | def test_context_manager_class(self): 354 | self.assertTrue('__enter__' in pg8000.core.Cursor.__dict__) 355 | self.assertTrue('__exit__' in pg8000.core.Cursor.__dict__) 356 | 357 | with self.db.cursor() as cursor: 358 | cursor.execute('select 1') 359 | 360 | def test_deallocate_prepared_statements(self): 361 | try: 362 | cursor = self.db.cursor() 363 | cursor.execute("select * from t1") 364 | cursor.execute("alter table t1 drop column f3") 365 | cursor.execute("select count(*) from pg_prepared_statements") 366 | res = cursor.fetchall() 367 | self.assertEqual(res[0][0], 1) 368 | finally: 369 | cursor.close() 370 | self.db.rollback() 371 | 372 | if __name__ == "__main__": 373 | unittest.main() 374 | -------------------------------------------------------------------------------- /tests/test_typeconversion.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pg8000 3 | from pg8000 import PGJsonb, PGEnum 4 | import datetime 5 | import decimal 6 | import struct 7 | from connection_settings import db_connect 8 | from six import b, PY2, u 9 | import uuid 10 | import os 11 | import time 12 | from distutils.version import LooseVersion 13 | import sys 14 | import json 15 | import pytz 16 | from collections import OrderedDict 17 | 18 | 19 | IS_JYTHON = sys.platform.lower().count('java') > 0 20 | 21 | 22 | # Type conversion tests 23 | class Tests(unittest.TestCase): 24 | def setUp(self): 25 | self.db = pg8000.connect(**db_connect) 26 | self.cursor = self.db.cursor() 27 | 28 | def tearDown(self): 29 | self.cursor.close() 30 | self.cursor = None 31 | self.db.close() 32 | 33 | def testTimeRoundtrip(self): 34 | self.cursor.execute("SELECT %s as f1", (datetime.time(4, 5, 6),)) 35 | retval = self.cursor.fetchall() 36 | self.assertEqual(retval[0][0], datetime.time(4, 5, 6)) 37 | 38 | def testDateRoundtrip(self): 39 | v = datetime.date(2001, 2, 3) 40 | self.cursor.execute("SELECT %s as f1", (v,)) 41 | retval = self.cursor.fetchall() 42 | self.assertEqual(retval[0][0], v) 43 | 44 | def testBoolRoundtrip(self): 45 | self.cursor.execute("SELECT %s as f1", (True,)) 46 | retval = self.cursor.fetchall() 47 | self.assertEqual(retval[0][0], True) 48 | 49 | def testNullRoundtrip(self): 50 | # We can't just "SELECT %s" and set None as the parameter, since it has 51 | # no type. That would result in a PG error, "could not determine data 52 | # type of parameter %s". So we create a temporary table, insert null 53 | # values, and read them back. 54 | self.cursor.execute( 55 | "CREATE TEMPORARY TABLE TestNullWrite " 56 | "(f1 int4, f2 timestamp, f3 varchar)") 57 | self.cursor.execute( 58 | "INSERT INTO TestNullWrite VALUES (%s, %s, %s)", 59 | (None, None, None,)) 60 | self.cursor.execute("SELECT * FROM TestNullWrite") 61 | retval = self.cursor.fetchone() 62 | self.assertEqual(retval, [None, None, None]) 63 | 64 | def testNullSelectFailure(self): 65 | # See comment in TestNullRoundtrip. This test is here to ensure that 66 | # this behaviour is documented and doesn't mysteriously change. 67 | self.assertRaises( 68 | pg8000.ProgrammingError, self.cursor.execute, 69 | "SELECT %s as f1", (None,)) 70 | self.db.rollback() 71 | 72 | def testDecimalRoundtrip(self): 73 | values = ( 74 | "1.1", "-1.1", "10000", "20000", "-1000000000.123456789", "1.0", 75 | "12.44") 76 | for v in values: 77 | self.cursor.execute("SELECT %s as f1", (decimal.Decimal(v),)) 78 | retval = self.cursor.fetchall() 79 | self.assertEqual(str(retval[0][0]), v) 80 | 81 | def testFloatRoundtrip(self): 82 | # This test ensures that the binary float value doesn't change in a 83 | # roundtrip to the server. That could happen if the value was 84 | # converted to text and got rounded by a decimal place somewhere. 85 | val = 1.756e-12 86 | bin_orig = struct.pack("!d", val) 87 | self.cursor.execute("SELECT %s as f1", (val,)) 88 | retval = self.cursor.fetchall() 89 | bin_new = struct.pack("!d", retval[0][0]) 90 | self.assertEqual(bin_new, bin_orig) 91 | 92 | def test_float_plus_infinity_roundtrip(self): 93 | v = float('inf') 94 | self.cursor.execute("SELECT %s as f1", (v,)) 95 | retval = self.cursor.fetchall() 96 | self.assertEqual(retval[0][0], v) 97 | 98 | def testStrRoundtrip(self): 99 | v = "hello world" 100 | self.cursor.execute( 101 | "create temporary table test_str (f character varying(255))") 102 | self.cursor.execute("INSERT INTO test_str VALUES (%s)", (v,)) 103 | self.cursor.execute("SELECT * from test_str") 104 | retval = self.cursor.fetchall() 105 | self.assertEqual(retval[0][0], v) 106 | 107 | if PY2: 108 | v = "hello \xce\x94 world" 109 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v,)) 110 | retval = self.cursor.fetchall() 111 | self.assertEqual(retval[0][0], v.decode('utf8')) 112 | 113 | def test_str_then_int(self): 114 | v1 = "hello world" 115 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v1,)) 116 | retval = self.cursor.fetchall() 117 | self.assertEqual(retval[0][0], v1) 118 | 119 | v2 = 1 120 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v2,)) 121 | retval = self.cursor.fetchall() 122 | self.assertEqual(retval[0][0], str(v2)) 123 | 124 | def testUnicodeRoundtrip(self): 125 | v = u("hello \u0173 world") 126 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v,)) 127 | retval = self.cursor.fetchall() 128 | self.assertEqual(retval[0][0], v) 129 | 130 | def testLongRoundtrip(self): 131 | self.cursor.execute( 132 | "SELECT %s", (50000000000000,)) 133 | retval = self.cursor.fetchall() 134 | self.assertEqual(retval[0][0], 50000000000000) 135 | 136 | def testIntExecuteMany(self): 137 | self.cursor.executemany("SELECT %s", ((1,), (40000,))) 138 | self.cursor.fetchall() 139 | 140 | v = ([None], [4]) 141 | self.cursor.execute( 142 | "create temporary table test_int (f integer)") 143 | self.cursor.executemany("INSERT INTO test_int VALUES (%s)", v) 144 | self.cursor.execute("SELECT * from test_int") 145 | retval = self.cursor.fetchall() 146 | self.assertEqual(retval, v) 147 | 148 | def testIntRoundtrip(self): 149 | int2 = 21 150 | int4 = 23 151 | int8 = 20 152 | 153 | test_values = [ 154 | (0, int2), 155 | (-32767, int2), 156 | (-32768, int4), 157 | (+32767, int2), 158 | (+32768, int4), 159 | (-2147483647, int4), 160 | (-2147483648, int8), 161 | (+2147483647, int4), 162 | (+2147483648, int8), 163 | (-9223372036854775807, int8), 164 | (+9223372036854775807, int8)] 165 | 166 | for value, typoid in test_values: 167 | self.cursor.execute("SELECT %s", (value,)) 168 | retval = self.cursor.fetchall() 169 | self.assertEqual(retval[0][0], value) 170 | column_name, column_typeoid = self.cursor.description[0][0:2] 171 | self.assertEqual(column_typeoid, typoid) 172 | 173 | def testByteaRoundtrip(self): 174 | self.cursor.execute( 175 | "SELECT %s as f1", 176 | (pg8000.Binary(b("\x00\x01\x02\x03\x02\x01\x00")),)) 177 | retval = self.cursor.fetchall() 178 | self.assertEqual(retval[0][0], b("\x00\x01\x02\x03\x02\x01\x00")) 179 | 180 | def test_bytearray_round_trip(self): 181 | binary = b'\x00\x01\x02\x03\x02\x01\x00' 182 | self.cursor.execute("SELECT %s as f1", (bytearray(binary),)) 183 | retval = self.cursor.fetchall() 184 | self.assertEqual(retval[0][0], binary) 185 | 186 | def test_bytearray_subclass_round_trip(self): 187 | class BClass(bytearray): 188 | pass 189 | binary = b'\x00\x01\x02\x03\x02\x01\x00' 190 | self.cursor.execute("SELECT %s as f1", (BClass(binary),)) 191 | retval = self.cursor.fetchall() 192 | self.assertEqual(retval[0][0], binary) 193 | 194 | def testTimestampRoundtrip(self): 195 | v = datetime.datetime(2001, 2, 3, 4, 5, 6, 170000) 196 | self.cursor.execute("SELECT %s as f1", (v,)) 197 | retval = self.cursor.fetchall() 198 | self.assertEqual(retval[0][0], v) 199 | 200 | # Test that time zone doesn't affect it 201 | # Jython 2.5.3 doesn't have a time.tzset() so skip 202 | if not IS_JYTHON: 203 | orig_tz = os.environ.get('TZ') 204 | os.environ['TZ'] = "America/Edmonton" 205 | time.tzset() 206 | 207 | self.cursor.execute("SELECT %s as f1", (v,)) 208 | retval = self.cursor.fetchall() 209 | self.assertEqual(retval[0][0], v) 210 | 211 | if orig_tz is None: 212 | del os.environ['TZ'] 213 | else: 214 | os.environ['TZ'] = orig_tz 215 | time.tzset() 216 | 217 | def testIntervalRoundtrip(self): 218 | v = pg8000.Interval(microseconds=123456789, days=2, months=24) 219 | self.cursor.execute("SELECT %s as f1", (v,)) 220 | retval = self.cursor.fetchall() 221 | self.assertEqual(retval[0][0], v) 222 | 223 | v = datetime.timedelta(seconds=30) 224 | self.cursor.execute("SELECT %s as f1", (v,)) 225 | retval = self.cursor.fetchall() 226 | self.assertEqual(retval[0][0], v) 227 | 228 | def test_enum_str_round_trip(self): 229 | try: 230 | self.cursor.execute( 231 | "create type lepton as enum ('electron', 'muon', 'tau')") 232 | except pg8000.ProgrammingError: 233 | self.db.rollback() 234 | 235 | v = 'muon' 236 | self.cursor.execute("SELECT cast(%s as lepton) as f1", (v,)) 237 | retval = self.cursor.fetchall() 238 | self.assertEqual(retval[0][0], v) 239 | self.cursor.execute( 240 | "CREATE TEMPORARY TABLE testenum " 241 | "(f1 lepton)") 242 | self.cursor.execute( 243 | "INSERT INTO testenum VALUES (cast(%s as lepton))", ('electron',)) 244 | self.cursor.execute("drop table testenum") 245 | self.cursor.execute("drop type lepton") 246 | self.db.commit() 247 | 248 | def test_enum_custom_round_trip(self): 249 | class Lepton(object): 250 | # Implements PEP 435 in the minimal fashion needed 251 | __members__ = OrderedDict() 252 | 253 | def __init__(self, name, value, alias=None): 254 | self.name = name 255 | self.value = value 256 | self.__members__[name] = self 257 | setattr(self.__class__, name, self) 258 | if alias: 259 | self.__members__[alias] = self 260 | setattr(self.__class__, alias, self) 261 | 262 | try: 263 | self.cursor.execute( 264 | "create type lepton as enum ('1', '2', '3')") 265 | except pg8000.ProgrammingError: 266 | self.db.rollback() 267 | 268 | v = Lepton('muon', '2') 269 | self.cursor.execute( 270 | "SELECT cast(%s as lepton) as f1", (PGEnum(v),)) 271 | retval = self.cursor.fetchall() 272 | self.assertEqual(retval[0][0], v.value) 273 | self.cursor.execute("drop type lepton") 274 | self.db.commit() 275 | 276 | def test_enum_py_round_trip(self): 277 | try: 278 | from enum import Enum 279 | 280 | class Lepton(Enum): 281 | electron = '1' 282 | muon = '2' 283 | tau = '3' 284 | 285 | try: 286 | self.cursor.execute( 287 | "create type lepton as enum ('1', '2', '3')") 288 | except pg8000.ProgrammingError: 289 | self.db.rollback() 290 | 291 | v = Lepton.muon 292 | self.cursor.execute("SELECT cast(%s as lepton) as f1", (v,)) 293 | retval = self.cursor.fetchall() 294 | self.assertEqual(retval[0][0], v.value) 295 | 296 | self.cursor.execute( 297 | "CREATE TEMPORARY TABLE testenum " 298 | "(f1 lepton)") 299 | self.cursor.execute( 300 | "INSERT INTO testenum VALUES (cast(%s as lepton))", 301 | (Lepton.electron,)) 302 | self.cursor.execute("drop table testenum") 303 | self.cursor.execute("drop type lepton") 304 | self.db.commit() 305 | except ImportError: 306 | pass 307 | 308 | def testXmlRoundtrip(self): 309 | v = 'gatccgagtac' 310 | self.cursor.execute("select xmlparse(content %s) as f1", (v,)) 311 | retval = self.cursor.fetchall() 312 | self.assertEqual(retval[0][0], v) 313 | 314 | def testUuidRoundtrip(self): 315 | v = uuid.UUID('911460f2-1f43-fea2-3e2c-e01fd5b5069d') 316 | self.cursor.execute("select %s as f1", (v,)) 317 | retval = self.cursor.fetchall() 318 | self.assertEqual(retval[0][0], v) 319 | 320 | def testInetRoundtrip(self): 321 | try: 322 | import ipaddress 323 | 324 | v = ipaddress.ip_network('192.168.0.0/28') 325 | self.cursor.execute("select %s as f1", (v,)) 326 | retval = self.cursor.fetchall() 327 | self.assertEqual(retval[0][0], v) 328 | 329 | v = ipaddress.ip_address('192.168.0.1') 330 | self.cursor.execute("select %s as f1", (v,)) 331 | retval = self.cursor.fetchall() 332 | self.assertEqual(retval[0][0], v) 333 | 334 | except ImportError: 335 | for v in ('192.168.100.128/25', '192.168.0.1'): 336 | self.cursor.execute( 337 | "select cast(cast(%s as varchar) as inet) as f1", (v,)) 338 | retval = self.cursor.fetchall() 339 | self.assertEqual(retval[0][0], v) 340 | 341 | def testXidRoundtrip(self): 342 | v = 86722 343 | self.cursor.execute( 344 | "select cast(cast(%s as varchar) as xid) as f1", (v,)) 345 | retval = self.cursor.fetchall() 346 | self.assertEqual(retval[0][0], v) 347 | 348 | # Should complete without an exception 349 | self.cursor.execute( 350 | "select * from pg_locks where transactionid = %s", (97712,)) 351 | retval = self.cursor.fetchall() 352 | 353 | def testInt2VectorIn(self): 354 | self.cursor.execute("select cast('1 2' as int2vector) as f1") 355 | retval = self.cursor.fetchall() 356 | self.assertEqual(retval[0][0], [1, 2]) 357 | 358 | # Should complete without an exception 359 | self.cursor.execute("select indkey from pg_index") 360 | retval = self.cursor.fetchall() 361 | 362 | def testTimestampTzOut(self): 363 | self.cursor.execute( 364 | "SELECT '2001-02-03 04:05:06.17 America/Edmonton'" 365 | "::timestamp with time zone") 366 | retval = self.cursor.fetchall() 367 | dt = retval[0][0] 368 | self.assertEqual(dt.tzinfo is not None, True, "no tzinfo returned") 369 | self.assertEqual( 370 | dt.astimezone(pg8000.utc), 371 | datetime.datetime(2001, 2, 3, 11, 5, 6, 170000, pg8000.utc), 372 | "retrieved value match failed") 373 | 374 | def testTimestampTzRoundtrip(self): 375 | if not IS_JYTHON: 376 | mst = pytz.timezone("America/Edmonton") 377 | v1 = mst.localize(datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)) 378 | self.cursor.execute("SELECT %s as f1", (v1,)) 379 | retval = self.cursor.fetchall() 380 | v2 = retval[0][0] 381 | self.assertNotEqual(v2.tzinfo, None) 382 | self.assertEqual(v1, v2) 383 | 384 | def testTimestampMismatch(self): 385 | if not IS_JYTHON: 386 | mst = pytz.timezone("America/Edmonton") 387 | self.cursor.execute("SET SESSION TIME ZONE 'America/Edmonton'") 388 | try: 389 | self.cursor.execute( 390 | "CREATE TEMPORARY TABLE TestTz " 391 | "(f1 timestamp with time zone, " 392 | "f2 timestamp without time zone)") 393 | self.cursor.execute( 394 | "INSERT INTO TestTz (f1, f2) VALUES (%s, %s)", ( 395 | # insert timestamp into timestamptz field (v1) 396 | datetime.datetime(2001, 2, 3, 4, 5, 6, 170000), 397 | # insert timestamptz into timestamp field (v2) 398 | mst.localize(datetime.datetime( 399 | 2001, 2, 3, 4, 5, 6, 170000)))) 400 | self.cursor.execute("SELECT f1, f2 FROM TestTz") 401 | retval = self.cursor.fetchall() 402 | 403 | # when inserting a timestamp into a timestamptz field, 404 | # postgresql assumes that it is in local time. So the value 405 | # that comes out will be the server's local time interpretation 406 | # of v1. We've set the server's TZ to MST, the time should 407 | # be... 408 | f1 = retval[0][0] 409 | self.assertEqual( 410 | f1, datetime.datetime( 411 | 2001, 2, 3, 11, 5, 6, 170000, pytz.utc)) 412 | 413 | # inserting the timestamptz into a timestamp field, pg8000 414 | # converts the value into UTC, and then the PG server converts 415 | # it into local time for insertion into the field. When we 416 | # query for it, we get the same time back, like the tz was 417 | # dropped. 418 | f2 = retval[0][1] 419 | self.assertEqual( 420 | f2, datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)) 421 | finally: 422 | self.cursor.execute("SET SESSION TIME ZONE DEFAULT") 423 | 424 | def testNameOut(self): 425 | # select a field that is of "name" type: 426 | self.cursor.execute("SELECT usename FROM pg_user") 427 | self.cursor.fetchall() 428 | # It is sufficient that no errors were encountered. 429 | 430 | def testOidOut(self): 431 | self.cursor.execute("SELECT oid FROM pg_type") 432 | self.cursor.fetchall() 433 | # It is sufficient that no errors were encountered. 434 | 435 | def testBooleanOut(self): 436 | self.cursor.execute("SELECT cast('t' as bool)") 437 | retval = self.cursor.fetchall() 438 | self.assertTrue(retval[0][0]) 439 | 440 | def testNumericOut(self): 441 | for num in ('5000', '50.34'): 442 | self.cursor.execute("SELECT " + num + "::numeric") 443 | retval = self.cursor.fetchall() 444 | self.assertEqual(str(retval[0][0]), num) 445 | 446 | def testInt2Out(self): 447 | self.cursor.execute("SELECT 5000::smallint") 448 | retval = self.cursor.fetchall() 449 | self.assertEqual(retval[0][0], 5000) 450 | 451 | def testInt4Out(self): 452 | self.cursor.execute("SELECT 5000::integer") 453 | retval = self.cursor.fetchall() 454 | self.assertEqual(retval[0][0], 5000) 455 | 456 | def testInt8Out(self): 457 | self.cursor.execute("SELECT 50000000000000::bigint") 458 | retval = self.cursor.fetchall() 459 | self.assertEqual(retval[0][0], 50000000000000) 460 | 461 | def testFloat4Out(self): 462 | self.cursor.execute("SELECT 1.1::real") 463 | retval = self.cursor.fetchall() 464 | self.assertEqual(retval[0][0], 1.1000000238418579) 465 | 466 | def testFloat8Out(self): 467 | self.cursor.execute("SELECT 1.1::double precision") 468 | retval = self.cursor.fetchall() 469 | self.assertEqual(retval[0][0], 1.1000000000000001) 470 | 471 | def testVarcharOut(self): 472 | self.cursor.execute("SELECT 'hello'::varchar(20)") 473 | retval = self.cursor.fetchall() 474 | self.assertEqual(retval[0][0], "hello") 475 | 476 | def testCharOut(self): 477 | self.cursor.execute("SELECT 'hello'::char(20)") 478 | retval = self.cursor.fetchall() 479 | self.assertEqual(retval[0][0], "hello ") 480 | 481 | def testTextOut(self): 482 | self.cursor.execute("SELECT 'hello'::text") 483 | retval = self.cursor.fetchall() 484 | self.assertEqual(retval[0][0], "hello") 485 | 486 | def testIntervalOut(self): 487 | self.cursor.execute( 488 | "SELECT '1 month 16 days 12 hours 32 minutes 64 seconds'" 489 | "::interval") 490 | retval = self.cursor.fetchall() 491 | expected_value = pg8000.Interval( 492 | microseconds=(12 * 60 * 60 * 1000 * 1000) + 493 | (32 * 60 * 1000 * 1000) + (64 * 1000 * 1000), 494 | days=16, months=1) 495 | self.assertEqual(retval[0][0], expected_value) 496 | 497 | self.cursor.execute("select interval '30 seconds'") 498 | retval = self.cursor.fetchall() 499 | expected_value = datetime.timedelta(seconds=30) 500 | self.assertEqual(retval[0][0], expected_value) 501 | 502 | self.cursor.execute("select interval '12 days 30 seconds'") 503 | retval = self.cursor.fetchall() 504 | expected_value = datetime.timedelta(days=12, seconds=30) 505 | self.assertEqual(retval[0][0], expected_value) 506 | 507 | def testTimestampOut(self): 508 | self.cursor.execute("SELECT '2001-02-03 04:05:06.17'::timestamp") 509 | retval = self.cursor.fetchall() 510 | self.assertEqual( 511 | retval[0][0], datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)) 512 | 513 | # confirms that pg8000's binary output methods have the same output for 514 | # a data type as the PG server 515 | def testBinaryOutputMethods(self): 516 | methods = ( 517 | ("float8send", 22.2), 518 | ("timestamp_send", datetime.datetime(2001, 2, 3, 4, 5, 6, 789)), 519 | ("byteasend", pg8000.Binary(b("\x01\x02"))), 520 | ("interval_send", pg8000.Interval(1234567, 123, 123)),) 521 | for method_out, value in methods: 522 | self.cursor.execute("SELECT %s(%%s) as f1" % method_out, (value,)) 523 | retval = self.cursor.fetchall() 524 | self.assertEqual( 525 | retval[0][0], self.db.make_params((value,))[0][2](value)) 526 | 527 | def testInt4ArrayOut(self): 528 | self.cursor.execute( 529 | "SELECT '{1,2,3,4}'::INT[] AS f1, " 530 | "'{{1,2,3},{4,5,6}}'::INT[][] AS f2, " 531 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::INT[][][] AS f3") 532 | f1, f2, f3 = self.cursor.fetchone() 533 | self.assertEqual(f1, [1, 2, 3, 4]) 534 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]]) 535 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]]) 536 | 537 | def testInt2ArrayOut(self): 538 | self.cursor.execute( 539 | "SELECT '{1,2,3,4}'::INT2[] AS f1, " 540 | "'{{1,2,3},{4,5,6}}'::INT2[][] AS f2, " 541 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::INT2[][][] AS f3") 542 | f1, f2, f3 = self.cursor.fetchone() 543 | self.assertEqual(f1, [1, 2, 3, 4]) 544 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]]) 545 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]]) 546 | 547 | def testInt8ArrayOut(self): 548 | self.cursor.execute( 549 | "SELECT '{1,2,3,4}'::INT8[] AS f1, " 550 | "'{{1,2,3},{4,5,6}}'::INT8[][] AS f2, " 551 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::INT8[][][] AS f3") 552 | f1, f2, f3 = self.cursor.fetchone() 553 | self.assertEqual(f1, [1, 2, 3, 4]) 554 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]]) 555 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]]) 556 | 557 | def testBoolArrayOut(self): 558 | self.cursor.execute( 559 | "SELECT '{TRUE,FALSE,FALSE,TRUE}'::BOOL[] AS f1, " 560 | "'{{TRUE,FALSE,TRUE},{FALSE,TRUE,FALSE}}'::BOOL[][] AS f2, " 561 | "'{{{TRUE,FALSE},{FALSE,TRUE}},{{NULL,TRUE},{FALSE,FALSE}}}'" 562 | "::BOOL[][][] AS f3") 563 | f1, f2, f3 = self.cursor.fetchone() 564 | self.assertEqual(f1, [True, False, False, True]) 565 | self.assertEqual(f2, [[True, False, True], [False, True, False]]) 566 | self.assertEqual( 567 | f3, 568 | [[[True, False], [False, True]], [[None, True], [False, False]]]) 569 | 570 | def testFloat4ArrayOut(self): 571 | self.cursor.execute( 572 | "SELECT '{1,2,3,4}'::FLOAT4[] AS f1, " 573 | "'{{1,2,3},{4,5,6}}'::FLOAT4[][] AS f2, " 574 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::FLOAT4[][][] AS f3") 575 | f1, f2, f3 = self.cursor.fetchone() 576 | self.assertEqual(f1, [1, 2, 3, 4]) 577 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]]) 578 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]]) 579 | 580 | def testFloat8ArrayOut(self): 581 | self.cursor.execute( 582 | "SELECT '{1,2,3,4}'::FLOAT8[] AS f1, " 583 | "'{{1,2,3},{4,5,6}}'::FLOAT8[][] AS f2, " 584 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::FLOAT8[][][] AS f3") 585 | f1, f2, f3 = self.cursor.fetchone() 586 | self.assertEqual(f1, [1, 2, 3, 4]) 587 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]]) 588 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]]) 589 | 590 | def testIntArrayRoundtrip(self): 591 | # send small int array, should be sent as INT2[] 592 | self.cursor.execute("SELECT %s as f1", ([1, 2, 3],)) 593 | retval = self.cursor.fetchall() 594 | self.assertEqual(retval[0][0], [1, 2, 3]) 595 | column_name, column_typeoid = self.cursor.description[0][0:2] 596 | self.assertEqual(column_typeoid, 1005, "type should be INT2[]") 597 | 598 | # test multi-dimensional array, should be sent as INT2[] 599 | self.cursor.execute("SELECT %s as f1", ([[1, 2], [3, 4]],)) 600 | retval = self.cursor.fetchall() 601 | self.assertEqual(retval[0][0], [[1, 2], [3, 4]]) 602 | 603 | column_name, column_typeoid = self.cursor.description[0][0:2] 604 | self.assertEqual(column_typeoid, 1005, "type should be INT2[]") 605 | 606 | # a larger value should kick it up to INT4[]... 607 | self.cursor.execute("SELECT %s as f1 -- integer[]", ([70000, 2, 3],)) 608 | retval = self.cursor.fetchall() 609 | self.assertEqual(retval[0][0], [70000, 2, 3]) 610 | column_name, column_typeoid = self.cursor.description[0][0:2] 611 | self.assertEqual(column_typeoid, 1007, "type should be INT4[]") 612 | 613 | # a much larger value should kick it up to INT8[]... 614 | self.cursor.execute( 615 | "SELECT %s as f1 -- bigint[]", ([7000000000, 2, 3],)) 616 | retval = self.cursor.fetchall() 617 | self.assertEqual( 618 | retval[0][0], [7000000000, 2, 3], 619 | "retrieved value match failed") 620 | column_name, column_typeoid = self.cursor.description[0][0:2] 621 | self.assertEqual(column_typeoid, 1016, "type should be INT8[]") 622 | 623 | def testIntArrayWithNullRoundtrip(self): 624 | self.cursor.execute("SELECT %s as f1", ([1, None, 3],)) 625 | retval = self.cursor.fetchall() 626 | self.assertEqual(retval[0][0], [1, None, 3]) 627 | 628 | def testFloatArrayRoundtrip(self): 629 | self.cursor.execute("SELECT %s as f1", ([1.1, 2.2, 3.3],)) 630 | retval = self.cursor.fetchall() 631 | self.assertEqual(retval[0][0], [1.1, 2.2, 3.3]) 632 | 633 | def testBoolArrayRoundtrip(self): 634 | self.cursor.execute("SELECT %s as f1", ([True, False, None],)) 635 | retval = self.cursor.fetchall() 636 | self.assertEqual(retval[0][0], [True, False, None]) 637 | 638 | def testStringArrayOut(self): 639 | self.cursor.execute("SELECT '{a,b,c}'::TEXT[] AS f1") 640 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"]) 641 | self.cursor.execute("SELECT '{a,b,c}'::CHAR[] AS f1") 642 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"]) 643 | self.cursor.execute("SELECT '{a,b,c}'::VARCHAR[] AS f1") 644 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"]) 645 | self.cursor.execute("SELECT '{a,b,c}'::CSTRING[] AS f1") 646 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"]) 647 | self.cursor.execute("SELECT '{a,b,c}'::NAME[] AS f1") 648 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"]) 649 | self.cursor.execute("SELECT '{}'::text[];") 650 | self.assertEqual(self.cursor.fetchone()[0], []) 651 | self.cursor.execute("SELECT '{NULL,\"NULL\",NULL,\"\"}'::text[];") 652 | self.assertEqual(self.cursor.fetchone()[0], [None, 'NULL', None, ""]) 653 | 654 | def testNumericArrayOut(self): 655 | self.cursor.execute("SELECT '{1.1,2.2,3.3}'::numeric[] AS f1") 656 | self.assertEqual( 657 | self.cursor.fetchone()[0], [ 658 | decimal.Decimal("1.1"), decimal.Decimal("2.2"), 659 | decimal.Decimal("3.3")]) 660 | 661 | def testNumericArrayRoundtrip(self): 662 | v = [decimal.Decimal("1.1"), None, decimal.Decimal("3.3")] 663 | self.cursor.execute("SELECT %s as f1", (v,)) 664 | retval = self.cursor.fetchall() 665 | self.assertEqual(retval[0][0], v) 666 | 667 | def testStringArrayRoundtrip(self): 668 | v = [ 669 | "Hello!", "World!", "abcdefghijklmnopqrstuvwxyz", "", 670 | "A bunch of random characters:", 671 | " ~!@#$%^&*()_+`1234567890-=[]\\{}|{;':\",./<>?\t", None] 672 | self.cursor.execute("SELECT %s as f1", (v,)) 673 | retval = self.cursor.fetchall() 674 | self.assertEqual(retval[0][0], v) 675 | 676 | def testUnicodeArrayRoundtrip(self): 677 | if PY2: 678 | v = map(unicode, ("Second", "To", None)) # noqa 679 | self.cursor.execute("SELECT %s as f1", (v,)) 680 | retval = self.cursor.fetchall() 681 | self.assertEqual(retval[0][0], v) 682 | 683 | def testEmptyArray(self): 684 | v = [] 685 | self.cursor.execute("SELECT %s as f1", (v,)) 686 | retval = self.cursor.fetchall() 687 | self.assertEqual(retval[0][0], v) 688 | 689 | def testArrayContentNotSupported(self): 690 | class Kajigger(object): 691 | pass 692 | self.assertRaises( 693 | pg8000.ArrayContentNotSupportedError, 694 | self.db.array_inspect, [[Kajigger()], [None], [None]]) 695 | self.db.rollback() 696 | 697 | def testArrayDimensions(self): 698 | for arr in ( 699 | [1, [2]], [[1], [2], [3, 4]], 700 | [[[1]], [[2]], [[3, 4]]], 701 | [[[1]], [[2]], [[3, 4]]], 702 | [[[[1]]], [[[2]]], [[[3, 4]]]], 703 | [[1, 2, 3], [4, [5], 6]]): 704 | 705 | arr_send = self.db.array_inspect(arr)[2] 706 | self.assertRaises( 707 | pg8000.ArrayDimensionsNotConsistentError, arr_send, arr) 708 | self.db.rollback() 709 | 710 | def testArrayHomogenous(self): 711 | arr = [[[1]], [[2]], [[3.1]]] 712 | arr_send = self.db.array_inspect(arr)[2] 713 | self.assertRaises( 714 | pg8000.ArrayContentNotHomogenousError, arr_send, arr) 715 | self.db.rollback() 716 | 717 | def testArrayInspect(self): 718 | self.db.array_inspect([1, 2, 3]) 719 | self.db.array_inspect([[1], [2], [3]]) 720 | self.db.array_inspect([[[1]], [[2]], [[3]]]) 721 | 722 | def testMacaddr(self): 723 | self.cursor.execute("SELECT macaddr '08002b:010203'") 724 | retval = self.cursor.fetchall() 725 | self.assertEqual(retval[0][0], "08:00:2b:01:02:03") 726 | 727 | def testTsvectorRoundtrip(self): 728 | self.cursor.execute( 729 | "SELECT cast(%s as tsvector)", 730 | ('a fat cat sat on a mat and ate a fat rat',)) 731 | retval = self.cursor.fetchall() 732 | self.assertEqual( 733 | retval[0][0], "'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat'") 734 | 735 | def testHstoreRoundtrip(self): 736 | val = '"a"=>"1"' 737 | self.cursor.execute("SELECT cast(%s as hstore)", (val,)) 738 | retval = self.cursor.fetchall() 739 | self.assertEqual(retval[0][0], val) 740 | 741 | def testJsonRoundtrip(self): 742 | if self.db._server_version >= LooseVersion('9.2'): 743 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003} 744 | self.cursor.execute( 745 | "SELECT %s", (pg8000.PGJson(val),)) 746 | retval = self.cursor.fetchall() 747 | self.assertEqual(retval[0][0], val) 748 | 749 | def testJsonbRoundtrip(self): 750 | if self.db._server_version >= LooseVersion('9.4'): 751 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003} 752 | self.cursor.execute( 753 | "SELECT cast(%s as jsonb)", (json.dumps(val),)) 754 | retval = self.cursor.fetchall() 755 | self.assertEqual(retval[0][0], val) 756 | 757 | def test_json_access_object(self): 758 | if self.db._server_version >= LooseVersion('9.4'): 759 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003} 760 | self.cursor.execute( 761 | "SELECT cast(%s as json) -> %s", (json.dumps(val), 'name')) 762 | retval = self.cursor.fetchall() 763 | self.assertEqual(retval[0][0], 'Apollo 11 Cave') 764 | 765 | def test_jsonb_access_object(self): 766 | if self.db._server_version >= LooseVersion('9.4'): 767 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003} 768 | self.cursor.execute( 769 | "SELECT cast(%s as jsonb) -> %s", (json.dumps(val), 'name')) 770 | retval = self.cursor.fetchall() 771 | self.assertEqual(retval[0][0], 'Apollo 11 Cave') 772 | 773 | def test_json_access_array(self): 774 | if self.db._server_version >= LooseVersion('9.4'): 775 | val = [-1, -2, -3, -4, -5] 776 | self.cursor.execute( 777 | "SELECT cast(%s as json) -> %s", (json.dumps(val), 2)) 778 | retval = self.cursor.fetchall() 779 | self.assertEqual(retval[0][0], -3) 780 | 781 | def testJsonbAccessArray(self): 782 | if self.db._server_version >= LooseVersion('9.4'): 783 | val = [-1, -2, -3, -4, -5] 784 | self.cursor.execute( 785 | "SELECT cast(%s as jsonb) -> %s", (json.dumps(val), 2)) 786 | retval = self.cursor.fetchall() 787 | self.assertEqual(retval[0][0], -3) 788 | 789 | def test_jsonb_access_path(self): 790 | if self.db._server_version >= LooseVersion('9.4'): 791 | j = { 792 | "a": [1, 2, 3], 793 | "b": [4, 5, 6]} 794 | 795 | path = ['a', '2'] 796 | 797 | self.cursor.execute("SELECT %s #>> %s", [PGJsonb(j), path]) 798 | retval = self.cursor.fetchall() 799 | self.assertEqual(retval[0][0], str(j[path[0]][int(path[1])])) 800 | 801 | def test_timestamp_send_float(self): 802 | assert b('A\xbe\x19\xcf\x80\x00\x00\x00') == \ 803 | pg8000.core.timestamp_send_float( 804 | datetime.datetime(2016, 1, 2, 0, 0)) 805 | 806 | def test_infinity_timestamp_roundtrip(self): 807 | v = 'infinity' 808 | self.cursor.execute("SELECT cast(%s as timestamp) as f1", (v,)) 809 | retval = self.cursor.fetchall() 810 | self.assertEqual(retval[0][0], v) 811 | 812 | 813 | if __name__ == "__main__": 814 | unittest.main() 815 | -------------------------------------------------------------------------------- /tests/test_typeobjects.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from pg8000 import Interval 3 | 4 | 5 | # Type conversion tests 6 | class Tests(unittest.TestCase): 7 | def setUp(self): 8 | pass 9 | 10 | def tearDown(self): 11 | pass 12 | 13 | def testIntervalConstructor(self): 14 | i = Interval(days=1) 15 | self.assertEqual(i.months, 0) 16 | self.assertEqual(i.days, 1) 17 | self.assertEqual(i.microseconds, 0) 18 | 19 | def intervalRangeTest(self, parameter, in_range, out_of_range): 20 | for v in out_of_range: 21 | try: 22 | Interval(**{parameter: v}) 23 | self.fail("expected OverflowError") 24 | except OverflowError: 25 | pass 26 | for v in in_range: 27 | Interval(**{parameter: v}) 28 | 29 | def testIntervalDaysRange(self): 30 | out_of_range_days = (-2147483648, +2147483648,) 31 | in_range_days = (-2147483647, +2147483647,) 32 | self.intervalRangeTest("days", in_range_days, out_of_range_days) 33 | 34 | def testIntervalMonthsRange(self): 35 | out_of_range_months = (-2147483648, +2147483648,) 36 | in_range_months = (-2147483647, +2147483647,) 37 | self.intervalRangeTest("months", in_range_months, out_of_range_months) 38 | 39 | def testIntervalMicrosecondsRange(self): 40 | out_of_range_microseconds = ( 41 | -9223372036854775808, +9223372036854775808,) 42 | in_range_microseconds = ( 43 | -9223372036854775807, +9223372036854775807,) 44 | self.intervalRangeTest( 45 | "microseconds", in_range_microseconds, out_of_range_microseconds) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | skip_missing_interpreters=True 8 | 9 | 10 | [testenv] 11 | commands = 12 | nosetests -x 13 | python -m doctest README.adoc 14 | flake8 pg8000 15 | python setup.py check 16 | deps = 17 | nose 18 | flake8 19 | pytz 20 | -------------------------------------------------------------------------------- /versioneer.py: -------------------------------------------------------------------------------- 1 | 2 | # Version: 0.15 3 | 4 | """ 5 | The Versioneer 6 | ============== 7 | 8 | * like a rocketeer, but for versions! 9 | * https://github.com/warner/python-versioneer 10 | * Brian Warner 11 | * License: Public Domain 12 | * Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, and pypy 13 | * [![Latest Version] 14 | (https://pypip.in/version/versioneer/badge.svg?style=flat) 15 | ](https://pypi.python.org/pypi/versioneer/) 16 | * [![Build Status] 17 | (https://travis-ci.org/warner/python-versioneer.png?branch=master) 18 | ](https://travis-ci.org/warner/python-versioneer) 19 | 20 | This is a tool for managing a recorded version number in distutils-based 21 | python projects. The goal is to remove the tedious and error-prone "update 22 | the embedded version string" step from your release process. Making a new 23 | release should be as easy as recording a new tag in your version-control 24 | system, and maybe making new tarballs. 25 | 26 | 27 | ## Quick Install 28 | 29 | * `pip install versioneer` to somewhere to your $PATH 30 | * add a `[versioneer]` section to your setup.cfg (see below) 31 | * run `versioneer install` in your source tree, commit the results 32 | 33 | ## Version Identifiers 34 | 35 | Source trees come from a variety of places: 36 | 37 | * a version-control system checkout (mostly used by developers) 38 | * a nightly tarball, produced by build automation 39 | * a snapshot tarball, produced by a web-based VCS browser, like github's 40 | "tarball from tag" feature 41 | * a release tarball, produced by "setup.py sdist", distributed through PyPI 42 | 43 | Within each source tree, the version identifier (either a string or a number, 44 | this tool is format-agnostic) can come from a variety of places: 45 | 46 | * ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows 47 | about recent "tags" and an absolute revision-id 48 | * the name of the directory into which the tarball was unpacked 49 | * an expanded VCS keyword ($Id$, etc) 50 | * a `_version.py` created by some earlier build step 51 | 52 | For released software, the version identifier is closely related to a VCS 53 | tag. Some projects use tag names that include more than just the version 54 | string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool 55 | needs to strip the tag prefix to extract the version identifier. For 56 | unreleased software (between tags), the version identifier should provide 57 | enough information to help developers recreate the same tree, while also 58 | giving them an idea of roughly how old the tree is (after version 1.2, before 59 | version 1.3). Many VCS systems can report a description that captures this, 60 | for example `git describe --tags --dirty --always` reports things like 61 | "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 62 | 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has 63 | uncommitted changes. 64 | 65 | The version identifier is used for multiple purposes: 66 | 67 | * to allow the module to self-identify its version: `myproject.__version__` 68 | * to choose a name and prefix for a 'setup.py sdist' tarball 69 | 70 | ## Theory of Operation 71 | 72 | Versioneer works by adding a special `_version.py` file into your source 73 | tree, where your `__init__.py` can import it. This `_version.py` knows how to 74 | dynamically ask the VCS tool for version information at import time. 75 | 76 | `_version.py` also contains `$Revision$` markers, and the installation 77 | process marks `_version.py` to have this marker rewritten with a tag name 78 | during the `git archive` command. As a result, generated tarballs will 79 | contain enough information to get the proper version. 80 | 81 | To allow `setup.py` to compute a version too, a `versioneer.py` is added to 82 | the top level of your source tree, next to `setup.py` and the `setup.cfg` 83 | that configures it. This overrides several distutils/setuptools commands to 84 | compute the version when invoked, and changes `setup.py build` and `setup.py 85 | sdist` to replace `_version.py` with a small static file that contains just 86 | the generated version data. 87 | 88 | ## Installation 89 | 90 | First, decide on values for the following configuration variables: 91 | 92 | * `VCS`: the version control system you use. Currently accepts "git". 93 | 94 | * `style`: the style of version string to be produced. See "Styles" below for 95 | details. Defaults to "pep440", which looks like 96 | `TAG[+DISTANCE.gSHORTHASH[.dirty]]`. 97 | 98 | * `versionfile_source`: 99 | 100 | A project-relative pathname into which the generated version strings should 101 | be written. This is usually a `_version.py` next to your project's main 102 | `__init__.py` file, so it can be imported at runtime. If your project uses 103 | `src/myproject/__init__.py`, this should be `src/myproject/_version.py`. 104 | This file should be checked in to your VCS as usual: the copy created below 105 | by `setup.py setup_versioneer` will include code that parses expanded VCS 106 | keywords in generated tarballs. The 'build' and 'sdist' commands will 107 | replace it with a copy that has just the calculated version string. 108 | 109 | This must be set even if your project does not have any modules (and will 110 | therefore never import `_version.py`), since "setup.py sdist" -based trees 111 | still need somewhere to record the pre-calculated version strings. Anywhere 112 | in the source tree should do. If there is a `__init__.py` next to your 113 | `_version.py`, the `setup.py setup_versioneer` command (described below) 114 | will append some `__version__`-setting assignments, if they aren't already 115 | present. 116 | 117 | * `versionfile_build`: 118 | 119 | Like `versionfile_source`, but relative to the build directory instead of 120 | the source directory. These will differ when your setup.py uses 121 | 'package_dir='. If you have `package_dir={'myproject': 'src/myproject'}`, 122 | then you will probably have `versionfile_build='myproject/_version.py'` and 123 | `versionfile_source='src/myproject/_version.py'`. 124 | 125 | If this is set to None, then `setup.py build` will not attempt to rewrite 126 | any `_version.py` in the built tree. If your project does not have any 127 | libraries (e.g. if it only builds a script), then you should use 128 | `versionfile_build = None` and override `distutils.command.build_scripts` 129 | to explicitly insert a copy of `versioneer.get_version()` into your 130 | generated script. 131 | 132 | * `tag_prefix`: 133 | 134 | a string, like 'PROJECTNAME-', which appears at the start of all VCS tags. 135 | If your tags look like 'myproject-1.2.0', then you should use 136 | tag_prefix='myproject-'. If you use unprefixed tags like '1.2.0', this 137 | should be an empty string. 138 | 139 | * `parentdir_prefix`: 140 | 141 | a optional string, frequently the same as tag_prefix, which appears at the 142 | start of all unpacked tarball filenames. If your tarball unpacks into 143 | 'myproject-1.2.0', this should be 'myproject-'. To disable this feature, 144 | just omit the field from your `setup.cfg`. 145 | 146 | This tool provides one script, named `versioneer`. That script has one mode, 147 | "install", which writes a copy of `versioneer.py` into the current directory 148 | and runs `versioneer.py setup` to finish the installation. 149 | 150 | To versioneer-enable your project: 151 | 152 | * 1: Modify your `setup.cfg`, adding a section named `[versioneer]` and 153 | populating it with the configuration values you decided earlier (note that 154 | the option names are not case-sensitive): 155 | 156 | ```` 157 | [versioneer] 158 | VCS = git 159 | style = pep440 160 | versionfile_source = src/myproject/_version.py 161 | versionfile_build = myproject/_version.py 162 | tag_prefix = "" 163 | parentdir_prefix = myproject- 164 | ```` 165 | 166 | * 2: Run `versioneer install`. This will do the following: 167 | 168 | * copy `versioneer.py` into the top of your source tree 169 | * create `_version.py` in the right place (`versionfile_source`) 170 | * modify your `__init__.py` (if one exists next to `_version.py`) to define 171 | `__version__` (by calling a function from `_version.py`) 172 | * modify your `MANIFEST.in` to include both `versioneer.py` and the 173 | generated `_version.py` in sdist tarballs 174 | 175 | `versioneer install` will complain about any problems it finds with your 176 | `setup.py` or `setup.cfg`. Run it multiple times until you have fixed all 177 | the problems. 178 | 179 | * 3: add a `import versioneer` to your setup.py, and add the following 180 | arguments to the setup() call: 181 | 182 | version=versioneer.get_version(), 183 | cmdclass=versioneer.get_cmdclass(), 184 | 185 | * 4: commit these changes to your VCS. To make sure you won't forget, 186 | `versioneer install` will mark everything it touched for addition using 187 | `git add`. Don't forget to add `setup.py` and `setup.cfg` too. 188 | 189 | ## Post-Installation Usage 190 | 191 | Once established, all uses of your tree from a VCS checkout should get the 192 | current version string. All generated tarballs should include an embedded 193 | version string (so users who unpack them will not need a VCS tool installed). 194 | 195 | If you distribute your project through PyPI, then the release process should 196 | boil down to two steps: 197 | 198 | * 1: git tag 1.0 199 | * 2: python setup.py register sdist upload 200 | 201 | If you distribute it through github (i.e. users use github to generate 202 | tarballs with `git archive`), the process is: 203 | 204 | * 1: git tag 1.0 205 | * 2: git push; git push --tags 206 | 207 | Versioneer will report "0+untagged.NUMCOMMITS.gHASH" until your tree has at 208 | least one tag in its history. 209 | 210 | ## Version-String Flavors 211 | 212 | Code which uses Versioneer can learn about its version string at runtime by 213 | importing `_version` from your main `__init__.py` file and running the 214 | `get_versions()` function. From the "outside" (e.g. in `setup.py`), you can 215 | import the top-level `versioneer.py` and run `get_versions()`. 216 | 217 | Both functions return a dictionary with different flavors of version 218 | information: 219 | 220 | * `['version']`: A condensed version string, rendered using the selected 221 | style. This is the most commonly used value for the project's version 222 | string. The default "pep440" style yields strings like `0.11`, 223 | `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section 224 | below for alternative styles. 225 | 226 | * `['full-revisionid']`: detailed revision identifier. For Git, this is the 227 | full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". 228 | 229 | * `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that 230 | this is only accurate if run in a VCS checkout, otherwise it is likely to 231 | be False or None 232 | 233 | * `['error']`: if the version string could not be computed, this will be set 234 | to a string describing the problem, otherwise it will be None. It may be 235 | useful to throw an exception in setup.py if this is set, to avoid e.g. 236 | creating tarballs with a version string of "unknown". 237 | 238 | Some variants are more useful than others. Including `full-revisionid` in a 239 | bug report should allow developers to reconstruct the exact code being tested 240 | (or indicate the presence of local changes that should be shared with the 241 | developers). `version` is suitable for display in an "about" box or a CLI 242 | `--version` output: it can be easily compared against release notes and lists 243 | of bugs fixed in various releases. 244 | 245 | The installer adds the following text to your `__init__.py` to place a basic 246 | version in `YOURPROJECT.__version__`: 247 | 248 | from ._version import get_versions 249 | __version__ = get_versions()['version'] 250 | del get_versions 251 | 252 | ## Styles 253 | 254 | The setup.cfg `style=` configuration controls how the VCS information is 255 | rendered into a version string. 256 | 257 | The default style, "pep440", produces a PEP440-compliant string, equal to the 258 | un-prefixed tag name for actual releases, and containing an additional "local 259 | version" section with more detail for in-between builds. For Git, this is 260 | TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags 261 | --dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the 262 | tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and 263 | that this commit is two revisions ("+2") beyond the "0.11" tag. For released 264 | software (exactly equal to a known tag), the identifier will only contain the 265 | stripped tag, e.g. "0.11". 266 | 267 | Other styles are available. See details.md in the Versioneer source tree for 268 | descriptions. 269 | 270 | ## Debugging 271 | 272 | Versioneer tries to avoid fatal errors: if something goes wrong, it will tend 273 | to return a version of "0+unknown". To investigate the problem, run `setup.py 274 | version`, which will run the version-lookup code in a verbose mode, and will 275 | display the full contents of `get_versions()` (including the `error` string, 276 | which may help identify what went wrong). 277 | 278 | ## Updating Versioneer 279 | 280 | To upgrade your project to a new release of Versioneer, do the following: 281 | 282 | * install the new Versioneer (`pip install -U versioneer` or equivalent) 283 | * edit `setup.cfg`, if necessary, to include any new configuration settings 284 | indicated by the release notes 285 | * re-run `versioneer install` in your source tree, to replace 286 | `SRC/_version.py` 287 | * commit any changed files 288 | 289 | ### Upgrading to 0.15 290 | 291 | Starting with this version, Versioneer is configured with a `[versioneer]` 292 | section in your `setup.cfg` file. Earlier versions required the `setup.py` to 293 | set attributes on the `versioneer` module immediately after import. The new 294 | version will refuse to run (raising an exception during import) until you 295 | have provided the necessary `setup.cfg` section. 296 | 297 | In addition, the Versioneer package provides an executable named 298 | `versioneer`, and the installation process is driven by running `versioneer 299 | install`. In 0.14 and earlier, the executable was named 300 | `versioneer-installer` and was run without an argument. 301 | 302 | ### Upgrading to 0.14 303 | 304 | 0.14 changes the format of the version string. 0.13 and earlier used 305 | hyphen-separated strings like "0.11-2-g1076c97-dirty". 0.14 and beyond use a 306 | plus-separated "local version" section strings, with dot-separated 307 | components, like "0.11+2.g1076c97". PEP440-strict tools did not like the old 308 | format, but should be ok with the new one. 309 | 310 | ### Upgrading from 0.11 to 0.12 311 | 312 | Nothing special. 313 | 314 | ### Upgrading from 0.10 to 0.11 315 | 316 | You must add a `versioneer.VCS = "git"` to your `setup.py` before re-running 317 | `setup.py setup_versioneer`. This will enable the use of additional 318 | version-control systems (SVN, etc) in the future. 319 | 320 | ## Future Directions 321 | 322 | This tool is designed to make it easily extended to other version-control 323 | systems: all VCS-specific components are in separate directories like 324 | src/git/ . The top-level `versioneer.py` script is assembled from these 325 | components by running make-versioneer.py . In the future, make-versioneer.py 326 | will take a VCS name as an argument, and will construct a version of 327 | `versioneer.py` that is specific to the given VCS. It might also take the 328 | configuration arguments that are currently provided manually during 329 | installation by editing setup.py . Alternatively, it might go the other 330 | direction and include code from all supported VCS systems, reducing the 331 | number of intermediate scripts. 332 | 333 | 334 | ## License 335 | 336 | To make Versioneer easier to embed, all its code is hereby released into the 337 | public domain. The `_version.py` that it creates is also in the public 338 | domain. 339 | 340 | """ 341 | 342 | from __future__ import print_function 343 | try: 344 | import configparser 345 | except ImportError: 346 | import ConfigParser as configparser 347 | import errno 348 | import json 349 | import os 350 | import re 351 | import subprocess 352 | import sys 353 | 354 | 355 | class VersioneerConfig: 356 | pass 357 | 358 | 359 | def get_root(): 360 | # we require that all commands are run from the project root, i.e. the 361 | # directory that contains setup.py, setup.cfg, and versioneer.py . 362 | root = os.path.realpath(os.path.abspath(os.getcwd())) 363 | setup_py = os.path.join(root, "setup.py") 364 | versioneer_py = os.path.join(root, "versioneer.py") 365 | if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): 366 | # allow 'python path/to/setup.py COMMAND' 367 | root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) 368 | setup_py = os.path.join(root, "setup.py") 369 | versioneer_py = os.path.join(root, "versioneer.py") 370 | if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): 371 | err = ("Versioneer was unable to run the project root directory. " 372 | "Versioneer requires setup.py to be executed from " 373 | "its immediate directory (like 'python setup.py COMMAND'), " 374 | "or in a way that lets it use sys.argv[0] to find the root " 375 | "(like 'python path/to/setup.py COMMAND').") 376 | raise VersioneerBadRootError(err) 377 | try: 378 | # Certain runtime workflows (setup.py install/develop in a setuptools 379 | # tree) execute all dependencies in a single python process, so 380 | # "versioneer" may be imported multiple times, and python's shared 381 | # module-import table will cache the first one. So we can't use 382 | # os.path.dirname(__file__), as that will find whichever 383 | # versioneer.py was first imported, even in later projects. 384 | me = os.path.realpath(os.path.abspath(__file__)) 385 | if os.path.splitext(me)[0] != os.path.splitext(versioneer_py)[0]: 386 | print("Warning: build in %s is using versioneer.py from %s" 387 | % (os.path.dirname(me), versioneer_py)) 388 | except NameError: 389 | pass 390 | return root 391 | 392 | 393 | def get_config_from_root(root): 394 | # This might raise EnvironmentError (if setup.cfg is missing), or 395 | # configparser.NoSectionError (if it lacks a [versioneer] section), or 396 | # configparser.NoOptionError (if it lacks "VCS="). See the docstring at 397 | # the top of versioneer.py for instructions on writing your setup.cfg . 398 | setup_cfg = os.path.join(root, "setup.cfg") 399 | parser = configparser.SafeConfigParser() 400 | with open(setup_cfg, "r") as f: 401 | parser.readfp(f) 402 | VCS = parser.get("versioneer", "VCS") # mandatory 403 | 404 | def get(parser, name): 405 | if parser.has_option("versioneer", name): 406 | return parser.get("versioneer", name) 407 | return None 408 | cfg = VersioneerConfig() 409 | cfg.VCS = VCS 410 | cfg.style = get(parser, "style") or "" 411 | cfg.versionfile_source = get(parser, "versionfile_source") 412 | cfg.versionfile_build = get(parser, "versionfile_build") 413 | cfg.tag_prefix = get(parser, "tag_prefix") 414 | cfg.parentdir_prefix = get(parser, "parentdir_prefix") 415 | cfg.verbose = get(parser, "verbose") 416 | return cfg 417 | 418 | 419 | class NotThisMethod(Exception): 420 | pass 421 | 422 | # these dictionaries contain VCS-specific tools 423 | LONG_VERSION_PY = {} 424 | HANDLERS = {} 425 | 426 | 427 | def register_vcs_handler(vcs, method): # decorator 428 | def decorate(f): 429 | if vcs not in HANDLERS: 430 | HANDLERS[vcs] = {} 431 | HANDLERS[vcs][method] = f 432 | return f 433 | return decorate 434 | 435 | 436 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): 437 | assert isinstance(commands, list) 438 | p = None 439 | for c in commands: 440 | try: 441 | dispcmd = str([c] + args) 442 | # remember shell=False, so use git.cmd on windows, not just git 443 | p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, 444 | stderr=(subprocess.PIPE if hide_stderr 445 | else None)) 446 | break 447 | except EnvironmentError: 448 | e = sys.exc_info()[1] 449 | if e.errno == errno.ENOENT: 450 | continue 451 | if verbose: 452 | print("unable to run %s" % dispcmd) 453 | print(e) 454 | return None 455 | else: 456 | if verbose: 457 | print("unable to find command, tried %s" % (commands,)) 458 | return None 459 | stdout = p.communicate()[0].strip() 460 | if sys.version_info[0] >= 3: 461 | stdout = stdout.decode() 462 | if p.returncode != 0: 463 | if verbose: 464 | print("unable to run %s (error)" % dispcmd) 465 | return None 466 | return stdout 467 | LONG_VERSION_PY['git'] = ''' 468 | # This file helps to compute a version number in source trees obtained from 469 | # git-archive tarball (such as those provided by githubs download-from-tag 470 | # feature). Distribution tarballs (built by setup.py sdist) and build 471 | # directories (produced by setup.py build) will contain a much shorter file 472 | # that just contains the computed version number. 473 | 474 | # This file is released into the public domain. Generated by 475 | # versioneer-0.15 (https://github.com/warner/python-versioneer) 476 | 477 | import errno 478 | import os 479 | import re 480 | import subprocess 481 | import sys 482 | 483 | 484 | def get_keywords(): 485 | # these strings will be replaced by git during git-archive. 486 | # setup.py/versioneer.py will grep for the variable names, so they must 487 | # each be defined on a line of their own. _version.py will just call 488 | # get_keywords(). 489 | git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" 490 | git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" 491 | keywords = {"refnames": git_refnames, "full": git_full} 492 | return keywords 493 | 494 | 495 | class VersioneerConfig: 496 | pass 497 | 498 | 499 | def get_config(): 500 | # these strings are filled in when 'setup.py versioneer' creates 501 | # _version.py 502 | cfg = VersioneerConfig() 503 | cfg.VCS = "git" 504 | cfg.style = "%(STYLE)s" 505 | cfg.tag_prefix = "%(TAG_PREFIX)s" 506 | cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" 507 | cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" 508 | cfg.verbose = False 509 | return cfg 510 | 511 | 512 | class NotThisMethod(Exception): 513 | pass 514 | 515 | 516 | LONG_VERSION_PY = {} 517 | HANDLERS = {} 518 | 519 | 520 | def register_vcs_handler(vcs, method): # decorator 521 | def decorate(f): 522 | if vcs not in HANDLERS: 523 | HANDLERS[vcs] = {} 524 | HANDLERS[vcs][method] = f 525 | return f 526 | return decorate 527 | 528 | 529 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): 530 | assert isinstance(commands, list) 531 | p = None 532 | for c in commands: 533 | try: 534 | dispcmd = str([c] + args) 535 | # remember shell=False, so use git.cmd on windows, not just git 536 | p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE, 537 | stderr=(subprocess.PIPE if hide_stderr 538 | else None)) 539 | break 540 | except EnvironmentError: 541 | e = sys.exc_info()[1] 542 | if e.errno == errno.ENOENT: 543 | continue 544 | if verbose: 545 | print("unable to run %%s" %% dispcmd) 546 | print(e) 547 | return None 548 | else: 549 | if verbose: 550 | print("unable to find command, tried %%s" %% (commands,)) 551 | return None 552 | stdout = p.communicate()[0].strip() 553 | if sys.version_info[0] >= 3: 554 | stdout = stdout.decode() 555 | if p.returncode != 0: 556 | if verbose: 557 | print("unable to run %%s (error)" %% dispcmd) 558 | return None 559 | return stdout 560 | 561 | 562 | def versions_from_parentdir(parentdir_prefix, root, verbose): 563 | # Source tarballs conventionally unpack into a directory that includes 564 | # both the project name and a version string. 565 | dirname = os.path.basename(root) 566 | if not dirname.startswith(parentdir_prefix): 567 | if verbose: 568 | print("guessing rootdir is '%%s', but '%%s' doesn't start with " 569 | "prefix '%%s'" %% (root, dirname, parentdir_prefix)) 570 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 571 | return {"version": dirname[len(parentdir_prefix):], 572 | "full-revisionid": None, 573 | "dirty": False, "error": None} 574 | 575 | 576 | @register_vcs_handler("git", "get_keywords") 577 | def git_get_keywords(versionfile_abs): 578 | # the code embedded in _version.py can just fetch the value of these 579 | # keywords. When used from setup.py, we don't want to import _version.py, 580 | # so we do it with a regexp instead. This function is not used from 581 | # _version.py. 582 | keywords = {} 583 | try: 584 | f = open(versionfile_abs, "r") 585 | for line in f.readlines(): 586 | if line.strip().startswith("git_refnames ="): 587 | mo = re.search(r'=\s*"(.*)"', line) 588 | if mo: 589 | keywords["refnames"] = mo.group(1) 590 | if line.strip().startswith("git_full ="): 591 | mo = re.search(r'=\s*"(.*)"', line) 592 | if mo: 593 | keywords["full"] = mo.group(1) 594 | f.close() 595 | except EnvironmentError: 596 | pass 597 | return keywords 598 | 599 | 600 | @register_vcs_handler("git", "keywords") 601 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 602 | if not keywords: 603 | raise NotThisMethod("no keywords at all, weird") 604 | refnames = keywords["refnames"].strip() 605 | if refnames.startswith("$Format"): 606 | if verbose: 607 | print("keywords are unexpanded, not using") 608 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 609 | refs = set([r.strip() for r in refnames.strip("()").split(",")]) 610 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 611 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 612 | TAG = "tag: " 613 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) 614 | if not tags: 615 | # Either we're using git < 1.8.3, or there really are no tags. We use 616 | # a heuristic: assume all version tags have a digit. The old git %%d 617 | # expansion behaves like git log --decorate=short and strips out the 618 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 619 | # between branches and tags. By ignoring refnames without digits, we 620 | # filter out many common branch names like "release" and 621 | # "stabilization", as well as "HEAD" and "master". 622 | tags = set([r for r in refs if re.search(r'\d', r)]) 623 | if verbose: 624 | print("discarding '%%s', no digits" %% ",".join(refs-tags)) 625 | if verbose: 626 | print("likely tags: %%s" %% ",".join(sorted(tags))) 627 | for ref in sorted(tags): 628 | # sorting will prefer e.g. "2.0" over "2.0rc1" 629 | if ref.startswith(tag_prefix): 630 | r = ref[len(tag_prefix):] 631 | if verbose: 632 | print("picking %%s" %% r) 633 | return {"version": r, 634 | "full-revisionid": keywords["full"].strip(), 635 | "dirty": False, "error": None 636 | } 637 | # no suitable tags, so version is "0+unknown", but full hex is still there 638 | if verbose: 639 | print("no suitable tags, using unknown + full revision id") 640 | return {"version": "0+unknown", 641 | "full-revisionid": keywords["full"].strip(), 642 | "dirty": False, "error": "no suitable tags"} 643 | 644 | 645 | @register_vcs_handler("git", "pieces_from_vcs") 646 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): 647 | # this runs 'git' from the root of the source tree. This only gets called 648 | # if the git-archive 'subst' keywords were *not* expanded, and 649 | # _version.py hasn't already been rewritten with a short version string, 650 | # meaning we're inside a checked out source tree. 651 | 652 | if not os.path.exists(os.path.join(root, ".git")): 653 | if verbose: 654 | print("no .git in %%s" %% root) 655 | raise NotThisMethod("no .git directory") 656 | 657 | GITS = ["git"] 658 | if sys.platform == "win32": 659 | GITS = ["git.cmd", "git.exe"] 660 | # if there is a tag, this yields TAG-NUM-gHEX[-dirty] 661 | # if there are no tags, this yields HEX[-dirty] (no NUM) 662 | describe_out = run_command(GITS, ["describe", "--tags", "--dirty", 663 | "--always", "--long"], 664 | cwd=root) 665 | # --long was added in git-1.5.5 666 | if describe_out is None: 667 | raise NotThisMethod("'git describe' failed") 668 | describe_out = describe_out.strip() 669 | full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) 670 | if full_out is None: 671 | raise NotThisMethod("'git rev-parse' failed") 672 | full_out = full_out.strip() 673 | 674 | pieces = {} 675 | pieces["long"] = full_out 676 | pieces["short"] = full_out[:7] # maybe improved later 677 | pieces["error"] = None 678 | 679 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 680 | # TAG might have hyphens. 681 | git_describe = describe_out 682 | 683 | # look for -dirty suffix 684 | dirty = git_describe.endswith("-dirty") 685 | pieces["dirty"] = dirty 686 | if dirty: 687 | git_describe = git_describe[:git_describe.rindex("-dirty")] 688 | 689 | # now we have TAG-NUM-gHEX or HEX 690 | 691 | if "-" in git_describe: 692 | # TAG-NUM-gHEX 693 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 694 | if not mo: 695 | # unparseable. Maybe git-describe is misbehaving? 696 | pieces["error"] = ("unable to parse git-describe output: '%%s'" 697 | %% describe_out) 698 | return pieces 699 | 700 | # tag 701 | full_tag = mo.group(1) 702 | if not full_tag.startswith(tag_prefix): 703 | if verbose: 704 | fmt = "tag '%%s' doesn't start with prefix '%%s'" 705 | print(fmt %% (full_tag, tag_prefix)) 706 | pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" 707 | %% (full_tag, tag_prefix)) 708 | return pieces 709 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 710 | 711 | # distance: number of commits since tag 712 | pieces["distance"] = int(mo.group(2)) 713 | 714 | # commit: short hex revision ID 715 | pieces["short"] = mo.group(3) 716 | 717 | else: 718 | # HEX: no tags 719 | pieces["closest-tag"] = None 720 | count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], 721 | cwd=root) 722 | pieces["distance"] = int(count_out) # total number of commits 723 | 724 | return pieces 725 | 726 | 727 | def plus_or_dot(pieces): 728 | if "+" in pieces.get("closest-tag", ""): 729 | return "." 730 | return "+" 731 | 732 | 733 | def render_pep440(pieces): 734 | # now build up version string, with post-release "local version 735 | # identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 736 | # get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 737 | 738 | # exceptions: 739 | # 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 740 | 741 | if pieces["closest-tag"]: 742 | rendered = pieces["closest-tag"] 743 | if pieces["distance"] or pieces["dirty"]: 744 | rendered += plus_or_dot(pieces) 745 | rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) 746 | if pieces["dirty"]: 747 | rendered += ".dirty" 748 | else: 749 | # exception #1 750 | rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], 751 | pieces["short"]) 752 | if pieces["dirty"]: 753 | rendered += ".dirty" 754 | return rendered 755 | 756 | 757 | def render_pep440_pre(pieces): 758 | # TAG[.post.devDISTANCE] . No -dirty 759 | 760 | # exceptions: 761 | # 1: no tags. 0.post.devDISTANCE 762 | 763 | if pieces["closest-tag"]: 764 | rendered = pieces["closest-tag"] 765 | if pieces["distance"]: 766 | rendered += ".post.dev%%d" %% pieces["distance"] 767 | else: 768 | # exception #1 769 | rendered = "0.post.dev%%d" %% pieces["distance"] 770 | return rendered 771 | 772 | 773 | def render_pep440_post(pieces): 774 | # TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that 775 | # .dev0 sorts backwards (a dirty tree will appear "older" than the 776 | # corresponding clean one), but you shouldn't be releasing software with 777 | # -dirty anyways. 778 | 779 | # exceptions: 780 | # 1: no tags. 0.postDISTANCE[.dev0] 781 | 782 | if pieces["closest-tag"]: 783 | rendered = pieces["closest-tag"] 784 | if pieces["distance"] or pieces["dirty"]: 785 | rendered += ".post%%d" %% pieces["distance"] 786 | if pieces["dirty"]: 787 | rendered += ".dev0" 788 | rendered += plus_or_dot(pieces) 789 | rendered += "g%%s" %% pieces["short"] 790 | else: 791 | # exception #1 792 | rendered = "0.post%%d" %% pieces["distance"] 793 | if pieces["dirty"]: 794 | rendered += ".dev0" 795 | rendered += "+g%%s" %% pieces["short"] 796 | return rendered 797 | 798 | 799 | def render_pep440_old(pieces): 800 | # TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. 801 | 802 | # exceptions: 803 | # 1: no tags. 0.postDISTANCE[.dev0] 804 | 805 | if pieces["closest-tag"]: 806 | rendered = pieces["closest-tag"] 807 | if pieces["distance"] or pieces["dirty"]: 808 | rendered += ".post%%d" %% pieces["distance"] 809 | if pieces["dirty"]: 810 | rendered += ".dev0" 811 | else: 812 | # exception #1 813 | rendered = "0.post%%d" %% pieces["distance"] 814 | if pieces["dirty"]: 815 | rendered += ".dev0" 816 | return rendered 817 | 818 | 819 | def render_git_describe(pieces): 820 | # TAG[-DISTANCE-gHEX][-dirty], like 'git describe --tags --dirty 821 | # --always' 822 | 823 | # exceptions: 824 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix) 825 | 826 | if pieces["closest-tag"]: 827 | rendered = pieces["closest-tag"] 828 | if pieces["distance"]: 829 | rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) 830 | else: 831 | # exception #1 832 | rendered = pieces["short"] 833 | if pieces["dirty"]: 834 | rendered += "-dirty" 835 | return rendered 836 | 837 | 838 | def render_git_describe_long(pieces): 839 | # TAG-DISTANCE-gHEX[-dirty], like 'git describe --tags --dirty 840 | # --always -long'. The distance/hash is unconditional. 841 | 842 | # exceptions: 843 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix) 844 | 845 | if pieces["closest-tag"]: 846 | rendered = pieces["closest-tag"] 847 | rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) 848 | else: 849 | # exception #1 850 | rendered = pieces["short"] 851 | if pieces["dirty"]: 852 | rendered += "-dirty" 853 | return rendered 854 | 855 | 856 | def render(pieces, style): 857 | if pieces["error"]: 858 | return {"version": "unknown", 859 | "full-revisionid": pieces.get("long"), 860 | "dirty": None, 861 | "error": pieces["error"]} 862 | 863 | if not style or style == "default": 864 | style = "pep440" # the default 865 | 866 | if style == "pep440": 867 | rendered = render_pep440(pieces) 868 | elif style == "pep440-pre": 869 | rendered = render_pep440_pre(pieces) 870 | elif style == "pep440-post": 871 | rendered = render_pep440_post(pieces) 872 | elif style == "pep440-old": 873 | rendered = render_pep440_old(pieces) 874 | elif style == "git-describe": 875 | rendered = render_git_describe(pieces) 876 | elif style == "git-describe-long": 877 | rendered = render_git_describe_long(pieces) 878 | else: 879 | raise ValueError("unknown style '%%s'" %% style) 880 | 881 | return {"version": rendered, "full-revisionid": pieces["long"], 882 | "dirty": pieces["dirty"], "error": None} 883 | 884 | 885 | def get_versions(): 886 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 887 | # __file__, we can work backwards from there to the root. Some 888 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 889 | # case we can only use expanded keywords. 890 | 891 | cfg = get_config() 892 | verbose = cfg.verbose 893 | 894 | try: 895 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, 896 | verbose) 897 | except NotThisMethod: 898 | pass 899 | 900 | try: 901 | root = os.path.realpath(__file__) 902 | # versionfile_source is the relative path from the top of the source 903 | # tree (where the .git directory might live) to this file. Invert 904 | # this to find the root from __file__. 905 | for i in cfg.versionfile_source.split('/'): 906 | root = os.path.dirname(root) 907 | except NameError: 908 | return {"version": "0+unknown", "full-revisionid": None, 909 | "dirty": None, 910 | "error": "unable to find root of source tree"} 911 | 912 | try: 913 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 914 | return render(pieces, cfg.style) 915 | except NotThisMethod: 916 | pass 917 | 918 | try: 919 | if cfg.parentdir_prefix: 920 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 921 | except NotThisMethod: 922 | pass 923 | 924 | return {"version": "0+unknown", "full-revisionid": None, 925 | "dirty": None, 926 | "error": "unable to compute version"} 927 | ''' 928 | 929 | 930 | @register_vcs_handler("git", "get_keywords") 931 | def git_get_keywords(versionfile_abs): 932 | # the code embedded in _version.py can just fetch the value of these 933 | # keywords. When used from setup.py, we don't want to import _version.py, 934 | # so we do it with a regexp instead. This function is not used from 935 | # _version.py. 936 | keywords = {} 937 | try: 938 | f = open(versionfile_abs, "r") 939 | for line in f.readlines(): 940 | if line.strip().startswith("git_refnames ="): 941 | mo = re.search(r'=\s*"(.*)"', line) 942 | if mo: 943 | keywords["refnames"] = mo.group(1) 944 | if line.strip().startswith("git_full ="): 945 | mo = re.search(r'=\s*"(.*)"', line) 946 | if mo: 947 | keywords["full"] = mo.group(1) 948 | f.close() 949 | except EnvironmentError: 950 | pass 951 | return keywords 952 | 953 | 954 | @register_vcs_handler("git", "keywords") 955 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 956 | if not keywords: 957 | raise NotThisMethod("no keywords at all, weird") 958 | refnames = keywords["refnames"].strip() 959 | if refnames.startswith("$Format"): 960 | if verbose: 961 | print("keywords are unexpanded, not using") 962 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 963 | refs = set([r.strip() for r in refnames.strip("()").split(",")]) 964 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 965 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 966 | TAG = "tag: " 967 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) 968 | if not tags: 969 | # Either we're using git < 1.8.3, or there really are no tags. We use 970 | # a heuristic: assume all version tags have a digit. The old git %d 971 | # expansion behaves like git log --decorate=short and strips out the 972 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 973 | # between branches and tags. By ignoring refnames without digits, we 974 | # filter out many common branch names like "release" and 975 | # "stabilization", as well as "HEAD" and "master". 976 | tags = set([r for r in refs if re.search(r'\d', r)]) 977 | if verbose: 978 | print("discarding '%s', no digits" % ",".join(refs-tags)) 979 | if verbose: 980 | print("likely tags: %s" % ",".join(sorted(tags))) 981 | for ref in sorted(tags): 982 | # sorting will prefer e.g. "2.0" over "2.0rc1" 983 | if ref.startswith(tag_prefix): 984 | r = ref[len(tag_prefix):] 985 | if verbose: 986 | print("picking %s" % r) 987 | return {"version": r, 988 | "full-revisionid": keywords["full"].strip(), 989 | "dirty": False, "error": None 990 | } 991 | # no suitable tags, so version is "0+unknown", but full hex is still there 992 | if verbose: 993 | print("no suitable tags, using unknown + full revision id") 994 | return {"version": "0+unknown", 995 | "full-revisionid": keywords["full"].strip(), 996 | "dirty": False, "error": "no suitable tags"} 997 | 998 | 999 | @register_vcs_handler("git", "pieces_from_vcs") 1000 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): 1001 | # this runs 'git' from the root of the source tree. This only gets called 1002 | # if the git-archive 'subst' keywords were *not* expanded, and 1003 | # _version.py hasn't already been rewritten with a short version string, 1004 | # meaning we're inside a checked out source tree. 1005 | 1006 | if not os.path.exists(os.path.join(root, ".git")): 1007 | if verbose: 1008 | print("no .git in %s" % root) 1009 | raise NotThisMethod("no .git directory") 1010 | 1011 | GITS = ["git"] 1012 | if sys.platform == "win32": 1013 | GITS = ["git.cmd", "git.exe"] 1014 | # if there is a tag, this yields TAG-NUM-gHEX[-dirty] 1015 | # if there are no tags, this yields HEX[-dirty] (no NUM) 1016 | describe_out = run_command(GITS, ["describe", "--tags", "--dirty", 1017 | "--always", "--long"], 1018 | cwd=root) 1019 | # --long was added in git-1.5.5 1020 | if describe_out is None: 1021 | raise NotThisMethod("'git describe' failed") 1022 | describe_out = describe_out.strip() 1023 | full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) 1024 | if full_out is None: 1025 | raise NotThisMethod("'git rev-parse' failed") 1026 | full_out = full_out.strip() 1027 | 1028 | pieces = {} 1029 | pieces["long"] = full_out 1030 | pieces["short"] = full_out[:7] # maybe improved later 1031 | pieces["error"] = None 1032 | 1033 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 1034 | # TAG might have hyphens. 1035 | git_describe = describe_out 1036 | 1037 | # look for -dirty suffix 1038 | dirty = git_describe.endswith("-dirty") 1039 | pieces["dirty"] = dirty 1040 | if dirty: 1041 | git_describe = git_describe[:git_describe.rindex("-dirty")] 1042 | 1043 | # now we have TAG-NUM-gHEX or HEX 1044 | 1045 | if "-" in git_describe: 1046 | # TAG-NUM-gHEX 1047 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) 1048 | if not mo: 1049 | # unparseable. Maybe git-describe is misbehaving? 1050 | pieces["error"] = ("unable to parse git-describe output: '%s'" 1051 | % describe_out) 1052 | return pieces 1053 | 1054 | # tag 1055 | full_tag = mo.group(1) 1056 | if not full_tag.startswith(tag_prefix): 1057 | if verbose: 1058 | fmt = "tag '%s' doesn't start with prefix '%s'" 1059 | print(fmt % (full_tag, tag_prefix)) 1060 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" 1061 | % (full_tag, tag_prefix)) 1062 | return pieces 1063 | pieces["closest-tag"] = full_tag[len(tag_prefix):] 1064 | 1065 | # distance: number of commits since tag 1066 | pieces["distance"] = int(mo.group(2)) 1067 | 1068 | # commit: short hex revision ID 1069 | pieces["short"] = mo.group(3) 1070 | 1071 | else: 1072 | # HEX: no tags 1073 | pieces["closest-tag"] = None 1074 | count_out = run_command(GITS, ["rev-list", "HEAD", "--count"], 1075 | cwd=root) 1076 | pieces["distance"] = int(count_out) # total number of commits 1077 | 1078 | return pieces 1079 | 1080 | 1081 | def do_vcs_install(manifest_in, versionfile_source, ipy): 1082 | GITS = ["git"] 1083 | if sys.platform == "win32": 1084 | GITS = ["git.cmd", "git.exe"] 1085 | files = [manifest_in, versionfile_source] 1086 | if ipy: 1087 | files.append(ipy) 1088 | try: 1089 | me = __file__ 1090 | if me.endswith(".pyc") or me.endswith(".pyo"): 1091 | me = os.path.splitext(me)[0] + ".py" 1092 | versioneer_file = os.path.relpath(me) 1093 | except NameError: 1094 | versioneer_file = "versioneer.py" 1095 | files.append(versioneer_file) 1096 | present = False 1097 | try: 1098 | f = open(".gitattributes", "r") 1099 | for line in f.readlines(): 1100 | if line.strip().startswith(versionfile_source): 1101 | if "export-subst" in line.strip().split()[1:]: 1102 | present = True 1103 | f.close() 1104 | except EnvironmentError: 1105 | pass 1106 | if not present: 1107 | f = open(".gitattributes", "a+") 1108 | f.write("%s export-subst\n" % versionfile_source) 1109 | f.close() 1110 | files.append(".gitattributes") 1111 | run_command(GITS, ["add", "--"] + files) 1112 | 1113 | 1114 | def versions_from_parentdir(parentdir_prefix, root, verbose): 1115 | # Source tarballs conventionally unpack into a directory that includes 1116 | # both the project name and a version string. 1117 | dirname = os.path.basename(root) 1118 | if not dirname.startswith(parentdir_prefix): 1119 | if verbose: 1120 | print("guessing rootdir is '%s', but '%s' doesn't start with " 1121 | "prefix '%s'" % (root, dirname, parentdir_prefix)) 1122 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 1123 | return {"version": dirname[len(parentdir_prefix):], 1124 | "full-revisionid": None, 1125 | "dirty": False, "error": None} 1126 | 1127 | SHORT_VERSION_PY = """ 1128 | # This file was generated by 'versioneer.py' (0.15) from 1129 | # revision-control system data, or from the parent directory name of an 1130 | # unpacked source archive. Distribution tarballs contain a pre-generated copy 1131 | # of this file. 1132 | 1133 | import json 1134 | import sys 1135 | 1136 | version_json = ''' 1137 | %s 1138 | ''' # END VERSION_JSON 1139 | 1140 | 1141 | def get_versions(): 1142 | return json.loads(version_json) 1143 | """ 1144 | 1145 | 1146 | def versions_from_file(filename): 1147 | try: 1148 | with open(filename) as f: 1149 | contents = f.read() 1150 | except EnvironmentError: 1151 | raise NotThisMethod("unable to read _version.py") 1152 | mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", 1153 | contents, re.M | re.S) 1154 | if not mo: 1155 | raise NotThisMethod("no version_json in _version.py") 1156 | return json.loads(mo.group(1)) 1157 | 1158 | 1159 | def write_to_version_file(filename, versions): 1160 | os.unlink(filename) 1161 | contents = json.dumps(versions, sort_keys=True, 1162 | indent=1, separators=(",", ": ")) 1163 | with open(filename, "w") as f: 1164 | f.write(SHORT_VERSION_PY % contents) 1165 | 1166 | print("set %s to '%s'" % (filename, versions["version"])) 1167 | 1168 | 1169 | def plus_or_dot(pieces): 1170 | if "+" in pieces.get("closest-tag", ""): 1171 | return "." 1172 | return "+" 1173 | 1174 | 1175 | def render_pep440(pieces): 1176 | # now build up version string, with post-release "local version 1177 | # identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 1178 | # get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 1179 | 1180 | # exceptions: 1181 | # 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 1182 | 1183 | if pieces["closest-tag"]: 1184 | rendered = pieces["closest-tag"] 1185 | if pieces["distance"] or pieces["dirty"]: 1186 | rendered += plus_or_dot(pieces) 1187 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 1188 | if pieces["dirty"]: 1189 | rendered += ".dirty" 1190 | else: 1191 | # exception #1 1192 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], 1193 | pieces["short"]) 1194 | if pieces["dirty"]: 1195 | rendered += ".dirty" 1196 | return rendered 1197 | 1198 | 1199 | def render_pep440_pre(pieces): 1200 | # TAG[.post.devDISTANCE] . No -dirty 1201 | 1202 | # exceptions: 1203 | # 1: no tags. 0.post.devDISTANCE 1204 | 1205 | if pieces["closest-tag"]: 1206 | rendered = pieces["closest-tag"] 1207 | if pieces["distance"]: 1208 | rendered += ".post.dev%d" % pieces["distance"] 1209 | else: 1210 | # exception #1 1211 | rendered = "0.post.dev%d" % pieces["distance"] 1212 | return rendered 1213 | 1214 | 1215 | def render_pep440_post(pieces): 1216 | # TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that 1217 | # .dev0 sorts backwards (a dirty tree will appear "older" than the 1218 | # corresponding clean one), but you shouldn't be releasing software with 1219 | # -dirty anyways. 1220 | 1221 | # exceptions: 1222 | # 1: no tags. 0.postDISTANCE[.dev0] 1223 | 1224 | if pieces["closest-tag"]: 1225 | rendered = pieces["closest-tag"] 1226 | if pieces["distance"] or pieces["dirty"]: 1227 | rendered += ".post%d" % pieces["distance"] 1228 | if pieces["dirty"]: 1229 | rendered += ".dev0" 1230 | rendered += plus_or_dot(pieces) 1231 | rendered += "g%s" % pieces["short"] 1232 | else: 1233 | # exception #1 1234 | rendered = "0.post%d" % pieces["distance"] 1235 | if pieces["dirty"]: 1236 | rendered += ".dev0" 1237 | rendered += "+g%s" % pieces["short"] 1238 | return rendered 1239 | 1240 | 1241 | def render_pep440_old(pieces): 1242 | # TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. 1243 | 1244 | # exceptions: 1245 | # 1: no tags. 0.postDISTANCE[.dev0] 1246 | 1247 | if pieces["closest-tag"]: 1248 | rendered = pieces["closest-tag"] 1249 | if pieces["distance"] or pieces["dirty"]: 1250 | rendered += ".post%d" % pieces["distance"] 1251 | if pieces["dirty"]: 1252 | rendered += ".dev0" 1253 | else: 1254 | # exception #1 1255 | rendered = "0.post%d" % pieces["distance"] 1256 | if pieces["dirty"]: 1257 | rendered += ".dev0" 1258 | return rendered 1259 | 1260 | 1261 | def render_git_describe(pieces): 1262 | # TAG[-DISTANCE-gHEX][-dirty], like 'git describe --tags --dirty 1263 | # --always' 1264 | 1265 | # exceptions: 1266 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix) 1267 | 1268 | if pieces["closest-tag"]: 1269 | rendered = pieces["closest-tag"] 1270 | if pieces["distance"]: 1271 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 1272 | else: 1273 | # exception #1 1274 | rendered = pieces["short"] 1275 | if pieces["dirty"]: 1276 | rendered += "-dirty" 1277 | return rendered 1278 | 1279 | 1280 | def render_git_describe_long(pieces): 1281 | # TAG-DISTANCE-gHEX[-dirty], like 'git describe --tags --dirty 1282 | # --always -long'. The distance/hash is unconditional. 1283 | 1284 | # exceptions: 1285 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix) 1286 | 1287 | if pieces["closest-tag"]: 1288 | rendered = pieces["closest-tag"] 1289 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 1290 | else: 1291 | # exception #1 1292 | rendered = pieces["short"] 1293 | if pieces["dirty"]: 1294 | rendered += "-dirty" 1295 | return rendered 1296 | 1297 | 1298 | def render(pieces, style): 1299 | if pieces["error"]: 1300 | return {"version": "unknown", 1301 | "full-revisionid": pieces.get("long"), 1302 | "dirty": None, 1303 | "error": pieces["error"]} 1304 | 1305 | if not style or style == "default": 1306 | style = "pep440" # the default 1307 | 1308 | if style == "pep440": 1309 | rendered = render_pep440(pieces) 1310 | elif style == "pep440-pre": 1311 | rendered = render_pep440_pre(pieces) 1312 | elif style == "pep440-post": 1313 | rendered = render_pep440_post(pieces) 1314 | elif style == "pep440-old": 1315 | rendered = render_pep440_old(pieces) 1316 | elif style == "git-describe": 1317 | rendered = render_git_describe(pieces) 1318 | elif style == "git-describe-long": 1319 | rendered = render_git_describe_long(pieces) 1320 | else: 1321 | raise ValueError("unknown style '%s'" % style) 1322 | 1323 | return {"version": rendered, "full-revisionid": pieces["long"], 1324 | "dirty": pieces["dirty"], "error": None} 1325 | 1326 | 1327 | class VersioneerBadRootError(Exception): 1328 | pass 1329 | 1330 | 1331 | def get_versions(verbose=False): 1332 | # returns dict with two keys: 'version' and 'full' 1333 | 1334 | if "versioneer" in sys.modules: 1335 | # see the discussion in cmdclass.py:get_cmdclass() 1336 | del sys.modules["versioneer"] 1337 | 1338 | root = get_root() 1339 | cfg = get_config_from_root(root) 1340 | 1341 | assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" 1342 | handlers = HANDLERS.get(cfg.VCS) 1343 | assert handlers, "unrecognized VCS '%s'" % cfg.VCS 1344 | verbose = verbose or cfg.verbose 1345 | assert cfg.versionfile_source is not None, \ 1346 | "please set versioneer.versionfile_source" 1347 | assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" 1348 | 1349 | versionfile_abs = os.path.join(root, cfg.versionfile_source) 1350 | 1351 | # extract version from first of: _version.py, VCS command (e.g. 'git 1352 | # describe'), parentdir. This is meant to work for developers using a 1353 | # source checkout, for users of a tarball created by 'setup.py sdist', 1354 | # and for users of a tarball/zipball created by 'git archive' or github's 1355 | # download-from-tag feature or the equivalent in other VCSes. 1356 | 1357 | get_keywords_f = handlers.get("get_keywords") 1358 | from_keywords_f = handlers.get("keywords") 1359 | if get_keywords_f and from_keywords_f: 1360 | try: 1361 | keywords = get_keywords_f(versionfile_abs) 1362 | ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) 1363 | if verbose: 1364 | print("got version from expanded keyword %s" % ver) 1365 | return ver 1366 | except NotThisMethod: 1367 | pass 1368 | 1369 | try: 1370 | ver = versions_from_file(versionfile_abs) 1371 | if verbose: 1372 | print("got version from file %s %s" % (versionfile_abs, ver)) 1373 | return ver 1374 | except NotThisMethod: 1375 | pass 1376 | 1377 | from_vcs_f = handlers.get("pieces_from_vcs") 1378 | if from_vcs_f: 1379 | try: 1380 | pieces = from_vcs_f(cfg.tag_prefix, root, verbose) 1381 | ver = render(pieces, cfg.style) 1382 | if verbose: 1383 | print("got version from VCS %s" % ver) 1384 | return ver 1385 | except NotThisMethod: 1386 | pass 1387 | 1388 | try: 1389 | if cfg.parentdir_prefix: 1390 | ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 1391 | if verbose: 1392 | print("got version from parentdir %s" % ver) 1393 | return ver 1394 | except NotThisMethod: 1395 | pass 1396 | 1397 | if verbose: 1398 | print("unable to compute version") 1399 | 1400 | return {"version": "0+unknown", "full-revisionid": None, 1401 | "dirty": None, "error": "unable to compute version"} 1402 | 1403 | 1404 | def get_version(): 1405 | return get_versions()["version"] 1406 | 1407 | 1408 | def get_cmdclass(): 1409 | if "versioneer" in sys.modules: 1410 | del sys.modules["versioneer"] 1411 | # this fixes the "python setup.py develop" case (also 'install' and 1412 | # 'easy_install .'), in which subdependencies of the main project are 1413 | # built (using setup.py bdist_egg) in the same python process. Assume 1414 | # a main project A and a dependency B, which use different versions 1415 | # of Versioneer. A's setup.py imports A's Versioneer, leaving it in 1416 | # sys.modules by the time B's setup.py is executed, causing B to run 1417 | # with the wrong versioneer. Setuptools wraps the sub-dep builds in a 1418 | # sandbox that restores sys.modules to it's pre-build state, so the 1419 | # parent is protected against the child's "import versioneer". By 1420 | # removing ourselves from sys.modules here, before the child build 1421 | # happens, we protect the child from the parent's versioneer too. 1422 | # Also see https://github.com/warner/python-versioneer/issues/52 1423 | 1424 | cmds = {} 1425 | 1426 | # we add "version" to both distutils and setuptools 1427 | from distutils.core import Command 1428 | 1429 | class cmd_version(Command): 1430 | description = "report generated version string" 1431 | user_options = [] 1432 | boolean_options = [] 1433 | 1434 | def initialize_options(self): 1435 | pass 1436 | 1437 | def finalize_options(self): 1438 | pass 1439 | 1440 | def run(self): 1441 | vers = get_versions(verbose=True) 1442 | print("Version: %s" % vers["version"]) 1443 | print(" full-revisionid: %s" % vers.get("full-revisionid")) 1444 | print(" dirty: %s" % vers.get("dirty")) 1445 | if vers["error"]: 1446 | print(" error: %s" % vers["error"]) 1447 | cmds["version"] = cmd_version 1448 | 1449 | # we override "build_py" in both distutils and setuptools 1450 | # 1451 | # most invocation pathways end up running build_py: 1452 | # distutils/build -> build_py 1453 | # distutils/install -> distutils/build ->.. 1454 | # setuptools/bdist_wheel -> distutils/install ->.. 1455 | # setuptools/bdist_egg -> distutils/install_lib -> build_py 1456 | # setuptools/install -> bdist_egg ->.. 1457 | # setuptools/develop -> ? 1458 | 1459 | from distutils.command.build_py import build_py as _build_py 1460 | 1461 | class cmd_build_py(_build_py): 1462 | def run(self): 1463 | root = get_root() 1464 | cfg = get_config_from_root(root) 1465 | versions = get_versions() 1466 | _build_py.run(self) 1467 | # now locate _version.py in the new build/ directory and replace 1468 | # it with an updated value 1469 | if cfg.versionfile_build: 1470 | target_versionfile = os.path.join(self.build_lib, 1471 | cfg.versionfile_build) 1472 | print("UPDATING %s" % target_versionfile) 1473 | write_to_version_file(target_versionfile, versions) 1474 | cmds["build_py"] = cmd_build_py 1475 | 1476 | if "cx_Freeze" in sys.modules: # cx_freeze enabled? 1477 | from cx_Freeze.dist import build_exe as _build_exe 1478 | 1479 | class cmd_build_exe(_build_exe): 1480 | def run(self): 1481 | root = get_root() 1482 | cfg = get_config_from_root(root) 1483 | versions = get_versions() 1484 | target_versionfile = cfg.versionfile_source 1485 | print("UPDATING %s" % target_versionfile) 1486 | write_to_version_file(target_versionfile, versions) 1487 | 1488 | _build_exe.run(self) 1489 | os.unlink(target_versionfile) 1490 | with open(cfg.versionfile_source, "w") as f: 1491 | LONG = LONG_VERSION_PY[cfg.VCS] 1492 | f.write(LONG % 1493 | {"DOLLAR": "$", 1494 | "STYLE": cfg.style, 1495 | "TAG_PREFIX": cfg.tag_prefix, 1496 | "PARENTDIR_PREFIX": cfg.parentdir_prefix, 1497 | "VERSIONFILE_SOURCE": cfg.versionfile_source, 1498 | }) 1499 | cmds["build_exe"] = cmd_build_exe 1500 | del cmds["build_py"] 1501 | 1502 | # we override different "sdist" commands for both environments 1503 | if "setuptools" in sys.modules: 1504 | from setuptools.command.sdist import sdist as _sdist 1505 | else: 1506 | from distutils.command.sdist import sdist as _sdist 1507 | 1508 | class cmd_sdist(_sdist): 1509 | def run(self): 1510 | versions = get_versions() 1511 | self._versioneer_generated_versions = versions 1512 | # unless we update this, the command will keep using the old 1513 | # version 1514 | self.distribution.metadata.version = versions["version"] 1515 | return _sdist.run(self) 1516 | 1517 | def make_release_tree(self, base_dir, files): 1518 | root = get_root() 1519 | cfg = get_config_from_root(root) 1520 | _sdist.make_release_tree(self, base_dir, files) 1521 | # now locate _version.py in the new base_dir directory 1522 | # (remembering that it may be a hardlink) and replace it with an 1523 | # updated value 1524 | target_versionfile = os.path.join(base_dir, cfg.versionfile_source) 1525 | print("UPDATING %s" % target_versionfile) 1526 | write_to_version_file(target_versionfile, 1527 | self._versioneer_generated_versions) 1528 | cmds["sdist"] = cmd_sdist 1529 | 1530 | return cmds 1531 | 1532 | 1533 | CONFIG_ERROR = """ 1534 | setup.cfg is missing the necessary Versioneer configuration. You need 1535 | a section like: 1536 | 1537 | [versioneer] 1538 | VCS = git 1539 | style = pep440 1540 | versionfile_source = src/myproject/_version.py 1541 | versionfile_build = myproject/_version.py 1542 | tag_prefix = "" 1543 | parentdir_prefix = myproject- 1544 | 1545 | You will also need to edit your setup.py to use the results: 1546 | 1547 | import versioneer 1548 | setup(version=versioneer.get_version(), 1549 | cmdclass=versioneer.get_cmdclass(), ...) 1550 | 1551 | Please read the docstring in ./versioneer.py for configuration instructions, 1552 | edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. 1553 | """ 1554 | 1555 | SAMPLE_CONFIG = """ 1556 | # See the docstring in versioneer.py for instructions. Note that you must 1557 | # re-run 'versioneer.py setup' after changing this section, and commit the 1558 | # resulting files. 1559 | 1560 | [versioneer] 1561 | #VCS = git 1562 | #style = pep440 1563 | #versionfile_source = 1564 | #versionfile_build = 1565 | #tag_prefix = 1566 | #parentdir_prefix = 1567 | 1568 | """ 1569 | 1570 | INIT_PY_SNIPPET = """ 1571 | from ._version import get_versions 1572 | __version__ = get_versions()['version'] 1573 | del get_versions 1574 | """ 1575 | 1576 | 1577 | def do_setup(): 1578 | root = get_root() 1579 | try: 1580 | cfg = get_config_from_root(root) 1581 | except (EnvironmentError, configparser.NoSectionError, 1582 | configparser.NoOptionError) as e: 1583 | if isinstance(e, (EnvironmentError, configparser.NoSectionError)): 1584 | print("Adding sample versioneer config to setup.cfg", 1585 | file=sys.stderr) 1586 | with open(os.path.join(root, "setup.cfg"), "a") as f: 1587 | f.write(SAMPLE_CONFIG) 1588 | print(CONFIG_ERROR, file=sys.stderr) 1589 | return 1 1590 | 1591 | print(" creating %s" % cfg.versionfile_source) 1592 | with open(cfg.versionfile_source, "w") as f: 1593 | LONG = LONG_VERSION_PY[cfg.VCS] 1594 | f.write(LONG % {"DOLLAR": "$", 1595 | "STYLE": cfg.style, 1596 | "TAG_PREFIX": cfg.tag_prefix, 1597 | "PARENTDIR_PREFIX": cfg.parentdir_prefix, 1598 | "VERSIONFILE_SOURCE": cfg.versionfile_source, 1599 | }) 1600 | 1601 | ipy = os.path.join(os.path.dirname(cfg.versionfile_source), 1602 | "__init__.py") 1603 | if os.path.exists(ipy): 1604 | try: 1605 | with open(ipy, "r") as f: 1606 | old = f.read() 1607 | except EnvironmentError: 1608 | old = "" 1609 | if INIT_PY_SNIPPET not in old: 1610 | print(" appending to %s" % ipy) 1611 | with open(ipy, "a") as f: 1612 | f.write(INIT_PY_SNIPPET) 1613 | else: 1614 | print(" %s unmodified" % ipy) 1615 | else: 1616 | print(" %s doesn't exist, ok" % ipy) 1617 | ipy = None 1618 | 1619 | # Make sure both the top-level "versioneer.py" and versionfile_source 1620 | # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so 1621 | # they'll be copied into source distributions. Pip won't be able to 1622 | # install the package without this. 1623 | manifest_in = os.path.join(root, "MANIFEST.in") 1624 | simple_includes = set() 1625 | try: 1626 | with open(manifest_in, "r") as f: 1627 | for line in f: 1628 | if line.startswith("include "): 1629 | for include in line.split()[1:]: 1630 | simple_includes.add(include) 1631 | except EnvironmentError: 1632 | pass 1633 | # That doesn't cover everything MANIFEST.in can do 1634 | # (http://docs.python.org/2/distutils/sourcedist.html#commands), so 1635 | # it might give some false negatives. Appending redundant 'include' 1636 | # lines is safe, though. 1637 | if "versioneer.py" not in simple_includes: 1638 | print(" appending 'versioneer.py' to MANIFEST.in") 1639 | with open(manifest_in, "a") as f: 1640 | f.write("include versioneer.py\n") 1641 | else: 1642 | print(" 'versioneer.py' already in MANIFEST.in") 1643 | if cfg.versionfile_source not in simple_includes: 1644 | print(" appending versionfile_source ('%s') to MANIFEST.in" % 1645 | cfg.versionfile_source) 1646 | with open(manifest_in, "a") as f: 1647 | f.write("include %s\n" % cfg.versionfile_source) 1648 | else: 1649 | print(" versionfile_source already in MANIFEST.in") 1650 | 1651 | # Make VCS-specific changes. For git, this means creating/changing 1652 | # .gitattributes to mark _version.py for export-time keyword 1653 | # substitution. 1654 | do_vcs_install(manifest_in, cfg.versionfile_source, ipy) 1655 | return 0 1656 | 1657 | 1658 | def scan_setup_py(): 1659 | found = set() 1660 | setters = False 1661 | errors = 0 1662 | with open("setup.py", "r") as f: 1663 | for line in f.readlines(): 1664 | if "import versioneer" in line: 1665 | found.add("import") 1666 | if "versioneer.get_cmdclass()" in line: 1667 | found.add("cmdclass") 1668 | if "versioneer.get_version()" in line: 1669 | found.add("get_version") 1670 | if "versioneer.VCS" in line: 1671 | setters = True 1672 | if "versioneer.versionfile_source" in line: 1673 | setters = True 1674 | if len(found) != 3: 1675 | print("") 1676 | print("Your setup.py appears to be missing some important items") 1677 | print("(but I might be wrong). Please make sure it has something") 1678 | print("roughly like the following:") 1679 | print("") 1680 | print(" import versioneer") 1681 | print(" setup( version=versioneer.get_version(),") 1682 | print(" cmdclass=versioneer.get_cmdclass(), ...)") 1683 | print("") 1684 | errors += 1 1685 | if setters: 1686 | print("You should remove lines like 'versioneer.VCS = ' and") 1687 | print("'versioneer.versionfile_source = ' . This configuration") 1688 | print("now lives in setup.cfg, and should be removed from setup.py") 1689 | print("") 1690 | errors += 1 1691 | return errors 1692 | 1693 | if __name__ == "__main__": 1694 | cmd = sys.argv[1] 1695 | if cmd == "setup": 1696 | errors = do_setup() 1697 | errors += scan_setup_py() 1698 | if errors: 1699 | sys.exit(1) 1700 | --------------------------------------------------------------------------------