├── .gitattributes
├── .gitignore
├── .travis.yml
├── LICENSE
├── MANIFEST.in
├── README.adoc
├── multi
├── pg8000
├── __init__.py
├── _version.py
└── core.py
├── setup.cfg
├── setup.py
├── tests
├── connection_settings.py
├── dbapi20.py
├── performance.py
├── stress.py
├── test_connection.py
├── test_copy.py
├── test_dbapi.py
├── test_error_recovery.py
├── test_paramstyle.py
├── test_pg8000_dbapi20.py
├── test_query.py
├── test_typeconversion.py
└── test_typeobjects.py
├── tox.ini
└── versioneer.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | pg8000/_version.py export-subst
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.py[co]
2 | *.swp
3 | *.orig
4 | *.class
5 | build
6 | pg8000.egg-info
7 | tmp
8 | dist
9 | .tox
10 | MANIFEST
11 | venv
12 | .cache
13 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | sudo: required
2 | language: python
3 | python:
4 | - "2.7"
5 | - "3.5"
6 | - "3.6"
7 | - "pypy3.5"
8 |
9 | env:
10 | - DB="9.5"
11 | - DB="9.6"
12 |
13 | services:
14 | - postgresql
15 | addons:
16 | postgresql: "9.5"
17 | postgresql: "9.6"
18 |
19 | before_install:
20 | - sudo service postgresql stop
21 | - cd /etc/postgresql/$DB/main
22 | - sudo chmod ugo+rw pg_hba.conf
23 | - sudo cp pg_hba.conf old_pg_hba.conf
24 | - sudo echo "host pg8000_md5 all 127.0.0.1/32 md5" > pg_hba.conf
25 | - sudo echo "host pg8000_gss all 127.0.0.1/32 gss" >> pg_hba.conf
26 | - sudo echo "host pg8000_password all 127.0.0.1/32 password" >> pg_hba.conf
27 | - cat old_pg_hba.conf >> pg_hba.conf
28 | - sudo service postgresql start $DB
29 | - psql -U postgres -tc 'create extension hstore;'
30 | - psql -U postgres -tc 'show server_version;'
31 | - psql -U postgres -tc "alter user postgres with password 'pw';"
32 | - psql -U postgres -tc "alter system set client_min_messages = notice;"
33 | - sudo service postgresql reload $DB
34 | - psql -U postgres -tc 'show client_min_messages;'
35 | - cd $TRAVIS_BUILD_DIR
36 |
37 | install:
38 | - pip install nose
39 | - pip install pytz
40 |
41 | script:
42 | - nosetests
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2007-2009, Mathieu Fenniak
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are
6 | met:
7 |
8 | * Redistributions of source code must retain the above copyright notice,
9 | this list of conditions and the following disclaimer.
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 | * The name of the author may not be used to endorse or promote products
14 | derived from this software without specific prior written permission.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
20 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 | POSSIBILITY OF SUCH DAMAGE.
27 |
28 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.creole
2 | include versioneer.py
3 | include pg8000/_version.py
4 | include LICENSE
5 | include doc/*
6 |
--------------------------------------------------------------------------------
/README.adoc:
--------------------------------------------------------------------------------
1 | = pg8000
2 |
3 | Hello, the pg8000 repository has moved to https://github.com/tlocke/pg8000.
4 |
--------------------------------------------------------------------------------
/multi:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # set postgres share memory to minimum to trigger unpinned buffer errors.
4 |
5 | for i in {1..100}
6 | do
7 | python -m pg8000.tests.stress &
8 | done
9 | wait
10 | echo "All processes done!"
11 |
--------------------------------------------------------------------------------
/pg8000/__init__.py:
--------------------------------------------------------------------------------
1 | from pg8000.core import (
2 | Warning, Bytea, DataError, DatabaseError, InterfaceError, ProgrammingError,
3 | Error, OperationalError, IntegrityError, InternalError, NotSupportedError,
4 | ArrayContentNotHomogenousError, ArrayDimensionsNotConsistentError,
5 | ArrayContentNotSupportedError, utc, Connection, Cursor, Binary, Date,
6 | DateFromTicks, Time, TimeFromTicks, Timestamp, TimestampFromTicks, BINARY,
7 | Interval, PGEnum, PGJson, PGJsonb, PGTsvector, PGText, PGVarchar)
8 | from ._version import get_versions
9 | __version__ = get_versions()['version']
10 | del get_versions
11 |
12 | # Copyright (c) 2007-2009, Mathieu Fenniak
13 | # Copyright (c) The Contributors
14 | # All rights reserved.
15 | #
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are
18 | # met:
19 | #
20 | # * Redistributions of source code must retain the above copyright notice,
21 | # this list of conditions and the following disclaimer.
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 | # * The name of the author may not be used to endorse or promote products
26 | # derived from this software without specific prior written permission.
27 | #
28 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
29 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
30 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
31 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
32 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
33 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
36 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
37 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
38 | # POSSIBILITY OF SUCH DAMAGE.
39 |
40 | __author__ = "Mathieu Fenniak"
41 |
42 |
43 | def connect(
44 | user, host='localhost', unix_sock=None, port=5432, database=None,
45 | password=None, ssl=False, timeout=None, application_name=None,
46 | max_prepared_statements=1000):
47 | """Creates a connection to a PostgreSQL database.
48 |
49 | This function is part of the `DBAPI 2.0 specification
50 | `_; however, the arguments of the
51 | function are not defined by the specification.
52 |
53 | :param user:
54 | The username to connect to the PostgreSQL server with.
55 |
56 | If your server character encoding is not ``ascii`` or ``utf8``, then
57 | you need to provide ``user`` as bytes, eg.
58 | ``"my_name".encode('EUC-JP')``.
59 |
60 | :keyword host:
61 | The hostname of the PostgreSQL server to connect with. Providing this
62 | parameter is necessary for TCP/IP connections. One of either ``host``
63 | or ``unix_sock`` must be provided. The default is ``localhost``.
64 |
65 | :keyword unix_sock:
66 | The path to the UNIX socket to access the database through, for
67 | example, ``'/tmp/.s.PGSQL.5432'``. One of either ``host`` or
68 | ``unix_sock`` must be provided.
69 |
70 | :keyword port:
71 | The TCP/IP port of the PostgreSQL server instance. This parameter
72 | defaults to ``5432``, the registered common port of PostgreSQL TCP/IP
73 | servers.
74 |
75 | :keyword database:
76 | The name of the database instance to connect with. This parameter is
77 | optional; if omitted, the PostgreSQL server will assume the database
78 | name is the same as the username.
79 |
80 | If your server character encoding is not ``ascii`` or ``utf8``, then
81 | you need to provide ``database`` as bytes, eg.
82 | ``"my_db".encode('EUC-JP')``.
83 |
84 | :keyword password:
85 | The user password to connect to the server with. This parameter is
86 | optional; if omitted and the database server requests password-based
87 | authentication, the connection will fail to open. If this parameter
88 | is provided but not requested by the server, no error will occur.
89 |
90 | If your server character encoding is not ``ascii`` or ``utf8``, then
91 | you need to provide ``user`` as bytes, eg.
92 | ``"my_password".encode('EUC-JP')``.
93 |
94 | :keyword application_name:
95 | The name will be displayed in the pg_stat_activity view.
96 | This parameter is optional.
97 |
98 | :keyword ssl:
99 | Use SSL encryption for TCP/IP sockets if ``True``. Defaults to
100 | ``False``.
101 |
102 | :keyword timeout:
103 | Only used with Python 3, this is the time in seconds before the
104 | connection to the database will time out. The default is ``None`` which
105 | means no timeout.
106 |
107 | :rtype:
108 | A :class:`Connection` object.
109 | """
110 | return Connection(
111 | user, host, unix_sock, port, database, password, ssl, timeout,
112 | application_name, max_prepared_statements)
113 |
114 |
115 | apilevel = "2.0"
116 | """The DBAPI level supported, currently "2.0".
117 |
118 | This property is part of the `DBAPI 2.0 specification
119 | `_.
120 | """
121 |
122 | threadsafety = 1
123 | """Integer constant stating the level of thread safety the DBAPI interface
124 | supports. This DBAPI module supports sharing of the module only. Connections
125 | and cursors my not be shared between threads. This gives pg8000 a threadsafety
126 | value of 1.
127 |
128 | This property is part of the `DBAPI 2.0 specification
129 | `_.
130 | """
131 |
132 | paramstyle = 'format'
133 |
134 | max_prepared_statements = 1000
135 |
136 | # I have no idea what this would be used for by a client app. Should it be
137 | # TEXT, VARCHAR, CHAR? It will only compare against row_description's
138 | # type_code if it is this one type. It is the varchar type oid for now, this
139 | # appears to match expectations in the DB API 2.0 compliance test suite.
140 |
141 | STRING = 1043
142 | """String type oid."""
143 |
144 |
145 | NUMBER = 1700
146 | """Numeric type oid"""
147 |
148 | DATETIME = 1114
149 | """Timestamp type oid"""
150 |
151 | ROWID = 26
152 | """ROWID type oid"""
153 |
154 | __all__ = [
155 | Warning, Bytea, DataError, DatabaseError, connect, InterfaceError,
156 | ProgrammingError, Error, OperationalError, IntegrityError, InternalError,
157 | NotSupportedError, ArrayContentNotHomogenousError,
158 | ArrayDimensionsNotConsistentError, ArrayContentNotSupportedError, utc,
159 | Connection, Cursor, Binary, Date, DateFromTicks, Time, TimeFromTicks,
160 | Timestamp, TimestampFromTicks, BINARY, Interval, PGEnum, PGJson, PGJsonb,
161 | PGTsvector, PGText, PGVarchar]
162 |
163 | """Version string for pg8000.
164 |
165 | .. versionadded:: 1.9.11
166 | """
167 |
--------------------------------------------------------------------------------
/pg8000/_version.py:
--------------------------------------------------------------------------------
1 |
2 | # This file helps to compute a version number in source trees obtained from
3 | # git-archive tarball (such as those provided by githubs download-from-tag
4 | # feature). Distribution tarballs (built by setup.py sdist) and build
5 | # directories (produced by setup.py build) will contain a much shorter file
6 | # that just contains the computed version number.
7 |
8 | # This file is released into the public domain. Generated by
9 | # versioneer-0.15 (https://github.com/warner/python-versioneer)
10 |
11 | import errno
12 | import os
13 | import re
14 | import subprocess
15 | import sys
16 |
17 |
18 | def get_keywords():
19 | # these strings will be replaced by git during git-archive.
20 | # setup.py/versioneer.py will grep for the variable names, so they must
21 | # each be defined on a line of their own. _version.py will just call
22 | # get_keywords().
23 | git_refnames = " (HEAD -> master)"
24 | git_full = "412eace074514ada824e7a102765e37e2cda8eaa"
25 | keywords = {"refnames": git_refnames, "full": git_full}
26 | return keywords
27 |
28 |
29 | class VersioneerConfig:
30 | pass
31 |
32 |
33 | def get_config():
34 | # these strings are filled in when 'setup.py versioneer' creates
35 | # _version.py
36 | cfg = VersioneerConfig()
37 | cfg.VCS = "git"
38 | cfg.style = "pep440"
39 | cfg.tag_prefix = ""
40 | cfg.parentdir_prefix = "pg8000-"
41 | cfg.versionfile_source = "pg8000/_version.py"
42 | cfg.verbose = False
43 | return cfg
44 |
45 |
46 | class NotThisMethod(Exception):
47 | pass
48 |
49 |
50 | LONG_VERSION_PY = {}
51 | HANDLERS = {}
52 |
53 |
54 | def register_vcs_handler(vcs, method): # decorator
55 | def decorate(f):
56 | if vcs not in HANDLERS:
57 | HANDLERS[vcs] = {}
58 | HANDLERS[vcs][method] = f
59 | return f
60 | return decorate
61 |
62 |
63 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False):
64 | assert isinstance(commands, list)
65 | p = None
66 | for c in commands:
67 | try:
68 | dispcmd = str([c] + args)
69 | # remember shell=False, so use git.cmd on windows, not just git
70 | p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE,
71 | stderr=(subprocess.PIPE if hide_stderr
72 | else None))
73 | break
74 | except EnvironmentError:
75 | e = sys.exc_info()[1]
76 | if e.errno == errno.ENOENT:
77 | continue
78 | if verbose:
79 | print("unable to run %s" % dispcmd)
80 | print(e)
81 | return None
82 | else:
83 | if verbose:
84 | print("unable to find command, tried %s" % (commands,))
85 | return None
86 | stdout = p.communicate()[0].strip()
87 | if sys.version_info[0] >= 3:
88 | stdout = stdout.decode()
89 | if p.returncode != 0:
90 | if verbose:
91 | print("unable to run %s (error)" % dispcmd)
92 | return None
93 | return stdout
94 |
95 |
96 | def versions_from_parentdir(parentdir_prefix, root, verbose):
97 | # Source tarballs conventionally unpack into a directory that includes
98 | # both the project name and a version string.
99 | dirname = os.path.basename(root)
100 | if not dirname.startswith(parentdir_prefix):
101 | if verbose:
102 | print("guessing rootdir is '%s', but '%s' doesn't start with "
103 | "prefix '%s'" % (root, dirname, parentdir_prefix))
104 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
105 | return {"version": dirname[len(parentdir_prefix):],
106 | "full-revisionid": None,
107 | "dirty": False, "error": None}
108 |
109 |
110 | @register_vcs_handler("git", "get_keywords")
111 | def git_get_keywords(versionfile_abs):
112 | # the code embedded in _version.py can just fetch the value of these
113 | # keywords. When used from setup.py, we don't want to import _version.py,
114 | # so we do it with a regexp instead. This function is not used from
115 | # _version.py.
116 | keywords = {}
117 | try:
118 | f = open(versionfile_abs, "r")
119 | for line in f.readlines():
120 | if line.strip().startswith("git_refnames ="):
121 | mo = re.search(r'=\s*"(.*)"', line)
122 | if mo:
123 | keywords["refnames"] = mo.group(1)
124 | if line.strip().startswith("git_full ="):
125 | mo = re.search(r'=\s*"(.*)"', line)
126 | if mo:
127 | keywords["full"] = mo.group(1)
128 | f.close()
129 | except EnvironmentError:
130 | pass
131 | return keywords
132 |
133 |
134 | @register_vcs_handler("git", "keywords")
135 | def git_versions_from_keywords(keywords, tag_prefix, verbose):
136 | if not keywords:
137 | raise NotThisMethod("no keywords at all, weird")
138 | refnames = keywords["refnames"].strip()
139 | if refnames.startswith("$Format"):
140 | if verbose:
141 | print("keywords are unexpanded, not using")
142 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
143 | refs = set([r.strip() for r in refnames.strip("()").split(",")])
144 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
145 | # just "foo-1.0". If we see a "tag: " prefix, prefer those.
146 | TAG = "tag: "
147 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
148 | if not tags:
149 | # Either we're using git < 1.8.3, or there really are no tags. We use
150 | # a heuristic: assume all version tags have a digit. The old git %d
151 | # expansion behaves like git log --decorate=short and strips out the
152 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish
153 | # between branches and tags. By ignoring refnames without digits, we
154 | # filter out many common branch names like "release" and
155 | # "stabilization", as well as "HEAD" and "master".
156 | tags = set([r for r in refs if re.search(r'\d', r)])
157 | if verbose:
158 | print("discarding '%s', no digits" % ",".join(refs-tags))
159 | if verbose:
160 | print("likely tags: %s" % ",".join(sorted(tags)))
161 | for ref in sorted(tags):
162 | # sorting will prefer e.g. "2.0" over "2.0rc1"
163 | if ref.startswith(tag_prefix):
164 | r = ref[len(tag_prefix):]
165 | if verbose:
166 | print("picking %s" % r)
167 | return {"version": r,
168 | "full-revisionid": keywords["full"].strip(),
169 | "dirty": False, "error": None
170 | }
171 | # no suitable tags, so version is "0+unknown", but full hex is still there
172 | if verbose:
173 | print("no suitable tags, using unknown + full revision id")
174 | return {"version": "0+unknown",
175 | "full-revisionid": keywords["full"].strip(),
176 | "dirty": False, "error": "no suitable tags"}
177 |
178 |
179 | @register_vcs_handler("git", "pieces_from_vcs")
180 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
181 | # this runs 'git' from the root of the source tree. This only gets called
182 | # if the git-archive 'subst' keywords were *not* expanded, and
183 | # _version.py hasn't already been rewritten with a short version string,
184 | # meaning we're inside a checked out source tree.
185 |
186 | if not os.path.exists(os.path.join(root, ".git")):
187 | if verbose:
188 | print("no .git in %s" % root)
189 | raise NotThisMethod("no .git directory")
190 |
191 | GITS = ["git"]
192 | if sys.platform == "win32":
193 | GITS = ["git.cmd", "git.exe"]
194 | # if there is a tag, this yields TAG-NUM-gHEX[-dirty]
195 | # if there are no tags, this yields HEX[-dirty] (no NUM)
196 | describe_out = run_command(GITS, ["describe", "--tags", "--dirty",
197 | "--always", "--long"],
198 | cwd=root)
199 | # --long was added in git-1.5.5
200 | if describe_out is None:
201 | raise NotThisMethod("'git describe' failed")
202 | describe_out = describe_out.strip()
203 | full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
204 | if full_out is None:
205 | raise NotThisMethod("'git rev-parse' failed")
206 | full_out = full_out.strip()
207 |
208 | pieces = {}
209 | pieces["long"] = full_out
210 | pieces["short"] = full_out[:7] # maybe improved later
211 | pieces["error"] = None
212 |
213 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
214 | # TAG might have hyphens.
215 | git_describe = describe_out
216 |
217 | # look for -dirty suffix
218 | dirty = git_describe.endswith("-dirty")
219 | pieces["dirty"] = dirty
220 | if dirty:
221 | git_describe = git_describe[:git_describe.rindex("-dirty")]
222 |
223 | # now we have TAG-NUM-gHEX or HEX
224 |
225 | if "-" in git_describe:
226 | # TAG-NUM-gHEX
227 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
228 | if not mo:
229 | # unparseable. Maybe git-describe is misbehaving?
230 | pieces["error"] = ("unable to parse git-describe output: '%s'"
231 | % describe_out)
232 | return pieces
233 |
234 | # tag
235 | full_tag = mo.group(1)
236 | if not full_tag.startswith(tag_prefix):
237 | if verbose:
238 | fmt = "tag '%s' doesn't start with prefix '%s'"
239 | print(fmt % (full_tag, tag_prefix))
240 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
241 | % (full_tag, tag_prefix))
242 | return pieces
243 | pieces["closest-tag"] = full_tag[len(tag_prefix):]
244 |
245 | # distance: number of commits since tag
246 | pieces["distance"] = int(mo.group(2))
247 |
248 | # commit: short hex revision ID
249 | pieces["short"] = mo.group(3)
250 |
251 | else:
252 | # HEX: no tags
253 | pieces["closest-tag"] = None
254 | count_out = run_command(GITS, ["rev-list", "HEAD", "--count"],
255 | cwd=root)
256 | pieces["distance"] = int(count_out) # total number of commits
257 |
258 | return pieces
259 |
260 |
261 | def plus_or_dot(pieces):
262 | if "+" in pieces.get("closest-tag", ""):
263 | return "."
264 | return "+"
265 |
266 |
267 | def render_pep440(pieces):
268 | # now build up version string, with post-release "local version
269 | # identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
270 | # get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
271 |
272 | # exceptions:
273 | # 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
274 |
275 | if pieces["closest-tag"]:
276 | rendered = pieces["closest-tag"]
277 | if pieces["distance"] or pieces["dirty"]:
278 | rendered += plus_or_dot(pieces)
279 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
280 | if pieces["dirty"]:
281 | rendered += ".dirty"
282 | else:
283 | # exception #1
284 | rendered = "0+untagged.%d.g%s" % (pieces["distance"],
285 | pieces["short"])
286 | if pieces["dirty"]:
287 | rendered += ".dirty"
288 | return rendered
289 |
290 |
291 | def render_pep440_pre(pieces):
292 | # TAG[.post.devDISTANCE] . No -dirty
293 |
294 | # exceptions:
295 | # 1: no tags. 0.post.devDISTANCE
296 |
297 | if pieces["closest-tag"]:
298 | rendered = pieces["closest-tag"]
299 | if pieces["distance"]:
300 | rendered += ".post.dev%d" % pieces["distance"]
301 | else:
302 | # exception #1
303 | rendered = "0.post.dev%d" % pieces["distance"]
304 | return rendered
305 |
306 |
307 | def render_pep440_post(pieces):
308 | # TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that
309 | # .dev0 sorts backwards (a dirty tree will appear "older" than the
310 | # corresponding clean one), but you shouldn't be releasing software with
311 | # -dirty anyways.
312 |
313 | # exceptions:
314 | # 1: no tags. 0.postDISTANCE[.dev0]
315 |
316 | if pieces["closest-tag"]:
317 | rendered = pieces["closest-tag"]
318 | if pieces["distance"] or pieces["dirty"]:
319 | rendered += ".post%d" % pieces["distance"]
320 | if pieces["dirty"]:
321 | rendered += ".dev0"
322 | rendered += plus_or_dot(pieces)
323 | rendered += "g%s" % pieces["short"]
324 | else:
325 | # exception #1
326 | rendered = "0.post%d" % pieces["distance"]
327 | if pieces["dirty"]:
328 | rendered += ".dev0"
329 | rendered += "+g%s" % pieces["short"]
330 | return rendered
331 |
332 |
333 | def render_pep440_old(pieces):
334 | # TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty.
335 |
336 | # exceptions:
337 | # 1: no tags. 0.postDISTANCE[.dev0]
338 |
339 | if pieces["closest-tag"]:
340 | rendered = pieces["closest-tag"]
341 | if pieces["distance"] or pieces["dirty"]:
342 | rendered += ".post%d" % pieces["distance"]
343 | if pieces["dirty"]:
344 | rendered += ".dev0"
345 | else:
346 | # exception #1
347 | rendered = "0.post%d" % pieces["distance"]
348 | if pieces["dirty"]:
349 | rendered += ".dev0"
350 | return rendered
351 |
352 |
353 | def render_git_describe(pieces):
354 | # TAG[-DISTANCE-gHEX][-dirty], like 'git describe --tags --dirty
355 | # --always'
356 |
357 | # exceptions:
358 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix)
359 |
360 | if pieces["closest-tag"]:
361 | rendered = pieces["closest-tag"]
362 | if pieces["distance"]:
363 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
364 | else:
365 | # exception #1
366 | rendered = pieces["short"]
367 | if pieces["dirty"]:
368 | rendered += "-dirty"
369 | return rendered
370 |
371 |
372 | def render_git_describe_long(pieces):
373 | # TAG-DISTANCE-gHEX[-dirty], like 'git describe --tags --dirty
374 | # --always -long'. The distance/hash is unconditional.
375 |
376 | # exceptions:
377 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix)
378 |
379 | if pieces["closest-tag"]:
380 | rendered = pieces["closest-tag"]
381 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
382 | else:
383 | # exception #1
384 | rendered = pieces["short"]
385 | if pieces["dirty"]:
386 | rendered += "-dirty"
387 | return rendered
388 |
389 |
390 | def render(pieces, style):
391 | if pieces["error"]:
392 | return {"version": "unknown",
393 | "full-revisionid": pieces.get("long"),
394 | "dirty": None,
395 | "error": pieces["error"]}
396 |
397 | if not style or style == "default":
398 | style = "pep440" # the default
399 |
400 | if style == "pep440":
401 | rendered = render_pep440(pieces)
402 | elif style == "pep440-pre":
403 | rendered = render_pep440_pre(pieces)
404 | elif style == "pep440-post":
405 | rendered = render_pep440_post(pieces)
406 | elif style == "pep440-old":
407 | rendered = render_pep440_old(pieces)
408 | elif style == "git-describe":
409 | rendered = render_git_describe(pieces)
410 | elif style == "git-describe-long":
411 | rendered = render_git_describe_long(pieces)
412 | else:
413 | raise ValueError("unknown style '%s'" % style)
414 |
415 | return {"version": rendered, "full-revisionid": pieces["long"],
416 | "dirty": pieces["dirty"], "error": None}
417 |
418 |
419 | def get_versions():
420 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
421 | # __file__, we can work backwards from there to the root. Some
422 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
423 | # case we can only use expanded keywords.
424 |
425 | cfg = get_config()
426 | verbose = cfg.verbose
427 |
428 | try:
429 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
430 | verbose)
431 | except NotThisMethod:
432 | pass
433 |
434 | try:
435 | root = os.path.realpath(__file__)
436 | # versionfile_source is the relative path from the top of the source
437 | # tree (where the .git directory might live) to this file. Invert
438 | # this to find the root from __file__.
439 | for i in cfg.versionfile_source.split('/'):
440 | root = os.path.dirname(root)
441 | except NameError:
442 | return {"version": "0+unknown", "full-revisionid": None,
443 | "dirty": None,
444 | "error": "unable to find root of source tree"}
445 |
446 | try:
447 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
448 | return render(pieces, cfg.style)
449 | except NotThisMethod:
450 | pass
451 |
452 | try:
453 | if cfg.parentdir_prefix:
454 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
455 | except NotThisMethod:
456 | pass
457 |
458 | return {"version": "0+unknown", "full-revisionid": None,
459 | "dirty": None,
460 | "error": "unable to compute version"}
461 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [upload_docs]
2 | upload-dir = build/sphinx/html
3 |
4 | [bdist_wheel]
5 | universal=1
6 |
7 | [versioneer]
8 | VCS = git
9 | style = pep440
10 | versionfile_source = pg8000/_version.py
11 | versionfile_build = pg8000/_version.py
12 | tag_prefix =
13 | parentdir_prefix = pg8000-
14 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import versioneer
4 | from setuptools import setup
5 |
6 | long_description = """\
7 |
8 | pg8000
9 | ------
10 |
11 | pg8000 is a Pure-Python interface to the PostgreSQL database engine. It is \
12 | one of many PostgreSQL interfaces for the Python programming language. pg8000 \
13 | is somewhat distinctive in that it is written entirely in Python and does not \
14 | rely on any external libraries (such as a compiled python module, or \
15 | PostgreSQL's libpq library). pg8000 supports the standard Python DB-API \
16 | version 2.0.
17 |
18 | pg8000's name comes from the belief that it is probably about the 8000th \
19 | PostgreSQL interface for Python."""
20 |
21 | cmdclass = dict(versioneer.get_cmdclass())
22 | version = versioneer.get_version()
23 |
24 | setup(
25 | name="pg8000",
26 | version=version,
27 | cmdclass=cmdclass,
28 | description="PostgreSQL interface library",
29 | long_description=long_description,
30 | author="Mathieu Fenniak",
31 | author_email="biziqe@mathieu.fenniak.net",
32 | url="https://github.com/mfenniak/pg8000",
33 | license="BSD",
34 | install_requires=[
35 | "six>=1.10.0",
36 | ],
37 | classifiers=[
38 | "Development Status :: 4 - Beta",
39 | "Intended Audience :: Developers",
40 | "License :: OSI Approved :: BSD License",
41 | "Programming Language :: Python",
42 | "Programming Language :: Python :: 2",
43 | "Programming Language :: Python :: 2.7",
44 | "Programming Language :: Python :: 3",
45 | "Programming Language :: Python :: 3.3",
46 | "Programming Language :: Python :: 3.4",
47 | "Programming Language :: Python :: 3.5",
48 | "Programming Language :: Python :: Implementation",
49 | "Programming Language :: Python :: Implementation :: CPython",
50 | "Programming Language :: Python :: Implementation :: Jython",
51 | "Programming Language :: Python :: Implementation :: PyPy",
52 | "Operating System :: OS Independent",
53 | "Topic :: Database :: Front-Ends",
54 | "Topic :: Software Development :: Libraries :: Python Modules",
55 | ],
56 | keywords="postgresql dbapi",
57 | packages=("pg8000",),
58 | command_options={
59 | 'build_sphinx': {
60 | 'version': ('setup.py', version),
61 | 'release': ('setup.py', version)}},
62 | )
63 |
--------------------------------------------------------------------------------
/tests/connection_settings.py:
--------------------------------------------------------------------------------
1 | db_connect = {
2 | 'user': 'postgres',
3 | 'password': 'pw',
4 | 'port': 5432}
5 |
--------------------------------------------------------------------------------
/tests/dbapi20.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import unittest
3 | import time
4 | import warnings
5 | from six import b
6 | ''' Python DB API 2.0 driver compliance unit test suite.
7 |
8 | This software is Public Domain and may be used without restrictions.
9 |
10 | "Now we have booze and barflies entering the discussion, plus rumours of
11 | DBAs on drugs... and I won't tell you what flashes through my mind each
12 | time I read the subject line with 'Anal Compliance' in it. All around
13 | this is turning out to be a thoroughly unwholesome unit test."
14 |
15 | -- Ian Bicking
16 | '''
17 |
18 | __rcs_id__ = '$Id: dbapi20.py,v 1.10 2003/10/09 03:14:14 zenzen Exp $'
19 | __version__ = '$Revision: 1.10 $'[11:-2]
20 | __author__ = 'Stuart Bishop '
21 |
22 |
23 | # $Log: dbapi20.py,v $
24 | # Revision 1.10 2003/10/09 03:14:14 zenzen
25 | # Add test for DB API 2.0 optional extension, where database exceptions
26 | # are exposed as attributes on the Connection object.
27 | #
28 | # Revision 1.9 2003/08/13 01:16:36 zenzen
29 | # Minor tweak from Stefan Fleiter
30 | #
31 | # Revision 1.8 2003/04/10 00:13:25 zenzen
32 | # Changes, as per suggestions by M.-A. Lemburg
33 | # - Add a table prefix, to ensure namespace collisions can always be avoided
34 | #
35 | # Revision 1.7 2003/02/26 23:33:37 zenzen
36 | # Break out DDL into helper functions, as per request by David Rushby
37 | #
38 | # Revision 1.6 2003/02/21 03:04:33 zenzen
39 | # Stuff from Henrik Ekelund:
40 | # added test_None
41 | # added test_nextset & hooks
42 | #
43 | # Revision 1.5 2003/02/17 22:08:43 zenzen
44 | # Implement suggestions and code from Henrik Eklund - test that
45 | # cursor.arraysize defaults to 1 & generic cursor.callproc test added
46 | #
47 | # Revision 1.4 2003/02/15 00:16:33 zenzen
48 | # Changes, as per suggestions and bug reports by M.-A. Lemburg,
49 | # Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
50 | # - Class renamed
51 | # - Now a subclass of TestCase, to avoid requiring the driver stub
52 | # to use multiple inheritance
53 | # - Reversed the polarity of buggy test in test_description
54 | # - Test exception heirarchy correctly
55 | # - self.populate is now self._populate(), so if a driver stub
56 | # overrides self.ddl1 this change propogates
57 | # - VARCHAR columns now have a width, which will hopefully make the
58 | # DDL even more portible (this will be reversed if it causes more problems)
59 | # - cursor.rowcount being checked after various execute and fetchXXX methods
60 | # - Check for fetchall and fetchmany returning empty lists after results
61 | # are exhausted (already checking for empty lists if select retrieved
62 | # nothing
63 | # - Fix bugs in test_setoutputsize_basic and test_setinputsizes
64 | #
65 |
66 |
67 | class DatabaseAPI20Test(unittest.TestCase):
68 | ''' Test a database self.driver for DB API 2.0 compatibility.
69 | This implementation tests Gadfly, but the TestCase
70 | is structured so that other self.drivers can subclass this
71 | test case to ensure compiliance with the DB-API. It is
72 | expected that this TestCase may be expanded in the future
73 | if ambiguities or edge conditions are discovered.
74 |
75 | The 'Optional Extensions' are not yet being tested.
76 |
77 | self.drivers should subclass this test, overriding setUp, tearDown,
78 | self.driver, connect_args and connect_kw_args. Class specification
79 | should be as follows:
80 |
81 | import dbapi20
82 | class mytest(dbapi20.DatabaseAPI20Test):
83 | [...]
84 |
85 | Don't 'import DatabaseAPI20Test from dbapi20', or you will
86 | confuse the unit tester - just 'import dbapi20'.
87 | '''
88 |
89 | # The self.driver module. This should be the module where the 'connect'
90 | # method is to be found
91 | driver = None
92 | connect_args = () # List of arguments to pass to connect
93 | connect_kw_args = {} # Keyword arguments for connect
94 | table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
95 |
96 | ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
97 | ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
98 | xddl1 = 'drop table %sbooze' % table_prefix
99 | xddl2 = 'drop table %sbarflys' % table_prefix
100 |
101 | # Name of stored procedure to convert
102 | # string->lowercase
103 | lowerfunc = 'lower'
104 |
105 | # Some drivers may need to override these helpers, for example adding
106 | # a 'commit' after the execute.
107 | def executeDDL1(self, cursor):
108 | cursor.execute(self.ddl1)
109 |
110 | def executeDDL2(self, cursor):
111 | cursor.execute(self.ddl2)
112 |
113 | def setUp(self):
114 | ''' self.drivers should override this method to perform required setup
115 | if any is necessary, such as creating the database.
116 | '''
117 | pass
118 |
119 | def tearDown(self):
120 | ''' self.drivers should override this method to perform required
121 | cleanup if any is necessary, such as deleting the test database.
122 | The default drops the tables that may be created.
123 | '''
124 | con = self._connect()
125 | try:
126 | cur = con.cursor()
127 | for ddl in (self.xddl1, self.xddl2):
128 | try:
129 | cur.execute(ddl)
130 | con.commit()
131 | except self.driver.Error:
132 | # Assume table didn't exist. Other tests will check if
133 | # execute is busted.
134 | pass
135 | finally:
136 | con.close()
137 |
138 | def _connect(self):
139 | try:
140 | return self.driver.connect(
141 | *self.connect_args, **self.connect_kw_args)
142 | except AttributeError:
143 | self.fail("No connect method found in self.driver module")
144 |
145 | def test_connect(self):
146 | con = self._connect()
147 | con.close()
148 |
149 | def test_apilevel(self):
150 | try:
151 | # Must exist
152 | apilevel = self.driver.apilevel
153 | # Must equal 2.0
154 | self.assertEqual(apilevel, '2.0')
155 | except AttributeError:
156 | self.fail("Driver doesn't define apilevel")
157 |
158 | def test_threadsafety(self):
159 | try:
160 | # Must exist
161 | threadsafety = self.driver.threadsafety
162 | # Must be a valid value
163 | self.assertEqual(threadsafety in (0, 1, 2, 3), True)
164 | except AttributeError:
165 | self.fail("Driver doesn't define threadsafety")
166 |
167 | def test_paramstyle(self):
168 | try:
169 | # Must exist
170 | paramstyle = self.driver.paramstyle
171 | # Must be a valid value
172 | self.assertEqual(
173 | paramstyle in (
174 | 'qmark', 'numeric', 'named', 'format', 'pyformat'), True)
175 | except AttributeError:
176 | self.fail("Driver doesn't define paramstyle")
177 |
178 | def test_Exceptions(self):
179 | # Make sure required exceptions exist, and are in the
180 | # defined heirarchy.
181 | self.assertEqual(issubclass(self.driver.Warning, Exception), True)
182 | self.assertEqual(issubclass(self.driver.Error, Exception), True)
183 | self.assertEqual(
184 | issubclass(self.driver.InterfaceError, self.driver.Error), True)
185 | self.assertEqual(
186 | issubclass(self.driver.DatabaseError, self.driver.Error), True)
187 | self.assertEqual(
188 | issubclass(self.driver.OperationalError, self.driver.Error), True)
189 | self.assertEqual(
190 | issubclass(self.driver.IntegrityError, self.driver.Error), True)
191 | self.assertEqual(
192 | issubclass(self.driver.InternalError, self.driver.Error), True)
193 | self.assertEqual(
194 | issubclass(self.driver.ProgrammingError, self.driver.Error), True)
195 | self.assertEqual(
196 | issubclass(self.driver.NotSupportedError, self.driver.Error), True)
197 |
198 | def test_ExceptionsAsConnectionAttributes(self):
199 | # OPTIONAL EXTENSION
200 | # Test for the optional DB API 2.0 extension, where the exceptions
201 | # are exposed as attributes on the Connection object
202 | # I figure this optional extension will be implemented by any
203 | # driver author who is using this test suite, so it is enabled
204 | # by default.
205 | warnings.simplefilter("ignore")
206 | con = self._connect()
207 | drv = self.driver
208 | self.assertEqual(con.Warning is drv.Warning, True)
209 | self.assertEqual(con.Error is drv.Error, True)
210 | self.assertEqual(con.InterfaceError is drv.InterfaceError, True)
211 | self.assertEqual(con.DatabaseError is drv.DatabaseError, True)
212 | self.assertEqual(con.OperationalError is drv.OperationalError, True)
213 | self.assertEqual(con.IntegrityError is drv.IntegrityError, True)
214 | self.assertEqual(con.InternalError is drv.InternalError, True)
215 | self.assertEqual(con.ProgrammingError is drv.ProgrammingError, True)
216 | self.assertEqual(con.NotSupportedError is drv.NotSupportedError, True)
217 | warnings.resetwarnings()
218 | con.close()
219 |
220 | def test_commit(self):
221 | con = self._connect()
222 | try:
223 | # Commit must work, even if it doesn't do anything
224 | con.commit()
225 | finally:
226 | con.close()
227 |
228 | def test_rollback(self):
229 | con = self._connect()
230 | # If rollback is defined, it should either work or throw
231 | # the documented exception
232 | if hasattr(con, 'rollback'):
233 | try:
234 | con.rollback()
235 | except self.driver.NotSupportedError:
236 | pass
237 | con.close()
238 |
239 | def test_cursor(self):
240 | con = self._connect()
241 | try:
242 | con.cursor()
243 | finally:
244 | con.close()
245 |
246 | def test_cursor_isolation(self):
247 | con = self._connect()
248 | try:
249 | # Make sure cursors created from the same connection have
250 | # the documented transaction isolation level
251 | cur1 = con.cursor()
252 | cur2 = con.cursor()
253 | self.executeDDL1(cur1)
254 | cur1.execute(
255 | "insert into %sbooze values ('Victoria Bitter')" % (
256 | self.table_prefix))
257 | cur2.execute("select name from %sbooze" % self.table_prefix)
258 | booze = cur2.fetchall()
259 | self.assertEqual(len(booze), 1)
260 | self.assertEqual(len(booze[0]), 1)
261 | self.assertEqual(booze[0][0], 'Victoria Bitter')
262 | finally:
263 | con.close()
264 |
265 | def test_description(self):
266 | con = self._connect()
267 | try:
268 | cur = con.cursor()
269 | self.executeDDL1(cur)
270 | self.assertEqual(
271 | cur.description, None,
272 | 'cursor.description should be none after executing a '
273 | 'statement that can return no rows (such as DDL)')
274 | cur.execute('select name from %sbooze' % self.table_prefix)
275 | self.assertEqual(
276 | len(cur.description), 1,
277 | 'cursor.description describes too many columns')
278 | self.assertEqual(
279 | len(cur.description[0]), 7,
280 | 'cursor.description[x] tuples must have 7 elements')
281 | self.assertEqual(
282 | cur.description[0][0].lower(), b('name'),
283 | 'cursor.description[x][0] must return column name')
284 | self.assertEqual(
285 | cur.description[0][1], self.driver.STRING,
286 | 'cursor.description[x][1] must return column type. Got %r'
287 | % cur.description[0][1])
288 |
289 | # Make sure self.description gets reset
290 | self.executeDDL2(cur)
291 | self.assertEqual(
292 | cur.description, None,
293 | 'cursor.description not being set to None when executing '
294 | 'no-result statements (eg. DDL)')
295 | finally:
296 | con.close()
297 |
298 | def test_rowcount(self):
299 | con = self._connect()
300 | try:
301 | cur = con.cursor()
302 | self.executeDDL1(cur)
303 | self.assertEqual(
304 | cur.rowcount, -1,
305 | 'cursor.rowcount should be -1 after executing no-result '
306 | 'statements')
307 | cur.execute(
308 | "insert into %sbooze values ('Victoria Bitter')" % (
309 | self.table_prefix))
310 | self.assertEqual(
311 | cur.rowcount in (-1, 1), True,
312 | 'cursor.rowcount should == number or rows inserted, or '
313 | 'set to -1 after executing an insert statement')
314 | cur.execute("select name from %sbooze" % self.table_prefix)
315 | self.assertEqual(
316 | cur.rowcount in (-1, 1), True,
317 | 'cursor.rowcount should == number of rows returned, or '
318 | 'set to -1 after executing a select statement')
319 | self.executeDDL2(cur)
320 | self.assertEqual(
321 | cur.rowcount, -1,
322 | 'cursor.rowcount not being reset to -1 after executing '
323 | 'no-result statements')
324 | finally:
325 | con.close()
326 |
327 | lower_func = 'lower'
328 |
329 | def test_callproc(self):
330 | con = self._connect()
331 | try:
332 | cur = con.cursor()
333 | if self.lower_func and hasattr(cur, 'callproc'):
334 | r = cur.callproc(self.lower_func, ('FOO',))
335 | self.assertEqual(len(r), 1)
336 | self.assertEqual(r[0], 'FOO')
337 | r = cur.fetchall()
338 | self.assertEqual(len(r), 1, 'callproc produced no result set')
339 | self.assertEqual(
340 | len(r[0]), 1, 'callproc produced invalid result set')
341 | self.assertEqual(
342 | r[0][0], 'foo', 'callproc produced invalid results')
343 | finally:
344 | con.close()
345 |
346 | def test_close(self):
347 | con = self._connect()
348 | try:
349 | cur = con.cursor()
350 | finally:
351 | con.close()
352 |
353 | # cursor.execute should raise an Error if called after connection
354 | # closed
355 | self.assertRaises(self.driver.Error, self.executeDDL1, cur)
356 |
357 | # connection.commit should raise an Error if called after connection'
358 | # closed.'
359 | self.assertRaises(self.driver.Error, con.commit)
360 |
361 | # connection.close should raise an Error if called more than once
362 | self.assertRaises(self.driver.Error, con.close)
363 |
364 | def test_execute(self):
365 | con = self._connect()
366 | try:
367 | cur = con.cursor()
368 | self._paraminsert(cur)
369 | finally:
370 | con.close()
371 |
372 | def _paraminsert(self, cur):
373 | self.executeDDL1(cur)
374 | cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
375 | self.table_prefix))
376 | self.assertEqual(cur.rowcount in (-1, 1), True)
377 |
378 | if self.driver.paramstyle == 'qmark':
379 | cur.execute(
380 | 'insert into %sbooze values (?)' % self.table_prefix,
381 | ("Cooper's",))
382 | elif self.driver.paramstyle == 'numeric':
383 | cur.execute(
384 | 'insert into %sbooze values (:1)' % self.table_prefix,
385 | ("Cooper's",))
386 | elif self.driver.paramstyle == 'named':
387 | cur.execute(
388 | 'insert into %sbooze values (:beer)' % self.table_prefix,
389 | {'beer': "Cooper's"})
390 | elif self.driver.paramstyle == 'format':
391 | cur.execute(
392 | 'insert into %sbooze values (%%s)' % self.table_prefix,
393 | ("Cooper's",))
394 | elif self.driver.paramstyle == 'pyformat':
395 | cur.execute(
396 | 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
397 | {'beer': "Cooper's"})
398 | else:
399 | self.fail('Invalid paramstyle')
400 | self.assertEqual(cur.rowcount in (-1, 1), True)
401 |
402 | cur.execute('select name from %sbooze' % self.table_prefix)
403 | res = cur.fetchall()
404 | self.assertEqual(
405 | len(res), 2, 'cursor.fetchall returned too few rows')
406 | beers = [res[0][0], res[1][0]]
407 | beers.sort()
408 | self.assertEqual(
409 | beers[0], "Cooper's",
410 | 'cursor.fetchall retrieved incorrect data, or data inserted '
411 | 'incorrectly')
412 | self.assertEqual(
413 | beers[1], "Victoria Bitter",
414 | 'cursor.fetchall retrieved incorrect data, or data inserted '
415 | 'incorrectly')
416 |
417 | def test_executemany(self):
418 | con = self._connect()
419 | try:
420 | cur = con.cursor()
421 | self.executeDDL1(cur)
422 | largs = [("Cooper's",), ("Boag's",)]
423 | margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}]
424 | if self.driver.paramstyle == 'qmark':
425 | cur.executemany(
426 | 'insert into %sbooze values (?)' % self.table_prefix,
427 | largs
428 | )
429 | elif self.driver.paramstyle == 'numeric':
430 | cur.executemany(
431 | 'insert into %sbooze values (:1)' % self.table_prefix,
432 | largs
433 | )
434 | elif self.driver.paramstyle == 'named':
435 | cur.executemany(
436 | 'insert into %sbooze values (:beer)' % self.table_prefix,
437 | margs
438 | )
439 | elif self.driver.paramstyle == 'format':
440 | cur.executemany(
441 | 'insert into %sbooze values (%%s)' % self.table_prefix,
442 | largs
443 | )
444 | elif self.driver.paramstyle == 'pyformat':
445 | cur.executemany(
446 | 'insert into %sbooze values (%%(beer)s)' % (
447 | self.table_prefix), margs)
448 | else:
449 | self.fail('Unknown paramstyle')
450 | self.assertEqual(
451 | cur.rowcount in (-1, 2), True,
452 | 'insert using cursor.executemany set cursor.rowcount to '
453 | 'incorrect value %r' % cur.rowcount)
454 | cur.execute('select name from %sbooze' % self.table_prefix)
455 | res = cur.fetchall()
456 | self.assertEqual(
457 | len(res), 2,
458 | 'cursor.fetchall retrieved incorrect number of rows')
459 | beers = [res[0][0], res[1][0]]
460 | beers.sort()
461 | self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved')
462 | self.assertEqual(beers[1], "Cooper's", 'incorrect data retrieved')
463 | finally:
464 | con.close()
465 |
466 | def test_fetchone(self):
467 | con = self._connect()
468 | try:
469 | cur = con.cursor()
470 |
471 | # cursor.fetchone should raise an Error if called before
472 | # executing a select-type query
473 | self.assertRaises(self.driver.Error, cur.fetchone)
474 |
475 | # cursor.fetchone should raise an Error if called after
476 | # executing a query that cannnot return rows
477 | self.executeDDL1(cur)
478 | self.assertRaises(self.driver.Error, cur.fetchone)
479 |
480 | cur.execute('select name from %sbooze' % self.table_prefix)
481 | self.assertEqual(
482 | cur.fetchone(), None,
483 | 'cursor.fetchone should return None if a query retrieves '
484 | 'no rows')
485 | self.assertEqual(cur.rowcount in (-1, 0), True)
486 |
487 | # cursor.fetchone should raise an Error if called after
488 | # executing a query that cannnot return rows
489 | cur.execute(
490 | "insert into %sbooze values ('Victoria Bitter')" % (
491 | self.table_prefix))
492 | self.assertRaises(self.driver.Error, cur.fetchone)
493 |
494 | cur.execute('select name from %sbooze' % self.table_prefix)
495 | r = cur.fetchone()
496 | self.assertEqual(
497 | len(r), 1,
498 | 'cursor.fetchone should have retrieved a single row')
499 | self.assertEqual(
500 | r[0], 'Victoria Bitter',
501 | 'cursor.fetchone retrieved incorrect data')
502 | self.assertEqual(
503 | cur.fetchone(), None,
504 | 'cursor.fetchone should return None if no more rows available')
505 | self.assertEqual(cur.rowcount in (-1, 1), True)
506 | finally:
507 | con.close()
508 |
509 | samples = [
510 | 'Carlton Cold',
511 | 'Carlton Draft',
512 | 'Mountain Goat',
513 | 'Redback',
514 | 'Victoria Bitter',
515 | 'XXXX'
516 | ]
517 |
518 | def _populate(self):
519 | ''' Return a list of sql commands to setup the DB for the fetch
520 | tests.
521 | '''
522 | populate = [
523 | "insert into %sbooze values ('%s')" % (self.table_prefix, s)
524 | for s in self.samples]
525 | return populate
526 |
527 | def test_fetchmany(self):
528 | con = self._connect()
529 | try:
530 | cur = con.cursor()
531 |
532 | # cursor.fetchmany should raise an Error if called without
533 | # issuing a query
534 | self.assertRaises(self.driver.Error, cur.fetchmany, 4)
535 |
536 | self.executeDDL1(cur)
537 | for sql in self._populate():
538 | cur.execute(sql)
539 |
540 | cur.execute('select name from %sbooze' % self.table_prefix)
541 | r = cur.fetchmany()
542 | self.assertEqual(
543 | len(r), 1,
544 | 'cursor.fetchmany retrieved incorrect number of rows, '
545 | 'default of arraysize is one.')
546 | cur.arraysize = 10
547 | r = cur.fetchmany(3) # Should get 3 rows
548 | self.assertEqual(
549 | len(r), 3,
550 | 'cursor.fetchmany retrieved incorrect number of rows')
551 | r = cur.fetchmany(4) # Should get 2 more
552 | self.assertEqual(
553 | len(r), 2,
554 | 'cursor.fetchmany retrieved incorrect number of rows')
555 | r = cur.fetchmany(4) # Should be an empty sequence
556 | self.assertEqual(
557 | len(r), 0,
558 | 'cursor.fetchmany should return an empty sequence after '
559 | 'results are exhausted')
560 | self.assertEqual(cur.rowcount in (-1, 6), True)
561 |
562 | # Same as above, using cursor.arraysize
563 | cur.arraysize = 4
564 | cur.execute('select name from %sbooze' % self.table_prefix)
565 | r = cur.fetchmany() # Should get 4 rows
566 | self.assertEqual(
567 | len(r), 4,
568 | 'cursor.arraysize not being honoured by fetchmany')
569 | r = cur.fetchmany() # Should get 2 more
570 | self.assertEqual(len(r), 2)
571 | r = cur.fetchmany() # Should be an empty sequence
572 | self.assertEqual(len(r), 0)
573 | self.assertEqual(cur.rowcount in (-1, 6), True)
574 |
575 | cur.arraysize = 6
576 | cur.execute('select name from %sbooze' % self.table_prefix)
577 | rows = cur.fetchmany() # Should get all rows
578 | self.assertEqual(cur.rowcount in (-1, 6), True)
579 | self.assertEqual(len(rows), 6)
580 | self.assertEqual(len(rows), 6)
581 | rows = [row[0] for row in rows]
582 | rows.sort()
583 |
584 | # Make sure we get the right data back out
585 | for i in range(0, 6):
586 | self.assertEqual(
587 | rows[i], self.samples[i],
588 | 'incorrect data retrieved by cursor.fetchmany')
589 |
590 | rows = cur.fetchmany() # Should return an empty list
591 | self.assertEqual(
592 | len(rows), 0,
593 | 'cursor.fetchmany should return an empty sequence if '
594 | 'called after the whole result set has been fetched')
595 | self.assertEqual(cur.rowcount in (-1, 6), True)
596 |
597 | self.executeDDL2(cur)
598 | cur.execute('select name from %sbarflys' % self.table_prefix)
599 | r = cur.fetchmany() # Should get empty sequence
600 | self.assertEqual(
601 | len(r), 0,
602 | 'cursor.fetchmany should return an empty sequence if '
603 | 'query retrieved no rows')
604 | self.assertEqual(cur.rowcount in (-1, 0), True)
605 |
606 | finally:
607 | con.close()
608 |
609 | def test_fetchall(self):
610 | con = self._connect()
611 | try:
612 | cur = con.cursor()
613 | # cursor.fetchall should raise an Error if called
614 | # without executing a query that may return rows (such
615 | # as a select)
616 | self.assertRaises(self.driver.Error, cur.fetchall)
617 |
618 | self.executeDDL1(cur)
619 | for sql in self._populate():
620 | cur.execute(sql)
621 |
622 | # cursor.fetchall should raise an Error if called
623 | # after executing a a statement that cannot return rows
624 | self.assertRaises(self.driver.Error, cur.fetchall)
625 |
626 | cur.execute('select name from %sbooze' % self.table_prefix)
627 | rows = cur.fetchall()
628 | self.assertEqual(cur.rowcount in (-1, len(self.samples)), True)
629 | self.assertEqual(
630 | len(rows), len(self.samples),
631 | 'cursor.fetchall did not retrieve all rows')
632 | rows = [r[0] for r in rows]
633 | rows.sort()
634 | for i in range(0, len(self.samples)):
635 | self.assertEqual(
636 | rows[i], self.samples[i],
637 | 'cursor.fetchall retrieved incorrect rows')
638 | rows = cur.fetchall()
639 | self.assertEqual(
640 | len(rows), 0,
641 | 'cursor.fetchall should return an empty list if called '
642 | 'after the whole result set has been fetched')
643 | self.assertEqual(cur.rowcount in (-1, len(self.samples)), True)
644 |
645 | self.executeDDL2(cur)
646 | cur.execute('select name from %sbarflys' % self.table_prefix)
647 | rows = cur.fetchall()
648 | self.assertEqual(cur.rowcount in (-1, 0), True)
649 | self.assertEqual(
650 | len(rows), 0,
651 | 'cursor.fetchall should return an empty list if '
652 | 'a select query returns no rows')
653 |
654 | finally:
655 | con.close()
656 |
657 | def test_mixedfetch(self):
658 | con = self._connect()
659 | try:
660 | cur = con.cursor()
661 | self.executeDDL1(cur)
662 | for sql in self._populate():
663 | cur.execute(sql)
664 |
665 | cur.execute('select name from %sbooze' % self.table_prefix)
666 | rows1 = cur.fetchone()
667 | rows23 = cur.fetchmany(2)
668 | rows4 = cur.fetchone()
669 | rows56 = cur.fetchall()
670 | self.assertEqual(cur.rowcount in (-1, 6), True)
671 | self.assertEqual(
672 | len(rows23), 2, 'fetchmany returned incorrect number of rows')
673 | self.assertEqual(
674 | len(rows56), 2, 'fetchall returned incorrect number of rows')
675 |
676 | rows = [rows1[0]]
677 | rows.extend([rows23[0][0], rows23[1][0]])
678 | rows.append(rows4[0])
679 | rows.extend([rows56[0][0], rows56[1][0]])
680 | rows.sort()
681 | for i in range(0, len(self.samples)):
682 | self.assertEqual(
683 | rows[i], self.samples[i],
684 | 'incorrect data retrieved or inserted')
685 | finally:
686 | con.close()
687 |
688 | def help_nextset_setUp(self, cur):
689 | ''' Should create a procedure called deleteme
690 | that returns two result sets, first the
691 | number of rows in booze then "name from booze"
692 | '''
693 | raise NotImplementedError('Helper not implemented')
694 |
695 | def help_nextset_tearDown(self, cur):
696 | 'If cleaning up is needed after nextSetTest'
697 | raise NotImplementedError('Helper not implemented')
698 |
699 | def test_nextset(self):
700 | con = self._connect()
701 | try:
702 | cur = con.cursor()
703 | if not hasattr(cur, 'nextset'):
704 | return
705 |
706 | try:
707 | self.executeDDL1(cur)
708 | sql = self._populate()
709 | for sql in self._populate():
710 | cur.execute(sql)
711 |
712 | self.help_nextset_setUp(cur)
713 |
714 | cur.callproc('deleteme')
715 | numberofrows = cur.fetchone()
716 | assert numberofrows[0] == len(self.samples)
717 | assert cur.nextset()
718 | names = cur.fetchall()
719 | assert len(names) == len(self.samples)
720 | s = cur.nextset()
721 | assert s is None, 'No more return sets, should return None'
722 | finally:
723 | self.help_nextset_tearDown(cur)
724 |
725 | finally:
726 | con.close()
727 |
728 | '''
729 | def test_nextset(self):
730 | raise NotImplementedError('Drivers need to override this test')
731 | '''
732 |
733 | def test_arraysize(self):
734 | # Not much here - rest of the tests for this are in test_fetchmany
735 | con = self._connect()
736 | try:
737 | cur = con.cursor()
738 | self.assertEqual(
739 | hasattr(cur, 'arraysize'), True,
740 | 'cursor.arraysize must be defined')
741 | finally:
742 | con.close()
743 |
744 | def test_setinputsizes(self):
745 | con = self._connect()
746 | try:
747 | cur = con.cursor()
748 | cur.setinputsizes((25,))
749 | self._paraminsert(cur) # Make sure cursor still works
750 | finally:
751 | con.close()
752 |
753 | def test_setoutputsize_basic(self):
754 | # Basic test is to make sure setoutputsize doesn't blow up
755 | con = self._connect()
756 | try:
757 | cur = con.cursor()
758 | cur.setoutputsize(1000)
759 | cur.setoutputsize(2000, 0)
760 | self._paraminsert(cur) # Make sure the cursor still works
761 | finally:
762 | con.close()
763 |
764 | def test_setoutputsize(self):
765 | # Real test for setoutputsize is driver dependant
766 | raise NotImplementedError('Driver need to override this test')
767 |
768 | def test_None(self):
769 | con = self._connect()
770 | try:
771 | cur = con.cursor()
772 | self.executeDDL1(cur)
773 | cur.execute(
774 | 'insert into %sbooze values (NULL)' % self.table_prefix)
775 | cur.execute('select name from %sbooze' % self.table_prefix)
776 | r = cur.fetchall()
777 | self.assertEqual(len(r), 1)
778 | self.assertEqual(len(r[0]), 1)
779 | self.assertEqual(r[0][0], None, 'NULL value not returned as None')
780 | finally:
781 | con.close()
782 |
783 | def test_Date(self):
784 | self.driver.Date(2002, 12, 25)
785 | self.driver.DateFromTicks(
786 | time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
787 | # Can we assume this? API doesn't specify, but it seems implied
788 | # self.assertEqual(str(d1),str(d2))
789 |
790 | def test_Time(self):
791 | self.driver.Time(13, 45, 30)
792 | self.driver.TimeFromTicks(
793 | time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
794 | # Can we assume this? API doesn't specify, but it seems implied
795 | # self.assertEqual(str(t1),str(t2))
796 |
797 | def test_Timestamp(self):
798 | self.driver.Timestamp(2002, 12, 25, 13, 45, 30)
799 | self.driver.TimestampFromTicks(
800 | time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
801 | # Can we assume this? API doesn't specify, but it seems implied
802 | # self.assertEqual(str(t1),str(t2))
803 |
804 | def test_Binary(self):
805 | self.driver.Binary(b('Something'))
806 | self.driver.Binary(b(''))
807 |
808 | def test_STRING(self):
809 | self.assertEqual(
810 | hasattr(self.driver, 'STRING'), True,
811 | 'module.STRING must be defined')
812 |
813 | def test_BINARY(self):
814 | self.assertEqual(
815 | hasattr(self.driver, 'BINARY'), True,
816 | 'module.BINARY must be defined.')
817 |
818 | def test_NUMBER(self):
819 | self.assertTrue(
820 | hasattr(self.driver, 'NUMBER'), 'module.NUMBER must be defined.')
821 |
822 | def test_DATETIME(self):
823 | self.assertEqual(
824 | hasattr(self.driver, 'DATETIME'), True,
825 | 'module.DATETIME must be defined.')
826 |
827 | def test_ROWID(self):
828 | self.assertEqual(
829 | hasattr(self.driver, 'ROWID'), True,
830 | 'module.ROWID must be defined.')
831 |
--------------------------------------------------------------------------------
/tests/performance.py:
--------------------------------------------------------------------------------
1 | import pg8000
2 | from pg8000.tests.connection_settings import db_connect
3 | import time
4 | import warnings
5 | from contextlib import closing
6 | from decimal import Decimal
7 |
8 |
9 | whole_begin_time = time.time()
10 |
11 | tests = (
12 | ("cast(id / 100 as int2)", 'int2'),
13 | ("cast(id as int4)", 'int4'),
14 | ("cast(id * 100 as int8)", 'int8'),
15 | ("(id %% 2) = 0", 'bool'),
16 | ("N'Static text string'", 'txt'),
17 | ("cast(id / 100 as float4)", 'float4'),
18 | ("cast(id / 100 as float8)", 'float8'),
19 | ("cast(id / 100 as numeric)", 'numeric'),
20 | ("timestamp '2001-09-28' + id * interval '1 second'", 'timestamp'),
21 | )
22 |
23 | with warnings.catch_warnings(), closing(pg8000.connect(**db_connect)) as db:
24 | for txt, name in tests:
25 | query = """SELECT {0} AS column1, {0} AS column2, {0} AS column3,
26 | {0} AS column4, {0} AS column5, {0} AS column6, {0} AS column7
27 | FROM (SELECT generate_series(1, 10000) AS id) AS tbl""".format(txt)
28 | cursor = db.cursor()
29 | print("Beginning %s test..." % name)
30 | for i in range(1, 5):
31 | begin_time = time.time()
32 | cursor.execute(query)
33 | for row in cursor:
34 | pass
35 | end_time = time.time()
36 | print("Attempt %s - %s seconds." % (i, end_time - begin_time))
37 | db.commit()
38 | cursor = db.cursor()
39 | cursor.execute(
40 | "CREATE TEMPORARY TABLE t1 (f1 serial primary key, "
41 | "f2 bigint not null, f3 varchar(50) null, f4 bool)")
42 | db.commit()
43 | params = [(Decimal('7.4009'), 'season of mists...', True)] * 1000
44 | print("Beginning executemany test...")
45 | for i in range(1, 5):
46 | begin_time = time.time()
47 | cursor.executemany(
48 | "insert into t1 (f2, f3, f4) values (%s, %s, %s)", params)
49 | db.commit()
50 | end_time = time.time()
51 | print("Attempt {0} took {1} seconds.".format(i, end_time - begin_time))
52 |
53 | print("Beginning reuse statements test...")
54 | begin_time = time.time()
55 | for i in range(2000):
56 | cursor.execute("select count(*) from t1")
57 | cursor.fetchall()
58 | print("Took {0} seconds.".format(time.time() - begin_time))
59 |
60 | print("Whole time - %s seconds." % (time.time() - whole_begin_time))
61 |
--------------------------------------------------------------------------------
/tests/stress.py:
--------------------------------------------------------------------------------
1 | import pg8000
2 | from pg8000.tests.connection_settings import db_connect
3 | from contextlib import closing
4 |
5 |
6 | with closing(pg8000.connect(**db_connect)) as db:
7 | for i in range(100):
8 | cursor = db.cursor()
9 | cursor.execute("""
10 | SELECT n.nspname as "Schema",
11 | pg_catalog.format_type(t.oid, NULL) AS "Name",
12 | pg_catalog.obj_description(t.oid, 'pg_type') as "Description"
13 | FROM pg_catalog.pg_type t
14 | LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
15 | left join pg_catalog.pg_namespace kj on n.oid = t.typnamespace
16 | WHERE (t.typrelid = 0
17 | OR (SELECT c.relkind = 'c'
18 | FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
19 | AND NOT EXISTS(
20 | SELECT 1 FROM pg_catalog.pg_type el
21 | WHERE el.oid = t.typelem AND el.typarray = t.oid)
22 | AND pg_catalog.pg_type_is_visible(t.oid)
23 | ORDER BY 1, 2;""")
24 |
--------------------------------------------------------------------------------
/tests/test_connection.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pg8000
3 | from connection_settings import db_connect
4 | from six import PY2, u
5 | import sys
6 | from distutils.version import LooseVersion
7 | import socket
8 | import struct
9 |
10 |
11 | # Check if running in Jython
12 | if 'java' in sys.platform:
13 | from javax.net.ssl import TrustManager, X509TrustManager
14 | from jarray import array
15 | from javax.net.ssl import SSLContext
16 |
17 | class TrustAllX509TrustManager(X509TrustManager):
18 | '''Define a custom TrustManager which will blindly accept all
19 | certificates'''
20 |
21 | def checkClientTrusted(self, chain, auth):
22 | pass
23 |
24 | def checkServerTrusted(self, chain, auth):
25 | pass
26 |
27 | def getAcceptedIssuers(self):
28 | return None
29 | # Create a static reference to an SSLContext which will use
30 | # our custom TrustManager
31 | trust_managers = array([TrustAllX509TrustManager()], TrustManager)
32 | TRUST_ALL_CONTEXT = SSLContext.getInstance("SSL")
33 | TRUST_ALL_CONTEXT.init(None, trust_managers, None)
34 | # Keep a static reference to the JVM's default SSLContext for restoring
35 | # at a later time
36 | DEFAULT_CONTEXT = SSLContext.getDefault()
37 |
38 |
39 | def trust_all_certificates(f):
40 | '''Decorator function that will make it so the context of the decorated
41 | method will run with our TrustManager that accepts all certificates'''
42 | def wrapped(*args, **kwargs):
43 | # Only do this if running under Jython
44 | if 'java' in sys.platform:
45 | from javax.net.ssl import SSLContext
46 | SSLContext.setDefault(TRUST_ALL_CONTEXT)
47 | try:
48 | res = f(*args, **kwargs)
49 | return res
50 | finally:
51 | SSLContext.setDefault(DEFAULT_CONTEXT)
52 | else:
53 | return f(*args, **kwargs)
54 | return wrapped
55 |
56 |
57 | # Tests related to connecting to a database.
58 | class Tests(unittest.TestCase):
59 | def testSocketMissing(self):
60 | conn_params = {
61 | 'unix_sock': "/file-does-not-exist",
62 | 'user': "doesn't-matter"}
63 | self.assertRaises(pg8000.InterfaceError, pg8000.connect, **conn_params)
64 |
65 | def testDatabaseMissing(self):
66 | data = db_connect.copy()
67 | data["database"] = "missing-db"
68 | self.assertRaises(pg8000.ProgrammingError, pg8000.connect, **data)
69 |
70 | def testNotify(self):
71 |
72 | try:
73 | db = pg8000.connect(**db_connect)
74 | self.assertEqual(list(db.notifications), [])
75 | cursor = db.cursor()
76 | cursor.execute("LISTEN test")
77 | cursor.execute("NOTIFY test")
78 | db.commit()
79 |
80 | cursor.execute("VALUES (1, 2), (3, 4), (5, 6)")
81 | self.assertEqual(len(db.notifications), 1)
82 | self.assertEqual(db.notifications[0][1], "test")
83 | finally:
84 | cursor.close()
85 | db.close()
86 |
87 | # This requires a line in pg_hba.conf that requires md5 for the database
88 | # pg8000_md5
89 |
90 | def testMd5(self):
91 | data = db_connect.copy()
92 | data["database"] = "pg8000_md5"
93 |
94 | # Should only raise an exception saying db doesn't exist
95 | if PY2:
96 | self.assertRaises(
97 | pg8000.ProgrammingError, pg8000.connect, **data)
98 | else:
99 | self.assertRaisesRegex(
100 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data)
101 |
102 | # This requires a line in pg_hba.conf that requires gss for the database
103 | # pg8000_gss
104 |
105 | def testGss(self):
106 | data = db_connect.copy()
107 | data["database"] = "pg8000_gss"
108 |
109 | # Should raise an exception saying gss isn't supported
110 | if PY2:
111 | self.assertRaises(pg8000.InterfaceError, pg8000.connect, **data)
112 | else:
113 | self.assertRaisesRegex(
114 | pg8000.InterfaceError,
115 | "Authentication method 7 not supported by pg8000.",
116 | pg8000.connect, **data)
117 |
118 | @trust_all_certificates
119 | def testSsl(self):
120 | data = db_connect.copy()
121 | data["ssl"] = True
122 | db = pg8000.connect(**data)
123 | db.close()
124 |
125 | # This requires a line in pg_hba.conf that requires 'password' for the
126 | # database pg8000_password
127 |
128 | def testPassword(self):
129 | data = db_connect.copy()
130 | data["database"] = "pg8000_password"
131 |
132 | # Should only raise an exception saying db doesn't exist
133 | if PY2:
134 | self.assertRaises(
135 | pg8000.ProgrammingError, pg8000.connect, **data)
136 | else:
137 | self.assertRaisesRegex(
138 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data)
139 |
140 | def testUnicodeDatabaseName(self):
141 | data = db_connect.copy()
142 | data["database"] = "pg8000_sn\uFF6Fw"
143 |
144 | # Should only raise an exception saying db doesn't exist
145 | if PY2:
146 | self.assertRaises(
147 | pg8000.ProgrammingError, pg8000.connect, **data)
148 | else:
149 | self.assertRaisesRegex(
150 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data)
151 |
152 | def testBytesDatabaseName(self):
153 | data = db_connect.copy()
154 |
155 | # Should only raise an exception saying db doesn't exist
156 | if PY2:
157 | data["database"] = "pg8000_sn\uFF6Fw"
158 | self.assertRaises(
159 | pg8000.ProgrammingError, pg8000.connect, **data)
160 | else:
161 | data["database"] = bytes("pg8000_sn\uFF6Fw", 'utf8')
162 | self.assertRaisesRegex(
163 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data)
164 |
165 | def testBytesPassword(self):
166 | db = pg8000.connect(**db_connect)
167 | # Create user
168 | username = 'boltzmann'
169 | password = u('cha\uFF6Fs')
170 | cur = db.cursor()
171 |
172 | # Delete user if left over from previous run
173 | try:
174 | cur.execute("drop role " + username)
175 | except pg8000.ProgrammingError:
176 | cur.execute("rollback")
177 |
178 | cur.execute(
179 | "create user " + username + " with password '" + password + "';")
180 | cur.execute('commit;')
181 | db.close()
182 |
183 | data = db_connect.copy()
184 | data['user'] = username
185 | data['password'] = password.encode('utf8')
186 | data['database'] = 'pg8000_md5'
187 | if PY2:
188 | self.assertRaises(
189 | pg8000.ProgrammingError, pg8000.connect, **data)
190 | else:
191 | self.assertRaisesRegex(
192 | pg8000.ProgrammingError, '3D000', pg8000.connect, **data)
193 |
194 | db = pg8000.connect(**db_connect)
195 | cur = db.cursor()
196 | cur.execute("drop role " + username)
197 | cur.execute("commit;")
198 | db.close()
199 |
200 | def testBrokenPipe(self):
201 | db1 = pg8000.connect(**db_connect)
202 | db2 = pg8000.connect(**db_connect)
203 |
204 | try:
205 | cur1 = db1.cursor()
206 | cur2 = db2.cursor()
207 |
208 | cur1.execute("select pg_backend_pid()")
209 | pid1 = cur1.fetchone()[0]
210 |
211 | cur2.execute("select pg_terminate_backend(%s)", (pid1,))
212 | try:
213 | cur1.execute("select 1")
214 | except Exception as e:
215 | self.assertTrue(isinstance(e, (socket.error, struct.error)))
216 |
217 | cur2.close()
218 | finally:
219 | db1.close()
220 | db2.close()
221 |
222 | def testApplicatioName(self):
223 | params = db_connect.copy()
224 | params['application_name'] = 'my test application name'
225 | db = pg8000.connect(**params)
226 | cur = db.cursor()
227 |
228 | if db._server_version >= LooseVersion('9.2'):
229 | cur.execute('select application_name from pg_stat_activity '
230 | ' where pid = pg_backend_pid()')
231 | else:
232 | # for pg9.1 and earlier, procpod field rather than pid
233 | cur.execute('select application_name from pg_stat_activity '
234 | ' where procpid = pg_backend_pid()')
235 |
236 | application_name = cur.fetchone()[0]
237 | self.assertEqual(application_name, 'my test application name')
238 |
239 |
240 | if __name__ == "__main__":
241 | unittest.main()
242 |
--------------------------------------------------------------------------------
/tests/test_copy.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pg8000
3 | from connection_settings import db_connect
4 | from six import b, BytesIO, u, iteritems
5 | from sys import exc_info
6 |
7 |
8 | class Tests(unittest.TestCase):
9 | def setUp(self):
10 | self.db = pg8000.connect(**db_connect)
11 | try:
12 | cursor = self.db.cursor()
13 | try:
14 | cursor = self.db.cursor()
15 | cursor.execute("DROP TABLE t1")
16 | except pg8000.DatabaseError:
17 | e = exc_info()[1]
18 | # the only acceptable error is:
19 | msg = e.args[0]
20 | code = msg['C']
21 | self.assertEqual(
22 | code, '42P01', # table does not exist
23 | "incorrect error for drop table: " + str(msg))
24 | self.db.rollback()
25 | cursor.execute(
26 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
27 | "f2 int not null, f3 varchar(50) null)")
28 | finally:
29 | cursor.close()
30 |
31 | def tearDown(self):
32 | self.db.close()
33 |
34 | def testCopyToWithTable(self):
35 | try:
36 | cursor = self.db.cursor()
37 | cursor.execute(
38 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (1, 1, 1))
39 | cursor.execute(
40 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (2, 2, 2))
41 | cursor.execute(
42 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)", (3, 3, 3))
43 |
44 | stream = BytesIO()
45 | cursor.execute("copy t1 to stdout", stream=stream)
46 | self.assertEqual(
47 | stream.getvalue(), b("1\t1\t1\n2\t2\t2\n3\t3\t3\n"))
48 | self.assertEqual(cursor.rowcount, 3)
49 | self.db.commit()
50 | finally:
51 | cursor.close()
52 |
53 | def testCopyToWithQuery(self):
54 | try:
55 | cursor = self.db.cursor()
56 | stream = BytesIO()
57 | cursor.execute(
58 | "COPY (SELECT 1 as One, 2 as Two) TO STDOUT WITH DELIMITER "
59 | "'X' CSV HEADER QUOTE AS 'Y' FORCE QUOTE Two", stream=stream)
60 | self.assertEqual(stream.getvalue(), b('oneXtwo\n1XY2Y\n'))
61 | self.assertEqual(cursor.rowcount, 1)
62 | self.db.rollback()
63 | finally:
64 | cursor.close()
65 |
66 | def testCopyFromWithTable(self):
67 | try:
68 | cursor = self.db.cursor()
69 | stream = BytesIO(b("1\t1\t1\n2\t2\t2\n3\t3\t3\n"))
70 | cursor.execute("copy t1 from STDIN", stream=stream)
71 | self.assertEqual(cursor.rowcount, 3)
72 |
73 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
74 | retval = cursor.fetchall()
75 | self.assertEqual(retval, ([1, 1, '1'], [2, 2, '2'], [3, 3, '3']))
76 | self.db.rollback()
77 | finally:
78 | cursor.close()
79 |
80 | def testCopyFromWithQuery(self):
81 | try:
82 | cursor = self.db.cursor()
83 | stream = BytesIO(b("f1Xf2\n1XY1Y\n"))
84 | cursor.execute(
85 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
86 | "QUOTE AS 'Y' FORCE NOT NULL f1", stream=stream)
87 | self.assertEqual(cursor.rowcount, 1)
88 |
89 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
90 | retval = cursor.fetchall()
91 | self.assertEqual(retval, ([1, 1, None],))
92 | self.db.commit()
93 | finally:
94 | cursor.close()
95 |
96 | def testCopyFromWithError(self):
97 | try:
98 | cursor = self.db.cursor()
99 | stream = BytesIO(b("f1Xf2\n\n1XY1Y\n"))
100 | cursor.execute(
101 | "COPY t1 (f1, f2) FROM STDIN WITH DELIMITER 'X' CSV HEADER "
102 | "QUOTE AS 'Y' FORCE NOT NULL f1", stream=stream)
103 | self.assertTrue(False, "Should have raised an exception")
104 | except:
105 | args_dict = {
106 | 'S': u('ERROR'),
107 | 'C': u('22P02'),
108 | 'M': u('invalid input syntax for integer: ""'),
109 | 'W': u('COPY t1, line 2, column f1: ""'),
110 | 'F': u('numutils.c'),
111 | 'R': u('pg_atoi')
112 | }
113 | args = exc_info()[1].args[0]
114 | for k, v in iteritems(args_dict):
115 | self.assertEqual(args[k], v)
116 | finally:
117 | cursor.close()
118 |
119 |
120 | if __name__ == "__main__":
121 | unittest.main()
122 |
--------------------------------------------------------------------------------
/tests/test_dbapi.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import time
4 | import pg8000
5 | import datetime
6 | from connection_settings import db_connect
7 | from sys import exc_info
8 | from six import b
9 | from distutils.version import LooseVersion
10 |
11 |
12 | # DBAPI compatible interface tests
13 | class Tests(unittest.TestCase):
14 | def setUp(self):
15 | self.db = pg8000.connect(**db_connect)
16 |
17 | # Neither Windows nor Jython 2.5.3 have a time.tzset() so skip
18 | if hasattr(time, 'tzset'):
19 | os.environ['TZ'] = "UTC"
20 | time.tzset()
21 | self.HAS_TZSET = True
22 | else:
23 | self.HAS_TZSET = False
24 |
25 | try:
26 | c = self.db.cursor()
27 | try:
28 | c = self.db.cursor()
29 | c.execute("DROP TABLE t1")
30 | except pg8000.DatabaseError:
31 | e = exc_info()[1]
32 | # the only acceptable error is table does not exist
33 | self.assertEqual(e.args[0]['C'], '42P01')
34 | self.db.rollback()
35 | c.execute(
36 | "CREATE TEMPORARY TABLE t1 "
37 | "(f1 int primary key, f2 int not null, f3 varchar(50) null)")
38 | c.execute(
39 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
40 | (1, 1, None))
41 | c.execute(
42 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
43 | (2, 10, None))
44 | c.execute(
45 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
46 | (3, 100, None))
47 | c.execute(
48 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
49 | (4, 1000, None))
50 | c.execute(
51 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
52 | (5, 10000, None))
53 | self.db.commit()
54 | finally:
55 | c.close()
56 |
57 | def tearDown(self):
58 | self.db.close()
59 |
60 | def testParallelQueries(self):
61 | try:
62 | c1 = self.db.cursor()
63 | c2 = self.db.cursor()
64 |
65 | c1.execute("SELECT f1, f2, f3 FROM t1")
66 | while 1:
67 | row = c1.fetchone()
68 | if row is None:
69 | break
70 | f1, f2, f3 = row
71 | c2.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
72 | while 1:
73 | row = c2.fetchone()
74 | if row is None:
75 | break
76 | f1, f2, f3 = row
77 | finally:
78 | c1.close()
79 | c2.close()
80 |
81 | self.db.rollback()
82 |
83 | def testQmark(self):
84 | orig_paramstyle = pg8000.paramstyle
85 | try:
86 | pg8000.paramstyle = "qmark"
87 | c1 = self.db.cursor()
88 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > ?", (3,))
89 | while 1:
90 | row = c1.fetchone()
91 | if row is None:
92 | break
93 | f1, f2, f3 = row
94 | self.db.rollback()
95 | finally:
96 | pg8000.paramstyle = orig_paramstyle
97 | c1.close()
98 |
99 | def testNumeric(self):
100 | orig_paramstyle = pg8000.paramstyle
101 | try:
102 | pg8000.paramstyle = "numeric"
103 | c1 = self.db.cursor()
104 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > :1", (3,))
105 | while 1:
106 | row = c1.fetchone()
107 | if row is None:
108 | break
109 | f1, f2, f3 = row
110 | self.db.rollback()
111 | finally:
112 | pg8000.paramstyle = orig_paramstyle
113 | c1.close()
114 |
115 | def testNamed(self):
116 | orig_paramstyle = pg8000.paramstyle
117 | try:
118 | pg8000.paramstyle = "named"
119 | c1 = self.db.cursor()
120 | c1.execute(
121 | "SELECT f1, f2, f3 FROM t1 WHERE f1 > :f1", {"f1": 3})
122 | while 1:
123 | row = c1.fetchone()
124 | if row is None:
125 | break
126 | f1, f2, f3 = row
127 | self.db.rollback()
128 | finally:
129 | pg8000.paramstyle = orig_paramstyle
130 | c1.close()
131 |
132 | def testFormat(self):
133 | orig_paramstyle = pg8000.paramstyle
134 | try:
135 | pg8000.paramstyle = "format"
136 | c1 = self.db.cursor()
137 | c1.execute("SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (3,))
138 | while 1:
139 | row = c1.fetchone()
140 | if row is None:
141 | break
142 | f1, f2, f3 = row
143 | self.db.commit()
144 | finally:
145 | pg8000.paramstyle = orig_paramstyle
146 | c1.close()
147 |
148 | def testPyformat(self):
149 | orig_paramstyle = pg8000.paramstyle
150 | try:
151 | pg8000.paramstyle = "pyformat"
152 | c1 = self.db.cursor()
153 | c1.execute(
154 | "SELECT f1, f2, f3 FROM t1 WHERE f1 > %(f1)s", {"f1": 3})
155 | while 1:
156 | row = c1.fetchone()
157 | if row is None:
158 | break
159 | f1, f2, f3 = row
160 | self.db.commit()
161 | finally:
162 | pg8000.paramstyle = orig_paramstyle
163 | c1.close()
164 |
165 | def testArraysize(self):
166 | try:
167 | c1 = self.db.cursor()
168 | c1.arraysize = 3
169 | c1.execute("SELECT * FROM t1")
170 | retval = c1.fetchmany()
171 | self.assertEqual(len(retval), c1.arraysize)
172 | finally:
173 | c1.close()
174 | self.db.commit()
175 |
176 | def testDate(self):
177 | val = pg8000.Date(2001, 2, 3)
178 | self.assertEqual(val, datetime.date(2001, 2, 3))
179 |
180 | def testTime(self):
181 | val = pg8000.Time(4, 5, 6)
182 | self.assertEqual(val, datetime.time(4, 5, 6))
183 |
184 | def testTimestamp(self):
185 | val = pg8000.Timestamp(2001, 2, 3, 4, 5, 6)
186 | self.assertEqual(val, datetime.datetime(2001, 2, 3, 4, 5, 6))
187 |
188 | def testDateFromTicks(self):
189 | if self.HAS_TZSET:
190 | val = pg8000.DateFromTicks(1173804319)
191 | self.assertEqual(val, datetime.date(2007, 3, 13))
192 |
193 | def testTimeFromTicks(self):
194 | if self.HAS_TZSET:
195 | val = pg8000.TimeFromTicks(1173804319)
196 | self.assertEqual(val, datetime.time(16, 45, 19))
197 |
198 | def testTimestampFromTicks(self):
199 | if self.HAS_TZSET:
200 | val = pg8000.TimestampFromTicks(1173804319)
201 | self.assertEqual(val, datetime.datetime(2007, 3, 13, 16, 45, 19))
202 |
203 | def testBinary(self):
204 | v = pg8000.Binary(b("\x00\x01\x02\x03\x02\x01\x00"))
205 | self.assertEqual(v, b("\x00\x01\x02\x03\x02\x01\x00"))
206 | self.assertTrue(isinstance(v, pg8000.BINARY))
207 |
208 | def testRowCount(self):
209 | try:
210 | c1 = self.db.cursor()
211 | c1.execute("SELECT * FROM t1")
212 |
213 | # Before PostgreSQL 9 we don't know the row count for a select
214 | if self.db._server_version > LooseVersion('8.0.0'):
215 | self.assertEqual(5, c1.rowcount)
216 |
217 | c1.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
218 | self.assertEqual(2, c1.rowcount)
219 |
220 | c1.execute("DELETE FROM t1")
221 | self.assertEqual(5, c1.rowcount)
222 | finally:
223 | c1.close()
224 | self.db.commit()
225 |
226 | def testFetchMany(self):
227 | try:
228 | cursor = self.db.cursor()
229 | cursor.arraysize = 2
230 | cursor.execute("SELECT * FROM t1")
231 | self.assertEqual(2, len(cursor.fetchmany()))
232 | self.assertEqual(2, len(cursor.fetchmany()))
233 | self.assertEqual(1, len(cursor.fetchmany()))
234 | self.assertEqual(0, len(cursor.fetchmany()))
235 | finally:
236 | cursor.close()
237 | self.db.commit()
238 |
239 | def testIterator(self):
240 | from warnings import filterwarnings
241 | filterwarnings("ignore", "DB-API extension cursor.next()")
242 | filterwarnings("ignore", "DB-API extension cursor.__iter__()")
243 |
244 | try:
245 | cursor = self.db.cursor()
246 | cursor.execute("SELECT * FROM t1 ORDER BY f1")
247 | f1 = 0
248 | for row in cursor:
249 | next_f1 = row[0]
250 | assert next_f1 > f1
251 | f1 = next_f1
252 | except:
253 | cursor.close()
254 |
255 | self.db.commit()
256 |
257 | # Vacuum can't be run inside a transaction, so we need to turn
258 | # autocommit on.
259 | def testVacuum(self):
260 | self.db.autocommit = True
261 | try:
262 | cursor = self.db.cursor()
263 | cursor.execute("vacuum")
264 | finally:
265 | cursor.close()
266 |
267 | def testPreparedStatement(self):
268 | cursor = self.db.cursor()
269 | cursor.execute(
270 | 'PREPARE gen_series AS SELECT generate_series(1, 10);')
271 | cursor.execute('EXECUTE gen_series')
272 |
273 |
274 | if __name__ == "__main__":
275 | unittest.main()
276 |
--------------------------------------------------------------------------------
/tests/test_error_recovery.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pg8000
3 | from connection_settings import db_connect
4 | import warnings
5 | import datetime
6 | from sys import exc_info
7 |
8 |
9 | class TestException(Exception):
10 | pass
11 |
12 |
13 | class Tests(unittest.TestCase):
14 | def setUp(self):
15 | self.db = pg8000.connect(**db_connect)
16 |
17 | def tearDown(self):
18 | self.db.close()
19 |
20 | def raiseException(self, value):
21 | raise TestException("oh noes!")
22 |
23 | def testPyValueFail(self):
24 | # Ensure that if types.py_value throws an exception, the original
25 | # exception is raised (TestException), and the connection is
26 | # still usable after the error.
27 | orig = self.db.py_types[datetime.time]
28 | self.db.py_types[datetime.time] = (
29 | orig[0], orig[1], self.raiseException)
30 |
31 | try:
32 | c = self.db.cursor()
33 | try:
34 | try:
35 | c.execute("SELECT %s as f1", (datetime.time(10, 30),))
36 | c.fetchall()
37 | # shouldn't get here, exception should be thrown
38 | self.fail()
39 | except TestException:
40 | # should be TestException type, this is OK!
41 | self.db.rollback()
42 | finally:
43 | self.db.py_types[datetime.time] = orig
44 |
45 | # ensure that the connection is still usable for a new query
46 | c.execute("VALUES ('hw3'::text)")
47 | self.assertEqual(c.fetchone()[0], "hw3")
48 | finally:
49 | c.close()
50 |
51 | def testNoDataErrorRecovery(self):
52 | for i in range(1, 4):
53 | try:
54 | try:
55 | cursor = self.db.cursor()
56 | cursor.execute("DROP TABLE t1")
57 | finally:
58 | cursor.close()
59 | except pg8000.DatabaseError:
60 | e = exc_info()[1]
61 | # the only acceptable error is 'table does not exist'
62 | self.assertEqual(e.args[0]['C'], '42P01')
63 | self.db.rollback()
64 |
65 | def testClosedConnection(self):
66 | warnings.simplefilter("ignore")
67 | my_db = pg8000.connect(**db_connect)
68 | cursor = my_db.cursor()
69 | my_db.close()
70 | try:
71 | cursor.execute("VALUES ('hw1'::text)")
72 | self.fail("Should have raised an exception")
73 | except:
74 | e = exc_info()[1]
75 | self.assertTrue(isinstance(e, self.db.InterfaceError))
76 | self.assertEqual(str(e), 'connection is closed')
77 |
78 | warnings.resetwarnings()
79 |
80 |
81 | if __name__ == "__main__":
82 | unittest.main()
83 |
--------------------------------------------------------------------------------
/tests/test_paramstyle.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pg8000
3 |
4 |
5 | # Tests of the convert_paramstyle function.
6 | class Tests(unittest.TestCase):
7 | def testQmark(self):
8 | new_query, make_args = pg8000.core.convert_paramstyle(
9 | "qmark", "SELECT ?, ?, \"field_?\" FROM t "
10 | "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'")
11 | self.assertEqual(
12 | new_query, "SELECT $1, $2, \"field_?\" FROM t WHERE "
13 | "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'")
14 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3))
15 |
16 | def testQmark2(self):
17 | new_query, make_args = pg8000.core.convert_paramstyle(
18 | "qmark", "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'")
19 | self.assertEqual(
20 | new_query,
21 | "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'")
22 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3))
23 |
24 | def testNumeric(self):
25 | new_query, make_args = pg8000.core.convert_paramstyle(
26 | "numeric", "SELECT :2, :1, * FROM t WHERE a=:3")
27 | self.assertEqual(new_query, "SELECT $2, $1, * FROM t WHERE a=$3")
28 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3))
29 |
30 | def testNamed(self):
31 | new_query, make_args = pg8000.core.convert_paramstyle(
32 | "named", "SELECT :f_2, :f1 FROM t WHERE a=:f_2")
33 | self.assertEqual(new_query, "SELECT $1, $2 FROM t WHERE a=$1")
34 | self.assertEqual(make_args({"f_2": 1, "f1": 2}), (1, 2))
35 |
36 | def testFormat(self):
37 | new_query, make_args = pg8000.core.convert_paramstyle(
38 | "format", "SELECT %s, %s, \"f1_%%\", E'txt_%%' "
39 | "FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %")
40 | self.assertEqual(
41 | new_query,
42 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND "
43 | "b='75%%' AND c = '%' -- Comment with %")
44 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3))
45 |
46 | sql = r"""COMMENT ON TABLE test_schema.comment_test """ \
47 | r"""IS 'the test % '' " \ table comment'"""
48 | new_query, make_args = pg8000.core.convert_paramstyle("format", sql)
49 | self.assertEqual(new_query, sql)
50 |
51 | def testFormatMultiline(self):
52 | new_query, make_args = pg8000.core.convert_paramstyle(
53 | "format", "SELECT -- Comment\n%s FROM t")
54 | self.assertEqual(
55 | new_query,
56 | "SELECT -- Comment\n$1 FROM t")
57 |
58 | def testPyformat(self):
59 | new_query, make_args = pg8000.core.convert_paramstyle(
60 | "pyformat", "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' "
61 | "FROM t WHERE a=%(f2)s AND b='75%%'")
62 | self.assertEqual(
63 | new_query,
64 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND "
65 | "b='75%%'")
66 | self.assertEqual(make_args({"f2": 1, "f1": 2, "f3": 3}), (1, 2))
67 |
68 | # pyformat should support %s and an array, too:
69 | new_query, make_args = pg8000.core.convert_paramstyle(
70 | "pyformat", "SELECT %s, %s, \"f1_%%\", E'txt_%%' "
71 | "FROM t WHERE a=%s AND b='75%%'")
72 | self.assertEqual(
73 | new_query,
74 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND "
75 | "b='75%%'")
76 | self.assertEqual(make_args((1, 2, 3)), (1, 2, 3))
77 |
78 |
79 | if __name__ == "__main__":
80 | unittest.main()
81 |
--------------------------------------------------------------------------------
/tests/test_pg8000_dbapi20.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import dbapi20
3 | import unittest
4 | import pg8000
5 | from connection_settings import db_connect
6 |
7 |
8 | class Tests(dbapi20.DatabaseAPI20Test):
9 | driver = pg8000
10 | connect_args = ()
11 | connect_kw_args = db_connect
12 |
13 | lower_func = 'lower' # For stored procedure test
14 |
15 | def test_nextset(self):
16 | pass
17 |
18 | def test_setoutputsize(self):
19 | pass
20 |
21 |
22 | if __name__ == '__main__':
23 | unittest.main()
24 |
--------------------------------------------------------------------------------
/tests/test_query.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pg8000
3 | from connection_settings import db_connect
4 | from six import u
5 | from sys import exc_info
6 | import datetime
7 | from distutils.version import LooseVersion
8 |
9 | from warnings import filterwarnings
10 |
11 |
12 | # Tests relating to the basic operation of the database driver, driven by the
13 | # pg8000 custom interface.
14 | class Tests(unittest.TestCase):
15 | def setUp(self):
16 | self.db = pg8000.connect(**db_connect)
17 | filterwarnings("ignore", "DB-API extension cursor.next()")
18 | filterwarnings("ignore", "DB-API extension cursor.__iter__()")
19 | self.db.paramstyle = 'format'
20 | try:
21 | cursor = self.db.cursor()
22 | try:
23 | cursor.execute("DROP TABLE t1")
24 | except pg8000.DatabaseError:
25 | e = exc_info()[1]
26 | # the only acceptable error is 'table does not exist'
27 | self.assertEqual(e.args[0]['C'], '42P01')
28 | self.db.rollback()
29 | cursor.execute(
30 | "CREATE TEMPORARY TABLE t1 (f1 int primary key, "
31 | "f2 bigint not null, f3 varchar(50) null)")
32 | finally:
33 | cursor.close()
34 |
35 | self.db.commit()
36 |
37 | def tearDown(self):
38 | self.db.close()
39 |
40 | def testDatabaseError(self):
41 | try:
42 | cursor = self.db.cursor()
43 | self.assertRaises(
44 | pg8000.ProgrammingError, cursor.execute,
45 | "INSERT INTO t99 VALUES (1, 2, 3)")
46 | finally:
47 | cursor.close()
48 |
49 | self.db.rollback()
50 |
51 | def testParallelQueries(self):
52 | try:
53 | cursor = self.db.cursor()
54 | cursor.execute(
55 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
56 | (1, 1, None))
57 | cursor.execute(
58 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
59 | (2, 10, None))
60 | cursor.execute(
61 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
62 | (3, 100, None))
63 | cursor.execute(
64 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
65 | (4, 1000, None))
66 | cursor.execute(
67 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
68 | (5, 10000, None))
69 | try:
70 | c1 = self.db.cursor()
71 | c2 = self.db.cursor()
72 | c1.execute("SELECT f1, f2, f3 FROM t1")
73 | for row in c1:
74 | f1, f2, f3 = row
75 | c2.execute(
76 | "SELECT f1, f2, f3 FROM t1 WHERE f1 > %s", (f1,))
77 | for row in c2:
78 | f1, f2, f3 = row
79 | finally:
80 | c1.close()
81 | c2.close()
82 | finally:
83 | cursor.close()
84 | self.db.rollback()
85 |
86 | def testParallelOpenPortals(self):
87 | try:
88 | c1, c2 = self.db.cursor(), self.db.cursor()
89 | c1count, c2count = 0, 0
90 | q = "select * from generate_series(1, %s)"
91 | params = (100,)
92 | c1.execute(q, params)
93 | c2.execute(q, params)
94 | for c2row in c2:
95 | c2count += 1
96 | for c1row in c1:
97 | c1count += 1
98 | finally:
99 | c1.close()
100 | c2.close()
101 | self.db.rollback()
102 |
103 | self.assertEqual(c1count, c2count)
104 |
105 | # Run a query on a table, alter the structure of the table, then run the
106 | # original query again.
107 |
108 | def testAlter(self):
109 | try:
110 | cursor = self.db.cursor()
111 | cursor.execute("select * from t1")
112 | cursor.execute("alter table t1 drop column f3")
113 | cursor.execute("select * from t1")
114 | finally:
115 | cursor.close()
116 | self.db.rollback()
117 |
118 | # Run a query on a table, drop then re-create the table, then run the
119 | # original query again.
120 |
121 | def testCreate(self):
122 | try:
123 | cursor = self.db.cursor()
124 | cursor.execute("select * from t1")
125 | cursor.execute("drop table t1")
126 | cursor.execute("create temporary table t1 (f1 int primary key)")
127 | cursor.execute("select * from t1")
128 | finally:
129 | cursor.close()
130 | self.db.rollback()
131 |
132 | def testInsertReturning(self):
133 | try:
134 | cursor = self.db.cursor()
135 | cursor.execute("CREATE TABLE t2 (id serial, data text)")
136 |
137 | # Test INSERT ... RETURNING with one row...
138 | cursor.execute(
139 | "INSERT INTO t2 (data) VALUES (%s) RETURNING id",
140 | ("test1",))
141 | row_id = cursor.fetchone()[0]
142 | cursor.execute("SELECT data FROM t2 WHERE id = %s", (row_id,))
143 | self.assertEqual("test1", cursor.fetchone()[0])
144 |
145 | # Before PostgreSQL 9 we don't know the row count for a select
146 | if self.db._server_version > LooseVersion('8.0.0'):
147 | self.assertEqual(cursor.rowcount, 1)
148 |
149 | # Test with multiple rows...
150 | cursor.execute(
151 | "INSERT INTO t2 (data) VALUES (%s), (%s), (%s) "
152 | "RETURNING id", ("test2", "test3", "test4"))
153 | self.assertEqual(cursor.rowcount, 3)
154 | ids = tuple([x[0] for x in cursor])
155 | self.assertEqual(len(ids), 3)
156 | finally:
157 | cursor.close()
158 | self.db.rollback()
159 |
160 | def testRowCount(self):
161 | # Before PostgreSQL 9 we don't know the row count for a select
162 | if self.db._server_version > LooseVersion('8.0.0'):
163 | try:
164 | cursor = self.db.cursor()
165 | expected_count = 57
166 | cursor.executemany(
167 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
168 | tuple((i, i, None) for i in range(expected_count)))
169 |
170 | # Check rowcount after executemany
171 | self.assertEqual(expected_count, cursor.rowcount)
172 | self.db.commit()
173 |
174 | cursor.execute("SELECT * FROM t1")
175 |
176 | # Check row_count without doing any reading first...
177 | self.assertEqual(expected_count, cursor.rowcount)
178 |
179 | # Check rowcount after reading some rows, make sure it still
180 | # works...
181 | for i in range(expected_count // 2):
182 | cursor.fetchone()
183 | self.assertEqual(expected_count, cursor.rowcount)
184 | finally:
185 | cursor.close()
186 | self.db.commit()
187 |
188 | try:
189 | cursor = self.db.cursor()
190 | # Restart the cursor, read a few rows, and then check rowcount
191 | # again...
192 | cursor = self.db.cursor()
193 | cursor.execute("SELECT * FROM t1")
194 | for i in range(expected_count // 3):
195 | cursor.fetchone()
196 | self.assertEqual(expected_count, cursor.rowcount)
197 | self.db.rollback()
198 |
199 | # Should be -1 for a command with no results
200 | cursor.execute("DROP TABLE t1")
201 | self.assertEqual(-1, cursor.rowcount)
202 | finally:
203 | cursor.close()
204 | self.db.commit()
205 |
206 | def testRowCountUpdate(self):
207 | try:
208 | cursor = self.db.cursor()
209 | cursor.execute(
210 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
211 | (1, 1, None))
212 | cursor.execute(
213 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
214 | (2, 10, None))
215 | cursor.execute(
216 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
217 | (3, 100, None))
218 | cursor.execute(
219 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
220 | (4, 1000, None))
221 | cursor.execute(
222 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
223 | (5, 10000, None))
224 | cursor.execute("UPDATE t1 SET f3 = %s WHERE f2 > 101", ("Hello!",))
225 | self.assertEqual(cursor.rowcount, 2)
226 | finally:
227 | cursor.close()
228 | self.db.commit()
229 |
230 | def testIntOid(self):
231 | try:
232 | cursor = self.db.cursor()
233 | # https://bugs.launchpad.net/pg8000/+bug/230796
234 | cursor.execute(
235 | "SELECT typname FROM pg_type WHERE oid = %s", (100,))
236 | finally:
237 | cursor.close()
238 | self.db.rollback()
239 |
240 | def testUnicodeQuery(self):
241 | try:
242 | cursor = self.db.cursor()
243 | cursor.execute(
244 | u(
245 | "CREATE TEMPORARY TABLE \u043c\u0435\u0441\u0442\u043e "
246 | "(\u0438\u043c\u044f VARCHAR(50), "
247 | "\u0430\u0434\u0440\u0435\u0441 VARCHAR(250))"))
248 | finally:
249 | cursor.close()
250 | self.db.commit()
251 |
252 | def testExecutemany(self):
253 | try:
254 | cursor = self.db.cursor()
255 | cursor.executemany(
256 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
257 | ((1, 1, 'Avast ye!'), (2, 1, None)))
258 |
259 | cursor.executemany(
260 | "select %s",
261 | (
262 | (datetime.datetime(2014, 5, 7, tzinfo=pg8000.core.utc), ),
263 | (datetime.datetime(2014, 5, 7),)))
264 | finally:
265 | cursor.close()
266 | self.db.commit()
267 |
268 | # Check that autocommit stays off
269 | # We keep track of whether we're in a transaction or not by using the
270 | # READY_FOR_QUERY message.
271 | def testTransactions(self):
272 | try:
273 | cursor = self.db.cursor()
274 | cursor.execute("commit")
275 | cursor.execute(
276 | "INSERT INTO t1 (f1, f2, f3) VALUES (%s, %s, %s)",
277 | (1, 1, "Zombie"))
278 | cursor.execute("rollback")
279 | cursor.execute("select * from t1")
280 |
281 | # Before PostgreSQL 9 we don't know the row count for a select
282 | if self.db._server_version > LooseVersion('8.0.0'):
283 | self.assertEqual(cursor.rowcount, 0)
284 | finally:
285 | cursor.close()
286 | self.db.commit()
287 |
288 | def testIn(self):
289 | try:
290 | cursor = self.db.cursor()
291 | cursor.execute(
292 | "SELECT typname FROM pg_type WHERE oid = any(%s)", ([16, 23],))
293 | ret = cursor.fetchall()
294 | self.assertEqual(ret[0][0], 'bool')
295 | finally:
296 | cursor.close()
297 |
298 | def test_no_previous_tpc(self):
299 | try:
300 | self.db.tpc_begin('Stacey')
301 | cursor = self.db.cursor()
302 | cursor.execute("SELECT * FROM pg_type")
303 | self.db.tpc_commit()
304 | finally:
305 | cursor.close()
306 |
307 | # Check that tpc_recover() doesn't start a transaction
308 | def test_tpc_recover(self):
309 | try:
310 | self.db.tpc_recover()
311 | cursor = self.db.cursor()
312 | self.db.autocommit = True
313 |
314 | # If tpc_recover() has started a transaction, this will fail
315 | cursor.execute("VACUUM")
316 | finally:
317 | cursor.close()
318 |
319 | # An empty query should raise a ProgrammingError
320 | def test_empty_query(self):
321 | try:
322 | cursor = self.db.cursor()
323 | self.assertRaises(pg8000.ProgrammingError, cursor.execute, "")
324 | finally:
325 | cursor.close()
326 |
327 | # rolling back when not in a transaction doesn't generate a warning
328 | def test_rollback_no_transaction(self):
329 | try:
330 | # Remove any existing notices
331 | self.db.notices.clear()
332 |
333 | cursor = self.db.cursor()
334 |
335 | # First, verify that a raw rollback does produce a notice
336 | self.db.execute(cursor, "rollback", None)
337 |
338 | self.assertEqual(1, len(self.db.notices))
339 | # 25P01 is the code for no_active_sql_tronsaction. It has
340 | # a message and severity name, but those might be
341 | # localized/depend on the server version.
342 | self.assertEqual(self.db.notices.pop().get(b'C'), b'25P01')
343 |
344 | # Now going through the rollback method doesn't produce
345 | # any notices because it knows we're not in a transaction.
346 | self.db.rollback()
347 |
348 | self.assertEqual(0, len(self.db.notices))
349 |
350 | finally:
351 | cursor.close()
352 |
353 | def test_context_manager_class(self):
354 | self.assertTrue('__enter__' in pg8000.core.Cursor.__dict__)
355 | self.assertTrue('__exit__' in pg8000.core.Cursor.__dict__)
356 |
357 | with self.db.cursor() as cursor:
358 | cursor.execute('select 1')
359 |
360 | def test_deallocate_prepared_statements(self):
361 | try:
362 | cursor = self.db.cursor()
363 | cursor.execute("select * from t1")
364 | cursor.execute("alter table t1 drop column f3")
365 | cursor.execute("select count(*) from pg_prepared_statements")
366 | res = cursor.fetchall()
367 | self.assertEqual(res[0][0], 1)
368 | finally:
369 | cursor.close()
370 | self.db.rollback()
371 |
372 | if __name__ == "__main__":
373 | unittest.main()
374 |
--------------------------------------------------------------------------------
/tests/test_typeconversion.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import pg8000
3 | from pg8000 import PGJsonb, PGEnum
4 | import datetime
5 | import decimal
6 | import struct
7 | from connection_settings import db_connect
8 | from six import b, PY2, u
9 | import uuid
10 | import os
11 | import time
12 | from distutils.version import LooseVersion
13 | import sys
14 | import json
15 | import pytz
16 | from collections import OrderedDict
17 |
18 |
19 | IS_JYTHON = sys.platform.lower().count('java') > 0
20 |
21 |
22 | # Type conversion tests
23 | class Tests(unittest.TestCase):
24 | def setUp(self):
25 | self.db = pg8000.connect(**db_connect)
26 | self.cursor = self.db.cursor()
27 |
28 | def tearDown(self):
29 | self.cursor.close()
30 | self.cursor = None
31 | self.db.close()
32 |
33 | def testTimeRoundtrip(self):
34 | self.cursor.execute("SELECT %s as f1", (datetime.time(4, 5, 6),))
35 | retval = self.cursor.fetchall()
36 | self.assertEqual(retval[0][0], datetime.time(4, 5, 6))
37 |
38 | def testDateRoundtrip(self):
39 | v = datetime.date(2001, 2, 3)
40 | self.cursor.execute("SELECT %s as f1", (v,))
41 | retval = self.cursor.fetchall()
42 | self.assertEqual(retval[0][0], v)
43 |
44 | def testBoolRoundtrip(self):
45 | self.cursor.execute("SELECT %s as f1", (True,))
46 | retval = self.cursor.fetchall()
47 | self.assertEqual(retval[0][0], True)
48 |
49 | def testNullRoundtrip(self):
50 | # We can't just "SELECT %s" and set None as the parameter, since it has
51 | # no type. That would result in a PG error, "could not determine data
52 | # type of parameter %s". So we create a temporary table, insert null
53 | # values, and read them back.
54 | self.cursor.execute(
55 | "CREATE TEMPORARY TABLE TestNullWrite "
56 | "(f1 int4, f2 timestamp, f3 varchar)")
57 | self.cursor.execute(
58 | "INSERT INTO TestNullWrite VALUES (%s, %s, %s)",
59 | (None, None, None,))
60 | self.cursor.execute("SELECT * FROM TestNullWrite")
61 | retval = self.cursor.fetchone()
62 | self.assertEqual(retval, [None, None, None])
63 |
64 | def testNullSelectFailure(self):
65 | # See comment in TestNullRoundtrip. This test is here to ensure that
66 | # this behaviour is documented and doesn't mysteriously change.
67 | self.assertRaises(
68 | pg8000.ProgrammingError, self.cursor.execute,
69 | "SELECT %s as f1", (None,))
70 | self.db.rollback()
71 |
72 | def testDecimalRoundtrip(self):
73 | values = (
74 | "1.1", "-1.1", "10000", "20000", "-1000000000.123456789", "1.0",
75 | "12.44")
76 | for v in values:
77 | self.cursor.execute("SELECT %s as f1", (decimal.Decimal(v),))
78 | retval = self.cursor.fetchall()
79 | self.assertEqual(str(retval[0][0]), v)
80 |
81 | def testFloatRoundtrip(self):
82 | # This test ensures that the binary float value doesn't change in a
83 | # roundtrip to the server. That could happen if the value was
84 | # converted to text and got rounded by a decimal place somewhere.
85 | val = 1.756e-12
86 | bin_orig = struct.pack("!d", val)
87 | self.cursor.execute("SELECT %s as f1", (val,))
88 | retval = self.cursor.fetchall()
89 | bin_new = struct.pack("!d", retval[0][0])
90 | self.assertEqual(bin_new, bin_orig)
91 |
92 | def test_float_plus_infinity_roundtrip(self):
93 | v = float('inf')
94 | self.cursor.execute("SELECT %s as f1", (v,))
95 | retval = self.cursor.fetchall()
96 | self.assertEqual(retval[0][0], v)
97 |
98 | def testStrRoundtrip(self):
99 | v = "hello world"
100 | self.cursor.execute(
101 | "create temporary table test_str (f character varying(255))")
102 | self.cursor.execute("INSERT INTO test_str VALUES (%s)", (v,))
103 | self.cursor.execute("SELECT * from test_str")
104 | retval = self.cursor.fetchall()
105 | self.assertEqual(retval[0][0], v)
106 |
107 | if PY2:
108 | v = "hello \xce\x94 world"
109 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v,))
110 | retval = self.cursor.fetchall()
111 | self.assertEqual(retval[0][0], v.decode('utf8'))
112 |
113 | def test_str_then_int(self):
114 | v1 = "hello world"
115 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v1,))
116 | retval = self.cursor.fetchall()
117 | self.assertEqual(retval[0][0], v1)
118 |
119 | v2 = 1
120 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v2,))
121 | retval = self.cursor.fetchall()
122 | self.assertEqual(retval[0][0], str(v2))
123 |
124 | def testUnicodeRoundtrip(self):
125 | v = u("hello \u0173 world")
126 | self.cursor.execute("SELECT cast(%s as varchar) as f1", (v,))
127 | retval = self.cursor.fetchall()
128 | self.assertEqual(retval[0][0], v)
129 |
130 | def testLongRoundtrip(self):
131 | self.cursor.execute(
132 | "SELECT %s", (50000000000000,))
133 | retval = self.cursor.fetchall()
134 | self.assertEqual(retval[0][0], 50000000000000)
135 |
136 | def testIntExecuteMany(self):
137 | self.cursor.executemany("SELECT %s", ((1,), (40000,)))
138 | self.cursor.fetchall()
139 |
140 | v = ([None], [4])
141 | self.cursor.execute(
142 | "create temporary table test_int (f integer)")
143 | self.cursor.executemany("INSERT INTO test_int VALUES (%s)", v)
144 | self.cursor.execute("SELECT * from test_int")
145 | retval = self.cursor.fetchall()
146 | self.assertEqual(retval, v)
147 |
148 | def testIntRoundtrip(self):
149 | int2 = 21
150 | int4 = 23
151 | int8 = 20
152 |
153 | test_values = [
154 | (0, int2),
155 | (-32767, int2),
156 | (-32768, int4),
157 | (+32767, int2),
158 | (+32768, int4),
159 | (-2147483647, int4),
160 | (-2147483648, int8),
161 | (+2147483647, int4),
162 | (+2147483648, int8),
163 | (-9223372036854775807, int8),
164 | (+9223372036854775807, int8)]
165 |
166 | for value, typoid in test_values:
167 | self.cursor.execute("SELECT %s", (value,))
168 | retval = self.cursor.fetchall()
169 | self.assertEqual(retval[0][0], value)
170 | column_name, column_typeoid = self.cursor.description[0][0:2]
171 | self.assertEqual(column_typeoid, typoid)
172 |
173 | def testByteaRoundtrip(self):
174 | self.cursor.execute(
175 | "SELECT %s as f1",
176 | (pg8000.Binary(b("\x00\x01\x02\x03\x02\x01\x00")),))
177 | retval = self.cursor.fetchall()
178 | self.assertEqual(retval[0][0], b("\x00\x01\x02\x03\x02\x01\x00"))
179 |
180 | def test_bytearray_round_trip(self):
181 | binary = b'\x00\x01\x02\x03\x02\x01\x00'
182 | self.cursor.execute("SELECT %s as f1", (bytearray(binary),))
183 | retval = self.cursor.fetchall()
184 | self.assertEqual(retval[0][0], binary)
185 |
186 | def test_bytearray_subclass_round_trip(self):
187 | class BClass(bytearray):
188 | pass
189 | binary = b'\x00\x01\x02\x03\x02\x01\x00'
190 | self.cursor.execute("SELECT %s as f1", (BClass(binary),))
191 | retval = self.cursor.fetchall()
192 | self.assertEqual(retval[0][0], binary)
193 |
194 | def testTimestampRoundtrip(self):
195 | v = datetime.datetime(2001, 2, 3, 4, 5, 6, 170000)
196 | self.cursor.execute("SELECT %s as f1", (v,))
197 | retval = self.cursor.fetchall()
198 | self.assertEqual(retval[0][0], v)
199 |
200 | # Test that time zone doesn't affect it
201 | # Jython 2.5.3 doesn't have a time.tzset() so skip
202 | if not IS_JYTHON:
203 | orig_tz = os.environ.get('TZ')
204 | os.environ['TZ'] = "America/Edmonton"
205 | time.tzset()
206 |
207 | self.cursor.execute("SELECT %s as f1", (v,))
208 | retval = self.cursor.fetchall()
209 | self.assertEqual(retval[0][0], v)
210 |
211 | if orig_tz is None:
212 | del os.environ['TZ']
213 | else:
214 | os.environ['TZ'] = orig_tz
215 | time.tzset()
216 |
217 | def testIntervalRoundtrip(self):
218 | v = pg8000.Interval(microseconds=123456789, days=2, months=24)
219 | self.cursor.execute("SELECT %s as f1", (v,))
220 | retval = self.cursor.fetchall()
221 | self.assertEqual(retval[0][0], v)
222 |
223 | v = datetime.timedelta(seconds=30)
224 | self.cursor.execute("SELECT %s as f1", (v,))
225 | retval = self.cursor.fetchall()
226 | self.assertEqual(retval[0][0], v)
227 |
228 | def test_enum_str_round_trip(self):
229 | try:
230 | self.cursor.execute(
231 | "create type lepton as enum ('electron', 'muon', 'tau')")
232 | except pg8000.ProgrammingError:
233 | self.db.rollback()
234 |
235 | v = 'muon'
236 | self.cursor.execute("SELECT cast(%s as lepton) as f1", (v,))
237 | retval = self.cursor.fetchall()
238 | self.assertEqual(retval[0][0], v)
239 | self.cursor.execute(
240 | "CREATE TEMPORARY TABLE testenum "
241 | "(f1 lepton)")
242 | self.cursor.execute(
243 | "INSERT INTO testenum VALUES (cast(%s as lepton))", ('electron',))
244 | self.cursor.execute("drop table testenum")
245 | self.cursor.execute("drop type lepton")
246 | self.db.commit()
247 |
248 | def test_enum_custom_round_trip(self):
249 | class Lepton(object):
250 | # Implements PEP 435 in the minimal fashion needed
251 | __members__ = OrderedDict()
252 |
253 | def __init__(self, name, value, alias=None):
254 | self.name = name
255 | self.value = value
256 | self.__members__[name] = self
257 | setattr(self.__class__, name, self)
258 | if alias:
259 | self.__members__[alias] = self
260 | setattr(self.__class__, alias, self)
261 |
262 | try:
263 | self.cursor.execute(
264 | "create type lepton as enum ('1', '2', '3')")
265 | except pg8000.ProgrammingError:
266 | self.db.rollback()
267 |
268 | v = Lepton('muon', '2')
269 | self.cursor.execute(
270 | "SELECT cast(%s as lepton) as f1", (PGEnum(v),))
271 | retval = self.cursor.fetchall()
272 | self.assertEqual(retval[0][0], v.value)
273 | self.cursor.execute("drop type lepton")
274 | self.db.commit()
275 |
276 | def test_enum_py_round_trip(self):
277 | try:
278 | from enum import Enum
279 |
280 | class Lepton(Enum):
281 | electron = '1'
282 | muon = '2'
283 | tau = '3'
284 |
285 | try:
286 | self.cursor.execute(
287 | "create type lepton as enum ('1', '2', '3')")
288 | except pg8000.ProgrammingError:
289 | self.db.rollback()
290 |
291 | v = Lepton.muon
292 | self.cursor.execute("SELECT cast(%s as lepton) as f1", (v,))
293 | retval = self.cursor.fetchall()
294 | self.assertEqual(retval[0][0], v.value)
295 |
296 | self.cursor.execute(
297 | "CREATE TEMPORARY TABLE testenum "
298 | "(f1 lepton)")
299 | self.cursor.execute(
300 | "INSERT INTO testenum VALUES (cast(%s as lepton))",
301 | (Lepton.electron,))
302 | self.cursor.execute("drop table testenum")
303 | self.cursor.execute("drop type lepton")
304 | self.db.commit()
305 | except ImportError:
306 | pass
307 |
308 | def testXmlRoundtrip(self):
309 | v = 'gatccgagtac'
310 | self.cursor.execute("select xmlparse(content %s) as f1", (v,))
311 | retval = self.cursor.fetchall()
312 | self.assertEqual(retval[0][0], v)
313 |
314 | def testUuidRoundtrip(self):
315 | v = uuid.UUID('911460f2-1f43-fea2-3e2c-e01fd5b5069d')
316 | self.cursor.execute("select %s as f1", (v,))
317 | retval = self.cursor.fetchall()
318 | self.assertEqual(retval[0][0], v)
319 |
320 | def testInetRoundtrip(self):
321 | try:
322 | import ipaddress
323 |
324 | v = ipaddress.ip_network('192.168.0.0/28')
325 | self.cursor.execute("select %s as f1", (v,))
326 | retval = self.cursor.fetchall()
327 | self.assertEqual(retval[0][0], v)
328 |
329 | v = ipaddress.ip_address('192.168.0.1')
330 | self.cursor.execute("select %s as f1", (v,))
331 | retval = self.cursor.fetchall()
332 | self.assertEqual(retval[0][0], v)
333 |
334 | except ImportError:
335 | for v in ('192.168.100.128/25', '192.168.0.1'):
336 | self.cursor.execute(
337 | "select cast(cast(%s as varchar) as inet) as f1", (v,))
338 | retval = self.cursor.fetchall()
339 | self.assertEqual(retval[0][0], v)
340 |
341 | def testXidRoundtrip(self):
342 | v = 86722
343 | self.cursor.execute(
344 | "select cast(cast(%s as varchar) as xid) as f1", (v,))
345 | retval = self.cursor.fetchall()
346 | self.assertEqual(retval[0][0], v)
347 |
348 | # Should complete without an exception
349 | self.cursor.execute(
350 | "select * from pg_locks where transactionid = %s", (97712,))
351 | retval = self.cursor.fetchall()
352 |
353 | def testInt2VectorIn(self):
354 | self.cursor.execute("select cast('1 2' as int2vector) as f1")
355 | retval = self.cursor.fetchall()
356 | self.assertEqual(retval[0][0], [1, 2])
357 |
358 | # Should complete without an exception
359 | self.cursor.execute("select indkey from pg_index")
360 | retval = self.cursor.fetchall()
361 |
362 | def testTimestampTzOut(self):
363 | self.cursor.execute(
364 | "SELECT '2001-02-03 04:05:06.17 America/Edmonton'"
365 | "::timestamp with time zone")
366 | retval = self.cursor.fetchall()
367 | dt = retval[0][0]
368 | self.assertEqual(dt.tzinfo is not None, True, "no tzinfo returned")
369 | self.assertEqual(
370 | dt.astimezone(pg8000.utc),
371 | datetime.datetime(2001, 2, 3, 11, 5, 6, 170000, pg8000.utc),
372 | "retrieved value match failed")
373 |
374 | def testTimestampTzRoundtrip(self):
375 | if not IS_JYTHON:
376 | mst = pytz.timezone("America/Edmonton")
377 | v1 = mst.localize(datetime.datetime(2001, 2, 3, 4, 5, 6, 170000))
378 | self.cursor.execute("SELECT %s as f1", (v1,))
379 | retval = self.cursor.fetchall()
380 | v2 = retval[0][0]
381 | self.assertNotEqual(v2.tzinfo, None)
382 | self.assertEqual(v1, v2)
383 |
384 | def testTimestampMismatch(self):
385 | if not IS_JYTHON:
386 | mst = pytz.timezone("America/Edmonton")
387 | self.cursor.execute("SET SESSION TIME ZONE 'America/Edmonton'")
388 | try:
389 | self.cursor.execute(
390 | "CREATE TEMPORARY TABLE TestTz "
391 | "(f1 timestamp with time zone, "
392 | "f2 timestamp without time zone)")
393 | self.cursor.execute(
394 | "INSERT INTO TestTz (f1, f2) VALUES (%s, %s)", (
395 | # insert timestamp into timestamptz field (v1)
396 | datetime.datetime(2001, 2, 3, 4, 5, 6, 170000),
397 | # insert timestamptz into timestamp field (v2)
398 | mst.localize(datetime.datetime(
399 | 2001, 2, 3, 4, 5, 6, 170000))))
400 | self.cursor.execute("SELECT f1, f2 FROM TestTz")
401 | retval = self.cursor.fetchall()
402 |
403 | # when inserting a timestamp into a timestamptz field,
404 | # postgresql assumes that it is in local time. So the value
405 | # that comes out will be the server's local time interpretation
406 | # of v1. We've set the server's TZ to MST, the time should
407 | # be...
408 | f1 = retval[0][0]
409 | self.assertEqual(
410 | f1, datetime.datetime(
411 | 2001, 2, 3, 11, 5, 6, 170000, pytz.utc))
412 |
413 | # inserting the timestamptz into a timestamp field, pg8000
414 | # converts the value into UTC, and then the PG server converts
415 | # it into local time for insertion into the field. When we
416 | # query for it, we get the same time back, like the tz was
417 | # dropped.
418 | f2 = retval[0][1]
419 | self.assertEqual(
420 | f2, datetime.datetime(2001, 2, 3, 4, 5, 6, 170000))
421 | finally:
422 | self.cursor.execute("SET SESSION TIME ZONE DEFAULT")
423 |
424 | def testNameOut(self):
425 | # select a field that is of "name" type:
426 | self.cursor.execute("SELECT usename FROM pg_user")
427 | self.cursor.fetchall()
428 | # It is sufficient that no errors were encountered.
429 |
430 | def testOidOut(self):
431 | self.cursor.execute("SELECT oid FROM pg_type")
432 | self.cursor.fetchall()
433 | # It is sufficient that no errors were encountered.
434 |
435 | def testBooleanOut(self):
436 | self.cursor.execute("SELECT cast('t' as bool)")
437 | retval = self.cursor.fetchall()
438 | self.assertTrue(retval[0][0])
439 |
440 | def testNumericOut(self):
441 | for num in ('5000', '50.34'):
442 | self.cursor.execute("SELECT " + num + "::numeric")
443 | retval = self.cursor.fetchall()
444 | self.assertEqual(str(retval[0][0]), num)
445 |
446 | def testInt2Out(self):
447 | self.cursor.execute("SELECT 5000::smallint")
448 | retval = self.cursor.fetchall()
449 | self.assertEqual(retval[0][0], 5000)
450 |
451 | def testInt4Out(self):
452 | self.cursor.execute("SELECT 5000::integer")
453 | retval = self.cursor.fetchall()
454 | self.assertEqual(retval[0][0], 5000)
455 |
456 | def testInt8Out(self):
457 | self.cursor.execute("SELECT 50000000000000::bigint")
458 | retval = self.cursor.fetchall()
459 | self.assertEqual(retval[0][0], 50000000000000)
460 |
461 | def testFloat4Out(self):
462 | self.cursor.execute("SELECT 1.1::real")
463 | retval = self.cursor.fetchall()
464 | self.assertEqual(retval[0][0], 1.1000000238418579)
465 |
466 | def testFloat8Out(self):
467 | self.cursor.execute("SELECT 1.1::double precision")
468 | retval = self.cursor.fetchall()
469 | self.assertEqual(retval[0][0], 1.1000000000000001)
470 |
471 | def testVarcharOut(self):
472 | self.cursor.execute("SELECT 'hello'::varchar(20)")
473 | retval = self.cursor.fetchall()
474 | self.assertEqual(retval[0][0], "hello")
475 |
476 | def testCharOut(self):
477 | self.cursor.execute("SELECT 'hello'::char(20)")
478 | retval = self.cursor.fetchall()
479 | self.assertEqual(retval[0][0], "hello ")
480 |
481 | def testTextOut(self):
482 | self.cursor.execute("SELECT 'hello'::text")
483 | retval = self.cursor.fetchall()
484 | self.assertEqual(retval[0][0], "hello")
485 |
486 | def testIntervalOut(self):
487 | self.cursor.execute(
488 | "SELECT '1 month 16 days 12 hours 32 minutes 64 seconds'"
489 | "::interval")
490 | retval = self.cursor.fetchall()
491 | expected_value = pg8000.Interval(
492 | microseconds=(12 * 60 * 60 * 1000 * 1000) +
493 | (32 * 60 * 1000 * 1000) + (64 * 1000 * 1000),
494 | days=16, months=1)
495 | self.assertEqual(retval[0][0], expected_value)
496 |
497 | self.cursor.execute("select interval '30 seconds'")
498 | retval = self.cursor.fetchall()
499 | expected_value = datetime.timedelta(seconds=30)
500 | self.assertEqual(retval[0][0], expected_value)
501 |
502 | self.cursor.execute("select interval '12 days 30 seconds'")
503 | retval = self.cursor.fetchall()
504 | expected_value = datetime.timedelta(days=12, seconds=30)
505 | self.assertEqual(retval[0][0], expected_value)
506 |
507 | def testTimestampOut(self):
508 | self.cursor.execute("SELECT '2001-02-03 04:05:06.17'::timestamp")
509 | retval = self.cursor.fetchall()
510 | self.assertEqual(
511 | retval[0][0], datetime.datetime(2001, 2, 3, 4, 5, 6, 170000))
512 |
513 | # confirms that pg8000's binary output methods have the same output for
514 | # a data type as the PG server
515 | def testBinaryOutputMethods(self):
516 | methods = (
517 | ("float8send", 22.2),
518 | ("timestamp_send", datetime.datetime(2001, 2, 3, 4, 5, 6, 789)),
519 | ("byteasend", pg8000.Binary(b("\x01\x02"))),
520 | ("interval_send", pg8000.Interval(1234567, 123, 123)),)
521 | for method_out, value in methods:
522 | self.cursor.execute("SELECT %s(%%s) as f1" % method_out, (value,))
523 | retval = self.cursor.fetchall()
524 | self.assertEqual(
525 | retval[0][0], self.db.make_params((value,))[0][2](value))
526 |
527 | def testInt4ArrayOut(self):
528 | self.cursor.execute(
529 | "SELECT '{1,2,3,4}'::INT[] AS f1, "
530 | "'{{1,2,3},{4,5,6}}'::INT[][] AS f2, "
531 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::INT[][][] AS f3")
532 | f1, f2, f3 = self.cursor.fetchone()
533 | self.assertEqual(f1, [1, 2, 3, 4])
534 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]])
535 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]])
536 |
537 | def testInt2ArrayOut(self):
538 | self.cursor.execute(
539 | "SELECT '{1,2,3,4}'::INT2[] AS f1, "
540 | "'{{1,2,3},{4,5,6}}'::INT2[][] AS f2, "
541 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::INT2[][][] AS f3")
542 | f1, f2, f3 = self.cursor.fetchone()
543 | self.assertEqual(f1, [1, 2, 3, 4])
544 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]])
545 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]])
546 |
547 | def testInt8ArrayOut(self):
548 | self.cursor.execute(
549 | "SELECT '{1,2,3,4}'::INT8[] AS f1, "
550 | "'{{1,2,3},{4,5,6}}'::INT8[][] AS f2, "
551 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::INT8[][][] AS f3")
552 | f1, f2, f3 = self.cursor.fetchone()
553 | self.assertEqual(f1, [1, 2, 3, 4])
554 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]])
555 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]])
556 |
557 | def testBoolArrayOut(self):
558 | self.cursor.execute(
559 | "SELECT '{TRUE,FALSE,FALSE,TRUE}'::BOOL[] AS f1, "
560 | "'{{TRUE,FALSE,TRUE},{FALSE,TRUE,FALSE}}'::BOOL[][] AS f2, "
561 | "'{{{TRUE,FALSE},{FALSE,TRUE}},{{NULL,TRUE},{FALSE,FALSE}}}'"
562 | "::BOOL[][][] AS f3")
563 | f1, f2, f3 = self.cursor.fetchone()
564 | self.assertEqual(f1, [True, False, False, True])
565 | self.assertEqual(f2, [[True, False, True], [False, True, False]])
566 | self.assertEqual(
567 | f3,
568 | [[[True, False], [False, True]], [[None, True], [False, False]]])
569 |
570 | def testFloat4ArrayOut(self):
571 | self.cursor.execute(
572 | "SELECT '{1,2,3,4}'::FLOAT4[] AS f1, "
573 | "'{{1,2,3},{4,5,6}}'::FLOAT4[][] AS f2, "
574 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::FLOAT4[][][] AS f3")
575 | f1, f2, f3 = self.cursor.fetchone()
576 | self.assertEqual(f1, [1, 2, 3, 4])
577 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]])
578 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]])
579 |
580 | def testFloat8ArrayOut(self):
581 | self.cursor.execute(
582 | "SELECT '{1,2,3,4}'::FLOAT8[] AS f1, "
583 | "'{{1,2,3},{4,5,6}}'::FLOAT8[][] AS f2, "
584 | "'{{{1,2},{3,4}},{{NULL,6},{7,8}}}'::FLOAT8[][][] AS f3")
585 | f1, f2, f3 = self.cursor.fetchone()
586 | self.assertEqual(f1, [1, 2, 3, 4])
587 | self.assertEqual(f2, [[1, 2, 3], [4, 5, 6]])
588 | self.assertEqual(f3, [[[1, 2], [3, 4]], [[None, 6], [7, 8]]])
589 |
590 | def testIntArrayRoundtrip(self):
591 | # send small int array, should be sent as INT2[]
592 | self.cursor.execute("SELECT %s as f1", ([1, 2, 3],))
593 | retval = self.cursor.fetchall()
594 | self.assertEqual(retval[0][0], [1, 2, 3])
595 | column_name, column_typeoid = self.cursor.description[0][0:2]
596 | self.assertEqual(column_typeoid, 1005, "type should be INT2[]")
597 |
598 | # test multi-dimensional array, should be sent as INT2[]
599 | self.cursor.execute("SELECT %s as f1", ([[1, 2], [3, 4]],))
600 | retval = self.cursor.fetchall()
601 | self.assertEqual(retval[0][0], [[1, 2], [3, 4]])
602 |
603 | column_name, column_typeoid = self.cursor.description[0][0:2]
604 | self.assertEqual(column_typeoid, 1005, "type should be INT2[]")
605 |
606 | # a larger value should kick it up to INT4[]...
607 | self.cursor.execute("SELECT %s as f1 -- integer[]", ([70000, 2, 3],))
608 | retval = self.cursor.fetchall()
609 | self.assertEqual(retval[0][0], [70000, 2, 3])
610 | column_name, column_typeoid = self.cursor.description[0][0:2]
611 | self.assertEqual(column_typeoid, 1007, "type should be INT4[]")
612 |
613 | # a much larger value should kick it up to INT8[]...
614 | self.cursor.execute(
615 | "SELECT %s as f1 -- bigint[]", ([7000000000, 2, 3],))
616 | retval = self.cursor.fetchall()
617 | self.assertEqual(
618 | retval[0][0], [7000000000, 2, 3],
619 | "retrieved value match failed")
620 | column_name, column_typeoid = self.cursor.description[0][0:2]
621 | self.assertEqual(column_typeoid, 1016, "type should be INT8[]")
622 |
623 | def testIntArrayWithNullRoundtrip(self):
624 | self.cursor.execute("SELECT %s as f1", ([1, None, 3],))
625 | retval = self.cursor.fetchall()
626 | self.assertEqual(retval[0][0], [1, None, 3])
627 |
628 | def testFloatArrayRoundtrip(self):
629 | self.cursor.execute("SELECT %s as f1", ([1.1, 2.2, 3.3],))
630 | retval = self.cursor.fetchall()
631 | self.assertEqual(retval[0][0], [1.1, 2.2, 3.3])
632 |
633 | def testBoolArrayRoundtrip(self):
634 | self.cursor.execute("SELECT %s as f1", ([True, False, None],))
635 | retval = self.cursor.fetchall()
636 | self.assertEqual(retval[0][0], [True, False, None])
637 |
638 | def testStringArrayOut(self):
639 | self.cursor.execute("SELECT '{a,b,c}'::TEXT[] AS f1")
640 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"])
641 | self.cursor.execute("SELECT '{a,b,c}'::CHAR[] AS f1")
642 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"])
643 | self.cursor.execute("SELECT '{a,b,c}'::VARCHAR[] AS f1")
644 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"])
645 | self.cursor.execute("SELECT '{a,b,c}'::CSTRING[] AS f1")
646 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"])
647 | self.cursor.execute("SELECT '{a,b,c}'::NAME[] AS f1")
648 | self.assertEqual(self.cursor.fetchone()[0], ["a", "b", "c"])
649 | self.cursor.execute("SELECT '{}'::text[];")
650 | self.assertEqual(self.cursor.fetchone()[0], [])
651 | self.cursor.execute("SELECT '{NULL,\"NULL\",NULL,\"\"}'::text[];")
652 | self.assertEqual(self.cursor.fetchone()[0], [None, 'NULL', None, ""])
653 |
654 | def testNumericArrayOut(self):
655 | self.cursor.execute("SELECT '{1.1,2.2,3.3}'::numeric[] AS f1")
656 | self.assertEqual(
657 | self.cursor.fetchone()[0], [
658 | decimal.Decimal("1.1"), decimal.Decimal("2.2"),
659 | decimal.Decimal("3.3")])
660 |
661 | def testNumericArrayRoundtrip(self):
662 | v = [decimal.Decimal("1.1"), None, decimal.Decimal("3.3")]
663 | self.cursor.execute("SELECT %s as f1", (v,))
664 | retval = self.cursor.fetchall()
665 | self.assertEqual(retval[0][0], v)
666 |
667 | def testStringArrayRoundtrip(self):
668 | v = [
669 | "Hello!", "World!", "abcdefghijklmnopqrstuvwxyz", "",
670 | "A bunch of random characters:",
671 | " ~!@#$%^&*()_+`1234567890-=[]\\{}|{;':\",./<>?\t", None]
672 | self.cursor.execute("SELECT %s as f1", (v,))
673 | retval = self.cursor.fetchall()
674 | self.assertEqual(retval[0][0], v)
675 |
676 | def testUnicodeArrayRoundtrip(self):
677 | if PY2:
678 | v = map(unicode, ("Second", "To", None)) # noqa
679 | self.cursor.execute("SELECT %s as f1", (v,))
680 | retval = self.cursor.fetchall()
681 | self.assertEqual(retval[0][0], v)
682 |
683 | def testEmptyArray(self):
684 | v = []
685 | self.cursor.execute("SELECT %s as f1", (v,))
686 | retval = self.cursor.fetchall()
687 | self.assertEqual(retval[0][0], v)
688 |
689 | def testArrayContentNotSupported(self):
690 | class Kajigger(object):
691 | pass
692 | self.assertRaises(
693 | pg8000.ArrayContentNotSupportedError,
694 | self.db.array_inspect, [[Kajigger()], [None], [None]])
695 | self.db.rollback()
696 |
697 | def testArrayDimensions(self):
698 | for arr in (
699 | [1, [2]], [[1], [2], [3, 4]],
700 | [[[1]], [[2]], [[3, 4]]],
701 | [[[1]], [[2]], [[3, 4]]],
702 | [[[[1]]], [[[2]]], [[[3, 4]]]],
703 | [[1, 2, 3], [4, [5], 6]]):
704 |
705 | arr_send = self.db.array_inspect(arr)[2]
706 | self.assertRaises(
707 | pg8000.ArrayDimensionsNotConsistentError, arr_send, arr)
708 | self.db.rollback()
709 |
710 | def testArrayHomogenous(self):
711 | arr = [[[1]], [[2]], [[3.1]]]
712 | arr_send = self.db.array_inspect(arr)[2]
713 | self.assertRaises(
714 | pg8000.ArrayContentNotHomogenousError, arr_send, arr)
715 | self.db.rollback()
716 |
717 | def testArrayInspect(self):
718 | self.db.array_inspect([1, 2, 3])
719 | self.db.array_inspect([[1], [2], [3]])
720 | self.db.array_inspect([[[1]], [[2]], [[3]]])
721 |
722 | def testMacaddr(self):
723 | self.cursor.execute("SELECT macaddr '08002b:010203'")
724 | retval = self.cursor.fetchall()
725 | self.assertEqual(retval[0][0], "08:00:2b:01:02:03")
726 |
727 | def testTsvectorRoundtrip(self):
728 | self.cursor.execute(
729 | "SELECT cast(%s as tsvector)",
730 | ('a fat cat sat on a mat and ate a fat rat',))
731 | retval = self.cursor.fetchall()
732 | self.assertEqual(
733 | retval[0][0], "'a' 'and' 'ate' 'cat' 'fat' 'mat' 'on' 'rat' 'sat'")
734 |
735 | def testHstoreRoundtrip(self):
736 | val = '"a"=>"1"'
737 | self.cursor.execute("SELECT cast(%s as hstore)", (val,))
738 | retval = self.cursor.fetchall()
739 | self.assertEqual(retval[0][0], val)
740 |
741 | def testJsonRoundtrip(self):
742 | if self.db._server_version >= LooseVersion('9.2'):
743 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003}
744 | self.cursor.execute(
745 | "SELECT %s", (pg8000.PGJson(val),))
746 | retval = self.cursor.fetchall()
747 | self.assertEqual(retval[0][0], val)
748 |
749 | def testJsonbRoundtrip(self):
750 | if self.db._server_version >= LooseVersion('9.4'):
751 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003}
752 | self.cursor.execute(
753 | "SELECT cast(%s as jsonb)", (json.dumps(val),))
754 | retval = self.cursor.fetchall()
755 | self.assertEqual(retval[0][0], val)
756 |
757 | def test_json_access_object(self):
758 | if self.db._server_version >= LooseVersion('9.4'):
759 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003}
760 | self.cursor.execute(
761 | "SELECT cast(%s as json) -> %s", (json.dumps(val), 'name'))
762 | retval = self.cursor.fetchall()
763 | self.assertEqual(retval[0][0], 'Apollo 11 Cave')
764 |
765 | def test_jsonb_access_object(self):
766 | if self.db._server_version >= LooseVersion('9.4'):
767 | val = {'name': 'Apollo 11 Cave', 'zebra': True, 'age': 26.003}
768 | self.cursor.execute(
769 | "SELECT cast(%s as jsonb) -> %s", (json.dumps(val), 'name'))
770 | retval = self.cursor.fetchall()
771 | self.assertEqual(retval[0][0], 'Apollo 11 Cave')
772 |
773 | def test_json_access_array(self):
774 | if self.db._server_version >= LooseVersion('9.4'):
775 | val = [-1, -2, -3, -4, -5]
776 | self.cursor.execute(
777 | "SELECT cast(%s as json) -> %s", (json.dumps(val), 2))
778 | retval = self.cursor.fetchall()
779 | self.assertEqual(retval[0][0], -3)
780 |
781 | def testJsonbAccessArray(self):
782 | if self.db._server_version >= LooseVersion('9.4'):
783 | val = [-1, -2, -3, -4, -5]
784 | self.cursor.execute(
785 | "SELECT cast(%s as jsonb) -> %s", (json.dumps(val), 2))
786 | retval = self.cursor.fetchall()
787 | self.assertEqual(retval[0][0], -3)
788 |
789 | def test_jsonb_access_path(self):
790 | if self.db._server_version >= LooseVersion('9.4'):
791 | j = {
792 | "a": [1, 2, 3],
793 | "b": [4, 5, 6]}
794 |
795 | path = ['a', '2']
796 |
797 | self.cursor.execute("SELECT %s #>> %s", [PGJsonb(j), path])
798 | retval = self.cursor.fetchall()
799 | self.assertEqual(retval[0][0], str(j[path[0]][int(path[1])]))
800 |
801 | def test_timestamp_send_float(self):
802 | assert b('A\xbe\x19\xcf\x80\x00\x00\x00') == \
803 | pg8000.core.timestamp_send_float(
804 | datetime.datetime(2016, 1, 2, 0, 0))
805 |
806 | def test_infinity_timestamp_roundtrip(self):
807 | v = 'infinity'
808 | self.cursor.execute("SELECT cast(%s as timestamp) as f1", (v,))
809 | retval = self.cursor.fetchall()
810 | self.assertEqual(retval[0][0], v)
811 |
812 |
813 | if __name__ == "__main__":
814 | unittest.main()
815 |
--------------------------------------------------------------------------------
/tests/test_typeobjects.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from pg8000 import Interval
3 |
4 |
5 | # Type conversion tests
6 | class Tests(unittest.TestCase):
7 | def setUp(self):
8 | pass
9 |
10 | def tearDown(self):
11 | pass
12 |
13 | def testIntervalConstructor(self):
14 | i = Interval(days=1)
15 | self.assertEqual(i.months, 0)
16 | self.assertEqual(i.days, 1)
17 | self.assertEqual(i.microseconds, 0)
18 |
19 | def intervalRangeTest(self, parameter, in_range, out_of_range):
20 | for v in out_of_range:
21 | try:
22 | Interval(**{parameter: v})
23 | self.fail("expected OverflowError")
24 | except OverflowError:
25 | pass
26 | for v in in_range:
27 | Interval(**{parameter: v})
28 |
29 | def testIntervalDaysRange(self):
30 | out_of_range_days = (-2147483648, +2147483648,)
31 | in_range_days = (-2147483647, +2147483647,)
32 | self.intervalRangeTest("days", in_range_days, out_of_range_days)
33 |
34 | def testIntervalMonthsRange(self):
35 | out_of_range_months = (-2147483648, +2147483648,)
36 | in_range_months = (-2147483647, +2147483647,)
37 | self.intervalRangeTest("months", in_range_months, out_of_range_months)
38 |
39 | def testIntervalMicrosecondsRange(self):
40 | out_of_range_microseconds = (
41 | -9223372036854775808, +9223372036854775808,)
42 | in_range_microseconds = (
43 | -9223372036854775807, +9223372036854775807,)
44 | self.intervalRangeTest(
45 | "microseconds", in_range_microseconds, out_of_range_microseconds)
46 |
47 |
48 | if __name__ == "__main__":
49 | unittest.main()
50 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # Tox (http://tox.testrun.org/) is a tool for running tests
2 | # in multiple virtualenvs. This configuration file will run the
3 | # test suite on all supported python versions. To use it, "pip install tox"
4 | # and then run "tox" from this directory.
5 |
6 | [tox]
7 | skip_missing_interpreters=True
8 |
9 |
10 | [testenv]
11 | commands =
12 | nosetests -x
13 | python -m doctest README.adoc
14 | flake8 pg8000
15 | python setup.py check
16 | deps =
17 | nose
18 | flake8
19 | pytz
20 |
--------------------------------------------------------------------------------
/versioneer.py:
--------------------------------------------------------------------------------
1 |
2 | # Version: 0.15
3 |
4 | """
5 | The Versioneer
6 | ==============
7 |
8 | * like a rocketeer, but for versions!
9 | * https://github.com/warner/python-versioneer
10 | * Brian Warner
11 | * License: Public Domain
12 | * Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, and pypy
13 | * [![Latest Version]
14 | (https://pypip.in/version/versioneer/badge.svg?style=flat)
15 | ](https://pypi.python.org/pypi/versioneer/)
16 | * [![Build Status]
17 | (https://travis-ci.org/warner/python-versioneer.png?branch=master)
18 | ](https://travis-ci.org/warner/python-versioneer)
19 |
20 | This is a tool for managing a recorded version number in distutils-based
21 | python projects. The goal is to remove the tedious and error-prone "update
22 | the embedded version string" step from your release process. Making a new
23 | release should be as easy as recording a new tag in your version-control
24 | system, and maybe making new tarballs.
25 |
26 |
27 | ## Quick Install
28 |
29 | * `pip install versioneer` to somewhere to your $PATH
30 | * add a `[versioneer]` section to your setup.cfg (see below)
31 | * run `versioneer install` in your source tree, commit the results
32 |
33 | ## Version Identifiers
34 |
35 | Source trees come from a variety of places:
36 |
37 | * a version-control system checkout (mostly used by developers)
38 | * a nightly tarball, produced by build automation
39 | * a snapshot tarball, produced by a web-based VCS browser, like github's
40 | "tarball from tag" feature
41 | * a release tarball, produced by "setup.py sdist", distributed through PyPI
42 |
43 | Within each source tree, the version identifier (either a string or a number,
44 | this tool is format-agnostic) can come from a variety of places:
45 |
46 | * ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows
47 | about recent "tags" and an absolute revision-id
48 | * the name of the directory into which the tarball was unpacked
49 | * an expanded VCS keyword ($Id$, etc)
50 | * a `_version.py` created by some earlier build step
51 |
52 | For released software, the version identifier is closely related to a VCS
53 | tag. Some projects use tag names that include more than just the version
54 | string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool
55 | needs to strip the tag prefix to extract the version identifier. For
56 | unreleased software (between tags), the version identifier should provide
57 | enough information to help developers recreate the same tree, while also
58 | giving them an idea of roughly how old the tree is (after version 1.2, before
59 | version 1.3). Many VCS systems can report a description that captures this,
60 | for example `git describe --tags --dirty --always` reports things like
61 | "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the
62 | 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has
63 | uncommitted changes.
64 |
65 | The version identifier is used for multiple purposes:
66 |
67 | * to allow the module to self-identify its version: `myproject.__version__`
68 | * to choose a name and prefix for a 'setup.py sdist' tarball
69 |
70 | ## Theory of Operation
71 |
72 | Versioneer works by adding a special `_version.py` file into your source
73 | tree, where your `__init__.py` can import it. This `_version.py` knows how to
74 | dynamically ask the VCS tool for version information at import time.
75 |
76 | `_version.py` also contains `$Revision$` markers, and the installation
77 | process marks `_version.py` to have this marker rewritten with a tag name
78 | during the `git archive` command. As a result, generated tarballs will
79 | contain enough information to get the proper version.
80 |
81 | To allow `setup.py` to compute a version too, a `versioneer.py` is added to
82 | the top level of your source tree, next to `setup.py` and the `setup.cfg`
83 | that configures it. This overrides several distutils/setuptools commands to
84 | compute the version when invoked, and changes `setup.py build` and `setup.py
85 | sdist` to replace `_version.py` with a small static file that contains just
86 | the generated version data.
87 |
88 | ## Installation
89 |
90 | First, decide on values for the following configuration variables:
91 |
92 | * `VCS`: the version control system you use. Currently accepts "git".
93 |
94 | * `style`: the style of version string to be produced. See "Styles" below for
95 | details. Defaults to "pep440", which looks like
96 | `TAG[+DISTANCE.gSHORTHASH[.dirty]]`.
97 |
98 | * `versionfile_source`:
99 |
100 | A project-relative pathname into which the generated version strings should
101 | be written. This is usually a `_version.py` next to your project's main
102 | `__init__.py` file, so it can be imported at runtime. If your project uses
103 | `src/myproject/__init__.py`, this should be `src/myproject/_version.py`.
104 | This file should be checked in to your VCS as usual: the copy created below
105 | by `setup.py setup_versioneer` will include code that parses expanded VCS
106 | keywords in generated tarballs. The 'build' and 'sdist' commands will
107 | replace it with a copy that has just the calculated version string.
108 |
109 | This must be set even if your project does not have any modules (and will
110 | therefore never import `_version.py`), since "setup.py sdist" -based trees
111 | still need somewhere to record the pre-calculated version strings. Anywhere
112 | in the source tree should do. If there is a `__init__.py` next to your
113 | `_version.py`, the `setup.py setup_versioneer` command (described below)
114 | will append some `__version__`-setting assignments, if they aren't already
115 | present.
116 |
117 | * `versionfile_build`:
118 |
119 | Like `versionfile_source`, but relative to the build directory instead of
120 | the source directory. These will differ when your setup.py uses
121 | 'package_dir='. If you have `package_dir={'myproject': 'src/myproject'}`,
122 | then you will probably have `versionfile_build='myproject/_version.py'` and
123 | `versionfile_source='src/myproject/_version.py'`.
124 |
125 | If this is set to None, then `setup.py build` will not attempt to rewrite
126 | any `_version.py` in the built tree. If your project does not have any
127 | libraries (e.g. if it only builds a script), then you should use
128 | `versionfile_build = None` and override `distutils.command.build_scripts`
129 | to explicitly insert a copy of `versioneer.get_version()` into your
130 | generated script.
131 |
132 | * `tag_prefix`:
133 |
134 | a string, like 'PROJECTNAME-', which appears at the start of all VCS tags.
135 | If your tags look like 'myproject-1.2.0', then you should use
136 | tag_prefix='myproject-'. If you use unprefixed tags like '1.2.0', this
137 | should be an empty string.
138 |
139 | * `parentdir_prefix`:
140 |
141 | a optional string, frequently the same as tag_prefix, which appears at the
142 | start of all unpacked tarball filenames. If your tarball unpacks into
143 | 'myproject-1.2.0', this should be 'myproject-'. To disable this feature,
144 | just omit the field from your `setup.cfg`.
145 |
146 | This tool provides one script, named `versioneer`. That script has one mode,
147 | "install", which writes a copy of `versioneer.py` into the current directory
148 | and runs `versioneer.py setup` to finish the installation.
149 |
150 | To versioneer-enable your project:
151 |
152 | * 1: Modify your `setup.cfg`, adding a section named `[versioneer]` and
153 | populating it with the configuration values you decided earlier (note that
154 | the option names are not case-sensitive):
155 |
156 | ````
157 | [versioneer]
158 | VCS = git
159 | style = pep440
160 | versionfile_source = src/myproject/_version.py
161 | versionfile_build = myproject/_version.py
162 | tag_prefix = ""
163 | parentdir_prefix = myproject-
164 | ````
165 |
166 | * 2: Run `versioneer install`. This will do the following:
167 |
168 | * copy `versioneer.py` into the top of your source tree
169 | * create `_version.py` in the right place (`versionfile_source`)
170 | * modify your `__init__.py` (if one exists next to `_version.py`) to define
171 | `__version__` (by calling a function from `_version.py`)
172 | * modify your `MANIFEST.in` to include both `versioneer.py` and the
173 | generated `_version.py` in sdist tarballs
174 |
175 | `versioneer install` will complain about any problems it finds with your
176 | `setup.py` or `setup.cfg`. Run it multiple times until you have fixed all
177 | the problems.
178 |
179 | * 3: add a `import versioneer` to your setup.py, and add the following
180 | arguments to the setup() call:
181 |
182 | version=versioneer.get_version(),
183 | cmdclass=versioneer.get_cmdclass(),
184 |
185 | * 4: commit these changes to your VCS. To make sure you won't forget,
186 | `versioneer install` will mark everything it touched for addition using
187 | `git add`. Don't forget to add `setup.py` and `setup.cfg` too.
188 |
189 | ## Post-Installation Usage
190 |
191 | Once established, all uses of your tree from a VCS checkout should get the
192 | current version string. All generated tarballs should include an embedded
193 | version string (so users who unpack them will not need a VCS tool installed).
194 |
195 | If you distribute your project through PyPI, then the release process should
196 | boil down to two steps:
197 |
198 | * 1: git tag 1.0
199 | * 2: python setup.py register sdist upload
200 |
201 | If you distribute it through github (i.e. users use github to generate
202 | tarballs with `git archive`), the process is:
203 |
204 | * 1: git tag 1.0
205 | * 2: git push; git push --tags
206 |
207 | Versioneer will report "0+untagged.NUMCOMMITS.gHASH" until your tree has at
208 | least one tag in its history.
209 |
210 | ## Version-String Flavors
211 |
212 | Code which uses Versioneer can learn about its version string at runtime by
213 | importing `_version` from your main `__init__.py` file and running the
214 | `get_versions()` function. From the "outside" (e.g. in `setup.py`), you can
215 | import the top-level `versioneer.py` and run `get_versions()`.
216 |
217 | Both functions return a dictionary with different flavors of version
218 | information:
219 |
220 | * `['version']`: A condensed version string, rendered using the selected
221 | style. This is the most commonly used value for the project's version
222 | string. The default "pep440" style yields strings like `0.11`,
223 | `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section
224 | below for alternative styles.
225 |
226 | * `['full-revisionid']`: detailed revision identifier. For Git, this is the
227 | full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac".
228 |
229 | * `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that
230 | this is only accurate if run in a VCS checkout, otherwise it is likely to
231 | be False or None
232 |
233 | * `['error']`: if the version string could not be computed, this will be set
234 | to a string describing the problem, otherwise it will be None. It may be
235 | useful to throw an exception in setup.py if this is set, to avoid e.g.
236 | creating tarballs with a version string of "unknown".
237 |
238 | Some variants are more useful than others. Including `full-revisionid` in a
239 | bug report should allow developers to reconstruct the exact code being tested
240 | (or indicate the presence of local changes that should be shared with the
241 | developers). `version` is suitable for display in an "about" box or a CLI
242 | `--version` output: it can be easily compared against release notes and lists
243 | of bugs fixed in various releases.
244 |
245 | The installer adds the following text to your `__init__.py` to place a basic
246 | version in `YOURPROJECT.__version__`:
247 |
248 | from ._version import get_versions
249 | __version__ = get_versions()['version']
250 | del get_versions
251 |
252 | ## Styles
253 |
254 | The setup.cfg `style=` configuration controls how the VCS information is
255 | rendered into a version string.
256 |
257 | The default style, "pep440", produces a PEP440-compliant string, equal to the
258 | un-prefixed tag name for actual releases, and containing an additional "local
259 | version" section with more detail for in-between builds. For Git, this is
260 | TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags
261 | --dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the
262 | tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and
263 | that this commit is two revisions ("+2") beyond the "0.11" tag. For released
264 | software (exactly equal to a known tag), the identifier will only contain the
265 | stripped tag, e.g. "0.11".
266 |
267 | Other styles are available. See details.md in the Versioneer source tree for
268 | descriptions.
269 |
270 | ## Debugging
271 |
272 | Versioneer tries to avoid fatal errors: if something goes wrong, it will tend
273 | to return a version of "0+unknown". To investigate the problem, run `setup.py
274 | version`, which will run the version-lookup code in a verbose mode, and will
275 | display the full contents of `get_versions()` (including the `error` string,
276 | which may help identify what went wrong).
277 |
278 | ## Updating Versioneer
279 |
280 | To upgrade your project to a new release of Versioneer, do the following:
281 |
282 | * install the new Versioneer (`pip install -U versioneer` or equivalent)
283 | * edit `setup.cfg`, if necessary, to include any new configuration settings
284 | indicated by the release notes
285 | * re-run `versioneer install` in your source tree, to replace
286 | `SRC/_version.py`
287 | * commit any changed files
288 |
289 | ### Upgrading to 0.15
290 |
291 | Starting with this version, Versioneer is configured with a `[versioneer]`
292 | section in your `setup.cfg` file. Earlier versions required the `setup.py` to
293 | set attributes on the `versioneer` module immediately after import. The new
294 | version will refuse to run (raising an exception during import) until you
295 | have provided the necessary `setup.cfg` section.
296 |
297 | In addition, the Versioneer package provides an executable named
298 | `versioneer`, and the installation process is driven by running `versioneer
299 | install`. In 0.14 and earlier, the executable was named
300 | `versioneer-installer` and was run without an argument.
301 |
302 | ### Upgrading to 0.14
303 |
304 | 0.14 changes the format of the version string. 0.13 and earlier used
305 | hyphen-separated strings like "0.11-2-g1076c97-dirty". 0.14 and beyond use a
306 | plus-separated "local version" section strings, with dot-separated
307 | components, like "0.11+2.g1076c97". PEP440-strict tools did not like the old
308 | format, but should be ok with the new one.
309 |
310 | ### Upgrading from 0.11 to 0.12
311 |
312 | Nothing special.
313 |
314 | ### Upgrading from 0.10 to 0.11
315 |
316 | You must add a `versioneer.VCS = "git"` to your `setup.py` before re-running
317 | `setup.py setup_versioneer`. This will enable the use of additional
318 | version-control systems (SVN, etc) in the future.
319 |
320 | ## Future Directions
321 |
322 | This tool is designed to make it easily extended to other version-control
323 | systems: all VCS-specific components are in separate directories like
324 | src/git/ . The top-level `versioneer.py` script is assembled from these
325 | components by running make-versioneer.py . In the future, make-versioneer.py
326 | will take a VCS name as an argument, and will construct a version of
327 | `versioneer.py` that is specific to the given VCS. It might also take the
328 | configuration arguments that are currently provided manually during
329 | installation by editing setup.py . Alternatively, it might go the other
330 | direction and include code from all supported VCS systems, reducing the
331 | number of intermediate scripts.
332 |
333 |
334 | ## License
335 |
336 | To make Versioneer easier to embed, all its code is hereby released into the
337 | public domain. The `_version.py` that it creates is also in the public
338 | domain.
339 |
340 | """
341 |
342 | from __future__ import print_function
343 | try:
344 | import configparser
345 | except ImportError:
346 | import ConfigParser as configparser
347 | import errno
348 | import json
349 | import os
350 | import re
351 | import subprocess
352 | import sys
353 |
354 |
355 | class VersioneerConfig:
356 | pass
357 |
358 |
359 | def get_root():
360 | # we require that all commands are run from the project root, i.e. the
361 | # directory that contains setup.py, setup.cfg, and versioneer.py .
362 | root = os.path.realpath(os.path.abspath(os.getcwd()))
363 | setup_py = os.path.join(root, "setup.py")
364 | versioneer_py = os.path.join(root, "versioneer.py")
365 | if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
366 | # allow 'python path/to/setup.py COMMAND'
367 | root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0])))
368 | setup_py = os.path.join(root, "setup.py")
369 | versioneer_py = os.path.join(root, "versioneer.py")
370 | if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):
371 | err = ("Versioneer was unable to run the project root directory. "
372 | "Versioneer requires setup.py to be executed from "
373 | "its immediate directory (like 'python setup.py COMMAND'), "
374 | "or in a way that lets it use sys.argv[0] to find the root "
375 | "(like 'python path/to/setup.py COMMAND').")
376 | raise VersioneerBadRootError(err)
377 | try:
378 | # Certain runtime workflows (setup.py install/develop in a setuptools
379 | # tree) execute all dependencies in a single python process, so
380 | # "versioneer" may be imported multiple times, and python's shared
381 | # module-import table will cache the first one. So we can't use
382 | # os.path.dirname(__file__), as that will find whichever
383 | # versioneer.py was first imported, even in later projects.
384 | me = os.path.realpath(os.path.abspath(__file__))
385 | if os.path.splitext(me)[0] != os.path.splitext(versioneer_py)[0]:
386 | print("Warning: build in %s is using versioneer.py from %s"
387 | % (os.path.dirname(me), versioneer_py))
388 | except NameError:
389 | pass
390 | return root
391 |
392 |
393 | def get_config_from_root(root):
394 | # This might raise EnvironmentError (if setup.cfg is missing), or
395 | # configparser.NoSectionError (if it lacks a [versioneer] section), or
396 | # configparser.NoOptionError (if it lacks "VCS="). See the docstring at
397 | # the top of versioneer.py for instructions on writing your setup.cfg .
398 | setup_cfg = os.path.join(root, "setup.cfg")
399 | parser = configparser.SafeConfigParser()
400 | with open(setup_cfg, "r") as f:
401 | parser.readfp(f)
402 | VCS = parser.get("versioneer", "VCS") # mandatory
403 |
404 | def get(parser, name):
405 | if parser.has_option("versioneer", name):
406 | return parser.get("versioneer", name)
407 | return None
408 | cfg = VersioneerConfig()
409 | cfg.VCS = VCS
410 | cfg.style = get(parser, "style") or ""
411 | cfg.versionfile_source = get(parser, "versionfile_source")
412 | cfg.versionfile_build = get(parser, "versionfile_build")
413 | cfg.tag_prefix = get(parser, "tag_prefix")
414 | cfg.parentdir_prefix = get(parser, "parentdir_prefix")
415 | cfg.verbose = get(parser, "verbose")
416 | return cfg
417 |
418 |
419 | class NotThisMethod(Exception):
420 | pass
421 |
422 | # these dictionaries contain VCS-specific tools
423 | LONG_VERSION_PY = {}
424 | HANDLERS = {}
425 |
426 |
427 | def register_vcs_handler(vcs, method): # decorator
428 | def decorate(f):
429 | if vcs not in HANDLERS:
430 | HANDLERS[vcs] = {}
431 | HANDLERS[vcs][method] = f
432 | return f
433 | return decorate
434 |
435 |
436 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False):
437 | assert isinstance(commands, list)
438 | p = None
439 | for c in commands:
440 | try:
441 | dispcmd = str([c] + args)
442 | # remember shell=False, so use git.cmd on windows, not just git
443 | p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE,
444 | stderr=(subprocess.PIPE if hide_stderr
445 | else None))
446 | break
447 | except EnvironmentError:
448 | e = sys.exc_info()[1]
449 | if e.errno == errno.ENOENT:
450 | continue
451 | if verbose:
452 | print("unable to run %s" % dispcmd)
453 | print(e)
454 | return None
455 | else:
456 | if verbose:
457 | print("unable to find command, tried %s" % (commands,))
458 | return None
459 | stdout = p.communicate()[0].strip()
460 | if sys.version_info[0] >= 3:
461 | stdout = stdout.decode()
462 | if p.returncode != 0:
463 | if verbose:
464 | print("unable to run %s (error)" % dispcmd)
465 | return None
466 | return stdout
467 | LONG_VERSION_PY['git'] = '''
468 | # This file helps to compute a version number in source trees obtained from
469 | # git-archive tarball (such as those provided by githubs download-from-tag
470 | # feature). Distribution tarballs (built by setup.py sdist) and build
471 | # directories (produced by setup.py build) will contain a much shorter file
472 | # that just contains the computed version number.
473 |
474 | # This file is released into the public domain. Generated by
475 | # versioneer-0.15 (https://github.com/warner/python-versioneer)
476 |
477 | import errno
478 | import os
479 | import re
480 | import subprocess
481 | import sys
482 |
483 |
484 | def get_keywords():
485 | # these strings will be replaced by git during git-archive.
486 | # setup.py/versioneer.py will grep for the variable names, so they must
487 | # each be defined on a line of their own. _version.py will just call
488 | # get_keywords().
489 | git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s"
490 | git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s"
491 | keywords = {"refnames": git_refnames, "full": git_full}
492 | return keywords
493 |
494 |
495 | class VersioneerConfig:
496 | pass
497 |
498 |
499 | def get_config():
500 | # these strings are filled in when 'setup.py versioneer' creates
501 | # _version.py
502 | cfg = VersioneerConfig()
503 | cfg.VCS = "git"
504 | cfg.style = "%(STYLE)s"
505 | cfg.tag_prefix = "%(TAG_PREFIX)s"
506 | cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s"
507 | cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s"
508 | cfg.verbose = False
509 | return cfg
510 |
511 |
512 | class NotThisMethod(Exception):
513 | pass
514 |
515 |
516 | LONG_VERSION_PY = {}
517 | HANDLERS = {}
518 |
519 |
520 | def register_vcs_handler(vcs, method): # decorator
521 | def decorate(f):
522 | if vcs not in HANDLERS:
523 | HANDLERS[vcs] = {}
524 | HANDLERS[vcs][method] = f
525 | return f
526 | return decorate
527 |
528 |
529 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False):
530 | assert isinstance(commands, list)
531 | p = None
532 | for c in commands:
533 | try:
534 | dispcmd = str([c] + args)
535 | # remember shell=False, so use git.cmd on windows, not just git
536 | p = subprocess.Popen([c] + args, cwd=cwd, stdout=subprocess.PIPE,
537 | stderr=(subprocess.PIPE if hide_stderr
538 | else None))
539 | break
540 | except EnvironmentError:
541 | e = sys.exc_info()[1]
542 | if e.errno == errno.ENOENT:
543 | continue
544 | if verbose:
545 | print("unable to run %%s" %% dispcmd)
546 | print(e)
547 | return None
548 | else:
549 | if verbose:
550 | print("unable to find command, tried %%s" %% (commands,))
551 | return None
552 | stdout = p.communicate()[0].strip()
553 | if sys.version_info[0] >= 3:
554 | stdout = stdout.decode()
555 | if p.returncode != 0:
556 | if verbose:
557 | print("unable to run %%s (error)" %% dispcmd)
558 | return None
559 | return stdout
560 |
561 |
562 | def versions_from_parentdir(parentdir_prefix, root, verbose):
563 | # Source tarballs conventionally unpack into a directory that includes
564 | # both the project name and a version string.
565 | dirname = os.path.basename(root)
566 | if not dirname.startswith(parentdir_prefix):
567 | if verbose:
568 | print("guessing rootdir is '%%s', but '%%s' doesn't start with "
569 | "prefix '%%s'" %% (root, dirname, parentdir_prefix))
570 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
571 | return {"version": dirname[len(parentdir_prefix):],
572 | "full-revisionid": None,
573 | "dirty": False, "error": None}
574 |
575 |
576 | @register_vcs_handler("git", "get_keywords")
577 | def git_get_keywords(versionfile_abs):
578 | # the code embedded in _version.py can just fetch the value of these
579 | # keywords. When used from setup.py, we don't want to import _version.py,
580 | # so we do it with a regexp instead. This function is not used from
581 | # _version.py.
582 | keywords = {}
583 | try:
584 | f = open(versionfile_abs, "r")
585 | for line in f.readlines():
586 | if line.strip().startswith("git_refnames ="):
587 | mo = re.search(r'=\s*"(.*)"', line)
588 | if mo:
589 | keywords["refnames"] = mo.group(1)
590 | if line.strip().startswith("git_full ="):
591 | mo = re.search(r'=\s*"(.*)"', line)
592 | if mo:
593 | keywords["full"] = mo.group(1)
594 | f.close()
595 | except EnvironmentError:
596 | pass
597 | return keywords
598 |
599 |
600 | @register_vcs_handler("git", "keywords")
601 | def git_versions_from_keywords(keywords, tag_prefix, verbose):
602 | if not keywords:
603 | raise NotThisMethod("no keywords at all, weird")
604 | refnames = keywords["refnames"].strip()
605 | if refnames.startswith("$Format"):
606 | if verbose:
607 | print("keywords are unexpanded, not using")
608 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
609 | refs = set([r.strip() for r in refnames.strip("()").split(",")])
610 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
611 | # just "foo-1.0". If we see a "tag: " prefix, prefer those.
612 | TAG = "tag: "
613 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
614 | if not tags:
615 | # Either we're using git < 1.8.3, or there really are no tags. We use
616 | # a heuristic: assume all version tags have a digit. The old git %%d
617 | # expansion behaves like git log --decorate=short and strips out the
618 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish
619 | # between branches and tags. By ignoring refnames without digits, we
620 | # filter out many common branch names like "release" and
621 | # "stabilization", as well as "HEAD" and "master".
622 | tags = set([r for r in refs if re.search(r'\d', r)])
623 | if verbose:
624 | print("discarding '%%s', no digits" %% ",".join(refs-tags))
625 | if verbose:
626 | print("likely tags: %%s" %% ",".join(sorted(tags)))
627 | for ref in sorted(tags):
628 | # sorting will prefer e.g. "2.0" over "2.0rc1"
629 | if ref.startswith(tag_prefix):
630 | r = ref[len(tag_prefix):]
631 | if verbose:
632 | print("picking %%s" %% r)
633 | return {"version": r,
634 | "full-revisionid": keywords["full"].strip(),
635 | "dirty": False, "error": None
636 | }
637 | # no suitable tags, so version is "0+unknown", but full hex is still there
638 | if verbose:
639 | print("no suitable tags, using unknown + full revision id")
640 | return {"version": "0+unknown",
641 | "full-revisionid": keywords["full"].strip(),
642 | "dirty": False, "error": "no suitable tags"}
643 |
644 |
645 | @register_vcs_handler("git", "pieces_from_vcs")
646 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
647 | # this runs 'git' from the root of the source tree. This only gets called
648 | # if the git-archive 'subst' keywords were *not* expanded, and
649 | # _version.py hasn't already been rewritten with a short version string,
650 | # meaning we're inside a checked out source tree.
651 |
652 | if not os.path.exists(os.path.join(root, ".git")):
653 | if verbose:
654 | print("no .git in %%s" %% root)
655 | raise NotThisMethod("no .git directory")
656 |
657 | GITS = ["git"]
658 | if sys.platform == "win32":
659 | GITS = ["git.cmd", "git.exe"]
660 | # if there is a tag, this yields TAG-NUM-gHEX[-dirty]
661 | # if there are no tags, this yields HEX[-dirty] (no NUM)
662 | describe_out = run_command(GITS, ["describe", "--tags", "--dirty",
663 | "--always", "--long"],
664 | cwd=root)
665 | # --long was added in git-1.5.5
666 | if describe_out is None:
667 | raise NotThisMethod("'git describe' failed")
668 | describe_out = describe_out.strip()
669 | full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
670 | if full_out is None:
671 | raise NotThisMethod("'git rev-parse' failed")
672 | full_out = full_out.strip()
673 |
674 | pieces = {}
675 | pieces["long"] = full_out
676 | pieces["short"] = full_out[:7] # maybe improved later
677 | pieces["error"] = None
678 |
679 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
680 | # TAG might have hyphens.
681 | git_describe = describe_out
682 |
683 | # look for -dirty suffix
684 | dirty = git_describe.endswith("-dirty")
685 | pieces["dirty"] = dirty
686 | if dirty:
687 | git_describe = git_describe[:git_describe.rindex("-dirty")]
688 |
689 | # now we have TAG-NUM-gHEX or HEX
690 |
691 | if "-" in git_describe:
692 | # TAG-NUM-gHEX
693 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
694 | if not mo:
695 | # unparseable. Maybe git-describe is misbehaving?
696 | pieces["error"] = ("unable to parse git-describe output: '%%s'"
697 | %% describe_out)
698 | return pieces
699 |
700 | # tag
701 | full_tag = mo.group(1)
702 | if not full_tag.startswith(tag_prefix):
703 | if verbose:
704 | fmt = "tag '%%s' doesn't start with prefix '%%s'"
705 | print(fmt %% (full_tag, tag_prefix))
706 | pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'"
707 | %% (full_tag, tag_prefix))
708 | return pieces
709 | pieces["closest-tag"] = full_tag[len(tag_prefix):]
710 |
711 | # distance: number of commits since tag
712 | pieces["distance"] = int(mo.group(2))
713 |
714 | # commit: short hex revision ID
715 | pieces["short"] = mo.group(3)
716 |
717 | else:
718 | # HEX: no tags
719 | pieces["closest-tag"] = None
720 | count_out = run_command(GITS, ["rev-list", "HEAD", "--count"],
721 | cwd=root)
722 | pieces["distance"] = int(count_out) # total number of commits
723 |
724 | return pieces
725 |
726 |
727 | def plus_or_dot(pieces):
728 | if "+" in pieces.get("closest-tag", ""):
729 | return "."
730 | return "+"
731 |
732 |
733 | def render_pep440(pieces):
734 | # now build up version string, with post-release "local version
735 | # identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
736 | # get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
737 |
738 | # exceptions:
739 | # 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
740 |
741 | if pieces["closest-tag"]:
742 | rendered = pieces["closest-tag"]
743 | if pieces["distance"] or pieces["dirty"]:
744 | rendered += plus_or_dot(pieces)
745 | rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"])
746 | if pieces["dirty"]:
747 | rendered += ".dirty"
748 | else:
749 | # exception #1
750 | rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"],
751 | pieces["short"])
752 | if pieces["dirty"]:
753 | rendered += ".dirty"
754 | return rendered
755 |
756 |
757 | def render_pep440_pre(pieces):
758 | # TAG[.post.devDISTANCE] . No -dirty
759 |
760 | # exceptions:
761 | # 1: no tags. 0.post.devDISTANCE
762 |
763 | if pieces["closest-tag"]:
764 | rendered = pieces["closest-tag"]
765 | if pieces["distance"]:
766 | rendered += ".post.dev%%d" %% pieces["distance"]
767 | else:
768 | # exception #1
769 | rendered = "0.post.dev%%d" %% pieces["distance"]
770 | return rendered
771 |
772 |
773 | def render_pep440_post(pieces):
774 | # TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that
775 | # .dev0 sorts backwards (a dirty tree will appear "older" than the
776 | # corresponding clean one), but you shouldn't be releasing software with
777 | # -dirty anyways.
778 |
779 | # exceptions:
780 | # 1: no tags. 0.postDISTANCE[.dev0]
781 |
782 | if pieces["closest-tag"]:
783 | rendered = pieces["closest-tag"]
784 | if pieces["distance"] or pieces["dirty"]:
785 | rendered += ".post%%d" %% pieces["distance"]
786 | if pieces["dirty"]:
787 | rendered += ".dev0"
788 | rendered += plus_or_dot(pieces)
789 | rendered += "g%%s" %% pieces["short"]
790 | else:
791 | # exception #1
792 | rendered = "0.post%%d" %% pieces["distance"]
793 | if pieces["dirty"]:
794 | rendered += ".dev0"
795 | rendered += "+g%%s" %% pieces["short"]
796 | return rendered
797 |
798 |
799 | def render_pep440_old(pieces):
800 | # TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty.
801 |
802 | # exceptions:
803 | # 1: no tags. 0.postDISTANCE[.dev0]
804 |
805 | if pieces["closest-tag"]:
806 | rendered = pieces["closest-tag"]
807 | if pieces["distance"] or pieces["dirty"]:
808 | rendered += ".post%%d" %% pieces["distance"]
809 | if pieces["dirty"]:
810 | rendered += ".dev0"
811 | else:
812 | # exception #1
813 | rendered = "0.post%%d" %% pieces["distance"]
814 | if pieces["dirty"]:
815 | rendered += ".dev0"
816 | return rendered
817 |
818 |
819 | def render_git_describe(pieces):
820 | # TAG[-DISTANCE-gHEX][-dirty], like 'git describe --tags --dirty
821 | # --always'
822 |
823 | # exceptions:
824 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix)
825 |
826 | if pieces["closest-tag"]:
827 | rendered = pieces["closest-tag"]
828 | if pieces["distance"]:
829 | rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"])
830 | else:
831 | # exception #1
832 | rendered = pieces["short"]
833 | if pieces["dirty"]:
834 | rendered += "-dirty"
835 | return rendered
836 |
837 |
838 | def render_git_describe_long(pieces):
839 | # TAG-DISTANCE-gHEX[-dirty], like 'git describe --tags --dirty
840 | # --always -long'. The distance/hash is unconditional.
841 |
842 | # exceptions:
843 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix)
844 |
845 | if pieces["closest-tag"]:
846 | rendered = pieces["closest-tag"]
847 | rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"])
848 | else:
849 | # exception #1
850 | rendered = pieces["short"]
851 | if pieces["dirty"]:
852 | rendered += "-dirty"
853 | return rendered
854 |
855 |
856 | def render(pieces, style):
857 | if pieces["error"]:
858 | return {"version": "unknown",
859 | "full-revisionid": pieces.get("long"),
860 | "dirty": None,
861 | "error": pieces["error"]}
862 |
863 | if not style or style == "default":
864 | style = "pep440" # the default
865 |
866 | if style == "pep440":
867 | rendered = render_pep440(pieces)
868 | elif style == "pep440-pre":
869 | rendered = render_pep440_pre(pieces)
870 | elif style == "pep440-post":
871 | rendered = render_pep440_post(pieces)
872 | elif style == "pep440-old":
873 | rendered = render_pep440_old(pieces)
874 | elif style == "git-describe":
875 | rendered = render_git_describe(pieces)
876 | elif style == "git-describe-long":
877 | rendered = render_git_describe_long(pieces)
878 | else:
879 | raise ValueError("unknown style '%%s'" %% style)
880 |
881 | return {"version": rendered, "full-revisionid": pieces["long"],
882 | "dirty": pieces["dirty"], "error": None}
883 |
884 |
885 | def get_versions():
886 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
887 | # __file__, we can work backwards from there to the root. Some
888 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
889 | # case we can only use expanded keywords.
890 |
891 | cfg = get_config()
892 | verbose = cfg.verbose
893 |
894 | try:
895 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
896 | verbose)
897 | except NotThisMethod:
898 | pass
899 |
900 | try:
901 | root = os.path.realpath(__file__)
902 | # versionfile_source is the relative path from the top of the source
903 | # tree (where the .git directory might live) to this file. Invert
904 | # this to find the root from __file__.
905 | for i in cfg.versionfile_source.split('/'):
906 | root = os.path.dirname(root)
907 | except NameError:
908 | return {"version": "0+unknown", "full-revisionid": None,
909 | "dirty": None,
910 | "error": "unable to find root of source tree"}
911 |
912 | try:
913 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
914 | return render(pieces, cfg.style)
915 | except NotThisMethod:
916 | pass
917 |
918 | try:
919 | if cfg.parentdir_prefix:
920 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
921 | except NotThisMethod:
922 | pass
923 |
924 | return {"version": "0+unknown", "full-revisionid": None,
925 | "dirty": None,
926 | "error": "unable to compute version"}
927 | '''
928 |
929 |
930 | @register_vcs_handler("git", "get_keywords")
931 | def git_get_keywords(versionfile_abs):
932 | # the code embedded in _version.py can just fetch the value of these
933 | # keywords. When used from setup.py, we don't want to import _version.py,
934 | # so we do it with a regexp instead. This function is not used from
935 | # _version.py.
936 | keywords = {}
937 | try:
938 | f = open(versionfile_abs, "r")
939 | for line in f.readlines():
940 | if line.strip().startswith("git_refnames ="):
941 | mo = re.search(r'=\s*"(.*)"', line)
942 | if mo:
943 | keywords["refnames"] = mo.group(1)
944 | if line.strip().startswith("git_full ="):
945 | mo = re.search(r'=\s*"(.*)"', line)
946 | if mo:
947 | keywords["full"] = mo.group(1)
948 | f.close()
949 | except EnvironmentError:
950 | pass
951 | return keywords
952 |
953 |
954 | @register_vcs_handler("git", "keywords")
955 | def git_versions_from_keywords(keywords, tag_prefix, verbose):
956 | if not keywords:
957 | raise NotThisMethod("no keywords at all, weird")
958 | refnames = keywords["refnames"].strip()
959 | if refnames.startswith("$Format"):
960 | if verbose:
961 | print("keywords are unexpanded, not using")
962 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
963 | refs = set([r.strip() for r in refnames.strip("()").split(",")])
964 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
965 | # just "foo-1.0". If we see a "tag: " prefix, prefer those.
966 | TAG = "tag: "
967 | tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
968 | if not tags:
969 | # Either we're using git < 1.8.3, or there really are no tags. We use
970 | # a heuristic: assume all version tags have a digit. The old git %d
971 | # expansion behaves like git log --decorate=short and strips out the
972 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish
973 | # between branches and tags. By ignoring refnames without digits, we
974 | # filter out many common branch names like "release" and
975 | # "stabilization", as well as "HEAD" and "master".
976 | tags = set([r for r in refs if re.search(r'\d', r)])
977 | if verbose:
978 | print("discarding '%s', no digits" % ",".join(refs-tags))
979 | if verbose:
980 | print("likely tags: %s" % ",".join(sorted(tags)))
981 | for ref in sorted(tags):
982 | # sorting will prefer e.g. "2.0" over "2.0rc1"
983 | if ref.startswith(tag_prefix):
984 | r = ref[len(tag_prefix):]
985 | if verbose:
986 | print("picking %s" % r)
987 | return {"version": r,
988 | "full-revisionid": keywords["full"].strip(),
989 | "dirty": False, "error": None
990 | }
991 | # no suitable tags, so version is "0+unknown", but full hex is still there
992 | if verbose:
993 | print("no suitable tags, using unknown + full revision id")
994 | return {"version": "0+unknown",
995 | "full-revisionid": keywords["full"].strip(),
996 | "dirty": False, "error": "no suitable tags"}
997 |
998 |
999 | @register_vcs_handler("git", "pieces_from_vcs")
1000 | def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
1001 | # this runs 'git' from the root of the source tree. This only gets called
1002 | # if the git-archive 'subst' keywords were *not* expanded, and
1003 | # _version.py hasn't already been rewritten with a short version string,
1004 | # meaning we're inside a checked out source tree.
1005 |
1006 | if not os.path.exists(os.path.join(root, ".git")):
1007 | if verbose:
1008 | print("no .git in %s" % root)
1009 | raise NotThisMethod("no .git directory")
1010 |
1011 | GITS = ["git"]
1012 | if sys.platform == "win32":
1013 | GITS = ["git.cmd", "git.exe"]
1014 | # if there is a tag, this yields TAG-NUM-gHEX[-dirty]
1015 | # if there are no tags, this yields HEX[-dirty] (no NUM)
1016 | describe_out = run_command(GITS, ["describe", "--tags", "--dirty",
1017 | "--always", "--long"],
1018 | cwd=root)
1019 | # --long was added in git-1.5.5
1020 | if describe_out is None:
1021 | raise NotThisMethod("'git describe' failed")
1022 | describe_out = describe_out.strip()
1023 | full_out = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
1024 | if full_out is None:
1025 | raise NotThisMethod("'git rev-parse' failed")
1026 | full_out = full_out.strip()
1027 |
1028 | pieces = {}
1029 | pieces["long"] = full_out
1030 | pieces["short"] = full_out[:7] # maybe improved later
1031 | pieces["error"] = None
1032 |
1033 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
1034 | # TAG might have hyphens.
1035 | git_describe = describe_out
1036 |
1037 | # look for -dirty suffix
1038 | dirty = git_describe.endswith("-dirty")
1039 | pieces["dirty"] = dirty
1040 | if dirty:
1041 | git_describe = git_describe[:git_describe.rindex("-dirty")]
1042 |
1043 | # now we have TAG-NUM-gHEX or HEX
1044 |
1045 | if "-" in git_describe:
1046 | # TAG-NUM-gHEX
1047 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
1048 | if not mo:
1049 | # unparseable. Maybe git-describe is misbehaving?
1050 | pieces["error"] = ("unable to parse git-describe output: '%s'"
1051 | % describe_out)
1052 | return pieces
1053 |
1054 | # tag
1055 | full_tag = mo.group(1)
1056 | if not full_tag.startswith(tag_prefix):
1057 | if verbose:
1058 | fmt = "tag '%s' doesn't start with prefix '%s'"
1059 | print(fmt % (full_tag, tag_prefix))
1060 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
1061 | % (full_tag, tag_prefix))
1062 | return pieces
1063 | pieces["closest-tag"] = full_tag[len(tag_prefix):]
1064 |
1065 | # distance: number of commits since tag
1066 | pieces["distance"] = int(mo.group(2))
1067 |
1068 | # commit: short hex revision ID
1069 | pieces["short"] = mo.group(3)
1070 |
1071 | else:
1072 | # HEX: no tags
1073 | pieces["closest-tag"] = None
1074 | count_out = run_command(GITS, ["rev-list", "HEAD", "--count"],
1075 | cwd=root)
1076 | pieces["distance"] = int(count_out) # total number of commits
1077 |
1078 | return pieces
1079 |
1080 |
1081 | def do_vcs_install(manifest_in, versionfile_source, ipy):
1082 | GITS = ["git"]
1083 | if sys.platform == "win32":
1084 | GITS = ["git.cmd", "git.exe"]
1085 | files = [manifest_in, versionfile_source]
1086 | if ipy:
1087 | files.append(ipy)
1088 | try:
1089 | me = __file__
1090 | if me.endswith(".pyc") or me.endswith(".pyo"):
1091 | me = os.path.splitext(me)[0] + ".py"
1092 | versioneer_file = os.path.relpath(me)
1093 | except NameError:
1094 | versioneer_file = "versioneer.py"
1095 | files.append(versioneer_file)
1096 | present = False
1097 | try:
1098 | f = open(".gitattributes", "r")
1099 | for line in f.readlines():
1100 | if line.strip().startswith(versionfile_source):
1101 | if "export-subst" in line.strip().split()[1:]:
1102 | present = True
1103 | f.close()
1104 | except EnvironmentError:
1105 | pass
1106 | if not present:
1107 | f = open(".gitattributes", "a+")
1108 | f.write("%s export-subst\n" % versionfile_source)
1109 | f.close()
1110 | files.append(".gitattributes")
1111 | run_command(GITS, ["add", "--"] + files)
1112 |
1113 |
1114 | def versions_from_parentdir(parentdir_prefix, root, verbose):
1115 | # Source tarballs conventionally unpack into a directory that includes
1116 | # both the project name and a version string.
1117 | dirname = os.path.basename(root)
1118 | if not dirname.startswith(parentdir_prefix):
1119 | if verbose:
1120 | print("guessing rootdir is '%s', but '%s' doesn't start with "
1121 | "prefix '%s'" % (root, dirname, parentdir_prefix))
1122 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
1123 | return {"version": dirname[len(parentdir_prefix):],
1124 | "full-revisionid": None,
1125 | "dirty": False, "error": None}
1126 |
1127 | SHORT_VERSION_PY = """
1128 | # This file was generated by 'versioneer.py' (0.15) from
1129 | # revision-control system data, or from the parent directory name of an
1130 | # unpacked source archive. Distribution tarballs contain a pre-generated copy
1131 | # of this file.
1132 |
1133 | import json
1134 | import sys
1135 |
1136 | version_json = '''
1137 | %s
1138 | ''' # END VERSION_JSON
1139 |
1140 |
1141 | def get_versions():
1142 | return json.loads(version_json)
1143 | """
1144 |
1145 |
1146 | def versions_from_file(filename):
1147 | try:
1148 | with open(filename) as f:
1149 | contents = f.read()
1150 | except EnvironmentError:
1151 | raise NotThisMethod("unable to read _version.py")
1152 | mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON",
1153 | contents, re.M | re.S)
1154 | if not mo:
1155 | raise NotThisMethod("no version_json in _version.py")
1156 | return json.loads(mo.group(1))
1157 |
1158 |
1159 | def write_to_version_file(filename, versions):
1160 | os.unlink(filename)
1161 | contents = json.dumps(versions, sort_keys=True,
1162 | indent=1, separators=(",", ": "))
1163 | with open(filename, "w") as f:
1164 | f.write(SHORT_VERSION_PY % contents)
1165 |
1166 | print("set %s to '%s'" % (filename, versions["version"]))
1167 |
1168 |
1169 | def plus_or_dot(pieces):
1170 | if "+" in pieces.get("closest-tag", ""):
1171 | return "."
1172 | return "+"
1173 |
1174 |
1175 | def render_pep440(pieces):
1176 | # now build up version string, with post-release "local version
1177 | # identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
1178 | # get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
1179 |
1180 | # exceptions:
1181 | # 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
1182 |
1183 | if pieces["closest-tag"]:
1184 | rendered = pieces["closest-tag"]
1185 | if pieces["distance"] or pieces["dirty"]:
1186 | rendered += plus_or_dot(pieces)
1187 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
1188 | if pieces["dirty"]:
1189 | rendered += ".dirty"
1190 | else:
1191 | # exception #1
1192 | rendered = "0+untagged.%d.g%s" % (pieces["distance"],
1193 | pieces["short"])
1194 | if pieces["dirty"]:
1195 | rendered += ".dirty"
1196 | return rendered
1197 |
1198 |
1199 | def render_pep440_pre(pieces):
1200 | # TAG[.post.devDISTANCE] . No -dirty
1201 |
1202 | # exceptions:
1203 | # 1: no tags. 0.post.devDISTANCE
1204 |
1205 | if pieces["closest-tag"]:
1206 | rendered = pieces["closest-tag"]
1207 | if pieces["distance"]:
1208 | rendered += ".post.dev%d" % pieces["distance"]
1209 | else:
1210 | # exception #1
1211 | rendered = "0.post.dev%d" % pieces["distance"]
1212 | return rendered
1213 |
1214 |
1215 | def render_pep440_post(pieces):
1216 | # TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that
1217 | # .dev0 sorts backwards (a dirty tree will appear "older" than the
1218 | # corresponding clean one), but you shouldn't be releasing software with
1219 | # -dirty anyways.
1220 |
1221 | # exceptions:
1222 | # 1: no tags. 0.postDISTANCE[.dev0]
1223 |
1224 | if pieces["closest-tag"]:
1225 | rendered = pieces["closest-tag"]
1226 | if pieces["distance"] or pieces["dirty"]:
1227 | rendered += ".post%d" % pieces["distance"]
1228 | if pieces["dirty"]:
1229 | rendered += ".dev0"
1230 | rendered += plus_or_dot(pieces)
1231 | rendered += "g%s" % pieces["short"]
1232 | else:
1233 | # exception #1
1234 | rendered = "0.post%d" % pieces["distance"]
1235 | if pieces["dirty"]:
1236 | rendered += ".dev0"
1237 | rendered += "+g%s" % pieces["short"]
1238 | return rendered
1239 |
1240 |
1241 | def render_pep440_old(pieces):
1242 | # TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty.
1243 |
1244 | # exceptions:
1245 | # 1: no tags. 0.postDISTANCE[.dev0]
1246 |
1247 | if pieces["closest-tag"]:
1248 | rendered = pieces["closest-tag"]
1249 | if pieces["distance"] or pieces["dirty"]:
1250 | rendered += ".post%d" % pieces["distance"]
1251 | if pieces["dirty"]:
1252 | rendered += ".dev0"
1253 | else:
1254 | # exception #1
1255 | rendered = "0.post%d" % pieces["distance"]
1256 | if pieces["dirty"]:
1257 | rendered += ".dev0"
1258 | return rendered
1259 |
1260 |
1261 | def render_git_describe(pieces):
1262 | # TAG[-DISTANCE-gHEX][-dirty], like 'git describe --tags --dirty
1263 | # --always'
1264 |
1265 | # exceptions:
1266 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix)
1267 |
1268 | if pieces["closest-tag"]:
1269 | rendered = pieces["closest-tag"]
1270 | if pieces["distance"]:
1271 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
1272 | else:
1273 | # exception #1
1274 | rendered = pieces["short"]
1275 | if pieces["dirty"]:
1276 | rendered += "-dirty"
1277 | return rendered
1278 |
1279 |
1280 | def render_git_describe_long(pieces):
1281 | # TAG-DISTANCE-gHEX[-dirty], like 'git describe --tags --dirty
1282 | # --always -long'. The distance/hash is unconditional.
1283 |
1284 | # exceptions:
1285 | # 1: no tags. HEX[-dirty] (note: no 'g' prefix)
1286 |
1287 | if pieces["closest-tag"]:
1288 | rendered = pieces["closest-tag"]
1289 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
1290 | else:
1291 | # exception #1
1292 | rendered = pieces["short"]
1293 | if pieces["dirty"]:
1294 | rendered += "-dirty"
1295 | return rendered
1296 |
1297 |
1298 | def render(pieces, style):
1299 | if pieces["error"]:
1300 | return {"version": "unknown",
1301 | "full-revisionid": pieces.get("long"),
1302 | "dirty": None,
1303 | "error": pieces["error"]}
1304 |
1305 | if not style or style == "default":
1306 | style = "pep440" # the default
1307 |
1308 | if style == "pep440":
1309 | rendered = render_pep440(pieces)
1310 | elif style == "pep440-pre":
1311 | rendered = render_pep440_pre(pieces)
1312 | elif style == "pep440-post":
1313 | rendered = render_pep440_post(pieces)
1314 | elif style == "pep440-old":
1315 | rendered = render_pep440_old(pieces)
1316 | elif style == "git-describe":
1317 | rendered = render_git_describe(pieces)
1318 | elif style == "git-describe-long":
1319 | rendered = render_git_describe_long(pieces)
1320 | else:
1321 | raise ValueError("unknown style '%s'" % style)
1322 |
1323 | return {"version": rendered, "full-revisionid": pieces["long"],
1324 | "dirty": pieces["dirty"], "error": None}
1325 |
1326 |
1327 | class VersioneerBadRootError(Exception):
1328 | pass
1329 |
1330 |
1331 | def get_versions(verbose=False):
1332 | # returns dict with two keys: 'version' and 'full'
1333 |
1334 | if "versioneer" in sys.modules:
1335 | # see the discussion in cmdclass.py:get_cmdclass()
1336 | del sys.modules["versioneer"]
1337 |
1338 | root = get_root()
1339 | cfg = get_config_from_root(root)
1340 |
1341 | assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg"
1342 | handlers = HANDLERS.get(cfg.VCS)
1343 | assert handlers, "unrecognized VCS '%s'" % cfg.VCS
1344 | verbose = verbose or cfg.verbose
1345 | assert cfg.versionfile_source is not None, \
1346 | "please set versioneer.versionfile_source"
1347 | assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix"
1348 |
1349 | versionfile_abs = os.path.join(root, cfg.versionfile_source)
1350 |
1351 | # extract version from first of: _version.py, VCS command (e.g. 'git
1352 | # describe'), parentdir. This is meant to work for developers using a
1353 | # source checkout, for users of a tarball created by 'setup.py sdist',
1354 | # and for users of a tarball/zipball created by 'git archive' or github's
1355 | # download-from-tag feature or the equivalent in other VCSes.
1356 |
1357 | get_keywords_f = handlers.get("get_keywords")
1358 | from_keywords_f = handlers.get("keywords")
1359 | if get_keywords_f and from_keywords_f:
1360 | try:
1361 | keywords = get_keywords_f(versionfile_abs)
1362 | ver = from_keywords_f(keywords, cfg.tag_prefix, verbose)
1363 | if verbose:
1364 | print("got version from expanded keyword %s" % ver)
1365 | return ver
1366 | except NotThisMethod:
1367 | pass
1368 |
1369 | try:
1370 | ver = versions_from_file(versionfile_abs)
1371 | if verbose:
1372 | print("got version from file %s %s" % (versionfile_abs, ver))
1373 | return ver
1374 | except NotThisMethod:
1375 | pass
1376 |
1377 | from_vcs_f = handlers.get("pieces_from_vcs")
1378 | if from_vcs_f:
1379 | try:
1380 | pieces = from_vcs_f(cfg.tag_prefix, root, verbose)
1381 | ver = render(pieces, cfg.style)
1382 | if verbose:
1383 | print("got version from VCS %s" % ver)
1384 | return ver
1385 | except NotThisMethod:
1386 | pass
1387 |
1388 | try:
1389 | if cfg.parentdir_prefix:
1390 | ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
1391 | if verbose:
1392 | print("got version from parentdir %s" % ver)
1393 | return ver
1394 | except NotThisMethod:
1395 | pass
1396 |
1397 | if verbose:
1398 | print("unable to compute version")
1399 |
1400 | return {"version": "0+unknown", "full-revisionid": None,
1401 | "dirty": None, "error": "unable to compute version"}
1402 |
1403 |
1404 | def get_version():
1405 | return get_versions()["version"]
1406 |
1407 |
1408 | def get_cmdclass():
1409 | if "versioneer" in sys.modules:
1410 | del sys.modules["versioneer"]
1411 | # this fixes the "python setup.py develop" case (also 'install' and
1412 | # 'easy_install .'), in which subdependencies of the main project are
1413 | # built (using setup.py bdist_egg) in the same python process. Assume
1414 | # a main project A and a dependency B, which use different versions
1415 | # of Versioneer. A's setup.py imports A's Versioneer, leaving it in
1416 | # sys.modules by the time B's setup.py is executed, causing B to run
1417 | # with the wrong versioneer. Setuptools wraps the sub-dep builds in a
1418 | # sandbox that restores sys.modules to it's pre-build state, so the
1419 | # parent is protected against the child's "import versioneer". By
1420 | # removing ourselves from sys.modules here, before the child build
1421 | # happens, we protect the child from the parent's versioneer too.
1422 | # Also see https://github.com/warner/python-versioneer/issues/52
1423 |
1424 | cmds = {}
1425 |
1426 | # we add "version" to both distutils and setuptools
1427 | from distutils.core import Command
1428 |
1429 | class cmd_version(Command):
1430 | description = "report generated version string"
1431 | user_options = []
1432 | boolean_options = []
1433 |
1434 | def initialize_options(self):
1435 | pass
1436 |
1437 | def finalize_options(self):
1438 | pass
1439 |
1440 | def run(self):
1441 | vers = get_versions(verbose=True)
1442 | print("Version: %s" % vers["version"])
1443 | print(" full-revisionid: %s" % vers.get("full-revisionid"))
1444 | print(" dirty: %s" % vers.get("dirty"))
1445 | if vers["error"]:
1446 | print(" error: %s" % vers["error"])
1447 | cmds["version"] = cmd_version
1448 |
1449 | # we override "build_py" in both distutils and setuptools
1450 | #
1451 | # most invocation pathways end up running build_py:
1452 | # distutils/build -> build_py
1453 | # distutils/install -> distutils/build ->..
1454 | # setuptools/bdist_wheel -> distutils/install ->..
1455 | # setuptools/bdist_egg -> distutils/install_lib -> build_py
1456 | # setuptools/install -> bdist_egg ->..
1457 | # setuptools/develop -> ?
1458 |
1459 | from distutils.command.build_py import build_py as _build_py
1460 |
1461 | class cmd_build_py(_build_py):
1462 | def run(self):
1463 | root = get_root()
1464 | cfg = get_config_from_root(root)
1465 | versions = get_versions()
1466 | _build_py.run(self)
1467 | # now locate _version.py in the new build/ directory and replace
1468 | # it with an updated value
1469 | if cfg.versionfile_build:
1470 | target_versionfile = os.path.join(self.build_lib,
1471 | cfg.versionfile_build)
1472 | print("UPDATING %s" % target_versionfile)
1473 | write_to_version_file(target_versionfile, versions)
1474 | cmds["build_py"] = cmd_build_py
1475 |
1476 | if "cx_Freeze" in sys.modules: # cx_freeze enabled?
1477 | from cx_Freeze.dist import build_exe as _build_exe
1478 |
1479 | class cmd_build_exe(_build_exe):
1480 | def run(self):
1481 | root = get_root()
1482 | cfg = get_config_from_root(root)
1483 | versions = get_versions()
1484 | target_versionfile = cfg.versionfile_source
1485 | print("UPDATING %s" % target_versionfile)
1486 | write_to_version_file(target_versionfile, versions)
1487 |
1488 | _build_exe.run(self)
1489 | os.unlink(target_versionfile)
1490 | with open(cfg.versionfile_source, "w") as f:
1491 | LONG = LONG_VERSION_PY[cfg.VCS]
1492 | f.write(LONG %
1493 | {"DOLLAR": "$",
1494 | "STYLE": cfg.style,
1495 | "TAG_PREFIX": cfg.tag_prefix,
1496 | "PARENTDIR_PREFIX": cfg.parentdir_prefix,
1497 | "VERSIONFILE_SOURCE": cfg.versionfile_source,
1498 | })
1499 | cmds["build_exe"] = cmd_build_exe
1500 | del cmds["build_py"]
1501 |
1502 | # we override different "sdist" commands for both environments
1503 | if "setuptools" in sys.modules:
1504 | from setuptools.command.sdist import sdist as _sdist
1505 | else:
1506 | from distutils.command.sdist import sdist as _sdist
1507 |
1508 | class cmd_sdist(_sdist):
1509 | def run(self):
1510 | versions = get_versions()
1511 | self._versioneer_generated_versions = versions
1512 | # unless we update this, the command will keep using the old
1513 | # version
1514 | self.distribution.metadata.version = versions["version"]
1515 | return _sdist.run(self)
1516 |
1517 | def make_release_tree(self, base_dir, files):
1518 | root = get_root()
1519 | cfg = get_config_from_root(root)
1520 | _sdist.make_release_tree(self, base_dir, files)
1521 | # now locate _version.py in the new base_dir directory
1522 | # (remembering that it may be a hardlink) and replace it with an
1523 | # updated value
1524 | target_versionfile = os.path.join(base_dir, cfg.versionfile_source)
1525 | print("UPDATING %s" % target_versionfile)
1526 | write_to_version_file(target_versionfile,
1527 | self._versioneer_generated_versions)
1528 | cmds["sdist"] = cmd_sdist
1529 |
1530 | return cmds
1531 |
1532 |
1533 | CONFIG_ERROR = """
1534 | setup.cfg is missing the necessary Versioneer configuration. You need
1535 | a section like:
1536 |
1537 | [versioneer]
1538 | VCS = git
1539 | style = pep440
1540 | versionfile_source = src/myproject/_version.py
1541 | versionfile_build = myproject/_version.py
1542 | tag_prefix = ""
1543 | parentdir_prefix = myproject-
1544 |
1545 | You will also need to edit your setup.py to use the results:
1546 |
1547 | import versioneer
1548 | setup(version=versioneer.get_version(),
1549 | cmdclass=versioneer.get_cmdclass(), ...)
1550 |
1551 | Please read the docstring in ./versioneer.py for configuration instructions,
1552 | edit setup.cfg, and re-run the installer or 'python versioneer.py setup'.
1553 | """
1554 |
1555 | SAMPLE_CONFIG = """
1556 | # See the docstring in versioneer.py for instructions. Note that you must
1557 | # re-run 'versioneer.py setup' after changing this section, and commit the
1558 | # resulting files.
1559 |
1560 | [versioneer]
1561 | #VCS = git
1562 | #style = pep440
1563 | #versionfile_source =
1564 | #versionfile_build =
1565 | #tag_prefix =
1566 | #parentdir_prefix =
1567 |
1568 | """
1569 |
1570 | INIT_PY_SNIPPET = """
1571 | from ._version import get_versions
1572 | __version__ = get_versions()['version']
1573 | del get_versions
1574 | """
1575 |
1576 |
1577 | def do_setup():
1578 | root = get_root()
1579 | try:
1580 | cfg = get_config_from_root(root)
1581 | except (EnvironmentError, configparser.NoSectionError,
1582 | configparser.NoOptionError) as e:
1583 | if isinstance(e, (EnvironmentError, configparser.NoSectionError)):
1584 | print("Adding sample versioneer config to setup.cfg",
1585 | file=sys.stderr)
1586 | with open(os.path.join(root, "setup.cfg"), "a") as f:
1587 | f.write(SAMPLE_CONFIG)
1588 | print(CONFIG_ERROR, file=sys.stderr)
1589 | return 1
1590 |
1591 | print(" creating %s" % cfg.versionfile_source)
1592 | with open(cfg.versionfile_source, "w") as f:
1593 | LONG = LONG_VERSION_PY[cfg.VCS]
1594 | f.write(LONG % {"DOLLAR": "$",
1595 | "STYLE": cfg.style,
1596 | "TAG_PREFIX": cfg.tag_prefix,
1597 | "PARENTDIR_PREFIX": cfg.parentdir_prefix,
1598 | "VERSIONFILE_SOURCE": cfg.versionfile_source,
1599 | })
1600 |
1601 | ipy = os.path.join(os.path.dirname(cfg.versionfile_source),
1602 | "__init__.py")
1603 | if os.path.exists(ipy):
1604 | try:
1605 | with open(ipy, "r") as f:
1606 | old = f.read()
1607 | except EnvironmentError:
1608 | old = ""
1609 | if INIT_PY_SNIPPET not in old:
1610 | print(" appending to %s" % ipy)
1611 | with open(ipy, "a") as f:
1612 | f.write(INIT_PY_SNIPPET)
1613 | else:
1614 | print(" %s unmodified" % ipy)
1615 | else:
1616 | print(" %s doesn't exist, ok" % ipy)
1617 | ipy = None
1618 |
1619 | # Make sure both the top-level "versioneer.py" and versionfile_source
1620 | # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so
1621 | # they'll be copied into source distributions. Pip won't be able to
1622 | # install the package without this.
1623 | manifest_in = os.path.join(root, "MANIFEST.in")
1624 | simple_includes = set()
1625 | try:
1626 | with open(manifest_in, "r") as f:
1627 | for line in f:
1628 | if line.startswith("include "):
1629 | for include in line.split()[1:]:
1630 | simple_includes.add(include)
1631 | except EnvironmentError:
1632 | pass
1633 | # That doesn't cover everything MANIFEST.in can do
1634 | # (http://docs.python.org/2/distutils/sourcedist.html#commands), so
1635 | # it might give some false negatives. Appending redundant 'include'
1636 | # lines is safe, though.
1637 | if "versioneer.py" not in simple_includes:
1638 | print(" appending 'versioneer.py' to MANIFEST.in")
1639 | with open(manifest_in, "a") as f:
1640 | f.write("include versioneer.py\n")
1641 | else:
1642 | print(" 'versioneer.py' already in MANIFEST.in")
1643 | if cfg.versionfile_source not in simple_includes:
1644 | print(" appending versionfile_source ('%s') to MANIFEST.in" %
1645 | cfg.versionfile_source)
1646 | with open(manifest_in, "a") as f:
1647 | f.write("include %s\n" % cfg.versionfile_source)
1648 | else:
1649 | print(" versionfile_source already in MANIFEST.in")
1650 |
1651 | # Make VCS-specific changes. For git, this means creating/changing
1652 | # .gitattributes to mark _version.py for export-time keyword
1653 | # substitution.
1654 | do_vcs_install(manifest_in, cfg.versionfile_source, ipy)
1655 | return 0
1656 |
1657 |
1658 | def scan_setup_py():
1659 | found = set()
1660 | setters = False
1661 | errors = 0
1662 | with open("setup.py", "r") as f:
1663 | for line in f.readlines():
1664 | if "import versioneer" in line:
1665 | found.add("import")
1666 | if "versioneer.get_cmdclass()" in line:
1667 | found.add("cmdclass")
1668 | if "versioneer.get_version()" in line:
1669 | found.add("get_version")
1670 | if "versioneer.VCS" in line:
1671 | setters = True
1672 | if "versioneer.versionfile_source" in line:
1673 | setters = True
1674 | if len(found) != 3:
1675 | print("")
1676 | print("Your setup.py appears to be missing some important items")
1677 | print("(but I might be wrong). Please make sure it has something")
1678 | print("roughly like the following:")
1679 | print("")
1680 | print(" import versioneer")
1681 | print(" setup( version=versioneer.get_version(),")
1682 | print(" cmdclass=versioneer.get_cmdclass(), ...)")
1683 | print("")
1684 | errors += 1
1685 | if setters:
1686 | print("You should remove lines like 'versioneer.VCS = ' and")
1687 | print("'versioneer.versionfile_source = ' . This configuration")
1688 | print("now lives in setup.cfg, and should be removed from setup.py")
1689 | print("")
1690 | errors += 1
1691 | return errors
1692 |
1693 | if __name__ == "__main__":
1694 | cmd = sys.argv[1]
1695 | if cmd == "setup":
1696 | errors = do_setup()
1697 | errors += scan_setup_py()
1698 | if errors:
1699 | sys.exit(1)
1700 |
--------------------------------------------------------------------------------