├── .gitignore ├── .pylintrc ├── AUTHORS ├── CONTRIBUTING.rst ├── ChangeLog ├── LICENSE ├── Makefile ├── README.rst ├── circle.yml ├── deploy.py ├── deploy_requirements.txt ├── dev_requirements.txt ├── docs ├── Makefile ├── _static │ └── css │ │ └── custom.css ├── conf.py ├── contributing.rst ├── index.rst ├── installation.rst ├── release_notes.rst └── toc.rst ├── pytest_pgsql ├── __init__.py ├── database.py ├── errors.py ├── ext.py ├── plugin.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── database_test.py │ ├── ext_test.py │ ├── plugin_test.py │ └── time_test.py ├── time.py └── version.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── test_requirements.txt └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python,sublimetext,vim,emacs,git,pycharm,macos 3 | 4 | ### Emacs ### 5 | # -*- mode: gitignore; -*- 6 | *~ 7 | \#*\# 8 | /.emacs.desktop 9 | /.emacs.desktop.lock 10 | *.elc 11 | auto-save-list 12 | tramp 13 | .\#* 14 | 15 | # Org-mode 16 | .org-id-locations 17 | *_archive 18 | 19 | # flymake-mode 20 | *_flymake.* 21 | 22 | # eshell files 23 | /eshell/history 24 | /eshell/lastdir 25 | 26 | # elpa packages 27 | /elpa/ 28 | 29 | # reftex files 30 | *.rel 31 | 32 | # AUCTeX auto folder 33 | /auto/ 34 | 35 | # cask packages 36 | .cask/ 37 | dist/ 38 | 39 | # Flycheck 40 | flycheck_*.el 41 | 42 | # server auth directory 43 | /server/ 44 | 45 | # projectiles files 46 | .projectile 47 | 48 | # directory configuration 49 | .dir-locals.el 50 | 51 | ### Git ### 52 | *.orig 53 | 54 | ### macOS ### 55 | *.DS_Store 56 | .AppleDouble 57 | .LSOverride 58 | 59 | # Icon must end with two \r 60 | Icon 61 | 62 | 63 | # Thumbnails 64 | ._* 65 | 66 | # Files that might appear in the root of a volume 67 | .DocumentRevisions-V100 68 | .fseventsd 69 | .Spotlight-V100 70 | .TemporaryItems 71 | .Trashes 72 | .VolumeIcon.icns 73 | .com.apple.timemachine.donotpresent 74 | 75 | # Directories potentially created on remote AFP share 76 | .AppleDB 77 | .AppleDesktop 78 | Network Trash Folder 79 | Temporary Items 80 | .apdisk 81 | 82 | ### PyCharm ### 83 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 84 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 85 | 86 | # User-specific stuff: 87 | .idea/**/workspace.xml 88 | .idea/**/tasks.xml 89 | 90 | # Sensitive or high-churn files: 91 | .idea/**/dataSources/ 92 | .idea/**/dataSources.ids 93 | .idea/**/dataSources.xml 94 | .idea/**/dataSources.local.xml 95 | .idea/**/sqlDataSources.xml 96 | .idea/**/dynamic.xml 97 | .idea/**/uiDesigner.xml 98 | 99 | # Gradle: 100 | .idea/**/gradle.xml 101 | .idea/**/libraries 102 | 103 | # Mongo Explorer plugin: 104 | .idea/**/mongoSettings.xml 105 | 106 | ## File-based project format: 107 | *.iws 108 | 109 | ## Plugin-specific files: 110 | 111 | # IntelliJ 112 | /out/ 113 | 114 | # mpeltonen/sbt-idea plugin 115 | .idea_modules/ 116 | 117 | # JIRA plugin 118 | atlassian-ide-plugin.xml 119 | 120 | # Crashlytics plugin (for Android Studio and IntelliJ) 121 | com_crashlytics_export_strings.xml 122 | crashlytics.properties 123 | crashlytics-build.properties 124 | fabric.properties 125 | 126 | ### PyCharm Patch ### 127 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 128 | 129 | # *.iml 130 | # modules.xml 131 | # .idea/misc.xml 132 | # *.ipr 133 | 134 | ### Python ### 135 | # Byte-compiled / optimized / DLL files 136 | __pycache__/ 137 | *.py[cod] 138 | *$py.class 139 | 140 | # C extensions 141 | *.so 142 | 143 | # Distribution / packaging 144 | .Python 145 | env/ 146 | build/ 147 | develop-eggs/ 148 | downloads/ 149 | eggs/ 150 | .eggs/ 151 | lib/ 152 | lib64/ 153 | parts/ 154 | sdist/ 155 | var/ 156 | wheels/ 157 | *.egg-info/ 158 | .installed.cfg 159 | *.egg 160 | 161 | # PyInstaller 162 | # Usually these files are written by a python script from a template 163 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 164 | *.manifest 165 | *.spec 166 | 167 | # Installer logs 168 | pip-log.txt 169 | pip-delete-this-directory.txt 170 | 171 | # Unit test / coverage reports 172 | htmlcov/ 173 | .tox/ 174 | .coverage 175 | .coverage.* 176 | .cache 177 | nosetests.xml 178 | coverage.xml 179 | *,cover 180 | .hypothesis/ 181 | 182 | # Translations 183 | *.mo 184 | *.pot 185 | 186 | # Django stuff: 187 | *.log 188 | local_settings.py 189 | 190 | # Flask stuff: 191 | instance/ 192 | .webassets-cache 193 | 194 | # Scrapy stuff: 195 | .scrapy 196 | 197 | # Sphinx documentation 198 | docs/_build/ 199 | 200 | # PyBuilder 201 | target/ 202 | 203 | # Jupyter Notebook 204 | .ipynb_checkpoints 205 | 206 | # pyenv 207 | .python-version 208 | 209 | # celery beat schedule file 210 | celerybeat-schedule 211 | 212 | # dotenv 213 | .env 214 | 215 | # virtualenv 216 | .venv 217 | venv/ 218 | ENV/ 219 | 220 | # Spyder project settings 221 | .spyderproject 222 | 223 | # Rope project settings 224 | .ropeproject 225 | 226 | ### SublimeText ### 227 | # cache files for sublime text 228 | *.tmlanguage.cache 229 | *.tmPreferences.cache 230 | *.stTheme.cache 231 | 232 | # workspace files are user-specific 233 | *.sublime-workspace 234 | 235 | # project files should be checked into the repository, unless a significant 236 | # proportion of contributors will probably not be using SublimeText 237 | # *.sublime-project 238 | 239 | # sftp configuration file 240 | sftp-config.json 241 | 242 | # Package control specific files 243 | Package Control.last-run 244 | Package Control.ca-list 245 | Package Control.ca-bundle 246 | Package Control.system-ca-bundle 247 | Package Control.cache/ 248 | Package Control.ca-certs/ 249 | Package Control.merged-ca-bundle 250 | Package Control.user-ca-bundle 251 | oscrypto-ca-bundle.crt 252 | bh_unicode_properties.cache 253 | 254 | # Sublime-github package stores a github token in this file 255 | # https://packagecontrol.io/packages/sublime-github 256 | GitHub.sublime-settings 257 | 258 | ### Vim ### 259 | # swap 260 | [._]*.s[a-v][a-z] 261 | [._]*.sw[a-p] 262 | [._]s[a-v][a-z] 263 | [._]sw[a-p] 264 | # session 265 | Session.vim 266 | # temporary 267 | .netrwhist 268 | # auto-generated tag files 269 | tags 270 | 271 | # End of https://www.gitignore.io/api/python,sublimetext,vim,emacs,git,pycharm,macos -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # Pylint rules for pytest-pgsql 2 | 3 | [MESSAGES CONTROL] 4 | disable=locally-disabled, no-member, missing-docstring, invalid-name, protected-access, unused-argument, redefined-outer-name, too-few-public-methods, too-many-arguments 5 | 6 | [REPORTS] 7 | reports=no 8 | 9 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Circle CI 2 | Cooper Stimson <903744+6C1@users.noreply.github.com> 3 | Diego Argueta 4 | Diego Argueta 5 | Helen McLendon 6 | Jason Brownstein 7 | Wes Kendall 8 | Wes Kendall 9 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing Guide 2 | ================== 3 | 4 | Setup 5 | ~~~~~ 6 | 7 | Set up your development environment with:: 8 | 9 | git clone git@github.com:CloverHealth/pytest-pgsql.git 10 | cd pytest-pgsql 11 | make setup 12 | 13 | ``make setup`` will setup a virtual environment managed by `pyenv `_ and install dependencies. 14 | 15 | Note that if you'd like to use something else to manage dependencies other than pyenv, call ``make dependencies`` instead of 16 | ``make setup``. 17 | 18 | Testing and Validation 19 | ~~~~~~~~~~~~~~~~~~~~~~ 20 | 21 | Run the tests with:: 22 | 23 | make test 24 | 25 | Validate the code with:: 26 | 27 | make validate 28 | 29 | Documentation 30 | ~~~~~~~~~~~~~ 31 | 32 | `Sphinx `_ documentation can be built with:: 33 | 34 | make docs 35 | 36 | The static HTML files are stored in the ``docs/_build/html`` directory. A shortcut for opening them on OSX is:: 37 | 38 | make open_docs 39 | 40 | Releases and Versioning 41 | ~~~~~~~~~~~~~~~~~~~~~~~ 42 | 43 | Anything that is merged into the master branch will be automatically deployed to PyPI. 44 | Documentation will be published to `ReadTheDocs `_ soon. 45 | 46 | The following files will be generated and should *not* be edited by a user: 47 | 48 | * ``ChangeLog`` - Contains the commit messages of the releases. Please have readable commit messages in the 49 | master branch and squash and merge commits when necessary. 50 | * ``AUTHORS`` - Contains the contributing authors. 51 | * ``version.py`` - Automatically updated to include the version string. 52 | 53 | This project uses `Semantic Versioning `_ through `PBR `_. This means when you make a commit, you can add a message like:: 54 | 55 | sem-ver: feature, Added this functionality that does blah. 56 | 57 | Depending on the sem-ver tag, the version will be bumped in the right way when releasing the package. For more information, 58 | about PBR, go the the `PBR docs `_. 59 | -------------------------------------------------------------------------------- /ChangeLog: -------------------------------------------------------------------------------- 1 | CHANGES 2 | ======= 3 | 4 | 1.1.3 5 | ----- 6 | 7 | 8 | 1.1.2 9 | ----- 10 | 11 | * adding deploy key and removing twine register step (#17) 12 | * missing an arg for the deploy script (#16) 13 | * Add postgresql.conf options (#15) 14 | 15 | 1.1.1 16 | ----- 17 | 18 | * Disable pylint false positive. (#10) 19 | 20 | 1.1.0 21 | ----- 22 | 23 | * sem-ver: deprecation, Drop Python 3.3 and crash installation on unsupported versions of Python (#5) 24 | 25 | 1.0.4 26 | ----- 27 | 28 | * [BUG] Remove silently ignored --pg-driver CLI option. (#3) 29 | 30 | 1.0.3 31 | ----- 32 | 33 | * Fixed issue when running "make setup" 34 | 35 | 1.0.2 36 | ----- 37 | 38 | * Fix Github link in package metadata and remove custom URL creation based on driver 39 | 40 | 1.0.1 41 | ----- 42 | 43 | * Fixed README docs link 44 | * sem-ver: api-break, Initial release 45 | * Initial commit 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Clover Health Labs, LLC 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL CLOVER HEALTH LABS, LLC BE LIABLE FOR ANY 22 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 25 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for packaging and testing pytest_pgsql 2 | # 3 | # This Makefile has the following targets: 4 | # 5 | # pyenv - Sets up pyenv and a virtualenv that is automatically used 6 | # deactivate_pyenv - Deactivates the pyenv setup 7 | # dependencies - Installs all dependencies for a project (including mac dependencies) 8 | # setup - Sets up the entire development environment (pyenv and dependencies) 9 | # clean_docs - Clean the documentation folder 10 | # clean - Clean any generated files (including documentation) 11 | # open_docs - Open any docs generated with "make docs" 12 | # docs - Generated sphinx docs 13 | # validate - Run code validation 14 | # test - Run tests 15 | # run - Run any services for local development (databases, CSS compiliation, airflow, etc) 16 | # version - Show the version of the package 17 | 18 | OS = $(shell uname -s) 19 | 20 | PACKAGE_NAME=pytest_pgsql 21 | MODULE_NAME=pytest_pgsql 22 | 23 | ifdef CIRCLECI 24 | TOX_POSARGS=-- --junitxml={env:CIRCLE_TEST_REPORTS}/pytest_pgsql/junit.xml 25 | # Use CircleCIs version 26 | PYTHON_VERSION= 27 | # Dont log pip install output since it can print the private repo url 28 | PIP_INSTALL_CMD=pip install -q 29 | # Do local installs without editable mode because of issues with CircleCI's venv 30 | PIP_LOCAL_INSTALL_CMD=pip install -q . 31 | else 32 | TOX_POSARGS= 33 | PIP_INSTALL_CMD=pip install 34 | PIP_LOCAL_INSTALL_CMD=pip install -e . 35 | endif 36 | 37 | 38 | # Print usage of main targets when user types "make" or "make help" 39 | help: 40 | @echo "Please choose one of the following targets: \n"\ 41 | " setup: Setup your development environment and install dependencies\n"\ 42 | " test: Run tests\n"\ 43 | " validate: Validate code and documentation\n"\ 44 | " docs: Build Sphinx documentation\n"\ 45 | " open_docs: Open built documentation\n"\ 46 | "\n"\ 47 | "View the Makefile for more documentation about all of the available commands" 48 | @exit 2 49 | 50 | 51 | # Sets up pyenv and the virtualenv that is managed by pyenv 52 | .PHONY: pyenv 53 | pyenv: 54 | ifeq (${OS}, Darwin) 55 | brew install pyenv pyenv-virtualenv 2> /dev/null || true 56 | # Ensure we remain up to date with pyenv so that new python versions are available for installation 57 | brew upgrade pyenv pyenv-virtualenv 2> /dev/null || true 58 | endif 59 | 60 | # Install all supported Python versions. There are more recent patch releases 61 | # for most of these but CircleCI doesn't have them preinstalled. Installing a 62 | # version of Python that isn't preinstalled slows down the build significantly. 63 | # 64 | # If you don't have these installed yet it's going to take a long time, but 65 | # you'll only need to do it once. 66 | pyenv install -s 3.6.2 67 | pyenv install -s 3.5.2 68 | pyenv install -s 3.4.4 69 | 70 | # Set up the environments for Tox 71 | pyenv local 3.6.2 3.5.2 3.4.4 72 | 73 | 74 | # Deactivates pyenv and removes it from auto-using the virtualenv 75 | .PHONY: deactivate_pyenv 76 | deactivate_pyenv: 77 | rm .python-version 78 | 79 | 80 | # Builds all dependencies for a project 81 | .PHONY: dependencies 82 | dependencies: 83 | ${PIP_INSTALL_CMD} -U -r dev_requirements.txt # Use -U to ensure requirements are upgraded every time 84 | ${PIP_INSTALL_CMD} -r test_requirements.txt 85 | ${PIP_LOCAL_INSTALL_CMD} 86 | pip check 87 | 88 | 89 | # Performs the full development environment setup 90 | .PHONY: setup 91 | setup: pyenv dependencies 92 | 93 | 94 | # Clean the documentation folder 95 | .PHONY: clean_docs 96 | clean_docs: 97 | cd docs && make clean 98 | 99 | 100 | # Clean any auto-generated files 101 | .PHONY: clean 102 | clean: clean_docs 103 | python setup.py clean 104 | rm -rf build/ 105 | rm -rf dist/ 106 | rm -rf *.egg*/ 107 | rm -rf __pycache__/ 108 | rm -f MANIFEST 109 | find ${PACKAGE_NAME} -type f -name '*.pyc' -delete 110 | rm -rf coverage .coverage .coverage* 111 | 112 | 113 | # Open the build docs (only works on Mac) 114 | .PHONY: open_docs 115 | open_docs: 116 | open docs/_build/html/index.html 117 | 118 | 119 | # Build Sphinx autodocs 120 | .PHONY: docs 121 | docs: clean_docs # Ensure docs are clean, otherwise weird render errors can result 122 | sphinx-apidoc -f -e -M -o docs/ pytest_pgsql 'pytest_pgsql/tests' 'pytest_pgsql/version.py' && cd docs && make html 123 | 124 | # Run code validation 125 | .PHONY: validate 126 | validate: 127 | flake8 -v ${MODULE_NAME}/ 128 | pylint ${MODULE_NAME} 129 | make docs # Ensure docs can be built during validation 130 | 131 | 132 | # Run tests 133 | .PHONY: test 134 | test: 135 | tox ${TOX_POSARGS} 136 | coverage report 137 | 138 | .PHONY: test_single_version 139 | test_single_version: 140 | coverage run -a -m pytest --pg-conf-opt="track_commit_timestamp=True" --pg-extensions=btree_gin,,btree_gist pytest_pgsql/tests 141 | 142 | 143 | # Run any services for local development. For example, docker databases, CSS compilation watching, etc 144 | .PHONY: run 145 | run: 146 | @echo "No services need to be running for local development" 147 | 148 | 149 | # Distribution helpers for determining the version of the package 150 | VERSION=$(shell python setup.py --version | sed 's/\([0-9]*\.[0-9]*\.[0-9]*\).*$$/\1/') 151 | 152 | .PHONY: version 153 | version: 154 | @echo ${VERSION} 155 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Clean PostgreSQL Databases for Your Tests 2 | ========================================= 3 | 4 | The following is a summary of the complete pytest_pgsql docs, which are available 5 | on `ReadTheDocs `_. 6 | 7 | What is ``pytest_pgsql``? 8 | ------------------------------ 9 | 10 | ``pytest_pgsql`` is a `pytest `_ plugin you can use to 11 | write unit tests that utilize a temporary PostgreSQL database that gets cleaned 12 | up automatically after every test runs, allowing each test to run on a completely 13 | clean database (with some limitations). 14 | 15 | The plugin gives you two fixtures you can use in your tests: ``postgresql_db`` and 16 | ``transacted_postgresql_db``. Both of these give you similar interfaces to access 17 | to the database, but have slightly different use cases (see below). 18 | 19 | Sample Usage 20 | ------------ 21 | 22 | You can use a session, connection, or engine - the choice is up to you. 23 | ``postgresql_db`` and ``transacted_postgresql_db`` both give you a session, but 24 | ``postgresql_db`` exposes its engine and ``transacted_postgresql_db`` exposes its 25 | connection:: 26 | 27 | def test_orm(postgresql_db): 28 | instance = Person(name='Foo Bar') 29 | postgresql_db.session.add(instance) 30 | postgresql_db.session.commit() 31 | with postgresql_db.engine.connect() as conn: 32 | do_thing(conn) 33 | 34 | def test_connection(transacted_postgresql_db): 35 | instance = Person(name='Foo Bar') 36 | transacted_postgresql_db.session.add(instance) 37 | transacted_postgresql_db.session.commit() 38 | 39 | transacted_postgresql_db.connection.execute('DROP TABLE my_table') 40 | 41 | Features 42 | -------- 43 | 44 | The following is a non-exhaustive list of some of the features provided to you 45 | by the database fixtures. 46 | 47 | Manipulating Time 48 | ~~~~~~~~~~~~~~~~~ 49 | 50 | Both database fixtures use `freezegun `_ to 51 | allow you to freeze time inside a block of code. You can use it in a variety of 52 | ways: 53 | 54 | As a context manager:: 55 | 56 | with postgresql.time.freeze('December 31st 1999 11:59:59 PM') as freezer: 57 | # Time is frozen inside the database *and* Python. 58 | now = postgresql_db.session.execute('SELECT NOW()').scalar() 59 | assert now.date() == datetime.date(1999, 12, 31) 60 | assert datetime.date.today() == datetime.date(1999, 12, 31) 61 | 62 | # Advance time by 1 second so we roll over into the new year 63 | freezer.tick() 64 | 65 | now = postgresql_db.session.execute('SELECT NOW()').scalar() 66 | assert now.date() == datetime.date(2000, 1, 1) 67 | 68 | As a decorator:: 69 | 70 | @pytest_pgsql.freeze_time(datetime.datetime(2038, 1, 19, 3, 14, 7)) 71 | def test_freezing(postgresql_db): 72 | today = postgresql_db.session.execute( 73 | "SELECT EXTRACT('YEAR' FROM CURRENT_DATE)").scalar() 74 | assert today.year == 2038 75 | assert datetime.date.today() == datetime.date(2038, 1, 19) 76 | 77 | And more! 78 | 79 | General-Purpose Functions 80 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 81 | 82 | ``postgresql_db`` and ``transacted_postgresql_db`` provide some general-purpose 83 | functions to ease test setup and execution. 84 | 85 | - ``load_csv()`` loads a CSV file into an existing table. 86 | - ``run_sql_file()`` executes a SQL script, optionally performing variable binding. 87 | 88 | Extension Management 89 | ~~~~~~~~~~~~~~~~~~~~ 90 | 91 | Since version 9.1 Postgres supports `extensions `_. 92 | You can check for the presence of and install extensions like so:: 93 | 94 | >>> postgresql_db.is_extension_available('asdf') # Can I install this extension? 95 | False 96 | >>> postgresql_db.is_extension_available('uuid-ossp') # Maybe this one is supported... 97 | True 98 | >>> postgresql_db.install_extension('uuid-ossp') 99 | True 100 | >>> postgresql_db.is_extension_installed('uuid-ossp') 101 | True 102 | 103 | ``install_extension()`` has additional arguments to allow control over which schema 104 | the extension is installed in, what to do if the extension is already installed, 105 | and so on. See the documentation for descriptions of these features. 106 | 107 | Schemas and Tables 108 | ~~~~~~~~~~~~~~~~~~ 109 | 110 | You can create `table schemas `_ 111 | by calling ``create_schema()`` like so:: 112 | 113 | postgresql_db.create_schema('foo') # Create one schema 114 | postgresql_db.create_schema('foo', 'bar') # Create multiple ones 115 | 116 | To quickly see if a table schema exists, call ``has_schema()``:: 117 | 118 | >>> postgresql_db.has_schema('public') 119 | True 120 | 121 | Similarly, you can create tables in the database with ``create_table()``. You can 122 | pass SQLAlchemy ``Table`` instances or ORM declarative model classes:: 123 | 124 | # Just a regular Table. 125 | my_table = Table('abc', MetaData(), Column('def', Integer())) 126 | 127 | # A declarative model works too. 128 | class MyORMModel(declarative_base()): 129 | id = Column(Integer, primary_key=True) 130 | 131 | # Pass a variable amount of tables to create 132 | postgresql_db.create_table(my_table, MyORMModel) 133 | 134 | Installation 135 | ============ 136 | 137 | Sorry, this library is not compatible with Python 2. Please be sure to use ``pip3`` instead of 138 | ``pip`` when installing:: 139 | 140 | pip3 install pytest-pgsql 141 | 142 | 143 | Contributing Guide 144 | ================== 145 | 146 | For information on setting up pytest_pgsql for development and contributing 147 | changes, view `CONTRIBUTING.rst `_. 148 | -------------------------------------------------------------------------------- /circle.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | aliases: 3 | - &docker_image circleci/python:3.6.2-stretch 4 | - &dependencies 5 | name: Make virtualenv and install dependencies 6 | command: | 7 | python3 -m venv ~/venv 8 | echo ". ~/venv/bin/activate" >> $BASH_ENV 9 | source $BASH_ENV 10 | make dependencies 11 | 12 | 13 | jobs: 14 | lint: 15 | docker: 16 | - image: *docker_image 17 | steps: 18 | - checkout 19 | - run: 20 | <<: *dependencies 21 | - run: make validate 22 | test: 23 | docker: 24 | - image: *docker_image 25 | environment: 26 | TEST_REPORTS: /tmp/test-reports 27 | steps: 28 | - checkout 29 | - run: 30 | name: Install postgres 31 | command: | 32 | sudo apt-get update && sudo apt install postgresql postgresql-contrib 33 | - run: 34 | <<: *dependencies 35 | - run: make test_single_version 36 | - store_test_results: 37 | path: /tmp/test-reports 38 | - store_artifacts: 39 | path: /tmp/test-reports 40 | 41 | deploy: 42 | docker: 43 | - image: *docker_image 44 | steps: 45 | - checkout 46 | - run: 47 | <<: *dependencies 48 | - run: pip install -q -r deploy_requirements.txt 49 | - add_ssh_keys: 50 | fingerprints: 51 | - "f9:ef:5e:40:d5:ed:e3:86:a1:18:3e:09:85:93:ef:3a" # CAN I DO THIS?? prob not 52 | - run: python3 deploy.py prod 53 | 54 | workflows: 55 | version: 2 56 | checks_and_deploy: 57 | jobs: 58 | - lint 59 | - test 60 | - deploy: 61 | requires: 62 | - lint 63 | - test 64 | filters: 65 | branches: 66 | only: 67 | - master -------------------------------------------------------------------------------- /deploy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Deploys the package to PyPI and uploads docs. This script is called in circle.yml. 5 | """ 6 | import os 7 | import shlex 8 | import subprocess 9 | import sys 10 | 11 | CIRCLECI_ENV_VAR = 'CIRCLECI' 12 | 13 | 14 | def _shell(cmd, check=True, stdin=None, stdout=None, stderr=None): # pragma: no cover 15 | """Runs a subprocess shell with check=True by default""" 16 | return subprocess.run(cmd, shell=True, check=check, stdin=stdin, stdout=stdout, stderr=stderr) 17 | 18 | 19 | def _pypi_push(dist): 20 | """Push created package to PyPI. 21 | 22 | Requires the following defined environment variables: 23 | - TWINE_USERNAME: The PyPI username to upload this package under 24 | - TWINE_PASSWORD: The password to the user's account 25 | 26 | Args: 27 | dist (str): 28 | The distribution to push. Must be a valid directory; shell globs are 29 | NOT allowed. 30 | """ 31 | _shell('twine upload ' + shlex.quote(dist + '/*')) 32 | 33 | 34 | def deploy(target): 35 | """Deploys the package and documentation. 36 | 37 | Proceeds in the following steps: 38 | 39 | 1. Ensures proper environment variables are set and checks that we are on Circle CI 40 | 2. Tags the repository with the new version 41 | 3. Creates a standard distribution and a wheel 42 | 4. Updates version.py to have the proper version 43 | 5. Commits the ChangeLog, AUTHORS, and version.py file 44 | 6. Pushes to PyPI 45 | 7. Pushes the tags and newly committed files 46 | 47 | Raises: 48 | `EnvironmentError`: 49 | - Not running on CircleCI 50 | - `*_PYPI_USERNAME` and/or `*_PYPI_PASSWORD` environment variables 51 | are missing 52 | - Attempting to deploy to production from a branch that isn't master 53 | """ 54 | # Ensure proper environment 55 | if not os.getenv(CIRCLECI_ENV_VAR): # pragma: no cover 56 | raise EnvironmentError('Must be on CircleCI to run this script') 57 | 58 | current_branch = os.getenv('CIRCLE_BRANCH') 59 | if (target == 'PROD') and (current_branch != 'master'): 60 | raise EnvironmentError( 61 | f'Refusing to deploy to production from branch {current_branch!r}. ' 62 | f'Production deploys can only be made from master.') 63 | 64 | if target in ('PROD', 'TEST'): 65 | pypi_username = os.getenv(f'{target}_PYPI_USERNAME') 66 | pypi_password = os.getenv(f'{target}_PYPI_PASSWORD') 67 | else: 68 | raise ValueError(f"Deploy target must be 'PROD' or 'TEST', got {target!r}.") 69 | 70 | if not (pypi_username and pypi_password): # pragma: no cover 71 | raise EnvironmentError( 72 | f"Missing '{target}_PYPI_USERNAME' and/or '{target}_PYPI_PASSWORD' " 73 | f"environment variables. These are required to push to PyPI.") 74 | 75 | # Twine requires these environment variables to be set. Subprocesses will 76 | # inherit these when we invoke them, so no need to pass them on the command 77 | # line. We want to avoid that in case something's logging each command run. 78 | os.environ['TWINE_USERNAME'] = pypi_username 79 | os.environ['TWINE_PASSWORD'] = pypi_password 80 | 81 | # Set up git on circle to push to the current branch 82 | _shell('git config --global user.email "dev@cloverhealth.com"') 83 | _shell('git config --global user.name "Circle CI"') 84 | _shell('git config push.default current') 85 | 86 | # Obtain the version to deploy 87 | ret = _shell('make version', stdout=subprocess.PIPE) 88 | version = ret.stdout.decode('utf-8').strip() 89 | 90 | print(f'Deploying version {version!r}...') 91 | 92 | # Tag the version 93 | _shell(f'git tag -f -a {version} -m "Version {version}"') 94 | 95 | # Update the version 96 | _shell( 97 | f'sed -i.bak "s/^__version__ = .*/__version__ = {version!r}/" */version.py') 98 | 99 | # Create a standard distribution and a wheel 100 | _shell('python setup.py sdist bdist_wheel') 101 | 102 | # Add the updated ChangeLog and AUTHORS 103 | _shell('git add ChangeLog AUTHORS */version.py') 104 | 105 | # Start the commit message with "Merge" so that PBR will ignore it in the 106 | # ChangeLog. Use [skip ci] to ensure CircleCI doesn't recursively deploy. 107 | _shell('git commit --no-verify -m "Merge autogenerated files [skip ci]"') 108 | 109 | # Push the distributions to PyPI. 110 | _pypi_push('dist') 111 | 112 | # Push the tag and AUTHORS / ChangeLog after successful PyPI deploy 113 | _shell('git push --follow-tags') 114 | 115 | print(f'Deployment complete. Latest version is {version}.') 116 | 117 | 118 | if __name__ == '__main__': # pragma: no cover 119 | if len(sys.argv) != 2: 120 | raise RuntimeError('Require one argument indicating deploy target: `prod` or `test`.') 121 | 122 | deploy(sys.argv[1].upper()) 123 | -------------------------------------------------------------------------------- /deploy_requirements.txt: -------------------------------------------------------------------------------- 1 | pbr==3.1.1 # For semantic versioning 2 | twine==1.9.1 # For publishing to PyPI 3 | wheel==0.30.0 # For publishing wheels 4 | -------------------------------------------------------------------------------- /dev_requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements for doing development 2 | 3 | pip 4 | setuptools 5 | 6 | # For Sphinx documentation 7 | Sphinx==1.6.5 8 | sphinx-rtd-theme==0.2.4 9 | 10 | # For testing. Do *not* put these into test_requirements.txt since we don't want 11 | # to install tox in the environment created by tox. It won't get used. 12 | tox==2.9.1 13 | tox-pyenv==1.1.0 14 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = pytest-pgsql 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Temporary fix for https://github.com/rtfd/sphinx_rtd_theme/issues/417 */ 2 | 3 | .rst-content .highlight > pre { 4 | line-height: 18px; 5 | } 6 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # pytest_pgsql documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Feb 28 09:45:59 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | # import os 20 | # import sys 21 | # sys.path.insert(0, os.path.abspath('.')) 22 | 23 | import sphinx_rtd_theme 24 | 25 | import pytest_pgsql 26 | 27 | 28 | # -- General configuration ------------------------------------------------ 29 | 30 | # If your documentation needs a minimal Sphinx version, state it here. 31 | # 32 | # needs_sphinx = '1.0' 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = ['sphinx.ext.autodoc', 38 | 'sphinx.ext.intersphinx', 39 | 'sphinx.ext.napoleon', 40 | 'sphinx.ext.todo', 41 | 'sphinx.ext.viewcode'] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ['_templates'] 45 | 46 | # The suffix(es) of source filenames. 47 | # You can specify multiple suffix as a list of string: 48 | # 49 | # source_suffix = ['.rst', '.md'] 50 | source_suffix = '.rst' 51 | 52 | # The master toctree document. 53 | master_doc = 'toc' 54 | 55 | # default role for "`" (makes it attempt to match against references within project) 56 | default_role = 'any' 57 | 58 | # General information about the project. 59 | project = u'pytest_pgsql' 60 | copyright = u'2017, Clover Health' 61 | author = u'Clover Health' 62 | 63 | # The version info for the project you're documenting, acts as replacement for 64 | # |version| and |release|, also used in various other places throughout the 65 | # built documents. 66 | # 67 | # The short X.Y version. 68 | version = str(pytest_pgsql.__version__) 69 | # The full version, including alpha/beta/rc tags. 70 | release = version 71 | 72 | # The language for content autogenerated by Sphinx. Refer to documentation 73 | # for a list of supported languages. 74 | # 75 | # This is also used if you do content translation via gettext catalogs. 76 | # Usually you set "language" from the command line for these cases. 77 | language = None 78 | 79 | # List of patterns, relative to source directory, that match files and 80 | # directories to ignore when looking for source files. 81 | # This patterns also effect to html_static_path and html_extra_path 82 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 83 | 84 | # The name of the Pygments (syntax highlighting) style to use. 85 | pygments_style = 'sphinx' 86 | 87 | # If true, `todo` and `todoList` produce output, else they produce nothing. 88 | todo_include_todos = True 89 | 90 | 91 | # -- Options for HTML output ---------------------------------------------- 92 | 93 | # The theme to use for HTML and HTML Help pages. See the documentation for 94 | # a list of builtin themes. 95 | # 96 | html_theme = 'sphinx_rtd_theme' 97 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 98 | 99 | # Theme options are theme-specific and customize the look and feel of a theme 100 | # further. For a list of options available for each theme, see the 101 | # documentation. 102 | # 103 | # html_theme_options = {} 104 | 105 | # Add any paths that contain custom static files (such as style sheets) here, 106 | # relative to this directory. They are copied after the builtin static files, 107 | # so a file named "default.css" will overwrite the builtin "default.css". 108 | html_static_path = ['_static'] 109 | 110 | 111 | # -- Options for HTMLHelp output ------------------------------------------ 112 | 113 | # Output file base name for HTML help builder. 114 | htmlhelp_basename = 'pytest_pgsqldoc' 115 | 116 | 117 | # -- Options for LaTeX output --------------------------------------------- 118 | 119 | latex_elements = { 120 | # The paper size ('letterpaper' or 'a4paper'). 121 | # 122 | # 'papersize': 'letterpaper', 123 | 124 | # The font size ('10pt', '11pt' or '12pt'). 125 | # 126 | # 'pointsize': '10pt', 127 | 128 | # Additional stuff for the LaTeX preamble. 129 | # 130 | # 'preamble': '', 131 | 132 | # Latex figure (float) alignment 133 | # 134 | # 'figure_align': 'htbp', 135 | } 136 | 137 | # Grouping the document tree into LaTeX files. List of tuples 138 | # (source start file, target name, title, 139 | # author, documentclass [howto, manual, or own class]). 140 | latex_documents = [( 141 | master_doc, 142 | 'cookiecutterrepo_name.tex', u'pytest_pgsql Documentation', 143 | u'Clover Health', 'manual' 144 | )] 145 | 146 | 147 | # -- Options for manual page output --------------------------------------- 148 | 149 | # One entry per manual page. List of tuples 150 | # (source start file, name, description, authors, manual section). 151 | man_pages = [ 152 | (master_doc, 'cookiecutterrepo_name', u'pytest_pgsql Documentation', 153 | [author], 1) 154 | ] 155 | 156 | 157 | # -- Options for Texinfo output ------------------------------------------- 158 | 159 | # Grouping the document tree into Texinfo files. List of tuples 160 | # (source start file, target name, title, author, 161 | # dir menu entry, description, category) 162 | texinfo_documents = [ 163 | (master_doc, 'cookiecutterrepo_name', u'pytest_pgsql Documentation', 164 | author, 'cookiecutterrepo_name', 'One line description of project.', 165 | 'Miscellaneous'), 166 | ] 167 | 168 | # Example configuration for intersphinx: refer to the Python standard library. 169 | intersphinx_mapping = { 170 | 'https://docs.python.org/3': None, 171 | 'http://www.sqlalchemy.org/docs/11': None, 172 | 'http://pytest.readthedocs.io/en/latest': None, 173 | } 174 | 175 | 176 | def setup(app): 177 | app.add_stylesheet('css/custom.css') 178 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Clean PostgreSQL Databases for Your Tests 2 | ========================================= 3 | 4 | What is ``pytest-pgsql``? 5 | ------------------------------ 6 | 7 | ``pytest-pgsql`` is a `pytest `_ plugin you can use to 8 | write unit tests that utilize a temporary PostgreSQL database that gets cleaned 9 | up automatically after every test runs, allowing each test to run on a completely 10 | clean database (with limitations_). 11 | 12 | The plugin gives you two fixtures you can use in your tests: `postgresql_db` and 13 | `transacted_postgresql_db`. Both of these give you similar interfaces to access 14 | to the database, but have slightly different use cases (see below). 15 | 16 | General Usage 17 | ------------- 18 | 19 | You can use a session, connection, or engine - the choice is up to you. 20 | `postgresql_db` and `transacted_postgresql_db` both give you a session, but 21 | `postgresql_db` exposes its engine and `transacted_postgresql_db` exposes its 22 | connection:: 23 | 24 | def test_orm(postgresql_db): 25 | instance = Person(name='Foo Bar') 26 | postgresql_db.session.add(instance) 27 | postgresql_db.session.commit() 28 | with postgresql_db.engine.connect() as conn: 29 | do_thing(conn) 30 | 31 | def test_connection(transacted_postgresql_db): 32 | instance = Person(name='Foo Bar') 33 | transacted_postgresql_db.session.add(instance) 34 | transacted_postgresql_db.session.commit() 35 | 36 | transacted_postgresql_db.connection.execute('DROP TABLE my_table') 37 | 38 | Which Do I Use? 39 | ~~~~~~~~~~~~~~~ 40 | 41 | There are a few differences between the transacted and non-transacted versions 42 | of the PostgreSQL database fixture. They are as follows: 43 | 44 | **Transacted is faster** 45 | 46 | The major advantage of `transacted_postgresql_db` is that resetting the database 47 | to its original state is much faster than `postgresql_db`. Unfortunately, that 48 | comes at the expense of having to run the entire test in a single transaction. 49 | This means you can't call execute a ``COMMIT`` statement anywhere in your tests, 50 | or you'll risk causing nondeterministic bugs in your tests and possibly hiding 51 | bugs your production code. 52 | 53 | For a full description of what can go wrong if you execute ``COMMIT`` and how to 54 | get around this limitation, see the :ref:`Tips` section. 55 | 56 | **Non-transacted is more flexible** 57 | 58 | `postgresql_db` is more flexible than its transacted counterpart because it 59 | doesn't have to run in a single transaction, but teardown is more time-consuming 60 | because every single table, schema, extension, etc. needs to be manually reset. 61 | Tests that can't run in one transaction (e.g. data needs to be shared across 62 | threads) must use the `postgresql_db` fixture. 63 | 64 | .. _limitations: 65 | 66 | Limitations 67 | ~~~~~~~~~~~ 68 | 69 | It's important to note that at the moment the fixtures *can't* revert some 70 | changes if a top-level commit [#]_ has been executed. As far as we know this only 71 | applies to objects (extensions, schemas, tables, etc.) that existed before the 72 | test started. 73 | 74 | The following is a non-exhaustive list of changes that cannot be reverted after 75 | a top-level commit: 76 | 77 | - Modifications to the structure of preexisting tables, including 78 | 79 | - Added/removed/modified rows 80 | - Schema changes, e.g. ``ALTER TABLE``, ``ALTER COLUMN``, etc.) Tables that 81 | were renamed or moved to different schemas will be moved back. 82 | - Added/removed/modified constraints or indexes 83 | 84 | - Schemas, tables, and other objects that were dropped during the test cannot be 85 | fully restored. Schemas can be recreated but may have lost some of their 86 | original contents. 87 | - Database settings such as changes to the search path won't be reverted to 88 | defaults. 89 | - Ownership and permission changes will persist until the end of the test 90 | session. 91 | 92 | Utility Functions 93 | ----------------- 94 | 95 | There are a few utility functions each fixture gives you as well. The following 96 | examples use `postgresql_db`, but `transacted_postgresql_db` behaves similarly. 97 | 98 | Extensions 99 | ~~~~~~~~~~ 100 | 101 | Since version 9.1 Postgres supports `extensions `_. 102 | You can check for the presence of and install extensions like so:: 103 | 104 | >>> postgresql_db.is_extension_available('asdf') # Can I install this extension? 105 | False 106 | >>> postgresql_db.is_extension_available('uuid-ossp') # Maybe this one is supported... 107 | True 108 | >>> postgresql_db.install_extension('uuid-ossp') 109 | True 110 | >>> postgresql_db.is_extension_installed('uuid-ossp') 111 | True 112 | 113 | `install_extension` has additional arguments to allow control over which schema 114 | the extension is installed in, what to do if the extension is already installed, 115 | and so on. See the documentation for descriptions of these features. 116 | 117 | Schemas 118 | ~~~~~~~ 119 | 120 | You can create `table schemas `_ 121 | by calling `create_schema` like so:: 122 | 123 | postgresql_db.create_schema('foo') 124 | 125 | The function will throw an exception if the schema already exists. If you only 126 | want to create the schema if it doesn't already exist, pass ``True`` for the 127 | ``exists_ok`` argument:: 128 | 129 | postgresql_db.create_schema('foo', exists_ok=True) 130 | 131 | To quickly see if a table schema exists, call `has_schema`:: 132 | 133 | >>> postgresql_db.has_schema('public') 134 | True 135 | 136 | Note that multiple schemas can be created at once:: 137 | 138 | postgresql_db.create_schema('foo', 'bar') 139 | 140 | Tables 141 | ~~~~~~ 142 | 143 | Similarly, you can create tables in the database with `create_table`. You can 144 | pass :class:`sqlalchemy.Table` instances or ORM declarative model classes:: 145 | 146 | # Just a regular Table. 147 | my_table = Table('abc', MetaData(), Column('def', Integer())) 148 | 149 | # A declarative model works too. 150 | class MyORMModel(declarative_base()): 151 | id = Column(Integer, primary_key=True) 152 | 153 | # Pass a variable amount of tables to create 154 | postgresql_db.create_table(my_table, MyORMModel) 155 | 156 | There are several ways to check to see if a table exists:: 157 | 158 | >>> postgresql_db.has_table('mytable') # 'mytable' in *any* schema 159 | True 160 | 161 | >>> postgresql_db.has_table('the_schema.the_table') # 'the_table' only in 'the_schema' 162 | False 163 | 164 | >>> table = Table('foo', MetaData(), Column('bar', Integer())) 165 | >>> postgresql_db.has_table(table) 166 | False 167 | >>> postgresql_db.create_table(table) 168 | >>> postgresql_db.has_table(table) 169 | True 170 | 171 | >>> postgresql_db.has_table(MyORMModelClass) 172 | True 173 | 174 | Manipulating Time 175 | ~~~~~~~~~~~~~~~~~ 176 | 177 | Both database fixtures use `freezegun `_ to 178 | allow you to freeze time inside a block of code. You can use it in a variety of 179 | ways: 180 | 181 | As a context manager:: 182 | 183 | with postgresql.time.freeze('December 31st 1999 11:59:59 PM') as freezer: 184 | # Time is frozen inside the database *and* Python. 185 | now = postgresql_db.session.execute('SELECT NOW()').scalar() 186 | assert now.date() == datetime.date(1999, 12, 31) 187 | assert datetime.date.today() == datetime.date(1999, 12, 31) 188 | 189 | # Advance time by 1 second so we roll over into the new year 190 | freezer.tick() 191 | 192 | now = postgresql_db.session.execute('SELECT NOW()').scalar() 193 | assert now.date() == datetime.date(2000, 1, 1) 194 | 195 | 196 | Manually calling the :meth:`~pytest_pgsql.time.SQLAlchemyFreezegun.freeze` 197 | and :meth:`~pytest_pgsql.time.SQLAlchemyFreezegun.unfreeze` functions:: 198 | 199 | postgresql_db.time.freeze(datetime.datetime(1999, 12, 31, 23, 59, 59)) 200 | ... 201 | postgresql_db.time.unfreeze() 202 | 203 | You can also freeze time for an entire test if you like using the `freeze_time` 204 | decorator:: 205 | 206 | @pytest_pgsql.freeze_time(datetime.datetime(2038, 1, 19, 3, 14, 7)) 207 | def test_freezing(postgresql_db): 208 | today = postgresql_db.session.execute( 209 | "SELECT EXTRACT('YEAR' FROM CURRENT_DATE)").scalar() 210 | assert today.year == 2038 211 | assert datetime.date.today() == datetime.date(2038, 1, 19) 212 | 213 | If you choose not to use the context manager but still need control over the 214 | flow of time, the ``FrozenDateTimeFactory`` instance can be accessed with the 215 | :attr:`~pytest_pgsql.time.SQLAlchemyFreezegun.freezer` attribute:: 216 | 217 | postgresql_db.time.freeze(datetime.datetime(1999, 12, 31, 23, 59, 59)) 218 | postgresql_db.time.freezer.tick() 219 | 220 | now = postgresql_db.session.execute('SELECT LOCALTIME').scalar() 221 | assert now == datetime.datetime(2000, 1, 1) 222 | 223 | postgresql_db.time.unfreeze() 224 | 225 | See the documentation for `SQLAlchemyFreezegun` for detailed information on what 226 | this feature can and can't do for you. 227 | 228 | General-Purpose Functions 229 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 230 | 231 | `postgresql_db` and `transacted_postgresql_db` provide some general-purpose 232 | functions to ease test setup and execution. 233 | 234 | - `load_csv` loads a CSV file into an existing table. 235 | - `run_sql_file` executes a SQL script, optionally performing variable binding. 236 | 237 | Fixture Customization 238 | --------------------- 239 | 240 | You may find that the default settings for the database fixtures are inadequate 241 | for your needs. You can customize how the engine and database fixtures are 242 | created with the use of facilities provided in the :mod:`~pytest_pgsql.ext` 243 | ("extension") module. 244 | 245 | Customizing the Engine 246 | ~~~~~~~~~~~~~~~~~~~~~~ 247 | 248 | Suppose we want our database engine to transparently encode a 249 | :class:`~datetime.datetime` or :class:`decimal.Decimal` object in JSON for us. 250 | We can create our own engine that'll do so by using 251 | :func:`~pytest_pgsql.ext.create_engine_fixture`:: 252 | 253 | import pytest_pgsql 254 | import simplejson as json 255 | 256 | json_engine = pytest_pgsql.create_engine_fixture( 257 | 'json_engine', json_serializer=json.dumps, json_deserializer=json.loads) 258 | 259 | Great! Now we have a database engine that can encode and decode timestamps and 260 | fixed-point decimals without any manual conversion on our part. This is not the 261 | only way we can customize the engine--you can pass any keyword argument to 262 | :func:`~pytest_pgsql.ext.create_engine_fixture` that's valid for 263 | `sqlalchemy.create_engine`. See the documentation there for a full list of what 264 | you can do. 265 | 266 | So how do we use it with all the benefits we get from `postgresql_db` and 267 | `transacted_postgresql_db`? 268 | 269 | Customizing Database Fixtures 270 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 271 | 272 | You can create your own database fixture by choosing any subclass of 273 | `PostgreSQLTestDBBase` and invoking its 274 | :meth:`~pytest_pgsql.database.PostgreSQLTestDBBase.create_fixture` method, 275 | passing the name of your new fixture and the name of the engine fixture to use:: 276 | 277 | import pytest_pgsql 278 | 279 | simplejson_db = pytest_pgsql.PostgreSQLTestDB.create_fixture( 280 | 'simplejson_db', 'json_engine') 281 | 282 | We now have a function-scoped database fixture identical to `postgresql_db` but 283 | with more comprehensive JSON serialization! If I wanted a faster transacted 284 | version, I could use `TransactedPostgreSQLTestDB` as the base class instead:: 285 | 286 | import pytest_pgsql 287 | 288 | tsimplejson_db = pytest_pgsql.TransactedPostgreSQLTestDB.create_fixture( 289 | 'tsimplejson_db', 'json_engine') 290 | 291 | You can change how the fixture is created by passing any keyword arguments that 292 | are valid for the ``pytest.fixture`` decorator. For example, you can set the 293 | scope of the fixture to the module level by using the ``scope`` keyword argument:: 294 | 295 | simplejson_db = pytest_pgsql.PostgreSQLTestDB.create_fixture( 296 | 'simplejson_db', 'json_engine', scope='module') 297 | 298 | 299 | Now, in our tests we can use the fixtures directly:: 300 | 301 | import datetime 302 | import sqlalchemy as sqla 303 | import sqlalchemy.dialects.postgresql as sqla_pg 304 | 305 | def test_blah(simplejson_db): 306 | meta = sqla.MetaData(bind=simplejson_db.connection) 307 | table = sqla.Table('test', meta, sqla.Column('col', sqla_pg.JSON)) 308 | meta.create_all() 309 | 310 | simplejson_db.connection.execute(table.insert(), [ 311 | {'col': datetime.datetime.now()} 312 | ]) 313 | 314 | 315 | Command Line Options 316 | -------------------- 317 | 318 | ``--pg-extensions`` 319 | ~~~~~~~~~~~~~~~~~~~ 320 | 321 | If many of your tests are going to need one or more particular extensions, you 322 | can tell ``pytest_pgsql`` to install them at the beginning of the test session. 323 | This is *much* faster and more efficient than calling `install_extension` for 324 | each test. 325 | 326 | Pass a comma-separated list of the extensions you need on the command line like 327 | so: 328 | 329 | .. code-block:: sh 330 | 331 | # Install "uuid-ossp" and "pg_tgrm" so all tests can use it 332 | pytest --pg-extensions=uuid-ossp,pg_tgrm 333 | 334 | ``--pg-work-mem`` 335 | ~~~~~~~~~~~~~~~~~ 336 | 337 | The ``--pg-work-mem`` option allows you to tweak the amount of memory that sort 338 | operations can use. The Postgres default is rather low (4MB at time of writing) 339 | so ``pytest_pgsql`` uses 32MB as its default. Try adjusting this value up 340 | or down to find the optimal value for your test suite, or use ``0`` to use the 341 | server default. 342 | 343 | .. code-block:: sh 344 | 345 | # Increase the amount of working memory to 64MB 346 | pytest --pg-work-mem=64 347 | 348 | # Disable tweaking and use the server default 349 | pytest --pg-work-mem=0 350 | 351 | For more information: 352 | 353 | * PostgreSQL documentation: `Resource Consumption `_ 354 | * PostgreSQL wiki: `Tuning your PostgreSQL Server `_ 355 | 356 | .. _tips: 357 | 358 | Tips 359 | ---- 360 | 361 | Be careful with ``COMMIT`` 362 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 363 | 364 | When using `transacted_postgresql_db`, do *not* use ``connection.execute()`` to 365 | commit changes made:: 366 | 367 | # This is fine... 368 | transacted_postgresql_db.session.commit() 369 | 370 | # So is this... 371 | transaction = transacted_postgresql_db.connection.begin() 372 | transaction.commit() 373 | 374 | # But this is not. 375 | transacted_postgresql_db.connection.execute('COMMIT') 376 | 377 | The problem with executing ``COMMIT`` in a transacted PostgreSQL testcase is that 378 | all tests assume they're running in a clean database. Committing persists changes 379 | made, so the database is no longer clean *for the rest of the session*. Let's 380 | see how that can be harmful: 381 | 382 | 1. Suppose we have a test A that creates some rows in table X and executes a 383 | ``COMMIT``. We now have one row in X. 384 | 385 | 2. Test B creates another row in X. Now there are two rows in the table, but 386 | test B thinks there's only one. 387 | 388 | 3. Test B does a search for all rows in X that fit some criterion, but there's a 389 | bug in the code and it unintentionally skips the first row it finds. If test 390 | A created a row meeting that criterion, then test B will pass *even though 391 | there's a bug in B's code*. 392 | 393 | Furthermore, this will only happen *if* test A runs before test B. Thus, adding 394 | or removing any tests can change the order and make the error appear and disappear 395 | seemingly at random. 396 | 397 | .. [#] A *top-level commit* is a commit made on the outermost transaction of a 398 | session. SQLAlchemy allows you to nest transactions so that changes are only 399 | persisted to the database when the outermost one is committed. For more 400 | information, see 401 | `Using SAVEPOINT `_ 402 | in the SQLAlchemy docs. 403 | 404 | .. [#] `pg8000 `_ is one such driver that 405 | doesn't work. If your driver uses server-side prepared statements instead of 406 | doing the parametrization in Python, your driver *will not* work. This is 407 | because PostgreSQL's prepared statements don't support parametrizing ``IN`` 408 | clauses, something currently required by 409 | :meth:`~pytest_pgsql.database.PostgreSQLTestDBBase.is_dirty`. 410 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | System Requirements 5 | ------------------- 6 | 7 | * PostgreSQL 9.1 or greater must be installed. 8 | * Python 3.4 or greater. Compatibility with PyPy3.5 is untested and not guaranteed. 9 | * You must use ``psycopg2`` as your database driver. 10 | 11 | .. note:: 12 | 13 | Due to the way that ``tox`` works with environment setup, if your system's 14 | Python 3 version is 3.6.x and you installed any Python package that uses 15 | ``cli-helpers`` version 0.2.0 or greater, ``make setup`` will fail. This is 16 | due to a `known bug `_ in 17 | ``pbr`` and as of 2017-12-02 there is no workaround that won't potentially 18 | break other packages. 19 | 20 | Setup 21 | ----- 22 | 23 | .. code-block:: sh 24 | 25 | pip3 install pytest-pgsql 26 | -------------------------------------------------------------------------------- /docs/release_notes.rst: -------------------------------------------------------------------------------- 1 | Release Notes 2 | ============= 3 | 4 | .. literalinclude:: ../ChangeLog 5 | :language: rst 6 | :start-after: ======= 7 | -------------------------------------------------------------------------------- /docs/toc.rst: -------------------------------------------------------------------------------- 1 | Table of Contents 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | index 8 | installation 9 | release_notes 10 | contributing 11 | modules 12 | 13 | Indices and tables 14 | ================== 15 | 16 | * :ref:`genindex` 17 | * :ref:`modindex` 18 | * :ref:`search` 19 | -------------------------------------------------------------------------------- /pytest_pgsql/__init__.py: -------------------------------------------------------------------------------- 1 | """pytest_pgsql""" 2 | from pytest_pgsql.version import __version__ # flake8: noqa 3 | from pytest_pgsql.time import SQLAlchemyFreezegun 4 | from pytest_pgsql.time import freeze_time 5 | from pytest_pgsql.database import PostgreSQLTestDBBase 6 | from pytest_pgsql.database import PostgreSQLTestDB 7 | from pytest_pgsql.database import TransactedPostgreSQLTestDB 8 | from pytest_pgsql.ext import create_engine_fixture 9 | -------------------------------------------------------------------------------- /pytest_pgsql/database.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import csv 3 | 4 | import pytest 5 | import sqlalchemy as sqla 6 | import sqlalchemy.exc as sqla_exc 7 | import sqlalchemy.orm as sqla_orm 8 | 9 | import pytest_pgsql.time 10 | from pytest_pgsql import errors 11 | 12 | 13 | #: A query to get a snapshot of all existing tables and views and their OIDs. 14 | TABLE_SNAPSHOT_QUERY = """ 15 | SELECT 16 | schemaname AS schema_name, 17 | tablename AS table_name, 18 | (schemaname || '.' || tablename)::regclass::oid AS table_oid 19 | FROM pg_tables 20 | """ 21 | 22 | 23 | def create_database_snapshot(connectable): 24 | """Create a snapshot of the current state of the database so that we can 25 | restore it to this state when the test exits. 26 | 27 | Arguments: 28 | connectable (`sqlalchemy.engine.Connectable`): 29 | The engine, connection, or other Connectable to use to take this 30 | snapshot. 31 | 32 | Returns (dict): 33 | A dictionary with three keys: 34 | 35 | * ``schemas``: A tuple of the names of all the schemas present. 36 | * ``tables``: A list of all of the tables present. Each item in the list 37 | is a dictionary with the schema name, table name, and table OID. 38 | * ``extensions``: A tuple of the names of all the extensions currently 39 | installed. 40 | """ 41 | execute = connectable.execute 42 | return { 43 | 'schemas': tuple(r['nspname'] for r in execute('SELECT nspname FROM pg_namespace')), 44 | 'tables': [dict(r) for r in execute(TABLE_SNAPSHOT_QUERY)], 45 | 'extensions': tuple( 46 | r['extname'] for r in execute('SELECT extname FROM pg_extension') 47 | ) 48 | } 49 | 50 | 51 | class PostgreSQLTestDBBase(metaclass=abc.ABCMeta): 52 | """Utility to wrap ``testing.postgresql`` and provide extra functionality. 53 | 54 | This is a base class and cannot be instantiated directly. Take a look at the 55 | main subclasses, `PostgreSQLTestDB` and `TransactedPostgreSQLTestDB`. 56 | 57 | Arguments: 58 | url (str): 59 | The connection URI of the PostgreSQL test database. 60 | 61 | connectable (`sqlalchemy.engine.Connectable`): 62 | The SQLAlchemy engine or connection to be used for accessing the 63 | test database. The ORM session will be backed by this connectable. 64 | 65 | restore_state (dict): 66 | Optional. A snapshot of the state to restore the database to after 67 | each test. If not given, the only cleanup that can be performed is a 68 | rollback of the current transaction. 69 | 70 | .. seealso:: `create_database_snapshot` 71 | """ 72 | def __init__(self, url, connectable, restore_state=None): 73 | self._conn = connectable 74 | self._sessionmaker = sqla_orm.sessionmaker(bind=connectable) 75 | self.session = self._sessionmaker() 76 | self.postgresql_url = url 77 | self.time = pytest_pgsql.time.SQLAlchemyFreezegun(connectable) 78 | self._restore_state = restore_state 79 | 80 | def is_dirty(self): 81 | """Determine if there are tables, schemas, or extensions installed that 82 | weren't there when the test started. 83 | 84 | If this returns ``True``, then a full teardown of the database is needed 85 | to return it to its original state. 86 | 87 | Returns (bool): 88 | ``True`` if the database needs to be cleaned up with `reset_db`, 89 | ``False`` otherwise. 90 | """ 91 | original_table_oids = tuple(t['table_oid'] for t in self._restore_state['tables']) 92 | 93 | query = sqla.text(""" 94 | SELECT 95 | EXISTS( 96 | SELECT 1 FROM pg_namespace 97 | WHERE nspname NOT IN :ignore_schemas 98 | LIMIT 1 99 | ) 100 | OR 101 | EXISTS( 102 | SELECT 1 FROM pg_tables 103 | WHERE (schemaname || '.' || tablename)::regclass::oid NOT IN :ignore_tables 104 | LIMIT 1 105 | -- Checking for OIDs in our snapshot that're missing from pg_tables 106 | -- will give us a list of all preexisting tables that are now 107 | -- missing. Do we care? 108 | ) 109 | OR 110 | EXISTS( 111 | SELECT 1 FROM pg_extension 112 | WHERE extname NOT IN :ignore_extensions 113 | LIMIT 1 114 | ) 115 | """).bindparams(ignore_tables=original_table_oids, 116 | ignore_schemas=self._restore_state['schemas'], 117 | ignore_extensions=self._restore_state['extensions']) 118 | 119 | return self._conn.execute(query).scalar() 120 | 121 | def _clean_up_extensions(self): 122 | """Drop any extensions installed by the test.""" 123 | # Build a list of all extensions we installed during the tests and drop 124 | # them. 125 | new_extensions = self._conn.execute(sqla.text(""" 126 | SELECT extname FROM pg_extension 127 | WHERE extname NOT IN :ignore 128 | """).bindparams(ignore=self._restore_state['extensions'])) 129 | 130 | quote = self.id_quoter.quote 131 | drop_query = ';'.join( 132 | 'DROP EXTENSION IF EXISTS %s CASCADE' % quote(r['extname']) 133 | for r in new_extensions) 134 | if drop_query: 135 | self._conn.execute(drop_query) 136 | 137 | def _clean_up_schemas(self): 138 | """Drop all schemas created during this test. 139 | 140 | .. warning:: 141 | 142 | This should NOT be executed before we're sure that all preexisting 143 | tables are back in their original schemas. No checks are performed 144 | to ensure that preexisting tables we need to save aren't in any of 145 | the schemas we're about to drop. 146 | """ 147 | execute = self._conn.execute 148 | quote = self.id_quoter.quote 149 | 150 | extra_schemas = execute(sqla.text(""" 151 | SELECT 152 | nspname 153 | FROM pg_namespace 154 | WHERE 155 | nspname != 'pytest_pgsql' 156 | AND nspname NOT IN :schemas 157 | """).bindparams(schemas=self._restore_state['schemas'])) 158 | 159 | for schema in (r['nspname'] for r in extra_schemas): 160 | try: 161 | execute('DROP SCHEMA %s CASCADE' % quote(schema)) 162 | except sqla_exc.OperationalError: # pragma: no cover 163 | # Sometimes when we drop really large schemas the database will 164 | # crash because it runs out of memory. If that happens we gotta 165 | # drop all the tables in the schema one by one. 166 | 167 | # Recover from the exception. 168 | self.rollback() 169 | 170 | extra_tables = execute( 171 | sqla.text(""" 172 | SELECT 173 | table_name 174 | FROM pytest_pgsql.current_tables 175 | WHERE schema_name = :name 176 | """) 177 | .bindparams(name=schema)) 178 | 179 | for table in (r['table_name'] for r in extra_tables): 180 | execute('DROP TABLE %s.%s CASCADE' 181 | % (quote(schema), quote(table))) 182 | 183 | execute('DROP SCHEMA %s CASCADE' % quote(schema)) 184 | 185 | def _undo_table_renames(self): 186 | """Undo table renames and ensure preexisting tables are in their original 187 | schemas. 188 | """ 189 | execute = self._conn.execute 190 | quote = self.id_quoter.quote 191 | 192 | # We can't just rename them one by one because if two tables swapped 193 | # names that'd cause a collision. Instead, we rename each table to use 194 | # its schema and table OIDs so that the names are guaranteed(ish) to be 195 | # unique, *then* move everything back to where it was. 196 | 197 | # Build a list of all original tables that have been renamed or changed 198 | # schemas. 199 | rows = execute(""" 200 | SELECT 201 | cur.*, 202 | orig.schema_name AS orig_schema, 203 | orig.table_name AS orig_table, 204 | floor(random() * 1000) AS rnd_i -- See explanation below 205 | FROM pytest_pgsql.original_tables AS orig 206 | -- Use LEFT JOIN so table_oid will be null if a table was deleted. 207 | LEFT JOIN pytest_pgsql.current_tables AS cur 208 | ON orig.table_oid = cur.table_oid 209 | WHERE ( 210 | orig.table_name != cur.table_name 211 | OR orig.schema_name != cur.schema_name 212 | OR cur.table_oid IS NULL 213 | ) 214 | AND cur.schema_name != 'pytest_pgsql' 215 | """) 216 | moved_tables = [dict(r) for r in rows] 217 | 218 | # Detect original tables that were deleted, and crash if any were. 219 | deleted_tables = [t for t in moved_tables if t['table_oid'] is None] 220 | if deleted_tables: # pragma: no cover 221 | raise errors.DatabaseRestoreFailedError( 222 | "Can't restore dropped table(s): " + 223 | ', '.join('{orig_schema}.{orig_table}'.format_map(t) 224 | for t in deleted_tables)) 225 | 226 | # Rename each table to something unique-ish - a combination of the 227 | # table's OID and a random number. Now, when we start moving/renaming 228 | # tables back to what they used to be, the chances of a collision are 229 | # minimal. 230 | rename_query = ';'.join( 231 | 'ALTER TABLE %s.%s RENAME TO %s' % ( 232 | quote(t['schema_name']), 233 | quote(t['table_name']), 234 | '_pgtu_{orig_table_oid}{rnd_i}'.format_map(t) 235 | ) 236 | for t in moved_tables) 237 | if rename_query: # pragma: no cover 238 | execute(rename_query) 239 | 240 | # All tables renamed, start moving them back to their original places. 241 | move_query = ';'.join(""" 242 | ALTER TABLE {cur_schema}.{rntable} RENAME TO {orig_table}; 243 | CREATE SCHEMA IF NOT EXISTS {orig_schema}; 244 | ALTER TABLE {cur_schema}.{orig_table} SET SCHEMA {orig_schema}; 245 | """.format(cur_schema=quote(t['schema_name']), 246 | rntable='_pgtu_{orig_table_oid}{rnd_i}'.format_map(t), 247 | orig_table=quote(t['orig_table']), 248 | orig_schema=quote(t['orig_schema'])) 249 | for t in moved_tables) 250 | if move_query: # pragma: no cover 251 | execute(move_query) 252 | 253 | def _clean_up_tables(self): 254 | """Drop any tables created by the test. 255 | 256 | This should be executed *after* extra schemas were dropped to minimize 257 | the number of tables that have to be dropped individually. 258 | """ 259 | execute = self._conn.execute 260 | quote = self.id_quoter.quote 261 | 262 | ignored_tables = tuple( 263 | '{schema_name}.{table_name}'.format_map(t) 264 | for t in self._restore_state['tables'] 265 | ) 266 | 267 | new_tables = execute(sqla.text(""" 268 | SELECT 269 | schemaname, 270 | tablename 271 | FROM pg_tables 272 | WHERE schemaname || '.' || tablename NOT IN :ignore; 273 | """).bindparams(ignore=ignored_tables)) 274 | 275 | drop_query = ';'.join( 276 | 'DROP TABLE %s.%s CASCADE' % (quote(r['schemaname']), quote(r['tablename'])) 277 | for r in new_tables 278 | ) 279 | if drop_query: # pragma: no cover 280 | execute(drop_query) 281 | 282 | def restore_to_snapshot(self): 283 | """Restore the database to its original state. 284 | 285 | :raises `NoSnapshotAvailableError`: 286 | If the restore snapshot wasn't given to the constructor. 287 | """ 288 | if not self._restore_state: 289 | raise errors.NoSnapshotAvailableError() 290 | 291 | self._clean_up_extensions() 292 | 293 | self._conn.execute(""" 294 | DROP SCHEMA IF EXISTS pytest_pgsql CASCADE; 295 | CREATE SCHEMA pytest_pgsql; 296 | 297 | CREATE UNLOGGED TABLE pytest_pgsql.current_tables AS {table_query}; 298 | CREATE UNLOGGED TABLE pytest_pgsql.original_tables ( 299 | LIKE pytest_pgsql.current_tables EXCLUDING ALL 300 | ); 301 | """.format(table_query=TABLE_SNAPSHOT_QUERY)) 302 | 303 | orig_tables = self.get_table('pytest_pgsql.original_tables') 304 | # pylint: disable=no-value-for-parameter 305 | self._conn.execute(orig_tables.insert().values(self._restore_state['tables'])) 306 | # pylint: enable=no-value-for-parameter 307 | 308 | self._undo_table_renames() 309 | self._clean_up_schemas() 310 | self._clean_up_tables() 311 | self._conn.execute('DROP SCHEMA pytest_pgsql CASCADE; COMMIT;') 312 | 313 | def reset_db(self): 314 | """Reset the database to its initial state.""" 315 | self.time.unfreeze() 316 | self.rollback() 317 | if self._restore_state is not None: 318 | self.restore_to_snapshot() 319 | 320 | @property 321 | def id_quoter(self): 322 | """An :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` you can use 323 | to quote table names, identifiers, etc. to prevent SQL injection 324 | vulnerabilities. 325 | """ 326 | return self._conn.dialect.preparer(self._conn.dialect) 327 | 328 | def is_extension_available(self, name): 329 | """Determine if the named extension is available for installation. 330 | 331 | Arguments: 332 | name (str): 333 | The name of the extension to search for. 334 | 335 | Returns (bool): 336 | ``True`` if the extension is available, ``False`` otherwise. Note 337 | that availability is no guarantee the extension will install 338 | successfully. 339 | """ 340 | query = sqla.text( 341 | 'SELECT EXISTS(SELECT 1 FROM pg_available_extensions WHERE name=:n LIMIT 1)' 342 | ).bindparams(n=name) 343 | return self._conn.execute(query).scalar() 344 | 345 | def install_extension(self, extension, if_available=False, exists_ok=False, 346 | schema=None): 347 | """Install a PostgreSQL extension. 348 | 349 | Arguments: 350 | extension (str): 351 | The name of the extension to install. 352 | 353 | schema (str): 354 | Optional. The name of the schema to install the extension into. 355 | If not given, it'll be installed in the default schema. Consult 356 | the PostgreSQL docs for `CREATE EXTENSION`__ for more info. 357 | 358 | if_available (bool): 359 | Only attempt to install the extension if the PostgreSQL server 360 | supports it. 361 | 362 | exists_ok (bool): 363 | Don't bother installing the extension if it's already installed. 364 | 365 | Returns (bool): 366 | ``True`` if the extension was installed. If ``if_available`` is set, 367 | then this returns ``False`` if installation was skipped because the 368 | extension isn't available. 369 | 370 | .. note:: 371 | Dependencies are *not* automatically installed. 372 | 373 | .. _pg_doc: https://www.postgresql.org/docs/current/static/sql-createextension.html 374 | __ pg_doc_ 375 | """ 376 | if if_available and not self.is_extension_available(extension): 377 | return False 378 | 379 | check = 'IF NOT EXISTS' if exists_ok else '' 380 | 381 | if schema: 382 | stmt = 'CREATE EXTENSION {check} {ext} WITH SCHEMA {schema}'.format( 383 | check=check, 384 | ext=self.id_quoter.quote_identifier(extension), 385 | schema=self.id_quoter.quote_schema(schema)) 386 | else: 387 | stmt = 'CREATE EXTENSION {check} {ext}'.format( 388 | check=check, 389 | ext=self.id_quoter.quote_identifier(extension)) 390 | 391 | self._conn.execute(stmt) 392 | return True 393 | 394 | def has_extension(self, extension): 395 | """Determine if the given extension has already been installed. 396 | 397 | Arguments: 398 | extension (str): 399 | The name of the extension to search for. 400 | 401 | Returns (bool): 402 | ``True`` if the extension is installed, ``False`` otherwise. 403 | 404 | .. note :: 405 | This is *not* the same as checking the availability of an extension. 406 | You'll need to use `is_extension_available` for that. 407 | """ 408 | query = sqla.text( 409 | 'SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname=:n LIMIT 1)' 410 | ).bindparams(n=extension) 411 | return self._conn.execute(query).scalar() 412 | 413 | def has_schema(self, schema): 414 | """Determine if the given schema exists in the database. 415 | 416 | Arguments: 417 | schema (str): 418 | The name of the schema to check for. 419 | 420 | Returns (bool): 421 | ``True`` if the schema exists, ``False`` otherwise. 422 | """ 423 | query = sqla.text( 424 | 'SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname=:s LIMIT 1)' 425 | ).bindparams(s=schema) 426 | return self._conn.execute(query).scalar() 427 | 428 | def has_table(self, table): 429 | """Determine if the given table exists in the database. 430 | 431 | ``table`` must reference a regular table, not a view, foreign table, or 432 | temporary table. 433 | 434 | Arguments: 435 | table: 436 | The table to search for. This can be any of the following: 437 | 438 | - A full table name with the schema: ``myschema.mytable``. 439 | - Just the table name: ``mytable``. This will search in *all* 440 | schemas for a table with the given name, *not* the search 441 | path. 442 | - A `sqlalchemy.schema.Table` object. 443 | - A SQLAlchemy ORM declarative model. 444 | 445 | Returns (bool): 446 | ``True`` if the table exists, ``False`` otherwise. 447 | """ 448 | if isinstance(table, str): 449 | if '.' in table: 450 | schema_name, _sep, table_name = table.partition('.') 451 | else: 452 | schema_name = '' 453 | table_name = table 454 | elif isinstance(table, sqla.Table): 455 | schema_name = table.schema 456 | table_name = table.name 457 | elif hasattr(table, '__table__'): 458 | # Assume this is an ORM declarative model. 459 | schema_name = table.__table__.schema 460 | table_name = table.__table__.name 461 | else: 462 | raise TypeError( 463 | 'Expected str, SQLAlchemy Table, or declarative model, got %r.' 464 | % type(table).__name__) 465 | 466 | subquery = 'SELECT 1 FROM pg_tables WHERE tablename = :t' 467 | params = {'t': table_name} 468 | 469 | if schema_name: 470 | subquery += ' AND schemaname = :s' 471 | params['s'] = schema_name 472 | 473 | query = sqla.text('SELECT EXISTS (' + subquery + ' LIMIT 1)').bindparams(**params) 474 | return self._conn.execute(query).scalar() 475 | 476 | def create_schema(self, *schemas, exists_ok=False): 477 | """Create one or more schemas in the test database. 478 | 479 | Schemas are created in a single operation. 480 | 481 | Arguments: 482 | *schemas (str): 483 | The names of the schemas to create. 484 | 485 | exists_ok (bool): 486 | Don't throw an exception if the schema exists already. 487 | """ 488 | check = 'IF NOT EXISTS' if exists_ok else '' 489 | quoted_names = [self.id_quoter.quote_schema(s) for s in schemas] 490 | query = ';'.join('CREATE SCHEMA %s %s' % (check, s) for s in quoted_names) 491 | return self._conn.execute(query) 492 | 493 | def create_table(self, *tables): 494 | """Create a table in the database. 495 | 496 | If the table is in a schema and that schema does not exist, it will be 497 | created. 498 | 499 | Arguments: 500 | *tables: 501 | `sqlalchemy.schema.Table` instances or declarative model 502 | classes. 503 | """ 504 | for table in tables: 505 | if not isinstance(table, sqla.Table): 506 | table = table.__table__ 507 | 508 | if table.schema is not None: 509 | self.create_schema(table.schema, exists_ok=True) 510 | table.create(self._conn) 511 | 512 | def get_table(self, table, metadata=None): 513 | """Create a `sqlalchemy.schema.Table` instance from an existing table in 514 | the database. 515 | 516 | SQLAlchemy refers to this as `reflection 517 | `_. 518 | 519 | Arguments: 520 | table (str): 521 | The name of the table to reflect, including the schema name if 522 | applicable (e.g. ``'my_schema.the_table'``). 523 | 524 | metadata (`sqlalchemy.schema.MetaData`): 525 | The metadata to associate the table with. If not given, a new 526 | :class:`~sqlalchemy.schema.MetaData` object will be created and 527 | bound to the current connection or engine. 528 | 529 | Returns (`sqlalchemy.schema.Table`): 530 | The reflected table. 531 | """ 532 | if not metadata: 533 | metadata = sqla.MetaData(bind=self._conn) 534 | 535 | # If the metadata isn't bound to an engine or connection we need to pass 536 | # `autoload_with` and a Connectible. 537 | if not metadata.bind: 538 | kwargs = {'autoload_with': self._conn} 539 | else: 540 | kwargs = {} 541 | 542 | schema_name, _sep, table_name = table.partition('.') 543 | if table_name: 544 | # Caller passed in a table and a schema. 545 | return sqla.Table(table_name, metadata, autoload=True, 546 | schema=schema_name, **kwargs) 547 | 548 | # Caller passed in a table name without a schema. The name of the table 549 | # will be in `schema_name` due to how `partition()` works. 550 | return sqla.Table(schema_name, metadata, autoload=True, **kwargs) 551 | 552 | def run_sql_file(self, source, **bindings): 553 | """Convenience method for running a SQL file, optionally filling in any 554 | bindings present in the file. 555 | 556 | If the ``bindings`` mapping is empty, the query is executed exactly as 557 | is in the file. If ``bindings`` contains values, the query text is 558 | wrapped inside a :class:`~sqlalchemy.sql.expression.TextClause` with a 559 | call to :func:`sqlalchemy.expression.text` before execution. As such, 560 | the only supported parametrization syntax is the one that uses colons, 561 | e.g.: 562 | 563 | .. code-block:: sql 564 | 565 | DELETE FROM users WHERE username = :user 566 | 567 | Arguments: 568 | source: 569 | The path to the SQL file to run, or a file-like object with a 570 | ``read()`` function. 571 | 572 | **bindings: 573 | Values to bind to the query once the file is loaded. If no 574 | values are given, no binding will be performed and the file will 575 | be executed exactly as is. 576 | 577 | Returns (`sqlalchemy.engine.ResultProxy`): 578 | The results of the SQL file's execution. 579 | """ 580 | if isinstance(source, str): 581 | with open(source, 'r') as fd: 582 | to_run = fd.read() 583 | else: 584 | to_run = source.read() 585 | 586 | if bindings: 587 | return self._conn.execute(sqla.text(to_run).bindparams(**bindings)) 588 | return self._conn.execute(sqla.text(to_run)) 589 | 590 | def load_csv(self, csv_source, table, dialect='excel', truncate=False, 591 | cascade=False): 592 | """Load an existing table with the contents of a CSV. 593 | 594 | Arguments: 595 | csv_source: 596 | The path to a CSV file, or a readable file-like object. 597 | 598 | table: 599 | The name of the target table, a `sqlalchemy.schema.Table`, or a 600 | declarative model. 601 | 602 | dialect: 603 | Either a string naming one of the CSV dialects Python defines, 604 | or a `csv.Dialect` object to configure the CSV reader. 605 | 606 | truncate (bool): 607 | If ``True``, truncate the table and reset all sequences before 608 | loading anything so that it only contains rows from the CSV. The 609 | default is to only append rows to the table. 610 | 611 | cascade (bool): 612 | If ``True`` and ``truncate`` is also ``True``, then the truncate 613 | will cascade to rows in other tables that reference it. If 614 | ``False`` (the default), then the truncate won't cascade, and 615 | will throw an exception if there are any other tables with rows 616 | referencing this table. 617 | 618 | Returns (int): 619 | The number of rows inserted into the table. 620 | """ 621 | if isinstance(table, str): 622 | table_obj = self.get_table(table) 623 | elif hasattr(table, '__table__'): 624 | table_obj = table.__table__ 625 | else: 626 | table_obj = table 627 | 628 | if truncate: 629 | schema_name, _sep, table_name = table_obj.fullname.partition('.') 630 | quote = self.id_quoter.quote 631 | 632 | if table_name: 633 | quoted_table = '%s.%s' % (quote(schema_name), quote(table_name)) 634 | else: 635 | quoted_table = quote(schema_name) 636 | 637 | self._conn.execute('TRUNCATE TABLE ONLY %s RESTART IDENTITY %s' 638 | % (quoted_table, 'CASCADE' if cascade else '')) 639 | 640 | if isinstance(csv_source, str): 641 | with open(csv_source, 'r') as fdesc: 642 | data_rows = list(csv.DictReader(fdesc, dialect=dialect)) 643 | else: 644 | data_rows = list(csv.DictReader(csv_source, dialect=dialect)) 645 | 646 | self._conn.execute(table_obj.insert().values(data_rows)) 647 | return len(data_rows) 648 | 649 | @abc.abstractmethod 650 | def __enter__(self): 651 | """Start a transaction that will be rolled back upon exit.""" 652 | 653 | @abc.abstractmethod 654 | def __exit__(self, exc_type, exc_val, exc_tb): 655 | """Roll back all changes made while inside the context manager.""" 656 | 657 | @abc.abstractmethod 658 | def rollback(self): 659 | """Roll back the current transaction in a connection/engine/session 660 | agnostic way.""" 661 | 662 | @classmethod 663 | @abc.abstractmethod 664 | def create_fixture(cls, name, engine_name='pg_engine', 665 | use_restore_state=True, **fixture_kwargs): 666 | """Create a database fixture function using an instance of this class. 667 | 668 | Arguments: 669 | name (str): 670 | The name of the database fixture to create. This must be unique, 671 | so you can't use the names of any fixtures defined by this 672 | plugin. 673 | 674 | engine_name (str): 675 | The name of the engine fixture to use. The engine is lazily 676 | retrieved, so it only needs to be accessible at runtime. 677 | Default: ``pg_engine`` 678 | 679 | use_restore_state (bool): 680 | Whether to use a restore state. See the documentation for the 681 | ``restore_state`` constructor parameter for more details. 682 | Default: ``True`` 683 | 684 | **fixture_kwargs: 685 | Keyword arguments to pass to the ``pytest.fixture`` decorator. 686 | 687 | Returns (`callable`): 688 | A pytest fixture that returns an instance of a `PostgreSQLTestDBBase` 689 | subclass. 690 | """ 691 | 692 | 693 | class PostgreSQLTestDB(PostgreSQLTestDBBase): 694 | """A PostgreSQL test database that performs a full reset when a test finishes. 695 | 696 | Unless your test cannot run in one transaction, it's advised that you prefer 697 | `TransactedPostgreSQLTestDB` instead, since teardown is faster. 698 | 699 | Arguments: 700 | url (str): 701 | The connection URI of the PostgreSQL test database. 702 | 703 | engine (`sqlalchemy.engine.Engine`): 704 | The engine to use for database operations. 705 | 706 | restore_state (dict): 707 | Optional. A snapshot of the state to restore the database to after 708 | each test. If not given, the only cleanup that can be performed is a 709 | rollback of the current transaction. 710 | 711 | .. seealso:: `create_database_snapshot` 712 | """ 713 | def __init__(self, url, engine, restore_state=None): 714 | super().__init__(url, engine, restore_state) 715 | self.engine = engine 716 | 717 | def __enter__(self): 718 | return self 719 | 720 | def __exit__(self, exc_type, exc_val, exc_tb): 721 | self.rollback() 722 | if self.is_dirty(): 723 | self.reset_db() 724 | 725 | def rollback(self): 726 | return self.session.rollback() 727 | 728 | @classmethod 729 | def create_fixture(cls, name, engine_name='pg_engine', 730 | use_restore_state=True, **fixture_kwargs): 731 | """See :meth:`PostgreSQLTestDBBase.create_fixture`.""" 732 | @pytest.fixture(name=name, **fixture_kwargs) 733 | def _fixture(database_uri, request): 734 | engine = request.getfixturevalue(engine_name) 735 | 736 | if use_restore_state: 737 | restore_state = request.getfixturevalue('database_snapshot') 738 | else: # pragma: no cover 739 | restore_state = None 740 | 741 | with cls(database_uri, engine, restore_state) as inst: 742 | yield inst 743 | 744 | return _fixture 745 | 746 | 747 | class TransactedPostgreSQLTestDB(PostgreSQLTestDBBase): 748 | """A PostgreSQL test database that rolls back the current transaction when a 749 | test finishes. 750 | 751 | Arguments: 752 | url (str): 753 | The connection URI of the PostgreSQL test database. 754 | 755 | connection (`sqlalchemy.engine.Connection`): 756 | The connection to use for database operations. 757 | 758 | restore_state (dict): 759 | Optional. A snapshot of the state to restore the database to after 760 | each test. If not given, the only cleanup that can be performed is a 761 | rollback of the current transaction. A rollback is usually enough to 762 | completely reset the database and this is only needed in the event 763 | of an accidental ``COMMIT`` being executed. 764 | 765 | Database integrity will *not* be checked after the rollback if a 766 | restore state isn't given. 767 | 768 | .. seealso:: `create_database_snapshot` 769 | """ 770 | def __init__(self, url, connection, restore_state=None): 771 | super().__init__(url, connection, restore_state) 772 | self.connection = connection 773 | self._transaction = self.connection.begin() 774 | 775 | def reset_db(self): 776 | """Reset the database by rolling back the current transaction. 777 | 778 | If ``restore_state`` was passed to the constructor, an exception 779 | will be thrown if `is_dirty` returns ``True``. Database integrity is 780 | *not* verified if no restore state is given to the class. 781 | 782 | :raises `DatabaseIsDirtyError`: `is_dirty` returned ``True`` 783 | """ 784 | self.time.unfreeze() 785 | self.rollback() 786 | 787 | if not self._restore_state or not self.is_dirty(): 788 | return 789 | 790 | new_snapshot = create_database_snapshot(self._conn) 791 | raise errors.DatabaseIsDirtyError.from_snapshots(self._restore_state, 792 | new_snapshot) 793 | 794 | def __enter__(self): 795 | # Should already be inside a transaction so there's nothing to do here. 796 | return self 797 | 798 | def __exit__(self, exc_type, exc_val, exc_tb): 799 | self.reset_db() 800 | 801 | def rollback(self): 802 | """Roll back the current transaction and start a new one.""" 803 | self.session.rollback() 804 | self._transaction.rollback() 805 | self._transaction = self.connection.begin() 806 | 807 | @classmethod 808 | def create_fixture(cls, name, engine_name='pg_engine', 809 | use_restore_state=True, **fixture_kwargs): 810 | """See :meth:`PostgreSQLTestDBBase.create_fixture`.""" 811 | @pytest.fixture(name=name, **fixture_kwargs) 812 | def _fixture(database_uri, request): 813 | engine = request.getfixturevalue(engine_name) 814 | 815 | if use_restore_state: 816 | restore_state = request.getfixturevalue('database_snapshot') 817 | else: # pragma: no cover 818 | restore_state = None 819 | 820 | with engine.connect() as conn: 821 | with cls(database_uri, conn, restore_state) as inst: 822 | yield inst 823 | 824 | return _fixture 825 | -------------------------------------------------------------------------------- /pytest_pgsql/errors.py: -------------------------------------------------------------------------------- 1 | """Specialized errors.""" 2 | 3 | import collections 4 | 5 | 6 | TableInfo = collections.namedtuple('TableInfo', ['schema', 'table', 'oid']) 7 | 8 | 9 | def _diff_snapshots(original_snapshot, current_snapshot): 10 | """Compare two database snapshots and return the differences. 11 | 12 | Arguments: 13 | original_snapshot (dict): 14 | The original snapshot of the database. 15 | 16 | current_snapshot (dict): 17 | The snapshot of the database in its current state. 18 | 19 | Returns (dict): 20 | A dictionary with information on what schemas, extensions, tables, etc. 21 | are left over or are missing. 22 | 23 | .. seealso:: :func:`create_database_snapshot` 24 | """ 25 | new_schemas = set(current_snapshot['schemas']) 26 | old_schemas = set(original_snapshot['schemas']) 27 | new_ext = set(current_snapshot['extensions']) 28 | old_ext = set(original_snapshot['extensions']) 29 | new_tables = { 30 | TableInfo(t['schema_name'], t['table_name'], t['table_oid']) 31 | for t in current_snapshot['tables'] 32 | } 33 | old_tables = { 34 | TableInfo(t['schema_name'], t['table_name'], t['table_oid']) 35 | for t in original_snapshot['tables'] 36 | } 37 | 38 | return { 39 | 'extra_extensions': new_ext - old_ext, 40 | 'missing_extensions': old_ext - new_ext, 41 | 'extra_schemas': new_schemas - old_schemas, 42 | 'missing_schemas': old_schemas - new_schemas, 43 | 'extra_tables': new_tables - old_tables, 44 | 'missing_tables': old_tables - new_tables, 45 | } 46 | 47 | 48 | class Error(Exception): 49 | """The base class for all errors. 50 | 51 | This exception is not meant to be thrown directly. 52 | 53 | Arguments: 54 | message (str): 55 | Optional. The error message for the exception to be thrown. If not 56 | given, defaults to the first line of the exception class' docstring. 57 | """ 58 | def __init__(self, message=None): 59 | if not message: 60 | message = self.__doc__.splitlines()[0] 61 | super().__init__(message) 62 | 63 | 64 | class DatabaseRestoreFailedError(Error): 65 | """Generic base class for database reset failures.""" 66 | 67 | 68 | class DatabaseIsDirtyError(DatabaseRestoreFailedError): 69 | """Couldn't restore the database to its original state due to committed 70 | changes. 71 | 72 | Arguments: 73 | message (str): 74 | The error message. 75 | 76 | state_details (dict): 77 | Optional. A dictionary detailing what extensions, schemas, or tables 78 | are missing or are left over. Keys include: 79 | 80 | - ``extra_extensions``: A `set` of the names of extensions that 81 | weren't originally installed but still remain. 82 | - ``missing_extensions``: A `set` of the names of extensions that 83 | were uninstalled. 84 | - ``extra_schemas``: A `set` of the names of schemas that weren't 85 | present initially but still remain. 86 | - ``missing_schemas``: A `set` of the names of schemas that were 87 | initially present but were dropped and can't be restored. 88 | - ``extra_tables``: A `set` of `TableInfo` objects for tables that 89 | weren't present at the beginning of the test session. 90 | - ``missing_tables``: A `set` of `TableInfo` objects for tables that 91 | were present at the beginning of the test session but were dropped 92 | and can't be restored. 93 | """ 94 | def __init__(self, message=None, state_details=None): 95 | super().__init__(message) 96 | self.state_details = state_details 97 | 98 | @classmethod 99 | def from_snapshots(cls, original_snapshot, current_snapshot): 100 | """Create an exception with an error message derived from the given 101 | snapshots. 102 | 103 | Arguments: 104 | original_snapshot (dict): 105 | The snapshot of the database taken right after it was created. 106 | 107 | current_snapshot (dict): 108 | The snapshot of the database in its current (dirty) state. 109 | 110 | .. seealso:: `pytest_pgsql.database.create_database_snapshot` 111 | """ 112 | state = _diff_snapshots(original_snapshot, current_snapshot) 113 | strings = { 114 | 'extra_extensions': ', '.join(state['extra_extensions']) or 'None', 115 | 'missing_extensions': ', '.join(state['missing_extensions']) or 'None', 116 | 'extra_schemas': ', '.join(state['extra_schemas']) or 'None', 117 | 'missing_schemas': ', '.join(state['missing_schemas']) or 'None', 118 | 'extra_tables': 119 | ', '.join('{0.schema}.{0.table}'.format(t) 120 | for t in state['extra_tables']) or 'None', 121 | 'missing_tables': 122 | ', '.join('{0.schema}.{0.table}'.format(t) 123 | for t in state['missing_tables']) or 'None', 124 | } 125 | 126 | return cls( 127 | "The database state wasn't reset successfully. Extra tables or " 128 | "schemas may remain, or preexisting tables and/or schemas may not " 129 | "have been restored:\n" 130 | " * Extra extensions: %(extra_extensions)s\n" 131 | " * Missing extensions: %(missing_extensions)s\n" 132 | " * Extra schemas: %(extra_schemas)s\n" 133 | " * Missing schemas: %(missing_schemas)s\n" 134 | " * Extra tables: %(extra_tables)s\n" 135 | " * Missing tables: %(missing_tables)s" % strings, 136 | state) 137 | 138 | 139 | class NoSnapshotAvailableError(DatabaseRestoreFailedError): 140 | """Can't restore the database - no snapshot was given to the class.""" 141 | -------------------------------------------------------------------------------- /pytest_pgsql/ext.py: -------------------------------------------------------------------------------- 1 | """Facilities for extending the database.""" 2 | 3 | import pytest 4 | import sqlalchemy as sqla 5 | 6 | 7 | def create_engine_fixture(name, scope='session', **engine_params): 8 | """A factory function that creates a fixture with a customized SQLAlchemy 9 | :class:`~sqlalchemy.engine.Engine`. 10 | 11 | Because setup and teardown will require additional time and resources if 12 | you're using both a custom *and* the default engine, if you need this engine 13 | in more than one module you might want to consider using this scoped at the 14 | session level, i.e. initialized and torn down once for the entire test run. 15 | The tradeoff is that if you use multiple engines, each custom one will use 16 | additional resources such as connection pools and memory for the entirety of 17 | the session. If you only need this custom engine in a few places, it may be 18 | more resource-efficient to scope this to an individual test, class, or 19 | module. 20 | 21 | Any extensions declared using the ``--pg-extensions`` command-line option 22 | will be installed as part of this engine's setup process. 23 | 24 | .. warning:: 25 | Because an engine performs no cleanup itself, any changes made with an 26 | engine fixture directly are *not* rolled back and can result in the 27 | failure of other tests (usually with a 28 | :class:`~pytest_pgsql.errors.DatabaseIsDirtyError` at teardown). 29 | You should only use this in conjunction with 30 | :meth:`~pytest_pgsql.database.PostgreSQLTestDBBase.create_fixture` to 31 | create a *database* fixture that you'll use. Engine fixtures shouldn't 32 | be used directly. 33 | 34 | Arguments: 35 | name (str): 36 | The name of the fixture. It must be unique, so ``pg_engine`` is not 37 | allowed. 38 | 39 | scope (str): 40 | The scope that this customized engine should have. Valid values are: 41 | 42 | * ``class``: The engine is initialized and torn down for each test 43 | class that uses it. 44 | * ``function``: The engine is initialized and torn down for each 45 | test that uses it. 46 | * ``module``: The engine is initialized and torn down once per 47 | module that uses it. 48 | * ``session``: The engine is initialized and torn down once per 49 | pytest run. 50 | 51 | Default: ``session``. 52 | 53 | **engine_params: 54 | Keyword arguments to pass to :func:`sqlalchemy.create_engine`. (You 55 | cannot change the connection URL with this.) 56 | 57 | Usage: 58 | 59 | .. code-block:: python 60 | 61 | # conftest.py 62 | import simplejson as json 63 | 64 | # Create an engine fixture named `jengine` 65 | jengine = pytest_pgsql.create_engine_fixture( 66 | 'jengine', json_serializer=json.dumps, json_deserializer=json.loads) 67 | 68 | # Create a new database fixture that uses our `jengine`. 69 | jdb = pytest_pgsql.PostgreSQLTestDB.create_fixture('jdb', 'jengine') 70 | 71 | # ---------------- 72 | # test_json.py 73 | import datetime 74 | import sqlalchemy as sqla 75 | import sqlalchemy.dialects.postgresql as sqla_pg 76 | 77 | def test_blah(jdb): 78 | meta = sqla.MetaData(bind=jdb.connection) 79 | table = sqla.Table('test', meta, sqla.Column('col', sqla_pg.JSON)) 80 | meta.create_all() 81 | 82 | jdb.connection.execute(table.insert(), [ 83 | {'col': datetime.datetime.now()} 84 | ]) 85 | """ 86 | @pytest.fixture(name=name, scope=scope) 87 | def _engine_fixture(database_uri, request): 88 | engine = sqla.create_engine(database_uri, **engine_params) 89 | quote_id = engine.dialect.preparer(engine.dialect).quote_identifier 90 | 91 | opt_string = request.config.getoption('--pg-extensions') 92 | to_install = (s.strip() for s in opt_string.split(',')) 93 | 94 | query_string = ';'.join( 95 | 'CREATE EXTENSION IF NOT EXISTS %s' % quote_id(ext) 96 | for ext in to_install if ext) 97 | 98 | if query_string: # pragma: no cover 99 | engine.execute('BEGIN TRANSACTION; ' + query_string + '; COMMIT;') 100 | 101 | yield engine 102 | engine.dispose() 103 | 104 | return _engine_fixture 105 | -------------------------------------------------------------------------------- /pytest_pgsql/plugin.py: -------------------------------------------------------------------------------- 1 | """This forms the core of the pytest plugin.""" 2 | 3 | import pytest 4 | import testing.postgresql 5 | 6 | from pytest_pgsql import database 7 | from pytest_pgsql import ext 8 | 9 | 10 | def pytest_addoption(parser): 11 | """Add configuration options for pytest_pgsql.""" 12 | parser.addoption( 13 | '--pg-extensions', action='store', default='', 14 | help="A comma-separated list of PostgreSQL extensions to install at " 15 | "the beginning of the session for use by all tests. Example: " 16 | "--pg-extensions=uuid-ossp,pg_tgrm,pgcrypto") 17 | 18 | parser.addoption( 19 | '--pg-work-mem', type=int, default=32, 20 | help='Set the value of the `work_mem` setting, in megabytes. ' 21 | '`pytest_pgsql` defaults to 32. Adjusting this up or down can ' 22 | 'help performance; see the Postgres documentation for more details.') 23 | 24 | parser.addoption( 25 | '--pg-conf-opt', 26 | action='append', 27 | help='Set postgres config options for the test database. ' 28 | 'These are the options that are found in the postgres.conf file' 29 | 'Example: "--pg-conf-opt="track_commit_timestamp=True""') 30 | 31 | 32 | @pytest.fixture(scope='session') 33 | def database_uri(request): 34 | """A fixture giving the connection URI of the session-wide test database.""" 35 | # Note: due to the nature of the variable configs, the command line options 36 | # must be tested manually. 37 | 38 | work_mem = request.config.getoption('--pg-work-mem') 39 | if work_mem < 0: # pragma: no cover 40 | pytest.exit('ERROR: --pg-work-mem value must be >= 0. Got: %d' % work_mem) 41 | return 42 | elif work_mem == 0: # pragma: no cover 43 | # Disable memory tweak and use the server default. 44 | work_mem_setting = '' 45 | else: 46 | # User wants to change the working memory setting. 47 | work_mem_setting = '-c work_mem=%dMB ' % work_mem 48 | 49 | # pylint: disable=bad-continuation,deprecated-method 50 | conf_opts = request.config.getoption('--pg-conf-opt') 51 | if conf_opts: 52 | conf_opts_string = ' -c ' + ' -c '.join(conf_opts) 53 | else: 54 | conf_opts_string = '' 55 | 56 | with testing.postgresql.Postgresql( 57 | postgres_args='-c TimeZone=UTC ' 58 | '-c fsync=off ' 59 | '-c synchronous_commit=off ' 60 | '-c full_page_writes=off ' 61 | + work_mem_setting + 62 | '-c checkpoint_timeout=30min ' 63 | '-c bgwriter_delay=10000ms' 64 | + conf_opts_string 65 | ) as pgdb: 66 | yield pgdb.url() 67 | 68 | 69 | #: A SQLAlchemy engine shared by the transacted and non-transacted database fixtures. 70 | #: 71 | #: .. seealso:: `pytest_pgsql.ext.create_engine_fixture` 72 | # pylint: disable=invalid-name 73 | pg_engine = ext.create_engine_fixture('pg_engine', scope='session') 74 | # pylint: enable=invalid-name 75 | 76 | 77 | @pytest.fixture(scope='session') 78 | def database_snapshot(pg_engine): 79 | """Create one database snapshot for the session. 80 | 81 | The database will be restored to this state after each test. 82 | 83 | .. note :: 84 | 85 | This is an implementation detail and should not be used directly except 86 | by derived fixtures. 87 | """ 88 | return database.create_database_snapshot(pg_engine) 89 | 90 | 91 | # pylint: disable=invalid-name 92 | 93 | #: Create a test database instance and cleans up after each test finishes. 94 | #: 95 | #: You should prefer the `transacted_postgresql_db` fixture unless your test 96 | #: cannot be run in a single transaction. The `transacted_postgresql_db` fixture 97 | #: leads to faster tests since it doesn't tear down the entire database between 98 | #: each test. 99 | postgresql_db = \ 100 | database.PostgreSQLTestDB.create_fixture('postgresql_db') 101 | 102 | 103 | #: Create a test database instance that rolls back the current transaction after 104 | #: each test finishes, verifying its integrity before returning. 105 | #: 106 | #: Read the warning in the main documentation page before using this fixture. 107 | transacted_postgresql_db = \ 108 | database.TransactedPostgreSQLTestDB.create_fixture('transacted_postgresql_db') 109 | 110 | # pylint: enable=invalid-name 111 | -------------------------------------------------------------------------------- /pytest_pgsql/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CloverHealth/pytest-pgsql/75c7be051545ea20420557d90886cbd32eb46151/pytest_pgsql/tests/__init__.py -------------------------------------------------------------------------------- /pytest_pgsql/tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Common fixtures and functions for use in tests.""" 2 | 3 | import contextlib 4 | 5 | import pytest 6 | 7 | 8 | @contextlib.contextmanager 9 | def check_teardown(fixture, execute): 10 | yield fixture 11 | 12 | # Teardown hasn't been executed yet so we need to trigger it ourselves. 13 | fixture.rollback() 14 | fixture.reset_db() 15 | 16 | assert not fixture.is_dirty() 17 | 18 | 19 | @pytest.fixture 20 | def clean_tpgdb(transacted_postgresql_db): # pragma: no cover 21 | """A transacted_postgresql_db fixture that verifies its cleanliness...ish.""" 22 | execute = transacted_postgresql_db.connection.execute 23 | with check_teardown(transacted_postgresql_db, execute) as fixture: 24 | yield fixture 25 | 26 | 27 | @pytest.fixture 28 | def clean_pgdb(postgresql_db): # pragma: no cover 29 | """A postgresql_db fixture that verifies its cleanliness...ish.""" 30 | execute = postgresql_db.engine.execute 31 | with check_teardown(postgresql_db, execute) as fixture: 32 | yield fixture 33 | 34 | 35 | @pytest.fixture(params=['non-transacted', 'transacted']) 36 | def clean_db(database_uri, transacted_postgresql_db, postgresql_db, request): 37 | """Generic database - run a test with both the transacted and non-transacted 38 | databases.""" 39 | if request.param == 'transacted': 40 | yield transacted_postgresql_db 41 | else: 42 | yield postgresql_db 43 | -------------------------------------------------------------------------------- /pytest_pgsql/tests/database_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the database.""" 2 | 3 | import csv 4 | import datetime 5 | import io 6 | import random 7 | import tempfile 8 | 9 | import pytest 10 | import sqlalchemy as sqla 11 | import sqlalchemy.engine as sqla_eng 12 | import sqlalchemy.exc as sqla_exc 13 | import sqlalchemy.ext.declarative as sqla_decl 14 | from sqlalchemy import func as sqla_func 15 | import sqlalchemy.sql as sqla_sql 16 | 17 | import pytest_pgsql.time 18 | from pytest_pgsql import errors 19 | 20 | DeclBase = sqla_decl.declarative_base() 21 | 22 | 23 | class BasicModel(DeclBase): 24 | """A trivial ORM model for testing stuff.""" 25 | __tablename__ = 'basic_model' 26 | id = sqla.Column(sqla.Integer, primary_key=True) 27 | value = sqla.Column(sqla.Integer) 28 | 29 | 30 | BASIC_TABLE = sqla.Table( 31 | 'basic_table', 32 | sqla.MetaData(), 33 | sqla.Column('id', sqla.Integer), 34 | sqla.Column('value', sqla.Integer)) 35 | 36 | 37 | # We need to put this table in a separate schema to ensure that table names with 38 | # schemas are properly handled in the truncate statement created by load_csv(). 39 | REFERENCED_TABLE = sqla.Table( 40 | 'referenced_table', 41 | sqla.MetaData(), 42 | sqla.Column('id', sqla.Integer, primary_key=True, autoincrement=True), 43 | sqla.Column('value', sqla.Integer), 44 | schema='some_schema') 45 | 46 | 47 | REFERRING_TABLE = sqla.Table( 48 | 'referring_table', 49 | sqla.MetaData(), 50 | sqla.Column('id', sqla.Integer, primary_key=True, autoincrement=True), 51 | sqla.Column('ref_id', sqla.ForeignKey(REFERENCED_TABLE.c.id))) 52 | 53 | 54 | def random_identifier(prefix='_'): 55 | return prefix + '%06d' % random.randrange(1E7) 56 | 57 | 58 | def get_basictable_rowcount(session, table=BASIC_TABLE): 59 | """Return the number of rows in BASIC_TABLE.""" 60 | count_query = sqla_sql.select([sqla_func.count()]).select_from(table) 61 | return session.execute(count_query).scalar() 62 | 63 | 64 | @pytest.mark.parametrize('conn_name', [ 65 | '_conn', 66 | 'session', 67 | ]) 68 | def test_create_schema(clean_db, conn_name): 69 | """Create a schema in the test database.""" 70 | schema_name = random_identifier() 71 | clean_db.create_schema(schema_name) 72 | 73 | conn = getattr(clean_db, conn_name) 74 | query = sqla.text( 75 | 'SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname=:name LIMIT 1)' 76 | ).bindparams(name=schema_name) 77 | 78 | assert conn.execute(query).scalar() == 1 79 | 80 | 81 | def test_create_multiple_schemas(clean_db): 82 | """Create multiple schemas in the test database.""" 83 | schema_names = ['schema_%s' % i for i in range(5)] 84 | clean_db.create_schema(*schema_names) 85 | 86 | query = sqla.text( 87 | 'SELECT COUNT(*) FROM pg_namespace WHERE nspname IN :names' 88 | ).bindparams(names=tuple(schema_names)) 89 | 90 | assert clean_db.session.execute(query).scalar() == len(schema_names) 91 | 92 | 93 | def test_create_schema_no_injection(clean_db): 94 | """Verify that create_schema is invulnerable to SQL injection.""" 95 | # Create a table we can maliciously drop 96 | table = sqla.Table('test', sqla.MetaData(), sqla.Column('id', sqla.Integer)) 97 | clean_db.create_table(table) 98 | assert clean_db.has_table(table) 99 | 100 | # Try creating a schema that will result in the following query: 101 | # CREATE SCHEMA "foo"; DROP TABLE test CASCADE; --" 102 | malicious_schema = 'foo"; DROP TABLE test CASCADE; --' 103 | clean_db.create_schema(malicious_schema) 104 | 105 | # Our table should be intact, and there should be no schema called "foo". 106 | assert clean_db.has_table(table) 107 | assert clean_db.has_schema(malicious_schema) 108 | assert not clean_db.has_schema('foo') 109 | 110 | 111 | @pytest.mark.parametrize('conn_name', [ 112 | '_conn', 113 | 'session', 114 | ]) 115 | def test_has_schema(clean_db, conn_name): 116 | schema = random_identifier() 117 | clean_db.create_schema(schema) 118 | conn = getattr(clean_db, conn_name) 119 | 120 | # Make sure we get the same result executing the query directly and using 121 | # our function. 122 | query = sqla.text( 123 | 'SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname=:name LIMIT 1)' 124 | ).bindparams(name=schema) 125 | 126 | assert conn.execute(query).scalar() is True 127 | assert clean_db.has_schema(schema) 128 | 129 | exists = conn.execute( 130 | "SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname='bogus' LIMIT 1)" 131 | ).scalar() 132 | 133 | assert exists is False 134 | assert not clean_db.has_schema('bogus') 135 | 136 | 137 | @pytest.mark.parametrize('conn_name', [ 138 | '_conn', 139 | 'session', 140 | ]) 141 | def test_has_table(clean_db, conn_name): 142 | schema_name = random_identifier() 143 | table_name = random_identifier() 144 | conn = getattr(clean_db, conn_name) 145 | 146 | table = sqla.Table( 147 | table_name, 148 | sqla.MetaData(bind=conn), 149 | sqla.Column('id', sqla.Integer, primary_key=True), 150 | schema=schema_name) 151 | 152 | clean_db.create_schema(schema_name) 153 | clean_db.create_table(table) 154 | 155 | # Test all three ways to check for this table - the Table object, the table 156 | # name with the schema, and the table name without the schema. 157 | assert clean_db.has_table(table) 158 | assert clean_db.has_table(table.fullname) 159 | assert clean_db.has_table(table_name) 160 | 161 | assert not clean_db.has_table('bogus_table') 162 | 163 | with pytest.raises(TypeError) as errinfo: 164 | clean_db.has_table(0) 165 | 166 | assert str(errinfo.value) == \ 167 | "Expected str, SQLAlchemy Table, or declarative model, got 'int'." 168 | 169 | 170 | def test_reset_db_transacted(clean_db): 171 | """Verify reset_db() blows away schemas we created in the transacted DB.""" 172 | schema = random_identifier() 173 | clean_db.create_schema(schema) 174 | clean_db.install_extension('pgcrypto') 175 | assert clean_db.has_schema(schema) 176 | 177 | 178 | def test_create_table(clean_db): 179 | """Create a single table.""" 180 | table = sqla.Table( 181 | 'test_table', 182 | sqla.MetaData(), 183 | sqla.Column('id', sqla.Integer, primary_key=True)) 184 | 185 | assert not clean_db.has_table('test_table') 186 | clean_db.create_table(table) 187 | assert clean_db.has_table('test_table') 188 | 189 | 190 | def test_create_multiple_tables(clean_db): 191 | """Create multiple tables.""" 192 | tables = [ 193 | sqla.Table( 194 | 'test_table_%s' % i, 195 | sqla.MetaData(), 196 | sqla.Column('id', sqla.Integer, primary_key=True)) 197 | for i in range(5) 198 | ] 199 | 200 | for table in tables: 201 | assert not clean_db.has_table(table.fullname) 202 | 203 | clean_db.create_table(*tables) 204 | 205 | for table in tables: 206 | assert clean_db.has_table(table.fullname) 207 | 208 | 209 | def test_create_decl_table(clean_db): 210 | """Create a declarative ORM model in the database.""" 211 | assert not clean_db.has_table(BasicModel) 212 | assert not clean_db.has_table(BasicModel.__table__) 213 | assert not clean_db.has_table(BasicModel.__tablename__) 214 | 215 | clean_db.create_table(BasicModel.__table__) 216 | 217 | assert clean_db.has_table(BasicModel) 218 | assert clean_db.has_table(BasicModel.__table__) 219 | assert clean_db.has_table(BasicModel.__tablename__) 220 | 221 | 222 | # NOTE: Don't swap the order of the parameterizations, and don't combine them. 223 | # We need the schema to change multiple times per connection, not change the 224 | # connection multiple times per schema. 225 | @pytest.mark.parametrize('conn_name', [ 226 | '_conn', 227 | 'session', 228 | ]) 229 | @pytest.mark.parametrize('schema', [ 230 | # Each schema name must occur twice consecutively so that SQLAlchemy will 231 | # explode if we have a collision. 232 | 'my_schema', 233 | 'my_schema', 234 | 'public', 235 | 'public', 236 | ]) 237 | def test_manual_create_table_teardown(clean_db, conn_name, schema): 238 | """Tables created manually should be deleted automatically. 239 | 240 | The idea here is to create a table using the connection. The test will end 241 | and the table should be deleted. After that, the session will attempt to 242 | create a table with the same name. No exception should be raised. 243 | 244 | We test this with both a custom schema and a preexisting schema to ensure 245 | that table deletions work for schemas we can't delete. 246 | """ 247 | # clean_db.create_schema('my_schema') 248 | # conn = getattr(clean_db, conn_name) 249 | # conn.execute('CREATE TABLE %s.thing (id SERIAL)' % schema) 250 | # assert clean_db.has_table('%s.thing' % schema) 251 | 252 | 253 | @pytest.mark.parametrize('conn_name', [ 254 | '_conn', 255 | 'session', 256 | ]) 257 | def test_has_extension_true_negative(clean_db, conn_name): 258 | """Verify we can accurately detect an uninstalled extension.""" 259 | assert not clean_db.has_extension('uuid-ossp') 260 | 261 | # If the extension isn't installed, attempting to generate a UUID will fail. 262 | conn = getattr(clean_db, conn_name) 263 | with pytest.raises(sqla_exc.DatabaseError): 264 | conn.execute('SELECT uuid_generate_v4()') 265 | 266 | 267 | def test_create_extension_no_injection(clean_db): 268 | """Verify that install_extension is invulnerable to SQL injection.""" 269 | # Try installing an extension that will result in the following query: 270 | # CREATE EXTENSION "uuid-ossp"; DROP TABLE IF EXISTS test CASCADE; --" 271 | malicious_extension = 'uuid-ossp"; DROP TABLE IF EXISTS test CASCADE; --' 272 | 273 | # If we quoted this right then an exception should've been raised because 274 | # 'uuid-ossp"; DROP TABLE test CASCADE; --' is not a valid extension name. 275 | # If no exception is raised then the DROP succeeded. 276 | with pytest.raises(sqla_exc.DataError): 277 | clean_db.install_extension(malicious_extension) 278 | 279 | 280 | @pytest.mark.parametrize('conn_name', [ 281 | '_conn', 282 | 'session', 283 | ]) 284 | def test_create_extension(clean_db, conn_name): 285 | """Verify creating an extension works.""" 286 | assert not clean_db.has_extension('uuid-ossp') 287 | 288 | clean_db.install_extension('uuid-ossp') 289 | assert clean_db.has_extension('uuid-ossp') 290 | 291 | # This shouldn't blow up if we have the extension installed. 292 | conn = getattr(clean_db, conn_name) 293 | conn.execute('SELECT uuid_generate_v4()') 294 | 295 | 296 | def test_create_extension_in_schema(clean_db): 297 | """Ensure we can create an extension in a non-default schema.""" 298 | assert not clean_db.has_extension('uuid-ossp') 299 | clean_db.create_schema('foo') 300 | clean_db.install_extension('uuid-ossp', schema='foo') 301 | assert clean_db.has_extension('uuid-ossp') 302 | 303 | # Because has_extension() will return ``True`` wherever the extension is, 304 | # we have to see if it exists in a different schema another way. 305 | clean_db.session.execute('SELECT foo.uuid_generate_v4()') 306 | 307 | 308 | def test_create_extension_exists_ok(clean_db): 309 | """The exists_ok argument should prevent crashing if an extension exists.""" 310 | assert not clean_db.has_extension('uuid-ossp') 311 | assert clean_db.install_extension('uuid-ossp') is True 312 | assert clean_db.has_extension('uuid-ossp') 313 | 314 | # Try to install the extension again. It should blow up if we don't have 315 | # exists_ok set. 316 | with pytest.raises(sqla_exc.ProgrammingError): 317 | clean_db.install_extension('uuid-ossp') 318 | 319 | # Recover from the exception 320 | clean_db.rollback() 321 | 322 | # Try installing the extension again. 323 | assert clean_db.install_extension('uuid-ossp', exists_ok=True) is True 324 | 325 | 326 | def test_create_extension_if_available(clean_db): 327 | """Using if_available will not install unsupported extensions.""" 328 | assert not clean_db.is_extension_available('asdf') 329 | 330 | # Try to install this bogus exception and make sure we don't swallow the 331 | # error somehow. 332 | with pytest.raises(sqla_exc.OperationalError): 333 | clean_db.install_extension('asdf') 334 | 335 | # Recover from the exception and try to install the extension again. 336 | clean_db.rollback() 337 | assert clean_db.install_extension('asfd', if_available=True) is False 338 | 339 | 340 | @pytest.mark.parametrize('ext_name,result', [ 341 | ('uuid-ossp', True), # Test with test name requiring quotes 342 | ('pgcrypto', True), # Test with test name not requiring quotes 343 | ('asdfjkl', False), # This shouldn't exist. 344 | ]) 345 | def test_has_extension(clean_db, ext_name, result): 346 | """Verify extension checking works.""" 347 | assert clean_db.is_extension_available(ext_name) is result 348 | 349 | 350 | def test_reflect_public_table_no_schema(clean_db): 351 | """Reflecting a table without the schema will succeed IF the table is in the 352 | search path, e.g. ``public`` or ``pg_catalog``.""" 353 | assert clean_db.has_table('pg_index') 354 | reflected = clean_db.get_table('pg_index') 355 | 356 | # Verify we can use this reflected table. 357 | query = sqla_sql.select([sqla_func.count()]).select_from(reflected) 358 | n_rows = clean_db.session.execute(query).scalar() 359 | assert n_rows > 0 360 | 361 | 362 | def test_reflect_table_with_schema(clean_db): 363 | """Reflect a table with the schema.""" 364 | assert clean_db.has_table('pg_catalog.pg_index') 365 | reflected = clean_db.get_table('pg_catalog.pg_index') 366 | 367 | # Verify we can use this reflected table. 368 | query = sqla_sql.select([sqla_func.count()]).select_from(reflected) 369 | n_rows = clean_db.session.execute(query).scalar() 370 | assert n_rows > 0 371 | 372 | 373 | def test_reflect_table_with_metadata(clean_db): 374 | """Verify the metadata we pass in is bound to the table.""" 375 | meta = sqla.MetaData() 376 | reflected = clean_db.get_table('pg_catalog.pg_index', meta) 377 | assert reflected.metadata is meta 378 | 379 | 380 | @pytest_pgsql.freeze_time('2017-01-01') 381 | def test_run_sql_basic_filename(clean_tpgdb): 382 | """Test executing a basic SQL file, passing a filename to the function.""" 383 | with tempfile.NamedTemporaryFile('w+') as fd: 384 | fd.write('SELECT CURRENT_DATE') 385 | fd.flush() 386 | 387 | result = clean_tpgdb.run_sql_file(fd.name) 388 | assert isinstance(result, sqla_eng.ResultProxy) 389 | assert result.scalar() == datetime.date(2017, 1, 1) 390 | 391 | 392 | @pytest_pgsql.freeze_time('2017-01-01') 393 | def test_run_sql_basic_buffer(clean_tpgdb): 394 | """Test executing a basic SQL file, passing a buffer to the function.""" 395 | sql_file = io.StringIO('SELECT CURRENT_DATE') 396 | 397 | result = clean_tpgdb.run_sql_file(sql_file) 398 | assert isinstance(result, sqla_eng.ResultProxy) 399 | assert result.scalar() == datetime.date(2017, 1, 1) 400 | 401 | 402 | def test_run_sql_basic_bindings(clean_tpgdb): 403 | """Test executing a basic SQL file with bindings.""" 404 | sql_file = io.StringIO('SELECT CURRENT_DATE = :date') 405 | 406 | result = clean_tpgdb.run_sql_file(sql_file, date=datetime.date(1970, 1, 1)) 407 | assert isinstance(result, sqla_eng.ResultProxy) 408 | assert result.scalar() is False 409 | 410 | 411 | def test_run_sql_transacted_teardown_ok(clean_tpgdb): 412 | """Verify that teardown still works with the SQL execution in the transacted 413 | database.""" 414 | sql_file = io.StringIO(""" 415 | CREATE SCHEMA garbage; 416 | CREATE TABLE garbage.more_garbage(id SERIAL PRIMARY KEY); 417 | CREATE TABLE public.even_more_garbage(id SERIAL PRIMARY KEY); 418 | """) 419 | 420 | clean_tpgdb.run_sql_file(sql_file) 421 | # Assertions done by clean_tpgdb for us 422 | 423 | 424 | def test_run_sql_teardown_ok(clean_pgdb): 425 | """Verify that teardown still works with the SQL execution in the regular 426 | database.""" 427 | sql_file = io.StringIO(""" 428 | CREATE SCHEMA garbage; 429 | CREATE TABLE garbage.more_garbage(id SERIAL PRIMARY KEY); 430 | CREATE TABLE public.even_more_garbage(id SERIAL PRIMARY KEY); 431 | """) 432 | 433 | clean_pgdb.run_sql_file(sql_file) 434 | # Assertions done by clean_pgdb for us 435 | 436 | 437 | @pytest.fixture 438 | def basic_csv(): 439 | """Create a CSV file and return the data along with the file descriptor.""" 440 | csv_rows = [{'id': i, 'value': random.randrange(100)} for i in range(10)] 441 | 442 | # We have to use `NamedTemporaryFile` because the fs fixture doesn't appear 443 | # to work with Pandas (TypeError thrown when opening a file by name). 444 | with tempfile.NamedTemporaryFile('w+') as fd: 445 | writer = csv.DictWriter(fd, ('id', 'value')) 446 | writer.writeheader() 447 | writer.writerows(csv_rows) 448 | fd.flush() 449 | fd.seek(0) 450 | yield csv_rows, fd 451 | 452 | 453 | @pytest.mark.parametrize('count_mult,truncate,cascade', ( 454 | (1, True, False), # Truncate but don't cascade (should be okay for this). 455 | (1, True, True), # Truncate and cascade (shouldn't matter). 456 | (2, False, False), # Don't truncate, don't cascade. 457 | (2, False, True), # Don't truncate, `cascade` should be ignored. 458 | )) 459 | def test_load_csv_basic(clean_tpgdb, basic_csv, count_mult, truncate, cascade): 460 | """Test basic load of a CSV, and that data is appended by default.""" 461 | csv_rows, csv_fd = basic_csv 462 | 463 | clean_tpgdb.create_table(BASIC_TABLE) 464 | assert clean_tpgdb.has_table(BASIC_TABLE) 465 | 466 | n_inserted = clean_tpgdb.load_csv(csv_fd, BASIC_TABLE) 467 | assert n_inserted == len(csv_rows) 468 | assert get_basictable_rowcount(clean_tpgdb.session) == n_inserted 469 | 470 | # Load data from the CSV again. We should now have exactly twice the number 471 | # of rows, since this is supposed to append by default. Also use the table's 472 | # name instead of the object itself. 473 | n_inserted = clean_tpgdb.load_csv(csv_fd.name, BASIC_TABLE.fullname, 474 | truncate=truncate, cascade=cascade) 475 | assert n_inserted == len(csv_rows) 476 | assert get_basictable_rowcount(clean_tpgdb.session) == count_mult * len(csv_rows) 477 | 478 | 479 | def test_load_csv_declarative(clean_tpgdb, basic_csv): 480 | """Try loading a CSV into a declarative model. 481 | 482 | TODO (dargueta): Somehow integrate this into ``test_load_csv_basic``. 483 | """ 484 | csv_rows, csv_fd = basic_csv 485 | 486 | clean_tpgdb.create_table(BasicModel) 487 | assert clean_tpgdb.has_table(BasicModel) 488 | 489 | n_inserted = clean_tpgdb.load_csv(csv_fd.name, BasicModel) 490 | assert n_inserted == len(csv_rows) 491 | assert get_basictable_rowcount(clean_tpgdb.session, BasicModel) == n_inserted 492 | 493 | 494 | def test_load_csv_truncates_table(clean_tpgdb, basic_csv): 495 | """Verify truncating works when no tables reference the one being loaded.""" 496 | csv_rows, csv_fd = basic_csv 497 | 498 | clean_tpgdb.create_table(BASIC_TABLE) 499 | assert clean_tpgdb.has_table(BASIC_TABLE) 500 | 501 | # pylint: disable=no-value-for-parameter 502 | clean_tpgdb.session.execute(BASIC_TABLE.insert().values(csv_rows)) 503 | # pylint: enable=no-value-for-parameter 504 | 505 | # Load data from the CSV again. Because we're truncating we should still 506 | # have exactly the same number of rows. 507 | n_inserted = clean_tpgdb.load_csv(csv_fd.name, BASIC_TABLE, truncate=True) 508 | assert n_inserted == len(csv_rows) 509 | assert get_basictable_rowcount(clean_tpgdb.session) == len(csv_rows) 510 | 511 | 512 | @pytest.mark.parametrize('truncate,expected_exc', ( 513 | (False, sqla_exc.IntegrityError), # Don't truncate -> pkey violation 514 | (True, sqla_exc.NotSupportedError), # Truncate but don't cascade -> boom 515 | )) 516 | def test_load_csv_to_referenced_table_crash( 517 | clean_tpgdb, basic_csv, truncate, expected_exc): 518 | """Verify expected crashes when loading duplicates but not truncating, or 519 | not cascading when truncating.""" 520 | csv_rows, csv_fd = basic_csv 521 | 522 | clean_tpgdb.create_schema(REFERENCED_TABLE.schema) 523 | clean_tpgdb.create_table(REFERENCED_TABLE) 524 | clean_tpgdb.create_table(REFERRING_TABLE) 525 | 526 | assert clean_tpgdb.has_table(REFERENCED_TABLE) 527 | assert clean_tpgdb.has_table(REFERRING_TABLE) 528 | 529 | # pylint: disable=no-value-for-parameter 530 | clean_tpgdb.session.execute(REFERENCED_TABLE.insert().values(csv_rows)) 531 | REFERRING_TABLE.insert().values({'id': 1, 'ref_id': 1}) 532 | # pylint: enable=no-value-for-parameter 533 | 534 | # Try loading from the CSV. 535 | with pytest.raises(expected_exc): 536 | clean_tpgdb.load_csv(csv_fd.name, REFERENCED_TABLE, truncate=truncate) 537 | 538 | 539 | def test_load_csv_to_referenced_table_ok(clean_tpgdb, basic_csv): 540 | """Ensure truncation cascades to referring tables.""" 541 | csv_rows, csv_fd = basic_csv 542 | 543 | clean_tpgdb.create_table(REFERENCED_TABLE) 544 | clean_tpgdb.create_table(REFERRING_TABLE) 545 | 546 | assert clean_tpgdb.has_table(REFERENCED_TABLE) 547 | assert clean_tpgdb.has_table(REFERRING_TABLE) 548 | 549 | # pylint: disable=no-value-for-parameter 550 | clean_tpgdb.session.execute(REFERENCED_TABLE.insert().values(csv_rows)) 551 | REFERRING_TABLE.insert().values({'id': 1, 'ref_id': 1}) 552 | # pylint: enable=no-value-for-parameter 553 | 554 | # Load data from the CSV again. Because we're truncating we should still 555 | # have exactly the same number of rows. 556 | n_inserted = clean_tpgdb.load_csv(csv_fd.name, REFERENCED_TABLE, 557 | truncate=True, cascade=True) 558 | assert n_inserted == len(csv_rows) 559 | assert get_basictable_rowcount(clean_tpgdb.session, REFERENCED_TABLE) == len(csv_rows) 560 | assert get_basictable_rowcount(clean_tpgdb.session, REFERRING_TABLE) == 0 561 | 562 | 563 | @pytest.mark.parametrize('create_stmt,drop_stmt', [ 564 | ('CREATE TABLE public.garbage (id SERIAL)', 'DROP TABLE public.garbage CASCADE'), 565 | ('CREATE SCHEMA garbage', 'DROP SCHEMA garbage CASCADE'), 566 | ('CREATE EXTENSION pgcrypto', 'DROP EXTENSION pgcrypto'), 567 | ]) 568 | def test_dirty_database_table(transacted_postgresql_db, create_stmt, drop_stmt): 569 | """Verify an exception is thrown when the database isn't cleaned up with 570 | just a rollback.""" 571 | transacted_postgresql_db.connection.execute(create_stmt + '; COMMIT;') 572 | 573 | with pytest.raises(errors.DatabaseIsDirtyError): 574 | transacted_postgresql_db.reset_db() 575 | 576 | transacted_postgresql_db.connection.execute(drop_stmt + '; COMMIT;') 577 | 578 | 579 | @pytest.mark.parametrize('db_class', [ 580 | pytest_pgsql.database.TransactedPostgreSQLTestDB, 581 | pytest_pgsql.database.PostgreSQLTestDB, 582 | ]) 583 | def test_restore_no_snapshot_transacted_fails(transacted_postgresql_db, db_class): 584 | """Blow up if the user tries restoring the database without a snapshot.""" 585 | db = db_class(transacted_postgresql_db.postgresql_url, 586 | transacted_postgresql_db.connection) 587 | with pytest.raises(errors.NoSnapshotAvailableError): 588 | db.restore_to_snapshot() 589 | 590 | 591 | @pytest.mark.parametrize('db_class', [ 592 | pytest_pgsql.database.TransactedPostgreSQLTestDB, 593 | pytest_pgsql.database.PostgreSQLTestDB, 594 | ]) 595 | def test_reset_db_no_snapshot_is_ok(transacted_postgresql_db, db_class, mocker): 596 | """Resetting without a snapshot should skip restore_to_snapshot().""" 597 | db = db_class(transacted_postgresql_db.postgresql_url, 598 | transacted_postgresql_db.connection) 599 | 600 | restore_mock = mocker.patch.object(db, 'restore_to_snapshot') 601 | db.reset_db() 602 | assert restore_mock.call_count == 0 603 | -------------------------------------------------------------------------------- /pytest_pgsql/tests/ext_test.py: -------------------------------------------------------------------------------- 1 | """Test extension stuff.""" 2 | 3 | import datetime 4 | 5 | import sqlalchemy as sqla 6 | import sqlalchemy.dialects.postgresql as sqla_pg 7 | 8 | import pytest_pgsql 9 | 10 | 11 | jengine = pytest_pgsql.create_engine_fixture('jengine', 12 | json_serializer=lambda _: '{}') 13 | jdb = pytest_pgsql.TransactedPostgreSQLTestDB.create_fixture('jdb', 'jengine') 14 | 15 | 16 | def test_uses_json(jdb): 17 | """Verify that the engine we create has a JSON serializer attached.""" 18 | meta = sqla.MetaData(bind=jdb.connection) 19 | table = sqla.Table('ext_test', meta, sqla.Column('col', sqla_pg.JSON)) 20 | meta.create_all() 21 | 22 | # If we have the serializer set on this engine, we should be able to insert 23 | # a datetime with no trouble. 24 | jdb.connection.execute(table.insert(), [ # pylint: disable=no-value-for-parameter 25 | {'col': datetime.datetime.now()} 26 | ]) 27 | 28 | all_rows = [dict(r) for r in jdb.connection.execute(table.select())] 29 | 30 | assert all_rows == [{'col': {}}] 31 | -------------------------------------------------------------------------------- /pytest_pgsql/tests/plugin_test.py: -------------------------------------------------------------------------------- 1 | """Basic tests for the fixtures.""" 2 | 3 | 4 | def test_extensions_option(pg_engine): 5 | """Verifies the --pg-extensions option works. 6 | 7 | This test entirely relies on the test suite being executed with:: 8 | 9 | --pg-extensions=btree_gin,,btree_gist 10 | """ 11 | are_installed = pg_engine.execute(""" 12 | SELECT 13 | EXISTS(SELECT 1 FROM pg_extension WHERE extname='btree_gin' LIMIT 1) 14 | AND EXISTS(SELECT 1 FROM pg_extension WHERE extname='btree_gist' LIMIT 1) 15 | """).scalar() 16 | 17 | assert are_installed, \ 18 | "'btree_gin' and 'btree_gist' should've been installed automatically." 19 | -------------------------------------------------------------------------------- /pytest_pgsql/tests/time_test.py: -------------------------------------------------------------------------------- 1 | """Tests for the freezegun adaptation for the fixtures.""" 2 | 3 | import datetime 4 | 5 | import pytest 6 | import sqlalchemy.orm as sqla_orm 7 | 8 | import pytest_pgsql 9 | 10 | _PGFREEZE_DATETIME_TZ = datetime.datetime(2099, 12, 31, 23, 59, 59, 123000, 11 | datetime.timezone.utc) 12 | _PGFREEZE_TIME_TZ = _PGFREEZE_DATETIME_TZ.timetz() 13 | _PGFREEZE_DATE = _PGFREEZE_DATETIME_TZ.date() 14 | _PGFREEZE_DATETIME_NAIVE = _PGFREEZE_DATETIME_TZ.replace(tzinfo=None) 15 | _PGFREEZE_TIME_NAIVE = _PGFREEZE_TIME_TZ.replace(tzinfo=None) 16 | 17 | # These are identical to the above timestamps except they're rounded down to the 18 | # nearest second. 19 | _PGFREEZE_DATETIME_TZ_SECS = _PGFREEZE_DATETIME_TZ.replace(microsecond=0) 20 | _PGFREEZE_TIME_TZ_SECS = _PGFREEZE_TIME_TZ.replace(microsecond=0) 21 | _PGFREEZE_DATETIME_NAIVE_SECS = _PGFREEZE_DATETIME_NAIVE.replace(microsecond=0) 22 | _PGFREEZE_TIME_NAIVE_SECS = _PGFREEZE_TIME_NAIVE.replace(microsecond=0) 23 | 24 | 25 | @pytest.mark.parametrize('expression', [ 26 | 'SELECT %s', 27 | "SELECT date_trunc('MICROSECONDS', %s)", 28 | ]) 29 | @pytest.mark.parametrize('function,expected', [ 30 | ('CURRENT_TIMESTAMP', _PGFREEZE_DATETIME_TZ), 31 | ('NOW()', _PGFREEZE_DATETIME_TZ), 32 | ('now ( )', _PGFREEZE_DATETIME_TZ), 33 | ('transaction_timestamp()', _PGFREEZE_DATETIME_TZ), 34 | ('StaTemEnt_TimeStaMp()', _PGFREEZE_DATETIME_TZ), 35 | ('LocalTimestamp', _PGFREEZE_DATETIME_NAIVE), 36 | 37 | # These ensure we can form syntactically valid statements with casts. 38 | ('current_timestamp::timestamp', _PGFREEZE_DATETIME_NAIVE), 39 | ('localtimestamp::timestamp(0)', _PGFREEZE_DATETIME_NAIVE_SECS), 40 | ('current_timestamp(3)::timestamptz(0)', _PGFREEZE_DATETIME_TZ_SECS), 41 | ]) 42 | def test_basic_timestamp(clean_db, expression, function, expected): 43 | """Basic test to see if the database freezegun works.""" 44 | clean_db.time.freeze(_PGFREEZE_DATETIME_TZ) 45 | assert clean_db.time.freezer is not None 46 | assert clean_db.time.is_frozen 47 | 48 | query = expression % function 49 | 50 | # Test first with the connectible... 51 | result = clean_db._conn.execute(query).scalar() # pylint: disable=protected-access 52 | assert result == expected, 'Freezing the connection failed.' 53 | 54 | # ...then with the session. 55 | result = clean_db.session.execute(query).scalar() 56 | assert result == expected, 'Freezing the session failed.' 57 | 58 | clean_db.time.unfreeze() 59 | assert clean_db.time.freezer is None 60 | assert not clean_db.time.is_frozen 61 | 62 | 63 | @pytest.mark.parametrize('conn_name', [ 64 | '_conn', 65 | 'session', 66 | ]) 67 | @pytest.mark.parametrize('function,expected', [ 68 | ('CURRENT_TIMESTAMP(0)', _PGFREEZE_DATETIME_TZ_SECS), 69 | ('CURRENT_TIME (0) ', _PGFREEZE_TIME_TZ_SECS), 70 | ('localtime(0 )', _PGFREEZE_TIME_NAIVE_SECS), 71 | ('LocalTimestamp ( 0 )', _PGFREEZE_DATETIME_NAIVE_SECS), 72 | ('NOW()::TIMESTAMPTZ(0)', _PGFREEZE_DATETIME_TZ_SECS), 73 | ]) 74 | def test_precision_override(clean_db, function, expected, conn_name): 75 | """Verify we form valid datetimes even when using precision truncation.""" 76 | conn = getattr(clean_db, conn_name) 77 | with clean_db.time.freeze(_PGFREEZE_DATETIME_TZ): 78 | result = conn.execute('SELECT ' + function).scalar() 79 | assert result == expected, 'Freezing the engine failed.' 80 | 81 | 82 | @pytest.mark.parametrize('conn_name', [ 83 | '_conn', 84 | 'session', 85 | ]) 86 | def test_timeofday(clean_db, conn_name): 87 | """``timeofday()`` is an oddball - it returns a string, not a timestamp.""" 88 | conn = getattr(clean_db, conn_name) 89 | with clean_db.time.freeze(_PGFREEZE_DATETIME_TZ): 90 | result = conn.execute('SELECT TIMEOFDAY() -- :)').scalar() 91 | assert result == '2099-12-31 23:59:59.123000 +0000' 92 | 93 | 94 | @pytest.mark.parametrize('conn_name', [ 95 | '_conn', 96 | 'session', 97 | ]) 98 | @pytest.mark.parametrize('expression,expected', [ 99 | # The query modifier shouldn't replace the 'CURRENT_DATE' part of this. 100 | ("SELECT 'CURRENT_DATETIME'::TEXT", 'CURRENT_DATETIME'), 101 | ]) 102 | def test_ignore_id(clean_db, conn_name, expression, expected): 103 | """These expressions should not be modified.""" 104 | conn = getattr(clean_db, conn_name) 105 | result = conn.execute(expression).scalar() 106 | assert result == expected 107 | 108 | 109 | @pytest.mark.parametrize('conn_name', [ 110 | '_conn', 111 | 'session', 112 | ]) 113 | def test_context(clean_db, conn_name): 114 | """Ensure that time resumes normally after the context manager exits.""" 115 | expected_ts = datetime.datetime(2525, 1, 1, 0, 0, 0, 116 | tzinfo=datetime.timezone.utc) 117 | 118 | conn = getattr(clean_db, conn_name) 119 | with clean_db.time.freeze(expected_ts): 120 | result = conn.execute('SELECT NOW()').scalar() 121 | assert result == expected_ts 122 | assert clean_db.time.freezer is not None 123 | assert clean_db.time.is_frozen 124 | 125 | assert clean_db.time.freezer is None 126 | assert not clean_db.time.is_frozen 127 | assert datetime.datetime.now(datetime.timezone.utc) != expected_ts 128 | result = conn.execute('SELECT NOW()').scalar() 129 | assert result != expected_ts 130 | 131 | 132 | def test_unfreeze_twice(clean_db): 133 | """Attempting to unfreeze when not frozen shouldn't throw exceptions. 134 | 135 | This verifies that the query execution hook is properly removed and won't 136 | cause an exception due to the freezer and hook being nulled out. 137 | """ 138 | clean_db.time.unfreeze() 139 | assert clean_db.time.freezer is None 140 | 141 | with clean_db.time.freeze('1900-01-01'): 142 | date = clean_db.session.execute('SELECT CURRENT_DATE').scalar() 143 | assert date == datetime.date(1900, 1, 1) 144 | assert clean_db.time.freezer is not None 145 | 146 | # Shouldn't be frozen anymore, but let's try unfreezing anyway. 147 | clean_db.time.unfreeze() 148 | assert clean_db.time.freezer is None 149 | 150 | 151 | def test_tick(clean_db): 152 | """Verify we can use the freezer factory thing.""" 153 | expected = datetime.datetime(2016, 12, 31, 23, 59, 59) 154 | with clean_db.time.freeze(expected) as freezer: 155 | db_now = clean_db.session.execute('SELECT LOCALTIMESTAMP').scalar() 156 | assert datetime.datetime.now() == expected 157 | assert db_now == expected 158 | 159 | freezer.tick() 160 | 161 | expected = datetime.datetime(2017, 1, 1, 0, 0, 0) 162 | db_now = clean_db.session.execute('SELECT LOCALTIMESTAMP').scalar() 163 | assert datetime.datetime.now() == expected 164 | assert db_now == expected 165 | 166 | 167 | @pytest.mark.parametrize('conn_name', [ 168 | '_conn', 169 | 'session', 170 | ]) 171 | def test_change_time_twice(clean_db, conn_name): 172 | """Ensure we can change the time multiple times in a row.""" 173 | expected_ts = datetime.datetime(1066, 10, 14, 9, 0, 0, 174 | tzinfo=datetime.timezone.utc) 175 | 176 | conn = getattr(clean_db, conn_name) 177 | with clean_db.time.freeze(expected_ts): 178 | result = conn.execute('SELECT NOW()').scalar() 179 | assert result == expected_ts 180 | assert clean_db.time.freezer is not None 181 | assert clean_db.time.is_frozen 182 | 183 | # Time should be back to normal for the time being 184 | result = conn.execute('SELECT NOW()').scalar() 185 | assert result != expected_ts 186 | assert clean_db.time.freezer is None 187 | 188 | expected_ts = datetime.datetime(1234, 5, 6, 7, 8, 9, 189 | tzinfo=datetime.timezone.utc) 190 | with clean_db.time.freeze(expected_ts): 191 | result = conn.execute('SELECT NOW()').scalar() 192 | assert result == expected_ts 193 | assert clean_db.time.freezer is not None 194 | assert clean_db.time.is_frozen 195 | 196 | assert clean_db.time.freezer is None 197 | assert not clean_db.time.is_frozen 198 | 199 | # Make sure time is unfrozen again. 200 | result = conn.execute('SELECT NOW()').scalar() 201 | assert result != expected_ts 202 | 203 | 204 | @pytest_pgsql.freeze_time('1111-11-11 11:11:11.111000 +0000') 205 | @pytest.mark.parametrize('conn_name', [ 206 | '_conn', 207 | 'session', 208 | ]) 209 | def test_decorator(clean_db, conn_name): 210 | """We're using the freezing decorator, all time should be frozen in here.""" 211 | expected_time = datetime.datetime(1111, 11, 11, 11, 11, 11, 111000, 212 | datetime.timezone.utc) 213 | 214 | # First make sure that we activated freezegun properly 215 | now = datetime.datetime.now(datetime.timezone.utc) 216 | assert now == expected_time, 'Freezegun not activated!' 217 | 218 | conn = getattr(clean_db, conn_name) 219 | 220 | # Freezegun is active, now make sure that the database works as expected. 221 | now = conn.execute('SELECT NOW()').scalar() 222 | assert now == expected_time, "Database time isn't frozen." 223 | 224 | 225 | def test_decorator_no_fixtures(): 226 | """The decorator should crash if there are no database fixtures.""" 227 | with pytest.raises(RuntimeError) as excinfo: 228 | @pytest_pgsql.freeze_time('1999-12-31') 229 | def bad_test(): 230 | """This test doesn't use a database fixture!""" 231 | 232 | bad_test() 233 | 234 | assert str(excinfo.value).endswith("'bad_test' has 0.") 235 | 236 | 237 | def test_decorator_too_many_fixtures(transacted_postgresql_db, postgresql_db): 238 | """The decorator should crash if there's more than one database fixture.""" 239 | with pytest.raises(RuntimeError) as excinfo: 240 | @pytest_pgsql.freeze_time('1999-12-31') 241 | def bad_test(fixture_a, fixture_b): 242 | """This test uses too many database fixtures!""" 243 | 244 | bad_test(transacted_postgresql_db, postgresql_db) 245 | 246 | assert str(excinfo.value).endswith("'bad_test' has 2.") 247 | 248 | 249 | def test_can_use_session_transacted(clean_db): 250 | """Ensure we can use a Session for freezing time in the transacted DB.""" 251 | freezer = pytest_pgsql.SQLAlchemyFreezegun(clean_db.session) 252 | with freezer.freeze('2000-01-01'): 253 | expected = datetime.date(2000, 1, 1) 254 | assert datetime.date.today() == expected 255 | 256 | now = clean_db.session.execute('SELECT CURRENT_DATE').scalar() 257 | assert now == expected 258 | 259 | # pylint: disable=protected-access 260 | now = clean_db._conn.execute('SELECT CURRENT_DATE').scalar() 261 | # pylint: enable=protected-access 262 | assert now == expected 263 | 264 | 265 | def test_unbound_session_crashes(): 266 | """Ensure attempting to use an unbound session will fail.""" 267 | sessionmaker = sqla_orm.sessionmaker() 268 | session = sessionmaker() 269 | 270 | with pytest.raises(TypeError): 271 | pytest_pgsql.SQLAlchemyFreezegun(session) 272 | -------------------------------------------------------------------------------- /pytest_pgsql/time.py: -------------------------------------------------------------------------------- 1 | """A utility class for freezing timestamps inside SQL queries.""" 2 | 3 | import datetime 4 | import functools 5 | import re 6 | 7 | import freezegun 8 | import sqlalchemy.event as sa_event 9 | import sqlalchemy.orm.session as sqla_session 10 | 11 | 12 | _TIMESTAMP_REPLACEMENT_FORMATS = ( 13 | # Functions 14 | (r'\b((NOW|CLOCK_TIMESTAMP|STATEMENT_TIMESTAMP|TRANSACTION_TIMESTAMP)\s*\(\s*\))', 15 | r"'{:%Y-%m-%d %H:%M:%S.%f %z}'::TIMESTAMPTZ"), 16 | (r'\b(TIMEOFDAY\s*\(\s*\))', r"'{:%Y-%m-%d %H:%M:%S.%f %z}'::TEXT"), 17 | 18 | # Keywords 19 | (r'\b(CURRENT_DATE)\b', r"'{:%Y-%m-%d}'::DATE"), 20 | (r'\b(CURRENT_TIME)\b', r"'{:%H:%M:%S.%f %z}'::TIMETZ"), 21 | (r'\b(CURRENT_TIMESTAMP)\b', r"'{:%Y-%m-%d %H:%M:%S.%f %z}'::TIMESTAMPTZ"), 22 | (r'\b(LOCALTIME)\b', r"'{:%H:%M:%S.%f}'::TIME"), 23 | (r'\b(LOCALTIMESTAMP)\b', r"'{:%Y-%m-%d %H:%M:%S.%f}'::TIMESTAMP"), 24 | ) 25 | 26 | 27 | class SQLAlchemyFreezegun(object): 28 | """Freeze timestamps in all SQL executed while this freezegun is active. 29 | 30 | This works by hooking into SQLAlchemy's "before_cursor_execute" event and 31 | modifying the query to use a predetermined date/time/timestamp instead of 32 | calling ``NOW()`` or ``CURRENT_TIMESTAMP``. This gives reasonable assurance 33 | that all timestamps in database queries are predictable. 34 | 35 | You can use this as a context manager or by manually invoking functions:: 36 | 37 | def test_foo(postgresql_db): 38 | postgresql_db.time.freeze('2017-01-01 00:00:00') 39 | ... 40 | postgresql_db.time.unfreeze() 41 | 42 | 43 | def test_bar(postgresql_db): 44 | with postgresql_db.time.freeze('2017-01-01 00:00:00'): 45 | ... 46 | 47 | You also might be interested in the :func:`freeze_time` decorator. 48 | 49 | .. note :: 50 | Because this scans each query multiple times with a regular expression 51 | it can hurt performance considerably. You probably won't want to use 52 | this unless it's necessary. 53 | 54 | .. warning :: 55 | Because this works by modifying the query with regular expressions, it 56 | *is* fallible and unexpected behavior such as inexplicable syntax errors 57 | can occur. Some known cases in which it'll fail: 58 | 59 | **Triggers** 60 | 61 | Columns with an ``ON UPDATE CURRENT_TIMESTAMP`` trigger attached to them 62 | will always be set to the real time on an update, unless you override it 63 | in the insert statement. 64 | 65 | **Stored Procedures** 66 | 67 | Since only the query is modified, stored procedures will still use the 68 | real current time. 69 | 70 | **Strings** 71 | 72 | While rare and definitely ill-advised, when a query uses a keyword in a 73 | string constant it will still get replaced. For example, this: 74 | 75 | .. code-block:: sql 76 | 77 | SELECT CURRENT_DATE AS "current_date" 78 | 79 | will become this: 80 | 81 | .. code-block:: sql 82 | 83 | SELECT '2017-04-05'::DATE AS "'2017-04-05'::DATE" 84 | 85 | Arguments: 86 | connectable: 87 | A :class:`~sqlalchemy.engine.Connection`, 88 | :class:`~sqlalchemy.engine.Engine`, or a bound 89 | :class:`~sqlalchemy.orm.session.Session`. Only queries executed with 90 | this object will be modified. 91 | """ 92 | def __init__(self, connectable): 93 | # If the caller gives us a session, take the underlying connection or 94 | # engine instead. 95 | if isinstance(connectable, sqla_session.Session): 96 | if connectable.bind is None: 97 | raise TypeError("Can't use unbound `Session` object for freezing.") 98 | connectable = connectable.bind 99 | 100 | self._connectable = connectable 101 | self._query_hook = None 102 | self._freeze_time = None 103 | self._freezer_factory = None 104 | 105 | @property 106 | def is_frozen(self): 107 | """Is time currently being frozen?""" 108 | return self._freezer_factory is not None 109 | 110 | @property 111 | def freezer(self): 112 | """Return the currently active ``FrozenDateTimeFactory`` instance, or 113 | ``None`` if time is not being frozen.""" 114 | return self._freezer_factory 115 | 116 | def freeze(self, when=None, **freezegun_kwargs): 117 | """Start modifying timestamps in queries and Python code. 118 | 119 | Arguments: 120 | when (str|date|time|datetime): 121 | The point in time to freeze all date and time functions to. This 122 | will affect both PostgreSQL and Python. If a string is given, it 123 | must be a date and/or time in a format that Postgres recognizes. 124 | 125 | If not given, defaults to the current timestamp in UTC. 126 | 127 | **freezegun_kwargs: 128 | Any additional arguments to pass to ``freezegun.freeze_time()``. 129 | 130 | .. note :: 131 | If ``when`` is a `naive datetime `_, 132 | the default timezone is UTC, *not* the local timezone. 133 | """ 134 | if not when: 135 | when = datetime.datetime.now(datetime.timezone.utc) 136 | 137 | self.unfreeze() 138 | self._freeze_time = freezegun.freeze_time(when, **freezegun_kwargs) 139 | self._freezer_factory = self._freeze_time.start() 140 | 141 | # pylint: disable=unused-argument 142 | @sa_event.listens_for(self._connectable, 'before_cursor_execute', 143 | retval=True) 144 | def _hook(conn, cursor, statement, parameters, context, executemany): 145 | """Query hook to modify all timestamps.""" 146 | # We use datetime.now() here because it should already be frozen. No 147 | # need to hardcode it. 148 | timestamp = datetime.datetime.now(datetime.timezone.utc) 149 | 150 | for regex, replacement in _TIMESTAMP_REPLACEMENT_FORMATS: 151 | statement = re.sub(regex, replacement.format(timestamp), 152 | statement, flags=re.IGNORECASE) 153 | 154 | return statement, parameters 155 | # pylint: enable=unused-argument 156 | 157 | # Set up our query modifier to listen for execution events 158 | self._query_hook = _hook 159 | sa_event.listen(self._connectable, 'before_cursor_execute', _hook) 160 | 161 | # This is correct. Do *not* change this to return self._freezer_factory 162 | # or you will break the context manager behavior. 163 | return self 164 | 165 | def unfreeze(self): 166 | """Stop modifying timestamps in queries.""" 167 | if self._query_hook: 168 | sa_event.remove(self._connectable, 'before_cursor_execute', 169 | self._query_hook) 170 | self._query_hook = None 171 | 172 | if self._freeze_time is not None: 173 | self._freeze_time.stop() 174 | self._freeze_time = None 175 | self._freezer_factory = None 176 | 177 | def __enter__(self): 178 | """Start the time freeze when this is used as a context manager.""" 179 | self.freeze() 180 | return self._freezer_factory 181 | 182 | def __exit__(self, *error_args): 183 | """Exiting the context manager, stop freezing time.""" 184 | self.unfreeze() 185 | 186 | 187 | def _is_freezeable(obj): 188 | """Determine if obj has the same freezing interface as `PostgreSQLTestUtil`. 189 | 190 | For some reason isinstance doesn't work properly with fixtures, so checking 191 | ``isinstance(obj, PostgreSQLTestDB)`` will always fail. Instead, we check to 192 | see if obj.time.freeze()/unfreeze() are present, and that the `time` member 193 | has context manager behavior implemented. 194 | """ 195 | return ( 196 | hasattr(obj, 'time') and 197 | callable(getattr(obj.time, 'freeze', None)) and 198 | callable(getattr(obj.time, 'unfreeze', None)) and 199 | callable(getattr(obj.time, '__enter__', None)) and 200 | callable(getattr(obj.time, '__exit__', None)) 201 | ) 202 | 203 | 204 | def freeze_time(when): 205 | """Freeze time inside a test, including in queries made to the database. 206 | 207 | This differs from normal ``freezegun`` usage in that it also works for 208 | database queries, with some caveats (see `SQLAlchemyFreezegun`). 209 | 210 | The test modified by this decorator must use one and only one fixture that 211 | returns a `PostgreSQLTestDB` instance. This means you can't *implicitly* use 212 | a fixture with the ``pytest.mark.usefixtures`` decorator. 213 | 214 | Sample usage:: 215 | 216 | @pytest_pgsql.freeze_time('2999-12-31') 217 | def test_baz(postgresql_db): 218 | assert datetime.date.today() == datetime.date(2999, 12, 31) 219 | 220 | now = postgresql_db.engine.execute('SELECT CURRENT_DATE').scalar() 221 | assert now == datetime.date(2999, 12, 31) 222 | 223 | Arguments: 224 | when (str|date|time|datetime): 225 | The timestamp to freeze all date and time functions to. This will 226 | affect both PostgreSQL and Python. 227 | """ 228 | def decorator(func): 229 | @functools.wraps(func) 230 | def test_function_wrapper(*args, **kwargs): 231 | # Get all fixtures passed to the test function; one and only one of 232 | # these must be a freezable database. 233 | databases = [a for a in args if _is_freezeable(a)] 234 | databases.extend(v for v in kwargs.values() if _is_freezeable(v)) 235 | 236 | if len(databases) != 1: 237 | func_name = getattr(func, '__name__', type(func).__name__) 238 | raise RuntimeError( 239 | 'You must use exactly *one* database fixture with the ' 240 | '`freeze_time` decorator. %r has %d.' 241 | % (func_name, len(databases))) 242 | 243 | with databases[0].time.freeze(when): 244 | return func(*args, **kwargs) 245 | return test_function_wrapper 246 | return decorator 247 | -------------------------------------------------------------------------------- /pytest_pgsql/version.py: -------------------------------------------------------------------------------- 1 | """__version__ is automatically updated with the deploy.py script. Don't touch this file""" 2 | __version__ = '1.1.3' 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | freezegun>=0.3.6 2 | pytest>=3.0.0 3 | sqlalchemy>=1.1.0 4 | testing.postgresql>=1.3.0 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pytest-pgsql 3 | author = Clover Health 4 | author-email = dev@cloverhealth.com 5 | summary = Pytest plugins and helpers for tests using a Postgres database. 6 | description-file = README.rst 7 | home-page = https://github.com/CloverHealth/pytest-pgsql 8 | python-requires = >=3.4 9 | license = Copyright Clover Health, Inc. 10 | classifier = 11 | Development Status :: 5 - Production/Stable 12 | Framework :: Pytest 13 | Intended Audience :: Developers 14 | License :: OSI Approved :: BSD License 15 | Operating System :: OS Independent 16 | Programming Language :: Python 17 | Programming Language :: Python :: 3 18 | Programming Language :: Python :: 3.4 19 | Programming Language :: Python :: 3.5 20 | Programming Language :: Python :: 3.6 21 | Topic :: Database 22 | Topic :: Software Development :: Testing 23 | 24 | [coverage:run] 25 | branch = True 26 | source = pytest_pgsql 27 | 28 | [coverage:report] 29 | exclude_lines = 30 | # Have to re-enable the standard pragma 31 | pragma: no cover 32 | 33 | # Don't cover defensive assertion code 34 | raise AssertionError 35 | raise NotImplementedError 36 | show_missing = 1 37 | fail_under = 100 38 | 39 | [entry_points] 40 | pytest11 = 41 | pytest_pgsql = pytest_pgsql.plugin 42 | 43 | [files] 44 | packages = pytest_pgsql 45 | 46 | [flake8] 47 | max-complexity = 10 48 | max-line-length = 99 49 | 50 | [pylint] 51 | # Pylint rules are defined in .pylintrc since it has no setup.cfg configuration 52 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import sys 3 | 4 | # Check the Python version manually because pip < 9.0 doesn't check it for us. 5 | if sys.version_info < (3, 4): 6 | raise RuntimeError('Unsupported version of Python: ' + sys.version) 7 | 8 | setup( 9 | setup_requires=['pbr'], 10 | pbr=True, 11 | ) 12 | -------------------------------------------------------------------------------- /test_requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements needed for running the test suite. Don't include tox and 2 | # tox-pyenv here because they're in dev_requirements.txt. 3 | 4 | coverage==4.4.2 5 | flake8==3.5.0 6 | psycopg2==2.7.3 7 | pylint==1.8.1 8 | pytest-mock==1.6.3 9 | pytest-profiling==1.2.11 10 | pytest-randomly==1.2.2 11 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py34,py35,py36 3 | 4 | [testenv] 5 | usedevelop = true 6 | deps = -rtest_requirements.txt 7 | passenv = * 8 | 9 | # Deliberately executing pytest with `--pg-extensions=btree_gin,,btree_gist,` to 10 | # verify that it installs extensions and ignores empty strings as expected. Make 11 | # sure you don't break it! 12 | commands = 13 | coverage run -a -m pytest --pg-conf-opt="track_commit_timestamp=True" --pg-extensions=btree_gin,,btree_gist, {posargs} pytest_pgsql/tests 14 | --------------------------------------------------------------------------------